parent
bdd0cbb717
commit
f07695c4cd
6 changed files with 0 additions and 796 deletions
30
Cargo.lock
generated
30
Cargo.lock
generated
|
@ -4919,36 +4919,6 @@ dependencies = [
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "evals"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"anyhow",
|
|
||||||
"clap",
|
|
||||||
"client",
|
|
||||||
"clock",
|
|
||||||
"collections",
|
|
||||||
"dap",
|
|
||||||
"env_logger 0.11.8",
|
|
||||||
"feature_flags",
|
|
||||||
"fs",
|
|
||||||
"gpui",
|
|
||||||
"http_client",
|
|
||||||
"language",
|
|
||||||
"languages",
|
|
||||||
"node_runtime",
|
|
||||||
"open_ai",
|
|
||||||
"project",
|
|
||||||
"reqwest_client",
|
|
||||||
"semantic_index",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"settings",
|
|
||||||
"smol",
|
|
||||||
"util",
|
|
||||||
"workspace-hack",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "event-listener"
|
name = "event-listener"
|
||||||
version = "2.5.3"
|
version = "2.5.3"
|
||||||
|
|
|
@ -47,7 +47,6 @@ members = [
|
||||||
"crates/docs_preprocessor",
|
"crates/docs_preprocessor",
|
||||||
"crates/editor",
|
"crates/editor",
|
||||||
"crates/eval",
|
"crates/eval",
|
||||||
"crates/evals",
|
|
||||||
"crates/extension",
|
"crates/extension",
|
||||||
"crates/extension_api",
|
"crates/extension_api",
|
||||||
"crates/extension_cli",
|
"crates/extension_cli",
|
||||||
|
@ -696,7 +695,6 @@ breadcrumbs = { codegen-units = 1 }
|
||||||
collections = { codegen-units = 1 }
|
collections = { codegen-units = 1 }
|
||||||
command_palette = { codegen-units = 1 }
|
command_palette = { codegen-units = 1 }
|
||||||
command_palette_hooks = { codegen-units = 1 }
|
command_palette_hooks = { codegen-units = 1 }
|
||||||
evals = { codegen-units = 1 }
|
|
||||||
extension_cli = { codegen-units = 1 }
|
extension_cli = { codegen-units = 1 }
|
||||||
feature_flags = { codegen-units = 1 }
|
feature_flags = { codegen-units = 1 }
|
||||||
file_icons = { codegen-units = 1 }
|
file_icons = { codegen-units = 1 }
|
||||||
|
|
|
@ -1,40 +0,0 @@
|
||||||
[package]
|
|
||||||
name = "evals"
|
|
||||||
description = "Evaluations for Zed's AI features"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition.workspace = true
|
|
||||||
publish.workspace = true
|
|
||||||
license = "GPL-3.0-or-later"
|
|
||||||
|
|
||||||
[lints]
|
|
||||||
workspace = true
|
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name = "eval"
|
|
||||||
path = "src/eval.rs"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
anyhow.workspace = true
|
|
||||||
clap.workspace = true
|
|
||||||
client.workspace = true
|
|
||||||
clock.workspace = true
|
|
||||||
collections.workspace = true
|
|
||||||
dap.workspace = true
|
|
||||||
env_logger.workspace = true
|
|
||||||
feature_flags.workspace = true
|
|
||||||
fs.workspace = true
|
|
||||||
gpui.workspace = true
|
|
||||||
http_client.workspace = true
|
|
||||||
language.workspace = true
|
|
||||||
languages.workspace = true
|
|
||||||
node_runtime.workspace = true
|
|
||||||
open_ai.workspace = true
|
|
||||||
project.workspace = true
|
|
||||||
reqwest_client.workspace = true
|
|
||||||
semantic_index.workspace = true
|
|
||||||
serde.workspace = true
|
|
||||||
serde_json.workspace = true
|
|
||||||
settings.workspace = true
|
|
||||||
smol.workspace = true
|
|
||||||
util.workspace = true
|
|
||||||
workspace-hack.workspace = true
|
|
|
@ -1 +0,0 @@
|
||||||
../../LICENSE-GPL
|
|
|
@ -1,5 +0,0 @@
|
||||||
fn main() {
|
|
||||||
if cfg!(target_os = "macos") {
|
|
||||||
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,718 +0,0 @@
|
||||||
use ::fs::{Fs, RealFs};
|
|
||||||
use anyhow::Result;
|
|
||||||
use clap::Parser;
|
|
||||||
use client::{Client, UserStore};
|
|
||||||
use clock::RealSystemClock;
|
|
||||||
use collections::BTreeMap;
|
|
||||||
use dap::DapRegistry;
|
|
||||||
use feature_flags::FeatureFlagAppExt as _;
|
|
||||||
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Entity};
|
|
||||||
use http_client::{HttpClient, Method};
|
|
||||||
use language::LanguageRegistry;
|
|
||||||
use node_runtime::NodeRuntime;
|
|
||||||
use open_ai::OpenAiEmbeddingModel;
|
|
||||||
use project::Project;
|
|
||||||
use reqwest_client::ReqwestClient;
|
|
||||||
use semantic_index::{
|
|
||||||
EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
|
|
||||||
};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use settings::SettingsStore;
|
|
||||||
use smol::Timer;
|
|
||||||
use smol::channel::bounded;
|
|
||||||
use smol::io::AsyncReadExt;
|
|
||||||
use std::ops::RangeInclusive;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::time::Duration;
|
|
||||||
use std::{
|
|
||||||
fs,
|
|
||||||
path::Path,
|
|
||||||
process::{Stdio, exit},
|
|
||||||
sync::{
|
|
||||||
Arc,
|
|
||||||
atomic::{AtomicUsize, Ordering::SeqCst},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
|
|
||||||
const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
|
|
||||||
const EVAL_DB_PATH: &'static str = "target/eval_db";
|
|
||||||
const SEARCH_RESULT_LIMIT: usize = 8;
|
|
||||||
const SKIP_EVAL_PATH: &'static str = ".skip_eval";
|
|
||||||
|
|
||||||
#[derive(clap::Parser)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Cli {
|
|
||||||
#[command(subcommand)]
|
|
||||||
command: Commands,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(clap::Subcommand)]
|
|
||||||
enum Commands {
|
|
||||||
Fetch {},
|
|
||||||
Run {
|
|
||||||
#[arg(long)]
|
|
||||||
repo: Option<String>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
|
||||||
struct EvaluationProject {
|
|
||||||
repo: String,
|
|
||||||
sha: String,
|
|
||||||
queries: Vec<EvaluationQuery>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
struct EvaluationQuery {
|
|
||||||
query: String,
|
|
||||||
expected_results: Vec<EvaluationSearchResult>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
|
||||||
struct EvaluationSearchResult {
|
|
||||||
file: String,
|
|
||||||
lines: RangeInclusive<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
|
||||||
struct EvaluationProjectOutcome {
|
|
||||||
repo: String,
|
|
||||||
sha: String,
|
|
||||||
queries: Vec<EvaluationQueryOutcome>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
struct EvaluationQueryOutcome {
|
|
||||||
repo: String,
|
|
||||||
query: String,
|
|
||||||
expected_results: Vec<EvaluationSearchResult>,
|
|
||||||
actual_results: Vec<EvaluationSearchResult>,
|
|
||||||
covered_file_count: usize,
|
|
||||||
overlapped_result_count: usize,
|
|
||||||
covered_result_count: usize,
|
|
||||||
total_result_count: usize,
|
|
||||||
covered_result_indices: Vec<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
let cli = Cli::parse();
|
|
||||||
env_logger::init();
|
|
||||||
|
|
||||||
gpui::Application::headless().run(move |cx| {
|
|
||||||
let executor = cx.background_executor().clone();
|
|
||||||
let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
|
|
||||||
cx.set_http_client(client.clone());
|
|
||||||
match cli.command {
|
|
||||||
Commands::Fetch {} => {
|
|
||||||
executor
|
|
||||||
.clone()
|
|
||||||
.spawn(async move {
|
|
||||||
if let Err(err) = fetch_evaluation_resources(client, &executor).await {
|
|
||||||
eprintln!("Error: {}", err);
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
exit(0);
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
}
|
|
||||||
Commands::Run { repo } => {
|
|
||||||
cx.spawn(async move |cx| {
|
|
||||||
if let Err(err) = run_evaluation(repo, &executor, cx).await {
|
|
||||||
eprintln!("Error: {}", err);
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
exit(0);
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn fetch_evaluation_resources(
|
|
||||||
http_client: Arc<dyn HttpClient>,
|
|
||||||
executor: &BackgroundExecutor,
|
|
||||||
) -> Result<()> {
|
|
||||||
fetch_code_search_net_resources(&*http_client).await?;
|
|
||||||
fetch_eval_repos(executor, &*http_client).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
|
|
||||||
eprintln!("Fetching CodeSearchNet evaluations...");
|
|
||||||
|
|
||||||
let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
|
|
||||||
|
|
||||||
let dataset_dir = Path::new(CODESEARCH_NET_DIR);
|
|
||||||
fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
|
|
||||||
|
|
||||||
// Fetch the annotations CSV, which contains the human-annotated search relevances
|
|
||||||
let annotations_path = dataset_dir.join("annotations.csv");
|
|
||||||
let annotations_csv_content = if annotations_path.exists() {
|
|
||||||
fs::read_to_string(&annotations_path).expect("failed to read annotations")
|
|
||||||
} else {
|
|
||||||
let response = http_client
|
|
||||||
.get(annotations_url, Default::default(), true)
|
|
||||||
.await
|
|
||||||
.expect("failed to fetch annotations csv");
|
|
||||||
let mut body = String::new();
|
|
||||||
response
|
|
||||||
.into_body()
|
|
||||||
.read_to_string(&mut body)
|
|
||||||
.await
|
|
||||||
.expect("failed to read annotations.csv response");
|
|
||||||
fs::write(annotations_path, &body).expect("failed to write annotations.csv");
|
|
||||||
body
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parse the annotations CSV. Skip over queries with zero relevance.
|
|
||||||
let rows = annotations_csv_content.lines().filter_map(|line| {
|
|
||||||
let mut values = line.split(',');
|
|
||||||
let _language = values.next()?;
|
|
||||||
let query = values.next()?;
|
|
||||||
let github_url = values.next()?;
|
|
||||||
let score = values.next()?;
|
|
||||||
|
|
||||||
if score == "0" {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let url_path = github_url.strip_prefix("https://github.com/")?;
|
|
||||||
let (url_path, hash) = url_path.split_once('#')?;
|
|
||||||
let (repo_name, url_path) = url_path.split_once("/blob/")?;
|
|
||||||
let (sha, file_path) = url_path.split_once('/')?;
|
|
||||||
let line_range = if let Some((start, end)) = hash.split_once('-') {
|
|
||||||
start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
|
|
||||||
} else {
|
|
||||||
let row = hash.strip_prefix("L")?.parse().ok()?;
|
|
||||||
row..=row
|
|
||||||
};
|
|
||||||
Some((repo_name, sha, query, file_path, line_range))
|
|
||||||
});
|
|
||||||
|
|
||||||
// Group the annotations by repo and sha.
|
|
||||||
let mut evaluations_by_repo = BTreeMap::new();
|
|
||||||
for (repo_name, sha, query, file_path, lines) in rows {
|
|
||||||
let evaluation_project = evaluations_by_repo
|
|
||||||
.entry((repo_name, sha))
|
|
||||||
.or_insert_with(|| EvaluationProject {
|
|
||||||
repo: repo_name.to_string(),
|
|
||||||
sha: sha.to_string(),
|
|
||||||
queries: Vec::new(),
|
|
||||||
});
|
|
||||||
|
|
||||||
let ix = evaluation_project
|
|
||||||
.queries
|
|
||||||
.iter()
|
|
||||||
.position(|entry| entry.query == query)
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
evaluation_project.queries.push(EvaluationQuery {
|
|
||||||
query: query.to_string(),
|
|
||||||
expected_results: Vec::new(),
|
|
||||||
});
|
|
||||||
evaluation_project.queries.len() - 1
|
|
||||||
});
|
|
||||||
let results = &mut evaluation_project.queries[ix].expected_results;
|
|
||||||
let result = EvaluationSearchResult {
|
|
||||||
file: file_path.to_string(),
|
|
||||||
lines,
|
|
||||||
};
|
|
||||||
if !results.contains(&result) {
|
|
||||||
results.push(result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
|
|
||||||
let evaluations_path = dataset_dir.join("evaluations.json");
|
|
||||||
fs::write(
|
|
||||||
&evaluations_path,
|
|
||||||
serde_json::to_vec_pretty(&evaluations).unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
eprintln!(
|
|
||||||
"Fetched CodeSearchNet evaluations into {}",
|
|
||||||
evaluations_path.display()
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
|
||||||
struct Counts {
|
|
||||||
covered_results: usize,
|
|
||||||
overlapped_results: usize,
|
|
||||||
covered_files: usize,
|
|
||||||
total_results: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_evaluation(
|
|
||||||
only_repo: Option<String>,
|
|
||||||
executor: &BackgroundExecutor,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut http_client = None;
|
|
||||||
cx.update(|cx| {
|
|
||||||
let mut store = SettingsStore::new(cx);
|
|
||||||
store
|
|
||||||
.set_default_settings(settings::default_settings().as_ref(), cx)
|
|
||||||
.unwrap();
|
|
||||||
cx.set_global(store);
|
|
||||||
client::init_settings(cx);
|
|
||||||
language::init(cx);
|
|
||||||
Project::init_settings(cx);
|
|
||||||
http_client = Some(cx.http_client());
|
|
||||||
cx.update_flags(false, vec![]);
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
let http_client = http_client.unwrap();
|
|
||||||
let dataset_dir = Path::new(CODESEARCH_NET_DIR);
|
|
||||||
let evaluations_path = dataset_dir.join("evaluations.json");
|
|
||||||
let repos_dir = Path::new(EVAL_REPOS_DIR);
|
|
||||||
let db_path = Path::new(EVAL_DB_PATH);
|
|
||||||
let api_key = std::env::var("OPENAI_API_KEY").unwrap();
|
|
||||||
let fs = Arc::new(RealFs::new(None, cx.background_executor().clone())) as Arc<dyn Fs>;
|
|
||||||
let clock = Arc::new(RealSystemClock);
|
|
||||||
let client = cx
|
|
||||||
.update(|cx| {
|
|
||||||
Client::new(
|
|
||||||
clock,
|
|
||||||
Arc::new(http_client::HttpClientWithUrl::new(
|
|
||||||
http_client.clone(),
|
|
||||||
"https://zed.dev",
|
|
||||||
None,
|
|
||||||
)),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)).unwrap();
|
|
||||||
let node_runtime = NodeRuntime::unavailable();
|
|
||||||
|
|
||||||
let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
|
|
||||||
let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
|
|
||||||
|
|
||||||
let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
|
|
||||||
http_client.clone(),
|
|
||||||
OpenAiEmbeddingModel::TextEmbedding3Small,
|
|
||||||
open_ai::OPEN_AI_API_URL.to_string(),
|
|
||||||
api_key,
|
|
||||||
));
|
|
||||||
|
|
||||||
let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
|
|
||||||
let debug_adapters = Arc::new(DapRegistry::default());
|
|
||||||
cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let mut counts = Counts::default();
|
|
||||||
eprint!("Running evals.");
|
|
||||||
|
|
||||||
let mut failures = Vec::new();
|
|
||||||
|
|
||||||
for evaluation_project in evaluations {
|
|
||||||
if only_repo
|
|
||||||
.as_ref()
|
|
||||||
.map_or(false, |only_repo| only_repo != &evaluation_project.repo)
|
|
||||||
{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
eprint!("\r\x1B[2K");
|
|
||||||
eprint!(
|
|
||||||
"Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
|
|
||||||
counts.covered_results,
|
|
||||||
counts.total_results,
|
|
||||||
counts.overlapped_results,
|
|
||||||
counts.total_results,
|
|
||||||
counts.covered_files,
|
|
||||||
counts.total_results,
|
|
||||||
evaluation_project.repo
|
|
||||||
);
|
|
||||||
|
|
||||||
let repo_dir = repos_dir.join(&evaluation_project.repo);
|
|
||||||
if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
|
|
||||||
eprintln!("Skipping {}: directory not found", evaluation_project.repo);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let repo_db_path =
|
|
||||||
db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
|
|
||||||
|
|
||||||
let project = cx
|
|
||||||
.update(|cx| {
|
|
||||||
Project::local(
|
|
||||||
client.clone(),
|
|
||||||
node_runtime.clone(),
|
|
||||||
user_store.clone(),
|
|
||||||
language_registry.clone(),
|
|
||||||
debug_adapters.clone(),
|
|
||||||
fs.clone(),
|
|
||||||
None,
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let repo = evaluation_project.repo.clone();
|
|
||||||
if let Err(err) = run_eval_project(
|
|
||||||
evaluation_project,
|
|
||||||
&user_store,
|
|
||||||
repo_db_path,
|
|
||||||
&repo_dir,
|
|
||||||
&mut counts,
|
|
||||||
project,
|
|
||||||
embedding_provider.clone(),
|
|
||||||
fs.clone(),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
eprintln!("{repo} eval failed with error: {:?}", err);
|
|
||||||
|
|
||||||
failures.push((repo, err));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
eprintln!(
|
|
||||||
"Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. {} failed.",
|
|
||||||
counts.covered_results,
|
|
||||||
counts.total_results,
|
|
||||||
counts.overlapped_results,
|
|
||||||
counts.total_results,
|
|
||||||
counts.covered_files,
|
|
||||||
counts.total_results,
|
|
||||||
failures.len(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if failures.is_empty() {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
eprintln!("Failures:\n");
|
|
||||||
|
|
||||||
for (index, (repo, failure)) in failures.iter().enumerate() {
|
|
||||||
eprintln!("Failure #{} - {repo}\n{:?}", index + 1, failure);
|
|
||||||
}
|
|
||||||
|
|
||||||
Err(anyhow::anyhow!("Some evals failed."))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_eval_project(
|
|
||||||
evaluation_project: EvaluationProject,
|
|
||||||
user_store: &Entity<UserStore>,
|
|
||||||
repo_db_path: PathBuf,
|
|
||||||
repo_dir: &Path,
|
|
||||||
counts: &mut Counts,
|
|
||||||
project: Entity<Project>,
|
|
||||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
|
||||||
fs: Arc<dyn Fs>,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<(), anyhow::Error> {
|
|
||||||
let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider, cx).await?;
|
|
||||||
|
|
||||||
let (worktree, _) = project
|
|
||||||
.update(cx, |project, cx| {
|
|
||||||
project.find_or_create_worktree(repo_dir, true, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
worktree
|
|
||||||
.update(cx, |worktree, _| {
|
|
||||||
worktree.as_local().unwrap().scan_complete()
|
|
||||||
})?
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let project_index = cx.update(|cx| semantic_index.create_project_index(project.clone(), cx))?;
|
|
||||||
wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
|
|
||||||
|
|
||||||
for query in evaluation_project.queries {
|
|
||||||
let results = {
|
|
||||||
// Retry search up to 3 times in case of timeout, network failure, etc.
|
|
||||||
let mut retries_remaining = 3;
|
|
||||||
let mut result;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
match cx.update(|cx| {
|
|
||||||
let project_index = project_index.read(cx);
|
|
||||||
project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
|
|
||||||
}) {
|
|
||||||
Ok(task) => match task.await {
|
|
||||||
Ok(answer) => {
|
|
||||||
result = Ok(answer);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
result = Err(err);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(err) => {
|
|
||||||
result = Err(err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if retries_remaining > 0 {
|
|
||||||
eprintln!(
|
|
||||||
"Retrying search after it failed on query {:?} with {:?}",
|
|
||||||
query, result
|
|
||||||
);
|
|
||||||
retries_remaining -= 1;
|
|
||||||
} else {
|
|
||||||
eprintln!(
|
|
||||||
"Ran out of retries; giving up on search which failed on query {:?} with {:?}",
|
|
||||||
query, result
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SemanticDb::load_results(result?, &fs.clone(), &cx).await?
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut project_covered_result_count = 0;
|
|
||||||
let mut project_overlapped_result_count = 0;
|
|
||||||
let mut project_covered_file_count = 0;
|
|
||||||
let mut covered_result_indices = Vec::new();
|
|
||||||
for expected_result in &query.expected_results {
|
|
||||||
let mut file_matched = false;
|
|
||||||
let mut range_overlapped = false;
|
|
||||||
let mut range_covered = false;
|
|
||||||
|
|
||||||
for (ix, result) in results.iter().enumerate() {
|
|
||||||
if result.path.as_ref() == Path::new(&expected_result.file) {
|
|
||||||
file_matched = true;
|
|
||||||
let start_matched = result.row_range.contains(expected_result.lines.start());
|
|
||||||
let end_matched = result.row_range.contains(expected_result.lines.end());
|
|
||||||
|
|
||||||
if start_matched || end_matched {
|
|
||||||
range_overlapped = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if start_matched && end_matched {
|
|
||||||
range_covered = true;
|
|
||||||
covered_result_indices.push(ix);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if range_covered {
|
|
||||||
project_covered_result_count += 1
|
|
||||||
};
|
|
||||||
if range_overlapped {
|
|
||||||
project_overlapped_result_count += 1
|
|
||||||
};
|
|
||||||
if file_matched {
|
|
||||||
project_covered_file_count += 1
|
|
||||||
};
|
|
||||||
}
|
|
||||||
let outcome_repo = evaluation_project.repo.clone();
|
|
||||||
|
|
||||||
let query_results = EvaluationQueryOutcome {
|
|
||||||
repo: outcome_repo,
|
|
||||||
query: query.query,
|
|
||||||
total_result_count: query.expected_results.len(),
|
|
||||||
covered_result_count: project_covered_result_count,
|
|
||||||
overlapped_result_count: project_overlapped_result_count,
|
|
||||||
covered_file_count: project_covered_file_count,
|
|
||||||
expected_results: query.expected_results,
|
|
||||||
actual_results: results
|
|
||||||
.iter()
|
|
||||||
.map(|result| EvaluationSearchResult {
|
|
||||||
file: result.path.to_string_lossy().to_string(),
|
|
||||||
lines: result.row_range.clone(),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
covered_result_indices,
|
|
||||||
};
|
|
||||||
|
|
||||||
counts.overlapped_results += query_results.overlapped_result_count;
|
|
||||||
counts.covered_results += query_results.covered_result_count;
|
|
||||||
counts.covered_files += query_results.covered_file_count;
|
|
||||||
counts.total_results += query_results.total_result_count;
|
|
||||||
|
|
||||||
println!("{}", serde_json::to_string(&query_results)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
user_store.update(cx, |_, _| {
|
|
||||||
drop(semantic_index);
|
|
||||||
drop(project);
|
|
||||||
drop(worktree);
|
|
||||||
drop(project_index);
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn wait_for_indexing_complete(
|
|
||||||
project_index: &Entity<ProjectIndex>,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
timeout: Option<Duration>,
|
|
||||||
) {
|
|
||||||
let (tx, rx) = bounded(1);
|
|
||||||
let subscription = cx.update(|cx| {
|
|
||||||
cx.subscribe(project_index, move |_, event, _| {
|
|
||||||
if let Status::Idle = event {
|
|
||||||
let _ = tx.try_send(*event);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
let result = match timeout {
|
|
||||||
Some(timeout_duration) => {
|
|
||||||
smol::future::or(
|
|
||||||
async {
|
|
||||||
rx.recv().await.map_err(|_| ())?;
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
async {
|
|
||||||
Timer::after(timeout_duration).await;
|
|
||||||
Err(())
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
None => rx.recv().await.map(|_| ()).map_err(|_| ()),
|
|
||||||
};
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(_) => {
|
|
||||||
if let Some(timeout) = timeout {
|
|
||||||
eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
drop(subscription);
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn fetch_eval_repos(
|
|
||||||
executor: &BackgroundExecutor,
|
|
||||||
http_client: &dyn HttpClient,
|
|
||||||
) -> Result<()> {
|
|
||||||
let dataset_dir = Path::new(CODESEARCH_NET_DIR);
|
|
||||||
let evaluations_path = dataset_dir.join("evaluations.json");
|
|
||||||
let repos_dir = Path::new(EVAL_REPOS_DIR);
|
|
||||||
|
|
||||||
let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
|
|
||||||
let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
|
|
||||||
|
|
||||||
eprintln!("Fetching evaluation repositories...");
|
|
||||||
|
|
||||||
executor
|
|
||||||
.scoped(move |scope| {
|
|
||||||
let done_count = Arc::new(AtomicUsize::new(0));
|
|
||||||
let len = evaluations.len();
|
|
||||||
for chunk in evaluations.chunks(evaluations.len() / 8) {
|
|
||||||
let chunk = chunk.to_vec();
|
|
||||||
let done_count = done_count.clone();
|
|
||||||
scope.spawn(async move {
|
|
||||||
for EvaluationProject { repo, sha, .. } in chunk {
|
|
||||||
eprint!(
|
|
||||||
"\rFetching evaluation repositories ({}/{})...",
|
|
||||||
done_count.load(SeqCst),
|
|
||||||
len,
|
|
||||||
);
|
|
||||||
|
|
||||||
fetch_eval_repo(repo, sha, repos_dir, http_client).await;
|
|
||||||
done_count.fetch_add(1, SeqCst);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn fetch_eval_repo(
|
|
||||||
repo: String,
|
|
||||||
sha: String,
|
|
||||||
repos_dir: &Path,
|
|
||||||
http_client: &dyn HttpClient,
|
|
||||||
) {
|
|
||||||
let Some((owner, repo_name)) = repo.split_once('/') else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
let repo_dir = repos_dir.join(owner).join(repo_name);
|
|
||||||
fs::create_dir_all(&repo_dir).unwrap();
|
|
||||||
let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
|
|
||||||
if skip_eval_path.exists() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
|
|
||||||
if head_content.trim() == sha {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let repo_response = http_client
|
|
||||||
.send(
|
|
||||||
http_client::Request::builder()
|
|
||||||
.method(Method::HEAD)
|
|
||||||
.uri(format!("https://github.com/{}", repo))
|
|
||||||
.body(Default::default())
|
|
||||||
.expect(""),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.expect("failed to check github repo");
|
|
||||||
if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
|
|
||||||
fs::write(&skip_eval_path, "").unwrap();
|
|
||||||
eprintln!(
|
|
||||||
"Repo {repo} is no longer public ({:?}). Skipping",
|
|
||||||
repo_response.status()
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if !repo_dir.join(".git").exists() {
|
|
||||||
let init_output = util::command::new_std_command("git")
|
|
||||||
.current_dir(&repo_dir)
|
|
||||||
.args(&["init"])
|
|
||||||
.output()
|
|
||||||
.unwrap();
|
|
||||||
if !init_output.status.success() {
|
|
||||||
eprintln!(
|
|
||||||
"Failed to initialize git repository for {}: {}",
|
|
||||||
repo,
|
|
||||||
String::from_utf8_lossy(&init_output.stderr)
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let url = format!("https://github.com/{}.git", repo);
|
|
||||||
util::command::new_std_command("git")
|
|
||||||
.current_dir(&repo_dir)
|
|
||||||
.args(&["remote", "add", "-f", "origin", &url])
|
|
||||||
.stdin(Stdio::null())
|
|
||||||
.output()
|
|
||||||
.unwrap();
|
|
||||||
let fetch_output = util::command::new_std_command("git")
|
|
||||||
.current_dir(&repo_dir)
|
|
||||||
.args(&["fetch", "--depth", "1", "origin", &sha])
|
|
||||||
.stdin(Stdio::null())
|
|
||||||
.output()
|
|
||||||
.unwrap();
|
|
||||||
if !fetch_output.status.success() {
|
|
||||||
eprintln!(
|
|
||||||
"Failed to fetch {} for {}: {}",
|
|
||||||
sha,
|
|
||||||
repo,
|
|
||||||
String::from_utf8_lossy(&fetch_output.stderr)
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
let checkout_output = util::command::new_std_command("git")
|
|
||||||
.current_dir(&repo_dir)
|
|
||||||
.args(&["checkout", &sha])
|
|
||||||
.output()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
if !checkout_output.status.success() {
|
|
||||||
eprintln!(
|
|
||||||
"Failed to checkout {} for {}: {}",
|
|
||||||
sha,
|
|
||||||
repo,
|
|
||||||
String::from_utf8_lossy(&checkout_output.stderr)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue