diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index 7553c19925..a2be44cbce 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -411,7 +411,6 @@ "cmd-k cmd-t": "theme_selector::Toggle", "cmd-k cmd-s": "zed::OpenKeymap", "cmd-t": "project_symbols::Toggle", - "cmd-ctrl-t": "semantic_search::Toggle", "cmd-p": "file_finder::Toggle", "cmd-shift-p": "command_palette::Toggle", "cmd-shift-m": "diagnostics::Deploy", diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 4bc97da0f0..d180f5e831 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -267,41 +267,32 @@ impl VectorDatabase { pub fn top_k_search( &self, - worktree_ids: &[i64], query_embedding: &Vec, limit: usize, - include_globs: Vec, - exclude_globs: Vec, - ) -> Result)>> { + file_ids: &[i64], + ) -> Result> { let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - self.for_each_document( - &worktree_ids, - include_globs, - exclude_globs, - |id, embedding| { - let similarity = dot(&embedding, &query_embedding); - let ix = match results.binary_search_by(|(_, s)| { - similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - }, - )?; + self.for_each_document(file_ids, |id, embedding| { + let similarity = dot(&embedding, &query_embedding); + let ix = match results + .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) + { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + })?; - let ids = results.into_iter().map(|(id, _)| id).collect::>(); - self.get_documents_by_ids(&ids) + Ok(results) } - fn for_each_document( + pub fn retrieve_included_file_ids( &self, worktree_ids: &[i64], include_globs: Vec, exclude_globs: Vec, - mut f: impl FnMut(i64, Vec), - ) -> Result<()> { + ) -> Result> { let mut file_query = self.db.prepare( " SELECT @@ -315,6 +306,7 @@ impl VectorDatabase { let mut file_ids = Vec::::new(); let mut rows = file_query.query([ids_to_sql(worktree_ids)])?; + while let Some(row) = rows.next()? { let file_id = row.get(0)?; let relative_path = row.get_ref(1)?.as_str()?; @@ -330,6 +322,10 @@ impl VectorDatabase { } } + Ok(file_ids) + } + + fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec)) -> Result<()> { let mut query_statement = self.db.prepare( " SELECT @@ -350,7 +346,7 @@ impl VectorDatabase { Ok(()) } - fn get_documents_by_ids(&self, ids: &[i64]) -> Result)>> { + pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result)>> { let mut statement = self.db.prepare( " SELECT diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index e4a307573a..bd114de216 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -20,6 +20,7 @@ use postage::watch; use project::{Fs, Project, WorktreeId}; use smol::channel; use std::{ + cmp::Ordering, collections::HashMap, mem, ops::Range, @@ -704,27 +705,69 @@ impl SemanticIndex { let database_url = self.database_url.clone(); let fs = self.fs.clone(); cx.spawn(|this, mut cx| async move { - let documents = cx - .background() - .spawn(async move { - let database = VectorDatabase::new(fs, database_url).await?; + let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; - let phrase_embedding = embedding_provider - .embed_batch(vec![&phrase]) - .await? - .into_iter() - .next() - .unwrap(); + let phrase_embedding = embedding_provider + .embed_batch(vec![&phrase]) + .await? + .into_iter() + .next() + .unwrap(); - database.top_k_search( - &worktree_db_ids, - &phrase_embedding, - limit, - include_globs, - exclude_globs, - ) - }) - .await?; + let file_ids = database.retrieve_included_file_ids( + &worktree_db_ids, + include_globs, + exclude_globs, + )?; + + let batch_n = cx.background().num_cpus(); + let ids_len = file_ids.clone().len(); + let batch_size = if ids_len <= batch_n { + ids_len + } else { + ids_len / batch_n + }; + + let mut result_tasks = Vec::new(); + for batch in file_ids.chunks(batch_size) { + let batch = batch.into_iter().map(|v| *v).collect::>(); + let limit = limit.clone(); + let fs = fs.clone(); + let database_url = database_url.clone(); + let phrase_embedding = phrase_embedding.clone(); + let task = cx.background().spawn(async move { + let database = VectorDatabase::new(fs, database_url).await.log_err(); + if database.is_none() { + return Err(anyhow!("failed to acquire database connection")); + } else { + database + .unwrap() + .top_k_search(&phrase_embedding, limit, batch.as_slice()) + } + }); + result_tasks.push(task); + } + + let batch_results = futures::future::join_all(result_tasks).await; + + let mut results = Vec::new(); + for batch_result in batch_results { + if batch_result.is_ok() { + for (id, similarity) in batch_result.unwrap() { + let ix = match results.binary_search_by(|(_, s)| { + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) + }) { + Ok(ix) => ix, + Err(ix) => ix, + }; + results.insert(ix, (id, similarity)); + results.truncate(limit); + } + } + } + + let ids = results.into_iter().map(|(id, _)| id).collect::>(); + let documents = database.get_documents_by_ids(ids.as_slice())?; let mut tasks = Vec::new(); let mut ranges = Vec::new();