From 91ffa02e2c7ee30b9a172ce5944ad96a747a453e Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Fri, 13 Sep 2024 13:17:49 -0400 Subject: [PATCH] /auto (#16696) Add `/auto` behind a feature flag that's disabled for now, even for staff. We've decided on a different design for context inference, but there are parts of /auto that will be useful for that, so we want them in the code base even if they're unused for now. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra Co-authored-by: Marshall Bowers --- Cargo.lock | 27 + Cargo.toml | 2 + crates/assistant/src/assistant.rs | 23 +- crates/assistant/src/assistant_panel.rs | 14 + crates/assistant/src/assistant_settings.rs | 1 + crates/assistant/src/slash_command.rs | 1 + .../src/slash_command/auto_command.rs | 360 ++++++ .../slash_command/prompt_after_summary.txt | 24 + .../slash_command/prompt_before_summary.txt | 31 + .../src/slash_command/search_command.rs | 9 +- crates/collab/k8s/collab.template.yml | 10 +- crates/collab/src/db/queries/projects.rs | 5 + crates/collab/src/db/queries/rooms.rs | 5 + crates/collab/src/lib.rs | 8 +- crates/collab/src/llm.rs | 4 +- crates/collab/src/llm/db/queries/providers.rs | 13 +- crates/collab/src/llm/db/seed.rs | 9 + crates/collab/src/tests/test_server.rs | 4 +- crates/feature_flags/Cargo.toml | 1 + crates/feature_flags/src/feature_flags.rs | 58 +- crates/fs/src/fs.rs | 20 +- crates/git/src/status.rs | 1 - crates/http_client/src/http_client.rs | 4 + .../language_model/src/model/cloud_model.rs | 4 +- .../language_model/src/provider/anthropic.rs | 2 +- crates/language_model/src/provider/google.rs | 4 +- crates/language_model/src/provider/open_ai.rs | 2 +- crates/language_model/src/registry.rs | 6 +- crates/project_panel/src/project_panel.rs | 1 + crates/proto/proto/zed.proto | 1 + crates/semantic_index/Cargo.toml | 4 + crates/semantic_index/examples/index.rs | 5 +- crates/semantic_index/src/embedding.rs | 12 +- crates/semantic_index/src/embedding_index.rs | 469 +++++++ crates/semantic_index/src/indexing.rs | 49 + crates/semantic_index/src/project_index.rs | 523 ++++++++ .../src/project_index_debug_view.rs | 16 +- crates/semantic_index/src/semantic_index.rs | 1135 ++--------------- crates/semantic_index/src/summary_backlog.rs | 48 + crates/semantic_index/src/summary_index.rs | 693 ++++++++++ crates/semantic_index/src/worktree_index.rs | 217 ++++ crates/worktree/src/worktree.rs | 5 + 42 files changed, 2776 insertions(+), 1054 deletions(-) create mode 100644 crates/assistant/src/slash_command/auto_command.rs create mode 100644 crates/assistant/src/slash_command/prompt_after_summary.txt create mode 100644 crates/assistant/src/slash_command/prompt_before_summary.txt create mode 100644 crates/semantic_index/src/embedding_index.rs create mode 100644 crates/semantic_index/src/indexing.rs create mode 100644 crates/semantic_index/src/project_index.rs create mode 100644 crates/semantic_index/src/summary_backlog.rs create mode 100644 crates/semantic_index/src/summary_index.rs create mode 100644 crates/semantic_index/src/worktree_index.rs diff --git a/Cargo.lock b/Cargo.lock index 5eaf3ddde1..793cb66ad7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -304,6 +304,9 @@ name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +dependencies = [ + "serde", +] [[package]] name = "as-raw-xcb-connection" @@ -1709,6 +1712,19 @@ dependencies = [ "profiling", ] +[[package]] +name = "blake3" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block" version = "0.1.6" @@ -2752,6 +2768,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "context_servers" version = "0.1.0" @@ -4187,6 +4209,7 @@ dependencies = [ name = "feature_flags" version = "0.1.0" dependencies = [ + "futures 0.3.30", "gpui", ] @@ -9814,10 +9837,13 @@ name = "semantic_index" version = "0.1.0" dependencies = [ "anyhow", + "arrayvec", + "blake3", "client", "clock", "collections", "env_logger", + "feature_flags", "fs", "futures 0.3.30", "futures-batch", @@ -9825,6 +9851,7 @@ dependencies = [ "heed", "http_client", "language", + "language_model", "languages", "log", "open_ai", diff --git a/Cargo.toml b/Cargo.toml index 79f5ce2dcf..53109002fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -309,6 +309,7 @@ aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/alacritty/alacritty", rev = "91d034ff8b53867143c005acfaa14609147c9a2c" } any_vec = "0.14" anyhow = "1.0.86" +arrayvec = { version = "0.7.4", features = ["serde"] } ashpd = "0.9.1" async-compression = { version = "0.4", features = ["gzip", "futures-io"] } async-dispatcher = "0.1" @@ -325,6 +326,7 @@ bitflags = "2.6.0" blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" } blade-macros = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" } blade-util = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" } +blake3 = "1.5.3" cargo_metadata = "0.18" cargo_toml = "0.20" chrono = { version = "0.4", features = ["serde"] } diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 70e37ba239..7a73c188ec 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -37,13 +37,13 @@ use language_model::{ pub(crate) use model_selector::*; pub use prompts::PromptBuilder; use prompts::PromptLoadingParams; -use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; +use semantic_index::{CloudEmbeddingProvider, SemanticDb}; use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsStore}; use slash_command::{ - context_server_command, default_command, diagnostics_command, docs_command, fetch_command, - file_command, now_command, project_command, prompt_command, search_command, symbols_command, - tab_command, terminal_command, workflow_command, + auto_command, context_server_command, default_command, diagnostics_command, docs_command, + fetch_command, file_command, now_command, project_command, prompt_command, search_command, + symbols_command, tab_command, terminal_command, workflow_command, }; use std::path::PathBuf; use std::sync::Arc; @@ -210,12 +210,13 @@ pub fn init( let client = client.clone(); async move { let embedding_provider = CloudEmbeddingProvider::new(client.clone()); - let semantic_index = SemanticIndex::new( + let semantic_index = SemanticDb::new( paths::embeddings_dir().join("semantic-index-db.0.mdb"), Arc::new(embedding_provider), &mut cx, ) .await?; + cx.update(|cx| cx.set_global(semantic_index)) } }) @@ -364,6 +365,7 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) { fn register_slash_commands(prompt_builder: Option>, cx: &mut AppContext) { let slash_command_registry = SlashCommandRegistry::global(cx); + slash_command_registry.register_command(file_command::FileSlashCommand, true); slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true); slash_command_registry.register_command(tab_command::TabSlashCommand, true); @@ -382,6 +384,17 @@ fn register_slash_commands(prompt_builder: Option>, cx: &mut } slash_command_registry.register_command(fetch_command::FetchSlashCommand, false); + cx.observe_flag::({ + let slash_command_registry = slash_command_registry.clone(); + move |is_enabled, _cx| { + if is_enabled { + // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped + slash_command_registry.register_command(auto_command::AutoCommand, true); + } + } + }) + .detach(); + update_slash_commands_from_settings(cx); cx.observe_global::(update_slash_commands_from_settings) .detach(); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 634f2231cd..51c9aa9b4e 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -4723,6 +4723,20 @@ impl Render for ContextEditorToolbarItem { let weak_self = cx.view().downgrade(); let right_side = h_flex() .gap_2() + // TODO display this in a nicer way, once we have a design for it. + // .children({ + // let project = self + // .workspace + // .upgrade() + // .map(|workspace| workspace.read(cx).project().downgrade()); + // + // let scan_items_remaining = cx.update_global(|db: &mut SemanticDb, cx| { + // project.and_then(|project| db.remaining_summaries(&project, cx)) + // }); + + // scan_items_remaining + // .map(|remaining_items| format!("Files to scan: {}", remaining_items)) + // }) .child( ModelSelector::new( self.fs.clone(), diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 3e326886d5..7939eacd93 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -519,6 +519,7 @@ impl Settings for AssistantSettings { &mut settings.default_model, value.default_model.map(Into::into), ); + // merge(&mut settings.infer_context, value.infer_context); TODO re-enable this once we ship context inference } Ok(settings) diff --git a/crates/assistant/src/slash_command.rs b/crates/assistant/src/slash_command.rs index b1a97688b2..387e8231e4 100644 --- a/crates/assistant/src/slash_command.rs +++ b/crates/assistant/src/slash_command.rs @@ -19,6 +19,7 @@ use std::{ use ui::ActiveTheme; use workspace::Workspace; +pub mod auto_command; pub mod context_server_command; pub mod default_command; pub mod diagnostics_command; diff --git a/crates/assistant/src/slash_command/auto_command.rs b/crates/assistant/src/slash_command/auto_command.rs new file mode 100644 index 0000000000..cedfc63702 --- /dev/null +++ b/crates/assistant/src/slash_command/auto_command.rs @@ -0,0 +1,360 @@ +use super::create_label_for_command; +use super::{SlashCommand, SlashCommandOutput}; +use anyhow::{anyhow, Result}; +use assistant_slash_command::ArgumentCompletion; +use feature_flags::FeatureFlag; +use futures::StreamExt; +use gpui::{AppContext, AsyncAppContext, Task, WeakView}; +use language::{CodeLabel, LspAdapterDelegate}; +use language_model::{ + LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, Role, +}; +use semantic_index::{FileSummary, SemanticDb}; +use smol::channel; +use std::sync::{atomic::AtomicBool, Arc}; +use ui::{BorrowAppContext, WindowContext}; +use util::ResultExt; +use workspace::Workspace; + +pub struct AutoSlashCommandFeatureFlag; + +impl FeatureFlag for AutoSlashCommandFeatureFlag { + const NAME: &'static str = "auto-slash-command"; +} + +pub(crate) struct AutoCommand; + +impl SlashCommand for AutoCommand { + fn name(&self) -> String { + "auto".into() + } + + fn description(&self) -> String { + "Automatically infer what context to add, based on your prompt".into() + } + + fn menu_text(&self) -> String { + "Automatically Infer Context".into() + } + + fn label(&self, cx: &AppContext) -> CodeLabel { + create_label_for_command("auto", &["--prompt"], cx) + } + + fn complete_argument( + self: Arc, + _arguments: &[String], + _cancel: Arc, + workspace: Option>, + cx: &mut WindowContext, + ) -> Task>> { + // There's no autocomplete for a prompt, since it's arbitrary text. + // However, we can use this opportunity to kick off a drain of the backlog. + // That way, it can hopefully be done resummarizing by the time we've actually + // typed out our prompt. This re-runs on every keystroke during autocomplete, + // but in the future, we could instead do it only once, when /auto is first entered. + let Some(workspace) = workspace.and_then(|ws| ws.upgrade()) else { + log::warn!("workspace was dropped or unavailable during /auto autocomplete"); + + return Task::ready(Ok(Vec::new())); + }; + + let project = workspace.read(cx).project().clone(); + let Some(project_index) = + cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx)) + else { + return Task::ready(Err(anyhow!("No project indexer, cannot use /auto"))); + }; + + let cx: &mut AppContext = cx; + + cx.spawn(|cx: gpui::AsyncAppContext| async move { + let task = project_index.read_with(&cx, |project_index, cx| { + project_index.flush_summary_backlogs(cx) + })?; + + cx.background_executor().spawn(task).await; + + anyhow::Ok(Vec::new()) + }) + } + + fn requires_argument(&self) -> bool { + true + } + + fn run( + self: Arc, + arguments: &[String], + workspace: WeakView, + _delegate: Option>, + cx: &mut WindowContext, + ) -> Task> { + let Some(workspace) = workspace.upgrade() else { + return Task::ready(Err(anyhow::anyhow!("workspace was dropped"))); + }; + if arguments.is_empty() { + return Task::ready(Err(anyhow!("missing prompt"))); + }; + let argument = arguments.join(" "); + let original_prompt = argument.to_string(); + let project = workspace.read(cx).project().clone(); + let Some(project_index) = + cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx)) + else { + return Task::ready(Err(anyhow!("no project indexer"))); + }; + + let task = cx.spawn(|cx: gpui::AsyncWindowContext| async move { + let summaries = project_index + .read_with(&cx, |project_index, cx| project_index.all_summaries(cx))? + .await?; + + commands_for_summaries(&summaries, &original_prompt, &cx).await + }); + + // As a convenience, append /auto's argument to the end of the prompt + // so you don't have to write it again. + let original_prompt = argument.to_string(); + + cx.background_executor().spawn(async move { + let commands = task.await?; + let mut prompt = String::new(); + + log::info!( + "Translating this response into slash-commands: {:?}", + commands + ); + + for command in commands { + prompt.push('/'); + prompt.push_str(&command.name); + prompt.push(' '); + prompt.push_str(&command.arg); + prompt.push('\n'); + } + + prompt.push('\n'); + prompt.push_str(&original_prompt); + + Ok(SlashCommandOutput { + text: prompt, + sections: Vec::new(), + run_commands_in_text: true, + }) + }) + } +} + +const PROMPT_INSTRUCTIONS_BEFORE_SUMMARY: &str = include_str!("prompt_before_summary.txt"); +const PROMPT_INSTRUCTIONS_AFTER_SUMMARY: &str = include_str!("prompt_after_summary.txt"); + +fn summaries_prompt(summaries: &[FileSummary], original_prompt: &str) -> String { + let json_summaries = serde_json::to_string(summaries).unwrap(); + + format!("{PROMPT_INSTRUCTIONS_BEFORE_SUMMARY}\n{json_summaries}\n{PROMPT_INSTRUCTIONS_AFTER_SUMMARY}\n{original_prompt}") +} + +/// The slash commands that the model is told about, and which we look for in the inference response. +const SUPPORTED_SLASH_COMMANDS: &[&str] = &["search", "file"]; + +#[derive(Debug, Clone)] +struct CommandToRun { + name: String, + arg: String, +} + +/// Given the pre-indexed file summaries for this project, as well as the original prompt +/// string passed to `/auto`, get a list of slash commands to run, along with their arguments. +/// +/// The prompt's output does not include the slashes (to reduce the chance that it makes a mistake), +/// so taking one of these returned Strings and turning it into a real slash-command-with-argument +/// involves prepending a slash to it. +/// +/// This function will validate that each of the returned lines begins with one of SUPPORTED_SLASH_COMMANDS. +/// Any other lines it encounters will be discarded, with a warning logged. +async fn commands_for_summaries( + summaries: &[FileSummary], + original_prompt: &str, + cx: &AsyncAppContext, +) -> Result> { + if summaries.is_empty() { + log::warn!("Inferring no context because there were no summaries available."); + return Ok(Vec::new()); + } + + // Use the globally configured model to translate the summaries into slash-commands, + // because Qwen2-7B-Instruct has not done a good job at that task. + let Some(model) = cx.update(|cx| LanguageModelRegistry::read_global(cx).active_model())? else { + log::warn!("Can't infer context because there's no active model."); + return Ok(Vec::new()); + }; + // Only go up to 90% of the actual max token count, to reduce chances of + // exceeding the token count due to inaccuracies in the token counting heuristic. + let max_token_count = (model.max_token_count() * 9) / 10; + + // Rather than recursing (which would require this async function use a pinned box), + // we use an explicit stack of arguments and answers for when we need to "recurse." + let mut stack = vec![summaries]; + let mut final_response = Vec::new(); + let mut prompts = Vec::new(); + + // TODO We only need to create multiple Requests because we currently + // don't have the ability to tell if a CompletionProvider::complete response + // was a "too many tokens in this request" error. If we had that, then + // we could try the request once, instead of having to make separate requests + // to check the token count and then afterwards to run the actual prompt. + let make_request = |prompt: String| LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![prompt.into()], + // Nothing in here will benefit from caching + cache: false, + }], + tools: Vec::new(), + stop: Vec::new(), + temperature: 1.0, + }; + + while let Some(current_summaries) = stack.pop() { + // The split can result in one slice being empty and the other having one element. + // Whenever that happens, skip the empty one. + if current_summaries.is_empty() { + continue; + } + + log::info!( + "Inferring prompt context using {} file summaries", + current_summaries.len() + ); + + let prompt = summaries_prompt(¤t_summaries, original_prompt); + let start = std::time::Instant::now(); + // Per OpenAI, 1 token ~= 4 chars in English (we go with 4.5 to overestimate a bit, because failed API requests cost a lot of perf) + // Verifying this against an actual model.count_tokens() confirms that it's usually within ~5% of the correct answer, whereas + // getting the correct answer from tiktoken takes hundreds of milliseconds (compared to this arithmetic being ~free). + // source: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them + let token_estimate = prompt.len() * 2 / 9; + let duration = start.elapsed(); + log::info!( + "Time taken to count tokens for prompt of length {:?}B: {:?}", + prompt.len(), + duration + ); + + if token_estimate < max_token_count { + prompts.push(prompt); + } else if current_summaries.len() == 1 { + log::warn!("Inferring context for a single file's summary failed because the prompt's token length exceeded the model's token limit."); + } else { + log::info!( + "Context inference using file summaries resulted in a prompt containing {token_estimate} tokens, which exceeded the model's max of {max_token_count}. Retrying as two separate prompts, each including half the number of summaries.", + ); + let (left, right) = current_summaries.split_at(current_summaries.len() / 2); + stack.push(right); + stack.push(left); + } + } + + let all_start = std::time::Instant::now(); + + let (tx, rx) = channel::bounded(1024); + + let completion_streams = prompts + .into_iter() + .map(|prompt| { + let request = make_request(prompt.clone()); + let model = model.clone(); + let tx = tx.clone(); + let stream = model.stream_completion(request, &cx); + + (stream, tx) + }) + .collect::>(); + + cx.background_executor() + .spawn(async move { + let futures = completion_streams + .into_iter() + .enumerate() + .map(|(ix, (stream, tx))| async move { + let start = std::time::Instant::now(); + let events = stream.await?; + log::info!("Time taken for awaiting /await chunk stream #{ix}: {:?}", start.elapsed()); + + let completion: String = events + .filter_map(|event| async { + if let Ok(LanguageModelCompletionEvent::Text(text)) = event { + Some(text) + } else { + None + } + }) + .collect() + .await; + + log::info!("Time taken for all /auto chunks to come back for #{ix}: {:?}", start.elapsed()); + + for line in completion.split('\n') { + if let Some(first_space) = line.find(' ') { + let command = &line[..first_space].trim(); + let arg = &line[first_space..].trim(); + + tx.send(CommandToRun { + name: command.to_string(), + arg: arg.to_string(), + }) + .await?; + } else if !line.trim().is_empty() { + // All slash-commands currently supported in context inference need a space for the argument. + log::warn!( + "Context inference returned a non-blank line that contained no spaces (meaning no argument for the slash command): {:?}", + line + ); + } + } + + anyhow::Ok(()) + }) + .collect::>(); + + let _ = futures::future::try_join_all(futures).await.log_err(); + + let duration = all_start.elapsed(); + eprintln!("All futures completed in {:?}", duration); + }) + .await; + + drop(tx); // Close the channel so that rx.collect() won't hang. This is safe because all futures have completed. + let results = rx.collect::>().await; + eprintln!( + "Finished collecting from the channel with {} results", + results.len() + ); + for command in results { + // Don't return empty or duplicate commands + if !command.name.is_empty() + && !final_response + .iter() + .any(|cmd: &CommandToRun| cmd.name == command.name && cmd.arg == command.arg) + { + if SUPPORTED_SLASH_COMMANDS + .iter() + .any(|supported| &command.name == supported) + { + final_response.push(command); + } else { + log::warn!( + "Context inference returned an unrecognized slash command: {:?}", + command + ); + } + } + } + + // Sort the commands by name (reversed just so that /search appears before /file) + final_response.sort_by(|cmd1, cmd2| cmd1.name.cmp(&cmd2.name).reverse()); + + Ok(final_response) +} diff --git a/crates/assistant/src/slash_command/prompt_after_summary.txt b/crates/assistant/src/slash_command/prompt_after_summary.txt new file mode 100644 index 0000000000..fc139a1fcb --- /dev/null +++ b/crates/assistant/src/slash_command/prompt_after_summary.txt @@ -0,0 +1,24 @@ +Actions have a cost, so only include actions that you think +will be helpful to you in doing a great job answering the +prompt in the future. + +You must respond ONLY with a list of actions you would like to +perform. Each action should be on its own line, and followed by a space and then its parameter. + +Actions can be performed more than once with different parameters. +Here is an example valid response: + +``` +file path/to/my/file.txt +file path/to/another/file.txt +search something to search for +search something else to search for +``` + +Once again, do not forget: you must respond ONLY in the format of +one action per line, and the action name should be followed by +its parameter. Your response must not include anything other +than a list of actions, with one action per line, in this format. +It is extremely important that you do not deviate from this format even slightly! + +This is the end of my instructions for how to respond. The rest is the prompt: diff --git a/crates/assistant/src/slash_command/prompt_before_summary.txt b/crates/assistant/src/slash_command/prompt_before_summary.txt new file mode 100644 index 0000000000..5d8db1b8f7 --- /dev/null +++ b/crates/assistant/src/slash_command/prompt_before_summary.txt @@ -0,0 +1,31 @@ +I'm going to give you a prompt. I don't want you to respond +to the prompt itself. I want you to figure out which of the following +actions on my project, if any, would help you answer the prompt. + +Here are the actions: + +## file + +This action's parameter is a file path to one of the files +in the project. If you ask for this action, I will tell you +the full contents of the file, so you can learn all the +details of the file. + +## search + +This action's parameter is a string to do a semantic search for +across the files in the project. (You will have a JSON summary +of all the files in the project.) It will tell you which files this string +(or similar strings; it is a semantic search) appear in, +as well as some context of the lines surrounding each result. +It's very important that you only use this action when you think +that searching across the specific files in this project for the query +in question will be useful. For example, don't use this command to search +for queries you might put into a general Web search engine, because those +will be too general to give useful results in this project-specific search. + +--- + +That was the end of the list of actions. + +Here is a JSON summary of each of the files in my project: diff --git a/crates/assistant/src/slash_command/search_command.rs b/crates/assistant/src/slash_command/search_command.rs index 4da8a5585f..3a513ed9ad 100644 --- a/crates/assistant/src/slash_command/search_command.rs +++ b/crates/assistant/src/slash_command/search_command.rs @@ -8,7 +8,7 @@ use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection}; use feature_flags::FeatureFlag; use gpui::{AppContext, Task, WeakView}; use language::{CodeLabel, LineEnding, LspAdapterDelegate}; -use semantic_index::SemanticIndex; +use semantic_index::SemanticDb; use std::{ fmt::Write, path::PathBuf, @@ -92,8 +92,11 @@ impl SlashCommand for SearchSlashCommand { let project = workspace.read(cx).project().clone(); let fs = project.read(cx).fs().clone(); - let project_index = - cx.update_global(|index: &mut SemanticIndex, cx| index.project_index(project, cx)); + let Some(project_index) = + cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx)) + else { + return Task::ready(Err(anyhow::anyhow!("no project indexer"))); + }; cx.spawn(|cx| async move { let results = project_index diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index dcd935166a..f5e454c3fc 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -149,16 +149,16 @@ spec: secretKeyRef: name: google-ai key: api_key - - name: QWEN2_7B_API_KEY + - name: RUNPOD_API_KEY valueFrom: secretKeyRef: - name: hugging-face + name: runpod key: api_key - - name: QWEN2_7B_API_URL + - name: RUNPOD_API_SUMMARY_URL valueFrom: secretKeyRef: - name: hugging-face - key: qwen2_api_url + name: runpod + key: summary - name: BLOB_STORE_ACCESS_KEY valueFrom: secretKeyRef: diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index a6956c8496..c6db54b572 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -728,6 +728,11 @@ impl Database { is_ignored: db_entry.is_ignored, is_external: db_entry.is_external, git_status: db_entry.git_status.map(|status| status as i32), + // This is only used in the summarization backlog, so if it's None, + // that just means we won't be able to detect when to resummarize + // based on total number of backlogged bytes - instead, we'd go + // on number of files only. That shouldn't be a huge deal in practice. + size: None, is_fifo: db_entry.is_fifo, }); } diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index 1669ddbb3b..635e2d232f 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -663,6 +663,11 @@ impl Database { is_ignored: db_entry.is_ignored, is_external: db_entry.is_external, git_status: db_entry.git_status.map(|status| status as i32), + // This is only used in the summarization backlog, so if it's None, + // that just means we won't be able to detect when to resummarize + // based on total number of backlogged bytes - instead, we'd go + // on number of files only. That shouldn't be a huge deal in practice. + size: None, is_fifo: db_entry.is_fifo, }); } diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 461adc3575..81ff3ff21f 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -170,8 +170,8 @@ pub struct Config { pub anthropic_api_key: Option>, pub anthropic_staff_api_key: Option>, pub llm_closed_beta_model_name: Option>, - pub qwen2_7b_api_key: Option>, - pub qwen2_7b_api_url: Option>, + pub runpod_api_key: Option>, + pub runpod_api_summary_url: Option>, pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, @@ -235,8 +235,8 @@ impl Config { stripe_api_key: None, stripe_price_id: None, supermaven_admin_api_key: None, - qwen2_7b_api_key: None, - qwen2_7b_api_url: None, + runpod_api_key: None, + runpod_api_summary_url: None, user_backfiller_github_access_token: None, } } diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index e1a3454368..def4499ae4 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -402,12 +402,12 @@ async fn perform_completion( LanguageModelProvider::Zed => { let api_key = state .config - .qwen2_7b_api_key + .runpod_api_key .as_ref() .context("no Qwen2-7B API key configured on the server")?; let api_url = state .config - .qwen2_7b_api_url + .runpod_api_summary_url .as_ref() .context("no Qwen2-7B URL configured on the server")?; let chunks = open_ai::stream_completion( diff --git a/crates/collab/src/llm/db/queries/providers.rs b/crates/collab/src/llm/db/queries/providers.rs index 8a73b399c6..7e51061cee 100644 --- a/crates/collab/src/llm/db/queries/providers.rs +++ b/crates/collab/src/llm/db/queries/providers.rs @@ -1,5 +1,5 @@ use super::*; -use sea_orm::QueryOrder; +use sea_orm::{sea_query::OnConflict, QueryOrder}; use std::str::FromStr; use strum::IntoEnumIterator as _; @@ -99,6 +99,17 @@ impl LlmDatabase { ..Default::default() } })) + .on_conflict( + OnConflict::columns([model::Column::ProviderId, model::Column::Name]) + .update_columns([ + model::Column::MaxRequestsPerMinute, + model::Column::MaxTokensPerMinute, + model::Column::MaxTokensPerDay, + model::Column::PricePerMillionInputTokens, + model::Column::PricePerMillionOutputTokens, + ]) + .to_owned(), + ) .exec_without_returning(&*tx) .await?; Ok(()) diff --git a/crates/collab/src/llm/db/seed.rs b/crates/collab/src/llm/db/seed.rs index 55c6c30cd5..24bc224227 100644 --- a/crates/collab/src/llm/db/seed.rs +++ b/crates/collab/src/llm/db/seed.rs @@ -40,6 +40,15 @@ pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) price_per_million_input_tokens: 25, // $0.25/MTok price_per_million_output_tokens: 125, // $1.25/MTok }, + ModelParams { + provider: LanguageModelProvider::Zed, + name: "Qwen/Qwen2-7B-Instruct".into(), + max_requests_per_minute: 5, + max_tokens_per_minute: 25_000, // These are arbitrary limits we've set to cap costs; we control this number + max_tokens_per_day: 300_000, + price_per_million_input_tokens: 25, + price_per_million_output_tokens: 125, + }, ]) .await } diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index e691afceda..1421e4c7f7 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -679,8 +679,8 @@ impl TestServer { stripe_api_key: None, stripe_price_id: None, supermaven_admin_api_key: None, - qwen2_7b_api_key: None, - qwen2_7b_api_url: None, + runpod_api_key: None, + runpod_api_summary_url: None, user_backfiller_github_access_token: None, }, }) diff --git a/crates/feature_flags/Cargo.toml b/crates/feature_flags/Cargo.toml index 101e90c646..834e315af3 100644 --- a/crates/feature_flags/Cargo.toml +++ b/crates/feature_flags/Cargo.toml @@ -13,3 +13,4 @@ path = "src/feature_flags.rs" [dependencies] gpui.workspace = true +futures.workspace = true diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 29768138af..fb4e192023 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -1,4 +1,10 @@ +use futures::{channel::oneshot, FutureExt as _}; use gpui::{AppContext, Global, Subscription, ViewContext}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; #[derive(Default)] struct FeatureFlags { @@ -53,6 +59,15 @@ impl FeatureFlag for ZedPro { const NAME: &'static str = "zed-pro"; } +pub struct AutoCommand {} +impl FeatureFlag for AutoCommand { + const NAME: &'static str = "auto-command"; + + fn enabled_for_staff() -> bool { + false + } +} + pub trait FeatureFlagViewExt { fn observe_flag(&mut self, callback: F) -> Subscription where @@ -75,6 +90,7 @@ where } pub trait FeatureFlagAppExt { + fn wait_for_flag(&mut self) -> WaitForFlag; fn update_flags(&mut self, staff: bool, flags: Vec); fn set_staff(&mut self, staff: bool); fn has_flag(&self) -> bool; @@ -82,7 +98,7 @@ pub trait FeatureFlagAppExt { fn observe_flag(&mut self, callback: F) -> Subscription where - F: Fn(bool, &mut AppContext) + 'static; + F: FnMut(bool, &mut AppContext) + 'static; } impl FeatureFlagAppExt for AppContext { @@ -109,13 +125,49 @@ impl FeatureFlagAppExt for AppContext { .unwrap_or(false) } - fn observe_flag(&mut self, callback: F) -> Subscription + fn observe_flag(&mut self, mut callback: F) -> Subscription where - F: Fn(bool, &mut AppContext) + 'static, + F: FnMut(bool, &mut AppContext) + 'static, { self.observe_global::(move |cx| { let feature_flags = cx.global::(); callback(feature_flags.has_flag::(), cx); }) } + + fn wait_for_flag(&mut self) -> WaitForFlag { + let (tx, rx) = oneshot::channel::(); + let mut tx = Some(tx); + let subscription: Option; + + match self.try_global::() { + Some(feature_flags) => { + subscription = None; + tx.take().unwrap().send(feature_flags.has_flag::()).ok(); + } + None => { + subscription = Some(self.observe_global::(move |cx| { + let feature_flags = cx.global::(); + if let Some(tx) = tx.take() { + tx.send(feature_flags.has_flag::()).ok(); + } + })); + } + } + + WaitForFlag(rx, subscription) + } +} + +pub struct WaitForFlag(oneshot::Receiver, Option); + +impl Future for WaitForFlag { + type Output = bool; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.poll_unpin(cx).map(|result| { + self.1.take(); + result.unwrap_or(false) + }) + } } diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index 0ec5a4c601..b649831fd2 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -171,6 +171,7 @@ pub struct Metadata { pub mtime: SystemTime, pub is_symlink: bool, pub is_dir: bool, + pub len: u64, pub is_fifo: bool, } @@ -497,6 +498,7 @@ impl Fs for RealFs { Ok(Some(Metadata { inode, mtime: metadata.modified().unwrap(), + len: metadata.len(), is_symlink, is_dir: metadata.file_type().is_dir(), is_fifo, @@ -800,11 +802,13 @@ enum FakeFsEntry { File { inode: u64, mtime: SystemTime, + len: u64, content: Vec, }, Dir { inode: u64, mtime: SystemTime, + len: u64, entries: BTreeMap>>, git_repo_state: Option>>, }, @@ -935,6 +939,7 @@ impl FakeFs { root: Arc::new(Mutex::new(FakeFsEntry::Dir { inode: 0, mtime: SystemTime::UNIX_EPOCH, + len: 0, entries: Default::default(), git_repo_state: None, })), @@ -969,6 +974,7 @@ impl FakeFs { inode: new_inode, mtime: new_mtime, content: Vec::new(), + len: 0, }))); } btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() { @@ -1016,6 +1022,7 @@ impl FakeFs { let file = Arc::new(Mutex::new(FakeFsEntry::File { inode, mtime, + len: content.len() as u64, content, })); let mut kind = None; @@ -1369,6 +1376,7 @@ impl Fs for FakeFs { Arc::new(Mutex::new(FakeFsEntry::Dir { inode, mtime, + len: 0, entries: Default::default(), git_repo_state: None, })) @@ -1391,6 +1399,7 @@ impl Fs for FakeFs { let file = Arc::new(Mutex::new(FakeFsEntry::File { inode, mtime, + len: 0, content: Vec::new(), })); let mut kind = Some(PathEventKind::Created); @@ -1539,6 +1548,7 @@ impl Fs for FakeFs { e.insert(Arc::new(Mutex::new(FakeFsEntry::File { inode, mtime, + len: content.len() as u64, content: Vec::new(), }))) .clone(), @@ -1694,16 +1704,22 @@ impl Fs for FakeFs { let entry = entry.lock(); Ok(Some(match &*entry { - FakeFsEntry::File { inode, mtime, .. } => Metadata { + FakeFsEntry::File { + inode, mtime, len, .. + } => Metadata { inode: *inode, mtime: *mtime, + len: *len, is_dir: false, is_symlink, is_fifo: false, }, - FakeFsEntry::Dir { inode, mtime, .. } => Metadata { + FakeFsEntry::Dir { + inode, mtime, len, .. + } => Metadata { inode: *inode, mtime: *mtime, + len: *len, is_dir: true, is_symlink, is_fifo: false, diff --git a/crates/git/src/status.rs b/crates/git/src/status.rs index e6098ffd3c..6eb98ecefe 100644 --- a/crates/git/src/status.rs +++ b/crates/git/src/status.rs @@ -57,7 +57,6 @@ impl GitStatus { let stderr = String::from_utf8_lossy(&output.stderr); return Err(anyhow!("git status process failed: {}", stderr)); } - let stdout = String::from_utf8_lossy(&output.stdout); let mut entries = stdout .split('\0') diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 452be0a243..1841a1f394 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -221,6 +221,10 @@ impl HttpClient for HttpClientWithUrl { pub fn client(user_agent: Option, proxy: Option) -> Arc { let mut builder = isahc::HttpClient::builder() + // Some requests to Qwen2 models on Runpod can take 32+ seconds, + // especially if there's a cold boot involved. We may need to have + // those requests use a different http client, because global timeouts + // of 50 and 60 seconds, respectively, would be very high! .connect_timeout(Duration::from_secs(5)) .low_speed_timeout(100, Duration::from_secs(5)) .proxy(proxy.clone()); diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index f36b6b2788..be0812eab9 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -17,14 +17,14 @@ pub enum CloudModel { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)] pub enum ZedModel { - #[serde(rename = "qwen2-7b-instruct")] + #[serde(rename = "Qwen/Qwen2-7B-Instruct")] Qwen2_7bInstruct, } impl ZedModel { pub fn id(&self) -> &str { match self { - ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct", + ZedModel::Qwen2_7bInstruct => "Qwen/Qwen2-7B-Instruct", } } diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index eac4ad3021..9f7135aef7 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -319,7 +319,7 @@ impl AnthropicModel { }; async move { - let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let api_key = api_key.ok_or_else(|| anyhow!("Missing Anthropic API Key"))?; let request = anthropic::stream_completion( http_client.as_ref(), &api_url, diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index fc4a7a7a34..005f35ff8b 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -265,7 +265,7 @@ impl LanguageModel for GoogleLanguageModel { let low_speed_timeout = settings.low_speed_timeout; async move { - let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?; let response = google_ai::count_tokens( http_client.as_ref(), &api_url, @@ -304,7 +304,7 @@ impl LanguageModel for GoogleLanguageModel { }; let future = self.rate_limiter.stream(async move { - let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?; let response = stream_generate_content( http_client.as_ref(), &api_url, diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 3a371499eb..fe5e60caec 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -239,7 +239,7 @@ impl OpenAiLanguageModel { }; let future = self.request_limiter.stream(async move { - let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenAI API Key"))?; let request = stream_completion( http_client.as_ref(), &api_url, diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 589dfe776a..b3c8ef5f57 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -159,11 +159,13 @@ impl LanguageModelRegistry { providers } - pub fn available_models(&self, cx: &AppContext) -> Vec> { + pub fn available_models<'a>( + &'a self, + cx: &'a AppContext, + ) -> impl Iterator> + 'a { self.providers .values() .flat_map(|provider| provider.provided_models(cx)) - .collect() } pub fn provider(&self, id: &LanguageModelProviderId) -> Option> { diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index c77a2170dd..c8e1ce28eb 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -1823,6 +1823,7 @@ impl ProjectPanel { path: entry.path.join("\0").into(), inode: 0, mtime: entry.mtime, + size: entry.size, is_ignored: entry.is_ignored, is_external: false, is_private: false, diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index e5d767fffb..f59e8146b6 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -1855,6 +1855,7 @@ message Entry { bool is_external = 8; optional GitStatus git_status = 9; bool is_fifo = 10; + optional uint64 size = 11; } message RepositoryEntry { diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 4fd3a86b29..c8dbb6a9f5 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -19,14 +19,18 @@ crate-type = ["bin"] [dependencies] anyhow.workspace = true +arrayvec.workspace = true +blake3.workspace = true client.workspace = true clock.workspace = true collections.workspace = true +feature_flags.workspace = true fs.workspace = true futures.workspace = true futures-batch.workspace = true gpui.workspace = true language.workspace = true +language_model.workspace = true log.workspace = true heed.workspace = true http_client.workspace = true diff --git a/crates/semantic_index/examples/index.rs b/crates/semantic_index/examples/index.rs index e536ea1db6..977473d1dc 100644 --- a/crates/semantic_index/examples/index.rs +++ b/crates/semantic_index/examples/index.rs @@ -4,7 +4,7 @@ use gpui::App; use http_client::HttpClientWithUrl; use language::language_settings::AllLanguageSettings; use project::Project; -use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex}; +use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticDb}; use settings::SettingsStore; use std::{ path::{Path, PathBuf}, @@ -50,7 +50,7 @@ fn main() { )); cx.spawn(|mut cx| async move { - let semantic_index = SemanticIndex::new( + let semantic_index = SemanticDb::new( PathBuf::from("/tmp/semantic-index-db.mdb"), embedding_provider, &mut cx, @@ -71,6 +71,7 @@ fn main() { let project_index = cx .update(|cx| semantic_index.project_index(project.clone(), cx)) + .unwrap() .unwrap(); let (tx, rx) = oneshot::channel(); diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index b5195c8911..b05c4ac9da 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -12,6 +12,12 @@ use futures::{future::BoxFuture, FutureExt}; use serde::{Deserialize, Serialize}; use std::{fmt, future}; +/// Trait for embedding providers. Texts in, vectors out. +pub trait EmbeddingProvider: Sync + Send { + fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>>; + fn batch_size(&self) -> usize; +} + #[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] pub struct Embedding(Vec); @@ -68,12 +74,6 @@ impl fmt::Display for Embedding { } } -/// Trait for embedding providers. Texts in, vectors out. -pub trait EmbeddingProvider: Sync + Send { - fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>>; - fn batch_size(&self) -> usize; -} - #[derive(Debug)] pub struct TextToEmbed<'a> { pub text: &'a str, diff --git a/crates/semantic_index/src/embedding_index.rs b/crates/semantic_index/src/embedding_index.rs new file mode 100644 index 0000000000..dd7c58dc11 --- /dev/null +++ b/crates/semantic_index/src/embedding_index.rs @@ -0,0 +1,469 @@ +use crate::{ + chunking::{self, Chunk}, + embedding::{Embedding, EmbeddingProvider, TextToEmbed}, + indexing::{IndexingEntryHandle, IndexingEntrySet}, +}; +use anyhow::{anyhow, Context as _, Result}; +use collections::Bound; +use fs::Fs; +use futures::stream::StreamExt; +use futures_batch::ChunksTimeoutStreamExt; +use gpui::{AppContext, Model, Task}; +use heed::types::{SerdeBincode, Str}; +use language::LanguageRegistry; +use log; +use project::{Entry, UpdatedEntriesSet, Worktree}; +use serde::{Deserialize, Serialize}; +use smol::channel; +use std::{ + cmp::Ordering, + future::Future, + iter, + path::Path, + sync::Arc, + time::{Duration, SystemTime}, +}; +use util::ResultExt; +use worktree::Snapshot; + +pub struct EmbeddingIndex { + worktree: Model, + db_connection: heed::Env, + db: heed::Database>, + fs: Arc, + language_registry: Arc, + embedding_provider: Arc, + entry_ids_being_indexed: Arc, +} + +impl EmbeddingIndex { + pub fn new( + worktree: Model, + fs: Arc, + db_connection: heed::Env, + embedding_db: heed::Database>, + language_registry: Arc, + embedding_provider: Arc, + entry_ids_being_indexed: Arc, + ) -> Self { + Self { + worktree, + fs, + db_connection, + db: embedding_db, + language_registry, + embedding_provider, + entry_ids_being_indexed, + } + } + + pub fn db(&self) -> &heed::Database> { + &self.db + } + + pub fn index_entries_changed_on_disk( + &self, + cx: &AppContext, + ) -> impl Future> { + let worktree = self.worktree.read(cx).snapshot(); + let worktree_abs_path = worktree.abs_path().clone(); + let scan = self.scan_entries(worktree, cx); + let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); + let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx); + let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); + async move { + futures::try_join!(scan.task, chunk.task, embed.task, persist)?; + Ok(()) + } + } + + pub fn index_updated_entries( + &self, + updated_entries: UpdatedEntriesSet, + cx: &AppContext, + ) -> impl Future> { + let worktree = self.worktree.read(cx).snapshot(); + let worktree_abs_path = worktree.abs_path().clone(); + let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx); + let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); + let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx); + let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); + async move { + futures::try_join!(scan.task, chunk.task, embed.task, persist)?; + Ok(()) + } + } + + fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries { + let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); + let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); + let db_connection = self.db_connection.clone(); + let db = self.db; + let entries_being_indexed = self.entry_ids_being_indexed.clone(); + let task = cx.background_executor().spawn(async move { + let txn = db_connection + .read_txn() + .context("failed to create read transaction")?; + let mut db_entries = db + .iter(&txn) + .context("failed to create iterator")? + .move_between_keys() + .peekable(); + + let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None; + for entry in worktree.files(false, 0) { + log::trace!("scanning for embedding index: {:?}", &entry.path); + + let entry_db_key = db_key_for_path(&entry.path); + + let mut saved_mtime = None; + while let Some(db_entry) = db_entries.peek() { + match db_entry { + Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) { + Ordering::Less => { + if let Some(deletion_range) = deletion_range.as_mut() { + deletion_range.1 = Bound::Included(db_path); + } else { + deletion_range = + Some((Bound::Included(db_path), Bound::Included(db_path))); + } + + db_entries.next(); + } + Ordering::Equal => { + if let Some(deletion_range) = deletion_range.take() { + deleted_entry_ranges_tx + .send(( + deletion_range.0.map(ToString::to_string), + deletion_range.1.map(ToString::to_string), + )) + .await?; + } + saved_mtime = db_embedded_file.mtime; + db_entries.next(); + break; + } + Ordering::Greater => { + break; + } + }, + Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?, + } + } + + if entry.mtime != saved_mtime { + let handle = entries_being_indexed.insert(entry.id); + updated_entries_tx.send((entry.clone(), handle)).await?; + } + } + + if let Some(db_entry) = db_entries.next() { + let (db_path, _) = db_entry?; + deleted_entry_ranges_tx + .send((Bound::Included(db_path.to_string()), Bound::Unbounded)) + .await?; + } + + Ok(()) + }); + + ScanEntries { + updated_entries: updated_entries_rx, + deleted_entry_ranges: deleted_entry_ranges_rx, + task, + } + } + + fn scan_updated_entries( + &self, + worktree: Snapshot, + updated_entries: UpdatedEntriesSet, + cx: &AppContext, + ) -> ScanEntries { + let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); + let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); + let entries_being_indexed = self.entry_ids_being_indexed.clone(); + let task = cx.background_executor().spawn(async move { + for (path, entry_id, status) in updated_entries.iter() { + match status { + project::PathChange::Added + | project::PathChange::Updated + | project::PathChange::AddedOrUpdated => { + if let Some(entry) = worktree.entry_for_id(*entry_id) { + if entry.is_file() { + let handle = entries_being_indexed.insert(entry.id); + updated_entries_tx.send((entry.clone(), handle)).await?; + } + } + } + project::PathChange::Removed => { + let db_path = db_key_for_path(path); + deleted_entry_ranges_tx + .send((Bound::Included(db_path.clone()), Bound::Included(db_path))) + .await?; + } + project::PathChange::Loaded => { + // Do nothing. + } + } + } + + Ok(()) + }); + + ScanEntries { + updated_entries: updated_entries_rx, + deleted_entry_ranges: deleted_entry_ranges_rx, + task, + } + } + + fn chunk_files( + &self, + worktree_abs_path: Arc, + entries: channel::Receiver<(Entry, IndexingEntryHandle)>, + cx: &AppContext, + ) -> ChunkFiles { + let language_registry = self.language_registry.clone(); + let fs = self.fs.clone(); + let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048); + let task = cx.spawn(|cx| async move { + cx.background_executor() + .scoped(|cx| { + for _ in 0..cx.num_cpus() { + cx.spawn(async { + while let Ok((entry, handle)) = entries.recv().await { + let entry_abs_path = worktree_abs_path.join(&entry.path); + match fs.load(&entry_abs_path).await { + Ok(text) => { + let language = language_registry + .language_for_file_path(&entry.path) + .await + .ok(); + let chunked_file = ChunkedFile { + chunks: chunking::chunk_text( + &text, + language.as_ref(), + &entry.path, + ), + handle, + path: entry.path, + mtime: entry.mtime, + text, + }; + + if chunked_files_tx.send(chunked_file).await.is_err() { + return; + } + } + Err(_)=> { + log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}"); + } + } + } + }); + } + }) + .await; + Ok(()) + }); + + ChunkFiles { + files: chunked_files_rx, + task, + } + } + + pub fn embed_files( + embedding_provider: Arc, + chunked_files: channel::Receiver, + cx: &AppContext, + ) -> EmbedFiles { + let embedding_provider = embedding_provider.clone(); + let (embedded_files_tx, embedded_files_rx) = channel::bounded(512); + let task = cx.background_executor().spawn(async move { + let mut chunked_file_batches = + chunked_files.chunks_timeout(512, Duration::from_secs(2)); + while let Some(chunked_files) = chunked_file_batches.next().await { + // View the batch of files as a vec of chunks + // Flatten out to a vec of chunks that we can subdivide into batch sized pieces + // Once those are done, reassemble them back into the files in which they belong + // If any embeddings fail for a file, the entire file is discarded + + let chunks: Vec = chunked_files + .iter() + .flat_map(|file| { + file.chunks.iter().map(|chunk| TextToEmbed { + text: &file.text[chunk.range.clone()], + digest: chunk.digest, + }) + }) + .collect::>(); + + let mut embeddings: Vec> = Vec::new(); + for embedding_batch in chunks.chunks(embedding_provider.batch_size()) { + if let Some(batch_embeddings) = + embedding_provider.embed(embedding_batch).await.log_err() + { + if batch_embeddings.len() == embedding_batch.len() { + embeddings.extend(batch_embeddings.into_iter().map(Some)); + continue; + } + log::error!( + "embedding provider returned unexpected embedding count {}, expected {}", + batch_embeddings.len(), embedding_batch.len() + ); + } + + embeddings.extend(iter::repeat(None).take(embedding_batch.len())); + } + + let mut embeddings = embeddings.into_iter(); + for chunked_file in chunked_files { + let mut embedded_file = EmbeddedFile { + path: chunked_file.path, + mtime: chunked_file.mtime, + chunks: Vec::new(), + }; + + let mut embedded_all_chunks = true; + for (chunk, embedding) in + chunked_file.chunks.into_iter().zip(embeddings.by_ref()) + { + if let Some(embedding) = embedding { + embedded_file + .chunks + .push(EmbeddedChunk { chunk, embedding }); + } else { + embedded_all_chunks = false; + } + } + + if embedded_all_chunks { + embedded_files_tx + .send((embedded_file, chunked_file.handle)) + .await?; + } + } + } + Ok(()) + }); + + EmbedFiles { + files: embedded_files_rx, + task, + } + } + + fn persist_embeddings( + &self, + mut deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, + embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>, + cx: &AppContext, + ) -> Task> { + let db_connection = self.db_connection.clone(); + let db = self.db; + cx.background_executor().spawn(async move { + while let Some(deletion_range) = deleted_entry_ranges.next().await { + let mut txn = db_connection.write_txn()?; + let start = deletion_range.0.as_ref().map(|start| start.as_str()); + let end = deletion_range.1.as_ref().map(|end| end.as_str()); + log::debug!("deleting embeddings in range {:?}", &(start, end)); + db.delete_range(&mut txn, &(start, end))?; + txn.commit()?; + } + + let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2)); + while let Some(embedded_files) = embedded_files.next().await { + let mut txn = db_connection.write_txn()?; + for (file, _) in &embedded_files { + log::debug!("saving embedding for file {:?}", file.path); + let key = db_key_for_path(&file.path); + db.put(&mut txn, &key, file)?; + } + txn.commit()?; + + drop(embedded_files); + log::debug!("committed"); + } + + Ok(()) + }) + } + + pub fn paths(&self, cx: &AppContext) -> Task>>> { + let connection = self.db_connection.clone(); + let db = self.db; + cx.background_executor().spawn(async move { + let tx = connection + .read_txn() + .context("failed to create read transaction")?; + let result = db + .iter(&tx)? + .map(|entry| Ok(entry?.1.path.clone())) + .collect::>>>(); + drop(tx); + result + }) + } + + pub fn chunks_for_path( + &self, + path: Arc, + cx: &AppContext, + ) -> Task>> { + let connection = self.db_connection.clone(); + let db = self.db; + cx.background_executor().spawn(async move { + let tx = connection + .read_txn() + .context("failed to create read transaction")?; + Ok(db + .get(&tx, &db_key_for_path(&path))? + .ok_or_else(|| anyhow!("no such path"))? + .chunks + .clone()) + }) + } +} + +struct ScanEntries { + updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>, + deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, + task: Task>, +} + +struct ChunkFiles { + files: channel::Receiver, + task: Task>, +} + +pub struct ChunkedFile { + pub path: Arc, + pub mtime: Option, + pub handle: IndexingEntryHandle, + pub text: String, + pub chunks: Vec, +} + +pub struct EmbedFiles { + pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>, + pub task: Task>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct EmbeddedFile { + pub path: Arc, + pub mtime: Option, + pub chunks: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct EmbeddedChunk { + pub chunk: Chunk, + pub embedding: Embedding, +} + +fn db_key_for_path(path: &Arc) -> String { + path.to_string_lossy().replace('/', "\0") +} diff --git a/crates/semantic_index/src/indexing.rs b/crates/semantic_index/src/indexing.rs new file mode 100644 index 0000000000..aca9504891 --- /dev/null +++ b/crates/semantic_index/src/indexing.rs @@ -0,0 +1,49 @@ +use collections::HashSet; +use parking_lot::Mutex; +use project::ProjectEntryId; +use smol::channel; +use std::sync::{Arc, Weak}; + +/// The set of entries that are currently being indexed. +pub struct IndexingEntrySet { + entry_ids: Mutex>, + tx: channel::Sender<()>, +} + +/// When dropped, removes the entry from the set of entries that are being indexed. +#[derive(Clone)] +pub(crate) struct IndexingEntryHandle { + entry_id: ProjectEntryId, + set: Weak, +} + +impl IndexingEntrySet { + pub fn new(tx: channel::Sender<()>) -> Self { + Self { + entry_ids: Default::default(), + tx, + } + } + + pub fn insert(self: &Arc, entry_id: ProjectEntryId) -> IndexingEntryHandle { + self.entry_ids.lock().insert(entry_id); + self.tx.send_blocking(()).ok(); + IndexingEntryHandle { + entry_id, + set: Arc::downgrade(self), + } + } + + pub fn len(&self) -> usize { + self.entry_ids.lock().len() + } +} + +impl Drop for IndexingEntryHandle { + fn drop(&mut self) { + if let Some(set) = self.set.upgrade() { + set.tx.send_blocking(()).ok(); + set.entry_ids.lock().remove(&self.entry_id); + } + } +} diff --git a/crates/semantic_index/src/project_index.rs b/crates/semantic_index/src/project_index.rs new file mode 100644 index 0000000000..84a72c1a3d --- /dev/null +++ b/crates/semantic_index/src/project_index.rs @@ -0,0 +1,523 @@ +use crate::{ + embedding::{EmbeddingProvider, TextToEmbed}, + summary_index::FileSummary, + worktree_index::{WorktreeIndex, WorktreeIndexHandle}, +}; +use anyhow::{anyhow, Context, Result}; +use collections::HashMap; +use fs::Fs; +use futures::{stream::StreamExt, FutureExt}; +use gpui::{ + AppContext, Entity, EntityId, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel, +}; +use language::LanguageRegistry; +use log; +use project::{Project, Worktree, WorktreeId}; +use serde::{Deserialize, Serialize}; +use smol::channel; +use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc}; +use util::ResultExt; + +#[derive(Debug)] +pub struct SearchResult { + pub worktree: Model, + pub path: Arc, + pub range: Range, + pub score: f32, +} + +pub struct WorktreeSearchResult { + pub worktree_id: WorktreeId, + pub path: Arc, + pub range: Range, + pub score: f32, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub enum Status { + Idle, + Loading, + Scanning { remaining_count: NonZeroUsize }, +} + +pub struct ProjectIndex { + db_connection: heed::Env, + project: WeakModel, + worktree_indices: HashMap, + language_registry: Arc, + fs: Arc, + last_status: Status, + status_tx: channel::Sender<()>, + embedding_provider: Arc, + _maintain_status: Task<()>, + _subscription: Subscription, +} + +impl ProjectIndex { + pub fn new( + project: Model, + db_connection: heed::Env, + embedding_provider: Arc, + cx: &mut ModelContext, + ) -> Self { + let language_registry = project.read(cx).languages().clone(); + let fs = project.read(cx).fs().clone(); + let (status_tx, mut status_rx) = channel::unbounded(); + let mut this = ProjectIndex { + db_connection, + project: project.downgrade(), + worktree_indices: HashMap::default(), + language_registry, + fs, + status_tx, + last_status: Status::Idle, + embedding_provider, + _subscription: cx.subscribe(&project, Self::handle_project_event), + _maintain_status: cx.spawn(|this, mut cx| async move { + while status_rx.next().await.is_some() { + if this + .update(&mut cx, |this, cx| this.update_status(cx)) + .is_err() + { + break; + } + } + }), + }; + this.update_worktree_indices(cx); + this + } + + pub fn status(&self) -> Status { + self.last_status + } + + pub fn project(&self) -> WeakModel { + self.project.clone() + } + + pub fn fs(&self) -> Arc { + self.fs.clone() + } + + fn handle_project_event( + &mut self, + _: Model, + event: &project::Event, + cx: &mut ModelContext, + ) { + match event { + project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { + self.update_worktree_indices(cx); + } + _ => {} + } + } + + fn update_worktree_indices(&mut self, cx: &mut ModelContext) { + let Some(project) = self.project.upgrade() else { + return; + }; + + let worktrees = project + .read(cx) + .visible_worktrees(cx) + .filter_map(|worktree| { + if worktree.read(cx).is_local() { + Some((worktree.entity_id(), worktree)) + } else { + None + } + }) + .collect::>(); + + self.worktree_indices + .retain(|worktree_id, _| worktrees.contains_key(worktree_id)); + for (worktree_id, worktree) in worktrees { + self.worktree_indices.entry(worktree_id).or_insert_with(|| { + let worktree_index = WorktreeIndex::load( + worktree.clone(), + self.db_connection.clone(), + self.language_registry.clone(), + self.fs.clone(), + self.status_tx.clone(), + self.embedding_provider.clone(), + cx, + ); + + let load_worktree = cx.spawn(|this, mut cx| async move { + let result = match worktree_index.await { + Ok(worktree_index) => { + this.update(&mut cx, |this, _| { + this.worktree_indices.insert( + worktree_id, + WorktreeIndexHandle::Loaded { + index: worktree_index.clone(), + }, + ); + })?; + Ok(worktree_index) + } + Err(error) => { + this.update(&mut cx, |this, _cx| { + this.worktree_indices.remove(&worktree_id) + })?; + Err(Arc::new(error)) + } + }; + + this.update(&mut cx, |this, cx| this.update_status(cx))?; + + result + }); + + WorktreeIndexHandle::Loading { + index: load_worktree.shared(), + } + }); + } + + self.update_status(cx); + } + + fn update_status(&mut self, cx: &mut ModelContext) { + let mut indexing_count = 0; + let mut any_loading = false; + + for index in self.worktree_indices.values_mut() { + match index { + WorktreeIndexHandle::Loading { .. } => { + any_loading = true; + break; + } + WorktreeIndexHandle::Loaded { index, .. } => { + indexing_count += index.read(cx).entry_ids_being_indexed().len(); + } + } + } + + let status = if any_loading { + Status::Loading + } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) { + Status::Scanning { remaining_count } + } else { + Status::Idle + }; + + if status != self.last_status { + self.last_status = status; + cx.emit(status); + } + } + + pub fn search( + &self, + query: String, + limit: usize, + cx: &AppContext, + ) -> Task>> { + let (chunks_tx, chunks_rx) = channel::bounded(1024); + let mut worktree_scan_tasks = Vec::new(); + for worktree_index in self.worktree_indices.values() { + let worktree_index = worktree_index.clone(); + let chunks_tx = chunks_tx.clone(); + worktree_scan_tasks.push(cx.spawn(|cx| async move { + let index = match worktree_index { + WorktreeIndexHandle::Loading { index } => { + index.clone().await.map_err(|error| anyhow!(error))? + } + WorktreeIndexHandle::Loaded { index } => index.clone(), + }; + + index + .read_with(&cx, |index, cx| { + let worktree_id = index.worktree().read(cx).id(); + let db_connection = index.db_connection().clone(); + let db = *index.embedding_index().db(); + cx.background_executor().spawn(async move { + let txn = db_connection + .read_txn() + .context("failed to create read transaction")?; + let db_entries = db.iter(&txn).context("failed to iterate database")?; + for db_entry in db_entries { + let (_key, db_embedded_file) = db_entry?; + for chunk in db_embedded_file.chunks { + chunks_tx + .send((worktree_id, db_embedded_file.path.clone(), chunk)) + .await?; + } + } + anyhow::Ok(()) + }) + })? + .await + })); + } + drop(chunks_tx); + + let project = self.project.clone(); + let embedding_provider = self.embedding_provider.clone(); + cx.spawn(|cx| async move { + #[cfg(debug_assertions)] + let embedding_query_start = std::time::Instant::now(); + log::info!("Searching for {query}"); + + let query_embeddings = embedding_provider + .embed(&[TextToEmbed::new(&query)]) + .await?; + let query_embedding = query_embeddings + .into_iter() + .next() + .ok_or_else(|| anyhow!("no embedding for query"))?; + + let mut results_by_worker = Vec::new(); + for _ in 0..cx.background_executor().num_cpus() { + results_by_worker.push(Vec::::new()); + } + + #[cfg(debug_assertions)] + let search_start = std::time::Instant::now(); + + cx.background_executor() + .scoped(|cx| { + for results in results_by_worker.iter_mut() { + cx.spawn(async { + while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await { + let score = chunk.embedding.similarity(&query_embedding); + let ix = match results.binary_search_by(|probe| { + score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal) + }) { + Ok(ix) | Err(ix) => ix, + }; + results.insert( + ix, + WorktreeSearchResult { + worktree_id, + path: path.clone(), + range: chunk.chunk.range.clone(), + score, + }, + ); + results.truncate(limit); + } + }); + } + }) + .await; + + for scan_task in futures::future::join_all(worktree_scan_tasks).await { + scan_task.log_err(); + } + + project.read_with(&cx, |project, cx| { + let mut search_results = Vec::with_capacity(results_by_worker.len() * limit); + for worker_results in results_by_worker { + search_results.extend(worker_results.into_iter().filter_map(|result| { + Some(SearchResult { + worktree: project.worktree_for_id(result.worktree_id, cx)?, + path: result.path, + range: result.range, + score: result.score, + }) + })); + } + search_results.sort_unstable_by(|a, b| { + b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal) + }); + search_results.truncate(limit); + + #[cfg(debug_assertions)] + { + let search_elapsed = search_start.elapsed(); + log::debug!( + "searched {} entries in {:?}", + search_results.len(), + search_elapsed + ); + let embedding_query_elapsed = embedding_query_start.elapsed(); + log::debug!("embedding query took {:?}", embedding_query_elapsed); + } + + search_results + }) + }) + } + + #[cfg(test)] + pub fn path_count(&self, cx: &AppContext) -> Result { + let mut result = 0; + for worktree_index in self.worktree_indices.values() { + if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index { + result += index.read(cx).path_count()?; + } + } + Ok(result) + } + + pub(crate) fn worktree_index( + &self, + worktree_id: WorktreeId, + cx: &AppContext, + ) -> Option> { + for index in self.worktree_indices.values() { + if let WorktreeIndexHandle::Loaded { index, .. } = index { + if index.read(cx).worktree().read(cx).id() == worktree_id { + return Some(index.clone()); + } + } + } + None + } + + pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec> { + let mut result = self + .worktree_indices + .values() + .filter_map(|index| { + if let WorktreeIndexHandle::Loaded { index, .. } = index { + Some(index.clone()) + } else { + None + } + }) + .collect::>(); + result.sort_by_key(|index| index.read(cx).worktree().read(cx).id()); + result + } + + pub fn all_summaries(&self, cx: &AppContext) -> Task>> { + let (summaries_tx, summaries_rx) = channel::bounded(1024); + let mut worktree_scan_tasks = Vec::new(); + for worktree_index in self.worktree_indices.values() { + let worktree_index = worktree_index.clone(); + let summaries_tx: channel::Sender<(String, String)> = summaries_tx.clone(); + worktree_scan_tasks.push(cx.spawn(|cx| async move { + let index = match worktree_index { + WorktreeIndexHandle::Loading { index } => { + index.clone().await.map_err(|error| anyhow!(error))? + } + WorktreeIndexHandle::Loaded { index } => index.clone(), + }; + + index + .read_with(&cx, |index, cx| { + let db_connection = index.db_connection().clone(); + let summary_index = index.summary_index(); + let file_digest_db = summary_index.file_digest_db(); + let summary_db = summary_index.summary_db(); + + cx.background_executor().spawn(async move { + let txn = db_connection + .read_txn() + .context("failed to create db read transaction")?; + let db_entries = file_digest_db + .iter(&txn) + .context("failed to iterate database")?; + for db_entry in db_entries { + let (file_path, db_file) = db_entry?; + + match summary_db.get(&txn, &db_file.digest) { + Ok(opt_summary) => { + // Currently, we only use summaries we already have. If the file hasn't been + // summarized yet, then we skip it and don't include it in the inferred context. + // If we want to do just-in-time summarization, this would be the place to do it! + if let Some(summary) = opt_summary { + summaries_tx + .send((file_path.to_string(), summary.to_string())) + .await?; + } else { + log::warn!("No summary found for {:?}", &db_file); + } + } + Err(err) => { + log::error!( + "Error reading from summary database: {:?}", + err + ); + } + } + } + anyhow::Ok(()) + }) + })? + .await + })); + } + drop(summaries_tx); + + let project = self.project.clone(); + cx.spawn(|cx| async move { + let mut results_by_worker = Vec::new(); + for _ in 0..cx.background_executor().num_cpus() { + results_by_worker.push(Vec::::new()); + } + + cx.background_executor() + .scoped(|cx| { + for results in results_by_worker.iter_mut() { + cx.spawn(async { + while let Ok((filename, summary)) = summaries_rx.recv().await { + results.push(FileSummary { filename, summary }); + } + }); + } + }) + .await; + + for scan_task in futures::future::join_all(worktree_scan_tasks).await { + scan_task.log_err(); + } + + project.read_with(&cx, |_project, _cx| { + results_by_worker.into_iter().flatten().collect() + }) + }) + } + + /// Empty out the backlogs of all the worktrees in the project + pub fn flush_summary_backlogs(&self, cx: &AppContext) -> impl Future { + let flush_start = std::time::Instant::now(); + + futures::future::join_all(self.worktree_indices.values().map(|worktree_index| { + let worktree_index = worktree_index.clone(); + + cx.spawn(|cx| async move { + let index = match worktree_index { + WorktreeIndexHandle::Loading { index } => { + index.clone().await.map_err(|error| anyhow!(error))? + } + WorktreeIndexHandle::Loaded { index } => index.clone(), + }; + let worktree_abs_path = + cx.update(|cx| index.read(cx).worktree().read(cx).abs_path())?; + + index + .read_with(&cx, |index, cx| { + cx.background_executor() + .spawn(index.summary_index().flush_backlog(worktree_abs_path, cx)) + })? + .await + }) + })) + .map(move |results| { + // Log any errors, but don't block the user. These summaries are supposed to + // improve quality by providing extra context, but they aren't hard requirements! + for result in results { + if let Err(err) = result { + log::error!("Error flushing summary backlog: {:?}", err); + } + } + + log::info!("Summary backlog flushed in {:?}", flush_start.elapsed()); + }) + } + + pub fn remaining_summaries(&self, cx: &mut ModelContext) -> usize { + self.worktree_indices(cx) + .iter() + .map(|index| index.read(cx).summary_index().backlog_len()) + .sum() + } +} + +impl EventEmitter for ProjectIndex {} diff --git a/crates/semantic_index/src/project_index_debug_view.rs b/crates/semantic_index/src/project_index_debug_view.rs index e5881a24e7..d6628064ac 100644 --- a/crates/semantic_index/src/project_index_debug_view.rs +++ b/crates/semantic_index/src/project_index_debug_view.rs @@ -55,8 +55,12 @@ impl ProjectIndexDebugView { for index in worktree_indices { let (root_path, worktree_id, worktree_paths) = index.read_with(&cx, |index, cx| { - let worktree = index.worktree.read(cx); - (worktree.abs_path(), worktree.id(), index.paths(cx)) + let worktree = index.worktree().read(cx); + ( + worktree.abs_path(), + worktree.id(), + index.embedding_index().paths(cx), + ) })?; rows.push(Row::Worktree(root_path)); rows.extend( @@ -82,10 +86,12 @@ impl ProjectIndexDebugView { cx: &mut ViewContext, ) -> Option<()> { let project_index = self.index.read(cx); - let fs = project_index.fs.clone(); + let fs = project_index.fs().clone(); let worktree_index = project_index.worktree_index(worktree_id, cx)?.read(cx); - let root_path = worktree_index.worktree.read(cx).abs_path(); - let chunks = worktree_index.chunks_for_path(file_path.clone(), cx); + let root_path = worktree_index.worktree().read(cx).abs_path(); + let chunks = worktree_index + .embedding_index() + .chunks_for_path(file_path.clone(), cx); cx.spawn(|this, mut cx| async move { let chunks = chunks.await?; diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index fad3a5d3e8..f2b325ead6 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1,48 +1,35 @@ mod chunking; mod embedding; +mod embedding_index; +mod indexing; +mod project_index; mod project_index_debug_view; +mod summary_backlog; +mod summary_index; +mod worktree_index; + +use anyhow::{Context as _, Result}; +use collections::HashMap; +use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel}; +use project::Project; +use project_index::ProjectIndex; +use std::{path::PathBuf, sync::Arc}; +use ui::ViewContext; +use workspace::Workspace; -use anyhow::{anyhow, Context as _, Result}; -use chunking::{chunk_text, Chunk}; -use collections::{Bound, HashMap, HashSet}; pub use embedding::*; -use fs::Fs; -use futures::{future::Shared, stream::StreamExt, FutureExt}; -use futures_batch::ChunksTimeoutStreamExt; -use gpui::{ - AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global, - Model, ModelContext, Subscription, Task, WeakModel, -}; -use heed::types::{SerdeBincode, Str}; -use language::LanguageRegistry; -use parking_lot::Mutex; -use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId}; -use serde::{Deserialize, Serialize}; -use smol::channel; -use std::{ - cmp::Ordering, - future::Future, - iter, - num::NonZeroUsize, - ops::Range, - path::{Path, PathBuf}, - sync::{Arc, Weak}, - time::{Duration, SystemTime}, -}; -use util::ResultExt; -use worktree::Snapshot; - pub use project_index_debug_view::ProjectIndexDebugView; +pub use summary_index::FileSummary; -pub struct SemanticIndex { +pub struct SemanticDb { embedding_provider: Arc, db_connection: heed::Env, project_indices: HashMap, Model>, } -impl Global for SemanticIndex {} +impl Global for SemanticDb {} -impl SemanticIndex { +impl SemanticDb { pub async fn new( db_path: PathBuf, embedding_provider: Arc, @@ -62,7 +49,45 @@ impl SemanticIndex { .await .context("opening database connection")?; - Ok(SemanticIndex { + cx.update(|cx| { + cx.observe_new_views( + |workspace: &mut Workspace, cx: &mut ViewContext| { + let project = workspace.project().clone(); + + if cx.has_global::() { + cx.update_global::(|this, cx| { + let project_index = cx.new_model(|cx| { + ProjectIndex::new( + project.clone(), + this.db_connection.clone(), + this.embedding_provider.clone(), + cx, + ) + }); + + let project_weak = project.downgrade(); + this.project_indices + .insert(project_weak.clone(), project_index); + + cx.on_release(move |_, _, cx| { + if cx.has_global::() { + cx.update_global::(|this, _| { + this.project_indices.remove(&project_weak); + }) + } + }) + .detach(); + }) + } else { + log::info!("No SemanticDb, skipping project index") + } + }, + ) + .detach(); + }) + .ok(); + + Ok(SemanticDb { db_connection, embedding_provider, project_indices: HashMap::default(), @@ -72,985 +97,50 @@ impl SemanticIndex { pub fn project_index( &mut self, project: Model, + _cx: &mut AppContext, + ) -> Option> { + self.project_indices.get(&project.downgrade()).cloned() + } + + pub fn remaining_summaries( + &self, + project: &WeakModel, cx: &mut AppContext, - ) -> Model { - let project_weak = project.downgrade(); - project.update(cx, move |_, cx| { - cx.on_release(move |_, cx| { - if cx.has_global::() { - cx.update_global::(|this, _| { - this.project_indices.remove(&project_weak); - }) - } - }) - .detach(); - }); - - self.project_indices - .entry(project.downgrade()) - .or_insert_with(|| { - cx.new_model(|cx| { - ProjectIndex::new( - project, - self.db_connection.clone(), - self.embedding_provider.clone(), - cx, - ) - }) - }) - .clone() - } -} - -pub struct ProjectIndex { - db_connection: heed::Env, - project: WeakModel, - worktree_indices: HashMap, - language_registry: Arc, - fs: Arc, - last_status: Status, - status_tx: channel::Sender<()>, - embedding_provider: Arc, - _maintain_status: Task<()>, - _subscription: Subscription, -} - -#[derive(Clone)] -enum WorktreeIndexHandle { - Loading { - index: Shared, Arc>>>, - }, - Loaded { - index: Model, - }, -} - -impl ProjectIndex { - fn new( - project: Model, - db_connection: heed::Env, - embedding_provider: Arc, - cx: &mut ModelContext, - ) -> Self { - let language_registry = project.read(cx).languages().clone(); - let fs = project.read(cx).fs().clone(); - let (status_tx, mut status_rx) = channel::unbounded(); - let mut this = ProjectIndex { - db_connection, - project: project.downgrade(), - worktree_indices: HashMap::default(), - language_registry, - fs, - status_tx, - last_status: Status::Idle, - embedding_provider, - _subscription: cx.subscribe(&project, Self::handle_project_event), - _maintain_status: cx.spawn(|this, mut cx| async move { - while status_rx.next().await.is_some() { - if this - .update(&mut cx, |this, cx| this.update_status(cx)) - .is_err() - { - break; - } - } - }), - }; - this.update_worktree_indices(cx); - this - } - - pub fn status(&self) -> Status { - self.last_status - } - - pub fn project(&self) -> WeakModel { - self.project.clone() - } - - pub fn fs(&self) -> Arc { - self.fs.clone() - } - - fn handle_project_event( - &mut self, - _: Model, - event: &project::Event, - cx: &mut ModelContext, - ) { - match event { - project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { - self.update_worktree_indices(cx); - } - _ => {} - } - } - - fn update_worktree_indices(&mut self, cx: &mut ModelContext) { - let Some(project) = self.project.upgrade() else { - return; - }; - - let worktrees = project - .read(cx) - .visible_worktrees(cx) - .filter_map(|worktree| { - if worktree.read(cx).is_local() { - Some((worktree.entity_id(), worktree)) - } else { - None - } - }) - .collect::>(); - - self.worktree_indices - .retain(|worktree_id, _| worktrees.contains_key(worktree_id)); - for (worktree_id, worktree) in worktrees { - self.worktree_indices.entry(worktree_id).or_insert_with(|| { - let worktree_index = WorktreeIndex::load( - worktree.clone(), - self.db_connection.clone(), - self.language_registry.clone(), - self.fs.clone(), - self.status_tx.clone(), - self.embedding_provider.clone(), - cx, - ); - - let load_worktree = cx.spawn(|this, mut cx| async move { - let result = match worktree_index.await { - Ok(worktree_index) => { - this.update(&mut cx, |this, _| { - this.worktree_indices.insert( - worktree_id, - WorktreeIndexHandle::Loaded { - index: worktree_index.clone(), - }, - ); - })?; - Ok(worktree_index) - } - Err(error) => { - this.update(&mut cx, |this, _cx| { - this.worktree_indices.remove(&worktree_id) - })?; - Err(Arc::new(error)) - } - }; - - this.update(&mut cx, |this, cx| this.update_status(cx))?; - - result - }); - - WorktreeIndexHandle::Loading { - index: load_worktree.shared(), - } - }); - } - - self.update_status(cx); - } - - fn update_status(&mut self, cx: &mut ModelContext) { - let mut indexing_count = 0; - let mut any_loading = false; - - for index in self.worktree_indices.values_mut() { - match index { - WorktreeIndexHandle::Loading { .. } => { - any_loading = true; - break; - } - WorktreeIndexHandle::Loaded { index, .. } => { - indexing_count += index.read(cx).entry_ids_being_indexed.len(); - } - } - } - - let status = if any_loading { - Status::Loading - } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) { - Status::Scanning { remaining_count } - } else { - Status::Idle - }; - - if status != self.last_status { - self.last_status = status; - cx.emit(status); - } - } - - pub fn search( - &self, - query: String, - limit: usize, - cx: &AppContext, - ) -> Task>> { - let (chunks_tx, chunks_rx) = channel::bounded(1024); - let mut worktree_scan_tasks = Vec::new(); - for worktree_index in self.worktree_indices.values() { - let worktree_index = worktree_index.clone(); - let chunks_tx = chunks_tx.clone(); - worktree_scan_tasks.push(cx.spawn(|cx| async move { - let index = match worktree_index { - WorktreeIndexHandle::Loading { index } => { - index.clone().await.map_err(|error| anyhow!(error))? - } - WorktreeIndexHandle::Loaded { index } => index.clone(), - }; - - index - .read_with(&cx, |index, cx| { - let worktree_id = index.worktree.read(cx).id(); - let db_connection = index.db_connection.clone(); - let db = index.db; - cx.background_executor().spawn(async move { - let txn = db_connection - .read_txn() - .context("failed to create read transaction")?; - let db_entries = db.iter(&txn).context("failed to iterate database")?; - for db_entry in db_entries { - let (_key, db_embedded_file) = db_entry?; - for chunk in db_embedded_file.chunks { - chunks_tx - .send((worktree_id, db_embedded_file.path.clone(), chunk)) - .await?; - } - } - anyhow::Ok(()) - }) - })? - .await - })); - } - drop(chunks_tx); - - let project = self.project.clone(); - let embedding_provider = self.embedding_provider.clone(); - cx.spawn(|cx| async move { - #[cfg(debug_assertions)] - let embedding_query_start = std::time::Instant::now(); - log::info!("Searching for {query}"); - - let query_embeddings = embedding_provider - .embed(&[TextToEmbed::new(&query)]) - .await?; - let query_embedding = query_embeddings - .into_iter() - .next() - .ok_or_else(|| anyhow!("no embedding for query"))?; - - let mut results_by_worker = Vec::new(); - for _ in 0..cx.background_executor().num_cpus() { - results_by_worker.push(Vec::::new()); - } - - #[cfg(debug_assertions)] - let search_start = std::time::Instant::now(); - - cx.background_executor() - .scoped(|cx| { - for results in results_by_worker.iter_mut() { - cx.spawn(async { - while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await { - let score = chunk.embedding.similarity(&query_embedding); - let ix = match results.binary_search_by(|probe| { - score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal) - }) { - Ok(ix) | Err(ix) => ix, - }; - results.insert( - ix, - WorktreeSearchResult { - worktree_id, - path: path.clone(), - range: chunk.chunk.range.clone(), - score, - }, - ); - results.truncate(limit); - } - }); - } - }) - .await; - - for scan_task in futures::future::join_all(worktree_scan_tasks).await { - scan_task.log_err(); - } - - project.read_with(&cx, |project, cx| { - let mut search_results = Vec::with_capacity(results_by_worker.len() * limit); - for worker_results in results_by_worker { - search_results.extend(worker_results.into_iter().filter_map(|result| { - Some(SearchResult { - worktree: project.worktree_for_id(result.worktree_id, cx)?, - path: result.path, - range: result.range, - score: result.score, - }) - })); - } - search_results.sort_unstable_by(|a, b| { - b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal) - }); - search_results.truncate(limit); - - #[cfg(debug_assertions)] - { - let search_elapsed = search_start.elapsed(); - log::debug!( - "searched {} entries in {:?}", - search_results.len(), - search_elapsed - ); - let embedding_query_elapsed = embedding_query_start.elapsed(); - log::debug!("embedding query took {:?}", embedding_query_elapsed); - } - - search_results + ) -> Option { + self.project_indices.get(project).map(|project_index| { + project_index.update(cx, |project_index, cx| { + project_index.remaining_summaries(cx) }) }) } - - #[cfg(test)] - pub fn path_count(&self, cx: &AppContext) -> Result { - let mut result = 0; - for worktree_index in self.worktree_indices.values() { - if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index { - result += index.read(cx).path_count()?; - } - } - Ok(result) - } - - pub(crate) fn worktree_index( - &self, - worktree_id: WorktreeId, - cx: &AppContext, - ) -> Option> { - for index in self.worktree_indices.values() { - if let WorktreeIndexHandle::Loaded { index, .. } = index { - if index.read(cx).worktree.read(cx).id() == worktree_id { - return Some(index.clone()); - } - } - } - None - } - - pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec> { - let mut result = self - .worktree_indices - .values() - .filter_map(|index| { - if let WorktreeIndexHandle::Loaded { index, .. } = index { - Some(index.clone()) - } else { - None - } - }) - .collect::>(); - result.sort_by_key(|index| index.read(cx).worktree.read(cx).id()); - result - } -} - -pub struct SearchResult { - pub worktree: Model, - pub path: Arc, - pub range: Range, - pub score: f32, -} - -pub struct WorktreeSearchResult { - pub worktree_id: WorktreeId, - pub path: Arc, - pub range: Range, - pub score: f32, -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] -pub enum Status { - Idle, - Loading, - Scanning { remaining_count: NonZeroUsize }, -} - -impl EventEmitter for ProjectIndex {} - -struct WorktreeIndex { - worktree: Model, - db_connection: heed::Env, - db: heed::Database>, - language_registry: Arc, - fs: Arc, - embedding_provider: Arc, - entry_ids_being_indexed: Arc, - _index_entries: Task>, - _subscription: Subscription, -} - -impl WorktreeIndex { - pub fn load( - worktree: Model, - db_connection: heed::Env, - language_registry: Arc, - fs: Arc, - status_tx: channel::Sender<()>, - embedding_provider: Arc, - cx: &mut AppContext, - ) -> Task>> { - let worktree_abs_path = worktree.read(cx).abs_path(); - cx.spawn(|mut cx| async move { - let db = cx - .background_executor() - .spawn({ - let db_connection = db_connection.clone(); - async move { - let mut txn = db_connection.write_txn()?; - let db_name = worktree_abs_path.to_string_lossy(); - let db = db_connection.create_database(&mut txn, Some(&db_name))?; - txn.commit()?; - anyhow::Ok(db) - } - }) - .await?; - cx.new_model(|cx| { - Self::new( - worktree, - db_connection, - db, - status_tx, - language_registry, - fs, - embedding_provider, - cx, - ) - }) - }) - } - - #[allow(clippy::too_many_arguments)] - fn new( - worktree: Model, - db_connection: heed::Env, - db: heed::Database>, - status: channel::Sender<()>, - language_registry: Arc, - fs: Arc, - embedding_provider: Arc, - cx: &mut ModelContext, - ) -> Self { - let (updated_entries_tx, updated_entries_rx) = channel::unbounded(); - let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| { - if let worktree::Event::UpdatedEntries(update) = event { - _ = updated_entries_tx.try_send(update.clone()); - } - }); - - Self { - db_connection, - db, - worktree, - language_registry, - fs, - embedding_provider, - entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)), - _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)), - _subscription, - } - } - - async fn index_entries( - this: WeakModel, - updated_entries: channel::Receiver, - mut cx: AsyncAppContext, - ) -> Result<()> { - let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?; - index.await.log_err(); - - while let Ok(updated_entries) = updated_entries.recv().await { - let index = this.update(&mut cx, |this, cx| { - this.index_updated_entries(updated_entries, cx) - })?; - index.await.log_err(); - } - - Ok(()) - } - - fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future> { - let worktree = self.worktree.read(cx).snapshot(); - let worktree_abs_path = worktree.abs_path().clone(); - let scan = self.scan_entries(worktree, cx); - let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); - let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx); - let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); - async move { - futures::try_join!(scan.task, chunk.task, embed.task, persist)?; - Ok(()) - } - } - - fn index_updated_entries( - &self, - updated_entries: UpdatedEntriesSet, - cx: &AppContext, - ) -> impl Future> { - let worktree = self.worktree.read(cx).snapshot(); - let worktree_abs_path = worktree.abs_path().clone(); - let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx); - let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx); - let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx); - let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx); - async move { - futures::try_join!(scan.task, chunk.task, embed.task, persist)?; - Ok(()) - } - } - - fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries { - let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); - let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); - let db_connection = self.db_connection.clone(); - let db = self.db; - let entries_being_indexed = self.entry_ids_being_indexed.clone(); - let task = cx.background_executor().spawn(async move { - let txn = db_connection - .read_txn() - .context("failed to create read transaction")?; - let mut db_entries = db - .iter(&txn) - .context("failed to create iterator")? - .move_between_keys() - .peekable(); - - let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None; - for entry in worktree.files(false, 0) { - let entry_db_key = db_key_for_path(&entry.path); - - let mut saved_mtime = None; - while let Some(db_entry) = db_entries.peek() { - match db_entry { - Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) { - Ordering::Less => { - if let Some(deletion_range) = deletion_range.as_mut() { - deletion_range.1 = Bound::Included(db_path); - } else { - deletion_range = - Some((Bound::Included(db_path), Bound::Included(db_path))); - } - - db_entries.next(); - } - Ordering::Equal => { - if let Some(deletion_range) = deletion_range.take() { - deleted_entry_ranges_tx - .send(( - deletion_range.0.map(ToString::to_string), - deletion_range.1.map(ToString::to_string), - )) - .await?; - } - saved_mtime = db_embedded_file.mtime; - db_entries.next(); - break; - } - Ordering::Greater => { - break; - } - }, - Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?, - } - } - - if entry.mtime != saved_mtime { - let handle = entries_being_indexed.insert(entry.id); - updated_entries_tx.send((entry.clone(), handle)).await?; - } - } - - if let Some(db_entry) = db_entries.next() { - let (db_path, _) = db_entry?; - deleted_entry_ranges_tx - .send((Bound::Included(db_path.to_string()), Bound::Unbounded)) - .await?; - } - - Ok(()) - }); - - ScanEntries { - updated_entries: updated_entries_rx, - deleted_entry_ranges: deleted_entry_ranges_rx, - task, - } - } - - fn scan_updated_entries( - &self, - worktree: Snapshot, - updated_entries: UpdatedEntriesSet, - cx: &AppContext, - ) -> ScanEntries { - let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); - let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); - let entries_being_indexed = self.entry_ids_being_indexed.clone(); - let task = cx.background_executor().spawn(async move { - for (path, entry_id, status) in updated_entries.iter() { - match status { - project::PathChange::Added - | project::PathChange::Updated - | project::PathChange::AddedOrUpdated => { - if let Some(entry) = worktree.entry_for_id(*entry_id) { - if entry.is_file() { - let handle = entries_being_indexed.insert(entry.id); - updated_entries_tx.send((entry.clone(), handle)).await?; - } - } - } - project::PathChange::Removed => { - let db_path = db_key_for_path(path); - deleted_entry_ranges_tx - .send((Bound::Included(db_path.clone()), Bound::Included(db_path))) - .await?; - } - project::PathChange::Loaded => { - // Do nothing. - } - } - } - - Ok(()) - }); - - ScanEntries { - updated_entries: updated_entries_rx, - deleted_entry_ranges: deleted_entry_ranges_rx, - task, - } - } - - fn chunk_files( - &self, - worktree_abs_path: Arc, - entries: channel::Receiver<(Entry, IndexingEntryHandle)>, - cx: &AppContext, - ) -> ChunkFiles { - let language_registry = self.language_registry.clone(); - let fs = self.fs.clone(); - let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048); - let task = cx.spawn(|cx| async move { - cx.background_executor() - .scoped(|cx| { - for _ in 0..cx.num_cpus() { - cx.spawn(async { - while let Ok((entry, handle)) = entries.recv().await { - let entry_abs_path = worktree_abs_path.join(&entry.path); - let Some(text) = fs - .load(&entry_abs_path) - .await - .with_context(|| { - format!("failed to read path {entry_abs_path:?}") - }) - .log_err() - else { - continue; - }; - let language = language_registry - .language_for_file_path(&entry.path) - .await - .ok(); - let chunked_file = ChunkedFile { - chunks: chunk_text(&text, language.as_ref(), &entry.path), - handle, - path: entry.path, - mtime: entry.mtime, - text, - }; - - if chunked_files_tx.send(chunked_file).await.is_err() { - return; - } - } - }); - } - }) - .await; - Ok(()) - }); - - ChunkFiles { - files: chunked_files_rx, - task, - } - } - - fn embed_files( - embedding_provider: Arc, - chunked_files: channel::Receiver, - cx: &AppContext, - ) -> EmbedFiles { - let embedding_provider = embedding_provider.clone(); - let (embedded_files_tx, embedded_files_rx) = channel::bounded(512); - let task = cx.background_executor().spawn(async move { - let mut chunked_file_batches = - chunked_files.chunks_timeout(512, Duration::from_secs(2)); - while let Some(chunked_files) = chunked_file_batches.next().await { - // View the batch of files as a vec of chunks - // Flatten out to a vec of chunks that we can subdivide into batch sized pieces - // Once those are done, reassemble them back into the files in which they belong - // If any embeddings fail for a file, the entire file is discarded - - let chunks: Vec = chunked_files - .iter() - .flat_map(|file| { - file.chunks.iter().map(|chunk| TextToEmbed { - text: &file.text[chunk.range.clone()], - digest: chunk.digest, - }) - }) - .collect::>(); - - let mut embeddings: Vec> = Vec::new(); - for embedding_batch in chunks.chunks(embedding_provider.batch_size()) { - if let Some(batch_embeddings) = - embedding_provider.embed(embedding_batch).await.log_err() - { - if batch_embeddings.len() == embedding_batch.len() { - embeddings.extend(batch_embeddings.into_iter().map(Some)); - continue; - } - log::error!( - "embedding provider returned unexpected embedding count {}, expected {}", - batch_embeddings.len(), embedding_batch.len() - ); - } - - embeddings.extend(iter::repeat(None).take(embedding_batch.len())); - } - - let mut embeddings = embeddings.into_iter(); - for chunked_file in chunked_files { - let mut embedded_file = EmbeddedFile { - path: chunked_file.path, - mtime: chunked_file.mtime, - chunks: Vec::new(), - }; - - let mut embedded_all_chunks = true; - for (chunk, embedding) in - chunked_file.chunks.into_iter().zip(embeddings.by_ref()) - { - if let Some(embedding) = embedding { - embedded_file - .chunks - .push(EmbeddedChunk { chunk, embedding }); - } else { - embedded_all_chunks = false; - } - } - - if embedded_all_chunks { - embedded_files_tx - .send((embedded_file, chunked_file.handle)) - .await?; - } - } - } - Ok(()) - }); - - EmbedFiles { - files: embedded_files_rx, - task, - } - } - - fn persist_embeddings( - &self, - mut deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, - embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>, - cx: &AppContext, - ) -> Task> { - let db_connection = self.db_connection.clone(); - let db = self.db; - cx.background_executor().spawn(async move { - while let Some(deletion_range) = deleted_entry_ranges.next().await { - let mut txn = db_connection.write_txn()?; - let start = deletion_range.0.as_ref().map(|start| start.as_str()); - let end = deletion_range.1.as_ref().map(|end| end.as_str()); - log::debug!("deleting embeddings in range {:?}", &(start, end)); - db.delete_range(&mut txn, &(start, end))?; - txn.commit()?; - } - - let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2)); - while let Some(embedded_files) = embedded_files.next().await { - let mut txn = db_connection.write_txn()?; - for (file, _) in &embedded_files { - log::debug!("saving embedding for file {:?}", file.path); - let key = db_key_for_path(&file.path); - db.put(&mut txn, &key, file)?; - } - txn.commit()?; - - drop(embedded_files); - log::debug!("committed"); - } - - Ok(()) - }) - } - - fn paths(&self, cx: &AppContext) -> Task>>> { - let connection = self.db_connection.clone(); - let db = self.db; - cx.background_executor().spawn(async move { - let tx = connection - .read_txn() - .context("failed to create read transaction")?; - let result = db - .iter(&tx)? - .map(|entry| Ok(entry?.1.path.clone())) - .collect::>>>(); - drop(tx); - result - }) - } - - fn chunks_for_path( - &self, - path: Arc, - cx: &AppContext, - ) -> Task>> { - let connection = self.db_connection.clone(); - let db = self.db; - cx.background_executor().spawn(async move { - let tx = connection - .read_txn() - .context("failed to create read transaction")?; - Ok(db - .get(&tx, &db_key_for_path(&path))? - .ok_or_else(|| anyhow!("no such path"))? - .chunks - .clone()) - }) - } - - #[cfg(test)] - fn path_count(&self) -> Result { - let txn = self - .db_connection - .read_txn() - .context("failed to create read transaction")?; - Ok(self.db.len(&txn)?) - } -} - -struct ScanEntries { - updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>, - deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, - task: Task>, -} - -struct ChunkFiles { - files: channel::Receiver, - task: Task>, -} - -struct ChunkedFile { - pub path: Arc, - pub mtime: Option, - pub handle: IndexingEntryHandle, - pub text: String, - pub chunks: Vec, -} - -struct EmbedFiles { - files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>, - task: Task>, -} - -#[derive(Debug, Serialize, Deserialize)] -struct EmbeddedFile { - path: Arc, - mtime: Option, - chunks: Vec, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -struct EmbeddedChunk { - chunk: Chunk, - embedding: Embedding, -} - -/// The set of entries that are currently being indexed. -struct IndexingEntrySet { - entry_ids: Mutex>, - tx: channel::Sender<()>, -} - -/// When dropped, removes the entry from the set of entries that are being indexed. -#[derive(Clone)] -struct IndexingEntryHandle { - entry_id: ProjectEntryId, - set: Weak, -} - -impl IndexingEntrySet { - fn new(tx: channel::Sender<()>) -> Self { - Self { - entry_ids: Default::default(), - tx, - } - } - - fn insert(self: &Arc, entry_id: ProjectEntryId) -> IndexingEntryHandle { - self.entry_ids.lock().insert(entry_id); - self.tx.send_blocking(()).ok(); - IndexingEntryHandle { - entry_id, - set: Arc::downgrade(self), - } - } - - pub fn len(&self) -> usize { - self.entry_ids.lock().len() - } -} - -impl Drop for IndexingEntryHandle { - fn drop(&mut self) { - if let Some(set) = self.set.upgrade() { - set.tx.send_blocking(()).ok(); - set.entry_ids.lock().remove(&self.entry_id); - } - } -} - -fn db_key_for_path(path: &Arc) -> String { - path.to_string_lossy().replace('/', "\0") } #[cfg(test)] mod tests { use super::*; + use anyhow::anyhow; + use chunking::Chunk; + use embedding_index::{ChunkedFile, EmbeddingIndex}; + use feature_flags::FeatureFlagAppExt; + use fs::FakeFs; use futures::{future::BoxFuture, FutureExt}; use gpui::TestAppContext; + use indexing::IndexingEntrySet; use language::language_settings::AllLanguageSettings; - use project::Project; + use project::{Project, ProjectEntryId}; + use serde_json::json; use settings::SettingsStore; + use smol::{channel, stream::StreamExt}; use std::{future, path::Path, sync::Arc}; fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { let store = SettingsStore::test(cx); cx.set_global(store); language::init(cx); + cx.update_flags(false, vec![]); Project::init_settings(cx); SettingsStore::update(cx, |store, cx| { store.update_user_settings::(cx, |_| {}); @@ -1100,7 +190,7 @@ mod tests { let temp_dir = tempfile::tempdir().unwrap(); - let mut semantic_index = SemanticIndex::new( + let mut semantic_index = SemanticDb::new( temp_dir.path().into(), Arc::new(TestEmbeddingProvider::new(16, |text| { let mut embedding = vec![0f32; 2]; @@ -1124,26 +214,57 @@ mod tests { .await .unwrap(); - let project_path = Path::new("./fixture"); + let fs = FakeFs::new(cx.executor()); + let project_path = Path::new("/fake_project"); - let project = cx - .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await }) - .await; + fs.insert_tree( + project_path, + json!({ + "fixture": { + "main.rs": include_str!("../fixture/main.rs"), + "needle.md": include_str!("../fixture/needle.md"), + } + }), + ) + .await; + + let project = Project::test(fs, [project_path], cx).await; cx.update(|cx| { let language_registry = project.read(cx).languages().clone(); let node_runtime = project.read(cx).node_runtime().unwrap().clone(); languages::init(language_registry, node_runtime, cx); + + // Manually create and insert the ProjectIndex + let project_index = cx.new_model(|cx| { + ProjectIndex::new( + project.clone(), + semantic_index.db_connection.clone(), + semantic_index.embedding_provider.clone(), + cx, + ) + }); + semantic_index + .project_indices + .insert(project.downgrade(), project_index); }); - let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx)); + let project_index = cx + .update(|_cx| { + semantic_index + .project_indices + .get(&project.downgrade()) + .cloned() + }) + .unwrap(); - while project_index - .read_with(cx, |index, cx| index.path_count(cx)) + cx.run_until_parked(); + while cx + .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx)) .unwrap() - == 0 + > 0 { - project_index.next_event(cx).await; + cx.run_until_parked(); } let results = cx @@ -1155,7 +276,11 @@ mod tests { .await .unwrap(); - assert!(results.len() > 1, "should have found some results"); + assert!( + results.len() > 1, + "should have found some results, but only found {:?}", + results + ); for result in &results { println!("result: {:?}", result.path); @@ -1165,7 +290,7 @@ mod tests { // Find result that is greater than 0.5 let search_result = results.iter().find(|result| result.score > 0.9).unwrap(); - assert_eq!(search_result.path.to_string_lossy(), "needle.md"); + assert_eq!(search_result.path.to_string_lossy(), "fixture/needle.md"); let content = cx .update(|cx| { @@ -1236,7 +361,7 @@ mod tests { chunked_files_tx.close(); let embed_files_task = - cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx)); + cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx)); embed_files_task.task.await.unwrap(); let mut embedded_files_rx = embed_files_task.files; diff --git a/crates/semantic_index/src/summary_backlog.rs b/crates/semantic_index/src/summary_backlog.rs new file mode 100644 index 0000000000..c6d8e33a45 --- /dev/null +++ b/crates/semantic_index/src/summary_backlog.rs @@ -0,0 +1,48 @@ +use collections::HashMap; +use std::{path::Path, sync::Arc, time::SystemTime}; + +const MAX_FILES_BEFORE_RESUMMARIZE: usize = 4; +const MAX_BYTES_BEFORE_RESUMMARIZE: u64 = 1_000_000; // 1 MB + +#[derive(Default, Debug)] +pub struct SummaryBacklog { + /// Key: path to a file that needs summarization, but that we haven't summarized yet. Value: that file's size on disk, in bytes, and its mtime. + files: HashMap, (u64, Option)>, + /// Cache of the sum of all values in `files`, so we don't have to traverse the whole map to check if we're over the byte limit. + total_bytes: u64, +} + +impl SummaryBacklog { + /// Store the given path in the backlog, along with how many bytes are in it. + pub fn insert(&mut self, path: Arc, bytes_on_disk: u64, mtime: Option) { + let (prev_bytes, _) = self + .files + .insert(path, (bytes_on_disk, mtime)) + .unwrap_or_default(); // Default to 0 prev_bytes + + // Update the cached total by subtracting out the old amount and adding the new one. + self.total_bytes = self.total_bytes - prev_bytes + bytes_on_disk; + } + + /// Returns true if the total number of bytes in the backlog exceeds a predefined threshold. + pub fn needs_drain(&self) -> bool { + self.files.len() > MAX_FILES_BEFORE_RESUMMARIZE || + // The whole purpose of the cached total_bytes is to make this comparison cheap. + // Otherwise we'd have to traverse the entire dictionary every time we wanted this answer. + self.total_bytes > MAX_BYTES_BEFORE_RESUMMARIZE + } + + /// Remove all the entries in the backlog and return the file paths as an iterator. + #[allow(clippy::needless_lifetimes)] // Clippy thinks this 'a can be elided, but eliding it gives a compile error + pub fn drain<'a>(&'a mut self) -> impl Iterator, Option)> + 'a { + self.total_bytes = 0; + + self.files + .drain() + .map(|(path, (_size, mtime))| (path, mtime)) + } + + pub fn len(&self) -> usize { + self.files.len() + } +} diff --git a/crates/semantic_index/src/summary_index.rs b/crates/semantic_index/src/summary_index.rs new file mode 100644 index 0000000000..08f25ae028 --- /dev/null +++ b/crates/semantic_index/src/summary_index.rs @@ -0,0 +1,693 @@ +use anyhow::{anyhow, Context as _, Result}; +use arrayvec::ArrayString; +use fs::Fs; +use futures::{stream::StreamExt, TryFutureExt}; +use futures_batch::ChunksTimeoutStreamExt; +use gpui::{AppContext, Model, Task}; +use heed::{ + types::{SerdeBincode, Str}, + RoTxn, +}; +use language_model::{ + LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, Role, +}; +use log; +use parking_lot::Mutex; +use project::{Entry, UpdatedEntriesSet, Worktree}; +use serde::{Deserialize, Serialize}; +use smol::channel; +use std::{ + future::Future, + path::Path, + sync::Arc, + time::{Duration, Instant, SystemTime}, +}; +use util::ResultExt; +use worktree::Snapshot; + +use crate::{indexing::IndexingEntrySet, summary_backlog::SummaryBacklog}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct FileSummary { + pub filename: String, + pub summary: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct UnsummarizedFile { + // Path to the file on disk + path: Arc, + // The mtime of the file on disk + mtime: Option, + // BLAKE3 hash of the source file's contents + digest: Blake3Digest, + // The source file's contents + contents: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct SummarizedFile { + // Path to the file on disk + path: String, + // The mtime of the file on disk + mtime: Option, + // BLAKE3 hash of the source file's contents + digest: Blake3Digest, + // The LLM's summary of the file's contents + summary: String, +} + +/// This is what blake3's to_hex() method returns - see https://docs.rs/blake3/1.5.3/src/blake3/lib.rs.html#246 +pub type Blake3Digest = ArrayString<{ blake3::OUT_LEN * 2 }>; + +#[derive(Debug, Serialize, Deserialize)] +pub struct FileDigest { + pub mtime: Option, + pub digest: Blake3Digest, +} + +struct NeedsSummary { + files: channel::Receiver, + task: Task>, +} + +struct SummarizeFiles { + files: channel::Receiver, + task: Task>, +} + +pub struct SummaryIndex { + worktree: Model, + fs: Arc, + db_connection: heed::Env, + file_digest_db: heed::Database>, // Key: file path. Val: BLAKE3 digest of its contents. + summary_db: heed::Database, Str>, // Key: BLAKE3 digest of a file's contents. Val: LLM summary of those contents. + backlog: Arc>, + _entry_ids_being_indexed: Arc, // TODO can this be removed? +} + +struct Backlogged { + paths_to_digest: channel::Receiver, Option)>>, + task: Task>, +} + +struct MightNeedSummaryFiles { + files: channel::Receiver, + task: Task>, +} + +impl SummaryIndex { + pub fn new( + worktree: Model, + fs: Arc, + db_connection: heed::Env, + file_digest_db: heed::Database>, + summary_db: heed::Database, Str>, + _entry_ids_being_indexed: Arc, + ) -> Self { + Self { + worktree, + fs, + db_connection, + file_digest_db, + summary_db, + _entry_ids_being_indexed, + backlog: Default::default(), + } + } + + pub fn file_digest_db(&self) -> heed::Database> { + self.file_digest_db + } + + pub fn summary_db(&self) -> heed::Database, Str> { + self.summary_db + } + + pub fn index_entries_changed_on_disk( + &self, + is_auto_available: bool, + cx: &AppContext, + ) -> impl Future> { + let start = Instant::now(); + let backlogged; + let digest; + let needs_summary; + let summaries; + let persist; + + if is_auto_available { + let worktree = self.worktree.read(cx).snapshot(); + let worktree_abs_path = worktree.abs_path().clone(); + + backlogged = self.scan_entries(worktree, cx); + digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx); + needs_summary = self.check_summary_cache(digest.files, cx); + summaries = self.summarize_files(needs_summary.files, cx); + persist = self.persist_summaries(summaries.files, cx); + } else { + // This feature is only staff-shipped, so make the rest of these no-ops. + backlogged = Backlogged { + paths_to_digest: channel::unbounded().1, + task: Task::ready(Ok(())), + }; + digest = MightNeedSummaryFiles { + files: channel::unbounded().1, + task: Task::ready(Ok(())), + }; + needs_summary = NeedsSummary { + files: channel::unbounded().1, + task: Task::ready(Ok(())), + }; + summaries = SummarizeFiles { + files: channel::unbounded().1, + task: Task::ready(Ok(())), + }; + persist = Task::ready(Ok(())); + } + + async move { + futures::try_join!( + backlogged.task, + digest.task, + needs_summary.task, + summaries.task, + persist + )?; + + if is_auto_available { + log::info!( + "Summarizing everything that changed on disk took {:?}", + start.elapsed() + ); + } + + Ok(()) + } + } + + pub fn index_updated_entries( + &mut self, + updated_entries: UpdatedEntriesSet, + is_auto_available: bool, + cx: &AppContext, + ) -> impl Future> { + let start = Instant::now(); + let backlogged; + let digest; + let needs_summary; + let summaries; + let persist; + + if is_auto_available { + let worktree = self.worktree.read(cx).snapshot(); + let worktree_abs_path = worktree.abs_path().clone(); + + backlogged = self.scan_updated_entries(worktree, updated_entries.clone(), cx); + digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx); + needs_summary = self.check_summary_cache(digest.files, cx); + summaries = self.summarize_files(needs_summary.files, cx); + persist = self.persist_summaries(summaries.files, cx); + } else { + // This feature is only staff-shipped, so make the rest of these no-ops. + backlogged = Backlogged { + paths_to_digest: channel::unbounded().1, + task: Task::ready(Ok(())), + }; + digest = MightNeedSummaryFiles { + files: channel::unbounded().1, + task: Task::ready(Ok(())), + }; + needs_summary = NeedsSummary { + files: channel::unbounded().1, + task: Task::ready(Ok(())), + }; + summaries = SummarizeFiles { + files: channel::unbounded().1, + task: Task::ready(Ok(())), + }; + persist = Task::ready(Ok(())); + } + + async move { + futures::try_join!( + backlogged.task, + digest.task, + needs_summary.task, + summaries.task, + persist + )?; + + log::info!("Summarizing updated entries took {:?}", start.elapsed()); + + Ok(()) + } + } + + fn check_summary_cache( + &self, + mut might_need_summary: channel::Receiver, + cx: &AppContext, + ) -> NeedsSummary { + let db_connection = self.db_connection.clone(); + let db = self.summary_db; + let (needs_summary_tx, needs_summary_rx) = channel::bounded(512); + let task = cx.background_executor().spawn(async move { + while let Some(file) = might_need_summary.next().await { + let tx = db_connection + .read_txn() + .context("Failed to create read transaction for checking which hashes are in summary cache")?; + + match db.get(&tx, &file.digest) { + Ok(opt_answer) => { + if opt_answer.is_none() { + // It's not in the summary cache db, so we need to summarize it. + log::debug!("File {:?} (digest {:?}) was NOT in the db cache and needs to be resummarized.", file.path.display(), &file.digest); + needs_summary_tx.send(file).await?; + } else { + log::debug!("File {:?} (digest {:?}) was in the db cache and does not need to be resummarized.", file.path.display(), &file.digest); + } + } + Err(err) => { + log::error!("Reading from the summaries database failed: {:?}", err); + } + } + } + + Ok(()) + }); + + NeedsSummary { + files: needs_summary_rx, + task, + } + } + + fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> Backlogged { + let (tx, rx) = channel::bounded(512); + let db_connection = self.db_connection.clone(); + let digest_db = self.file_digest_db; + let backlog = Arc::clone(&self.backlog); + let task = cx.background_executor().spawn(async move { + let txn = db_connection + .read_txn() + .context("failed to create read transaction")?; + + for entry in worktree.files(false, 0) { + let needs_summary = + Self::add_to_backlog(Arc::clone(&backlog), digest_db, &txn, entry); + + if !needs_summary.is_empty() { + tx.send(needs_summary).await?; + } + } + + // TODO delete db entries for deleted files + + Ok(()) + }); + + Backlogged { + paths_to_digest: rx, + task, + } + } + + fn add_to_backlog( + backlog: Arc>, + digest_db: heed::Database>, + txn: &RoTxn<'_>, + entry: &Entry, + ) -> Vec<(Arc, Option)> { + let entry_db_key = db_key_for_path(&entry.path); + + match digest_db.get(&txn, &entry_db_key) { + Ok(opt_saved_digest) => { + // The file path is the same, but the mtime is different. (Or there was no mtime.) + // It needs updating, so add it to the backlog! Then, if the backlog is full, drain it and summarize its contents. + if entry.mtime != opt_saved_digest.and_then(|digest| digest.mtime) { + let mut backlog = backlog.lock(); + + log::info!( + "Inserting {:?} ({:?} bytes) into backlog", + &entry.path, + entry.size, + ); + backlog.insert(Arc::clone(&entry.path), entry.size, entry.mtime); + + if backlog.needs_drain() { + log::info!("Draining summary backlog..."); + return backlog.drain().collect(); + } + } + } + Err(err) => { + log::error!( + "Error trying to get file digest db entry {:?}: {:?}", + &entry_db_key, + err + ); + } + } + + Vec::new() + } + + fn scan_updated_entries( + &self, + worktree: Snapshot, + updated_entries: UpdatedEntriesSet, + cx: &AppContext, + ) -> Backlogged { + log::info!("Scanning for updated entries that might need summarization..."); + let (tx, rx) = channel::bounded(512); + // let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); + let db_connection = self.db_connection.clone(); + let digest_db = self.file_digest_db; + let backlog = Arc::clone(&self.backlog); + let task = cx.background_executor().spawn(async move { + let txn = db_connection + .read_txn() + .context("failed to create read transaction")?; + + for (path, entry_id, status) in updated_entries.iter() { + match status { + project::PathChange::Loaded + | project::PathChange::Added + | project::PathChange::Updated + | project::PathChange::AddedOrUpdated => { + if let Some(entry) = worktree.entry_for_id(*entry_id) { + if entry.is_file() { + let needs_summary = Self::add_to_backlog( + Arc::clone(&backlog), + digest_db, + &txn, + entry, + ); + + if !needs_summary.is_empty() { + tx.send(needs_summary).await?; + } + } + } + } + project::PathChange::Removed => { + let _db_path = db_key_for_path(path); + // TODO delete db entries for deleted files + // deleted_entry_ranges_tx + // .send((Bound::Included(db_path.clone()), Bound::Included(db_path))) + // .await?; + } + } + } + + Ok(()) + }); + + Backlogged { + paths_to_digest: rx, + // deleted_entry_ranges: deleted_entry_ranges_rx, + task, + } + } + + fn digest_files( + &self, + paths: channel::Receiver, Option)>>, + worktree_abs_path: Arc, + cx: &AppContext, + ) -> MightNeedSummaryFiles { + let fs = self.fs.clone(); + let (rx, tx) = channel::bounded(2048); + let task = cx.spawn(|cx| async move { + cx.background_executor() + .scoped(|cx| { + for _ in 0..cx.num_cpus() { + cx.spawn(async { + while let Ok(pairs) = paths.recv().await { + // Note: we could process all these files concurrently if desired. Might or might not speed things up. + for (path, mtime) in pairs { + let entry_abs_path = worktree_abs_path.join(&path); + + // Load the file's contents and compute its hash digest. + let unsummarized_file = { + let Some(contents) = fs + .load(&entry_abs_path) + .await + .with_context(|| { + format!("failed to read path {entry_abs_path:?}") + }) + .log_err() + else { + continue; + }; + + let digest = { + let mut hasher = blake3::Hasher::new(); + // Incorporate both the (relative) file path as well as the contents of the file into the hash. + // This is because in some languages and frameworks, identical files can do different things + // depending on their paths (e.g. Rails controllers). It's also why we send the path to the model. + hasher.update(path.display().to_string().as_bytes()); + hasher.update(contents.as_bytes()); + hasher.finalize().to_hex() + }; + + UnsummarizedFile { + digest, + contents, + path, + mtime, + } + }; + + if let Err(err) = rx + .send(unsummarized_file) + .map_err(|error| anyhow!(error)) + .await + { + log::error!("Error: {:?}", err); + + return; + } + } + } + }); + } + }) + .await; + Ok(()) + }); + + MightNeedSummaryFiles { files: tx, task } + } + + fn summarize_files( + &self, + mut unsummarized_files: channel::Receiver, + cx: &AppContext, + ) -> SummarizeFiles { + let (summarized_tx, summarized_rx) = channel::bounded(512); + let task = cx.spawn(|cx| async move { + while let Some(file) = unsummarized_files.next().await { + log::debug!("Summarizing {:?}", file); + let summary = cx + .update(|cx| Self::summarize_code(&file.contents, &file.path, cx))? + .await + .unwrap_or_else(|err| { + // Log a warning because we'll continue anyway. + // In the future, we may want to try splitting it up into multiple requests and concatenating the summaries, + // but this might give bad summaries due to cutting off source code files in the middle. + log::warn!("Failed to summarize {} - {:?}", file.path.display(), err); + + String::new() + }); + + // Note that the summary could be empty because of an error talking to a cloud provider, + // e.g. because the context limit was exceeded. In that case, we return Ok(String::new()). + if !summary.is_empty() { + summarized_tx + .send(SummarizedFile { + path: file.path.display().to_string(), + digest: file.digest, + summary, + mtime: file.mtime, + }) + .await? + } + } + + Ok(()) + }); + + SummarizeFiles { + files: summarized_rx, + task, + } + } + + fn summarize_code( + code: &str, + path: &Path, + cx: &AppContext, + ) -> impl Future> { + let start = Instant::now(); + let (summary_model_id, use_cache): (LanguageModelId, bool) = ( + "Qwen/Qwen2-7B-Instruct".to_string().into(), // TODO read this from the user's settings. + false, // qwen2 doesn't have a cache, but we should probably infer this from the model + ); + let Some(model) = LanguageModelRegistry::read_global(cx) + .available_models(cx) + .find(|model| &model.id() == &summary_model_id) + else { + return cx.background_executor().spawn(async move { + Err(anyhow!("Couldn't find the preferred summarization model ({:?}) in the language registry's available models", summary_model_id)) + }); + }; + let utf8_path = path.to_string_lossy(); + const PROMPT_BEFORE_CODE: &str = "Summarize what the code in this file does in 3 sentences, using no newlines or bullet points in the summary:"; + let prompt = format!("{PROMPT_BEFORE_CODE}\n{utf8_path}:\n{code}"); + + log::debug!( + "Summarizing code by sending this prompt to {:?}: {:?}", + model.name(), + &prompt + ); + + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![prompt.into()], + cache: use_cache, + }], + tools: Vec::new(), + stop: Vec::new(), + temperature: 1.0, + }; + + let code_len = code.len(); + cx.spawn(|cx| async move { + let stream = model.stream_completion(request, &cx); + cx.background_executor() + .spawn(async move { + let answer: String = stream + .await? + .filter_map(|event| async { + if let Ok(LanguageModelCompletionEvent::Text(text)) = event { + Some(text) + } else { + None + } + }) + .collect() + .await; + + log::info!( + "It took {:?} to summarize {:?} bytes of code.", + start.elapsed(), + code_len + ); + + log::debug!("Summary was: {:?}", &answer); + + Ok(answer) + }) + .await + + // TODO if summarization failed, put it back in the backlog! + }) + } + + fn persist_summaries( + &self, + summaries: channel::Receiver, + cx: &AppContext, + ) -> Task> { + let db_connection = self.db_connection.clone(); + let digest_db = self.file_digest_db; + let summary_db = self.summary_db; + cx.background_executor().spawn(async move { + let mut summaries = summaries.chunks_timeout(4096, Duration::from_secs(2)); + while let Some(summaries) = summaries.next().await { + let mut txn = db_connection.write_txn()?; + for file in &summaries { + log::debug!( + "Saving summary of {:?} - which is {} bytes of summary for content digest {:?}", + &file.path, + file.summary.len(), + file.digest + ); + digest_db.put( + &mut txn, + &file.path, + &FileDigest { + mtime: file.mtime, + digest: file.digest, + }, + )?; + summary_db.put(&mut txn, &file.digest, &file.summary)?; + } + txn.commit()?; + + drop(summaries); + log::debug!("committed summaries"); + } + + Ok(()) + }) + } + + /// Empty out the backlog of files that haven't been resummarized, and resummarize them immediately. + pub(crate) fn flush_backlog( + &self, + worktree_abs_path: Arc, + cx: &AppContext, + ) -> impl Future> { + let start = Instant::now(); + let backlogged = { + let (tx, rx) = channel::bounded(512); + let needs_summary: Vec<(Arc, Option)> = { + let mut backlog = self.backlog.lock(); + + backlog.drain().collect() + }; + + let task = cx.background_executor().spawn(async move { + tx.send(needs_summary).await?; + Ok(()) + }); + + Backlogged { + paths_to_digest: rx, + task, + } + }; + + let digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx); + let needs_summary = self.check_summary_cache(digest.files, cx); + let summaries = self.summarize_files(needs_summary.files, cx); + let persist = self.persist_summaries(summaries.files, cx); + + async move { + futures::try_join!( + backlogged.task, + digest.task, + needs_summary.task, + summaries.task, + persist + )?; + + log::info!("Summarizing backlogged entries took {:?}", start.elapsed()); + + Ok(()) + } + } + + pub(crate) fn backlog_len(&self) -> usize { + self.backlog.lock().len() + } +} + +fn db_key_for_path(path: &Arc) -> String { + path.to_string_lossy().replace('/', "\0") +} diff --git a/crates/semantic_index/src/worktree_index.rs b/crates/semantic_index/src/worktree_index.rs new file mode 100644 index 0000000000..7ca5a49619 --- /dev/null +++ b/crates/semantic_index/src/worktree_index.rs @@ -0,0 +1,217 @@ +use crate::embedding::EmbeddingProvider; +use crate::embedding_index::EmbeddingIndex; +use crate::indexing::IndexingEntrySet; +use crate::summary_index::SummaryIndex; +use anyhow::Result; +use feature_flags::{AutoCommand, FeatureFlagAppExt}; +use fs::Fs; +use futures::future::Shared; +use gpui::{ + AppContext, AsyncAppContext, Context, Model, ModelContext, Subscription, Task, WeakModel, +}; +use language::LanguageRegistry; +use log; +use project::{UpdatedEntriesSet, Worktree}; +use smol::channel; +use std::sync::Arc; +use util::ResultExt; + +#[derive(Clone)] +pub enum WorktreeIndexHandle { + Loading { + index: Shared, Arc>>>, + }, + Loaded { + index: Model, + }, +} + +pub struct WorktreeIndex { + worktree: Model, + db_connection: heed::Env, + embedding_index: EmbeddingIndex, + summary_index: SummaryIndex, + entry_ids_being_indexed: Arc, + _index_entries: Task>, + _subscription: Subscription, +} + +impl WorktreeIndex { + pub fn load( + worktree: Model, + db_connection: heed::Env, + language_registry: Arc, + fs: Arc, + status_tx: channel::Sender<()>, + embedding_provider: Arc, + cx: &mut AppContext, + ) -> Task>> { + let worktree_for_index = worktree.clone(); + let worktree_for_summary = worktree.clone(); + let worktree_abs_path = worktree.read(cx).abs_path(); + let embedding_fs = Arc::clone(&fs); + let summary_fs = fs; + cx.spawn(|mut cx| async move { + let entries_being_indexed = Arc::new(IndexingEntrySet::new(status_tx)); + let (embedding_index, summary_index) = cx + .background_executor() + .spawn({ + let entries_being_indexed = Arc::clone(&entries_being_indexed); + let db_connection = db_connection.clone(); + async move { + let mut txn = db_connection.write_txn()?; + let embedding_index = { + let db_name = worktree_abs_path.to_string_lossy(); + let db = db_connection.create_database(&mut txn, Some(&db_name))?; + + EmbeddingIndex::new( + worktree_for_index, + embedding_fs, + db_connection.clone(), + db, + language_registry, + embedding_provider, + Arc::clone(&entries_being_indexed), + ) + }; + let summary_index = { + let file_digest_db = { + let db_name = + // Prepend something that wouldn't be found at the beginning of an + // absolute path, so we don't get db key namespace conflicts with + // embeddings, which use the abs path as a key. + format!("digests-{}", worktree_abs_path.to_string_lossy()); + db_connection.create_database(&mut txn, Some(&db_name))? + }; + let summary_db = { + let db_name = + // Prepend something that wouldn't be found at the beginning of an + // absolute path, so we don't get db key namespace conflicts with + // embeddings, which use the abs path as a key. + format!("summaries-{}", worktree_abs_path.to_string_lossy()); + db_connection.create_database(&mut txn, Some(&db_name))? + }; + SummaryIndex::new( + worktree_for_summary, + summary_fs, + db_connection.clone(), + file_digest_db, + summary_db, + Arc::clone(&entries_being_indexed), + ) + }; + txn.commit()?; + anyhow::Ok((embedding_index, summary_index)) + } + }) + .await?; + + cx.new_model(|cx| { + Self::new( + worktree, + db_connection, + embedding_index, + summary_index, + entries_being_indexed, + cx, + ) + }) + }) + } + + #[allow(clippy::too_many_arguments)] + pub fn new( + worktree: Model, + db_connection: heed::Env, + embedding_index: EmbeddingIndex, + summary_index: SummaryIndex, + entry_ids_being_indexed: Arc, + cx: &mut ModelContext, + ) -> Self { + let (updated_entries_tx, updated_entries_rx) = channel::unbounded(); + let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| { + if let worktree::Event::UpdatedEntries(update) = event { + log::debug!("Updating entries..."); + _ = updated_entries_tx.try_send(update.clone()); + } + }); + + Self { + db_connection, + embedding_index, + summary_index, + worktree, + entry_ids_being_indexed, + _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)), + _subscription, + } + } + + pub fn entry_ids_being_indexed(&self) -> &IndexingEntrySet { + self.entry_ids_being_indexed.as_ref() + } + + pub fn worktree(&self) -> &Model { + &self.worktree + } + + pub fn db_connection(&self) -> &heed::Env { + &self.db_connection + } + + pub fn embedding_index(&self) -> &EmbeddingIndex { + &self.embedding_index + } + + pub fn summary_index(&self) -> &SummaryIndex { + &self.summary_index + } + + async fn index_entries( + this: WeakModel, + updated_entries: channel::Receiver, + mut cx: AsyncAppContext, + ) -> Result<()> { + let is_auto_available = cx.update(|cx| cx.wait_for_flag::())?.await; + let index = this.update(&mut cx, |this, cx| { + futures::future::try_join( + this.embedding_index.index_entries_changed_on_disk(cx), + this.summary_index + .index_entries_changed_on_disk(is_auto_available, cx), + ) + })?; + index.await.log_err(); + + while let Ok(updated_entries) = updated_entries.recv().await { + let is_auto_available = cx + .update(|cx| cx.has_flag::()) + .unwrap_or(false); + + let index = this.update(&mut cx, |this, cx| { + futures::future::try_join( + this.embedding_index + .index_updated_entries(updated_entries.clone(), cx), + this.summary_index.index_updated_entries( + updated_entries, + is_auto_available, + cx, + ), + ) + })?; + index.await.log_err(); + } + + Ok(()) + } + + #[cfg(test)] + pub fn path_count(&self) -> Result { + use anyhow::Context; + + let txn = self + .db_connection + .read_txn() + .context("failed to create read transaction")?; + Ok(self.embedding_index().db().len(&txn)?) + } +} diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index c6e64deb59..584524a1d7 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -3227,6 +3227,8 @@ pub struct Entry { pub git_status: Option, /// Whether this entry is considered to be a `.env` file. pub is_private: bool, + /// The entry's size on disk, in bytes. + pub size: u64, pub char_bag: CharBag, pub is_fifo: bool, } @@ -3282,6 +3284,7 @@ impl Entry { path, inode: metadata.inode, mtime: Some(metadata.mtime), + size: metadata.len, canonical_path, is_symlink: metadata.is_symlink, is_ignored: false, @@ -5210,6 +5213,7 @@ impl<'a> From<&'a Entry> for proto::Entry { is_external: entry.is_external, git_status: entry.git_status.map(git_status_to_proto), is_fifo: entry.is_fifo, + size: Some(entry.size), } } } @@ -5231,6 +5235,7 @@ impl<'a> TryFrom<(&'a CharBag, proto::Entry)> for Entry { path, inode: entry.inode, mtime: entry.mtime.map(|time| time.into()), + size: entry.size.unwrap_or(0), canonical_path: None, is_ignored: entry.is_ignored, is_external: entry.is_external,