diff --git a/Cargo.lock b/Cargo.lock index e53237b8c3..3342bf39b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4580,6 +4580,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex 0.4.4", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "ndk" version = "0.7.0" @@ -4706,7 +4719,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" dependencies = [ "num-bigint 0.2.6", - "num-complex", + "num-complex 0.2.4", "num-integer", "num-iter", "num-rational 0.2.4", @@ -4762,6 +4775,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -6751,6 +6773,7 @@ dependencies = [ "language", "lazy_static", "log", + "ndarray", "node_runtime", "ordered-float", "parking_lot 0.11.2", diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index e38ae1f06d..efda311633 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -39,6 +39,7 @@ rand.workspace = true schemars.workspace = true globset.workspace = true sha1 = "0.10.5" +ndarray = { version = "0.15.0" } [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs index 8280dc7d65..63527cea1c 100644 --- a/crates/semantic_index/src/db.rs +++ b/crates/semantic_index/src/db.rs @@ -7,13 +7,13 @@ use anyhow::{anyhow, Context, Result}; use collections::HashMap; use futures::channel::oneshot; use gpui::executor; +use ndarray::{Array1, Array2}; use ordered_float::OrderedFloat; use project::{search::PathMatcher, Fs}; use rpc::proto::Timestamp; use rusqlite::params; use rusqlite::types::Value; use std::{ - cmp::Reverse, future::Future, ops::Range, path::{Path, PathBuf}, @@ -23,6 +23,13 @@ use std::{ }; use util::TryFutureExt; +pub fn argsort(data: &[T]) -> Vec { + let mut indices = (0..data.len()).collect::>(); + indices.sort_by_key(|&i| &data[i]); + indices.reverse(); + indices +} + #[derive(Debug)] pub struct FileRecord { pub id: usize, @@ -409,23 +416,91 @@ impl VectorDatabase { limit: usize, file_ids: &[i64], ) -> impl Future)>>> { - let query_embedding = query_embedding.clone(); let file_ids = file_ids.to_vec(); + let query = query_embedding.clone().0; + let query = Array1::from_vec(query); self.transact(move |db| { - let mut results = Vec::<(i64, OrderedFloat)>::with_capacity(limit + 1); - Self::for_each_span(db, &file_ids, |id, embedding| { - let similarity = embedding.similarity(&query_embedding); - let ix = match results - .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s)) - { - Ok(ix) => ix, - Err(ix) => ix, - }; - results.insert(ix, (id, similarity)); - results.truncate(limit); - })?; + let mut query_statement = db.prepare( + " + SELECT + id, embedding + FROM + spans + WHERE + file_id IN rarray(?) + ", + )?; - anyhow::Ok(results) + let deserialized_rows = query_statement + .query_map(params![ids_to_sql(&file_ids)], |row| { + Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?)) + })? + .filter_map(|row| row.ok()) + .collect::>(); + + if deserialized_rows.len() == 0 { + return Ok(Vec::new()); + } + + // Get Length of Embeddings Returned + let embedding_len = deserialized_rows[0].1 .0.len(); + + let batch_n = 1000; + let mut batches = Vec::new(); + let mut batch_ids = Vec::new(); + let mut batch_embeddings: Vec = Vec::new(); + deserialized_rows.iter().for_each(|(id, embedding)| { + batch_ids.push(id); + batch_embeddings.extend(&embedding.0); + + if batch_ids.len() == batch_n { + let embeddings = std::mem::take(&mut batch_embeddings); + let ids = std::mem::take(&mut batch_ids); + let array = + Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings); + match array { + Ok(array) => { + batches.push((ids, array)); + } + Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err), + } + } + }); + + if batch_ids.len() > 0 { + let array = Array2::from_shape_vec( + (batch_ids.len(), embedding_len), + batch_embeddings.clone(), + ); + match array { + Ok(array) => { + batches.push((batch_ids.clone(), array)); + } + Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err), + } + } + + let mut ids: Vec = Vec::new(); + let mut results = Vec::new(); + for (batch_ids, array) in batches { + let scores = array + .dot(&query.t()) + .to_vec() + .iter() + .map(|score| OrderedFloat(*score)) + .collect::>>(); + results.extend(scores); + ids.extend(batch_ids); + } + + let sorted_idx = argsort(&results); + let mut sorted_results = Vec::new(); + let last_idx = limit.min(sorted_idx.len()); + for idx in &sorted_idx[0..last_idx] { + sorted_results.push((ids[*idx] as i64, results[*idx])) + } + + Ok(sorted_results) }) } @@ -468,31 +543,6 @@ impl VectorDatabase { }) } - fn for_each_span( - db: &rusqlite::Connection, - file_ids: &[i64], - mut f: impl FnMut(i64, Embedding), - ) -> Result<()> { - let mut query_statement = db.prepare( - " - SELECT - id, embedding - FROM - spans - WHERE - file_id IN rarray(?) - ", - )?; - - query_statement - .query_map(params![ids_to_sql(&file_ids)], |row| { - Ok((row.get(0)?, row.get::<_, Embedding>(1)?)) - })? - .filter_map(|row| row.ok()) - .for_each(|(id, embedding)| f(id, embedding)); - Ok(()) - } - pub fn spans_for_ids( &self, ids: &[i64], diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index fd41cb1500..ecdba43643 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -705,11 +705,13 @@ impl SemanticIndex { cx.spawn(|this, mut cx| async move { index.await?; + let t0 = Instant::now(); let query = embedding_provider .embed_batch(vec![query]) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; + log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis()); let search_start = Instant::now(); let modified_buffer_results = this.update(&mut cx, |this, cx| { @@ -787,10 +789,15 @@ impl SemanticIndex { let batch_n = cx.background().num_cpus(); let ids_len = file_ids.clone().len(); - let batch_size = if ids_len <= batch_n { - ids_len - } else { - ids_len / batch_n + let minimum_batch_size = 50; + + let batch_size = { + let size = ids_len / batch_n; + if size < minimum_batch_size { + minimum_batch_size + } else { + size + } }; let mut batch_results = Vec::new(); @@ -822,6 +829,7 @@ impl SemanticIndex { Ok(ix) => ix, Err(ix) => ix, }; + results.insert(ix, (id, similarity)); results.truncate(limit); } @@ -856,7 +864,6 @@ impl SemanticIndex { })?; let buffers = futures::future::join_all(tasks).await; - Ok(buffers .into_iter() .zip(ranges) diff --git a/script/evaluate_semantic_index b/script/evaluate_semantic_index index e9a96a02b4..8dcb53c399 100755 --- a/script/evaluate_semantic_index +++ b/script/evaluate_semantic_index @@ -1,3 +1,3 @@ #!/bin/bash -cargo run -p semantic_index --example eval +RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release