diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs index cb3f2beabb..c6256df216 100644 --- a/crates/ai/src/auth.rs +++ b/crates/ai/src/auth.rs @@ -8,17 +8,8 @@ pub enum ProviderCredential { } pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential); fn delete_credentials(&self, cx: &AppContext); } - -#[derive(Clone)] -pub struct NullCredentialProvider; -impl CredentialProvider for NullCredentialProvider { - fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { - ProviderCredential::NotNeeded - } - fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {} - fn delete_credentials(&self, cx: &AppContext) {} -} diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 6a2806a5cb..7fdc49e918 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,28 +1,14 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; -use gpui::AppContext; -use crate::{ - auth::{CredentialProvider, ProviderCredential}, - models::LanguageModel, -}; +use crate::{auth::CredentialProvider, models::LanguageModel}; pub trait CompletionRequest: Send + Sync { fn data(&self) -> serde_json::Result; } -pub trait CompletionProvider { +pub trait CompletionProvider: CredentialProvider { fn base_model(&self) -> Box; - fn credential_provider(&self) -> Box; - fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { - self.credential_provider().retrieve_credentials(cx) - } - fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { - self.credential_provider().save_credentials(cx, credential); - } - fn delete_credentials(&self, cx: &AppContext) { - self.credential_provider().delete_credentials(cx); - } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 50f04232ab..6768b7ce7b 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -2,12 +2,11 @@ use std::time::Instant; use anyhow::Result; use async_trait::async_trait; -use gpui::AppContext; use ordered_float::OrderedFloat; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; -use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::auth::CredentialProvider; use crate::models::LanguageModel; #[derive(Debug, PartialEq, Clone)] @@ -70,17 +69,9 @@ impl Embedding { } #[async_trait] -pub trait EmbeddingProvider: Sync + Send { +pub trait EmbeddingProvider: CredentialProvider { fn base_model(&self) -> Box; - fn credential_provider(&self) -> Box; - fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { - self.credential_provider().retrieve_credentials(cx) - } - async fn embed_batch( - &self, - spans: Vec, - credential: ProviderCredential, - ) -> Result>; + async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn rate_limit_expiration(&self) -> Option; } diff --git a/crates/ai/src/providers/open_ai/auth.rs b/crates/ai/src/providers/open_ai/auth.rs deleted file mode 100644 index 7cb51ab449..0000000000 --- a/crates/ai/src/providers/open_ai/auth.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::env; - -use gpui::AppContext; -use util::ResultExt; - -use crate::auth::{CredentialProvider, ProviderCredential}; -use crate::providers::open_ai::OPENAI_API_URL; - -#[derive(Clone)] -pub struct OpenAICredentialProvider {} - -impl CredentialProvider for OpenAICredentialProvider { - fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - - if let Some(api_key) = api_key { - ProviderCredential::Credentials { api_key } - } else { - ProviderCredential::NoCredentials - } - } - fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { - match credential { - ProviderCredential::Credentials { api_key } => { - cx.platform() - .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) - .log_err(); - } - _ => {} - } - } - fn delete_credentials(&self, cx: &AppContext) { - cx.platform().delete_credentials(OPENAI_API_URL).log_err(); - } -} diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index febe491123..02d25a7eec 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -3,14 +3,17 @@ use futures::{ future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt, }; -use gpui::executor::Background; +use gpui::{executor::Background, AppContext}; use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; use serde::{Deserialize, Serialize}; use std::{ + env, fmt::{self, Display}, io, sync::Arc, }; +use util::ResultExt; use crate::{ auth::{CredentialProvider, ProviderCredential}, @@ -18,9 +21,7 @@ use crate::{ models::LanguageModel, }; -use super::{auth::OpenAICredentialProvider, OpenAILanguageModel}; - -pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] @@ -194,42 +195,83 @@ pub async fn stream_completion( pub struct OpenAICompletionProvider { model: OpenAILanguageModel, - credential_provider: OpenAICredentialProvider, - credential: ProviderCredential, + credential: Arc>, executor: Arc, } impl OpenAICompletionProvider { - pub fn new( - model_name: &str, - credential: ProviderCredential, - executor: Arc, - ) -> Self { + pub fn new(model_name: &str, executor: Arc) -> Self { let model = OpenAILanguageModel::load(model_name); - let credential_provider = OpenAICredentialProvider {}; + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); Self { model, - credential_provider, credential, executor, } } } +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + impl CompletionProvider for OpenAICompletionProvider { fn base_model(&self) -> Box { let model: Box = Box::new(self.model.clone()); model } - fn credential_provider(&self) -> Box { - let provider: Box = Box::new(self.credential_provider.clone()); - provider - } fn complete( &self, prompt: Box, ) -> BoxFuture<'static, Result>>> { - let credential = self.credential.clone(); + let credential = self.credential.read().clone(); let request = stream_completion(credential, self.executor.clone(), prompt); async move { let response = request.await?; diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index dafc94580d..fbfd0028f9 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -2,27 +2,29 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::serde_json; +use gpui::{serde_json, AppContext}; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use parse_duration::parse; use postage::watch; use serde::{Deserialize, Serialize}; +use std::env; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; +use util::ResultExt; use crate::auth::{CredentialProvider, ProviderCredential}; use crate::embedding::{Embedding, EmbeddingProvider}; use crate::models::LanguageModel; use crate::providers::open_ai::OpenAILanguageModel; -use crate::providers::open_ai::auth::OpenAICredentialProvider; +use crate::providers::open_ai::OPENAI_API_URL; lazy_static! { static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); @@ -31,7 +33,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddingProvider { model: OpenAILanguageModel, - credential_provider: OpenAICredentialProvider, + credential: Arc>, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -69,10 +71,11 @@ impl OpenAIEmbeddingProvider { let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); OpenAIEmbeddingProvider { model, - credential_provider: OpenAICredentialProvider {}, + credential, client, executor, rate_limit_count_rx, @@ -80,6 +83,13 @@ impl OpenAIEmbeddingProvider { } } + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + fn resolve_rate_limit(&self) { let reset_time = *self.rate_limit_count_tx.lock().borrow(); @@ -136,6 +146,57 @@ impl OpenAIEmbeddingProvider { } } +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { fn base_model(&self) -> Box { @@ -143,12 +204,6 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { model } - fn credential_provider(&self) -> Box { - let credential_provider: Box = - Box::new(self.credential_provider.clone()); - credential_provider - } - fn max_tokens_per_batch(&self) -> usize { 50000 } @@ -157,18 +212,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { *self.rate_limit_count_rx.borrow() } - async fn embed_batch( - &self, - spans: Vec, - credential: ProviderCredential, - ) -> Result> { + async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = match credential { - ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key), - _ => Err(anyhow!("no api key provided")), - }?; + let api_key = self.get_api_key()?; let mut request_number = 0; let mut rate_limiting = false; diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 49e29fbc8c..7d2f86045d 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,4 +1,3 @@ -pub mod auth; pub mod completion; pub mod embedding; pub mod model; @@ -6,3 +5,5 @@ pub mod model; pub use completion::*; pub use embedding::*; pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai/src/providers/open_ai/new.rs b/crates/ai/src/providers/open_ai/new.rs new file mode 100644 index 0000000000..c7d67f2ba1 --- /dev/null +++ b/crates/ai/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index b8f99af400..bc9a6a3e43 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -5,10 +5,11 @@ use std::{ use async_trait::async_trait; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::AppContext; use parking_lot::Mutex; use crate::{ - auth::{CredentialProvider, NullCredentialProvider, ProviderCredential}, + auth::{CredentialProvider, ProviderCredential}, completion::{CompletionProvider, CompletionRequest}, embedding::{Embedding, EmbeddingProvider}, models::{LanguageModel, TruncationDirection}, @@ -52,14 +53,12 @@ impl LanguageModel for FakeLanguageModel { pub struct FakeEmbeddingProvider { pub embedding_count: AtomicUsize, - pub credential_provider: NullCredentialProvider, } impl Clone for FakeEmbeddingProvider { fn clone(&self) -> Self { FakeEmbeddingProvider { embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), - credential_provider: self.credential_provider.clone(), } } } @@ -68,7 +67,6 @@ impl Default for FakeEmbeddingProvider { fn default() -> Self { FakeEmbeddingProvider { embedding_count: AtomicUsize::default(), - credential_provider: NullCredentialProvider {}, } } } @@ -99,16 +97,22 @@ impl FakeEmbeddingProvider { } } +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { fn base_model(&self) -> Box { Box::new(FakeLanguageModel { capacity: 1000 }) } - fn credential_provider(&self) -> Box { - let credential_provider: Box = - Box::new(self.credential_provider.clone()); - credential_provider - } fn max_tokens_per_batch(&self) -> usize { 1000 } @@ -117,11 +121,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider { None } - async fn embed_batch( - &self, - spans: Vec, - _credential: ProviderCredential, - ) -> anyhow::Result> { + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); @@ -129,11 +129,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider { } } -pub struct TestCompletionProvider { +pub struct FakeCompletionProvider { last_completion_tx: Mutex>>, } -impl TestCompletionProvider { +impl FakeCompletionProvider { pub fn new() -> Self { Self { last_completion_tx: Mutex::new(None), @@ -150,14 +150,22 @@ impl TestCompletionProvider { } } -impl CompletionProvider for TestCompletionProvider { +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { fn base_model(&self) -> Box { let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); model } - fn credential_provider(&self) -> Box { - Box::new(NullCredentialProvider {}) - } fn complete( &self, _prompt: Box, diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index f9187b8785..c10ad2c362 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -10,7 +10,7 @@ use ai::{ auth::ProviderCredential, completion::{CompletionProvider, CompletionRequest}, providers::open_ai::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, }, }; @@ -48,7 +48,7 @@ use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::SettingsStore; use std::{ cell::{Cell, RefCell}, - cmp, env, + cmp, fmt::Write, iter, ops::Range, @@ -210,7 +210,6 @@ impl AssistantPanel { // Defaulting currently to GPT4, allow for this to be set via config. let completion_provider = Box::new(OpenAICompletionProvider::new( "gpt-4", - ProviderCredential::NoCredentials, cx.background().clone(), )); @@ -298,7 +297,6 @@ impl AssistantPanel { cx: &mut ViewContext, project: &ModelHandle, ) { - let credential = self.credential.borrow().clone(); let selection = editor.read(cx).selections.newest_anchor().clone(); if selection.start.excerpt_id() != selection.end.excerpt_id() { return; @@ -330,7 +328,6 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( "gpt-4", - credential, cx.background().clone(), )); diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 7f4c95f655..8d8e49902f 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,7 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use ai::test::TestCompletionProvider; + use ai::test::FakeCompletionProvider; use futures::stream::{self}; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; @@ -379,7 +379,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -445,7 +445,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -511,7 +511,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 6f792c78e2..6ae8faa4cd 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,5 +1,5 @@ use crate::{parsing::Span, JobHandle}; -use ai::{auth::ProviderCredential, embedding::EmbeddingProvider}; +use ai::embedding::EmbeddingProvider; use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; @@ -41,7 +41,6 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - pub provider_credential: ProviderCredential, } #[derive(Clone)] @@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed { } impl EmbeddingQueue { - pub fn new( - embedding_provider: Arc, - executor: Arc, - provider_credential: ProviderCredential, - ) -> Self { + pub fn new(embedding_provider: Arc, executor: Arc) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { embedding_provider, @@ -64,14 +59,9 @@ impl EmbeddingQueue { pending_batch_token_count: 0, finished_files_tx, finished_files_rx, - provider_credential, } } - pub fn set_credential(&mut self, credential: ProviderCredential) { - self.provider_credential = credential; - } - pub fn push(&mut self, file: FileToEmbed) { if file.spans.is_empty() { self.finished_files_tx.try_send(file).unwrap(); @@ -118,7 +108,6 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - let credential = self.provider_credential.clone(); self.executor .spawn(async move { @@ -143,7 +132,7 @@ impl EmbeddingQueue { return; }; - match embedding_provider.embed_batch(spans, credential).await { + match embedding_provider.embed_batch(spans).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 7fb5f749b4..818faa0444 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,7 +7,6 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; -use ai::auth::ProviderCredential; use ai::embedding::{Embedding, EmbeddingProvider}; use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; @@ -125,8 +124,6 @@ pub struct SemanticIndex { _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, - provider_credential: ProviderCredential, - embedding_queue: Arc>, } struct ProjectState { @@ -281,24 +278,17 @@ impl SemanticIndex { } pub fn authenticate(&mut self, cx: &AppContext) -> bool { - let existing_credential = self.provider_credential.clone(); - let credential = match existing_credential { - ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx), - _ => existing_credential, - }; + if !self.embedding_provider.has_credentials() { + self.embedding_provider.retrieve_credentials(cx); + } else { + return true; + } - self.provider_credential = credential.clone(); - self.embedding_queue.lock().set_credential(credential); - self.is_authenticated() + self.embedding_provider.has_credentials() } pub fn is_authenticated(&self) -> bool { - let credential = &self.provider_credential; - match credential { - &ProviderCredential::Credentials { .. } => true, - &ProviderCredential::NotNeeded => true, - _ => false, - } + self.embedding_provider.has_credentials() } pub fn enabled(cx: &AppContext) -> bool { @@ -348,7 +338,7 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); let embedding_queue = - EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials); + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); let db = db.clone(); @@ -413,8 +403,6 @@ impl SemanticIndex { _embedding_task, _parsing_files_tasks, projects: Default::default(), - provider_credential: ProviderCredential::NoCredentials, - embedding_queue } })) } @@ -729,14 +717,13 @@ impl SemanticIndex { let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); - let credential = self.provider_credential.clone(); cx.spawn(|this, mut cx| async move { index.await?; let t0 = Instant::now(); let query = embedding_provider - .embed_batch(vec![query], credential) + .embed_batch(vec![query]) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; @@ -954,7 +941,6 @@ impl SemanticIndex { let fs = self.fs.clone(); let db_path = self.db.path().clone(); let background = cx.background().clone(); - let credential = self.provider_credential.clone(); cx.background().spawn(async move { let db = VectorDatabase::new(fs, db_path.clone(), background).await?; let mut results = Vec::::new(); @@ -969,15 +955,10 @@ impl SemanticIndex { .parse_file_with_template(None, &snapshot.text(), language) .log_err() .unwrap_or_default(); - if Self::embed_spans( - &mut spans, - embedding_provider.as_ref(), - &db, - credential.clone(), - ) - .await - .log_err() - .is_some() + if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db) + .await + .log_err() + .is_some() { for span in spans { let similarity = span.embedding.unwrap().similarity(&query); @@ -1201,7 +1182,6 @@ impl SemanticIndex { spans: &mut [Span], embedding_provider: &dyn EmbeddingProvider, db: &VectorDatabase, - credential: ProviderCredential, ) -> Result<()> { let mut batch = Vec::new(); let mut batch_tokens = 0; @@ -1224,7 +1204,7 @@ impl SemanticIndex { if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), credential.clone()) + .embed_batch(mem::take(&mut batch)) .await?; embeddings.extend(batch_embeddings); batch_tokens = 0; @@ -1236,7 +1216,7 @@ impl SemanticIndex { if !batch.is_empty() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), credential) + .embed_batch(mem::take(&mut batch)) .await?; embeddings.extend(batch_embeddings); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 7d5a4e22e8..7a91d1e100 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -220,11 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new( - embedding_provider.clone(), - cx.background(), - ai::auth::ProviderCredential::NoCredentials, - ); + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background()); for file in &files { queue.push(file.clone()); }