replace api_key with ProviderCredential throughout the AssistantPanel
This commit is contained in:
parent
558f54c424
commit
1e8b23d8fb
5 changed files with 208 additions and 121 deletions
|
@ -9,6 +9,8 @@ pub enum ProviderCredential {
|
|||
|
||||
pub trait CredentialProvider: Send + Sync {
|
||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
|
||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
|
||||
fn delete_credentials(&self, cx: &AppContext);
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -17,4 +19,6 @@ impl CredentialProvider for NullCredentialProvider {
|
|||
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||
ProviderCredential::NotNeeded
|
||||
}
|
||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {}
|
||||
fn delete_credentials(&self, cx: &AppContext) {}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,12 @@ pub trait CompletionProvider {
|
|||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||
self.credential_provider().retrieve_credentials(cx)
|
||||
}
|
||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||
self.credential_provider().save_credentials(cx, credential);
|
||||
}
|
||||
fn delete_credentials(&self, cx: &AppContext) {
|
||||
self.credential_provider().delete_credentials(cx);
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
|
|
|
@ -30,4 +30,17 @@ impl CredentialProvider for OpenAICredentialProvider {
|
|||
ProviderCredential::NoCredentials
|
||||
}
|
||||
}
|
||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||
match credential {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
cx.platform()
|
||||
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
fn delete_credentials(&self, cx: &AppContext) {
|
||||
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ use std::{
|
|||
};
|
||||
|
||||
use crate::{
|
||||
auth::CredentialProvider,
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
models::LanguageModel,
|
||||
};
|
||||
|
@ -102,10 +102,17 @@ pub struct OpenAIResponseStreamEvent {
|
|||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
api_key: String,
|
||||
credential: ProviderCredential,
|
||||
executor: Arc<Background>,
|
||||
request: Box<dyn CompletionRequest>,
|
||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||
let api_key = match credential {
|
||||
ProviderCredential::Credentials { api_key } => api_key,
|
||||
_ => {
|
||||
return Err(anyhow!("no credentials provider for completion"));
|
||||
}
|
||||
};
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||
|
||||
let json_data = request.data()?;
|
||||
|
@ -188,18 +195,22 @@ pub async fn stream_completion(
|
|||
pub struct OpenAICompletionProvider {
|
||||
model: OpenAILanguageModel,
|
||||
credential_provider: OpenAICredentialProvider,
|
||||
api_key: String,
|
||||
credential: ProviderCredential,
|
||||
executor: Arc<Background>,
|
||||
}
|
||||
|
||||
impl OpenAICompletionProvider {
|
||||
pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
|
||||
pub fn new(
|
||||
model_name: &str,
|
||||
credential: ProviderCredential,
|
||||
executor: Arc<Background>,
|
||||
) -> Self {
|
||||
let model = OpenAILanguageModel::load(model_name);
|
||||
let credential_provider = OpenAICredentialProvider {};
|
||||
Self {
|
||||
model,
|
||||
credential_provider,
|
||||
api_key,
|
||||
credential,
|
||||
executor,
|
||||
}
|
||||
}
|
||||
|
@ -218,7 +229,8 @@ impl CompletionProvider for OpenAICompletionProvider {
|
|||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
||||
let credential = self.credential.clone();
|
||||
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue