diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3ee4c04580..c55a3a9907 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -101,7 +101,7 @@ jobs: timeout-minutes: 60 name: (Linux) Run Clippy and tests runs-on: - - hosted-linux-x86-1 + - buildjet-16vcpu-ubuntu-2204 steps: - name: Add Rust to the PATH run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH @@ -111,6 +111,11 @@ jobs: with: clean: false + - name: Cache dependencies + uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + - name: Install Linux dependencies run: ./script/linux @@ -264,7 +269,7 @@ jobs: timeout-minutes: 60 name: Create a Linux bundle runs-on: - - hosted-linux-x86-1 + - buildjet-16vcpu-ubuntu-2204 if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }} needs: [linux_tests] env: @@ -279,9 +284,6 @@ jobs: - name: Install Linux dependencies run: ./script/linux - - name: Limit target directory size - run: script/clear-target-dir-if-larger-than 100 - - name: Determine version and release channel if: ${{ startsWith(github.ref, 'refs/tags/v') }} run: | @@ -335,7 +337,7 @@ jobs: timeout-minutes: 60 name: Create arm64 Linux bundle runs-on: - - hosted-linux-arm-1 + - buildjet-16vcpu-ubuntu-2204-arm if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }} needs: [linux_tests] env: @@ -350,9 +352,6 @@ jobs: - name: Install Linux dependencies run: ./script/linux - - name: Limit target directory size - run: script/clear-target-dir-if-larger-than 100 - - name: Determine version and release channel if: ${{ startsWith(github.ref, 'refs/tags/v') }} run: | diff --git a/Cargo.lock b/Cargo.lock index 3b3a370c36..9cff895393 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4000,6 +4000,33 @@ dependencies = [ "num-traits", ] +[[package]] +name = "evals" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "client", + "clock", + "collections", + "env_logger", + "feature_flags", + "fs", + "git", + "gpui", + "http_client", + "language", + "languages", + "node_runtime", + "open_ai", + "project", + "semantic_index", + "serde", + "serde_json", + "settings", + "smol", +] + [[package]] name = "event-listener" version = "2.5.3" diff --git a/Cargo.toml b/Cargo.toml index e1af231c7e..eea510edf2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ members = [ "crates/diagnostics", "crates/docs_preprocessor", "crates/editor", + "crates/evals", "crates/extension", "crates/extension_api", "crates/extension_cli", diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 52838b5c77..6eaa86f4a7 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -3282,7 +3282,7 @@ impl ContextEditor { let fence = codeblock_fence_for_path( filename.as_deref(), - Some(selection.start.row..selection.end.row), + Some(selection.start.row..=selection.end.row), ); if let Some((line_comment_prefix, outline_text)) = diff --git a/crates/assistant/src/slash_command/file_command.rs b/crates/assistant/src/slash_command/file_command.rs index 0df8b5d4e0..260c6b0e2a 100644 --- a/crates/assistant/src/slash_command/file_command.rs +++ b/crates/assistant/src/slash_command/file_command.rs @@ -8,7 +8,7 @@ use project::{PathMatchCandidateSet, Project}; use serde::{Deserialize, Serialize}; use std::{ fmt::Write, - ops::Range, + ops::{Range, RangeInclusive}, path::{Path, PathBuf}, sync::{atomic::AtomicBool, Arc}, }; @@ -342,7 +342,10 @@ fn collect_files( }) } -pub fn codeblock_fence_for_path(path: Option<&Path>, row_range: Option>) -> String { +pub fn codeblock_fence_for_path( + path: Option<&Path>, + row_range: Option>, +) -> String { let mut text = String::new(); write!(text, "```").unwrap(); @@ -357,7 +360,7 @@ pub fn codeblock_fence_for_path(path: Option<&Path>, row_range: Option, + }, +} + +#[derive(Clone, Deserialize, Serialize)] +struct EvaluationProject { + repo: String, + sha: String, + queries: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +struct EvaluationQuery { + query: String, + expected_results: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +struct EvaluationSearchResult { + file: String, + lines: RangeInclusive, +} + +#[derive(Clone, Deserialize, Serialize)] +struct EvaluationProjectOutcome { + repo: String, + sha: String, + queries: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +struct EvaluationQueryOutcome { + repo: String, + query: String, + expected_results: Vec, + actual_results: Vec, + covered_file_count: usize, + overlapped_result_count: usize, + covered_result_count: usize, + total_result_count: usize, + covered_result_indices: Vec, +} + +fn main() -> Result<()> { + let cli = Cli::parse(); + env_logger::init(); + + gpui::App::headless().run(move |cx| { + let executor = cx.background_executor().clone(); + + match cli.command { + Commands::Fetch {} => { + executor + .clone() + .spawn(async move { + if let Err(err) = fetch_evaluation_resources(&executor).await { + eprintln!("Error: {}", err); + exit(1); + } + exit(0); + }) + .detach(); + } + Commands::Run { repo } => { + cx.spawn(|mut cx| async move { + if let Err(err) = run_evaluation(repo, &executor, &mut cx).await { + eprintln!("Error: {}", err); + exit(1); + } + exit(0); + }) + .detach(); + } + } + }); + + Ok(()) +} + +async fn fetch_evaluation_resources(executor: &BackgroundExecutor) -> Result<()> { + let http_client = http_client::HttpClientWithProxy::new(None, None); + fetch_code_search_net_resources(&http_client).await?; + fetch_eval_repos(executor, &http_client).await?; + Ok(()) +} + +async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> { + eprintln!("Fetching CodeSearchNet evaluations..."); + + let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv"; + + let dataset_dir = Path::new(CODESEARCH_NET_DIR); + fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory"); + + // Fetch the annotations CSV, which contains the human-annotated search relevances + let annotations_path = dataset_dir.join("annotations.csv"); + let annotations_csv_content = if annotations_path.exists() { + fs::read_to_string(&annotations_path).expect("failed to read annotations") + } else { + let response = http_client + .get(annotations_url, Default::default(), true) + .await + .expect("failed to fetch annotations csv"); + let mut body = String::new(); + response + .into_body() + .read_to_string(&mut body) + .await + .expect("failed to read annotations.csv response"); + fs::write(annotations_path, &body).expect("failed to write annotations.csv"); + body + }; + + // Parse the annotations CSV. Skip over queries with zero relevance. + let rows = annotations_csv_content.lines().filter_map(|line| { + let mut values = line.split(','); + let _language = values.next()?; + let query = values.next()?; + let github_url = values.next()?; + let score = values.next()?; + + if score == "0" { + return None; + } + + let url_path = github_url.strip_prefix("https://github.com/")?; + let (url_path, hash) = url_path.split_once('#')?; + let (repo_name, url_path) = url_path.split_once("/blob/")?; + let (sha, file_path) = url_path.split_once('/')?; + let line_range = if let Some((start, end)) = hash.split_once('-') { + start.strip_prefix("L")?.parse::().ok()?..=end.strip_prefix("L")?.parse().ok()? + } else { + let row = hash.strip_prefix("L")?.parse().ok()?; + row..=row + }; + Some((repo_name, sha, query, file_path, line_range)) + }); + + // Group the annotations by repo and sha. + let mut evaluations_by_repo = BTreeMap::new(); + for (repo_name, sha, query, file_path, lines) in rows { + let evaluation_project = evaluations_by_repo + .entry((repo_name, sha)) + .or_insert_with(|| EvaluationProject { + repo: repo_name.to_string(), + sha: sha.to_string(), + queries: Vec::new(), + }); + + let ix = evaluation_project + .queries + .iter() + .position(|entry| entry.query == query) + .unwrap_or_else(|| { + evaluation_project.queries.push(EvaluationQuery { + query: query.to_string(), + expected_results: Vec::new(), + }); + evaluation_project.queries.len() - 1 + }); + let results = &mut evaluation_project.queries[ix].expected_results; + let result = EvaluationSearchResult { + file: file_path.to_string(), + lines, + }; + if !results.contains(&result) { + results.push(result); + } + } + + let evaluations = evaluations_by_repo.into_values().collect::>(); + let evaluations_path = dataset_dir.join("evaluations.json"); + fs::write( + &evaluations_path, + serde_json::to_vec_pretty(&evaluations).unwrap(), + ) + .unwrap(); + + eprintln!( + "Fetched CodeSearchNet evaluations into {}", + evaluations_path.display() + ); + + Ok(()) +} + +async fn run_evaluation( + only_repo: Option, + executor: &BackgroundExecutor, + cx: &mut AsyncAppContext, +) -> Result<()> { + cx.update(|cx| { + let mut store = SettingsStore::new(cx); + store + .set_default_settings(settings::default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(store); + client::init_settings(cx); + language::init(cx); + Project::init_settings(cx); + cx.update_flags(false, vec![]); + }) + .unwrap(); + + let dataset_dir = Path::new(CODESEARCH_NET_DIR); + let evaluations_path = dataset_dir.join("evaluations.json"); + let repos_dir = Path::new(EVAL_REPOS_DIR); + let db_path = Path::new(EVAL_DB_PATH); + let http_client = http_client::HttpClientWithProxy::new(None, None); + let api_key = std::env::var("OPENAI_API_KEY").unwrap(); + let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new()); + let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc; + let clock = Arc::new(RealSystemClock); + let client = cx + .update(|cx| { + Client::new( + clock, + Arc::new(http_client::HttpClientWithUrl::new( + "https://zed.dev", + None, + None, + )), + cx, + ) + }) + .unwrap(); + let user_store = cx + .new_model(|cx| UserStore::new(client.clone(), cx)) + .unwrap(); + let node_runtime = Arc::new(FakeNodeRuntime {}); + + let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json"); + let evaluations: Vec = serde_json::from_slice(&evaluations).unwrap(); + + let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new( + http_client.clone(), + OpenAiEmbeddingModel::TextEmbedding3Small, + open_ai::OPEN_AI_API_URL.to_string(), + api_key, + )); + + let language_registry = Arc::new(LanguageRegistry::new(executor.clone())); + cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx)) + .unwrap(); + + let mut covered_result_count = 0; + let mut overlapped_result_count = 0; + let mut covered_file_count = 0; + let mut total_result_count = 0; + eprint!("Running evals."); + + for evaluation_project in evaluations { + if only_repo + .as_ref() + .map_or(false, |only_repo| only_repo != &evaluation_project.repo) + { + continue; + } + + eprint!("\r\x1B[2K"); + eprint!( + "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...", + covered_result_count, + total_result_count, + overlapped_result_count, + total_result_count, + covered_file_count, + total_result_count, + evaluation_project.repo + ); + + let repo_db_path = + db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_"))); + let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider.clone(), cx) + .await + .unwrap(); + + let repo_dir = repos_dir.join(&evaluation_project.repo); + if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() { + eprintln!("Skipping {}: directory not found", evaluation_project.repo); + continue; + } + + let project = cx + .update(|cx| { + Project::local( + client.clone(), + node_runtime.clone(), + user_store.clone(), + language_registry.clone(), + fs.clone(), + None, + cx, + ) + }) + .unwrap(); + + let (worktree, _) = project + .update(cx, |project, cx| { + project.find_or_create_worktree(repo_dir, true, cx) + })? + .await?; + + worktree + .update(cx, |worktree, _| { + worktree.as_local().unwrap().scan_complete() + }) + .unwrap() + .await; + + let project_index = cx + .update(|cx| semantic_index.create_project_index(project.clone(), cx)) + .unwrap(); + wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await; + + for query in evaluation_project.queries { + let results = cx + .update(|cx| { + let project_index = project_index.read(cx); + project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx) + }) + .unwrap() + .await + .unwrap(); + + let results = SemanticDb::load_results(results, &fs.clone(), &cx) + .await + .unwrap(); + + let mut project_covered_result_count = 0; + let mut project_overlapped_result_count = 0; + let mut project_covered_file_count = 0; + let mut covered_result_indices = Vec::new(); + for expected_result in &query.expected_results { + let mut file_matched = false; + let mut range_overlapped = false; + let mut range_covered = false; + + for (ix, result) in results.iter().enumerate() { + if result.path.as_ref() == Path::new(&expected_result.file) { + file_matched = true; + let start_matched = + result.row_range.contains(&expected_result.lines.start()); + let end_matched = result.row_range.contains(&expected_result.lines.end()); + + if start_matched || end_matched { + range_overlapped = true; + } + + if start_matched && end_matched { + range_covered = true; + covered_result_indices.push(ix); + break; + } + } + } + + if range_covered { + project_covered_result_count += 1 + }; + if range_overlapped { + project_overlapped_result_count += 1 + }; + if file_matched { + project_covered_file_count += 1 + }; + } + let outcome_repo = evaluation_project.repo.clone(); + + let query_results = EvaluationQueryOutcome { + repo: outcome_repo, + query: query.query, + total_result_count: query.expected_results.len(), + covered_result_count: project_covered_result_count, + overlapped_result_count: project_overlapped_result_count, + covered_file_count: project_covered_file_count, + expected_results: query.expected_results, + actual_results: results + .iter() + .map(|result| EvaluationSearchResult { + file: result.path.to_string_lossy().to_string(), + lines: result.row_range.clone(), + }) + .collect(), + covered_result_indices, + }; + + overlapped_result_count += query_results.overlapped_result_count; + covered_result_count += query_results.covered_result_count; + covered_file_count += query_results.covered_file_count; + total_result_count += query_results.total_result_count; + + println!("{}", serde_json::to_string(&query_results).unwrap()); + } + } + + eprint!( + "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured.", + covered_result_count, + total_result_count, + overlapped_result_count, + total_result_count, + covered_file_count, + total_result_count, + ); + + Ok(()) +} + +async fn wait_for_indexing_complete( + project_index: &Model, + cx: &mut AsyncAppContext, + timeout: Option, +) { + let (tx, rx) = bounded(1); + let subscription = cx.update(|cx| { + cx.subscribe(project_index, move |_, event, _| { + if let Status::Idle = event { + let _ = tx.try_send(*event); + } + }) + }); + + let result = match timeout { + Some(timeout_duration) => { + smol::future::or( + async { + rx.recv().await.map_err(|_| ())?; + Ok(()) + }, + async { + Timer::after(timeout_duration).await; + Err(()) + }, + ) + .await + } + None => rx.recv().await.map(|_| ()).map_err(|_| ()), + }; + + match result { + Ok(_) => (), + Err(_) => { + if let Some(timeout) = timeout { + eprintln!("Timeout: Indexing did not complete within {:?}", timeout); + } + } + } + + drop(subscription); +} + +async fn fetch_eval_repos( + executor: &BackgroundExecutor, + http_client: &dyn HttpClient, +) -> Result<()> { + let dataset_dir = Path::new(CODESEARCH_NET_DIR); + let evaluations_path = dataset_dir.join("evaluations.json"); + let repos_dir = Path::new(EVAL_REPOS_DIR); + + let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json"); + let evaluations: Vec = serde_json::from_slice(&evaluations).unwrap(); + + eprint!("Fetching evaluation repositories..."); + + executor + .scoped(move |scope| { + let done_count = Arc::new(AtomicUsize::new(0)); + let len = evaluations.len(); + for chunk in evaluations.chunks(evaluations.len() / 8) { + let chunk = chunk.to_vec(); + let done_count = done_count.clone(); + scope.spawn(async move { + for EvaluationProject { repo, sha, .. } in chunk { + eprint!( + "\rFetching evaluation repositories ({}/{})...", + done_count.load(SeqCst), + len, + ); + + fetch_eval_repo(repo, sha, repos_dir, http_client).await; + done_count.fetch_add(1, SeqCst); + } + }); + } + }) + .await; + + Ok(()) +} + +async fn fetch_eval_repo( + repo: String, + sha: String, + repos_dir: &Path, + http_client: &dyn HttpClient, +) { + let Some((owner, repo_name)) = repo.split_once('/') else { + return; + }; + let repo_dir = repos_dir.join(owner).join(repo_name); + fs::create_dir_all(&repo_dir).unwrap(); + let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH); + if skip_eval_path.exists() { + return; + } + if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) { + if head_content.trim() == sha { + return; + } + } + let repo_response = http_client + .send( + http_client::Request::builder() + .method(Method::HEAD) + .uri(format!("https://github.com/{}", repo)) + .body(Default::default()) + .expect(""), + ) + .await + .expect("failed to check github repo"); + if !repo_response.status().is_success() && !repo_response.status().is_redirection() { + fs::write(&skip_eval_path, "").unwrap(); + eprintln!( + "Repo {repo} is no longer public ({:?}). Skipping", + repo_response.status() + ); + return; + } + if !repo_dir.join(".git").exists() { + let init_output = Command::new("git") + .current_dir(&repo_dir) + .args(&["init"]) + .output() + .unwrap(); + if !init_output.status.success() { + eprintln!( + "Failed to initialize git repository for {}: {}", + repo, + String::from_utf8_lossy(&init_output.stderr) + ); + return; + } + } + let url = format!("https://github.com/{}.git", repo); + Command::new("git") + .current_dir(&repo_dir) + .args(&["remote", "add", "-f", "origin", &url]) + .stdin(Stdio::null()) + .output() + .unwrap(); + let fetch_output = Command::new("git") + .current_dir(&repo_dir) + .args(&["fetch", "--depth", "1", "origin", &sha]) + .stdin(Stdio::null()) + .output() + .unwrap(); + if !fetch_output.status.success() { + eprintln!( + "Failed to fetch {} for {}: {}", + sha, + repo, + String::from_utf8_lossy(&fetch_output.stderr) + ); + return; + } + let checkout_output = Command::new("git") + .current_dir(&repo_dir) + .args(&["checkout", &sha]) + .output() + .unwrap(); + + if !checkout_output.status.success() { + eprintln!( + "Failed to checkout {} for {}: {}", + sha, + repo, + String::from_utf8_lossy(&checkout_output.stderr) + ); + } +} diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 1841a1f394..7ea0029d79 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -5,6 +5,7 @@ use derive_more::Deref; use futures::future::BoxFuture; use futures_lite::FutureExt; use isahc::config::{Configurable, RedirectPolicy}; +pub use isahc::http; pub use isahc::{ http::{Method, StatusCode, Uri}, AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response, @@ -226,7 +227,7 @@ pub fn client(user_agent: Option, proxy: Option) -> Arc { - let language = language_registry - .language_for_file_path(&entry.path) - .await - .ok(); - let chunked_file = ChunkedFile { - chunks: chunking::chunk_text( - &text, - language.as_ref(), - &entry.path, - ), - handle, - path: entry.path, - mtime: entry.mtime, - text, - }; + if let Some(text) = fs.load(&entry_abs_path).await.ok() { + let language = language_registry + .language_for_file_path(&entry.path) + .await + .ok(); + let chunked_file = ChunkedFile { + chunks: chunking::chunk_text( + &text, + language.as_ref(), + &entry.path, + ), + handle, + path: entry.path, + mtime: entry.mtime, + text, + }; - if chunked_files_tx.send(chunked_file).await.is_err() { - return; - } - } - Err(_)=> { - log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}"); + if chunked_files_tx.send(chunked_file).await.is_err() { + return; } } } @@ -358,33 +353,37 @@ impl EmbeddingIndex { fn persist_embeddings( &self, mut deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, - embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>, + mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>, cx: &AppContext, ) -> Task> { let db_connection = self.db_connection.clone(); let db = self.db; + cx.background_executor().spawn(async move { - while let Some(deletion_range) = deleted_entry_ranges.next().await { - let mut txn = db_connection.write_txn()?; - let start = deletion_range.0.as_ref().map(|start| start.as_str()); - let end = deletion_range.1.as_ref().map(|end| end.as_str()); - log::debug!("deleting embeddings in range {:?}", &(start, end)); - db.delete_range(&mut txn, &(start, end))?; - txn.commit()?; - } - - let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2)); - while let Some(embedded_files) = embedded_files.next().await { - let mut txn = db_connection.write_txn()?; - for (file, _) in &embedded_files { - log::debug!("saving embedding for file {:?}", file.path); - let key = db_key_for_path(&file.path); - db.put(&mut txn, &key, file)?; + loop { + // Interleave deletions and persists of embedded files + futures::select_biased! { + deletion_range = deleted_entry_ranges.next() => { + if let Some(deletion_range) = deletion_range { + let mut txn = db_connection.write_txn()?; + let start = deletion_range.0.as_ref().map(|start| start.as_str()); + let end = deletion_range.1.as_ref().map(|end| end.as_str()); + log::debug!("deleting embeddings in range {:?}", &(start, end)); + db.delete_range(&mut txn, &(start, end))?; + txn.commit()?; + } + }, + file = embedded_files.next() => { + if let Some((file, _)) = file { + let mut txn = db_connection.write_txn()?; + log::debug!("saving embedding for file {:?}", file.path); + let key = db_key_for_path(&file.path); + db.put(&mut txn, &key, &file)?; + txn.commit()?; + } + }, + complete => break, } - txn.commit()?; - - drop(embedded_files); - log::debug!("committed"); } Ok(()) diff --git a/crates/semantic_index/src/project_index.rs b/crates/semantic_index/src/project_index.rs index 84a72c1a3d..5c35c93fa9 100644 --- a/crates/semantic_index/src/project_index.rs +++ b/crates/semantic_index/src/project_index.rs @@ -15,7 +15,14 @@ use log; use project::{Project, Worktree, WorktreeId}; use serde::{Deserialize, Serialize}; use smol::channel; -use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc}; +use std::{ + cmp::Ordering, + future::Future, + num::NonZeroUsize, + ops::{Range, RangeInclusive}, + path::{Path, PathBuf}, + sync::Arc, +}; use util::ResultExt; #[derive(Debug)] @@ -26,6 +33,14 @@ pub struct SearchResult { pub score: f32, } +pub struct LoadedSearchResult { + pub path: Arc, + pub range: Range, + pub full_path: PathBuf, + pub file_content: String, + pub row_range: RangeInclusive, +} + pub struct WorktreeSearchResult { pub worktree_id: WorktreeId, pub path: Arc, diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index f2b325ead6..3435d0a9ca 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -10,14 +10,16 @@ mod worktree_index; use anyhow::{Context as _, Result}; use collections::HashMap; +use fs::Fs; use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel}; use project::Project; -use project_index::ProjectIndex; use std::{path::PathBuf, sync::Arc}; use ui::ViewContext; +use util::ResultExt as _; use workspace::Workspace; pub use embedding::*; +pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status}; pub use project_index_debug_view::ProjectIndexDebugView; pub use summary_index::FileSummary; @@ -56,27 +58,7 @@ impl SemanticDb { if cx.has_global::() { cx.update_global::(|this, cx| { - let project_index = cx.new_model(|cx| { - ProjectIndex::new( - project.clone(), - this.db_connection.clone(), - this.embedding_provider.clone(), - cx, - ) - }); - - let project_weak = project.downgrade(); - this.project_indices - .insert(project_weak.clone(), project_index); - - cx.on_release(move |_, _, cx| { - if cx.has_global::() { - cx.update_global::(|this, _| { - this.project_indices.remove(&project_weak); - }) - } - }) - .detach(); + this.create_project_index(project, cx); }) } else { log::info!("No SemanticDb, skipping project index") @@ -94,6 +76,50 @@ impl SemanticDb { }) } + pub async fn load_results( + results: Vec, + fs: &Arc, + cx: &AsyncAppContext, + ) -> Result> { + let mut loaded_results = Vec::new(); + for result in results { + let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| { + let entry_abs_path = worktree.abs_path().join(&result.path); + let mut entry_full_path = PathBuf::from(worktree.root_name()); + entry_full_path.push(&result.path); + let file_content = async { + let entry_abs_path = entry_abs_path; + fs.load(&entry_abs_path).await + }; + (entry_full_path, file_content) + })?; + if let Some(file_content) = file_content.await.log_err() { + let range_start = result.range.start.min(file_content.len()); + let range_end = result.range.end.min(file_content.len()); + + let start_row = file_content[0..range_start].matches('\n').count() as u32; + let end_row = file_content[0..range_end].matches('\n').count() as u32; + let start_line_byte_offset = file_content[0..range_start] + .rfind('\n') + .map(|pos| pos + 1) + .unwrap_or_default(); + let end_line_byte_offset = file_content[range_end..] + .find('\n') + .map(|pos| range_end + pos) + .unwrap_or_else(|| file_content.len()); + + loaded_results.push(LoadedSearchResult { + path: result.path, + range: start_line_byte_offset..end_line_byte_offset, + full_path, + file_content, + row_range: start_row..=end_row, + }); + } + } + Ok(loaded_results) + } + pub fn project_index( &mut self, project: Model, @@ -113,6 +139,36 @@ impl SemanticDb { }) }) } + + pub fn create_project_index( + &mut self, + project: Model, + cx: &mut AppContext, + ) -> Model { + let project_index = cx.new_model(|cx| { + ProjectIndex::new( + project.clone(), + self.db_connection.clone(), + self.embedding_provider.clone(), + cx, + ) + }); + + let project_weak = project.downgrade(); + self.project_indices + .insert(project_weak.clone(), project_index.clone()); + + cx.observe_release(&project, move |_, cx| { + if cx.has_global::() { + cx.update_global::(|this, _| { + this.project_indices.remove(&project_weak); + }) + } + }) + .detach(); + + project_index + } } #[cfg(test)] @@ -230,34 +286,13 @@ mod tests { let project = Project::test(fs, [project_path], cx).await; - cx.update(|cx| { + let project_index = cx.update(|cx| { let language_registry = project.read(cx).languages().clone(); let node_runtime = project.read(cx).node_runtime().unwrap().clone(); languages::init(language_registry, node_runtime, cx); - - // Manually create and insert the ProjectIndex - let project_index = cx.new_model(|cx| { - ProjectIndex::new( - project.clone(), - semantic_index.db_connection.clone(), - semantic_index.embedding_provider.clone(), - cx, - ) - }); - semantic_index - .project_indices - .insert(project.downgrade(), project_index); + semantic_index.create_project_index(project.clone(), cx) }); - let project_index = cx - .update(|_cx| { - semantic_index - .project_indices - .get(&project.downgrade()) - .cloned() - }) - .unwrap(); - cx.run_until_parked(); while cx .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))