replace api_key with ProviderCredential throughout the AssistantPanel

This commit is contained in:
KCaverly 2023-10-28 18:16:45 -04:00
parent 558f54c424
commit 1e8b23d8fb
5 changed files with 208 additions and 121 deletions

View file

@ -13,7 +13,7 @@ use std::{
};
use crate::{
auth::CredentialProvider,
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
@ -102,10 +102,17 @@ pub struct OpenAIResponseStreamEvent {
}
pub async fn stream_completion(
api_key: String,
credential: ProviderCredential,
executor: Arc<Background>,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = request.data()?;
@ -188,18 +195,22 @@ pub async fn stream_completion(
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential_provider: OpenAICredentialProvider,
api_key: String,
credential: ProviderCredential,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
pub fn new(
model_name: &str,
credential: ProviderCredential,
executor: Arc<Background>,
) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential_provider = OpenAICredentialProvider {};
Self {
model,
credential_provider,
api_key,
credential,
executor,
}
}
@ -218,7 +229,8 @@ impl CompletionProvider for OpenAICompletionProvider {
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
let credential = self.credential.clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response