Add an eval
binary that evaluates our semantic index against CodeSearchNet (#17375)
This PR is the beginning of an evaluation framework for our AI features. Right now, we're evaluating our semantic search feature against the [CodeSearchNet](https://github.com/github/CodeSearchNet) code search dataset. This dataset is very limited (for the most part, only 1 known good search result per repo) but it has surfaced some problems with our search already. Release Notes: - N/A --------- Co-authored-by: Jason <jason@zed.dev> Co-authored-by: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Co-authored-by: Nathan <nathan@zed.dev> Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
parent
06a13c2983
commit
d3d3a093b4
14 changed files with 881 additions and 144 deletions
17
.github/workflows/ci.yml
vendored
17
.github/workflows/ci.yml
vendored
|
@ -101,7 +101,7 @@ jobs:
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
name: (Linux) Run Clippy and tests
|
name: (Linux) Run Clippy and tests
|
||||||
runs-on:
|
runs-on:
|
||||||
- hosted-linux-x86-1
|
- buildjet-16vcpu-ubuntu-2204
|
||||||
steps:
|
steps:
|
||||||
- name: Add Rust to the PATH
|
- name: Add Rust to the PATH
|
||||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||||
|
@ -111,6 +111,11 @@ jobs:
|
||||||
with:
|
with:
|
||||||
clean: false
|
clean: false
|
||||||
|
|
||||||
|
- name: Cache dependencies
|
||||||
|
uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
|
||||||
|
with:
|
||||||
|
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||||
|
|
||||||
- name: Install Linux dependencies
|
- name: Install Linux dependencies
|
||||||
run: ./script/linux
|
run: ./script/linux
|
||||||
|
|
||||||
|
@ -264,7 +269,7 @@ jobs:
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
name: Create a Linux bundle
|
name: Create a Linux bundle
|
||||||
runs-on:
|
runs-on:
|
||||||
- hosted-linux-x86-1
|
- buildjet-16vcpu-ubuntu-2204
|
||||||
if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||||
needs: [linux_tests]
|
needs: [linux_tests]
|
||||||
env:
|
env:
|
||||||
|
@ -279,9 +284,6 @@ jobs:
|
||||||
- name: Install Linux dependencies
|
- name: Install Linux dependencies
|
||||||
run: ./script/linux
|
run: ./script/linux
|
||||||
|
|
||||||
- name: Limit target directory size
|
|
||||||
run: script/clear-target-dir-if-larger-than 100
|
|
||||||
|
|
||||||
- name: Determine version and release channel
|
- name: Determine version and release channel
|
||||||
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
|
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: |
|
run: |
|
||||||
|
@ -335,7 +337,7 @@ jobs:
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
name: Create arm64 Linux bundle
|
name: Create arm64 Linux bundle
|
||||||
runs-on:
|
runs-on:
|
||||||
- hosted-linux-arm-1
|
- buildjet-16vcpu-ubuntu-2204-arm
|
||||||
if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||||
needs: [linux_tests]
|
needs: [linux_tests]
|
||||||
env:
|
env:
|
||||||
|
@ -350,9 +352,6 @@ jobs:
|
||||||
- name: Install Linux dependencies
|
- name: Install Linux dependencies
|
||||||
run: ./script/linux
|
run: ./script/linux
|
||||||
|
|
||||||
- name: Limit target directory size
|
|
||||||
run: script/clear-target-dir-if-larger-than 100
|
|
||||||
|
|
||||||
- name: Determine version and release channel
|
- name: Determine version and release channel
|
||||||
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
|
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: |
|
run: |
|
||||||
|
|
27
Cargo.lock
generated
27
Cargo.lock
generated
|
@ -4000,6 +4000,33 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "evals"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"clap",
|
||||||
|
"client",
|
||||||
|
"clock",
|
||||||
|
"collections",
|
||||||
|
"env_logger",
|
||||||
|
"feature_flags",
|
||||||
|
"fs",
|
||||||
|
"git",
|
||||||
|
"gpui",
|
||||||
|
"http_client",
|
||||||
|
"language",
|
||||||
|
"languages",
|
||||||
|
"node_runtime",
|
||||||
|
"open_ai",
|
||||||
|
"project",
|
||||||
|
"semantic_index",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"settings",
|
||||||
|
"smol",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "event-listener"
|
name = "event-listener"
|
||||||
version = "2.5.3"
|
version = "2.5.3"
|
||||||
|
|
|
@ -27,6 +27,7 @@ members = [
|
||||||
"crates/diagnostics",
|
"crates/diagnostics",
|
||||||
"crates/docs_preprocessor",
|
"crates/docs_preprocessor",
|
||||||
"crates/editor",
|
"crates/editor",
|
||||||
|
"crates/evals",
|
||||||
"crates/extension",
|
"crates/extension",
|
||||||
"crates/extension_api",
|
"crates/extension_api",
|
||||||
"crates/extension_cli",
|
"crates/extension_cli",
|
||||||
|
|
|
@ -3282,7 +3282,7 @@ impl ContextEditor {
|
||||||
|
|
||||||
let fence = codeblock_fence_for_path(
|
let fence = codeblock_fence_for_path(
|
||||||
filename.as_deref(),
|
filename.as_deref(),
|
||||||
Some(selection.start.row..selection.end.row),
|
Some(selection.start.row..=selection.end.row),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some((line_comment_prefix, outline_text)) =
|
if let Some((line_comment_prefix, outline_text)) =
|
||||||
|
|
|
@ -8,7 +8,7 @@ use project::{PathMatchCandidateSet, Project};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
fmt::Write,
|
fmt::Write,
|
||||||
ops::Range,
|
ops::{Range, RangeInclusive},
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
sync::{atomic::AtomicBool, Arc},
|
sync::{atomic::AtomicBool, Arc},
|
||||||
};
|
};
|
||||||
|
@ -342,7 +342,10 @@ fn collect_files(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn codeblock_fence_for_path(path: Option<&Path>, row_range: Option<Range<u32>>) -> String {
|
pub fn codeblock_fence_for_path(
|
||||||
|
path: Option<&Path>,
|
||||||
|
row_range: Option<RangeInclusive<u32>>,
|
||||||
|
) -> String {
|
||||||
let mut text = String::new();
|
let mut text = String::new();
|
||||||
write!(text, "```").unwrap();
|
write!(text, "```").unwrap();
|
||||||
|
|
||||||
|
@ -357,7 +360,7 @@ pub fn codeblock_fence_for_path(path: Option<&Path>, row_range: Option<Range<u32
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(row_range) = row_range {
|
if let Some(row_range) = row_range {
|
||||||
write!(text, ":{}-{}", row_range.start + 1, row_range.end + 1).unwrap();
|
write!(text, ":{}-{}", row_range.start() + 1, row_range.end() + 1).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
text.push('\n');
|
text.push('\n');
|
||||||
|
|
|
@ -8,14 +8,12 @@ use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
|
||||||
use feature_flags::FeatureFlag;
|
use feature_flags::FeatureFlag;
|
||||||
use gpui::{AppContext, Task, WeakView};
|
use gpui::{AppContext, Task, WeakView};
|
||||||
use language::{CodeLabel, LineEnding, LspAdapterDelegate};
|
use language::{CodeLabel, LineEnding, LspAdapterDelegate};
|
||||||
use semantic_index::SemanticDb;
|
use semantic_index::{LoadedSearchResult, SemanticDb};
|
||||||
use std::{
|
use std::{
|
||||||
fmt::Write,
|
fmt::Write,
|
||||||
path::PathBuf,
|
|
||||||
sync::{atomic::AtomicBool, Arc},
|
sync::{atomic::AtomicBool, Arc},
|
||||||
};
|
};
|
||||||
use ui::{prelude::*, IconName};
|
use ui::{prelude::*, IconName};
|
||||||
use util::ResultExt;
|
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
|
||||||
pub(crate) struct SearchSlashCommandFeatureFlag;
|
pub(crate) struct SearchSlashCommandFeatureFlag;
|
||||||
|
@ -107,52 +105,28 @@ impl SlashCommand for SearchSlashCommand {
|
||||||
})?
|
})?
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut loaded_results = Vec::new();
|
let loaded_results = SemanticDb::load_results(results, &fs, &cx).await?;
|
||||||
for result in results {
|
|
||||||
let (full_path, file_content) =
|
|
||||||
result.worktree.read_with(&cx, |worktree, _cx| {
|
|
||||||
let entry_abs_path = worktree.abs_path().join(&result.path);
|
|
||||||
let mut entry_full_path = PathBuf::from(worktree.root_name());
|
|
||||||
entry_full_path.push(&result.path);
|
|
||||||
let file_content = async {
|
|
||||||
let entry_abs_path = entry_abs_path;
|
|
||||||
fs.load(&entry_abs_path).await
|
|
||||||
};
|
|
||||||
(entry_full_path, file_content)
|
|
||||||
})?;
|
|
||||||
if let Some(file_content) = file_content.await.log_err() {
|
|
||||||
loaded_results.push((result, full_path, file_content));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let output = cx
|
let output = cx
|
||||||
.background_executor()
|
.background_executor()
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
let mut text = format!("Search results for {query}:\n");
|
let mut text = format!("Search results for {query}:\n");
|
||||||
let mut sections = Vec::new();
|
let mut sections = Vec::new();
|
||||||
for (result, full_path, file_content) in loaded_results {
|
for LoadedSearchResult {
|
||||||
let range_start = result.range.start.min(file_content.len());
|
path,
|
||||||
let range_end = result.range.end.min(file_content.len());
|
range,
|
||||||
|
full_path,
|
||||||
let start_row = file_content[0..range_start].matches('\n').count() as u32;
|
file_content,
|
||||||
let end_row = file_content[0..range_end].matches('\n').count() as u32;
|
row_range,
|
||||||
let start_line_byte_offset = file_content[0..range_start]
|
} in loaded_results
|
||||||
.rfind('\n')
|
{
|
||||||
.map(|pos| pos + 1)
|
|
||||||
.unwrap_or_default();
|
|
||||||
let end_line_byte_offset = file_content[range_end..]
|
|
||||||
.find('\n')
|
|
||||||
.map(|pos| range_end + pos)
|
|
||||||
.unwrap_or_else(|| file_content.len());
|
|
||||||
|
|
||||||
let section_start_ix = text.len();
|
let section_start_ix = text.len();
|
||||||
text.push_str(&codeblock_fence_for_path(
|
text.push_str(&codeblock_fence_for_path(
|
||||||
Some(&result.path),
|
Some(&path),
|
||||||
Some(start_row..end_row),
|
Some(row_range.clone()),
|
||||||
));
|
));
|
||||||
|
|
||||||
let mut excerpt =
|
let mut excerpt = file_content[range].to_string();
|
||||||
file_content[start_line_byte_offset..end_line_byte_offset].to_string();
|
|
||||||
LineEnding::normalize(&mut excerpt);
|
LineEnding::normalize(&mut excerpt);
|
||||||
text.push_str(&excerpt);
|
text.push_str(&excerpt);
|
||||||
writeln!(text, "\n```\n").unwrap();
|
writeln!(text, "\n```\n").unwrap();
|
||||||
|
@ -161,7 +135,7 @@ impl SlashCommand for SearchSlashCommand {
|
||||||
section_start_ix..section_end_ix,
|
section_start_ix..section_end_ix,
|
||||||
Some(&full_path),
|
Some(&full_path),
|
||||||
false,
|
false,
|
||||||
Some(start_row + 1..end_row + 1),
|
Some(row_range.start() + 1..row_range.end() + 1),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
37
crates/evals/Cargo.toml
Normal file
37
crates/evals/Cargo.toml
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
[package]
|
||||||
|
name = "evals"
|
||||||
|
description = "Evaluations for Zed's AI features"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
license = "GPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "eval"
|
||||||
|
path = "src/eval.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
clap.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
client.workspace = true
|
||||||
|
clock.workspace = true
|
||||||
|
collections.workspace = true
|
||||||
|
env_logger.workspace = true
|
||||||
|
feature_flags.workspace = true
|
||||||
|
fs.workspace = true
|
||||||
|
git.workspace = true
|
||||||
|
gpui.workspace = true
|
||||||
|
language.workspace = true
|
||||||
|
languages.workspace = true
|
||||||
|
http_client.workspace = true
|
||||||
|
open_ai.workspace = true
|
||||||
|
project.workspace = true
|
||||||
|
settings.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
smol.workspace = true
|
||||||
|
semantic_index.workspace = true
|
||||||
|
node_runtime.workspace = true
|
1
crates/evals/LICENSE-GPL
Symbolic link
1
crates/evals/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../../LICENSE-GPL
|
14
crates/evals/build.rs
Normal file
14
crates/evals/build.rs
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
fn main() {
|
||||||
|
if cfg!(target_os = "macos") {
|
||||||
|
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-env-changed=ZED_BUNDLE");
|
||||||
|
if std::env::var("ZED_BUNDLE").ok().as_deref() == Some("true") {
|
||||||
|
// Find WebRTC.framework in the Frameworks folder when running as part of an application bundle.
|
||||||
|
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../Frameworks");
|
||||||
|
} else {
|
||||||
|
// Find WebRTC.framework as a sibling of the executable when running outside of an application bundle.
|
||||||
|
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
631
crates/evals/src/eval.rs
Normal file
631
crates/evals/src/eval.rs
Normal file
|
@ -0,0 +1,631 @@
|
||||||
|
use ::fs::{Fs, RealFs};
|
||||||
|
use anyhow::Result;
|
||||||
|
use clap::Parser;
|
||||||
|
use client::{Client, UserStore};
|
||||||
|
use clock::RealSystemClock;
|
||||||
|
use collections::BTreeMap;
|
||||||
|
use feature_flags::FeatureFlagAppExt as _;
|
||||||
|
use git::GitHostingProviderRegistry;
|
||||||
|
use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
|
||||||
|
use http_client::{HttpClient, Method};
|
||||||
|
use language::LanguageRegistry;
|
||||||
|
use node_runtime::FakeNodeRuntime;
|
||||||
|
use open_ai::OpenAiEmbeddingModel;
|
||||||
|
use project::Project;
|
||||||
|
use semantic_index::{OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use settings::SettingsStore;
|
||||||
|
use smol::channel::bounded;
|
||||||
|
use smol::io::AsyncReadExt;
|
||||||
|
use smol::Timer;
|
||||||
|
use std::ops::RangeInclusive;
|
||||||
|
use std::time::Duration;
|
||||||
|
use std::{
|
||||||
|
fs,
|
||||||
|
path::Path,
|
||||||
|
process::{exit, Command, Stdio},
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicUsize, Ordering::SeqCst},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
|
||||||
|
const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
|
||||||
|
const EVAL_DB_PATH: &'static str = "target/eval_db";
|
||||||
|
const SEARCH_RESULT_LIMIT: usize = 8;
|
||||||
|
const SKIP_EVAL_PATH: &'static str = ".skip_eval";
|
||||||
|
|
||||||
|
#[derive(clap::Parser)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Cli {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Commands,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(clap::Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
Fetch {},
|
||||||
|
Run {
|
||||||
|
#[arg(long)]
|
||||||
|
repo: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
|
struct EvaluationProject {
|
||||||
|
repo: String,
|
||||||
|
sha: String,
|
||||||
|
queries: Vec<EvaluationQuery>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
struct EvaluationQuery {
|
||||||
|
query: String,
|
||||||
|
expected_results: Vec<EvaluationSearchResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||||
|
struct EvaluationSearchResult {
|
||||||
|
file: String,
|
||||||
|
lines: RangeInclusive<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
|
struct EvaluationProjectOutcome {
|
||||||
|
repo: String,
|
||||||
|
sha: String,
|
||||||
|
queries: Vec<EvaluationQueryOutcome>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
struct EvaluationQueryOutcome {
|
||||||
|
repo: String,
|
||||||
|
query: String,
|
||||||
|
expected_results: Vec<EvaluationSearchResult>,
|
||||||
|
actual_results: Vec<EvaluationSearchResult>,
|
||||||
|
covered_file_count: usize,
|
||||||
|
overlapped_result_count: usize,
|
||||||
|
covered_result_count: usize,
|
||||||
|
total_result_count: usize,
|
||||||
|
covered_result_indices: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let cli = Cli::parse();
|
||||||
|
env_logger::init();
|
||||||
|
|
||||||
|
gpui::App::headless().run(move |cx| {
|
||||||
|
let executor = cx.background_executor().clone();
|
||||||
|
|
||||||
|
match cli.command {
|
||||||
|
Commands::Fetch {} => {
|
||||||
|
executor
|
||||||
|
.clone()
|
||||||
|
.spawn(async move {
|
||||||
|
if let Err(err) = fetch_evaluation_resources(&executor).await {
|
||||||
|
eprintln!("Error: {}", err);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
exit(0);
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
Commands::Run { repo } => {
|
||||||
|
cx.spawn(|mut cx| async move {
|
||||||
|
if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
|
||||||
|
eprintln!("Error: {}", err);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
exit(0);
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_evaluation_resources(executor: &BackgroundExecutor) -> Result<()> {
|
||||||
|
let http_client = http_client::HttpClientWithProxy::new(None, None);
|
||||||
|
fetch_code_search_net_resources(&http_client).await?;
|
||||||
|
fetch_eval_repos(executor, &http_client).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
|
||||||
|
eprintln!("Fetching CodeSearchNet evaluations...");
|
||||||
|
|
||||||
|
let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
|
||||||
|
|
||||||
|
let dataset_dir = Path::new(CODESEARCH_NET_DIR);
|
||||||
|
fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
|
||||||
|
|
||||||
|
// Fetch the annotations CSV, which contains the human-annotated search relevances
|
||||||
|
let annotations_path = dataset_dir.join("annotations.csv");
|
||||||
|
let annotations_csv_content = if annotations_path.exists() {
|
||||||
|
fs::read_to_string(&annotations_path).expect("failed to read annotations")
|
||||||
|
} else {
|
||||||
|
let response = http_client
|
||||||
|
.get(annotations_url, Default::default(), true)
|
||||||
|
.await
|
||||||
|
.expect("failed to fetch annotations csv");
|
||||||
|
let mut body = String::new();
|
||||||
|
response
|
||||||
|
.into_body()
|
||||||
|
.read_to_string(&mut body)
|
||||||
|
.await
|
||||||
|
.expect("failed to read annotations.csv response");
|
||||||
|
fs::write(annotations_path, &body).expect("failed to write annotations.csv");
|
||||||
|
body
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse the annotations CSV. Skip over queries with zero relevance.
|
||||||
|
let rows = annotations_csv_content.lines().filter_map(|line| {
|
||||||
|
let mut values = line.split(',');
|
||||||
|
let _language = values.next()?;
|
||||||
|
let query = values.next()?;
|
||||||
|
let github_url = values.next()?;
|
||||||
|
let score = values.next()?;
|
||||||
|
|
||||||
|
if score == "0" {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let url_path = github_url.strip_prefix("https://github.com/")?;
|
||||||
|
let (url_path, hash) = url_path.split_once('#')?;
|
||||||
|
let (repo_name, url_path) = url_path.split_once("/blob/")?;
|
||||||
|
let (sha, file_path) = url_path.split_once('/')?;
|
||||||
|
let line_range = if let Some((start, end)) = hash.split_once('-') {
|
||||||
|
start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
|
||||||
|
} else {
|
||||||
|
let row = hash.strip_prefix("L")?.parse().ok()?;
|
||||||
|
row..=row
|
||||||
|
};
|
||||||
|
Some((repo_name, sha, query, file_path, line_range))
|
||||||
|
});
|
||||||
|
|
||||||
|
// Group the annotations by repo and sha.
|
||||||
|
let mut evaluations_by_repo = BTreeMap::new();
|
||||||
|
for (repo_name, sha, query, file_path, lines) in rows {
|
||||||
|
let evaluation_project = evaluations_by_repo
|
||||||
|
.entry((repo_name, sha))
|
||||||
|
.or_insert_with(|| EvaluationProject {
|
||||||
|
repo: repo_name.to_string(),
|
||||||
|
sha: sha.to_string(),
|
||||||
|
queries: Vec::new(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let ix = evaluation_project
|
||||||
|
.queries
|
||||||
|
.iter()
|
||||||
|
.position(|entry| entry.query == query)
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
evaluation_project.queries.push(EvaluationQuery {
|
||||||
|
query: query.to_string(),
|
||||||
|
expected_results: Vec::new(),
|
||||||
|
});
|
||||||
|
evaluation_project.queries.len() - 1
|
||||||
|
});
|
||||||
|
let results = &mut evaluation_project.queries[ix].expected_results;
|
||||||
|
let result = EvaluationSearchResult {
|
||||||
|
file: file_path.to_string(),
|
||||||
|
lines,
|
||||||
|
};
|
||||||
|
if !results.contains(&result) {
|
||||||
|
results.push(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
|
||||||
|
let evaluations_path = dataset_dir.join("evaluations.json");
|
||||||
|
fs::write(
|
||||||
|
&evaluations_path,
|
||||||
|
serde_json::to_vec_pretty(&evaluations).unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
eprintln!(
|
||||||
|
"Fetched CodeSearchNet evaluations into {}",
|
||||||
|
evaluations_path.display()
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_evaluation(
|
||||||
|
only_repo: Option<String>,
|
||||||
|
executor: &BackgroundExecutor,
|
||||||
|
cx: &mut AsyncAppContext,
|
||||||
|
) -> Result<()> {
|
||||||
|
cx.update(|cx| {
|
||||||
|
let mut store = SettingsStore::new(cx);
|
||||||
|
store
|
||||||
|
.set_default_settings(settings::default_settings().as_ref(), cx)
|
||||||
|
.unwrap();
|
||||||
|
cx.set_global(store);
|
||||||
|
client::init_settings(cx);
|
||||||
|
language::init(cx);
|
||||||
|
Project::init_settings(cx);
|
||||||
|
cx.update_flags(false, vec![]);
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let dataset_dir = Path::new(CODESEARCH_NET_DIR);
|
||||||
|
let evaluations_path = dataset_dir.join("evaluations.json");
|
||||||
|
let repos_dir = Path::new(EVAL_REPOS_DIR);
|
||||||
|
let db_path = Path::new(EVAL_DB_PATH);
|
||||||
|
let http_client = http_client::HttpClientWithProxy::new(None, None);
|
||||||
|
let api_key = std::env::var("OPENAI_API_KEY").unwrap();
|
||||||
|
let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
|
||||||
|
let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
|
||||||
|
let clock = Arc::new(RealSystemClock);
|
||||||
|
let client = cx
|
||||||
|
.update(|cx| {
|
||||||
|
Client::new(
|
||||||
|
clock,
|
||||||
|
Arc::new(http_client::HttpClientWithUrl::new(
|
||||||
|
"https://zed.dev",
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
let user_store = cx
|
||||||
|
.new_model(|cx| UserStore::new(client.clone(), cx))
|
||||||
|
.unwrap();
|
||||||
|
let node_runtime = Arc::new(FakeNodeRuntime {});
|
||||||
|
|
||||||
|
let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
|
||||||
|
let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
|
||||||
|
|
||||||
|
let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
|
||||||
|
http_client.clone(),
|
||||||
|
OpenAiEmbeddingModel::TextEmbedding3Small,
|
||||||
|
open_ai::OPEN_AI_API_URL.to_string(),
|
||||||
|
api_key,
|
||||||
|
));
|
||||||
|
|
||||||
|
let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
|
||||||
|
cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut covered_result_count = 0;
|
||||||
|
let mut overlapped_result_count = 0;
|
||||||
|
let mut covered_file_count = 0;
|
||||||
|
let mut total_result_count = 0;
|
||||||
|
eprint!("Running evals.");
|
||||||
|
|
||||||
|
for evaluation_project in evaluations {
|
||||||
|
if only_repo
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |only_repo| only_repo != &evaluation_project.repo)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
eprint!("\r\x1B[2K");
|
||||||
|
eprint!(
|
||||||
|
"Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
|
||||||
|
covered_result_count,
|
||||||
|
total_result_count,
|
||||||
|
overlapped_result_count,
|
||||||
|
total_result_count,
|
||||||
|
covered_file_count,
|
||||||
|
total_result_count,
|
||||||
|
evaluation_project.repo
|
||||||
|
);
|
||||||
|
|
||||||
|
let repo_db_path =
|
||||||
|
db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
|
||||||
|
let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider.clone(), cx)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let repo_dir = repos_dir.join(&evaluation_project.repo);
|
||||||
|
if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
|
||||||
|
eprintln!("Skipping {}: directory not found", evaluation_project.repo);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let project = cx
|
||||||
|
.update(|cx| {
|
||||||
|
Project::local(
|
||||||
|
client.clone(),
|
||||||
|
node_runtime.clone(),
|
||||||
|
user_store.clone(),
|
||||||
|
language_registry.clone(),
|
||||||
|
fs.clone(),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (worktree, _) = project
|
||||||
|
.update(cx, |project, cx| {
|
||||||
|
project.find_or_create_worktree(repo_dir, true, cx)
|
||||||
|
})?
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
worktree
|
||||||
|
.update(cx, |worktree, _| {
|
||||||
|
worktree.as_local().unwrap().scan_complete()
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let project_index = cx
|
||||||
|
.update(|cx| semantic_index.create_project_index(project.clone(), cx))
|
||||||
|
.unwrap();
|
||||||
|
wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
|
||||||
|
|
||||||
|
for query in evaluation_project.queries {
|
||||||
|
let results = cx
|
||||||
|
.update(|cx| {
|
||||||
|
let project_index = project_index.read(cx);
|
||||||
|
project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx)
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = SemanticDb::load_results(results, &fs.clone(), &cx)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut project_covered_result_count = 0;
|
||||||
|
let mut project_overlapped_result_count = 0;
|
||||||
|
let mut project_covered_file_count = 0;
|
||||||
|
let mut covered_result_indices = Vec::new();
|
||||||
|
for expected_result in &query.expected_results {
|
||||||
|
let mut file_matched = false;
|
||||||
|
let mut range_overlapped = false;
|
||||||
|
let mut range_covered = false;
|
||||||
|
|
||||||
|
for (ix, result) in results.iter().enumerate() {
|
||||||
|
if result.path.as_ref() == Path::new(&expected_result.file) {
|
||||||
|
file_matched = true;
|
||||||
|
let start_matched =
|
||||||
|
result.row_range.contains(&expected_result.lines.start());
|
||||||
|
let end_matched = result.row_range.contains(&expected_result.lines.end());
|
||||||
|
|
||||||
|
if start_matched || end_matched {
|
||||||
|
range_overlapped = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if start_matched && end_matched {
|
||||||
|
range_covered = true;
|
||||||
|
covered_result_indices.push(ix);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if range_covered {
|
||||||
|
project_covered_result_count += 1
|
||||||
|
};
|
||||||
|
if range_overlapped {
|
||||||
|
project_overlapped_result_count += 1
|
||||||
|
};
|
||||||
|
if file_matched {
|
||||||
|
project_covered_file_count += 1
|
||||||
|
};
|
||||||
|
}
|
||||||
|
let outcome_repo = evaluation_project.repo.clone();
|
||||||
|
|
||||||
|
let query_results = EvaluationQueryOutcome {
|
||||||
|
repo: outcome_repo,
|
||||||
|
query: query.query,
|
||||||
|
total_result_count: query.expected_results.len(),
|
||||||
|
covered_result_count: project_covered_result_count,
|
||||||
|
overlapped_result_count: project_overlapped_result_count,
|
||||||
|
covered_file_count: project_covered_file_count,
|
||||||
|
expected_results: query.expected_results,
|
||||||
|
actual_results: results
|
||||||
|
.iter()
|
||||||
|
.map(|result| EvaluationSearchResult {
|
||||||
|
file: result.path.to_string_lossy().to_string(),
|
||||||
|
lines: result.row_range.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
covered_result_indices,
|
||||||
|
};
|
||||||
|
|
||||||
|
overlapped_result_count += query_results.overlapped_result_count;
|
||||||
|
covered_result_count += query_results.covered_result_count;
|
||||||
|
covered_file_count += query_results.covered_file_count;
|
||||||
|
total_result_count += query_results.total_result_count;
|
||||||
|
|
||||||
|
println!("{}", serde_json::to_string(&query_results).unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
eprint!(
|
||||||
|
"Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured.",
|
||||||
|
covered_result_count,
|
||||||
|
total_result_count,
|
||||||
|
overlapped_result_count,
|
||||||
|
total_result_count,
|
||||||
|
covered_file_count,
|
||||||
|
total_result_count,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_for_indexing_complete(
|
||||||
|
project_index: &Model<ProjectIndex>,
|
||||||
|
cx: &mut AsyncAppContext,
|
||||||
|
timeout: Option<Duration>,
|
||||||
|
) {
|
||||||
|
let (tx, rx) = bounded(1);
|
||||||
|
let subscription = cx.update(|cx| {
|
||||||
|
cx.subscribe(project_index, move |_, event, _| {
|
||||||
|
if let Status::Idle = event {
|
||||||
|
let _ = tx.try_send(*event);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = match timeout {
|
||||||
|
Some(timeout_duration) => {
|
||||||
|
smol::future::or(
|
||||||
|
async {
|
||||||
|
rx.recv().await.map_err(|_| ())?;
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
async {
|
||||||
|
Timer::after(timeout_duration).await;
|
||||||
|
Err(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
None => rx.recv().await.map(|_| ()).map_err(|_| ()),
|
||||||
|
};
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => (),
|
||||||
|
Err(_) => {
|
||||||
|
if let Some(timeout) = timeout {
|
||||||
|
eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(subscription);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_eval_repos(
|
||||||
|
executor: &BackgroundExecutor,
|
||||||
|
http_client: &dyn HttpClient,
|
||||||
|
) -> Result<()> {
|
||||||
|
let dataset_dir = Path::new(CODESEARCH_NET_DIR);
|
||||||
|
let evaluations_path = dataset_dir.join("evaluations.json");
|
||||||
|
let repos_dir = Path::new(EVAL_REPOS_DIR);
|
||||||
|
|
||||||
|
let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
|
||||||
|
let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
|
||||||
|
|
||||||
|
eprint!("Fetching evaluation repositories...");
|
||||||
|
|
||||||
|
executor
|
||||||
|
.scoped(move |scope| {
|
||||||
|
let done_count = Arc::new(AtomicUsize::new(0));
|
||||||
|
let len = evaluations.len();
|
||||||
|
for chunk in evaluations.chunks(evaluations.len() / 8) {
|
||||||
|
let chunk = chunk.to_vec();
|
||||||
|
let done_count = done_count.clone();
|
||||||
|
scope.spawn(async move {
|
||||||
|
for EvaluationProject { repo, sha, .. } in chunk {
|
||||||
|
eprint!(
|
||||||
|
"\rFetching evaluation repositories ({}/{})...",
|
||||||
|
done_count.load(SeqCst),
|
||||||
|
len,
|
||||||
|
);
|
||||||
|
|
||||||
|
fetch_eval_repo(repo, sha, repos_dir, http_client).await;
|
||||||
|
done_count.fetch_add(1, SeqCst);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_eval_repo(
|
||||||
|
repo: String,
|
||||||
|
sha: String,
|
||||||
|
repos_dir: &Path,
|
||||||
|
http_client: &dyn HttpClient,
|
||||||
|
) {
|
||||||
|
let Some((owner, repo_name)) = repo.split_once('/') else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let repo_dir = repos_dir.join(owner).join(repo_name);
|
||||||
|
fs::create_dir_all(&repo_dir).unwrap();
|
||||||
|
let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
|
||||||
|
if skip_eval_path.exists() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
|
||||||
|
if head_content.trim() == sha {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let repo_response = http_client
|
||||||
|
.send(
|
||||||
|
http_client::Request::builder()
|
||||||
|
.method(Method::HEAD)
|
||||||
|
.uri(format!("https://github.com/{}", repo))
|
||||||
|
.body(Default::default())
|
||||||
|
.expect(""),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("failed to check github repo");
|
||||||
|
if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
|
||||||
|
fs::write(&skip_eval_path, "").unwrap();
|
||||||
|
eprintln!(
|
||||||
|
"Repo {repo} is no longer public ({:?}). Skipping",
|
||||||
|
repo_response.status()
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if !repo_dir.join(".git").exists() {
|
||||||
|
let init_output = Command::new("git")
|
||||||
|
.current_dir(&repo_dir)
|
||||||
|
.args(&["init"])
|
||||||
|
.output()
|
||||||
|
.unwrap();
|
||||||
|
if !init_output.status.success() {
|
||||||
|
eprintln!(
|
||||||
|
"Failed to initialize git repository for {}: {}",
|
||||||
|
repo,
|
||||||
|
String::from_utf8_lossy(&init_output.stderr)
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let url = format!("https://github.com/{}.git", repo);
|
||||||
|
Command::new("git")
|
||||||
|
.current_dir(&repo_dir)
|
||||||
|
.args(&["remote", "add", "-f", "origin", &url])
|
||||||
|
.stdin(Stdio::null())
|
||||||
|
.output()
|
||||||
|
.unwrap();
|
||||||
|
let fetch_output = Command::new("git")
|
||||||
|
.current_dir(&repo_dir)
|
||||||
|
.args(&["fetch", "--depth", "1", "origin", &sha])
|
||||||
|
.stdin(Stdio::null())
|
||||||
|
.output()
|
||||||
|
.unwrap();
|
||||||
|
if !fetch_output.status.success() {
|
||||||
|
eprintln!(
|
||||||
|
"Failed to fetch {} for {}: {}",
|
||||||
|
sha,
|
||||||
|
repo,
|
||||||
|
String::from_utf8_lossy(&fetch_output.stderr)
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let checkout_output = Command::new("git")
|
||||||
|
.current_dir(&repo_dir)
|
||||||
|
.args(&["checkout", &sha])
|
||||||
|
.output()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
if !checkout_output.status.success() {
|
||||||
|
eprintln!(
|
||||||
|
"Failed to checkout {} for {}: {}",
|
||||||
|
sha,
|
||||||
|
repo,
|
||||||
|
String::from_utf8_lossy(&checkout_output.stderr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,6 +5,7 @@ use derive_more::Deref;
|
||||||
use futures::future::BoxFuture;
|
use futures::future::BoxFuture;
|
||||||
use futures_lite::FutureExt;
|
use futures_lite::FutureExt;
|
||||||
use isahc::config::{Configurable, RedirectPolicy};
|
use isahc::config::{Configurable, RedirectPolicy};
|
||||||
|
pub use isahc::http;
|
||||||
pub use isahc::{
|
pub use isahc::{
|
||||||
http::{Method, StatusCode, Uri},
|
http::{Method, StatusCode, Uri},
|
||||||
AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response,
|
AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response,
|
||||||
|
@ -226,7 +227,7 @@ pub fn client(user_agent: Option<String>, proxy: Option<Uri>) -> Arc<dyn HttpCli
|
||||||
// those requests use a different http client, because global timeouts
|
// those requests use a different http client, because global timeouts
|
||||||
// of 50 and 60 seconds, respectively, would be very high!
|
// of 50 and 60 seconds, respectively, would be very high!
|
||||||
.connect_timeout(Duration::from_secs(5))
|
.connect_timeout(Duration::from_secs(5))
|
||||||
.low_speed_timeout(100, Duration::from_secs(5))
|
.low_speed_timeout(100, Duration::from_secs(30))
|
||||||
.proxy(proxy.clone());
|
.proxy(proxy.clone());
|
||||||
if let Some(user_agent) = user_agent {
|
if let Some(user_agent) = user_agent {
|
||||||
builder = builder.default_header("User-Agent", user_agent);
|
builder = builder.default_header("User-Agent", user_agent);
|
||||||
|
|
|
@ -234,30 +234,25 @@ impl EmbeddingIndex {
|
||||||
cx.spawn(async {
|
cx.spawn(async {
|
||||||
while let Ok((entry, handle)) = entries.recv().await {
|
while let Ok((entry, handle)) = entries.recv().await {
|
||||||
let entry_abs_path = worktree_abs_path.join(&entry.path);
|
let entry_abs_path = worktree_abs_path.join(&entry.path);
|
||||||
match fs.load(&entry_abs_path).await {
|
if let Some(text) = fs.load(&entry_abs_path).await.ok() {
|
||||||
Ok(text) => {
|
let language = language_registry
|
||||||
let language = language_registry
|
.language_for_file_path(&entry.path)
|
||||||
.language_for_file_path(&entry.path)
|
.await
|
||||||
.await
|
.ok();
|
||||||
.ok();
|
let chunked_file = ChunkedFile {
|
||||||
let chunked_file = ChunkedFile {
|
chunks: chunking::chunk_text(
|
||||||
chunks: chunking::chunk_text(
|
&text,
|
||||||
&text,
|
language.as_ref(),
|
||||||
language.as_ref(),
|
&entry.path,
|
||||||
&entry.path,
|
),
|
||||||
),
|
handle,
|
||||||
handle,
|
path: entry.path,
|
||||||
path: entry.path,
|
mtime: entry.mtime,
|
||||||
mtime: entry.mtime,
|
text,
|
||||||
text,
|
};
|
||||||
};
|
|
||||||
|
|
||||||
if chunked_files_tx.send(chunked_file).await.is_err() {
|
if chunked_files_tx.send(chunked_file).await.is_err() {
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_)=> {
|
|
||||||
log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -358,33 +353,37 @@ impl EmbeddingIndex {
|
||||||
fn persist_embeddings(
|
fn persist_embeddings(
|
||||||
&self,
|
&self,
|
||||||
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
|
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
|
||||||
embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
|
mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
let db_connection = self.db_connection.clone();
|
let db_connection = self.db_connection.clone();
|
||||||
let db = self.db;
|
let db = self.db;
|
||||||
|
|
||||||
cx.background_executor().spawn(async move {
|
cx.background_executor().spawn(async move {
|
||||||
while let Some(deletion_range) = deleted_entry_ranges.next().await {
|
loop {
|
||||||
let mut txn = db_connection.write_txn()?;
|
// Interleave deletions and persists of embedded files
|
||||||
let start = deletion_range.0.as_ref().map(|start| start.as_str());
|
futures::select_biased! {
|
||||||
let end = deletion_range.1.as_ref().map(|end| end.as_str());
|
deletion_range = deleted_entry_ranges.next() => {
|
||||||
log::debug!("deleting embeddings in range {:?}", &(start, end));
|
if let Some(deletion_range) = deletion_range {
|
||||||
db.delete_range(&mut txn, &(start, end))?;
|
let mut txn = db_connection.write_txn()?;
|
||||||
txn.commit()?;
|
let start = deletion_range.0.as_ref().map(|start| start.as_str());
|
||||||
}
|
let end = deletion_range.1.as_ref().map(|end| end.as_str());
|
||||||
|
log::debug!("deleting embeddings in range {:?}", &(start, end));
|
||||||
let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
|
db.delete_range(&mut txn, &(start, end))?;
|
||||||
while let Some(embedded_files) = embedded_files.next().await {
|
txn.commit()?;
|
||||||
let mut txn = db_connection.write_txn()?;
|
}
|
||||||
for (file, _) in &embedded_files {
|
},
|
||||||
log::debug!("saving embedding for file {:?}", file.path);
|
file = embedded_files.next() => {
|
||||||
let key = db_key_for_path(&file.path);
|
if let Some((file, _)) = file {
|
||||||
db.put(&mut txn, &key, file)?;
|
let mut txn = db_connection.write_txn()?;
|
||||||
|
log::debug!("saving embedding for file {:?}", file.path);
|
||||||
|
let key = db_key_for_path(&file.path);
|
||||||
|
db.put(&mut txn, &key, &file)?;
|
||||||
|
txn.commit()?;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
complete => break,
|
||||||
}
|
}
|
||||||
txn.commit()?;
|
|
||||||
|
|
||||||
drop(embedded_files);
|
|
||||||
log::debug!("committed");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -15,7 +15,14 @@ use log;
|
||||||
use project::{Project, Worktree, WorktreeId};
|
use project::{Project, Worktree, WorktreeId};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use smol::channel;
|
use smol::channel;
|
||||||
use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc};
|
use std::{
|
||||||
|
cmp::Ordering,
|
||||||
|
future::Future,
|
||||||
|
num::NonZeroUsize,
|
||||||
|
ops::{Range, RangeInclusive},
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -26,6 +33,14 @@ pub struct SearchResult {
|
||||||
pub score: f32,
|
pub score: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct LoadedSearchResult {
|
||||||
|
pub path: Arc<Path>,
|
||||||
|
pub range: Range<usize>,
|
||||||
|
pub full_path: PathBuf,
|
||||||
|
pub file_content: String,
|
||||||
|
pub row_range: RangeInclusive<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct WorktreeSearchResult {
|
pub struct WorktreeSearchResult {
|
||||||
pub worktree_id: WorktreeId,
|
pub worktree_id: WorktreeId,
|
||||||
pub path: Arc<Path>,
|
pub path: Arc<Path>,
|
||||||
|
|
|
@ -10,14 +10,16 @@ mod worktree_index;
|
||||||
|
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
|
use fs::Fs;
|
||||||
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
|
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use project_index::ProjectIndex;
|
|
||||||
use std::{path::PathBuf, sync::Arc};
|
use std::{path::PathBuf, sync::Arc};
|
||||||
use ui::ViewContext;
|
use ui::ViewContext;
|
||||||
|
use util::ResultExt as _;
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
|
||||||
pub use embedding::*;
|
pub use embedding::*;
|
||||||
|
pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
|
||||||
pub use project_index_debug_view::ProjectIndexDebugView;
|
pub use project_index_debug_view::ProjectIndexDebugView;
|
||||||
pub use summary_index::FileSummary;
|
pub use summary_index::FileSummary;
|
||||||
|
|
||||||
|
@ -56,27 +58,7 @@ impl SemanticDb {
|
||||||
|
|
||||||
if cx.has_global::<SemanticDb>() {
|
if cx.has_global::<SemanticDb>() {
|
||||||
cx.update_global::<SemanticDb, _>(|this, cx| {
|
cx.update_global::<SemanticDb, _>(|this, cx| {
|
||||||
let project_index = cx.new_model(|cx| {
|
this.create_project_index(project, cx);
|
||||||
ProjectIndex::new(
|
|
||||||
project.clone(),
|
|
||||||
this.db_connection.clone(),
|
|
||||||
this.embedding_provider.clone(),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
let project_weak = project.downgrade();
|
|
||||||
this.project_indices
|
|
||||||
.insert(project_weak.clone(), project_index);
|
|
||||||
|
|
||||||
cx.on_release(move |_, _, cx| {
|
|
||||||
if cx.has_global::<SemanticDb>() {
|
|
||||||
cx.update_global::<SemanticDb, _>(|this, _| {
|
|
||||||
this.project_indices.remove(&project_weak);
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
log::info!("No SemanticDb, skipping project index")
|
log::info!("No SemanticDb, skipping project index")
|
||||||
|
@ -94,6 +76,50 @@ impl SemanticDb {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn load_results(
|
||||||
|
results: Vec<SearchResult>,
|
||||||
|
fs: &Arc<dyn Fs>,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> Result<Vec<LoadedSearchResult>> {
|
||||||
|
let mut loaded_results = Vec::new();
|
||||||
|
for result in results {
|
||||||
|
let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| {
|
||||||
|
let entry_abs_path = worktree.abs_path().join(&result.path);
|
||||||
|
let mut entry_full_path = PathBuf::from(worktree.root_name());
|
||||||
|
entry_full_path.push(&result.path);
|
||||||
|
let file_content = async {
|
||||||
|
let entry_abs_path = entry_abs_path;
|
||||||
|
fs.load(&entry_abs_path).await
|
||||||
|
};
|
||||||
|
(entry_full_path, file_content)
|
||||||
|
})?;
|
||||||
|
if let Some(file_content) = file_content.await.log_err() {
|
||||||
|
let range_start = result.range.start.min(file_content.len());
|
||||||
|
let range_end = result.range.end.min(file_content.len());
|
||||||
|
|
||||||
|
let start_row = file_content[0..range_start].matches('\n').count() as u32;
|
||||||
|
let end_row = file_content[0..range_end].matches('\n').count() as u32;
|
||||||
|
let start_line_byte_offset = file_content[0..range_start]
|
||||||
|
.rfind('\n')
|
||||||
|
.map(|pos| pos + 1)
|
||||||
|
.unwrap_or_default();
|
||||||
|
let end_line_byte_offset = file_content[range_end..]
|
||||||
|
.find('\n')
|
||||||
|
.map(|pos| range_end + pos)
|
||||||
|
.unwrap_or_else(|| file_content.len());
|
||||||
|
|
||||||
|
loaded_results.push(LoadedSearchResult {
|
||||||
|
path: result.path,
|
||||||
|
range: start_line_byte_offset..end_line_byte_offset,
|
||||||
|
full_path,
|
||||||
|
file_content,
|
||||||
|
row_range: start_row..=end_row,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(loaded_results)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn project_index(
|
pub fn project_index(
|
||||||
&mut self,
|
&mut self,
|
||||||
project: Model<Project>,
|
project: Model<Project>,
|
||||||
|
@ -113,6 +139,36 @@ impl SemanticDb {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn create_project_index(
|
||||||
|
&mut self,
|
||||||
|
project: Model<Project>,
|
||||||
|
cx: &mut AppContext,
|
||||||
|
) -> Model<ProjectIndex> {
|
||||||
|
let project_index = cx.new_model(|cx| {
|
||||||
|
ProjectIndex::new(
|
||||||
|
project.clone(),
|
||||||
|
self.db_connection.clone(),
|
||||||
|
self.embedding_provider.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
let project_weak = project.downgrade();
|
||||||
|
self.project_indices
|
||||||
|
.insert(project_weak.clone(), project_index.clone());
|
||||||
|
|
||||||
|
cx.observe_release(&project, move |_, cx| {
|
||||||
|
if cx.has_global::<SemanticDb>() {
|
||||||
|
cx.update_global::<SemanticDb, _>(|this, _| {
|
||||||
|
this.project_indices.remove(&project_weak);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
|
project_index
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -230,34 +286,13 @@ mod tests {
|
||||||
|
|
||||||
let project = Project::test(fs, [project_path], cx).await;
|
let project = Project::test(fs, [project_path], cx).await;
|
||||||
|
|
||||||
cx.update(|cx| {
|
let project_index = cx.update(|cx| {
|
||||||
let language_registry = project.read(cx).languages().clone();
|
let language_registry = project.read(cx).languages().clone();
|
||||||
let node_runtime = project.read(cx).node_runtime().unwrap().clone();
|
let node_runtime = project.read(cx).node_runtime().unwrap().clone();
|
||||||
languages::init(language_registry, node_runtime, cx);
|
languages::init(language_registry, node_runtime, cx);
|
||||||
|
semantic_index.create_project_index(project.clone(), cx)
|
||||||
// Manually create and insert the ProjectIndex
|
|
||||||
let project_index = cx.new_model(|cx| {
|
|
||||||
ProjectIndex::new(
|
|
||||||
project.clone(),
|
|
||||||
semantic_index.db_connection.clone(),
|
|
||||||
semantic_index.embedding_provider.clone(),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
semantic_index
|
|
||||||
.project_indices
|
|
||||||
.insert(project.downgrade(), project_index);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let project_index = cx
|
|
||||||
.update(|_cx| {
|
|
||||||
semantic_index
|
|
||||||
.project_indices
|
|
||||||
.get(&project.downgrade())
|
|
||||||
.cloned()
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
while cx
|
while cx
|
||||||
.update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
|
.update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue