implement new search strategy
This commit is contained in:
parent
0697d08e54
commit
86ec0b1d9f
5 changed files with 227 additions and 50 deletions
|
@ -705,11 +705,13 @@ impl SemanticIndex {
|
|||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
index.await?;
|
||||
let t0 = Instant::now();
|
||||
let query = embedding_provider
|
||||
.embed_batch(vec![query])
|
||||
.await?
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||
log::trace!("Embedding Search Query: {:?}", t0.elapsed().as_millis());
|
||||
|
||||
let search_start = Instant::now();
|
||||
let modified_buffer_results = this.update(&mut cx, |this, cx| {
|
||||
|
@ -787,10 +789,15 @@ impl SemanticIndex {
|
|||
|
||||
let batch_n = cx.background().num_cpus();
|
||||
let ids_len = file_ids.clone().len();
|
||||
let batch_size = if ids_len <= batch_n {
|
||||
ids_len
|
||||
} else {
|
||||
ids_len / batch_n
|
||||
let minimum_batch_size = 50;
|
||||
|
||||
let batch_size = {
|
||||
let size = ids_len / batch_n;
|
||||
if size < minimum_batch_size {
|
||||
minimum_batch_size
|
||||
} else {
|
||||
size
|
||||
}
|
||||
};
|
||||
|
||||
let mut batch_results = Vec::new();
|
||||
|
@ -813,17 +820,26 @@ impl SemanticIndex {
|
|||
let batch_results = futures::future::join_all(batch_results).await;
|
||||
|
||||
let mut results = Vec::new();
|
||||
let mut min_similarity = None;
|
||||
for batch_result in batch_results {
|
||||
if batch_result.is_ok() {
|
||||
for (id, similarity) in batch_result.unwrap() {
|
||||
if min_similarity.map_or_else(|| false, |min_sim| min_sim > similarity) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let ix = match results
|
||||
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
|
||||
if ix <= limit {
|
||||
min_similarity = Some(similarity);
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -856,7 +872,6 @@ impl SemanticIndex {
|
|||
})?;
|
||||
|
||||
let buffers = futures::future::join_all(tasks).await;
|
||||
|
||||
Ok(buffers
|
||||
.into_iter()
|
||||
.zip(ranges)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue