Add an eval binary that evaluates our semantic index against CodeSearchNet (#17375)

This PR is the beginning of an evaluation framework for our AI features.
Right now, we're evaluating our semantic search feature against the
[CodeSearchNet](https://github.com/github/CodeSearchNet) code search
dataset. This dataset is very limited (for the most part, only 1 known
good search result per repo) but it has surfaced some problems with our
search already.

Release Notes:

- N/A

---------

Co-authored-by: Jason <jason@zed.dev>
Co-authored-by: Jason Mancuso <7891333+jvmncs@users.noreply.github.com>
Co-authored-by: Nathan <nathan@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-09-17 12:44:33 -07:00 committed by GitHub
parent 06a13c2983
commit d3d3a093b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 881 additions and 144 deletions

View file

@ -234,30 +234,25 @@ impl EmbeddingIndex {
cx.spawn(async {
while let Ok((entry, handle)) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path);
match fs.load(&entry_abs_path).await {
Ok(text) => {
let language = language_registry
.language_for_file_path(&entry.path)
.await
.ok();
let chunked_file = ChunkedFile {
chunks: chunking::chunk_text(
&text,
language.as_ref(),
&entry.path,
),
handle,
path: entry.path,
mtime: entry.mtime,
text,
};
if let Some(text) = fs.load(&entry_abs_path).await.ok() {
let language = language_registry
.language_for_file_path(&entry.path)
.await
.ok();
let chunked_file = ChunkedFile {
chunks: chunking::chunk_text(
&text,
language.as_ref(),
&entry.path,
),
handle,
path: entry.path,
mtime: entry.mtime,
text,
};
if chunked_files_tx.send(chunked_file).await.is_err() {
return;
}
}
Err(_)=> {
log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
if chunked_files_tx.send(chunked_file).await.is_err() {
return;
}
}
}
@ -358,33 +353,37 @@ impl EmbeddingIndex {
fn persist_embeddings(
&self,
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
cx: &AppContext,
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
let db = self.db;
cx.background_executor().spawn(async move {
while let Some(deletion_range) = deleted_entry_ranges.next().await {
let mut txn = db_connection.write_txn()?;
let start = deletion_range.0.as_ref().map(|start| start.as_str());
let end = deletion_range.1.as_ref().map(|end| end.as_str());
log::debug!("deleting embeddings in range {:?}", &(start, end));
db.delete_range(&mut txn, &(start, end))?;
txn.commit()?;
}
let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
while let Some(embedded_files) = embedded_files.next().await {
let mut txn = db_connection.write_txn()?;
for (file, _) in &embedded_files {
log::debug!("saving embedding for file {:?}", file.path);
let key = db_key_for_path(&file.path);
db.put(&mut txn, &key, file)?;
loop {
// Interleave deletions and persists of embedded files
futures::select_biased! {
deletion_range = deleted_entry_ranges.next() => {
if let Some(deletion_range) = deletion_range {
let mut txn = db_connection.write_txn()?;
let start = deletion_range.0.as_ref().map(|start| start.as_str());
let end = deletion_range.1.as_ref().map(|end| end.as_str());
log::debug!("deleting embeddings in range {:?}", &(start, end));
db.delete_range(&mut txn, &(start, end))?;
txn.commit()?;
}
},
file = embedded_files.next() => {
if let Some((file, _)) = file {
let mut txn = db_connection.write_txn()?;
log::debug!("saving embedding for file {:?}", file.path);
let key = db_key_for_path(&file.path);
db.put(&mut txn, &key, &file)?;
txn.commit()?;
}
},
complete => break,
}
txn.commit()?;
drop(embedded_files);
log::debug!("committed");
}
Ok(())

View file

@ -15,7 +15,14 @@ use log;
use project::{Project, Worktree, WorktreeId};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc};
use std::{
cmp::Ordering,
future::Future,
num::NonZeroUsize,
ops::{Range, RangeInclusive},
path::{Path, PathBuf},
sync::Arc,
};
use util::ResultExt;
#[derive(Debug)]
@ -26,6 +33,14 @@ pub struct SearchResult {
pub score: f32,
}
pub struct LoadedSearchResult {
pub path: Arc<Path>,
pub range: Range<usize>,
pub full_path: PathBuf,
pub file_content: String,
pub row_range: RangeInclusive<u32>,
}
pub struct WorktreeSearchResult {
pub worktree_id: WorktreeId,
pub path: Arc<Path>,

View file

@ -10,14 +10,16 @@ mod worktree_index;
use anyhow::{Context as _, Result};
use collections::HashMap;
use fs::Fs;
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
use project::Project;
use project_index::ProjectIndex;
use std::{path::PathBuf, sync::Arc};
use ui::ViewContext;
use util::ResultExt as _;
use workspace::Workspace;
pub use embedding::*;
pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
pub use project_index_debug_view::ProjectIndexDebugView;
pub use summary_index::FileSummary;
@ -56,27 +58,7 @@ impl SemanticDb {
if cx.has_global::<SemanticDb>() {
cx.update_global::<SemanticDb, _>(|this, cx| {
let project_index = cx.new_model(|cx| {
ProjectIndex::new(
project.clone(),
this.db_connection.clone(),
this.embedding_provider.clone(),
cx,
)
});
let project_weak = project.downgrade();
this.project_indices
.insert(project_weak.clone(), project_index);
cx.on_release(move |_, _, cx| {
if cx.has_global::<SemanticDb>() {
cx.update_global::<SemanticDb, _>(|this, _| {
this.project_indices.remove(&project_weak);
})
}
})
.detach();
this.create_project_index(project, cx);
})
} else {
log::info!("No SemanticDb, skipping project index")
@ -94,6 +76,50 @@ impl SemanticDb {
})
}
pub async fn load_results(
results: Vec<SearchResult>,
fs: &Arc<dyn Fs>,
cx: &AsyncAppContext,
) -> Result<Vec<LoadedSearchResult>> {
let mut loaded_results = Vec::new();
for result in results {
let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| {
let entry_abs_path = worktree.abs_path().join(&result.path);
let mut entry_full_path = PathBuf::from(worktree.root_name());
entry_full_path.push(&result.path);
let file_content = async {
let entry_abs_path = entry_abs_path;
fs.load(&entry_abs_path).await
};
(entry_full_path, file_content)
})?;
if let Some(file_content) = file_content.await.log_err() {
let range_start = result.range.start.min(file_content.len());
let range_end = result.range.end.min(file_content.len());
let start_row = file_content[0..range_start].matches('\n').count() as u32;
let end_row = file_content[0..range_end].matches('\n').count() as u32;
let start_line_byte_offset = file_content[0..range_start]
.rfind('\n')
.map(|pos| pos + 1)
.unwrap_or_default();
let end_line_byte_offset = file_content[range_end..]
.find('\n')
.map(|pos| range_end + pos)
.unwrap_or_else(|| file_content.len());
loaded_results.push(LoadedSearchResult {
path: result.path,
range: start_line_byte_offset..end_line_byte_offset,
full_path,
file_content,
row_range: start_row..=end_row,
});
}
}
Ok(loaded_results)
}
pub fn project_index(
&mut self,
project: Model<Project>,
@ -113,6 +139,36 @@ impl SemanticDb {
})
})
}
pub fn create_project_index(
&mut self,
project: Model<Project>,
cx: &mut AppContext,
) -> Model<ProjectIndex> {
let project_index = cx.new_model(|cx| {
ProjectIndex::new(
project.clone(),
self.db_connection.clone(),
self.embedding_provider.clone(),
cx,
)
});
let project_weak = project.downgrade();
self.project_indices
.insert(project_weak.clone(), project_index.clone());
cx.observe_release(&project, move |_, cx| {
if cx.has_global::<SemanticDb>() {
cx.update_global::<SemanticDb, _>(|this, _| {
this.project_indices.remove(&project_weak);
})
}
})
.detach();
project_index
}
}
#[cfg(test)]
@ -230,34 +286,13 @@ mod tests {
let project = Project::test(fs, [project_path], cx).await;
cx.update(|cx| {
let project_index = cx.update(|cx| {
let language_registry = project.read(cx).languages().clone();
let node_runtime = project.read(cx).node_runtime().unwrap().clone();
languages::init(language_registry, node_runtime, cx);
// Manually create and insert the ProjectIndex
let project_index = cx.new_model(|cx| {
ProjectIndex::new(
project.clone(),
semantic_index.db_connection.clone(),
semantic_index.embedding_provider.clone(),
cx,
)
});
semantic_index
.project_indices
.insert(project.downgrade(), project_index);
semantic_index.create_project_index(project.clone(), cx)
});
let project_index = cx
.update(|_cx| {
semantic_index
.project_indices
.get(&project.downgrade())
.cloned()
})
.unwrap();
cx.run_until_parked();
while cx
.update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))