Refactor semantic searching of modified buffers

This commit is contained in:
Antonio Scandurra 2023-09-15 12:12:20 +02:00
parent 796bdd3da7
commit ae85a520f2
5 changed files with 215 additions and 228 deletions

1
Cargo.lock generated
View file

@ -6739,6 +6739,7 @@ dependencies = [
"lazy_static", "lazy_static",
"log", "log",
"matrixmultiply", "matrixmultiply",
"ordered-float",
"parking_lot 0.11.2", "parking_lot 0.11.2",
"parse_duration", "parse_duration",
"picker", "picker",

View file

@ -23,6 +23,7 @@ settings = { path = "../settings" }
anyhow.workspace = true anyhow.workspace = true
postage.workspace = true postage.workspace = true
futures.workspace = true futures.workspace = true
ordered-float.workspace = true
smol.workspace = true smol.workspace = true
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] } rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
isahc.workspace = true isahc.workspace = true

View file

@ -7,12 +7,13 @@ use anyhow::{anyhow, Context, Result};
use collections::HashMap; use collections::HashMap;
use futures::channel::oneshot; use futures::channel::oneshot;
use gpui::executor; use gpui::executor;
use ordered_float::OrderedFloat;
use project::{search::PathMatcher, Fs}; use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp; use rpc::proto::Timestamp;
use rusqlite::params; use rusqlite::params;
use rusqlite::types::Value; use rusqlite::types::Value;
use std::{ use std::{
cmp::Ordering, cmp::Reverse,
future::Future, future::Future,
ops::Range, ops::Range,
path::{Path, PathBuf}, path::{Path, PathBuf},
@ -407,16 +408,16 @@ impl VectorDatabase {
query_embedding: &Embedding, query_embedding: &Embedding,
limit: usize, limit: usize,
file_ids: &[i64], file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, f32)>>> { ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
let query_embedding = query_embedding.clone(); let query_embedding = query_embedding.clone();
let file_ids = file_ids.to_vec(); let file_ids = file_ids.to_vec();
self.transact(move |db| { self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
Self::for_each_span(db, &file_ids, |id, embedding| { Self::for_each_span(db, &file_ids, |id, embedding| {
let similarity = embedding.similarity(&query_embedding); let similarity = embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|(_, s)| { let ix = match results
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
}) { {
Ok(ix) => ix, Ok(ix) => ix,
Err(ix) => ix, Err(ix) => ix,
}; };

View file

@ -7,6 +7,7 @@ use isahc::http::StatusCode;
use isahc::prelude::Configurable; use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response}; use isahc::{AsyncBody, Response};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use ordered_float::OrderedFloat;
use parking_lot::Mutex; use parking_lot::Mutex;
use parse_duration::parse; use parse_duration::parse;
use postage::watch; use postage::watch;
@ -35,7 +36,7 @@ impl From<Vec<f32>> for Embedding {
} }
impl Embedding { impl Embedding {
pub fn similarity(&self, other: &Self) -> f32 { pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len(); let len = self.0.len();
assert_eq!(len, other.0.len()); assert_eq!(len, other.0.len());
@ -58,7 +59,7 @@ impl Embedding {
1, 1,
); );
} }
result OrderedFloat(result)
} }
} }
@ -379,13 +380,13 @@ mod tests {
); );
} }
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 { fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places); let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor (n * factor).round() / factor
} }
fn reference_dot(a: &[f32], b: &[f32]) -> f32 { fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum() OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
} }
} }
} }

View file

@ -16,13 +16,14 @@ use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{future, FutureExt, StreamExt}; use futures::{future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry}; use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
use ordered_float::OrderedFloat;
use parking_lot::Mutex; use parking_lot::Mutex;
use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES}; use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch; use postage::watch;
use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId}; use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
use smol::channel; use smol::channel;
use std::{ use std::{
cmp::Ordering, cmp::Reverse,
future::Future, future::Future,
mem, mem,
ops::Range, ops::Range,
@ -267,7 +268,7 @@ pub struct PendingFile {
pub struct SearchResult { pub struct SearchResult {
pub buffer: ModelHandle<Buffer>, pub buffer: ModelHandle<Buffer>,
pub range: Range<Anchor>, pub range: Range<Anchor>,
pub similarity: f32, pub similarity: OrderedFloat<f32>,
} }
impl SemanticIndex { impl SemanticIndex {
@ -690,39 +691,71 @@ impl SemanticIndex {
pub fn search_project( pub fn search_project(
&mut self, &mut self,
project: ModelHandle<Project>, project: ModelHandle<Project>,
phrase: String, query: String,
limit: usize, limit: usize,
includes: Vec<PathMatcher>, includes: Vec<PathMatcher>,
mut excludes: Vec<PathMatcher>, excludes: Vec<PathMatcher>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> { ) -> Task<Result<Vec<SearchResult>>> {
if query.is_empty() {
return Task::ready(Ok(Vec::new()));
}
let index = self.index_project(project.clone(), cx); let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone(); let embedding_provider = self.embedding_provider.clone();
cx.spawn(|this, mut cx| async move {
let query = embedding_provider
.embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
index.await?;
let search_start = Instant::now();
let modified_buffer_results = this.update(&mut cx, |this, cx| {
this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx)
});
let file_results = this.update(&mut cx, |this, cx| {
this.search_files(project, query, limit, includes, excludes, cx)
});
let (modified_buffer_results, file_results) =
futures::join!(modified_buffer_results, file_results);
// Weave together the results from modified buffers and files.
let mut results = Vec::new();
let mut modified_buffers = HashSet::default();
for result in modified_buffer_results.log_err().unwrap_or_default() {
modified_buffers.insert(result.buffer.clone());
results.push(result);
}
for result in file_results.log_err().unwrap_or_default() {
if !modified_buffers.contains(&result.buffer) {
results.push(result);
}
}
results.sort_by_key(|result| Reverse(result.similarity));
results.truncate(limit);
log::trace!("Semantic search took {:?}", search_start.elapsed());
Ok(results)
})
}
pub fn search_files(
&mut self,
project: ModelHandle<Project>,
query: Embedding,
limit: usize,
includes: Vec<PathMatcher>,
excludes: Vec<PathMatcher>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
let db_path = self.db.path().clone(); let db_path = self.db.path().clone();
let fs = self.fs.clone(); let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let database = let database =
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?; VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
if phrase.len() == 0 {
return Ok(Vec::new());
}
let phrase_embedding = embedding_provider
.embed_batch(vec![phrase])
.await?
.into_iter()
.next()
.unwrap();
log::trace!(
"Embedding search phrase took: {:?} milliseconds",
t0.elapsed().as_millis()
);
let worktree_db_ids = this.read_with(&cx, |this, _| { let worktree_db_ids = this.read_with(&cx, |this, _| {
let project_state = this let project_state = this
.projects .projects
@ -742,42 +775,6 @@ impl SemanticIndex {
anyhow::Ok(worktree_db_ids) anyhow::Ok(worktree_db_ids)
})?; })?;
let (dirty_buffers, dirty_paths) = project.read_with(&cx, |project, cx| {
let mut dirty_paths = Vec::new();
let dirty_buffers = project
.opened_buffers(cx)
.into_iter()
.filter_map(|buffer_handle| {
let buffer = buffer_handle.read(cx);
if buffer.is_dirty() {
let snapshot = buffer.snapshot();
if let Some(file_pathbuf) = snapshot.resolve_file_path(cx, false) {
let file_path = file_pathbuf.as_path();
if excludes.iter().any(|glob| glob.is_match(file_path)) {
return None;
}
file_pathbuf
.to_str()
.and_then(|path| PathMatcher::new(path).log_err())
.and_then(|path_matcher| {
dirty_paths.push(path_matcher);
Some(())
});
}
// TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly
Some((buffer_handle, buffer.snapshot()))
} else {
None
}
})
.collect::<HashMap<_, _>>();
(dirty_buffers, dirty_paths)
});
excludes.extend(dirty_paths);
let file_ids = database let file_ids = database
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
.await?; .await?;
@ -796,155 +793,26 @@ impl SemanticIndex {
let limit = limit.clone(); let limit = limit.clone();
let fs = fs.clone(); let fs = fs.clone();
let db_path = db_path.clone(); let db_path = db_path.clone();
let phrase_embedding = phrase_embedding.clone(); let query = query.clone();
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background()) if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
.await .await
.log_err() .log_err()
{ {
batch_results.push(async move { batch_results.push(async move {
db.top_k_search(&phrase_embedding, limit, batch.as_slice()) db.top_k_search(&query, limit, batch.as_slice()).await
.await
}); });
} }
} }
let buffer_results = if let Some(db) =
VectorDatabase::new(fs, db_path.clone(), cx.background())
.await
.log_err()
{
cx.background()
.spawn({
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
let embedding_provider = embedding_provider.clone();
let phrase_embedding = phrase_embedding.clone();
async move {
let mut results = Vec::<SearchResult>::new();
'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
let language = buffer_snapshot
.language_at(0)
.cloned()
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
if let Some(spans) = retriever
.parse_file_with_template(
None,
&buffer_snapshot.text(),
language,
)
.log_err()
{
let mut batch = Vec::new();
let mut batch_tokens = 0;
let mut embeddings = Vec::new();
let digests = spans
.iter()
.map(|span| span.digest.clone())
.collect::<Vec<_>>();
let embeddings_for_digests = db
.embeddings_for_digests(digests)
.await
.map_or(Default::default(), |m| m);
for span in &spans {
if embeddings_for_digests.contains_key(&span.digest) {
continue;
};
if batch_tokens + span.token_count
> embedding_provider.max_tokens_per_batch()
{
if let Some(batch_embeddings) = embedding_provider
.embed_batch(mem::take(&mut batch))
.await
.log_err()
{
embeddings.extend(batch_embeddings);
batch_tokens = 0;
} else {
continue 'buffers;
}
}
batch_tokens += span.token_count;
batch.push(span.content.clone());
}
if let Some(batch_embeddings) = embedding_provider
.embed_batch(mem::take(&mut batch))
.await
.log_err()
{
embeddings.extend(batch_embeddings);
} else {
continue 'buffers;
}
let mut embeddings = embeddings.into_iter();
for span in spans {
let embedding = if let Some(embedding) =
embeddings_for_digests.get(&span.digest)
{
Some(embedding.clone())
} else {
embeddings.next()
};
if let Some(embedding) = embedding {
let similarity =
embedding.similarity(&phrase_embedding);
let ix = match results.binary_search_by(|s| {
similarity
.partial_cmp(&s.similarity)
.unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
let range = {
let start = buffer_snapshot
.clip_offset(span.range.start, Bias::Left);
let end = buffer_snapshot
.clip_offset(span.range.end, Bias::Right);
buffer_snapshot.anchor_before(start)
..buffer_snapshot.anchor_after(end)
};
results.insert(
ix,
SearchResult {
buffer: buffer_handle.clone(),
range,
similarity,
},
);
results.truncate(limit);
} else {
log::error!("failed to embed span");
continue 'buffers;
}
}
}
}
anyhow::Ok(results)
}
})
.await
} else {
Ok(Vec::new())
};
let batch_results = futures::future::join_all(batch_results).await; let batch_results = futures::future::join_all(batch_results).await;
let mut results = Vec::new(); let mut results = Vec::new();
for batch_result in batch_results { for batch_result in batch_results {
if batch_result.is_ok() { if batch_result.is_ok() {
for (id, similarity) in batch_result.unwrap() { for (id, similarity) in batch_result.unwrap() {
let ix = match results.binary_search_by(|(_, s)| { let ix = match results
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
}) { {
Ok(ix) => ix, Ok(ix) => ix,
Err(ix) => ix, Err(ix) => ix,
}; };
@ -958,7 +826,7 @@ impl SemanticIndex {
let scores = results let scores = results
.into_iter() .into_iter()
.map(|(_, score)| score) .map(|(_, score)| score)
.collect::<Vec<f32>>(); .collect::<Vec<_>>();
let spans = database.spans_for_ids(ids.as_slice()).await?; let spans = database.spans_for_ids(ids.as_slice()).await?;
let mut tasks = Vec::new(); let mut tasks = Vec::new();
@ -983,12 +851,7 @@ impl SemanticIndex {
let buffers = futures::future::join_all(tasks).await; let buffers = futures::future::join_all(tasks).await;
log::trace!( Ok(buffers
"Semantic Searching took: {:?} milliseconds in total",
t0.elapsed().as_millis()
);
let mut database_results = buffers
.into_iter() .into_iter()
.zip(ranges) .zip(ranges)
.zip(scores) .zip(scores)
@ -1005,26 +868,89 @@ impl SemanticIndex {
similarity, similarity,
}) })
}) })
.collect::<Vec<_>>(); .collect())
})
}
// Stitch Together Database Results & Buffer Results fn search_modified_buffers(
if let Ok(buffer_results) = buffer_results { &self,
for buffer_result in buffer_results { project: &ModelHandle<Project>,
let ix = match database_results.binary_search_by(|search_result| { query: Embedding,
buffer_result limit: usize,
.similarity excludes: &[PathMatcher],
.partial_cmp(&search_result.similarity) cx: &mut ModelContext<Self>,
.unwrap_or(Ordering::Equal) ) -> Task<Result<Vec<SearchResult>>> {
}) { let modified_buffers = project
Ok(ix) => ix, .read(cx)
Err(ix) => ix, .opened_buffers(cx)
}; .into_iter()
database_results.insert(ix, buffer_result); .filter_map(|buffer_handle| {
database_results.truncate(limit); let buffer = buffer_handle.read(cx);
let snapshot = buffer.snapshot();
let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
excludes.iter().any(|matcher| matcher.is_match(&path))
});
if buffer.is_dirty() && !excluded {
Some((buffer_handle, snapshot))
} else {
None
}
})
.collect::<HashMap<_, _>>();
let embedding_provider = self.embedding_provider.clone();
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
for (buffer, snapshot) in modified_buffers {
let language = snapshot
.language_at(0)
.cloned()
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
let mut spans = retriever
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
.await
.log_err()
.is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
let ix = match results
.binary_search_by_key(&Reverse(similarity), |result| {
Reverse(result.similarity)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
let range = {
let start = snapshot.clip_offset(span.range.start, Bias::Left);
let end = snapshot.clip_offset(span.range.end, Bias::Right);
snapshot.anchor_before(start)..snapshot.anchor_after(end)
};
results.insert(
ix,
SearchResult {
buffer: buffer.clone(),
range,
similarity,
},
);
results.truncate(limit);
}
} }
} }
Ok(database_results) Ok(results)
}) })
} }
@ -1208,6 +1134,63 @@ impl SemanticIndex {
Ok(()) Ok(())
}) })
} }
async fn embed_spans(
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
let mut embeddings = Vec::new();
let digests = spans
.iter()
.map(|span| span.digest.clone())
.collect::<Vec<_>>();
let embeddings_for_digests = db
.embeddings_for_digests(digests)
.await
.log_err()
.unwrap_or_default();
for span in &*spans {
if embeddings_for_digests.contains_key(&span.digest) {
continue;
};
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
}
batch_tokens += span.token_count;
batch.push(span.content.clone());
}
if !batch.is_empty() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
}
let mut embeddings = embeddings.into_iter();
for span in spans {
let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
Some(embedding.clone())
} else {
embeddings.next()
};
let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
span.embedding = Some(embedding);
}
Ok(())
}
} }
impl Entity for SemanticIndex { impl Entity for SemanticIndex {