replace api_key with ProviderCredential throughout the AssistantPanel

This commit is contained in:
KCaverly 2023-10-28 18:16:45 -04:00
parent 558f54c424
commit 1e8b23d8fb
5 changed files with 208 additions and 121 deletions

View file

@ -9,6 +9,8 @@ pub enum ProviderCredential {
pub trait CredentialProvider: Send + Sync {
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
fn delete_credentials(&self, cx: &AppContext);
}
#[derive(Clone)]
@ -17,4 +19,6 @@ impl CredentialProvider for NullCredentialProvider {
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {}
fn delete_credentials(&self, cx: &AppContext) {}
}

View file

@ -17,6 +17,12 @@ pub trait CompletionProvider {
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
self.credential_provider().retrieve_credentials(cx)
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
self.credential_provider().save_credentials(cx, credential);
}
fn delete_credentials(&self, cx: &AppContext) {
self.credential_provider().delete_credentials(cx);
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,

View file

@ -30,4 +30,17 @@ impl CredentialProvider for OpenAICredentialProvider {
ProviderCredential::NoCredentials
}
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
match credential {
ProviderCredential::Credentials { api_key } => {
cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
}
fn delete_credentials(&self, cx: &AppContext) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
}
}

View file

@ -13,7 +13,7 @@ use std::{
};
use crate::{
auth::CredentialProvider,
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
@ -102,10 +102,17 @@ pub struct OpenAIResponseStreamEvent {
}
pub async fn stream_completion(
api_key: String,
credential: ProviderCredential,
executor: Arc<Background>,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = request.data()?;
@ -188,18 +195,22 @@ pub async fn stream_completion(
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential_provider: OpenAICredentialProvider,
api_key: String,
credential: ProviderCredential,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
pub fn new(
model_name: &str,
credential: ProviderCredential,
executor: Arc<Background>,
) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential_provider = OpenAICredentialProvider {};
Self {
model,
credential_provider,
api_key,
credential,
executor,
}
}
@ -218,7 +229,8 @@ impl CompletionProvider for OpenAICompletionProvider {
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
let credential = self.credential.clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response