moved authentication for the semantic index into the EmbeddingProvider

This commit is contained in:
KCaverly 2023-10-30 10:02:27 -04:00
parent 1e8b23d8fb
commit a2c3971ad6
14 changed files with 200 additions and 206 deletions

View file

@ -3,14 +3,17 @@ use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::executor::Background;
use gpui::{executor::Background, AppContext};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use util::ResultExt;
use crate::{
auth::{CredentialProvider, ProviderCredential},
@ -18,9 +21,7 @@ use crate::{
models::LanguageModel,
};
use super::{auth::OpenAICredentialProvider, OpenAILanguageModel};
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
@ -194,42 +195,83 @@ pub async fn stream_completion(
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential_provider: OpenAICredentialProvider,
credential: ProviderCredential,
credential: Arc<RwLock<ProviderCredential>>,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(
model_name: &str,
credential: ProviderCredential,
executor: Arc<Background>,
) -> Self {
pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential_provider = OpenAICredentialProvider {};
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
model,
credential_provider,
credential,
executor,
}
}
}
impl CredentialProvider for OpenAICompletionProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
let mut credential = self.credential.write();
match *credential {
ProviderCredential::Credentials { .. } => {
return credential.clone();
}
_ => {
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
*credential = ProviderCredential::Credentials { api_key };
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
*credential = ProviderCredential::Credentials { api_key };
}
} else {
};
}
}
credential.clone()
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
match credential.clone() {
ProviderCredential::Credentials { api_key } => {
cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
*self.credential.write() = credential;
}
fn delete_credentials(&self, cx: &AppContext) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn credential_provider(&self) -> Box<dyn CredentialProvider> {
let provider: Box<dyn CredentialProvider> = Box::new(self.credential_provider.clone());
provider
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let credential = self.credential.clone();
let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;