Add a slash command for automatically retrieving relevant context (#17972)

* [x] put this slash command behind a feature flag until we release
embedding access to the general population
* [x] choose a name for this slash command and name the rust module to
match

Release Notes:

- N/A

---------

Co-authored-by: Jason <jason@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Jason Mancuso <7891333+jvmncs@users.noreply.github.com>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Max Brunsfeld 2024-09-20 15:09:18 -07:00 committed by GitHub
parent 5905fbb9ac
commit e309fbda2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 683 additions and 223 deletions

View file

@ -0,0 +1,8 @@
A software developer is asking a question about their project. The source files in their project have been indexed into a database of semantic text embeddings.
Your task is to generate a list of 4 diverse search queries that can be run on this embedding database, in order to retrieve a list of code snippets
that are relevant to the developer's question. Redundant search queries will be heavily penalized, so only include another query if it's sufficiently
distinct from previous ones.
Here is the question that's been asked, together with context that the developer has added manually:
{{{context_buffer}}}

View file

@ -41,9 +41,10 @@ use semantic_index::{CloudEmbeddingProvider, SemanticDb};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore}; use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::{ use slash_command::{
auto_command, context_server_command, default_command, delta_command, diagnostics_command, auto_command, cargo_workspace_command, context_server_command, default_command, delta_command,
docs_command, fetch_command, file_command, now_command, project_command, prompt_command, diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command,
search_command, symbols_command, tab_command, terminal_command, workflow_command, prompt_command, search_command, symbols_command, tab_command, terminal_command,
workflow_command,
}; };
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
@ -384,20 +385,33 @@ fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut
slash_command_registry.register_command(delta_command::DeltaSlashCommand, true); slash_command_registry.register_command(delta_command::DeltaSlashCommand, true);
slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true); slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
slash_command_registry.register_command(tab_command::TabSlashCommand, true); slash_command_registry.register_command(tab_command::TabSlashCommand, true);
slash_command_registry.register_command(project_command::ProjectSlashCommand, true); slash_command_registry
.register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
slash_command_registry.register_command(prompt_command::PromptSlashCommand, true); slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
slash_command_registry.register_command(default_command::DefaultSlashCommand, false); slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true); slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
slash_command_registry.register_command(now_command::NowSlashCommand, false); slash_command_registry.register_command(now_command::NowSlashCommand, false);
slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true); slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
if let Some(prompt_builder) = prompt_builder { if let Some(prompt_builder) = prompt_builder {
slash_command_registry.register_command( slash_command_registry.register_command(
workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()), workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()),
true, true,
); );
cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
let slash_command_registry = slash_command_registry.clone();
move |is_enabled, _cx| {
if is_enabled {
slash_command_registry.register_command(
project_command::ProjectSlashCommand::new(prompt_builder.clone()),
true,
);
}
}
})
.detach();
} }
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({ cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
let slash_command_registry = slash_command_registry.clone(); let slash_command_registry = slash_command_registry.clone();
@ -435,10 +449,12 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) {
slash_command_registry.unregister_command(docs_command::DocsSlashCommand); slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
} }
if settings.project.enabled { if settings.cargo_workspace.enabled {
slash_command_registry.register_command(project_command::ProjectSlashCommand, true); slash_command_registry
.register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
} else { } else {
slash_command_registry.unregister_command(project_command::ProjectSlashCommand); slash_command_registry
.unregister_command(cargo_workspace_command::CargoWorkspaceSlashCommand);
} }
} }

View file

@ -1967,8 +1967,9 @@ impl Context {
} }
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> { pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
let provider = LanguageModelRegistry::read_global(cx).active_provider()?; let model_registry = LanguageModelRegistry::read_global(cx);
let model = LanguageModelRegistry::read_global(cx).active_model()?; let provider = model_registry.active_provider()?;
let model = model_registry.active_model()?;
let last_message_id = self.get_last_valid_message_id(cx)?; let last_message_id = self.get_last_valid_message_id(cx)?;
if !provider.is_authenticated(cx) { if !provider.is_authenticated(cx) {

View file

@ -40,6 +40,11 @@ pub struct TerminalAssistantPromptContext {
pub user_prompt: String, pub user_prompt: String,
} }
#[derive(Serialize)]
pub struct ProjectSlashCommandPromptContext {
pub context_buffer: String,
}
/// Context required to generate a workflow step resolution prompt. /// Context required to generate a workflow step resolution prompt.
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct StepResolutionContext { pub struct StepResolutionContext {
@ -317,4 +322,14 @@ impl PromptBuilder {
pub fn generate_workflow_prompt(&self) -> Result<String, RenderError> { pub fn generate_workflow_prompt(&self) -> Result<String, RenderError> {
self.handlebars.lock().render("edit_workflow", &()) self.handlebars.lock().render("edit_workflow", &())
} }
pub fn generate_project_slash_command_prompt(
&self,
context_buffer: String,
) -> Result<String, RenderError> {
self.handlebars.lock().render(
"project_slash_command",
&ProjectSlashCommandPromptContext { context_buffer },
)
}
} }

View file

@ -18,8 +18,8 @@ use std::{
}; };
use ui::ActiveTheme; use ui::ActiveTheme;
use workspace::Workspace; use workspace::Workspace;
pub mod auto_command; pub mod auto_command;
pub mod cargo_workspace_command;
pub mod context_server_command; pub mod context_server_command;
pub mod default_command; pub mod default_command;
pub mod delta_command; pub mod delta_command;

View file

@ -0,0 +1,153 @@
use super::{SlashCommand, SlashCommandOutput};
use anyhow::{anyhow, Context, Result};
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
use fs::Fs;
use gpui::{AppContext, Model, Task, WeakView};
use language::{BufferSnapshot, LspAdapterDelegate};
use project::{Project, ProjectPath};
use std::{
fmt::Write,
path::Path,
sync::{atomic::AtomicBool, Arc},
};
use ui::prelude::*;
use workspace::Workspace;
pub(crate) struct CargoWorkspaceSlashCommand;
impl CargoWorkspaceSlashCommand {
async fn build_message(fs: Arc<dyn Fs>, path_to_cargo_toml: &Path) -> Result<String> {
let buffer = fs.load(path_to_cargo_toml).await?;
let cargo_toml: cargo_toml::Manifest = toml::from_str(&buffer)?;
let mut message = String::new();
writeln!(message, "You are in a Rust project.")?;
if let Some(workspace) = cargo_toml.workspace {
writeln!(
message,
"The project is a Cargo workspace with the following members:"
)?;
for member in workspace.members {
writeln!(message, "- {member}")?;
}
if !workspace.default_members.is_empty() {
writeln!(message, "The default members are:")?;
for member in workspace.default_members {
writeln!(message, "- {member}")?;
}
}
if !workspace.dependencies.is_empty() {
writeln!(
message,
"The following workspace dependencies are installed:"
)?;
for dependency in workspace.dependencies.keys() {
writeln!(message, "- {dependency}")?;
}
}
} else if let Some(package) = cargo_toml.package {
writeln!(
message,
"The project name is \"{name}\".",
name = package.name
)?;
let description = package
.description
.as_ref()
.and_then(|description| description.get().ok().cloned());
if let Some(description) = description.as_ref() {
writeln!(message, "It describes itself as \"{description}\".")?;
}
if !cargo_toml.dependencies.is_empty() {
writeln!(message, "The following dependencies are installed:")?;
for dependency in cargo_toml.dependencies.keys() {
writeln!(message, "- {dependency}")?;
}
}
}
Ok(message)
}
fn path_to_cargo_toml(project: Model<Project>, cx: &mut AppContext) -> Option<Arc<Path>> {
let worktree = project.read(cx).worktrees(cx).next()?;
let worktree = worktree.read(cx);
let entry = worktree.entry_for_path("Cargo.toml")?;
let path = ProjectPath {
worktree_id: worktree.id(),
path: entry.path.clone(),
};
Some(Arc::from(
project.read(cx).absolute_path(&path, cx)?.as_path(),
))
}
}
impl SlashCommand for CargoWorkspaceSlashCommand {
fn name(&self) -> String {
"cargo-workspace".into()
}
fn description(&self) -> String {
"insert project workspace metadata".into()
}
fn menu_text(&self) -> String {
"Insert Project Workspace Metadata".into()
}
fn complete_argument(
self: Arc<Self>,
_arguments: &[String],
_cancel: Arc<AtomicBool>,
_workspace: Option<WeakView<Workspace>>,
_cx: &mut WindowContext,
) -> Task<Result<Vec<ArgumentCompletion>>> {
Task::ready(Err(anyhow!("this command does not require argument")))
}
fn requires_argument(&self) -> bool {
false
}
fn run(
self: Arc<Self>,
_arguments: &[String],
_context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
_context_buffer: BufferSnapshot,
workspace: WeakView<Workspace>,
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
cx: &mut WindowContext,
) -> Task<Result<SlashCommandOutput>> {
let output = workspace.update(cx, |workspace, cx| {
let project = workspace.project().clone();
let fs = workspace.project().read(cx).fs().clone();
let path = Self::path_to_cargo_toml(project, cx);
let output = cx.background_executor().spawn(async move {
let path = path.with_context(|| "Cargo.toml not found")?;
Self::build_message(fs, &path).await
});
cx.foreground_executor().spawn(async move {
let text = output.await?;
let range = 0..text.len();
Ok(SlashCommandOutput {
text,
sections: vec![SlashCommandOutputSection {
range,
icon: IconName::FileTree,
label: "Project".into(),
metadata: None,
}],
run_commands_in_text: false,
})
})
});
output.unwrap_or_else(|error| Task::ready(Err(error)))
}
}

View file

@ -1,90 +1,39 @@
use super::{SlashCommand, SlashCommandOutput}; use super::{
use anyhow::{anyhow, Context, Result}; create_label_for_command, search_command::add_search_result_section, SlashCommand,
SlashCommandOutput,
};
use crate::PromptBuilder;
use anyhow::{anyhow, Result};
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection}; use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
use fs::Fs; use feature_flags::FeatureFlag;
use gpui::{AppContext, Model, Task, WeakView}; use gpui::{AppContext, Task, WeakView, WindowContext};
use language::{BufferSnapshot, LspAdapterDelegate}; use language::{Anchor, CodeLabel, LspAdapterDelegate};
use project::{Project, ProjectPath}; use language_model::{LanguageModelRegistry, LanguageModelTool};
use schemars::JsonSchema;
use semantic_index::SemanticDb;
use serde::Deserialize;
pub struct ProjectSlashCommandFeatureFlag;
impl FeatureFlag for ProjectSlashCommandFeatureFlag {
const NAME: &'static str = "project-slash-command";
}
use std::{ use std::{
fmt::Write, fmt::Write as _,
path::Path, ops::DerefMut,
sync::{atomic::AtomicBool, Arc}, sync::{atomic::AtomicBool, Arc},
}; };
use ui::prelude::*; use ui::{BorrowAppContext as _, IconName};
use workspace::Workspace; use workspace::Workspace;
pub(crate) struct ProjectSlashCommand; pub struct ProjectSlashCommand {
prompt_builder: Arc<PromptBuilder>,
}
impl ProjectSlashCommand { impl ProjectSlashCommand {
async fn build_message(fs: Arc<dyn Fs>, path_to_cargo_toml: &Path) -> Result<String> { pub fn new(prompt_builder: Arc<PromptBuilder>) -> Self {
let buffer = fs.load(path_to_cargo_toml).await?; Self { prompt_builder }
let cargo_toml: cargo_toml::Manifest = toml::from_str(&buffer)?;
let mut message = String::new();
writeln!(message, "You are in a Rust project.")?;
if let Some(workspace) = cargo_toml.workspace {
writeln!(
message,
"The project is a Cargo workspace with the following members:"
)?;
for member in workspace.members {
writeln!(message, "- {member}")?;
}
if !workspace.default_members.is_empty() {
writeln!(message, "The default members are:")?;
for member in workspace.default_members {
writeln!(message, "- {member}")?;
}
}
if !workspace.dependencies.is_empty() {
writeln!(
message,
"The following workspace dependencies are installed:"
)?;
for dependency in workspace.dependencies.keys() {
writeln!(message, "- {dependency}")?;
}
}
} else if let Some(package) = cargo_toml.package {
writeln!(
message,
"The project name is \"{name}\".",
name = package.name
)?;
let description = package
.description
.as_ref()
.and_then(|description| description.get().ok().cloned());
if let Some(description) = description.as_ref() {
writeln!(message, "It describes itself as \"{description}\".")?;
}
if !cargo_toml.dependencies.is_empty() {
writeln!(message, "The following dependencies are installed:")?;
for dependency in cargo_toml.dependencies.keys() {
writeln!(message, "- {dependency}")?;
}
}
}
Ok(message)
}
fn path_to_cargo_toml(project: Model<Project>, cx: &mut AppContext) -> Option<Arc<Path>> {
let worktree = project.read(cx).worktrees(cx).next()?;
let worktree = worktree.read(cx);
let entry = worktree.entry_for_path("Cargo.toml")?;
let path = ProjectPath {
worktree_id: worktree.id(),
path: entry.path.clone(),
};
Some(Arc::from(
project.read(cx).absolute_path(&path, cx)?.as_path(),
))
} }
} }
@ -93,12 +42,20 @@ impl SlashCommand for ProjectSlashCommand {
"project".into() "project".into()
} }
fn label(&self, cx: &AppContext) -> CodeLabel {
create_label_for_command("project", &[], cx)
}
fn description(&self) -> String { fn description(&self) -> String {
"insert project metadata".into() "Generate semantic searches based on the current context".into()
} }
fn menu_text(&self) -> String { fn menu_text(&self) -> String {
"Insert Project Metadata".into() "Project Context".into()
}
fn requires_argument(&self) -> bool {
false
} }
fn complete_argument( fn complete_argument(
@ -108,46 +65,126 @@ impl SlashCommand for ProjectSlashCommand {
_workspace: Option<WeakView<Workspace>>, _workspace: Option<WeakView<Workspace>>,
_cx: &mut WindowContext, _cx: &mut WindowContext,
) -> Task<Result<Vec<ArgumentCompletion>>> { ) -> Task<Result<Vec<ArgumentCompletion>>> {
Task::ready(Err(anyhow!("this command does not require argument"))) Task::ready(Ok(Vec::new()))
}
fn requires_argument(&self) -> bool {
false
} }
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
_arguments: &[String], _arguments: &[String],
_context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>], _context_slash_command_output_sections: &[SlashCommandOutputSection<Anchor>],
_context_buffer: BufferSnapshot, context_buffer: language::BufferSnapshot,
workspace: WeakView<Workspace>, workspace: WeakView<Workspace>,
_delegate: Option<Arc<dyn LspAdapterDelegate>>, _delegate: Option<Arc<dyn LspAdapterDelegate>>,
cx: &mut WindowContext, cx: &mut WindowContext,
) -> Task<Result<SlashCommandOutput>> { ) -> Task<Result<SlashCommandOutput>> {
let output = workspace.update(cx, |workspace, cx| { let model_registry = LanguageModelRegistry::read_global(cx);
let project = workspace.project().clone(); let current_model = model_registry.active_model();
let fs = workspace.project().read(cx).fs().clone(); let prompt_builder = self.prompt_builder.clone();
let path = Self::path_to_cargo_toml(project, cx);
let output = cx.background_executor().spawn(async move {
let path = path.with_context(|| "Cargo.toml not found")?;
Self::build_message(fs, &path).await
});
cx.foreground_executor().spawn(async move { let Some(workspace) = workspace.upgrade() else {
let text = output.await?; return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
let range = 0..text.len(); };
Ok(SlashCommandOutput { let project = workspace.read(cx).project().clone();
text, let fs = project.read(cx).fs().clone();
sections: vec![SlashCommandOutputSection { let Some(project_index) =
range, cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
icon: IconName::FileTree, else {
label: "Project".into(), return Task::ready(Err(anyhow::anyhow!("no project indexer")));
};
cx.spawn(|mut cx| async move {
let current_model = current_model.ok_or_else(|| anyhow!("no model selected"))?;
let prompt =
prompt_builder.generate_project_slash_command_prompt(context_buffer.text())?;
let search_queries = current_model
.use_tool::<SearchQueries>(
language_model::LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![language_model::MessageContent::Text(prompt)],
cache: false,
}],
tools: vec![],
stop: vec![],
temperature: None,
},
cx.deref_mut(),
)
.await?
.search_queries;
let results = project_index
.read_with(&cx, |project_index, cx| {
project_index.search(search_queries.clone(), 25, cx)
})?
.await?;
let results = SemanticDb::load_results(results, &fs, &cx).await?;
cx.background_executor()
.spawn(async move {
let mut output = "Project context:\n".to_string();
let mut sections = Vec::new();
for (ix, query) in search_queries.into_iter().enumerate() {
let start_ix = output.len();
writeln!(&mut output, "Results for {query}:").unwrap();
let mut has_results = false;
for result in &results {
if result.query_index == ix {
add_search_result_section(result, &mut output, &mut sections);
has_results = true;
}
}
if has_results {
sections.push(SlashCommandOutputSection {
range: start_ix..output.len(),
icon: IconName::MagnifyingGlass,
label: query.into(),
metadata: None,
});
output.push('\n');
} else {
output.truncate(start_ix);
}
}
sections.push(SlashCommandOutputSection {
range: 0..output.len(),
icon: IconName::Book,
label: "Project context".into(),
metadata: None, metadata: None,
}], });
run_commands_in_text: false,
Ok(SlashCommandOutput {
text: output,
sections,
run_commands_in_text: true,
})
}) })
}) .await
}); })
output.unwrap_or_else(|error| Task::ready(Err(error))) }
}
#[derive(JsonSchema, Deserialize)]
struct SearchQueries {
/// An array of semantic search queries.
///
/// These queries will be used to search the user's codebase.
/// The function can only accept 4 queries, otherwise it will error.
/// As such, it's important that you limit the length of the search_queries array to 5 queries or less.
search_queries: Vec<String>,
}
impl LanguageModelTool for SearchQueries {
fn name() -> String {
"search_queries".to_string()
}
fn description() -> String {
"Generate semantic search queries based on context".to_string()
} }
} }

View file

@ -7,7 +7,7 @@ use anyhow::Result;
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection}; use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
use feature_flags::FeatureFlag; use feature_flags::FeatureFlag;
use gpui::{AppContext, Task, WeakView}; use gpui::{AppContext, Task, WeakView};
use language::{CodeLabel, LineEnding, LspAdapterDelegate}; use language::{CodeLabel, LspAdapterDelegate};
use semantic_index::{LoadedSearchResult, SemanticDb}; use semantic_index::{LoadedSearchResult, SemanticDb};
use std::{ use std::{
fmt::Write, fmt::Write,
@ -101,7 +101,7 @@ impl SlashCommand for SearchSlashCommand {
cx.spawn(|cx| async move { cx.spawn(|cx| async move {
let results = project_index let results = project_index
.read_with(&cx, |project_index, cx| { .read_with(&cx, |project_index, cx| {
project_index.search(query.clone(), limit.unwrap_or(5), cx) project_index.search(vec![query.clone()], limit.unwrap_or(5), cx)
})? })?
.await?; .await?;
@ -112,31 +112,8 @@ impl SlashCommand for SearchSlashCommand {
.spawn(async move { .spawn(async move {
let mut text = format!("Search results for {query}:\n"); let mut text = format!("Search results for {query}:\n");
let mut sections = Vec::new(); let mut sections = Vec::new();
for LoadedSearchResult { for loaded_result in &loaded_results {
path, add_search_result_section(loaded_result, &mut text, &mut sections);
range,
full_path,
file_content,
row_range,
} in loaded_results
{
let section_start_ix = text.len();
text.push_str(&codeblock_fence_for_path(
Some(&path),
Some(row_range.clone()),
));
let mut excerpt = file_content[range].to_string();
LineEnding::normalize(&mut excerpt);
text.push_str(&excerpt);
writeln!(text, "\n```\n").unwrap();
let section_end_ix = text.len() - 1;
sections.push(build_entry_output_section(
section_start_ix..section_end_ix,
Some(&full_path),
false,
Some(row_range.start() + 1..row_range.end() + 1),
));
} }
let query = SharedString::from(query); let query = SharedString::from(query);
@ -159,3 +136,35 @@ impl SlashCommand for SearchSlashCommand {
}) })
} }
} }
pub fn add_search_result_section(
loaded_result: &LoadedSearchResult,
text: &mut String,
sections: &mut Vec<SlashCommandOutputSection<usize>>,
) {
let LoadedSearchResult {
path,
full_path,
excerpt_content,
row_range,
..
} = loaded_result;
let section_start_ix = text.len();
text.push_str(&codeblock_fence_for_path(
Some(&path),
Some(row_range.clone()),
));
text.push_str(&excerpt_content);
if !text.ends_with('\n') {
text.push('\n');
}
writeln!(text, "```\n").unwrap();
let section_end_ix = text.len() - 1;
sections.push(build_entry_output_section(
section_start_ix..section_end_ix,
Some(&full_path),
false,
Some(row_range.start() + 1..row_range.end() + 1),
));
}

View file

@ -10,9 +10,9 @@ pub struct SlashCommandSettings {
/// Settings for the `/docs` slash command. /// Settings for the `/docs` slash command.
#[serde(default)] #[serde(default)]
pub docs: DocsCommandSettings, pub docs: DocsCommandSettings,
/// Settings for the `/project` slash command. /// Settings for the `/cargo-workspace` slash command.
#[serde(default)] #[serde(default)]
pub project: ProjectCommandSettings, pub cargo_workspace: CargoWorkspaceCommandSettings,
} }
/// Settings for the `/docs` slash command. /// Settings for the `/docs` slash command.
@ -23,10 +23,10 @@ pub struct DocsCommandSettings {
pub enabled: bool, pub enabled: bool,
} }
/// Settings for the `/project` slash command. /// Settings for the `/cargo-workspace` slash command.
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)] #[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
pub struct ProjectCommandSettings { pub struct CargoWorkspaceCommandSettings {
/// Whether `/project` is enabled. /// Whether `/cargo-workspace` is enabled.
#[serde(default)] #[serde(default)]
pub enabled: bool, pub enabled: bool,
} }

View file

@ -438,7 +438,7 @@ async fn run_eval_project(
loop { loop {
match cx.update(|cx| { match cx.update(|cx| {
let project_index = project_index.read(cx); let project_index = project_index.read(cx);
project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx) project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
}) { }) {
Ok(task) => match task.await { Ok(task) => match task.await {
Ok(answer) => { Ok(answer) => {

View file

@ -98,7 +98,7 @@ fn main() {
.update(|cx| { .update(|cx| {
let project_index = project_index.read(cx); let project_index = project_index.read(cx);
let query = "converting an anchor to a point"; let query = "converting an anchor to a point";
project_index.search(query.into(), 4, cx) project_index.search(vec![query.into()], 4, cx)
}) })
.unwrap() .unwrap()
.await .await

View file

@ -42,14 +42,23 @@ impl Embedding {
self.0.len() self.0.len()
} }
pub fn similarity(self, other: &Embedding) -> f32 { pub fn similarity(&self, others: &[Embedding]) -> (f32, usize) {
debug_assert_eq!(self.0.len(), other.0.len()); debug_assert!(others.iter().all(|other| self.0.len() == other.0.len()));
self.0 others
.iter() .iter()
.copied() .enumerate()
.zip(other.0.iter().copied()) .map(|(index, other)| {
.map(|(a, b)| a * b) let dot_product: f32 = self
.sum() .0
.iter()
.copied()
.zip(other.0.iter().copied())
.map(|(a, b)| a * b)
.sum();
(dot_product, index)
})
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0.0, 0))
} }
} }

View file

@ -31,20 +31,23 @@ pub struct SearchResult {
pub path: Arc<Path>, pub path: Arc<Path>,
pub range: Range<usize>, pub range: Range<usize>,
pub score: f32, pub score: f32,
pub query_index: usize,
} }
#[derive(Debug, PartialEq, Eq)]
pub struct LoadedSearchResult { pub struct LoadedSearchResult {
pub path: Arc<Path>, pub path: Arc<Path>,
pub range: Range<usize>,
pub full_path: PathBuf, pub full_path: PathBuf,
pub file_content: String, pub excerpt_content: String,
pub row_range: RangeInclusive<u32>, pub row_range: RangeInclusive<u32>,
pub query_index: usize,
} }
pub struct WorktreeSearchResult { pub struct WorktreeSearchResult {
pub worktree_id: WorktreeId, pub worktree_id: WorktreeId,
pub path: Arc<Path>, pub path: Arc<Path>,
pub range: Range<usize>, pub range: Range<usize>,
pub query_index: usize,
pub score: f32, pub score: f32,
} }
@ -227,7 +230,7 @@ impl ProjectIndex {
pub fn search( pub fn search(
&self, &self,
query: String, queries: Vec<String>,
limit: usize, limit: usize,
cx: &AppContext, cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> { ) -> Task<Result<Vec<SearchResult>>> {
@ -275,15 +278,18 @@ impl ProjectIndex {
cx.spawn(|cx| async move { cx.spawn(|cx| async move {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now(); let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {query}"); log::info!("Searching for {queries:?}");
let queries: Vec<TextToEmbed> = queries
.iter()
.map(|s| TextToEmbed::new(s.as_str()))
.collect();
let query_embeddings = embedding_provider let query_embeddings = embedding_provider.embed(&queries[..]).await?;
.embed(&[TextToEmbed::new(&query)]) if query_embeddings.len() != queries.len() {
.await?; return Err(anyhow!(
let query_embedding = query_embeddings "The number of query embeddings does not match the number of queries"
.into_iter() ));
.next() }
.ok_or_else(|| anyhow!("no embedding for query"))?;
let mut results_by_worker = Vec::new(); let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() { for _ in 0..cx.background_executor().num_cpus() {
@ -292,28 +298,34 @@ impl ProjectIndex {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
let search_start = std::time::Instant::now(); let search_start = std::time::Instant::now();
cx.background_executor() cx.background_executor()
.scoped(|cx| { .scoped(|cx| {
for results in results_by_worker.iter_mut() { for results in results_by_worker.iter_mut() {
cx.spawn(async { cx.spawn(async {
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await { while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
let score = chunk.embedding.similarity(&query_embedding); let (score, query_index) =
chunk.embedding.similarity(&query_embeddings);
let ix = match results.binary_search_by(|probe| { let ix = match results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal) score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) { }) {
Ok(ix) | Err(ix) => ix, Ok(ix) | Err(ix) => ix,
}; };
results.insert( if ix < limit {
ix, results.insert(
WorktreeSearchResult { ix,
worktree_id, WorktreeSearchResult {
path: path.clone(), worktree_id,
range: chunk.chunk.range.clone(), path: path.clone(),
score, range: chunk.chunk.range.clone(),
}, query_index,
); score,
results.truncate(limit); },
);
if results.len() > limit {
results.pop();
}
}
} }
}); });
} }
@ -333,6 +345,7 @@ impl ProjectIndex {
path: result.path, path: result.path,
range: result.range, range: result.range,
score: result.score, score: result.score,
query_index: result.query_index,
}) })
})); }));
} }

View file

@ -12,8 +12,13 @@ use anyhow::{Context as _, Result};
use collections::HashMap; use collections::HashMap;
use fs::Fs; use fs::Fs;
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel}; use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
use project::Project; use language::LineEnding;
use std::{path::PathBuf, sync::Arc}; use project::{Project, Worktree};
use std::{
cmp::Ordering,
path::{Path, PathBuf},
sync::Arc,
};
use ui::ViewContext; use ui::ViewContext;
use util::ResultExt as _; use util::ResultExt as _;
use workspace::Workspace; use workspace::Workspace;
@ -77,46 +82,127 @@ impl SemanticDb {
} }
pub async fn load_results( pub async fn load_results(
results: Vec<SearchResult>, mut results: Vec<SearchResult>,
fs: &Arc<dyn Fs>, fs: &Arc<dyn Fs>,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> Result<Vec<LoadedSearchResult>> { ) -> Result<Vec<LoadedSearchResult>> {
let mut loaded_results = Vec::new(); let mut max_scores_by_path = HashMap::<_, (f32, usize)>::default();
for result in results { for result in &results {
let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| { let (score, query_index) = max_scores_by_path
let entry_abs_path = worktree.abs_path().join(&result.path); .entry((result.worktree.clone(), result.path.clone()))
let mut entry_full_path = PathBuf::from(worktree.root_name()); .or_default();
entry_full_path.push(&result.path); if result.score > *score {
let file_content = async { *score = result.score;
let entry_abs_path = entry_abs_path; *query_index = result.query_index;
fs.load(&entry_abs_path).await
};
(entry_full_path, file_content)
})?;
if let Some(file_content) = file_content.await.log_err() {
let range_start = result.range.start.min(file_content.len());
let range_end = result.range.end.min(file_content.len());
let start_row = file_content[0..range_start].matches('\n').count() as u32;
let end_row = file_content[0..range_end].matches('\n').count() as u32;
let start_line_byte_offset = file_content[0..range_start]
.rfind('\n')
.map(|pos| pos + 1)
.unwrap_or_default();
let end_line_byte_offset = file_content[range_end..]
.find('\n')
.map(|pos| range_end + pos)
.unwrap_or_else(|| file_content.len());
loaded_results.push(LoadedSearchResult {
path: result.path,
range: start_line_byte_offset..end_line_byte_offset,
full_path,
file_content,
row_range: start_row..=end_row,
});
} }
} }
results.sort_by(|a, b| {
let max_score_a = max_scores_by_path[&(a.worktree.clone(), a.path.clone())].0;
let max_score_b = max_scores_by_path[&(b.worktree.clone(), b.path.clone())].0;
max_score_b
.partial_cmp(&max_score_a)
.unwrap_or(Ordering::Equal)
.then_with(|| a.worktree.entity_id().cmp(&b.worktree.entity_id()))
.then_with(|| a.path.cmp(&b.path))
.then_with(|| a.range.start.cmp(&b.range.start))
});
let mut last_loaded_file: Option<(Model<Worktree>, Arc<Path>, PathBuf, String)> = None;
let mut loaded_results = Vec::<LoadedSearchResult>::new();
for result in results {
let full_path;
let file_content;
if let Some(last_loaded_file) =
last_loaded_file
.as_ref()
.filter(|(last_worktree, last_path, _, _)| {
last_worktree == &result.worktree && last_path == &result.path
})
{
full_path = last_loaded_file.2.clone();
file_content = &last_loaded_file.3;
} else {
let output = result.worktree.read_with(cx, |worktree, _cx| {
let entry_abs_path = worktree.abs_path().join(&result.path);
let mut entry_full_path = PathBuf::from(worktree.root_name());
entry_full_path.push(&result.path);
let file_content = async {
let entry_abs_path = entry_abs_path;
fs.load(&entry_abs_path).await
};
(entry_full_path, file_content)
})?;
full_path = output.0;
let Some(content) = output.1.await.log_err() else {
continue;
};
last_loaded_file = Some((
result.worktree.clone(),
result.path.clone(),
full_path.clone(),
content,
));
file_content = &last_loaded_file.as_ref().unwrap().3;
};
let query_index = max_scores_by_path[&(result.worktree.clone(), result.path.clone())].1;
let mut range_start = result.range.start.min(file_content.len());
let mut range_end = result.range.end.min(file_content.len());
while !file_content.is_char_boundary(range_start) {
range_start += 1;
}
while !file_content.is_char_boundary(range_end) {
range_end += 1;
}
let start_row = file_content[0..range_start].matches('\n').count() as u32;
let mut end_row = file_content[0..range_end].matches('\n').count() as u32;
let start_line_byte_offset = file_content[0..range_start]
.rfind('\n')
.map(|pos| pos + 1)
.unwrap_or_default();
let mut end_line_byte_offset = range_end;
if file_content[..end_line_byte_offset].ends_with('\n') {
end_row -= 1;
} else {
end_line_byte_offset = file_content[range_end..]
.find('\n')
.map(|pos| range_end + pos + 1)
.unwrap_or_else(|| file_content.len());
}
let mut excerpt_content =
file_content[start_line_byte_offset..end_line_byte_offset].to_string();
LineEnding::normalize(&mut excerpt_content);
if let Some(prev_result) = loaded_results.last_mut() {
if prev_result.full_path == full_path {
if *prev_result.row_range.end() + 1 == start_row {
prev_result.row_range = *prev_result.row_range.start()..=end_row;
prev_result.excerpt_content.push_str(&excerpt_content);
continue;
}
}
}
loaded_results.push(LoadedSearchResult {
path: result.path,
full_path,
excerpt_content,
row_range: start_row..=end_row,
query_index,
});
}
for result in &mut loaded_results {
while result.excerpt_content.ends_with("\n\n") {
result.excerpt_content.pop();
result.row_range =
*result.row_range.start()..=result.row_range.end().saturating_sub(1)
}
}
Ok(loaded_results) Ok(loaded_results)
} }
@ -312,7 +398,7 @@ mod tests {
.update(|cx| { .update(|cx| {
let project_index = project_index.read(cx); let project_index = project_index.read(cx);
let query = "garbage in, garbage out"; let query = "garbage in, garbage out";
project_index.search(query.into(), 4, cx) project_index.search(vec![query.into()], 4, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -426,4 +512,117 @@ mod tests {
], ],
); );
} }
#[gpui::test]
async fn test_load_search_results(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project_path = Path::new("/fake_project");
let file1_content = "one\ntwo\nthree\nfour\nfive\n";
let file2_content = "aaa\nbbb\nccc\nddd\neee\n";
fs.insert_tree(
project_path,
json!({
"file1.txt": file1_content,
"file2.txt": file2_content,
}),
)
.await;
let fs = fs as Arc<dyn Fs>;
let project = Project::test(fs.clone(), [project_path], cx).await;
let worktree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap());
// chunk that is already newline-aligned
let search_results = vec![SearchResult {
worktree: worktree.clone(),
path: Path::new("file1.txt").into(),
range: 0..file1_content.find("four").unwrap(),
score: 0.5,
query_index: 0,
}];
assert_eq!(
SemanticDb::load_results(search_results, &fs, &cx.to_async())
.await
.unwrap(),
&[LoadedSearchResult {
path: Path::new("file1.txt").into(),
full_path: "fake_project/file1.txt".into(),
excerpt_content: "one\ntwo\nthree\n".into(),
row_range: 0..=2,
query_index: 0,
}]
);
// chunk that is *not* newline-aligned
let search_results = vec![SearchResult {
worktree: worktree.clone(),
path: Path::new("file1.txt").into(),
range: file1_content.find("two").unwrap() + 1..file1_content.find("four").unwrap() + 2,
score: 0.5,
query_index: 0,
}];
assert_eq!(
SemanticDb::load_results(search_results, &fs, &cx.to_async())
.await
.unwrap(),
&[LoadedSearchResult {
path: Path::new("file1.txt").into(),
full_path: "fake_project/file1.txt".into(),
excerpt_content: "two\nthree\nfour\n".into(),
row_range: 1..=3,
query_index: 0,
}]
);
// chunks that are adjacent
let search_results = vec![
SearchResult {
worktree: worktree.clone(),
path: Path::new("file1.txt").into(),
range: file1_content.find("two").unwrap()..file1_content.len(),
score: 0.6,
query_index: 0,
},
SearchResult {
worktree: worktree.clone(),
path: Path::new("file1.txt").into(),
range: 0..file1_content.find("two").unwrap(),
score: 0.5,
query_index: 1,
},
SearchResult {
worktree: worktree.clone(),
path: Path::new("file2.txt").into(),
range: 0..file2_content.len(),
score: 0.8,
query_index: 1,
},
];
assert_eq!(
SemanticDb::load_results(search_results, &fs, &cx.to_async())
.await
.unwrap(),
&[
LoadedSearchResult {
path: Path::new("file2.txt").into(),
full_path: "fake_project/file2.txt".into(),
excerpt_content: file2_content.into(),
row_range: 0..=4,
query_index: 1,
},
LoadedSearchResult {
path: Path::new("file1.txt").into(),
full_path: "fake_project/file1.txt".into(),
excerpt_content: file1_content.into(),
row_range: 0..=4,
query_index: 0,
}
]
);
}
} }