updated authentication for embedding provider
This commit is contained in:
parent
71bc35d241
commit
3447a9478c
16 changed files with 277 additions and 271 deletions
|
@ -7,6 +7,7 @@ pub mod semantic_index_settings;
|
|||
mod semantic_index_tests;
|
||||
|
||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||
use ai::auth::ProviderCredential;
|
||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||
use anyhow::{anyhow, Result};
|
||||
|
@ -124,7 +125,7 @@ pub struct SemanticIndex {
|
|||
_embedding_task: Task<()>,
|
||||
_parsing_files_tasks: Vec<Task<()>>,
|
||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||
api_key: Option<String>,
|
||||
provider_credential: ProviderCredential,
|
||||
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
||||
}
|
||||
|
||||
|
@ -279,18 +280,27 @@ impl SemanticIndex {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn authenticate(&mut self, cx: &AppContext) {
|
||||
if self.api_key.is_none() {
|
||||
self.api_key = self.embedding_provider.retrieve_credentials(cx);
|
||||
|
||||
self.embedding_queue
|
||||
.lock()
|
||||
.set_api_key(self.api_key.clone());
|
||||
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
|
||||
let credential = self.provider_credential.clone();
|
||||
match credential {
|
||||
ProviderCredential::NoCredentials => {
|
||||
let credential = self.embedding_provider.retrieve_credentials(cx);
|
||||
self.provider_credential = credential;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
self.embedding_queue.lock().set_credential(credential);
|
||||
|
||||
self.is_authenticated()
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
let credential = &self.provider_credential;
|
||||
match credential {
|
||||
&ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enabled(cx: &AppContext) -> bool {
|
||||
|
@ -340,7 +350,7 @@ impl SemanticIndex {
|
|||
Ok(cx.add_model(|cx| {
|
||||
let t0 = Instant::now();
|
||||
let embedding_queue =
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
|
||||
let _embedding_task = cx.background().spawn({
|
||||
let embedded_files = embedding_queue.finished_files();
|
||||
let db = db.clone();
|
||||
|
@ -405,7 +415,7 @@ impl SemanticIndex {
|
|||
_embedding_task,
|
||||
_parsing_files_tasks,
|
||||
projects: Default::default(),
|
||||
api_key: None,
|
||||
provider_credential: ProviderCredential::NoCredentials,
|
||||
embedding_queue
|
||||
}
|
||||
}))
|
||||
|
@ -721,13 +731,14 @@ impl SemanticIndex {
|
|||
|
||||
let index = self.index_project(project.clone(), cx);
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let credential = self.provider_credential.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
index.await?;
|
||||
let t0 = Instant::now();
|
||||
|
||||
let query = embedding_provider
|
||||
.embed_batch(vec![query], api_key)
|
||||
.embed_batch(vec![query], credential)
|
||||
.await?
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||
|
@ -945,7 +956,7 @@ impl SemanticIndex {
|
|||
let fs = self.fs.clone();
|
||||
let db_path = self.db.path().clone();
|
||||
let background = cx.background().clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let credential = self.provider_credential.clone();
|
||||
cx.background().spawn(async move {
|
||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||
let mut results = Vec::<SearchResult>::new();
|
||||
|
@ -964,7 +975,7 @@ impl SemanticIndex {
|
|||
&mut spans,
|
||||
embedding_provider.as_ref(),
|
||||
&db,
|
||||
api_key.clone(),
|
||||
credential.clone(),
|
||||
)
|
||||
.await
|
||||
.log_err()
|
||||
|
@ -1008,9 +1019,8 @@ impl SemanticIndex {
|
|||
project: ModelHandle<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
if self.api_key.is_none() {
|
||||
self.authenticate(cx);
|
||||
if self.api_key.is_none() {
|
||||
if !self.is_authenticated() {
|
||||
if !self.authenticate(cx) {
|
||||
return Task::ready(Err(anyhow!("user is not authenticated")));
|
||||
}
|
||||
}
|
||||
|
@ -1193,7 +1203,7 @@ impl SemanticIndex {
|
|||
spans: &mut [Span],
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
db: &VectorDatabase,
|
||||
api_key: Option<String>,
|
||||
credential: ProviderCredential,
|
||||
) -> Result<()> {
|
||||
let mut batch = Vec::new();
|
||||
let mut batch_tokens = 0;
|
||||
|
@ -1216,7 +1226,7 @@ impl SemanticIndex {
|
|||
|
||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch), api_key.clone())
|
||||
.embed_batch(mem::take(&mut batch), credential.clone())
|
||||
.await?;
|
||||
embeddings.extend(batch_embeddings);
|
||||
batch_tokens = 0;
|
||||
|
@ -1228,7 +1238,7 @@ impl SemanticIndex {
|
|||
|
||||
if !batch.is_empty() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch), api_key)
|
||||
.embed_batch(mem::take(&mut batch), credential)
|
||||
.await?;
|
||||
|
||||
embeddings.extend(batch_embeddings);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue