From e309fbda2a95a55a043ad41ead97c568c7aeef19 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 20 Sep 2024 15:09:18 -0700 Subject: [PATCH] Add a slash command for automatically retrieving relevant context (#17972) * [x] put this slash command behind a feature flag until we release embedding access to the general population * [x] choose a name for this slash command and name the rust module to match Release Notes: - N/A --------- Co-authored-by: Jason Co-authored-by: Richard Co-authored-by: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Co-authored-by: Richard Feldman --- assets/prompts/project_slash_command.hbs | 8 + crates/assistant/src/assistant.rs | 32 +- crates/assistant/src/context.rs | 5 +- crates/assistant/src/prompts.rs | 15 + crates/assistant/src/slash_command.rs | 2 +- .../slash_command/cargo_workspace_command.rs | 153 ++++++++++ .../src/slash_command/project_command.rs | 257 +++++++++------- .../src/slash_command/search_command.rs | 63 ++-- .../assistant/src/slash_command_settings.rs | 10 +- crates/evals/src/eval.rs | 2 +- crates/semantic_index/examples/index.rs | 2 +- crates/semantic_index/src/embedding.rs | 23 +- crates/semantic_index/src/project_index.rs | 59 ++-- crates/semantic_index/src/semantic_index.rs | 275 +++++++++++++++--- 14 files changed, 683 insertions(+), 223 deletions(-) create mode 100644 assets/prompts/project_slash_command.hbs create mode 100644 crates/assistant/src/slash_command/cargo_workspace_command.rs diff --git a/assets/prompts/project_slash_command.hbs b/assets/prompts/project_slash_command.hbs new file mode 100644 index 0000000000..6c63f71d89 --- /dev/null +++ b/assets/prompts/project_slash_command.hbs @@ -0,0 +1,8 @@ +A software developer is asking a question about their project. The source files in their project have been indexed into a database of semantic text embeddings. +Your task is to generate a list of 4 diverse search queries that can be run on this embedding database, in order to retrieve a list of code snippets +that are relevant to the developer's question. Redundant search queries will be heavily penalized, so only include another query if it's sufficiently +distinct from previous ones. + +Here is the question that's been asked, together with context that the developer has added manually: + +{{{context_buffer}}} diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 8b9c66ee55..9cc63af5a1 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -41,9 +41,10 @@ use semantic_index::{CloudEmbeddingProvider, SemanticDb}; use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsStore}; use slash_command::{ - auto_command, context_server_command, default_command, delta_command, diagnostics_command, - docs_command, fetch_command, file_command, now_command, project_command, prompt_command, - search_command, symbols_command, tab_command, terminal_command, workflow_command, + auto_command, cargo_workspace_command, context_server_command, default_command, delta_command, + diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command, + prompt_command, search_command, symbols_command, tab_command, terminal_command, + workflow_command, }; use std::path::PathBuf; use std::sync::Arc; @@ -384,20 +385,33 @@ fn register_slash_commands(prompt_builder: Option>, cx: &mut slash_command_registry.register_command(delta_command::DeltaSlashCommand, true); slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true); slash_command_registry.register_command(tab_command::TabSlashCommand, true); - slash_command_registry.register_command(project_command::ProjectSlashCommand, true); + slash_command_registry + .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true); slash_command_registry.register_command(prompt_command::PromptSlashCommand, true); slash_command_registry.register_command(default_command::DefaultSlashCommand, false); slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true); slash_command_registry.register_command(now_command::NowSlashCommand, false); slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true); + slash_command_registry.register_command(fetch_command::FetchSlashCommand, false); if let Some(prompt_builder) = prompt_builder { slash_command_registry.register_command( workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()), true, ); + cx.observe_flag::({ + let slash_command_registry = slash_command_registry.clone(); + move |is_enabled, _cx| { + if is_enabled { + slash_command_registry.register_command( + project_command::ProjectSlashCommand::new(prompt_builder.clone()), + true, + ); + } + } + }) + .detach(); } - slash_command_registry.register_command(fetch_command::FetchSlashCommand, false); cx.observe_flag::({ let slash_command_registry = slash_command_registry.clone(); @@ -435,10 +449,12 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) { slash_command_registry.unregister_command(docs_command::DocsSlashCommand); } - if settings.project.enabled { - slash_command_registry.register_command(project_command::ProjectSlashCommand, true); + if settings.cargo_workspace.enabled { + slash_command_registry + .register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true); } else { - slash_command_registry.unregister_command(project_command::ProjectSlashCommand); + slash_command_registry + .unregister_command(cargo_workspace_command::CargoWorkspaceSlashCommand); } } diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 97a5b3ea98..1cac47831f 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1967,8 +1967,9 @@ impl Context { } pub fn assist(&mut self, cx: &mut ModelContext) -> Option { - let provider = LanguageModelRegistry::read_global(cx).active_provider()?; - let model = LanguageModelRegistry::read_global(cx).active_model()?; + let model_registry = LanguageModelRegistry::read_global(cx); + let provider = model_registry.active_provider()?; + let model = model_registry.active_model()?; let last_message_id = self.get_last_valid_message_id(cx)?; if !provider.is_authenticated(cx) { diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 3b9f75bac9..106935cb88 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -40,6 +40,11 @@ pub struct TerminalAssistantPromptContext { pub user_prompt: String, } +#[derive(Serialize)] +pub struct ProjectSlashCommandPromptContext { + pub context_buffer: String, +} + /// Context required to generate a workflow step resolution prompt. #[derive(Debug, Serialize)] pub struct StepResolutionContext { @@ -317,4 +322,14 @@ impl PromptBuilder { pub fn generate_workflow_prompt(&self) -> Result { self.handlebars.lock().render("edit_workflow", &()) } + + pub fn generate_project_slash_command_prompt( + &self, + context_buffer: String, + ) -> Result { + self.handlebars.lock().render( + "project_slash_command", + &ProjectSlashCommandPromptContext { context_buffer }, + ) + } } diff --git a/crates/assistant/src/slash_command.rs b/crates/assistant/src/slash_command.rs index cf957a15c6..e430e35622 100644 --- a/crates/assistant/src/slash_command.rs +++ b/crates/assistant/src/slash_command.rs @@ -18,8 +18,8 @@ use std::{ }; use ui::ActiveTheme; use workspace::Workspace; - pub mod auto_command; +pub mod cargo_workspace_command; pub mod context_server_command; pub mod default_command; pub mod delta_command; diff --git a/crates/assistant/src/slash_command/cargo_workspace_command.rs b/crates/assistant/src/slash_command/cargo_workspace_command.rs new file mode 100644 index 0000000000..baf16d7f01 --- /dev/null +++ b/crates/assistant/src/slash_command/cargo_workspace_command.rs @@ -0,0 +1,153 @@ +use super::{SlashCommand, SlashCommandOutput}; +use anyhow::{anyhow, Context, Result}; +use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection}; +use fs::Fs; +use gpui::{AppContext, Model, Task, WeakView}; +use language::{BufferSnapshot, LspAdapterDelegate}; +use project::{Project, ProjectPath}; +use std::{ + fmt::Write, + path::Path, + sync::{atomic::AtomicBool, Arc}, +}; +use ui::prelude::*; +use workspace::Workspace; + +pub(crate) struct CargoWorkspaceSlashCommand; + +impl CargoWorkspaceSlashCommand { + async fn build_message(fs: Arc, path_to_cargo_toml: &Path) -> Result { + let buffer = fs.load(path_to_cargo_toml).await?; + let cargo_toml: cargo_toml::Manifest = toml::from_str(&buffer)?; + + let mut message = String::new(); + writeln!(message, "You are in a Rust project.")?; + + if let Some(workspace) = cargo_toml.workspace { + writeln!( + message, + "The project is a Cargo workspace with the following members:" + )?; + for member in workspace.members { + writeln!(message, "- {member}")?; + } + + if !workspace.default_members.is_empty() { + writeln!(message, "The default members are:")?; + for member in workspace.default_members { + writeln!(message, "- {member}")?; + } + } + + if !workspace.dependencies.is_empty() { + writeln!( + message, + "The following workspace dependencies are installed:" + )?; + for dependency in workspace.dependencies.keys() { + writeln!(message, "- {dependency}")?; + } + } + } else if let Some(package) = cargo_toml.package { + writeln!( + message, + "The project name is \"{name}\".", + name = package.name + )?; + + let description = package + .description + .as_ref() + .and_then(|description| description.get().ok().cloned()); + if let Some(description) = description.as_ref() { + writeln!(message, "It describes itself as \"{description}\".")?; + } + + if !cargo_toml.dependencies.is_empty() { + writeln!(message, "The following dependencies are installed:")?; + for dependency in cargo_toml.dependencies.keys() { + writeln!(message, "- {dependency}")?; + } + } + } + + Ok(message) + } + + fn path_to_cargo_toml(project: Model, cx: &mut AppContext) -> Option> { + let worktree = project.read(cx).worktrees(cx).next()?; + let worktree = worktree.read(cx); + let entry = worktree.entry_for_path("Cargo.toml")?; + let path = ProjectPath { + worktree_id: worktree.id(), + path: entry.path.clone(), + }; + Some(Arc::from( + project.read(cx).absolute_path(&path, cx)?.as_path(), + )) + } +} + +impl SlashCommand for CargoWorkspaceSlashCommand { + fn name(&self) -> String { + "cargo-workspace".into() + } + + fn description(&self) -> String { + "insert project workspace metadata".into() + } + + fn menu_text(&self) -> String { + "Insert Project Workspace Metadata".into() + } + + fn complete_argument( + self: Arc, + _arguments: &[String], + _cancel: Arc, + _workspace: Option>, + _cx: &mut WindowContext, + ) -> Task>> { + Task::ready(Err(anyhow!("this command does not require argument"))) + } + + fn requires_argument(&self) -> bool { + false + } + + fn run( + self: Arc, + _arguments: &[String], + _context_slash_command_output_sections: &[SlashCommandOutputSection], + _context_buffer: BufferSnapshot, + workspace: WeakView, + _delegate: Option>, + cx: &mut WindowContext, + ) -> Task> { + let output = workspace.update(cx, |workspace, cx| { + let project = workspace.project().clone(); + let fs = workspace.project().read(cx).fs().clone(); + let path = Self::path_to_cargo_toml(project, cx); + let output = cx.background_executor().spawn(async move { + let path = path.with_context(|| "Cargo.toml not found")?; + Self::build_message(fs, &path).await + }); + + cx.foreground_executor().spawn(async move { + let text = output.await?; + let range = 0..text.len(); + Ok(SlashCommandOutput { + text, + sections: vec![SlashCommandOutputSection { + range, + icon: IconName::FileTree, + label: "Project".into(), + metadata: None, + }], + run_commands_in_text: false, + }) + }) + }); + output.unwrap_or_else(|error| Task::ready(Err(error))) + } +} diff --git a/crates/assistant/src/slash_command/project_command.rs b/crates/assistant/src/slash_command/project_command.rs index 3e8596d942..197e91d91a 100644 --- a/crates/assistant/src/slash_command/project_command.rs +++ b/crates/assistant/src/slash_command/project_command.rs @@ -1,90 +1,39 @@ -use super::{SlashCommand, SlashCommandOutput}; -use anyhow::{anyhow, Context, Result}; +use super::{ + create_label_for_command, search_command::add_search_result_section, SlashCommand, + SlashCommandOutput, +}; +use crate::PromptBuilder; +use anyhow::{anyhow, Result}; use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection}; -use fs::Fs; -use gpui::{AppContext, Model, Task, WeakView}; -use language::{BufferSnapshot, LspAdapterDelegate}; -use project::{Project, ProjectPath}; +use feature_flags::FeatureFlag; +use gpui::{AppContext, Task, WeakView, WindowContext}; +use language::{Anchor, CodeLabel, LspAdapterDelegate}; +use language_model::{LanguageModelRegistry, LanguageModelTool}; +use schemars::JsonSchema; +use semantic_index::SemanticDb; +use serde::Deserialize; + +pub struct ProjectSlashCommandFeatureFlag; + +impl FeatureFlag for ProjectSlashCommandFeatureFlag { + const NAME: &'static str = "project-slash-command"; +} + use std::{ - fmt::Write, - path::Path, + fmt::Write as _, + ops::DerefMut, sync::{atomic::AtomicBool, Arc}, }; -use ui::prelude::*; +use ui::{BorrowAppContext as _, IconName}; use workspace::Workspace; -pub(crate) struct ProjectSlashCommand; +pub struct ProjectSlashCommand { + prompt_builder: Arc, +} impl ProjectSlashCommand { - async fn build_message(fs: Arc, path_to_cargo_toml: &Path) -> Result { - let buffer = fs.load(path_to_cargo_toml).await?; - let cargo_toml: cargo_toml::Manifest = toml::from_str(&buffer)?; - - let mut message = String::new(); - writeln!(message, "You are in a Rust project.")?; - - if let Some(workspace) = cargo_toml.workspace { - writeln!( - message, - "The project is a Cargo workspace with the following members:" - )?; - for member in workspace.members { - writeln!(message, "- {member}")?; - } - - if !workspace.default_members.is_empty() { - writeln!(message, "The default members are:")?; - for member in workspace.default_members { - writeln!(message, "- {member}")?; - } - } - - if !workspace.dependencies.is_empty() { - writeln!( - message, - "The following workspace dependencies are installed:" - )?; - for dependency in workspace.dependencies.keys() { - writeln!(message, "- {dependency}")?; - } - } - } else if let Some(package) = cargo_toml.package { - writeln!( - message, - "The project name is \"{name}\".", - name = package.name - )?; - - let description = package - .description - .as_ref() - .and_then(|description| description.get().ok().cloned()); - if let Some(description) = description.as_ref() { - writeln!(message, "It describes itself as \"{description}\".")?; - } - - if !cargo_toml.dependencies.is_empty() { - writeln!(message, "The following dependencies are installed:")?; - for dependency in cargo_toml.dependencies.keys() { - writeln!(message, "- {dependency}")?; - } - } - } - - Ok(message) - } - - fn path_to_cargo_toml(project: Model, cx: &mut AppContext) -> Option> { - let worktree = project.read(cx).worktrees(cx).next()?; - let worktree = worktree.read(cx); - let entry = worktree.entry_for_path("Cargo.toml")?; - let path = ProjectPath { - worktree_id: worktree.id(), - path: entry.path.clone(), - }; - Some(Arc::from( - project.read(cx).absolute_path(&path, cx)?.as_path(), - )) + pub fn new(prompt_builder: Arc) -> Self { + Self { prompt_builder } } } @@ -93,12 +42,20 @@ impl SlashCommand for ProjectSlashCommand { "project".into() } + fn label(&self, cx: &AppContext) -> CodeLabel { + create_label_for_command("project", &[], cx) + } + fn description(&self) -> String { - "insert project metadata".into() + "Generate semantic searches based on the current context".into() } fn menu_text(&self) -> String { - "Insert Project Metadata".into() + "Project Context".into() + } + + fn requires_argument(&self) -> bool { + false } fn complete_argument( @@ -108,46 +65,126 @@ impl SlashCommand for ProjectSlashCommand { _workspace: Option>, _cx: &mut WindowContext, ) -> Task>> { - Task::ready(Err(anyhow!("this command does not require argument"))) - } - - fn requires_argument(&self) -> bool { - false + Task::ready(Ok(Vec::new())) } fn run( self: Arc, _arguments: &[String], - _context_slash_command_output_sections: &[SlashCommandOutputSection], - _context_buffer: BufferSnapshot, + _context_slash_command_output_sections: &[SlashCommandOutputSection], + context_buffer: language::BufferSnapshot, workspace: WeakView, _delegate: Option>, cx: &mut WindowContext, ) -> Task> { - let output = workspace.update(cx, |workspace, cx| { - let project = workspace.project().clone(); - let fs = workspace.project().read(cx).fs().clone(); - let path = Self::path_to_cargo_toml(project, cx); - let output = cx.background_executor().spawn(async move { - let path = path.with_context(|| "Cargo.toml not found")?; - Self::build_message(fs, &path).await - }); + let model_registry = LanguageModelRegistry::read_global(cx); + let current_model = model_registry.active_model(); + let prompt_builder = self.prompt_builder.clone(); - cx.foreground_executor().spawn(async move { - let text = output.await?; - let range = 0..text.len(); - Ok(SlashCommandOutput { - text, - sections: vec![SlashCommandOutputSection { - range, - icon: IconName::FileTree, - label: "Project".into(), + let Some(workspace) = workspace.upgrade() else { + return Task::ready(Err(anyhow::anyhow!("workspace was dropped"))); + }; + let project = workspace.read(cx).project().clone(); + let fs = project.read(cx).fs().clone(); + let Some(project_index) = + cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx)) + else { + return Task::ready(Err(anyhow::anyhow!("no project indexer"))); + }; + + cx.spawn(|mut cx| async move { + let current_model = current_model.ok_or_else(|| anyhow!("no model selected"))?; + + let prompt = + prompt_builder.generate_project_slash_command_prompt(context_buffer.text())?; + + let search_queries = current_model + .use_tool::( + language_model::LanguageModelRequest { + messages: vec![language_model::LanguageModelRequestMessage { + role: language_model::Role::User, + content: vec![language_model::MessageContent::Text(prompt)], + cache: false, + }], + tools: vec![], + stop: vec![], + temperature: None, + }, + cx.deref_mut(), + ) + .await? + .search_queries; + + let results = project_index + .read_with(&cx, |project_index, cx| { + project_index.search(search_queries.clone(), 25, cx) + })? + .await?; + + let results = SemanticDb::load_results(results, &fs, &cx).await?; + + cx.background_executor() + .spawn(async move { + let mut output = "Project context:\n".to_string(); + let mut sections = Vec::new(); + + for (ix, query) in search_queries.into_iter().enumerate() { + let start_ix = output.len(); + writeln!(&mut output, "Results for {query}:").unwrap(); + let mut has_results = false; + for result in &results { + if result.query_index == ix { + add_search_result_section(result, &mut output, &mut sections); + has_results = true; + } + } + if has_results { + sections.push(SlashCommandOutputSection { + range: start_ix..output.len(), + icon: IconName::MagnifyingGlass, + label: query.into(), + metadata: None, + }); + output.push('\n'); + } else { + output.truncate(start_ix); + } + } + + sections.push(SlashCommandOutputSection { + range: 0..output.len(), + icon: IconName::Book, + label: "Project context".into(), metadata: None, - }], - run_commands_in_text: false, + }); + + Ok(SlashCommandOutput { + text: output, + sections, + run_commands_in_text: true, + }) }) - }) - }); - output.unwrap_or_else(|error| Task::ready(Err(error))) + .await + }) + } +} + +#[derive(JsonSchema, Deserialize)] +struct SearchQueries { + /// An array of semantic search queries. + /// + /// These queries will be used to search the user's codebase. + /// The function can only accept 4 queries, otherwise it will error. + /// As such, it's important that you limit the length of the search_queries array to 5 queries or less. + search_queries: Vec, +} + +impl LanguageModelTool for SearchQueries { + fn name() -> String { + "search_queries".to_string() + } + + fn description() -> String { + "Generate semantic search queries based on context".to_string() } } diff --git a/crates/assistant/src/slash_command/search_command.rs b/crates/assistant/src/slash_command/search_command.rs index 7e408cad39..f0f3ee3d25 100644 --- a/crates/assistant/src/slash_command/search_command.rs +++ b/crates/assistant/src/slash_command/search_command.rs @@ -7,7 +7,7 @@ use anyhow::Result; use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection}; use feature_flags::FeatureFlag; use gpui::{AppContext, Task, WeakView}; -use language::{CodeLabel, LineEnding, LspAdapterDelegate}; +use language::{CodeLabel, LspAdapterDelegate}; use semantic_index::{LoadedSearchResult, SemanticDb}; use std::{ fmt::Write, @@ -101,7 +101,7 @@ impl SlashCommand for SearchSlashCommand { cx.spawn(|cx| async move { let results = project_index .read_with(&cx, |project_index, cx| { - project_index.search(query.clone(), limit.unwrap_or(5), cx) + project_index.search(vec![query.clone()], limit.unwrap_or(5), cx) })? .await?; @@ -112,31 +112,8 @@ impl SlashCommand for SearchSlashCommand { .spawn(async move { let mut text = format!("Search results for {query}:\n"); let mut sections = Vec::new(); - for LoadedSearchResult { - path, - range, - full_path, - file_content, - row_range, - } in loaded_results - { - let section_start_ix = text.len(); - text.push_str(&codeblock_fence_for_path( - Some(&path), - Some(row_range.clone()), - )); - - let mut excerpt = file_content[range].to_string(); - LineEnding::normalize(&mut excerpt); - text.push_str(&excerpt); - writeln!(text, "\n```\n").unwrap(); - let section_end_ix = text.len() - 1; - sections.push(build_entry_output_section( - section_start_ix..section_end_ix, - Some(&full_path), - false, - Some(row_range.start() + 1..row_range.end() + 1), - )); + for loaded_result in &loaded_results { + add_search_result_section(loaded_result, &mut text, &mut sections); } let query = SharedString::from(query); @@ -159,3 +136,35 @@ impl SlashCommand for SearchSlashCommand { }) } } + +pub fn add_search_result_section( + loaded_result: &LoadedSearchResult, + text: &mut String, + sections: &mut Vec>, +) { + let LoadedSearchResult { + path, + full_path, + excerpt_content, + row_range, + .. + } = loaded_result; + let section_start_ix = text.len(); + text.push_str(&codeblock_fence_for_path( + Some(&path), + Some(row_range.clone()), + )); + + text.push_str(&excerpt_content); + if !text.ends_with('\n') { + text.push('\n'); + } + writeln!(text, "```\n").unwrap(); + let section_end_ix = text.len() - 1; + sections.push(build_entry_output_section( + section_start_ix..section_end_ix, + Some(&full_path), + false, + Some(row_range.start() + 1..row_range.end() + 1), + )); +} diff --git a/crates/assistant/src/slash_command_settings.rs b/crates/assistant/src/slash_command_settings.rs index eda950b6a2..c524b37803 100644 --- a/crates/assistant/src/slash_command_settings.rs +++ b/crates/assistant/src/slash_command_settings.rs @@ -10,9 +10,9 @@ pub struct SlashCommandSettings { /// Settings for the `/docs` slash command. #[serde(default)] pub docs: DocsCommandSettings, - /// Settings for the `/project` slash command. + /// Settings for the `/cargo-workspace` slash command. #[serde(default)] - pub project: ProjectCommandSettings, + pub cargo_workspace: CargoWorkspaceCommandSettings, } /// Settings for the `/docs` slash command. @@ -23,10 +23,10 @@ pub struct DocsCommandSettings { pub enabled: bool, } -/// Settings for the `/project` slash command. +/// Settings for the `/cargo-workspace` slash command. #[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)] -pub struct ProjectCommandSettings { - /// Whether `/project` is enabled. +pub struct CargoWorkspaceCommandSettings { + /// Whether `/cargo-workspace` is enabled. #[serde(default)] pub enabled: bool, } diff --git a/crates/evals/src/eval.rs b/crates/evals/src/eval.rs index 0580053373..e2c8b42644 100644 --- a/crates/evals/src/eval.rs +++ b/crates/evals/src/eval.rs @@ -438,7 +438,7 @@ async fn run_eval_project( loop { match cx.update(|cx| { let project_index = project_index.read(cx); - project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx) + project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx) }) { Ok(task) => match task.await { Ok(answer) => { diff --git a/crates/semantic_index/examples/index.rs b/crates/semantic_index/examples/index.rs index 0cc3f9f317..c5c2c633a1 100644 --- a/crates/semantic_index/examples/index.rs +++ b/crates/semantic_index/examples/index.rs @@ -98,7 +98,7 @@ fn main() { .update(|cx| { let project_index = project_index.read(cx); let query = "converting an anchor to a point"; - project_index.search(query.into(), 4, cx) + project_index.search(vec![query.into()], 4, cx) }) .unwrap() .await diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index b05c4ac9da..1e1e0f0be7 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -42,14 +42,23 @@ impl Embedding { self.0.len() } - pub fn similarity(self, other: &Embedding) -> f32 { - debug_assert_eq!(self.0.len(), other.0.len()); - self.0 + pub fn similarity(&self, others: &[Embedding]) -> (f32, usize) { + debug_assert!(others.iter().all(|other| self.0.len() == other.0.len())); + others .iter() - .copied() - .zip(other.0.iter().copied()) - .map(|(a, b)| a * b) - .sum() + .enumerate() + .map(|(index, other)| { + let dot_product: f32 = self + .0 + .iter() + .copied() + .zip(other.0.iter().copied()) + .map(|(a, b)| a * b) + .sum(); + (dot_product, index) + }) + .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap_or((0.0, 0)) } } diff --git a/crates/semantic_index/src/project_index.rs b/crates/semantic_index/src/project_index.rs index 5c35c93fa9..21c036d60a 100644 --- a/crates/semantic_index/src/project_index.rs +++ b/crates/semantic_index/src/project_index.rs @@ -31,20 +31,23 @@ pub struct SearchResult { pub path: Arc, pub range: Range, pub score: f32, + pub query_index: usize, } +#[derive(Debug, PartialEq, Eq)] pub struct LoadedSearchResult { pub path: Arc, - pub range: Range, pub full_path: PathBuf, - pub file_content: String, + pub excerpt_content: String, pub row_range: RangeInclusive, + pub query_index: usize, } pub struct WorktreeSearchResult { pub worktree_id: WorktreeId, pub path: Arc, pub range: Range, + pub query_index: usize, pub score: f32, } @@ -227,7 +230,7 @@ impl ProjectIndex { pub fn search( &self, - query: String, + queries: Vec, limit: usize, cx: &AppContext, ) -> Task>> { @@ -275,15 +278,18 @@ impl ProjectIndex { cx.spawn(|cx| async move { #[cfg(debug_assertions)] let embedding_query_start = std::time::Instant::now(); - log::info!("Searching for {query}"); + log::info!("Searching for {queries:?}"); + let queries: Vec = queries + .iter() + .map(|s| TextToEmbed::new(s.as_str())) + .collect(); - let query_embeddings = embedding_provider - .embed(&[TextToEmbed::new(&query)]) - .await?; - let query_embedding = query_embeddings - .into_iter() - .next() - .ok_or_else(|| anyhow!("no embedding for query"))?; + let query_embeddings = embedding_provider.embed(&queries[..]).await?; + if query_embeddings.len() != queries.len() { + return Err(anyhow!( + "The number of query embeddings does not match the number of queries" + )); + } let mut results_by_worker = Vec::new(); for _ in 0..cx.background_executor().num_cpus() { @@ -292,28 +298,34 @@ impl ProjectIndex { #[cfg(debug_assertions)] let search_start = std::time::Instant::now(); - cx.background_executor() .scoped(|cx| { for results in results_by_worker.iter_mut() { cx.spawn(async { while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await { - let score = chunk.embedding.similarity(&query_embedding); + let (score, query_index) = + chunk.embedding.similarity(&query_embeddings); + let ix = match results.binary_search_by(|probe| { score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal) }) { Ok(ix) | Err(ix) => ix, }; - results.insert( - ix, - WorktreeSearchResult { - worktree_id, - path: path.clone(), - range: chunk.chunk.range.clone(), - score, - }, - ); - results.truncate(limit); + if ix < limit { + results.insert( + ix, + WorktreeSearchResult { + worktree_id, + path: path.clone(), + range: chunk.chunk.range.clone(), + query_index, + score, + }, + ); + if results.len() > limit { + results.pop(); + } + } } }); } @@ -333,6 +345,7 @@ impl ProjectIndex { path: result.path, range: result.range, score: result.score, + query_index: result.query_index, }) })); } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6c97ece024..332b4271a0 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -12,8 +12,13 @@ use anyhow::{Context as _, Result}; use collections::HashMap; use fs::Fs; use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel}; -use project::Project; -use std::{path::PathBuf, sync::Arc}; +use language::LineEnding; +use project::{Project, Worktree}; +use std::{ + cmp::Ordering, + path::{Path, PathBuf}, + sync::Arc, +}; use ui::ViewContext; use util::ResultExt as _; use workspace::Workspace; @@ -77,46 +82,127 @@ impl SemanticDb { } pub async fn load_results( - results: Vec, + mut results: Vec, fs: &Arc, cx: &AsyncAppContext, ) -> Result> { - let mut loaded_results = Vec::new(); - for result in results { - let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| { - let entry_abs_path = worktree.abs_path().join(&result.path); - let mut entry_full_path = PathBuf::from(worktree.root_name()); - entry_full_path.push(&result.path); - let file_content = async { - let entry_abs_path = entry_abs_path; - fs.load(&entry_abs_path).await - }; - (entry_full_path, file_content) - })?; - if let Some(file_content) = file_content.await.log_err() { - let range_start = result.range.start.min(file_content.len()); - let range_end = result.range.end.min(file_content.len()); - - let start_row = file_content[0..range_start].matches('\n').count() as u32; - let end_row = file_content[0..range_end].matches('\n').count() as u32; - let start_line_byte_offset = file_content[0..range_start] - .rfind('\n') - .map(|pos| pos + 1) - .unwrap_or_default(); - let end_line_byte_offset = file_content[range_end..] - .find('\n') - .map(|pos| range_end + pos) - .unwrap_or_else(|| file_content.len()); - - loaded_results.push(LoadedSearchResult { - path: result.path, - range: start_line_byte_offset..end_line_byte_offset, - full_path, - file_content, - row_range: start_row..=end_row, - }); + let mut max_scores_by_path = HashMap::<_, (f32, usize)>::default(); + for result in &results { + let (score, query_index) = max_scores_by_path + .entry((result.worktree.clone(), result.path.clone())) + .or_default(); + if result.score > *score { + *score = result.score; + *query_index = result.query_index; } } + + results.sort_by(|a, b| { + let max_score_a = max_scores_by_path[&(a.worktree.clone(), a.path.clone())].0; + let max_score_b = max_scores_by_path[&(b.worktree.clone(), b.path.clone())].0; + max_score_b + .partial_cmp(&max_score_a) + .unwrap_or(Ordering::Equal) + .then_with(|| a.worktree.entity_id().cmp(&b.worktree.entity_id())) + .then_with(|| a.path.cmp(&b.path)) + .then_with(|| a.range.start.cmp(&b.range.start)) + }); + + let mut last_loaded_file: Option<(Model, Arc, PathBuf, String)> = None; + let mut loaded_results = Vec::::new(); + for result in results { + let full_path; + let file_content; + if let Some(last_loaded_file) = + last_loaded_file + .as_ref() + .filter(|(last_worktree, last_path, _, _)| { + last_worktree == &result.worktree && last_path == &result.path + }) + { + full_path = last_loaded_file.2.clone(); + file_content = &last_loaded_file.3; + } else { + let output = result.worktree.read_with(cx, |worktree, _cx| { + let entry_abs_path = worktree.abs_path().join(&result.path); + let mut entry_full_path = PathBuf::from(worktree.root_name()); + entry_full_path.push(&result.path); + let file_content = async { + let entry_abs_path = entry_abs_path; + fs.load(&entry_abs_path).await + }; + (entry_full_path, file_content) + })?; + full_path = output.0; + let Some(content) = output.1.await.log_err() else { + continue; + }; + last_loaded_file = Some(( + result.worktree.clone(), + result.path.clone(), + full_path.clone(), + content, + )); + file_content = &last_loaded_file.as_ref().unwrap().3; + }; + + let query_index = max_scores_by_path[&(result.worktree.clone(), result.path.clone())].1; + + let mut range_start = result.range.start.min(file_content.len()); + let mut range_end = result.range.end.min(file_content.len()); + while !file_content.is_char_boundary(range_start) { + range_start += 1; + } + while !file_content.is_char_boundary(range_end) { + range_end += 1; + } + + let start_row = file_content[0..range_start].matches('\n').count() as u32; + let mut end_row = file_content[0..range_end].matches('\n').count() as u32; + let start_line_byte_offset = file_content[0..range_start] + .rfind('\n') + .map(|pos| pos + 1) + .unwrap_or_default(); + let mut end_line_byte_offset = range_end; + if file_content[..end_line_byte_offset].ends_with('\n') { + end_row -= 1; + } else { + end_line_byte_offset = file_content[range_end..] + .find('\n') + .map(|pos| range_end + pos + 1) + .unwrap_or_else(|| file_content.len()); + } + let mut excerpt_content = + file_content[start_line_byte_offset..end_line_byte_offset].to_string(); + LineEnding::normalize(&mut excerpt_content); + + if let Some(prev_result) = loaded_results.last_mut() { + if prev_result.full_path == full_path { + if *prev_result.row_range.end() + 1 == start_row { + prev_result.row_range = *prev_result.row_range.start()..=end_row; + prev_result.excerpt_content.push_str(&excerpt_content); + continue; + } + } + } + + loaded_results.push(LoadedSearchResult { + path: result.path, + full_path, + excerpt_content, + row_range: start_row..=end_row, + query_index, + }); + } + + for result in &mut loaded_results { + while result.excerpt_content.ends_with("\n\n") { + result.excerpt_content.pop(); + result.row_range = + *result.row_range.start()..=result.row_range.end().saturating_sub(1) + } + } + Ok(loaded_results) } @@ -312,7 +398,7 @@ mod tests { .update(|cx| { let project_index = project_index.read(cx); let query = "garbage in, garbage out"; - project_index.search(query.into(), 4, cx) + project_index.search(vec![query.into()], 4, cx) }) .await .unwrap(); @@ -426,4 +512,117 @@ mod tests { ], ); } + + #[gpui::test] + async fn test_load_search_results(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project_path = Path::new("/fake_project"); + + let file1_content = "one\ntwo\nthree\nfour\nfive\n"; + let file2_content = "aaa\nbbb\nccc\nddd\neee\n"; + + fs.insert_tree( + project_path, + json!({ + "file1.txt": file1_content, + "file2.txt": file2_content, + }), + ) + .await; + + let fs = fs as Arc; + let project = Project::test(fs.clone(), [project_path], cx).await; + let worktree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap()); + + // chunk that is already newline-aligned + let search_results = vec![SearchResult { + worktree: worktree.clone(), + path: Path::new("file1.txt").into(), + range: 0..file1_content.find("four").unwrap(), + score: 0.5, + query_index: 0, + }]; + assert_eq!( + SemanticDb::load_results(search_results, &fs, &cx.to_async()) + .await + .unwrap(), + &[LoadedSearchResult { + path: Path::new("file1.txt").into(), + full_path: "fake_project/file1.txt".into(), + excerpt_content: "one\ntwo\nthree\n".into(), + row_range: 0..=2, + query_index: 0, + }] + ); + + // chunk that is *not* newline-aligned + let search_results = vec![SearchResult { + worktree: worktree.clone(), + path: Path::new("file1.txt").into(), + range: file1_content.find("two").unwrap() + 1..file1_content.find("four").unwrap() + 2, + score: 0.5, + query_index: 0, + }]; + assert_eq!( + SemanticDb::load_results(search_results, &fs, &cx.to_async()) + .await + .unwrap(), + &[LoadedSearchResult { + path: Path::new("file1.txt").into(), + full_path: "fake_project/file1.txt".into(), + excerpt_content: "two\nthree\nfour\n".into(), + row_range: 1..=3, + query_index: 0, + }] + ); + + // chunks that are adjacent + + let search_results = vec![ + SearchResult { + worktree: worktree.clone(), + path: Path::new("file1.txt").into(), + range: file1_content.find("two").unwrap()..file1_content.len(), + score: 0.6, + query_index: 0, + }, + SearchResult { + worktree: worktree.clone(), + path: Path::new("file1.txt").into(), + range: 0..file1_content.find("two").unwrap(), + score: 0.5, + query_index: 1, + }, + SearchResult { + worktree: worktree.clone(), + path: Path::new("file2.txt").into(), + range: 0..file2_content.len(), + score: 0.8, + query_index: 1, + }, + ]; + assert_eq!( + SemanticDb::load_results(search_results, &fs, &cx.to_async()) + .await + .unwrap(), + &[ + LoadedSearchResult { + path: Path::new("file2.txt").into(), + full_path: "fake_project/file2.txt".into(), + excerpt_content: file2_content.into(), + row_range: 0..=4, + query_index: 1, + }, + LoadedSearchResult { + path: Path::new("file1.txt").into(), + full_path: "fake_project/file1.txt".into(), + excerpt_content: file1_content.into(), + row_range: 0..=4, + query_index: 0, + } + ] + ); + } }