moved authentication for the semantic index into the EmbeddingProvider

This commit is contained in:
KCaverly 2023-10-30 10:02:27 -04:00
parent 1e8b23d8fb
commit a2c3971ad6
14 changed files with 200 additions and 206 deletions

View file

@ -1,5 +1,5 @@
use crate::{parsing::Span, JobHandle};
use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
use ai::embedding::EmbeddingProvider;
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
pub provider_credential: ProviderCredential,
}
#[derive(Clone)]
@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
}
impl EmbeddingQueue {
pub fn new(
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: Arc<Background>,
provider_credential: ProviderCredential,
) -> Self {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
@ -64,14 +59,9 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
provider_credential,
}
}
pub fn set_credential(&mut self, credential: ProviderCredential) {
self.provider_credential = credential;
}
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
@ -118,7 +108,6 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
let credential = self.provider_credential.clone();
self.executor
.spawn(async move {
@ -143,7 +132,7 @@ impl EmbeddingQueue {
return;
};
match embedding_provider.embed_batch(spans, credential).await {
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {

View file

@ -7,7 +7,6 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
use ai::auth::ProviderCredential;
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
@ -125,8 +124,6 @@ pub struct SemanticIndex {
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
provider_credential: ProviderCredential,
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
}
struct ProjectState {
@ -281,24 +278,17 @@ impl SemanticIndex {
}
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
let existing_credential = self.provider_credential.clone();
let credential = match existing_credential {
ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
_ => existing_credential,
};
if !self.embedding_provider.has_credentials() {
self.embedding_provider.retrieve_credentials(cx);
} else {
return true;
}
self.provider_credential = credential.clone();
self.embedding_queue.lock().set_credential(credential);
self.is_authenticated()
self.embedding_provider.has_credentials()
}
pub fn is_authenticated(&self) -> bool {
let credential = &self.provider_credential;
match credential {
&ProviderCredential::Credentials { .. } => true,
&ProviderCredential::NotNeeded => true,
_ => false,
}
self.embedding_provider.has_credentials()
}
pub fn enabled(cx: &AppContext) -> bool {
@ -348,7 +338,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
@ -413,8 +403,6 @@ impl SemanticIndex {
_embedding_task,
_parsing_files_tasks,
projects: Default::default(),
provider_credential: ProviderCredential::NoCredentials,
embedding_queue
}
}))
}
@ -729,14 +717,13 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
let credential = self.provider_credential.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let query = embedding_provider
.embed_batch(vec![query], credential)
.embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
@ -954,7 +941,6 @@ impl SemanticIndex {
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
let credential = self.provider_credential.clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@ -969,15 +955,10 @@ impl SemanticIndex {
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
if Self::embed_spans(
&mut spans,
embedding_provider.as_ref(),
&db,
credential.clone(),
)
.await
.log_err()
.is_some()
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
.await
.log_err()
.is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
@ -1201,7 +1182,6 @@ impl SemanticIndex {
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
credential: ProviderCredential,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
@ -1224,7 +1204,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch), credential.clone())
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
@ -1236,7 +1216,7 @@ impl SemanticIndex {
if !batch.is_empty() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch), credential)
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);

View file

@ -220,11 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(
embedding_provider.clone(),
cx.background(),
ai::auth::ProviderCredential::NoCredentials,
);
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files {
queue.push(file.clone());
}