From 6c8bb4b05e62aab88d5487cecb215c2fe8863a49 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 27 Oct 2023 08:33:35 +0200 Subject: [PATCH] ensure OpenAIEmbeddingProvider is using the provider credentials --- crates/ai/src/providers/open_ai/embedding.rs | 11 ++++++----- crates/semantic_index/src/embedding_queue.rs | 2 +- crates/semantic_index/src/semantic_index.rs | 17 ++++++----------- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 1385b32b4d..9806877660 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -162,14 +162,15 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { async fn embed_batch( &self, spans: Vec, - _credential: ProviderCredential, + credential: ProviderCredential, ) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; + let api_key = match credential { + ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key), + _ => Err(anyhow!("no api key provided")), + }?; let mut request_number = 0; let mut rate_limiting = false; @@ -178,7 +179,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { while request_number < MAX_RETRIES { response = self .send_request( - api_key, + &api_key, spans.iter().map(|x| &**x).collect(), request_timeout, ) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 299aa328b5..6f792c78e2 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -41,7 +41,7 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - provider_credential: ProviderCredential, + pub provider_credential: ProviderCredential, } #[derive(Clone)] diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index f420e0503b..7fb5f749b4 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -281,15 +281,13 @@ impl SemanticIndex { } pub fn authenticate(&mut self, cx: &AppContext) -> bool { - let credential = self.provider_credential.clone(); - match credential { - ProviderCredential::NoCredentials => { - let credential = self.embedding_provider.retrieve_credentials(cx); - self.provider_credential = credential; - } - _ => {} - } + let existing_credential = self.provider_credential.clone(); + let credential = match existing_credential { + ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx), + _ => existing_credential, + }; + self.provider_credential = credential.clone(); self.embedding_queue.lock().set_credential(credential); self.is_authenticated() } @@ -1020,14 +1018,11 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task> { if !self.is_authenticated() { - println!("Authenticating"); if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } - println!("SHOULD NOW BE AUTHENTICATED"); - if !self.projects.contains_key(&project.downgrade()) { let subscription = cx.subscribe(&project, |this, project, event, cx| match event { project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {