Add a slash command for automatically retrieving relevant context (#17972)

* [x] put this slash command behind a feature flag until we release
embedding access to the general population
* [x] choose a name for this slash command and name the rust module to
match

Release Notes:

- N/A

---------

Co-authored-by: Jason <jason@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Jason Mancuso <7891333+jvmncs@users.noreply.github.com>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Max Brunsfeld 2024-09-20 15:09:18 -07:00 committed by GitHub
parent 5905fbb9ac
commit e309fbda2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 683 additions and 223 deletions

View file

@ -31,20 +31,23 @@ pub struct SearchResult {
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
pub query_index: usize,
}
#[derive(Debug, PartialEq, Eq)]
pub struct LoadedSearchResult {
pub path: Arc<Path>,
pub range: Range<usize>,
pub full_path: PathBuf,
pub file_content: String,
pub excerpt_content: String,
pub row_range: RangeInclusive<u32>,
pub query_index: usize,
}
pub struct WorktreeSearchResult {
pub worktree_id: WorktreeId,
pub path: Arc<Path>,
pub range: Range<usize>,
pub query_index: usize,
pub score: f32,
}
@ -227,7 +230,7 @@ impl ProjectIndex {
pub fn search(
&self,
query: String,
queries: Vec<String>,
limit: usize,
cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> {
@ -275,15 +278,18 @@ impl ProjectIndex {
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {query}");
log::info!("Searching for {queries:?}");
let queries: Vec<TextToEmbed> = queries
.iter()
.map(|s| TextToEmbed::new(s.as_str()))
.collect();
let query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
.await?;
let query_embedding = query_embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("no embedding for query"))?;
let query_embeddings = embedding_provider.embed(&queries[..]).await?;
if query_embeddings.len() != queries.len() {
return Err(anyhow!(
"The number of query embeddings does not match the number of queries"
));
}
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
@ -292,28 +298,34 @@ impl ProjectIndex {
#[cfg(debug_assertions)]
let search_start = std::time::Instant::now();
cx.background_executor()
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
let score = chunk.embedding.similarity(&query_embedding);
let (score, query_index) =
chunk.embedding.similarity(&query_embeddings);
let ix = match results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
score,
},
);
results.truncate(limit);
if ix < limit {
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
query_index,
score,
},
);
if results.len() > limit {
results.pop();
}
}
}
});
}
@ -333,6 +345,7 @@ impl ProjectIndex {
path: result.path,
range: result.range,
score: result.score,
query_index: result.query_index,
})
}));
}