add initial search inside modified buffers
This commit is contained in:
parent
f86e5a987f
commit
c19c8899fe
3 changed files with 216 additions and 65 deletions
|
@ -263,9 +263,11 @@ pub struct PendingFile {
|
|||
job_handle: JobHandle,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SearchResult {
|
||||
pub buffer: ModelHandle<Buffer>,
|
||||
pub range: Range<Anchor>,
|
||||
pub similarity: f32,
|
||||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
|
@ -775,7 +777,8 @@ impl SemanticIndex {
|
|||
.filter_map(|buffer_handle| {
|
||||
let buffer = buffer_handle.read(cx);
|
||||
if buffer.is_dirty() {
|
||||
Some((buffer_handle.downgrade(), buffer.snapshot()))
|
||||
// TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly
|
||||
Some((buffer_handle, buffer.snapshot()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -783,77 +786,133 @@ impl SemanticIndex {
|
|||
.collect::<HashMap<_, _>>()
|
||||
});
|
||||
|
||||
cx.background()
|
||||
.spawn({
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
let phrase_embedding = phrase_embedding.clone();
|
||||
async move {
|
||||
let mut results = Vec::new();
|
||||
'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
|
||||
let language = buffer_snapshot
|
||||
.language_at(0)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
|
||||
if let Some(spans) = retriever
|
||||
.parse_file_with_template(None, &buffer_snapshot.text(), language)
|
||||
.log_err()
|
||||
{
|
||||
let mut batch = Vec::new();
|
||||
let mut batch_tokens = 0;
|
||||
let mut embeddings = Vec::new();
|
||||
let buffer_results = if let Some(db) =
|
||||
VectorDatabase::new(fs, db_path.clone(), cx.background())
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
cx.background()
|
||||
.spawn({
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
|
||||
let embedding_provider = embedding_provider.clone();
|
||||
let phrase_embedding = phrase_embedding.clone();
|
||||
async move {
|
||||
let mut results = Vec::<SearchResult>::new();
|
||||
'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
|
||||
let language = buffer_snapshot
|
||||
.language_at(0)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
|
||||
if let Some(spans) = retriever
|
||||
.parse_file_with_template(
|
||||
None,
|
||||
&buffer_snapshot.text(),
|
||||
language,
|
||||
)
|
||||
.log_err()
|
||||
{
|
||||
let mut batch = Vec::new();
|
||||
let mut batch_tokens = 0;
|
||||
let mut embeddings = Vec::new();
|
||||
|
||||
// TODO: query span digests in the database to avoid embedding them again.
|
||||
let digests = spans
|
||||
.iter()
|
||||
.map(|span| span.digest.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let embeddings_for_digests = db
|
||||
.embeddings_for_digests(digests)
|
||||
.await
|
||||
.map_or(Default::default(), |m| m);
|
||||
|
||||
for span in &spans {
|
||||
if span.embedding.is_some() {
|
||||
continue;
|
||||
for span in &spans {
|
||||
if embeddings_for_digests.contains_key(&span.digest) {
|
||||
continue;
|
||||
};
|
||||
|
||||
if batch_tokens + span.token_count
|
||||
> embedding_provider.max_tokens_per_batch()
|
||||
{
|
||||
if let Some(batch_embeddings) = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
embeddings.extend(batch_embeddings);
|
||||
batch_tokens = 0;
|
||||
} else {
|
||||
continue 'buffers;
|
||||
}
|
||||
}
|
||||
|
||||
batch_tokens += span.token_count;
|
||||
batch.push(span.content.clone());
|
||||
}
|
||||
|
||||
if batch_tokens + span.token_count
|
||||
> embedding_provider.max_tokens_per_batch()
|
||||
if let Some(batch_embeddings) = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
if let Some(batch_embeddings) = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await
|
||||
.log_err()
|
||||
embeddings.extend(batch_embeddings);
|
||||
} else {
|
||||
continue 'buffers;
|
||||
}
|
||||
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for span in spans {
|
||||
let embedding = if let Some(embedding) =
|
||||
embeddings_for_digests.get(&span.digest)
|
||||
{
|
||||
embeddings.extend(batch_embeddings);
|
||||
batch_tokens = 0;
|
||||
Some(embedding.clone())
|
||||
} else {
|
||||
embeddings.next()
|
||||
};
|
||||
|
||||
if let Some(embedding) = embedding {
|
||||
let similarity =
|
||||
embedding.similarity(&phrase_embedding);
|
||||
|
||||
let ix = match results.binary_search_by(|s| {
|
||||
similarity
|
||||
.partial_cmp(&s.similarity)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
|
||||
let range = {
|
||||
let start = buffer_snapshot
|
||||
.clip_offset(span.range.start, Bias::Left);
|
||||
let end = buffer_snapshot
|
||||
.clip_offset(span.range.end, Bias::Right);
|
||||
buffer_snapshot.anchor_before(start)
|
||||
..buffer_snapshot.anchor_after(end)
|
||||
};
|
||||
|
||||
results.insert(
|
||||
ix,
|
||||
SearchResult {
|
||||
buffer: buffer_handle.clone(),
|
||||
range,
|
||||
similarity,
|
||||
},
|
||||
);
|
||||
results.truncate(limit);
|
||||
} else {
|
||||
log::error!("failed to embed span");
|
||||
continue 'buffers;
|
||||
}
|
||||
}
|
||||
|
||||
batch_tokens += span.token_count;
|
||||
batch.push(span.content.clone());
|
||||
}
|
||||
|
||||
if let Some(batch_embeddings) = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
embeddings.extend(batch_embeddings);
|
||||
} else {
|
||||
continue 'buffers;
|
||||
}
|
||||
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for span in spans {
|
||||
let embedding = span.embedding.or_else(|| embeddings.next());
|
||||
if let Some(embedding) = embedding {
|
||||
todo!()
|
||||
} else {
|
||||
log::error!("failed to embed span");
|
||||
continue 'buffers;
|
||||
}
|
||||
}
|
||||
}
|
||||
anyhow::Ok(results)
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
Ok(Vec::new())
|
||||
};
|
||||
|
||||
let batch_results = futures::future::join_all(batch_results).await;
|
||||
|
||||
|
@ -873,7 +932,11 @@ impl SemanticIndex {
|
|||
}
|
||||
}
|
||||
|
||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
|
||||
let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
|
||||
let scores = results
|
||||
.into_iter()
|
||||
.map(|(_, score)| score)
|
||||
.collect::<Vec<f32>>();
|
||||
let spans = database.spans_for_ids(ids.as_slice()).await?;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
|
@ -903,19 +966,74 @@ impl SemanticIndex {
|
|||
t0.elapsed().as_millis()
|
||||
);
|
||||
|
||||
Ok(buffers
|
||||
let database_results = buffers
|
||||
.into_iter()
|
||||
.zip(ranges)
|
||||
.filter_map(|(buffer, range)| {
|
||||
.zip(scores)
|
||||
.filter_map(|((buffer, range), similarity)| {
|
||||
let buffer = buffer.log_err()?;
|
||||
let range = buffer.read_with(&cx, |buffer, _| {
|
||||
let start = buffer.clip_offset(range.start, Bias::Left);
|
||||
let end = buffer.clip_offset(range.end, Bias::Right);
|
||||
buffer.anchor_before(start)..buffer.anchor_after(end)
|
||||
});
|
||||
Some(SearchResult { buffer, range })
|
||||
Some(SearchResult {
|
||||
buffer,
|
||||
range,
|
||||
similarity,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Stitch Together Database Results & Buffer Results
|
||||
if let Ok(buffer_results) = buffer_results {
|
||||
let mut buffer_map = HashMap::default();
|
||||
for buffer_result in buffer_results {
|
||||
buffer_map
|
||||
.entry(buffer_result.clone().buffer)
|
||||
.or_insert(Vec::new())
|
||||
.push(buffer_result);
|
||||
}
|
||||
|
||||
for db_result in database_results {
|
||||
if !buffer_map.contains_key(&db_result.buffer) {
|
||||
buffer_map
|
||||
.entry(db_result.clone().buffer)
|
||||
.or_insert(Vec::new())
|
||||
.push(db_result);
|
||||
}
|
||||
}
|
||||
|
||||
let mut full_results = Vec::<SearchResult>::new();
|
||||
|
||||
for (_, results) in buffer_map {
|
||||
for res in results.into_iter() {
|
||||
let ix = match full_results.binary_search_by(|search_result| {
|
||||
res.similarity
|
||||
.partial_cmp(&search_result.similarity)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
full_results.insert(ix, res);
|
||||
full_results.truncate(limit);
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(full_results);
|
||||
} else {
|
||||
return Ok(database_results);
|
||||
}
|
||||
|
||||
// let ix = match results.binary_search_by(|(_, s)| {
|
||||
// similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
// }) {
|
||||
// Ok(ix) => ix,
|
||||
// Err(ix) => ix,
|
||||
// };
|
||||
// results.insert(ix, (id, similarity));
|
||||
// results.truncate(limit);
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue