moved authentication for the semantic index into the EmbeddingProvider
This commit is contained in:
parent
1e8b23d8fb
commit
a2c3971ad6
14 changed files with 200 additions and 206 deletions
|
@ -3,14 +3,17 @@ use futures::{
|
|||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
Stream, StreamExt,
|
||||
};
|
||||
use gpui::executor::Background;
|
||||
use gpui::{executor::Background, AppContext};
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
env,
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
|
@ -18,9 +21,7 @@ use crate::{
|
|||
models::LanguageModel,
|
||||
};
|
||||
|
||||
use super::{auth::OpenAICredentialProvider, OpenAILanguageModel};
|
||||
|
||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
|
@ -194,42 +195,83 @@ pub async fn stream_completion(
|
|||
|
||||
pub struct OpenAICompletionProvider {
|
||||
model: OpenAILanguageModel,
|
||||
credential_provider: OpenAICredentialProvider,
|
||||
credential: ProviderCredential,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
executor: Arc<Background>,
|
||||
}
|
||||
|
||||
impl OpenAICompletionProvider {
|
||||
pub fn new(
|
||||
model_name: &str,
|
||||
credential: ProviderCredential,
|
||||
executor: Arc<Background>,
|
||||
) -> Self {
|
||||
pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
|
||||
let model = OpenAILanguageModel::load(model_name);
|
||||
let credential_provider = OpenAICredentialProvider {};
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
Self {
|
||||
model,
|
||||
credential_provider,
|
||||
credential,
|
||||
executor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for OpenAICompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||
let mut credential = self.credential.write();
|
||||
match *credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return credential.clone();
|
||||
}
|
||||
_ => {
|
||||
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
*credential = ProviderCredential::Credentials { api_key };
|
||||
} else if let Some((_, api_key)) = cx
|
||||
.platform()
|
||||
.read_credentials(OPENAI_API_URL)
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
*credential = ProviderCredential::Credentials { api_key };
|
||||
}
|
||||
} else {
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
credential.clone()
|
||||
}
|
||||
|
||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||
match credential.clone() {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
cx.platform()
|
||||
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
*self.credential.write() = credential;
|
||||
}
|
||||
fn delete_credentials(&self, cx: &AppContext) {
|
||||
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for OpenAICompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
|
||||
let provider: Box<dyn CredentialProvider> = Box::new(self.credential_provider.clone());
|
||||
provider
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let credential = self.credential.clone();
|
||||
let credential = self.credential.read().clone();
|
||||
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue