Add a slash command for automatically retrieving relevant context (#17972)
* [x] put this slash command behind a feature flag until we release embedding access to the general population * [x] choose a name for this slash command and name the rust module to match Release Notes: - N/A --------- Co-authored-by: Jason <jason@zed.dev> Co-authored-by: Richard <richard@zed.dev> Co-authored-by: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
5905fbb9ac
commit
e309fbda2a
14 changed files with 683 additions and 223 deletions
|
@ -98,7 +98,7 @@ fn main() {
|
|||
.update(|cx| {
|
||||
let project_index = project_index.read(cx);
|
||||
let query = "converting an anchor to a point";
|
||||
project_index.search(query.into(), 4, cx)
|
||||
project_index.search(vec![query.into()], 4, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
|
|
|
@ -42,14 +42,23 @@ impl Embedding {
|
|||
self.0.len()
|
||||
}
|
||||
|
||||
pub fn similarity(self, other: &Embedding) -> f32 {
|
||||
debug_assert_eq!(self.0.len(), other.0.len());
|
||||
self.0
|
||||
pub fn similarity(&self, others: &[Embedding]) -> (f32, usize) {
|
||||
debug_assert!(others.iter().all(|other| self.0.len() == other.0.len()));
|
||||
others
|
||||
.iter()
|
||||
.copied()
|
||||
.zip(other.0.iter().copied())
|
||||
.map(|(a, b)| a * b)
|
||||
.sum()
|
||||
.enumerate()
|
||||
.map(|(index, other)| {
|
||||
let dot_product: f32 = self
|
||||
.0
|
||||
.iter()
|
||||
.copied()
|
||||
.zip(other.0.iter().copied())
|
||||
.map(|(a, b)| a * b)
|
||||
.sum();
|
||||
(dot_product, index)
|
||||
})
|
||||
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap_or((0.0, 0))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -31,20 +31,23 @@ pub struct SearchResult {
|
|||
pub path: Arc<Path>,
|
||||
pub range: Range<usize>,
|
||||
pub score: f32,
|
||||
pub query_index: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct LoadedSearchResult {
|
||||
pub path: Arc<Path>,
|
||||
pub range: Range<usize>,
|
||||
pub full_path: PathBuf,
|
||||
pub file_content: String,
|
||||
pub excerpt_content: String,
|
||||
pub row_range: RangeInclusive<u32>,
|
||||
pub query_index: usize,
|
||||
}
|
||||
|
||||
pub struct WorktreeSearchResult {
|
||||
pub worktree_id: WorktreeId,
|
||||
pub path: Arc<Path>,
|
||||
pub range: Range<usize>,
|
||||
pub query_index: usize,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
|
@ -227,7 +230,7 @@ impl ProjectIndex {
|
|||
|
||||
pub fn search(
|
||||
&self,
|
||||
query: String,
|
||||
queries: Vec<String>,
|
||||
limit: usize,
|
||||
cx: &AppContext,
|
||||
) -> Task<Result<Vec<SearchResult>>> {
|
||||
|
@ -275,15 +278,18 @@ impl ProjectIndex {
|
|||
cx.spawn(|cx| async move {
|
||||
#[cfg(debug_assertions)]
|
||||
let embedding_query_start = std::time::Instant::now();
|
||||
log::info!("Searching for {query}");
|
||||
log::info!("Searching for {queries:?}");
|
||||
let queries: Vec<TextToEmbed> = queries
|
||||
.iter()
|
||||
.map(|s| TextToEmbed::new(s.as_str()))
|
||||
.collect();
|
||||
|
||||
let query_embeddings = embedding_provider
|
||||
.embed(&[TextToEmbed::new(&query)])
|
||||
.await?;
|
||||
let query_embedding = query_embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("no embedding for query"))?;
|
||||
let query_embeddings = embedding_provider.embed(&queries[..]).await?;
|
||||
if query_embeddings.len() != queries.len() {
|
||||
return Err(anyhow!(
|
||||
"The number of query embeddings does not match the number of queries"
|
||||
));
|
||||
}
|
||||
|
||||
let mut results_by_worker = Vec::new();
|
||||
for _ in 0..cx.background_executor().num_cpus() {
|
||||
|
@ -292,28 +298,34 @@ impl ProjectIndex {
|
|||
|
||||
#[cfg(debug_assertions)]
|
||||
let search_start = std::time::Instant::now();
|
||||
|
||||
cx.background_executor()
|
||||
.scoped(|cx| {
|
||||
for results in results_by_worker.iter_mut() {
|
||||
cx.spawn(async {
|
||||
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
|
||||
let score = chunk.embedding.similarity(&query_embedding);
|
||||
let (score, query_index) =
|
||||
chunk.embedding.similarity(&query_embeddings);
|
||||
|
||||
let ix = match results.binary_search_by(|probe| {
|
||||
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) | Err(ix) => ix,
|
||||
};
|
||||
results.insert(
|
||||
ix,
|
||||
WorktreeSearchResult {
|
||||
worktree_id,
|
||||
path: path.clone(),
|
||||
range: chunk.chunk.range.clone(),
|
||||
score,
|
||||
},
|
||||
);
|
||||
results.truncate(limit);
|
||||
if ix < limit {
|
||||
results.insert(
|
||||
ix,
|
||||
WorktreeSearchResult {
|
||||
worktree_id,
|
||||
path: path.clone(),
|
||||
range: chunk.chunk.range.clone(),
|
||||
query_index,
|
||||
score,
|
||||
},
|
||||
);
|
||||
if results.len() > limit {
|
||||
results.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -333,6 +345,7 @@ impl ProjectIndex {
|
|||
path: result.path,
|
||||
range: result.range,
|
||||
score: result.score,
|
||||
query_index: result.query_index,
|
||||
})
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -12,8 +12,13 @@ use anyhow::{Context as _, Result};
|
|||
use collections::HashMap;
|
||||
use fs::Fs;
|
||||
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
|
||||
use project::Project;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use language::LineEnding;
|
||||
use project::{Project, Worktree};
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use ui::ViewContext;
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
|
@ -77,46 +82,127 @@ impl SemanticDb {
|
|||
}
|
||||
|
||||
pub async fn load_results(
|
||||
results: Vec<SearchResult>,
|
||||
mut results: Vec<SearchResult>,
|
||||
fs: &Arc<dyn Fs>,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Result<Vec<LoadedSearchResult>> {
|
||||
let mut loaded_results = Vec::new();
|
||||
for result in results {
|
||||
let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| {
|
||||
let entry_abs_path = worktree.abs_path().join(&result.path);
|
||||
let mut entry_full_path = PathBuf::from(worktree.root_name());
|
||||
entry_full_path.push(&result.path);
|
||||
let file_content = async {
|
||||
let entry_abs_path = entry_abs_path;
|
||||
fs.load(&entry_abs_path).await
|
||||
};
|
||||
(entry_full_path, file_content)
|
||||
})?;
|
||||
if let Some(file_content) = file_content.await.log_err() {
|
||||
let range_start = result.range.start.min(file_content.len());
|
||||
let range_end = result.range.end.min(file_content.len());
|
||||
|
||||
let start_row = file_content[0..range_start].matches('\n').count() as u32;
|
||||
let end_row = file_content[0..range_end].matches('\n').count() as u32;
|
||||
let start_line_byte_offset = file_content[0..range_start]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or_default();
|
||||
let end_line_byte_offset = file_content[range_end..]
|
||||
.find('\n')
|
||||
.map(|pos| range_end + pos)
|
||||
.unwrap_or_else(|| file_content.len());
|
||||
|
||||
loaded_results.push(LoadedSearchResult {
|
||||
path: result.path,
|
||||
range: start_line_byte_offset..end_line_byte_offset,
|
||||
full_path,
|
||||
file_content,
|
||||
row_range: start_row..=end_row,
|
||||
});
|
||||
let mut max_scores_by_path = HashMap::<_, (f32, usize)>::default();
|
||||
for result in &results {
|
||||
let (score, query_index) = max_scores_by_path
|
||||
.entry((result.worktree.clone(), result.path.clone()))
|
||||
.or_default();
|
||||
if result.score > *score {
|
||||
*score = result.score;
|
||||
*query_index = result.query_index;
|
||||
}
|
||||
}
|
||||
|
||||
results.sort_by(|a, b| {
|
||||
let max_score_a = max_scores_by_path[&(a.worktree.clone(), a.path.clone())].0;
|
||||
let max_score_b = max_scores_by_path[&(b.worktree.clone(), b.path.clone())].0;
|
||||
max_score_b
|
||||
.partial_cmp(&max_score_a)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.worktree.entity_id().cmp(&b.worktree.entity_id()))
|
||||
.then_with(|| a.path.cmp(&b.path))
|
||||
.then_with(|| a.range.start.cmp(&b.range.start))
|
||||
});
|
||||
|
||||
let mut last_loaded_file: Option<(Model<Worktree>, Arc<Path>, PathBuf, String)> = None;
|
||||
let mut loaded_results = Vec::<LoadedSearchResult>::new();
|
||||
for result in results {
|
||||
let full_path;
|
||||
let file_content;
|
||||
if let Some(last_loaded_file) =
|
||||
last_loaded_file
|
||||
.as_ref()
|
||||
.filter(|(last_worktree, last_path, _, _)| {
|
||||
last_worktree == &result.worktree && last_path == &result.path
|
||||
})
|
||||
{
|
||||
full_path = last_loaded_file.2.clone();
|
||||
file_content = &last_loaded_file.3;
|
||||
} else {
|
||||
let output = result.worktree.read_with(cx, |worktree, _cx| {
|
||||
let entry_abs_path = worktree.abs_path().join(&result.path);
|
||||
let mut entry_full_path = PathBuf::from(worktree.root_name());
|
||||
entry_full_path.push(&result.path);
|
||||
let file_content = async {
|
||||
let entry_abs_path = entry_abs_path;
|
||||
fs.load(&entry_abs_path).await
|
||||
};
|
||||
(entry_full_path, file_content)
|
||||
})?;
|
||||
full_path = output.0;
|
||||
let Some(content) = output.1.await.log_err() else {
|
||||
continue;
|
||||
};
|
||||
last_loaded_file = Some((
|
||||
result.worktree.clone(),
|
||||
result.path.clone(),
|
||||
full_path.clone(),
|
||||
content,
|
||||
));
|
||||
file_content = &last_loaded_file.as_ref().unwrap().3;
|
||||
};
|
||||
|
||||
let query_index = max_scores_by_path[&(result.worktree.clone(), result.path.clone())].1;
|
||||
|
||||
let mut range_start = result.range.start.min(file_content.len());
|
||||
let mut range_end = result.range.end.min(file_content.len());
|
||||
while !file_content.is_char_boundary(range_start) {
|
||||
range_start += 1;
|
||||
}
|
||||
while !file_content.is_char_boundary(range_end) {
|
||||
range_end += 1;
|
||||
}
|
||||
|
||||
let start_row = file_content[0..range_start].matches('\n').count() as u32;
|
||||
let mut end_row = file_content[0..range_end].matches('\n').count() as u32;
|
||||
let start_line_byte_offset = file_content[0..range_start]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or_default();
|
||||
let mut end_line_byte_offset = range_end;
|
||||
if file_content[..end_line_byte_offset].ends_with('\n') {
|
||||
end_row -= 1;
|
||||
} else {
|
||||
end_line_byte_offset = file_content[range_end..]
|
||||
.find('\n')
|
||||
.map(|pos| range_end + pos + 1)
|
||||
.unwrap_or_else(|| file_content.len());
|
||||
}
|
||||
let mut excerpt_content =
|
||||
file_content[start_line_byte_offset..end_line_byte_offset].to_string();
|
||||
LineEnding::normalize(&mut excerpt_content);
|
||||
|
||||
if let Some(prev_result) = loaded_results.last_mut() {
|
||||
if prev_result.full_path == full_path {
|
||||
if *prev_result.row_range.end() + 1 == start_row {
|
||||
prev_result.row_range = *prev_result.row_range.start()..=end_row;
|
||||
prev_result.excerpt_content.push_str(&excerpt_content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loaded_results.push(LoadedSearchResult {
|
||||
path: result.path,
|
||||
full_path,
|
||||
excerpt_content,
|
||||
row_range: start_row..=end_row,
|
||||
query_index,
|
||||
});
|
||||
}
|
||||
|
||||
for result in &mut loaded_results {
|
||||
while result.excerpt_content.ends_with("\n\n") {
|
||||
result.excerpt_content.pop();
|
||||
result.row_range =
|
||||
*result.row_range.start()..=result.row_range.end().saturating_sub(1)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(loaded_results)
|
||||
}
|
||||
|
||||
|
@ -312,7 +398,7 @@ mod tests {
|
|||
.update(|cx| {
|
||||
let project_index = project_index.read(cx);
|
||||
let query = "garbage in, garbage out";
|
||||
project_index.search(query.into(), 4, cx)
|
||||
project_index.search(vec![query.into()], 4, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -426,4 +512,117 @@ mod tests {
|
|||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_load_search_results(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project_path = Path::new("/fake_project");
|
||||
|
||||
let file1_content = "one\ntwo\nthree\nfour\nfive\n";
|
||||
let file2_content = "aaa\nbbb\nccc\nddd\neee\n";
|
||||
|
||||
fs.insert_tree(
|
||||
project_path,
|
||||
json!({
|
||||
"file1.txt": file1_content,
|
||||
"file2.txt": file2_content,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let fs = fs as Arc<dyn Fs>;
|
||||
let project = Project::test(fs.clone(), [project_path], cx).await;
|
||||
let worktree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap());
|
||||
|
||||
// chunk that is already newline-aligned
|
||||
let search_results = vec![SearchResult {
|
||||
worktree: worktree.clone(),
|
||||
path: Path::new("file1.txt").into(),
|
||||
range: 0..file1_content.find("four").unwrap(),
|
||||
score: 0.5,
|
||||
query_index: 0,
|
||||
}];
|
||||
assert_eq!(
|
||||
SemanticDb::load_results(search_results, &fs, &cx.to_async())
|
||||
.await
|
||||
.unwrap(),
|
||||
&[LoadedSearchResult {
|
||||
path: Path::new("file1.txt").into(),
|
||||
full_path: "fake_project/file1.txt".into(),
|
||||
excerpt_content: "one\ntwo\nthree\n".into(),
|
||||
row_range: 0..=2,
|
||||
query_index: 0,
|
||||
}]
|
||||
);
|
||||
|
||||
// chunk that is *not* newline-aligned
|
||||
let search_results = vec![SearchResult {
|
||||
worktree: worktree.clone(),
|
||||
path: Path::new("file1.txt").into(),
|
||||
range: file1_content.find("two").unwrap() + 1..file1_content.find("four").unwrap() + 2,
|
||||
score: 0.5,
|
||||
query_index: 0,
|
||||
}];
|
||||
assert_eq!(
|
||||
SemanticDb::load_results(search_results, &fs, &cx.to_async())
|
||||
.await
|
||||
.unwrap(),
|
||||
&[LoadedSearchResult {
|
||||
path: Path::new("file1.txt").into(),
|
||||
full_path: "fake_project/file1.txt".into(),
|
||||
excerpt_content: "two\nthree\nfour\n".into(),
|
||||
row_range: 1..=3,
|
||||
query_index: 0,
|
||||
}]
|
||||
);
|
||||
|
||||
// chunks that are adjacent
|
||||
|
||||
let search_results = vec![
|
||||
SearchResult {
|
||||
worktree: worktree.clone(),
|
||||
path: Path::new("file1.txt").into(),
|
||||
range: file1_content.find("two").unwrap()..file1_content.len(),
|
||||
score: 0.6,
|
||||
query_index: 0,
|
||||
},
|
||||
SearchResult {
|
||||
worktree: worktree.clone(),
|
||||
path: Path::new("file1.txt").into(),
|
||||
range: 0..file1_content.find("two").unwrap(),
|
||||
score: 0.5,
|
||||
query_index: 1,
|
||||
},
|
||||
SearchResult {
|
||||
worktree: worktree.clone(),
|
||||
path: Path::new("file2.txt").into(),
|
||||
range: 0..file2_content.len(),
|
||||
score: 0.8,
|
||||
query_index: 1,
|
||||
},
|
||||
];
|
||||
assert_eq!(
|
||||
SemanticDb::load_results(search_results, &fs, &cx.to_async())
|
||||
.await
|
||||
.unwrap(),
|
||||
&[
|
||||
LoadedSearchResult {
|
||||
path: Path::new("file2.txt").into(),
|
||||
full_path: "fake_project/file2.txt".into(),
|
||||
excerpt_content: file2_content.into(),
|
||||
row_range: 0..=4,
|
||||
query_index: 1,
|
||||
},
|
||||
LoadedSearchResult {
|
||||
path: Path::new("file1.txt").into(),
|
||||
full_path: "fake_project/file1.txt".into(),
|
||||
excerpt_content: file1_content.into(),
|
||||
row_range: 0..=4,
|
||||
query_index: 0,
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue