Optimize glob filtering of semantic search

Co-authored-by: Kyle <kyle@zed.dev>
This commit is contained in:
Max Brunsfeld 2023-07-20 14:23:11 -07:00
parent e02d6bc0d4
commit 81b05f2a08
4 changed files with 109 additions and 63 deletions

View file

@ -669,7 +669,6 @@ impl ProjectSearchView {
&mut self, &mut self,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> Option<(Vec<GlobMatcher>, Vec<GlobMatcher>)> { ) -> Option<(Vec<GlobMatcher>, Vec<GlobMatcher>)> {
let text = self.query_editor.read(cx).text(cx);
let included_files = let included_files =
match Self::load_glob_set(&self.included_files_editor.read(cx).text(cx)) { match Self::load_glob_set(&self.included_files_editor.read(cx).text(cx)) {
Ok(included_files) => { Ok(included_files) => {

View file

@ -1,6 +1,6 @@
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use globset::{Glob, GlobMatcher}; use globset::GlobMatcher;
use project::Fs; use project::Fs;
use rpc::proto::Timestamp; use rpc::proto::Timestamp;
use rusqlite::{ use rusqlite::{
@ -257,16 +257,11 @@ impl VectorDatabase {
exclude_globs: Vec<GlobMatcher>, exclude_globs: Vec<GlobMatcher>,
) -> Result<Vec<(i64, PathBuf, Range<usize>)>> { ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
self.for_each_document(&worktree_ids, |relative_path, id, embedding| { self.for_each_document(
if (include_globs.is_empty() &worktree_ids,
|| include_globs include_globs,
.iter() exclude_globs,
.any(|include_glob| include_glob.is_match(relative_path.clone()))) |id, embedding| {
&& (exclude_globs.is_empty()
|| !exclude_globs
.iter()
.any(|exclude_glob| exclude_glob.is_match(relative_path.clone())))
{
let similarity = dot(&embedding, &query_embedding); let similarity = dot(&embedding, &query_embedding);
let ix = match results.binary_search_by(|(_, s)| { let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
@ -276,8 +271,8 @@ impl VectorDatabase {
}; };
results.insert(ix, (id, similarity)); results.insert(ix, (id, similarity));
results.truncate(limit); results.truncate(limit);
} },
})?; )?;
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>(); let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
self.get_documents_by_ids(&ids) self.get_documents_by_ids(&ids)
@ -286,26 +281,55 @@ impl VectorDatabase {
fn for_each_document( fn for_each_document(
&self, &self,
worktree_ids: &[i64], worktree_ids: &[i64],
mut f: impl FnMut(String, i64, Vec<f32>), include_globs: Vec<GlobMatcher>,
exclude_globs: Vec<GlobMatcher>,
mut f: impl FnMut(i64, Vec<f32>),
) -> Result<()> { ) -> Result<()> {
let mut file_query = self.db.prepare(
"
SELECT
id, relative_path
FROM
files
WHERE
worktree_id IN rarray(?)
",
)?;
let mut file_ids = Vec::<i64>::new();
let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
while let Some(row) = rows.next()? {
let file_id = row.get(0)?;
let relative_path = row.get_ref(1)?.as_str()?;
let included = include_globs.is_empty()
|| include_globs
.iter()
.any(|glob| glob.is_match(relative_path));
let excluded = exclude_globs
.iter()
.any(|glob| glob.is_match(relative_path));
if included && !excluded {
file_ids.push(file_id);
}
}
let mut query_statement = self.db.prepare( let mut query_statement = self.db.prepare(
" "
SELECT SELECT
files.relative_path, documents.id, documents.embedding id, embedding
FROM FROM
documents, files documents
WHERE WHERE
documents.file_id = files.id AND file_id IN rarray(?)
files.worktree_id IN rarray(?)
", ",
)?; )?;
query_statement query_statement
.query_map(params![ids_to_sql(worktree_ids)], |row| { .query_map(params![ids_to_sql(&file_ids)], |row| {
Ok((row.get(0)?, row.get(1)?, row.get::<_, Embedding>(2)?)) Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})? })?
.filter_map(|row| row.ok()) .filter_map(|row| row.ok())
.for_each(|(relative_path, id, embedding)| f(relative_path, id, embedding.0)); .for_each(|(id, embedding)| f(id, embedding.0));
Ok(()) Ok(())
} }

View file

@ -11,7 +11,7 @@ use anyhow::{anyhow, Result};
use db::VectorDatabase; use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use futures::{channel::oneshot, Future}; use futures::{channel::oneshot, Future};
use globset::{Glob, GlobMatcher}; use globset::GlobMatcher;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Buffer, Language, LanguageRegistry}; use language::{Anchor, Buffer, Language, LanguageRegistry};
use parking_lot::Mutex; use parking_lot::Mutex;

View file

@ -3,7 +3,7 @@ use crate::{
embedding::EmbeddingProvider, embedding::EmbeddingProvider,
parsing::{subtract_ranges, CodeContextRetriever, Document}, parsing::{subtract_ranges, CodeContextRetriever, Document},
semantic_index_settings::SemanticIndexSettings, semantic_index_settings::SemanticIndexSettings,
SemanticIndex, SearchResult, SemanticIndex,
}; };
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
@ -46,21 +46,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
"src": { "src": {
"file1.rs": " "file1.rs": "
fn aaa() { fn aaa() {
println!(\"aaaa!\"); println!(\"aaaaaaaaaaaa!\");
} }
fn zzzzzzzzz() { fn zzzzz() {
println!(\"SLEEPING\"); println!(\"SLEEPING\");
} }
".unindent(), ".unindent(),
"file2.rs": " "file2.rs": "
fn bbb() { fn bbb() {
println!(\"bbbb!\"); println!(\"bbbbbbbbbbbbb!\");
} }
".unindent(), ".unindent(),
"file3.toml": " "file3.toml": "
ZZZZZZZ = 5 ZZZZZZZZZZZZZZZZZZ = 5
".unindent(), ".unindent(),
} }
}), }),
) )
@ -97,27 +97,37 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let search_results = store let search_results = store
.update(cx, |store, cx| { .update(cx, |store, cx| {
store.search_project(project.clone(), "aaaa".to_string(), 5, vec![], vec![], cx) store.search_project(
project.clone(),
"aaaaaabbbbzz".to_string(),
5,
vec![],
vec![],
cx,
)
}) })
.await .await
.unwrap(); .unwrap();
search_results[0].buffer.read_with(cx, |buffer, _cx| { assert_search_results(
assert_eq!(search_results[0].range.start.to_offset(buffer), 0); &search_results,
assert_eq!( &[
buffer.file().unwrap().path().as_ref(), (Path::new("src/file1.rs").into(), 0),
Path::new("src/file1.rs") (Path::new("src/file2.rs").into(), 0),
); (Path::new("src/file3.toml").into(), 0),
}); (Path::new("src/file1.rs").into(), 45),
],
cx,
);
// Test Include Files Functonality // Test Include Files Functonality
let include_files = vec![Glob::new("*.rs").unwrap().compile_matcher()]; let include_files = vec![Glob::new("*.rs").unwrap().compile_matcher()];
let exclude_files = vec![Glob::new("*.rs").unwrap().compile_matcher()]; let exclude_files = vec![Glob::new("*.rs").unwrap().compile_matcher()];
let search_results = store let rust_only_search_results = store
.update(cx, |store, cx| { .update(cx, |store, cx| {
store.search_project( store.search_project(
project.clone(), project.clone(),
"aaaa".to_string(), "aaaaaabbbbzz".to_string(),
5, 5,
include_files, include_files,
vec![], vec![],
@ -127,23 +137,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
.await .await
.unwrap(); .unwrap();
for res in &search_results { assert_search_results(
res.buffer.read_with(cx, |buffer, _cx| { &rust_only_search_results,
assert!(buffer &[
.file() (Path::new("src/file1.rs").into(), 0),
.unwrap() (Path::new("src/file2.rs").into(), 0),
.path() (Path::new("src/file1.rs").into(), 45),
.to_str() ],
.unwrap() cx,
.ends_with("rs")); );
});
}
let search_results = store let no_rust_search_results = store
.update(cx, |store, cx| { .update(cx, |store, cx| {
store.search_project( store.search_project(
project.clone(), project.clone(),
"aaaa".to_string(), "aaaaaabbbbzz".to_string(),
5, 5,
vec![], vec![],
exclude_files, exclude_files,
@ -153,17 +161,12 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
.await .await
.unwrap(); .unwrap();
for res in &search_results { assert_search_results(
res.buffer.read_with(cx, |buffer, _cx| { &no_rust_search_results,
assert!(!buffer &[(Path::new("src/file3.toml").into(), 0)],
.file() cx,
.unwrap() );
.path()
.to_str()
.unwrap()
.ends_with("rs"));
});
}
fs.save( fs.save(
"/the-root/src/file2.rs".as_ref(), "/the-root/src/file2.rs".as_ref(),
&" &"
@ -195,6 +198,26 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
); );
} }
#[track_caller]
fn assert_search_results(
actual: &[SearchResult],
expected: &[(Arc<Path>, usize)],
cx: &TestAppContext,
) {
let actual = actual
.iter()
.map(|search_result| {
search_result.buffer.read_with(cx, |buffer, _cx| {
(
buffer.file().unwrap().path().clone(),
search_result.range.start.to_offset(buffer),
)
})
})
.collect::<Vec<_>>();
assert_eq!(actual, expected);
}
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_rust() { async fn test_code_context_retrieval_rust() {
let language = rust_lang(); let language = rust_lang();