From eff44f9aa4412399c1c642eb271c4e5ec8297cec Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 13 Sep 2023 20:02:15 -0400 Subject: [PATCH] semantic index eval, indexing appropriately --- Cargo.lock | 4 + crates/semantic_index/Cargo.toml | 4 + crates/semantic_index/eval/tree-sitter.json | 6 +- crates/semantic_index/examples/eval.rs | 194 ++++++++++++++++---- crates/semantic_index/src/semantic_index.rs | 6 +- 5 files changed, 168 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a66391ed07..b0f46a90d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6744,6 +6744,7 @@ dependencies = [ "anyhow", "async-trait", "bincode", + "client", "collections", "ctor", "editor", @@ -6757,6 +6758,7 @@ dependencies = [ "lazy_static", "log", "matrixmultiply", + "node_runtime", "parking_lot 0.11.2", "parse_duration", "picker", @@ -6766,6 +6768,7 @@ dependencies = [ "rand 0.8.5", "rpc", "rusqlite", + "rust-embed", "schemars", "serde", "serde_json", @@ -6788,6 +6791,7 @@ dependencies = [ "unindent", "util", "workspace", + "zed", ] [[package]] diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index b5537dd2fa..a20f29fd68 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -51,6 +51,10 @@ rpc = { path = "../rpc", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"]} git2 = { version = "0.15"} +rust-embed = { version = "8.0", features = ["include-exclude"] } +client = { path = "../client" } +zed = { path = "../zed"} +node_runtime = { path = "../node_runtime"} pretty_assertions.workspace = true rand.workspace = true diff --git a/crates/semantic_index/eval/tree-sitter.json b/crates/semantic_index/eval/tree-sitter.json index 52d1e9df16..d3dcc86937 100644 --- a/crates/semantic_index/eval/tree-sitter.json +++ b/crates/semantic_index/eval/tree-sitter.json @@ -17,7 +17,7 @@ { "query": "generate tags based on config", "matches": [ - "tags/src/lib.rs:261", + "tags/src/lib.rs:261" ] }, { @@ -54,13 +54,13 @@ { "query": "Match based on associativity of actions", "matches": [ - "cri/src/generate/build_tables/build_parse_table.rs:542", + "cri/src/generate/build_tables/build_parse_table.rs:542" ] }, { "query": "Format token set display", "matches": [ - "cli/src/generate/build_tables/item.rs:246", + "cli/src/generate/build_tables/item.rs:246" ] }, { diff --git a/crates/semantic_index/examples/eval.rs b/crates/semantic_index/examples/eval.rs index f666f5c281..67ee52e28c 100644 --- a/crates/semantic_index/examples/eval.rs +++ b/crates/semantic_index/examples/eval.rs @@ -1,8 +1,46 @@ +use anyhow::{anyhow, Result}; +use client::{self, UserStore}; use git2::{Object, Oid, Repository}; -use semantic_index::SearchResult; +use gpui::{AppContext, AssetSource, ModelHandle, Task}; +use language::LanguageRegistry; +use node_runtime::RealNodeRuntime; +use project::{Fs, Project, RealFs}; +use rust_embed::RustEmbed; +use semantic_index::embedding::OpenAIEmbeddings; +use semantic_index::semantic_index_settings::SemanticIndexSettings; +use semantic_index::{SearchResult, SemanticIndex}; use serde::Deserialize; -use std::path::{Path, PathBuf}; -use std::{env, fs}; +use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore}; +use std::path::{self, Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; +use std::{cmp, env, fs}; +use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}; +use util::http::{self, HttpClient}; +use util::paths::{self, EMBEDDINGS_DIR}; +use zed::languages; + +#[derive(RustEmbed)] +#[folder = "../../assets"] +#[include = "fonts/**/*"] +#[include = "icons/**/*"] +#[include = "themes/**/*"] +#[include = "sounds/**/*"] +#[include = "*.md"] +#[exclude = "*.DS_Store"] +pub struct Assets; + +impl AssetSource for Assets { + fn load(&self, path: &str) -> Result> { + Self::get(path) + .map(|f| f.data) + .ok_or_else(|| anyhow!("could not find asset at path \"{}\"", path)) + } + + fn list(&self, path: &str) -> Vec> { + Self::iter().filter(|p| p.starts_with(path)).collect() + } +} #[derive(Deserialize, Clone)] struct EvaluationQuery { @@ -13,15 +51,18 @@ struct EvaluationQuery { impl EvaluationQuery { fn match_pairs(&self) -> Vec<(PathBuf, usize)> { let mut pairs = Vec::new(); - for match_identifier in self.matches { - let match_parts = match_identifier.split(":"); + for match_identifier in self.matches.iter() { + let mut match_parts = match_identifier.split(":"); if let Some(file_path) = match_parts.next() { if let Some(row_number) = match_parts.next() { - pairs.push((PathBuf::from(file_path), from_str::(row_number))); + pairs.push(( + PathBuf::from(file_path), + row_number.parse::().unwrap(), + )); } } - + } pairs } } @@ -33,7 +74,7 @@ struct RepoEval { assertions: Vec, } -const TMP_REPO_PATH: &str = "./target/eval_repos"; +const TMP_REPO_PATH: &str = "eval_repos"; fn parse_eval() -> anyhow::Result> { let eval_folder = env::current_dir()? @@ -74,7 +115,12 @@ fn clone_repo(repo_eval: RepoEval) -> anyhow::Result { .unwrap() .to_owned() .replace(".git", ""); - let clone_path = Path::new(TMP_REPO_PATH).join(&repo_name).to_path_buf(); + + let clone_path = fs::canonicalize(env::current_dir()?)? + .parent() + .ok_or(anyhow!("path canonicalization failed"))? + .join(TMP_REPO_PATH) + .join(&repo_name); // Delete Clone Path if already exists let _ = fs::remove_dir_all(&clone_path); @@ -105,7 +151,6 @@ fn dcg(hits: Vec) -> f32 { } fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec, k: usize) -> f32 { - // NDCG or Normalized Discounted Cumulative Gain, is determined by comparing the relevance of // items returned by the search engine relative to the hypothetical ideal. // Relevance is represented as a series of booleans, in which each search result returned @@ -125,47 +170,118 @@ fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec, // very high quality, whereas rank results quickly drop off after the first result. let ideal = vec![1; cmp::min(eval_query.matches.len(), k)]; + let hits = vec![1]; return dcg(hits) / dcg(ideal); } -fn evaluate_map(eval_query: EvaluationQuery, search_results: Vec, k: usize) -> f32 { - -} - -fn evaluate_repo(repo_eval: RepoEval, clone_path: PathBuf) { - - // Launch new repo as a new Zed workspace/project - // Index the project - // Search each eval_query - // Calculate Statistics +// fn evaluate_map(eval_query: EvaluationQuery, search_results: Vec, k: usize) -> f32 {} +fn init_logger() { + env_logger::init(); } fn main() { + // Launch new repo as a new Zed workspace/project + let app = gpui::App::new(Assets).unwrap(); + let fs = Arc::new(RealFs); + let http = http::client(); + let user_settings_file_rx = + watch_config_file(app.background(), fs.clone(), paths::SETTINGS.clone()); + let http_client = http::client(); + init_logger(); - // zed/main.rs - // creating an app and running it, gives you the context. - // create a project, find_or_create_local_worktree. + app.run(move |cx| { + cx.set_global(*RELEASE_CHANNEL); - if let Ok(repo_evals) = parse_eval() { - for repo in repo_evals { - let cloned = clone_repo(repo.clone()); - match cloned { - Ok(clone_path) => { - println!( - "Cloned {:?} @ {:?} into {:?}", - repo.repo, repo.commit, &clone_path - ); + let client = client::Client::new(http.clone(), cx); + let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client.clone(), cx)); - // Evaluate Repo - evaluate_repo(repo, clone_path); + // Initialize Settings + let mut store = SettingsStore::default(); + store + .set_default_settings(default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(store); + handle_settings_file_changes(user_settings_file_rx, cx); - } - Err(err) => { - println!("Error Cloning: {:?}", err); + // Initialize Languages + let login_shell_env_loaded = Task::ready(()); + let mut languages = LanguageRegistry::new(login_shell_env_loaded); + languages.set_executor(cx.background().clone()); + let languages = Arc::new(languages); + + let node_runtime = RealNodeRuntime::new(http.clone()); + languages::init(languages.clone(), node_runtime.clone()); + + project::Project::init(&client, cx); + semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx); + + settings::register::(cx); + + let db_file_path = EMBEDDINGS_DIR + .join(Path::new(RELEASE_CHANNEL_NAME.as_str())) + .join("embeddings_db"); + + let languages = languages.clone(); + let fs = fs.clone(); + cx.spawn(|mut cx| async move { + let semantic_index = SemanticIndex::new( + fs.clone(), + db_file_path, + Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), + languages.clone(), + cx.clone(), + ) + .await?; + + if let Ok(repo_evals) = parse_eval() { + for repo in repo_evals { + let cloned = clone_repo(repo.clone()); + match cloned { + Ok(clone_path) => { + log::trace!( + "Cloned {:?} @ {:?} into {:?}", + repo.repo, + repo.commit, + &clone_path + ); + + // Create Project + let project = cx.update(|cx| { + Project::local( + client.clone(), + user_store.clone(), + languages.clone(), + fs.clone(), + cx, + ) + }); + + // Register Worktree + let _ = project + .update(&mut cx, |project, cx| { + println!( + "Creating worktree in project: {:?}", + clone_path.clone() + ); + project.find_or_create_local_worktree(clone_path, true, cx) + }) + .await; + + let _ = semantic_index + .update(&mut cx, |index, cx| index.index_project(project, cx)) + .await; + } + Err(err) => { + log::trace!("Error cloning: {:?}", err); + } + } } } - } - } + + anyhow::Ok(()) + }) + .detach(); + }); } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 115bf5d7a8..63bcc900f2 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1,5 +1,5 @@ mod db; -mod embedding; +pub mod embedding; mod embedding_queue; mod parsing; pub mod semantic_index_settings; @@ -301,7 +301,7 @@ impl SemanticIndex { } } - async fn new( + pub async fn new( fs: Arc, database_path: PathBuf, embedding_provider: Arc, @@ -837,8 +837,6 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task> { if !self.projects.contains_key(&project.downgrade()) { - log::trace!("Registering Project for Semantic Index"); - let subscription = cx.subscribe(&project, |this, project, event, cx| match event { project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { this.project_worktrees_changed(project.clone(), cx);