implement new search strategy (#3029)
Augment current search strategy in semantic search, reducing search times by ~60% Release Notes: - Implemented minimum batch sizes for concurrent database reads. - Batch embedding matrix multiplication. - Calculate matmul with ndarray
This commit is contained in:
commit
edf29aa67d
5 changed files with 128 additions and 47 deletions
25
Cargo.lock
generated
25
Cargo.lock
generated
|
@ -4580,6 +4580,19 @@ dependencies = [
|
||||||
"tempfile",
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray"
|
||||||
|
version = "0.15.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||||
|
dependencies = [
|
||||||
|
"matrixmultiply",
|
||||||
|
"num-complex 0.4.4",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ndk"
|
name = "ndk"
|
||||||
version = "0.7.0"
|
version = "0.7.0"
|
||||||
|
@ -4706,7 +4719,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
|
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"num-bigint 0.2.6",
|
"num-bigint 0.2.6",
|
||||||
"num-complex",
|
"num-complex 0.2.4",
|
||||||
"num-integer",
|
"num-integer",
|
||||||
"num-iter",
|
"num-iter",
|
||||||
"num-rational 0.2.4",
|
"num-rational 0.2.4",
|
||||||
|
@ -4762,6 +4775,15 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-derive"
|
name = "num-derive"
|
||||||
version = "0.3.3"
|
version = "0.3.3"
|
||||||
|
@ -6751,6 +6773,7 @@ dependencies = [
|
||||||
"language",
|
"language",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"log",
|
"log",
|
||||||
|
"ndarray",
|
||||||
"node_runtime",
|
"node_runtime",
|
||||||
"ordered-float",
|
"ordered-float",
|
||||||
"parking_lot 0.11.2",
|
"parking_lot 0.11.2",
|
||||||
|
|
|
@ -39,6 +39,7 @@ rand.workspace = true
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
globset.workspace = true
|
globset.workspace = true
|
||||||
sha1 = "0.10.5"
|
sha1 = "0.10.5"
|
||||||
|
ndarray = { version = "0.15.0" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
collections = { path = "../collections", features = ["test-support"] }
|
collections = { path = "../collections", features = ["test-support"] }
|
||||||
|
|
|
@ -7,13 +7,13 @@ use anyhow::{anyhow, Context, Result};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use futures::channel::oneshot;
|
use futures::channel::oneshot;
|
||||||
use gpui::executor;
|
use gpui::executor;
|
||||||
|
use ndarray::{Array1, Array2};
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use project::{search::PathMatcher, Fs};
|
use project::{search::PathMatcher, Fs};
|
||||||
use rpc::proto::Timestamp;
|
use rpc::proto::Timestamp;
|
||||||
use rusqlite::params;
|
use rusqlite::params;
|
||||||
use rusqlite::types::Value;
|
use rusqlite::types::Value;
|
||||||
use std::{
|
use std::{
|
||||||
cmp::Reverse,
|
|
||||||
future::Future,
|
future::Future,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
|
@ -23,6 +23,13 @@ use std::{
|
||||||
};
|
};
|
||||||
use util::TryFutureExt;
|
use util::TryFutureExt;
|
||||||
|
|
||||||
|
pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
|
||||||
|
let mut indices = (0..data.len()).collect::<Vec<_>>();
|
||||||
|
indices.sort_by_key(|&i| &data[i]);
|
||||||
|
indices.reverse();
|
||||||
|
indices
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct FileRecord {
|
pub struct FileRecord {
|
||||||
pub id: usize,
|
pub id: usize,
|
||||||
|
@ -409,23 +416,91 @@ impl VectorDatabase {
|
||||||
limit: usize,
|
limit: usize,
|
||||||
file_ids: &[i64],
|
file_ids: &[i64],
|
||||||
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
|
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
|
||||||
let query_embedding = query_embedding.clone();
|
|
||||||
let file_ids = file_ids.to_vec();
|
let file_ids = file_ids.to_vec();
|
||||||
|
let query = query_embedding.clone().0;
|
||||||
|
let query = Array1::from_vec(query);
|
||||||
self.transact(move |db| {
|
self.transact(move |db| {
|
||||||
let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
|
let mut query_statement = db.prepare(
|
||||||
Self::for_each_span(db, &file_ids, |id, embedding| {
|
"
|
||||||
let similarity = embedding.similarity(&query_embedding);
|
SELECT
|
||||||
let ix = match results
|
id, embedding
|
||||||
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
FROM
|
||||||
{
|
spans
|
||||||
Ok(ix) => ix,
|
WHERE
|
||||||
Err(ix) => ix,
|
file_id IN rarray(?)
|
||||||
};
|
",
|
||||||
results.insert(ix, (id, similarity));
|
)?;
|
||||||
results.truncate(limit);
|
|
||||||
})?;
|
|
||||||
|
|
||||||
anyhow::Ok(results)
|
let deserialized_rows = query_statement
|
||||||
|
.query_map(params![ids_to_sql(&file_ids)], |row| {
|
||||||
|
Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
|
||||||
|
})?
|
||||||
|
.filter_map(|row| row.ok())
|
||||||
|
.collect::<Vec<(usize, Embedding)>>();
|
||||||
|
|
||||||
|
if deserialized_rows.len() == 0 {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Length of Embeddings Returned
|
||||||
|
let embedding_len = deserialized_rows[0].1 .0.len();
|
||||||
|
|
||||||
|
let batch_n = 1000;
|
||||||
|
let mut batches = Vec::new();
|
||||||
|
let mut batch_ids = Vec::new();
|
||||||
|
let mut batch_embeddings: Vec<f32> = Vec::new();
|
||||||
|
deserialized_rows.iter().for_each(|(id, embedding)| {
|
||||||
|
batch_ids.push(id);
|
||||||
|
batch_embeddings.extend(&embedding.0);
|
||||||
|
|
||||||
|
if batch_ids.len() == batch_n {
|
||||||
|
let embeddings = std::mem::take(&mut batch_embeddings);
|
||||||
|
let ids = std::mem::take(&mut batch_ids);
|
||||||
|
let array =
|
||||||
|
Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
|
||||||
|
match array {
|
||||||
|
Ok(array) => {
|
||||||
|
batches.push((ids, array));
|
||||||
|
}
|
||||||
|
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if batch_ids.len() > 0 {
|
||||||
|
let array = Array2::from_shape_vec(
|
||||||
|
(batch_ids.len(), embedding_len),
|
||||||
|
batch_embeddings.clone(),
|
||||||
|
);
|
||||||
|
match array {
|
||||||
|
Ok(array) => {
|
||||||
|
batches.push((batch_ids.clone(), array));
|
||||||
|
}
|
||||||
|
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut ids: Vec<usize> = Vec::new();
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for (batch_ids, array) in batches {
|
||||||
|
let scores = array
|
||||||
|
.dot(&query.t())
|
||||||
|
.to_vec()
|
||||||
|
.iter()
|
||||||
|
.map(|score| OrderedFloat(*score))
|
||||||
|
.collect::<Vec<OrderedFloat<f32>>>();
|
||||||
|
results.extend(scores);
|
||||||
|
ids.extend(batch_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
let sorted_idx = argsort(&results);
|
||||||
|
let mut sorted_results = Vec::new();
|
||||||
|
let last_idx = limit.min(sorted_idx.len());
|
||||||
|
for idx in &sorted_idx[0..last_idx] {
|
||||||
|
sorted_results.push((ids[*idx] as i64, results[*idx]))
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(sorted_results)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -468,31 +543,6 @@ impl VectorDatabase {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn for_each_span(
|
|
||||||
db: &rusqlite::Connection,
|
|
||||||
file_ids: &[i64],
|
|
||||||
mut f: impl FnMut(i64, Embedding),
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut query_statement = db.prepare(
|
|
||||||
"
|
|
||||||
SELECT
|
|
||||||
id, embedding
|
|
||||||
FROM
|
|
||||||
spans
|
|
||||||
WHERE
|
|
||||||
file_id IN rarray(?)
|
|
||||||
",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
query_statement
|
|
||||||
.query_map(params![ids_to_sql(&file_ids)], |row| {
|
|
||||||
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
|
|
||||||
})?
|
|
||||||
.filter_map(|row| row.ok())
|
|
||||||
.for_each(|(id, embedding)| f(id, embedding));
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn spans_for_ids(
|
pub fn spans_for_ids(
|
||||||
&self,
|
&self,
|
||||||
ids: &[i64],
|
ids: &[i64],
|
||||||
|
|
|
@ -705,11 +705,13 @@ impl SemanticIndex {
|
||||||
|
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
index.await?;
|
index.await?;
|
||||||
|
let t0 = Instant::now();
|
||||||
let query = embedding_provider
|
let query = embedding_provider
|
||||||
.embed_batch(vec![query])
|
.embed_batch(vec![query])
|
||||||
.await?
|
.await?
|
||||||
.pop()
|
.pop()
|
||||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||||
|
log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
|
||||||
|
|
||||||
let search_start = Instant::now();
|
let search_start = Instant::now();
|
||||||
let modified_buffer_results = this.update(&mut cx, |this, cx| {
|
let modified_buffer_results = this.update(&mut cx, |this, cx| {
|
||||||
|
@ -787,10 +789,15 @@ impl SemanticIndex {
|
||||||
|
|
||||||
let batch_n = cx.background().num_cpus();
|
let batch_n = cx.background().num_cpus();
|
||||||
let ids_len = file_ids.clone().len();
|
let ids_len = file_ids.clone().len();
|
||||||
let batch_size = if ids_len <= batch_n {
|
let minimum_batch_size = 50;
|
||||||
ids_len
|
|
||||||
} else {
|
let batch_size = {
|
||||||
ids_len / batch_n
|
let size = ids_len / batch_n;
|
||||||
|
if size < minimum_batch_size {
|
||||||
|
minimum_batch_size
|
||||||
|
} else {
|
||||||
|
size
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut batch_results = Vec::new();
|
let mut batch_results = Vec::new();
|
||||||
|
@ -822,6 +829,7 @@ impl SemanticIndex {
|
||||||
Ok(ix) => ix,
|
Ok(ix) => ix,
|
||||||
Err(ix) => ix,
|
Err(ix) => ix,
|
||||||
};
|
};
|
||||||
|
|
||||||
results.insert(ix, (id, similarity));
|
results.insert(ix, (id, similarity));
|
||||||
results.truncate(limit);
|
results.truncate(limit);
|
||||||
}
|
}
|
||||||
|
@ -856,7 +864,6 @@ impl SemanticIndex {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let buffers = futures::future::join_all(tasks).await;
|
let buffers = futures::future::join_all(tasks).await;
|
||||||
|
|
||||||
Ok(buffers
|
Ok(buffers
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(ranges)
|
.zip(ranges)
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
cargo run -p semantic_index --example eval
|
RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue