moved authentication for the semantic index into the EmbeddingProvider
This commit is contained in:
parent
1e8b23d8fb
commit
a2c3971ad6
14 changed files with 200 additions and 206 deletions
|
@ -1,5 +1,5 @@
|
|||
use crate::{parsing::Span, JobHandle};
|
||||
use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
|
||||
use ai::embedding::EmbeddingProvider;
|
||||
use gpui::executor::Background;
|
||||
use parking_lot::Mutex;
|
||||
use smol::channel;
|
||||
|
@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
|
|||
pending_batch_token_count: usize,
|
||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
pub provider_credential: ProviderCredential,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
|
|||
}
|
||||
|
||||
impl EmbeddingQueue {
|
||||
pub fn new(
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
executor: Arc<Background>,
|
||||
provider_credential: ProviderCredential,
|
||||
) -> Self {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||
Self {
|
||||
embedding_provider,
|
||||
|
@ -64,14 +59,9 @@ impl EmbeddingQueue {
|
|||
pending_batch_token_count: 0,
|
||||
finished_files_tx,
|
||||
finished_files_rx,
|
||||
provider_credential,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_credential(&mut self, credential: ProviderCredential) {
|
||||
self.provider_credential = credential;
|
||||
}
|
||||
|
||||
pub fn push(&mut self, file: FileToEmbed) {
|
||||
if file.spans.is_empty() {
|
||||
self.finished_files_tx.try_send(file).unwrap();
|
||||
|
@ -118,7 +108,6 @@ impl EmbeddingQueue {
|
|||
|
||||
let finished_files_tx = self.finished_files_tx.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let credential = self.provider_credential.clone();
|
||||
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
|
@ -143,7 +132,7 @@ impl EmbeddingQueue {
|
|||
return;
|
||||
};
|
||||
|
||||
match embedding_provider.embed_batch(spans, credential).await {
|
||||
match embedding_provider.embed_batch(spans).await {
|
||||
Ok(embeddings) => {
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for fragment in batch {
|
||||
|
|
|
@ -7,7 +7,6 @@ pub mod semantic_index_settings;
|
|||
mod semantic_index_tests;
|
||||
|
||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||
use ai::auth::ProviderCredential;
|
||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||
use anyhow::{anyhow, Result};
|
||||
|
@ -125,8 +124,6 @@ pub struct SemanticIndex {
|
|||
_embedding_task: Task<()>,
|
||||
_parsing_files_tasks: Vec<Task<()>>,
|
||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||
provider_credential: ProviderCredential,
|
||||
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
||||
}
|
||||
|
||||
struct ProjectState {
|
||||
|
@ -281,24 +278,17 @@ impl SemanticIndex {
|
|||
}
|
||||
|
||||
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
|
||||
let existing_credential = self.provider_credential.clone();
|
||||
let credential = match existing_credential {
|
||||
ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
|
||||
_ => existing_credential,
|
||||
};
|
||||
if !self.embedding_provider.has_credentials() {
|
||||
self.embedding_provider.retrieve_credentials(cx);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.provider_credential = credential.clone();
|
||||
self.embedding_queue.lock().set_credential(credential);
|
||||
self.is_authenticated()
|
||||
self.embedding_provider.has_credentials()
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
let credential = &self.provider_credential;
|
||||
match credential {
|
||||
&ProviderCredential::Credentials { .. } => true,
|
||||
&ProviderCredential::NotNeeded => true,
|
||||
_ => false,
|
||||
}
|
||||
self.embedding_provider.has_credentials()
|
||||
}
|
||||
|
||||
pub fn enabled(cx: &AppContext) -> bool {
|
||||
|
@ -348,7 +338,7 @@ impl SemanticIndex {
|
|||
Ok(cx.add_model(|cx| {
|
||||
let t0 = Instant::now();
|
||||
let embedding_queue =
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
|
||||
let _embedding_task = cx.background().spawn({
|
||||
let embedded_files = embedding_queue.finished_files();
|
||||
let db = db.clone();
|
||||
|
@ -413,8 +403,6 @@ impl SemanticIndex {
|
|||
_embedding_task,
|
||||
_parsing_files_tasks,
|
||||
projects: Default::default(),
|
||||
provider_credential: ProviderCredential::NoCredentials,
|
||||
embedding_queue
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
@ -729,14 +717,13 @@ impl SemanticIndex {
|
|||
|
||||
let index = self.index_project(project.clone(), cx);
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let credential = self.provider_credential.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
index.await?;
|
||||
let t0 = Instant::now();
|
||||
|
||||
let query = embedding_provider
|
||||
.embed_batch(vec![query], credential)
|
||||
.embed_batch(vec![query])
|
||||
.await?
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||
|
@ -954,7 +941,6 @@ impl SemanticIndex {
|
|||
let fs = self.fs.clone();
|
||||
let db_path = self.db.path().clone();
|
||||
let background = cx.background().clone();
|
||||
let credential = self.provider_credential.clone();
|
||||
cx.background().spawn(async move {
|
||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||
let mut results = Vec::<SearchResult>::new();
|
||||
|
@ -969,15 +955,10 @@ impl SemanticIndex {
|
|||
.parse_file_with_template(None, &snapshot.text(), language)
|
||||
.log_err()
|
||||
.unwrap_or_default();
|
||||
if Self::embed_spans(
|
||||
&mut spans,
|
||||
embedding_provider.as_ref(),
|
||||
&db,
|
||||
credential.clone(),
|
||||
)
|
||||
.await
|
||||
.log_err()
|
||||
.is_some()
|
||||
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
||||
.await
|
||||
.log_err()
|
||||
.is_some()
|
||||
{
|
||||
for span in spans {
|
||||
let similarity = span.embedding.unwrap().similarity(&query);
|
||||
|
@ -1201,7 +1182,6 @@ impl SemanticIndex {
|
|||
spans: &mut [Span],
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
db: &VectorDatabase,
|
||||
credential: ProviderCredential,
|
||||
) -> Result<()> {
|
||||
let mut batch = Vec::new();
|
||||
let mut batch_tokens = 0;
|
||||
|
@ -1224,7 +1204,7 @@ impl SemanticIndex {
|
|||
|
||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch), credential.clone())
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await?;
|
||||
embeddings.extend(batch_embeddings);
|
||||
batch_tokens = 0;
|
||||
|
@ -1236,7 +1216,7 @@ impl SemanticIndex {
|
|||
|
||||
if !batch.is_empty() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch), credential)
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await?;
|
||||
|
||||
embeddings.extend(batch_embeddings);
|
||||
|
|
|
@ -220,11 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
|||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
|
||||
let mut queue = EmbeddingQueue::new(
|
||||
embedding_provider.clone(),
|
||||
cx.background(),
|
||||
ai::auth::ProviderCredential::NoCredentials,
|
||||
);
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
||||
for file in &files {
|
||||
queue.push(file.clone());
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue