Refactor semantic searching of modified buffers
This commit is contained in:
parent
796bdd3da7
commit
ae85a520f2
5 changed files with 215 additions and 228 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -6739,6 +6739,7 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"log",
|
"log",
|
||||||
"matrixmultiply",
|
"matrixmultiply",
|
||||||
|
"ordered-float",
|
||||||
"parking_lot 0.11.2",
|
"parking_lot 0.11.2",
|
||||||
"parse_duration",
|
"parse_duration",
|
||||||
"picker",
|
"picker",
|
||||||
|
|
|
@ -23,6 +23,7 @@ settings = { path = "../settings" }
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
postage.workspace = true
|
postage.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
|
ordered-float.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
|
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
|
||||||
isahc.workspace = true
|
isahc.workspace = true
|
||||||
|
|
|
@ -7,12 +7,13 @@ use anyhow::{anyhow, Context, Result};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use futures::channel::oneshot;
|
use futures::channel::oneshot;
|
||||||
use gpui::executor;
|
use gpui::executor;
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
use project::{search::PathMatcher, Fs};
|
use project::{search::PathMatcher, Fs};
|
||||||
use rpc::proto::Timestamp;
|
use rpc::proto::Timestamp;
|
||||||
use rusqlite::params;
|
use rusqlite::params;
|
||||||
use rusqlite::types::Value;
|
use rusqlite::types::Value;
|
||||||
use std::{
|
use std::{
|
||||||
cmp::Ordering,
|
cmp::Reverse,
|
||||||
future::Future,
|
future::Future,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
|
@ -407,16 +408,16 @@ impl VectorDatabase {
|
||||||
query_embedding: &Embedding,
|
query_embedding: &Embedding,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
file_ids: &[i64],
|
file_ids: &[i64],
|
||||||
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
|
) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
|
||||||
let query_embedding = query_embedding.clone();
|
let query_embedding = query_embedding.clone();
|
||||||
let file_ids = file_ids.to_vec();
|
let file_ids = file_ids.to_vec();
|
||||||
self.transact(move |db| {
|
self.transact(move |db| {
|
||||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
|
||||||
Self::for_each_span(db, &file_ids, |id, embedding| {
|
Self::for_each_span(db, &file_ids, |id, embedding| {
|
||||||
let similarity = embedding.similarity(&query_embedding);
|
let similarity = embedding.similarity(&query_embedding);
|
||||||
let ix = match results.binary_search_by(|(_, s)| {
|
let ix = match results
|
||||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
||||||
}) {
|
{
|
||||||
Ok(ix) => ix,
|
Ok(ix) => ix,
|
||||||
Err(ix) => ix,
|
Err(ix) => ix,
|
||||||
};
|
};
|
||||||
|
|
|
@ -7,6 +7,7 @@ use isahc::http::StatusCode;
|
||||||
use isahc::prelude::Configurable;
|
use isahc::prelude::Configurable;
|
||||||
use isahc::{AsyncBody, Response};
|
use isahc::{AsyncBody, Response};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use parse_duration::parse;
|
use parse_duration::parse;
|
||||||
use postage::watch;
|
use postage::watch;
|
||||||
|
@ -35,7 +36,7 @@ impl From<Vec<f32>> for Embedding {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
pub fn similarity(&self, other: &Self) -> f32 {
|
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
|
||||||
let len = self.0.len();
|
let len = self.0.len();
|
||||||
assert_eq!(len, other.0.len());
|
assert_eq!(len, other.0.len());
|
||||||
|
|
||||||
|
@ -58,7 +59,7 @@ impl Embedding {
|
||||||
1,
|
1,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
result
|
OrderedFloat(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,13 +380,13 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
|
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
|
||||||
let factor = (10.0 as f32).powi(decimal_places);
|
let factor = (10.0 as f32).powi(decimal_places);
|
||||||
(n * factor).round() / factor
|
(n * factor).round() / factor
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
|
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
|
||||||
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
|
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,13 +16,14 @@ use embedding_queue::{EmbeddingQueue, FileToEmbed};
|
||||||
use futures::{future, FutureExt, StreamExt};
|
use futures::{future, FutureExt, StreamExt};
|
||||||
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
||||||
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
|
use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
|
use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
|
||||||
use postage::watch;
|
use postage::watch;
|
||||||
use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
|
use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
|
||||||
use smol::channel;
|
use smol::channel;
|
||||||
use std::{
|
use std::{
|
||||||
cmp::Ordering,
|
cmp::Reverse,
|
||||||
future::Future,
|
future::Future,
|
||||||
mem,
|
mem,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
|
@ -267,7 +268,7 @@ pub struct PendingFile {
|
||||||
pub struct SearchResult {
|
pub struct SearchResult {
|
||||||
pub buffer: ModelHandle<Buffer>,
|
pub buffer: ModelHandle<Buffer>,
|
||||||
pub range: Range<Anchor>,
|
pub range: Range<Anchor>,
|
||||||
pub similarity: f32,
|
pub similarity: OrderedFloat<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SemanticIndex {
|
impl SemanticIndex {
|
||||||
|
@ -690,39 +691,71 @@ impl SemanticIndex {
|
||||||
pub fn search_project(
|
pub fn search_project(
|
||||||
&mut self,
|
&mut self,
|
||||||
project: ModelHandle<Project>,
|
project: ModelHandle<Project>,
|
||||||
phrase: String,
|
query: String,
|
||||||
limit: usize,
|
limit: usize,
|
||||||
includes: Vec<PathMatcher>,
|
includes: Vec<PathMatcher>,
|
||||||
mut excludes: Vec<PathMatcher>,
|
excludes: Vec<PathMatcher>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<Vec<SearchResult>>> {
|
) -> Task<Result<Vec<SearchResult>>> {
|
||||||
|
if query.is_empty() {
|
||||||
|
return Task::ready(Ok(Vec::new()));
|
||||||
|
}
|
||||||
|
|
||||||
let index = self.index_project(project.clone(), cx);
|
let index = self.index_project(project.clone(), cx);
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
|
|
||||||
|
cx.spawn(|this, mut cx| async move {
|
||||||
|
let query = embedding_provider
|
||||||
|
.embed_batch(vec![query])
|
||||||
|
.await?
|
||||||
|
.pop()
|
||||||
|
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||||
|
index.await?;
|
||||||
|
|
||||||
|
let search_start = Instant::now();
|
||||||
|
let modified_buffer_results = this.update(&mut cx, |this, cx| {
|
||||||
|
this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx)
|
||||||
|
});
|
||||||
|
let file_results = this.update(&mut cx, |this, cx| {
|
||||||
|
this.search_files(project, query, limit, includes, excludes, cx)
|
||||||
|
});
|
||||||
|
let (modified_buffer_results, file_results) =
|
||||||
|
futures::join!(modified_buffer_results, file_results);
|
||||||
|
|
||||||
|
// Weave together the results from modified buffers and files.
|
||||||
|
let mut results = Vec::new();
|
||||||
|
let mut modified_buffers = HashSet::default();
|
||||||
|
for result in modified_buffer_results.log_err().unwrap_or_default() {
|
||||||
|
modified_buffers.insert(result.buffer.clone());
|
||||||
|
results.push(result);
|
||||||
|
}
|
||||||
|
for result in file_results.log_err().unwrap_or_default() {
|
||||||
|
if !modified_buffers.contains(&result.buffer) {
|
||||||
|
results.push(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.sort_by_key(|result| Reverse(result.similarity));
|
||||||
|
results.truncate(limit);
|
||||||
|
log::trace!("Semantic search took {:?}", search_start.elapsed());
|
||||||
|
Ok(results)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn search_files(
|
||||||
|
&mut self,
|
||||||
|
project: ModelHandle<Project>,
|
||||||
|
query: Embedding,
|
||||||
|
limit: usize,
|
||||||
|
includes: Vec<PathMatcher>,
|
||||||
|
excludes: Vec<PathMatcher>,
|
||||||
|
cx: &mut ModelContext<Self>,
|
||||||
|
) -> Task<Result<Vec<SearchResult>>> {
|
||||||
let db_path = self.db.path().clone();
|
let db_path = self.db.path().clone();
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
index.await?;
|
|
||||||
|
|
||||||
let t0 = Instant::now();
|
|
||||||
let database =
|
let database =
|
||||||
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
|
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
|
||||||
|
|
||||||
if phrase.len() == 0 {
|
|
||||||
return Ok(Vec::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
let phrase_embedding = embedding_provider
|
|
||||||
.embed_batch(vec![phrase])
|
|
||||||
.await?
|
|
||||||
.into_iter()
|
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"Embedding search phrase took: {:?} milliseconds",
|
|
||||||
t0.elapsed().as_millis()
|
|
||||||
);
|
|
||||||
|
|
||||||
let worktree_db_ids = this.read_with(&cx, |this, _| {
|
let worktree_db_ids = this.read_with(&cx, |this, _| {
|
||||||
let project_state = this
|
let project_state = this
|
||||||
.projects
|
.projects
|
||||||
|
@ -742,42 +775,6 @@ impl SemanticIndex {
|
||||||
anyhow::Ok(worktree_db_ids)
|
anyhow::Ok(worktree_db_ids)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let (dirty_buffers, dirty_paths) = project.read_with(&cx, |project, cx| {
|
|
||||||
let mut dirty_paths = Vec::new();
|
|
||||||
let dirty_buffers = project
|
|
||||||
.opened_buffers(cx)
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|buffer_handle| {
|
|
||||||
let buffer = buffer_handle.read(cx);
|
|
||||||
if buffer.is_dirty() {
|
|
||||||
let snapshot = buffer.snapshot();
|
|
||||||
if let Some(file_pathbuf) = snapshot.resolve_file_path(cx, false) {
|
|
||||||
let file_path = file_pathbuf.as_path();
|
|
||||||
|
|
||||||
if excludes.iter().any(|glob| glob.is_match(file_path)) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
file_pathbuf
|
|
||||||
.to_str()
|
|
||||||
.and_then(|path| PathMatcher::new(path).log_err())
|
|
||||||
.and_then(|path_matcher| {
|
|
||||||
dirty_paths.push(path_matcher);
|
|
||||||
Some(())
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly
|
|
||||||
Some((buffer_handle, buffer.snapshot()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<HashMap<_, _>>();
|
|
||||||
|
|
||||||
(dirty_buffers, dirty_paths)
|
|
||||||
});
|
|
||||||
|
|
||||||
excludes.extend(dirty_paths);
|
|
||||||
let file_ids = database
|
let file_ids = database
|
||||||
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
|
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -796,155 +793,26 @@ impl SemanticIndex {
|
||||||
let limit = limit.clone();
|
let limit = limit.clone();
|
||||||
let fs = fs.clone();
|
let fs = fs.clone();
|
||||||
let db_path = db_path.clone();
|
let db_path = db_path.clone();
|
||||||
let phrase_embedding = phrase_embedding.clone();
|
let query = query.clone();
|
||||||
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
|
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
|
||||||
.await
|
.await
|
||||||
.log_err()
|
.log_err()
|
||||||
{
|
{
|
||||||
batch_results.push(async move {
|
batch_results.push(async move {
|
||||||
db.top_k_search(&phrase_embedding, limit, batch.as_slice())
|
db.top_k_search(&query, limit, batch.as_slice()).await
|
||||||
.await
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let buffer_results = if let Some(db) =
|
|
||||||
VectorDatabase::new(fs, db_path.clone(), cx.background())
|
|
||||||
.await
|
|
||||||
.log_err()
|
|
||||||
{
|
|
||||||
cx.background()
|
|
||||||
.spawn({
|
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
|
|
||||||
let embedding_provider = embedding_provider.clone();
|
|
||||||
let phrase_embedding = phrase_embedding.clone();
|
|
||||||
async move {
|
|
||||||
let mut results = Vec::<SearchResult>::new();
|
|
||||||
'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
|
|
||||||
let language = buffer_snapshot
|
|
||||||
.language_at(0)
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
|
|
||||||
if let Some(spans) = retriever
|
|
||||||
.parse_file_with_template(
|
|
||||||
None,
|
|
||||||
&buffer_snapshot.text(),
|
|
||||||
language,
|
|
||||||
)
|
|
||||||
.log_err()
|
|
||||||
{
|
|
||||||
let mut batch = Vec::new();
|
|
||||||
let mut batch_tokens = 0;
|
|
||||||
let mut embeddings = Vec::new();
|
|
||||||
|
|
||||||
let digests = spans
|
|
||||||
.iter()
|
|
||||||
.map(|span| span.digest.clone())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let embeddings_for_digests = db
|
|
||||||
.embeddings_for_digests(digests)
|
|
||||||
.await
|
|
||||||
.map_or(Default::default(), |m| m);
|
|
||||||
|
|
||||||
for span in &spans {
|
|
||||||
if embeddings_for_digests.contains_key(&span.digest) {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
if batch_tokens + span.token_count
|
|
||||||
> embedding_provider.max_tokens_per_batch()
|
|
||||||
{
|
|
||||||
if let Some(batch_embeddings) = embedding_provider
|
|
||||||
.embed_batch(mem::take(&mut batch))
|
|
||||||
.await
|
|
||||||
.log_err()
|
|
||||||
{
|
|
||||||
embeddings.extend(batch_embeddings);
|
|
||||||
batch_tokens = 0;
|
|
||||||
} else {
|
|
||||||
continue 'buffers;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
batch_tokens += span.token_count;
|
|
||||||
batch.push(span.content.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(batch_embeddings) = embedding_provider
|
|
||||||
.embed_batch(mem::take(&mut batch))
|
|
||||||
.await
|
|
||||||
.log_err()
|
|
||||||
{
|
|
||||||
embeddings.extend(batch_embeddings);
|
|
||||||
} else {
|
|
||||||
continue 'buffers;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut embeddings = embeddings.into_iter();
|
|
||||||
for span in spans {
|
|
||||||
let embedding = if let Some(embedding) =
|
|
||||||
embeddings_for_digests.get(&span.digest)
|
|
||||||
{
|
|
||||||
Some(embedding.clone())
|
|
||||||
} else {
|
|
||||||
embeddings.next()
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(embedding) = embedding {
|
|
||||||
let similarity =
|
|
||||||
embedding.similarity(&phrase_embedding);
|
|
||||||
|
|
||||||
let ix = match results.binary_search_by(|s| {
|
|
||||||
similarity
|
|
||||||
.partial_cmp(&s.similarity)
|
|
||||||
.unwrap_or(Ordering::Equal)
|
|
||||||
}) {
|
|
||||||
Ok(ix) => ix,
|
|
||||||
Err(ix) => ix,
|
|
||||||
};
|
|
||||||
|
|
||||||
let range = {
|
|
||||||
let start = buffer_snapshot
|
|
||||||
.clip_offset(span.range.start, Bias::Left);
|
|
||||||
let end = buffer_snapshot
|
|
||||||
.clip_offset(span.range.end, Bias::Right);
|
|
||||||
buffer_snapshot.anchor_before(start)
|
|
||||||
..buffer_snapshot.anchor_after(end)
|
|
||||||
};
|
|
||||||
|
|
||||||
results.insert(
|
|
||||||
ix,
|
|
||||||
SearchResult {
|
|
||||||
buffer: buffer_handle.clone(),
|
|
||||||
range,
|
|
||||||
similarity,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
results.truncate(limit);
|
|
||||||
} else {
|
|
||||||
log::error!("failed to embed span");
|
|
||||||
continue 'buffers;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
anyhow::Ok(results)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
} else {
|
|
||||||
Ok(Vec::new())
|
|
||||||
};
|
|
||||||
|
|
||||||
let batch_results = futures::future::join_all(batch_results).await;
|
let batch_results = futures::future::join_all(batch_results).await;
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
for batch_result in batch_results {
|
for batch_result in batch_results {
|
||||||
if batch_result.is_ok() {
|
if batch_result.is_ok() {
|
||||||
for (id, similarity) in batch_result.unwrap() {
|
for (id, similarity) in batch_result.unwrap() {
|
||||||
let ix = match results.binary_search_by(|(_, s)| {
|
let ix = match results
|
||||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
.binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
|
||||||
}) {
|
{
|
||||||
Ok(ix) => ix,
|
Ok(ix) => ix,
|
||||||
Err(ix) => ix,
|
Err(ix) => ix,
|
||||||
};
|
};
|
||||||
|
@ -958,7 +826,7 @@ impl SemanticIndex {
|
||||||
let scores = results
|
let scores = results
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(_, score)| score)
|
.map(|(_, score)| score)
|
||||||
.collect::<Vec<f32>>();
|
.collect::<Vec<_>>();
|
||||||
let spans = database.spans_for_ids(ids.as_slice()).await?;
|
let spans = database.spans_for_ids(ids.as_slice()).await?;
|
||||||
|
|
||||||
let mut tasks = Vec::new();
|
let mut tasks = Vec::new();
|
||||||
|
@ -983,12 +851,7 @@ impl SemanticIndex {
|
||||||
|
|
||||||
let buffers = futures::future::join_all(tasks).await;
|
let buffers = futures::future::join_all(tasks).await;
|
||||||
|
|
||||||
log::trace!(
|
Ok(buffers
|
||||||
"Semantic Searching took: {:?} milliseconds in total",
|
|
||||||
t0.elapsed().as_millis()
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut database_results = buffers
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(ranges)
|
.zip(ranges)
|
||||||
.zip(scores)
|
.zip(scores)
|
||||||
|
@ -1005,26 +868,89 @@ impl SemanticIndex {
|
||||||
similarity,
|
similarity,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Stitch Together Database Results & Buffer Results
|
fn search_modified_buffers(
|
||||||
if let Ok(buffer_results) = buffer_results {
|
&self,
|
||||||
for buffer_result in buffer_results {
|
project: &ModelHandle<Project>,
|
||||||
let ix = match database_results.binary_search_by(|search_result| {
|
query: Embedding,
|
||||||
buffer_result
|
limit: usize,
|
||||||
.similarity
|
excludes: &[PathMatcher],
|
||||||
.partial_cmp(&search_result.similarity)
|
cx: &mut ModelContext<Self>,
|
||||||
.unwrap_or(Ordering::Equal)
|
) -> Task<Result<Vec<SearchResult>>> {
|
||||||
}) {
|
let modified_buffers = project
|
||||||
Ok(ix) => ix,
|
.read(cx)
|
||||||
Err(ix) => ix,
|
.opened_buffers(cx)
|
||||||
};
|
.into_iter()
|
||||||
database_results.insert(ix, buffer_result);
|
.filter_map(|buffer_handle| {
|
||||||
database_results.truncate(limit);
|
let buffer = buffer_handle.read(cx);
|
||||||
|
let snapshot = buffer.snapshot();
|
||||||
|
let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
|
||||||
|
excludes.iter().any(|matcher| matcher.is_match(&path))
|
||||||
|
});
|
||||||
|
if buffer.is_dirty() && !excluded {
|
||||||
|
Some((buffer_handle, snapshot))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
|
||||||
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
|
let fs = self.fs.clone();
|
||||||
|
let db_path = self.db.path().clone();
|
||||||
|
let background = cx.background().clone();
|
||||||
|
cx.background().spawn(async move {
|
||||||
|
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||||
|
let mut results = Vec::<SearchResult>::new();
|
||||||
|
|
||||||
|
let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
|
||||||
|
for (buffer, snapshot) in modified_buffers {
|
||||||
|
let language = snapshot
|
||||||
|
.language_at(0)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| language::PLAIN_TEXT.clone());
|
||||||
|
let mut spans = retriever
|
||||||
|
.parse_file_with_template(None, &snapshot.text(), language)
|
||||||
|
.log_err()
|
||||||
|
.unwrap_or_default();
|
||||||
|
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
||||||
|
.await
|
||||||
|
.log_err()
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
for span in spans {
|
||||||
|
let similarity = span.embedding.unwrap().similarity(&query);
|
||||||
|
let ix = match results
|
||||||
|
.binary_search_by_key(&Reverse(similarity), |result| {
|
||||||
|
Reverse(result.similarity)
|
||||||
|
}) {
|
||||||
|
Ok(ix) => ix,
|
||||||
|
Err(ix) => ix,
|
||||||
|
};
|
||||||
|
|
||||||
|
let range = {
|
||||||
|
let start = snapshot.clip_offset(span.range.start, Bias::Left);
|
||||||
|
let end = snapshot.clip_offset(span.range.end, Bias::Right);
|
||||||
|
snapshot.anchor_before(start)..snapshot.anchor_after(end)
|
||||||
|
};
|
||||||
|
|
||||||
|
results.insert(
|
||||||
|
ix,
|
||||||
|
SearchResult {
|
||||||
|
buffer: buffer.clone(),
|
||||||
|
range,
|
||||||
|
similarity,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
results.truncate(limit);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(database_results)
|
Ok(results)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1208,6 +1134,63 @@ impl SemanticIndex {
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn embed_spans(
|
||||||
|
spans: &mut [Span],
|
||||||
|
embedding_provider: &dyn EmbeddingProvider,
|
||||||
|
db: &VectorDatabase,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut batch = Vec::new();
|
||||||
|
let mut batch_tokens = 0;
|
||||||
|
let mut embeddings = Vec::new();
|
||||||
|
|
||||||
|
let digests = spans
|
||||||
|
.iter()
|
||||||
|
.map(|span| span.digest.clone())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let embeddings_for_digests = db
|
||||||
|
.embeddings_for_digests(digests)
|
||||||
|
.await
|
||||||
|
.log_err()
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
for span in &*spans {
|
||||||
|
if embeddings_for_digests.contains_key(&span.digest) {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||||
|
let batch_embeddings = embedding_provider
|
||||||
|
.embed_batch(mem::take(&mut batch))
|
||||||
|
.await?;
|
||||||
|
embeddings.extend(batch_embeddings);
|
||||||
|
batch_tokens = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_tokens += span.token_count;
|
||||||
|
batch.push(span.content.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !batch.is_empty() {
|
||||||
|
let batch_embeddings = embedding_provider
|
||||||
|
.embed_batch(mem::take(&mut batch))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
embeddings.extend(batch_embeddings);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut embeddings = embeddings.into_iter();
|
||||||
|
for span in spans {
|
||||||
|
let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
|
||||||
|
Some(embedding.clone())
|
||||||
|
} else {
|
||||||
|
embeddings.next()
|
||||||
|
};
|
||||||
|
let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
|
||||||
|
span.embedding = Some(embedding);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Entity for SemanticIndex {
|
impl Entity for SemanticIndex {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue