implement new search strategy

This commit is contained in:
KCaverly 2023-09-25 13:44:19 -04:00
parent 0697d08e54
commit 86ec0b1d9f
5 changed files with 227 additions and 50 deletions

View file

@ -39,6 +39,8 @@ rand.workspace = true
schemars.workspace = true
globset.workspace = true
sha1 = "0.10.5"
ndarray = { version = "0.15.0", features = ["blas"] }
blas-src = { version = "0.8", features = ["openblas"] }
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }

View file

@ -1,3 +1,5 @@
extern crate blas_src;
use crate::{
parsing::{Span, SpanDigest},
SEMANTIC_INDEX_VERSION,
@ -7,6 +9,7 @@ use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::channel::oneshot;
use gpui::executor;
use ndarray::{Array1, Array2};
use ordered_float::OrderedFloat;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
@ -19,10 +22,16 @@ use std::{
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::SystemTime,
time::{Instant, SystemTime},
};
use util::TryFutureExt;
pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
let mut indices = (0..data.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &data[i]);
indices
}
#[derive(Debug)]
pub struct FileRecord {
pub id: usize,
@ -409,23 +418,82 @@ impl VectorDatabase {
limit: usize,
file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
let query_embedding = query_embedding.clone();
let file_ids = file_ids.to_vec();
let query = query_embedding.clone().0;
let query = Array1::from_vec(query);
self.transact(move |db| {
let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
Self::for_each_span(db, &file_ids, |id, embedding| {
let similarity = embedding.similarity(&query_embedding);
let ix = match results
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
{
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
let mut query_statement = db.prepare(
"
SELECT
id, embedding
FROM
spans
WHERE
file_id IN rarray(?)
",
)?;
anyhow::Ok(results)
let deserialized_rows = query_statement
.query_map(params![ids_to_sql(&file_ids)], |row| {
Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
.collect::<Vec<(usize, Embedding)>>();
let batch_n = 250;
let mut batches = Vec::new();
let mut batch_ids = Vec::new();
let mut batch_embeddings: Vec<f32> = Vec::new();
deserialized_rows.iter().for_each(|(id, embedding)| {
batch_ids.push(id);
batch_embeddings.extend(&embedding.0);
if batch_ids.len() == batch_n {
let array =
Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
match array {
Ok(array) => {
batches.push((batch_ids.clone(), array));
}
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
}
batch_ids = Vec::new();
batch_embeddings = Vec::new();
}
});
if batch_ids.len() > 0 {
let array =
Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
match array {
Ok(array) => {
batches.push((batch_ids.clone(), array));
}
Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
}
}
let mut ids: Vec<usize> = Vec::new();
let mut results = Vec::new();
for (batch_ids, array) in batches {
let scores = array
.dot(&query.t())
.to_vec()
.iter()
.map(|score| OrderedFloat(*score))
.collect::<Vec<OrderedFloat<f32>>>();
results.extend(scores);
ids.extend(batch_ids);
}
let sorted_idx = argsort(&results);
let mut sorted_results = Vec::new();
let last_idx = limit.min(sorted_idx.len());
for idx in &sorted_idx[0..last_idx] {
sorted_results.push((ids[*idx] as i64, results[*idx]))
}
Ok(sorted_results)
})
}
@ -468,31 +536,6 @@ impl VectorDatabase {
})
}
fn for_each_span(
db: &rusqlite::Connection,
file_ids: &[i64],
mut f: impl FnMut(i64, Embedding),
) -> Result<()> {
let mut query_statement = db.prepare(
"
SELECT
id, embedding
FROM
spans
WHERE
file_id IN rarray(?)
",
)?;
query_statement
.query_map(params![ids_to_sql(&file_ids)], |row| {
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
.for_each(|(id, embedding)| f(id, embedding));
Ok(())
}
pub fn spans_for_ids(
&self,
ids: &[i64],

View file

@ -705,11 +705,13 @@ impl SemanticIndex {
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let query = embedding_provider
.embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
log::trace!("Embedding Search Query: {:?}", t0.elapsed().as_millis());
let search_start = Instant::now();
let modified_buffer_results = this.update(&mut cx, |this, cx| {
@ -787,10 +789,15 @@ impl SemanticIndex {
let batch_n = cx.background().num_cpus();
let ids_len = file_ids.clone().len();
let batch_size = if ids_len <= batch_n {
ids_len
} else {
ids_len / batch_n
let minimum_batch_size = 50;
let batch_size = {
let size = ids_len / batch_n;
if size < minimum_batch_size {
minimum_batch_size
} else {
size
}
};
let mut batch_results = Vec::new();
@ -813,17 +820,26 @@ impl SemanticIndex {
let batch_results = futures::future::join_all(batch_results).await;
let mut results = Vec::new();
let mut min_similarity = None;
for batch_result in batch_results {
if batch_result.is_ok() {
for (id, similarity) in batch_result.unwrap() {
if min_similarity.map_or_else(|| false, |min_sim| min_sim > similarity) {
continue;
}
let ix = match results
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
{
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
if ix <= limit {
min_similarity = Some(similarity);
results.insert(ix, (id, similarity));
results.truncate(limit);
}
}
}
}
@ -856,7 +872,6 @@ impl SemanticIndex {
})?;
let buffers = futures::future::join_all(tasks).await;
Ok(buffers
.into_iter()
.zip(ranges)