batch search queries in the vector database
This commit is contained in:
parent
6cd10f3d5e
commit
98fde36834
3 changed files with 106 additions and 45 deletions
|
@ -267,41 +267,56 @@ impl VectorDatabase {
|
|||
|
||||
pub fn top_k_search(
|
||||
&self,
|
||||
worktree_ids: &[i64],
|
||||
query_embedding: &Vec<f32>,
|
||||
limit: usize,
|
||||
include_globs: Vec<GlobMatcher>,
|
||||
exclude_globs: Vec<GlobMatcher>,
|
||||
) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
file_ids: &[i64],
|
||||
) -> Result<Vec<(i64, f32)>> {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
self.for_each_document(
|
||||
&worktree_ids,
|
||||
include_globs,
|
||||
exclude_globs,
|
||||
|id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
},
|
||||
)?;
|
||||
self.for_each_document(file_ids, |id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
let ix = match results
|
||||
.binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
})?;
|
||||
|
||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
|
||||
self.get_documents_by_ids(&ids)
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn for_each_document(
|
||||
// pub fn top_k_search(
|
||||
// &self,
|
||||
// worktree_ids: &[i64],
|
||||
// query_embedding: &Vec<f32>,
|
||||
// limit: usize,
|
||||
// file_ids: Vec<i64>,
|
||||
// ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
// let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
// self.for_each_document(&worktree_ids, file_ids, |id, embedding| {
|
||||
// let similarity = dot(&embedding, &query_embedding);
|
||||
// let ix = match results
|
||||
// .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
|
||||
// {
|
||||
// Ok(ix) => ix,
|
||||
// Err(ix) => ix,
|
||||
// };
|
||||
// results.insert(ix, (id, similarity));
|
||||
// results.truncate(limit);
|
||||
// })?;
|
||||
|
||||
// let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
|
||||
// self.get_documents_by_ids(&ids)
|
||||
// }
|
||||
|
||||
pub fn retrieve_included_file_ids(
|
||||
&self,
|
||||
worktree_ids: &[i64],
|
||||
include_globs: Vec<GlobMatcher>,
|
||||
exclude_globs: Vec<GlobMatcher>,
|
||||
mut f: impl FnMut(i64, Vec<f32>),
|
||||
) -> Result<()> {
|
||||
) -> Result<Vec<i64>> {
|
||||
let mut file_query = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
|
@ -315,6 +330,7 @@ impl VectorDatabase {
|
|||
|
||||
let mut file_ids = Vec::<i64>::new();
|
||||
let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
|
||||
|
||||
while let Some(row) = rows.next()? {
|
||||
let file_id = row.get(0)?;
|
||||
let relative_path = row.get_ref(1)?.as_str()?;
|
||||
|
@ -330,6 +346,10 @@ impl VectorDatabase {
|
|||
}
|
||||
}
|
||||
|
||||
Ok(file_ids)
|
||||
}
|
||||
|
||||
fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
|
||||
let mut query_statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
|
@ -350,7 +370,7 @@ impl VectorDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue