Add a slash command for automatically retrieving relevant context (#17972)

* [x] put this slash command behind a feature flag until we release
embedding access to the general population
* [x] choose a name for this slash command and name the rust module to
match

Release Notes:

- N/A

---------

Co-authored-by: Jason <jason@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Jason Mancuso <7891333+jvmncs@users.noreply.github.com>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Max Brunsfeld 2024-09-20 15:09:18 -07:00 committed by GitHub
parent 5905fbb9ac
commit e309fbda2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 683 additions and 223 deletions

View file

@ -0,0 +1,8 @@
A software developer is asking a question about their project. The source files in their project have been indexed into a database of semantic text embeddings.
Your task is to generate a list of 4 diverse search queries that can be run on this embedding database, in order to retrieve a list of code snippets
that are relevant to the developer's question. Redundant search queries will be heavily penalized, so only include another query if it's sufficiently
distinct from previous ones.
Here is the question that's been asked, together with context that the developer has added manually:
{{{context_buffer}}}

View file

@ -41,9 +41,10 @@ use semantic_index::{CloudEmbeddingProvider, SemanticDb};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::{
auto_command, context_server_command, default_command, delta_command, diagnostics_command,
docs_command, fetch_command, file_command, now_command, project_command, prompt_command,
search_command, symbols_command, tab_command, terminal_command, workflow_command,
auto_command, cargo_workspace_command, context_server_command, default_command, delta_command,
diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command,
prompt_command, search_command, symbols_command, tab_command, terminal_command,
workflow_command,
};
use std::path::PathBuf;
use std::sync::Arc;
@ -384,20 +385,33 @@ fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut
slash_command_registry.register_command(delta_command::DeltaSlashCommand, true);
slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
slash_command_registry.register_command(tab_command::TabSlashCommand, true);
slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
slash_command_registry
.register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
slash_command_registry.register_command(now_command::NowSlashCommand, false);
slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
if let Some(prompt_builder) = prompt_builder {
slash_command_registry.register_command(
workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()),
true,
);
cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
let slash_command_registry = slash_command_registry.clone();
move |is_enabled, _cx| {
if is_enabled {
slash_command_registry.register_command(
project_command::ProjectSlashCommand::new(prompt_builder.clone()),
true,
);
}
}
})
.detach();
}
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
let slash_command_registry = slash_command_registry.clone();
@ -435,10 +449,12 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) {
slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
}
if settings.project.enabled {
slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
if settings.cargo_workspace.enabled {
slash_command_registry
.register_command(cargo_workspace_command::CargoWorkspaceSlashCommand, true);
} else {
slash_command_registry.unregister_command(project_command::ProjectSlashCommand);
slash_command_registry
.unregister_command(cargo_workspace_command::CargoWorkspaceSlashCommand);
}
}

View file

@ -1967,8 +1967,9 @@ impl Context {
}
pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let model_registry = LanguageModelRegistry::read_global(cx);
let provider = model_registry.active_provider()?;
let model = model_registry.active_model()?;
let last_message_id = self.get_last_valid_message_id(cx)?;
if !provider.is_authenticated(cx) {

View file

@ -40,6 +40,11 @@ pub struct TerminalAssistantPromptContext {
pub user_prompt: String,
}
#[derive(Serialize)]
pub struct ProjectSlashCommandPromptContext {
pub context_buffer: String,
}
/// Context required to generate a workflow step resolution prompt.
#[derive(Debug, Serialize)]
pub struct StepResolutionContext {
@ -317,4 +322,14 @@ impl PromptBuilder {
pub fn generate_workflow_prompt(&self) -> Result<String, RenderError> {
self.handlebars.lock().render("edit_workflow", &())
}
pub fn generate_project_slash_command_prompt(
&self,
context_buffer: String,
) -> Result<String, RenderError> {
self.handlebars.lock().render(
"project_slash_command",
&ProjectSlashCommandPromptContext { context_buffer },
)
}
}

View file

@ -18,8 +18,8 @@ use std::{
};
use ui::ActiveTheme;
use workspace::Workspace;
pub mod auto_command;
pub mod cargo_workspace_command;
pub mod context_server_command;
pub mod default_command;
pub mod delta_command;

View file

@ -0,0 +1,153 @@
use super::{SlashCommand, SlashCommandOutput};
use anyhow::{anyhow, Context, Result};
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
use fs::Fs;
use gpui::{AppContext, Model, Task, WeakView};
use language::{BufferSnapshot, LspAdapterDelegate};
use project::{Project, ProjectPath};
use std::{
fmt::Write,
path::Path,
sync::{atomic::AtomicBool, Arc},
};
use ui::prelude::*;
use workspace::Workspace;
pub(crate) struct CargoWorkspaceSlashCommand;
impl CargoWorkspaceSlashCommand {
async fn build_message(fs: Arc<dyn Fs>, path_to_cargo_toml: &Path) -> Result<String> {
let buffer = fs.load(path_to_cargo_toml).await?;
let cargo_toml: cargo_toml::Manifest = toml::from_str(&buffer)?;
let mut message = String::new();
writeln!(message, "You are in a Rust project.")?;
if let Some(workspace) = cargo_toml.workspace {
writeln!(
message,
"The project is a Cargo workspace with the following members:"
)?;
for member in workspace.members {
writeln!(message, "- {member}")?;
}
if !workspace.default_members.is_empty() {
writeln!(message, "The default members are:")?;
for member in workspace.default_members {
writeln!(message, "- {member}")?;
}
}
if !workspace.dependencies.is_empty() {
writeln!(
message,
"The following workspace dependencies are installed:"
)?;
for dependency in workspace.dependencies.keys() {
writeln!(message, "- {dependency}")?;
}
}
} else if let Some(package) = cargo_toml.package {
writeln!(
message,
"The project name is \"{name}\".",
name = package.name
)?;
let description = package
.description
.as_ref()
.and_then(|description| description.get().ok().cloned());
if let Some(description) = description.as_ref() {
writeln!(message, "It describes itself as \"{description}\".")?;
}
if !cargo_toml.dependencies.is_empty() {
writeln!(message, "The following dependencies are installed:")?;
for dependency in cargo_toml.dependencies.keys() {
writeln!(message, "- {dependency}")?;
}
}
}
Ok(message)
}
fn path_to_cargo_toml(project: Model<Project>, cx: &mut AppContext) -> Option<Arc<Path>> {
let worktree = project.read(cx).worktrees(cx).next()?;
let worktree = worktree.read(cx);
let entry = worktree.entry_for_path("Cargo.toml")?;
let path = ProjectPath {
worktree_id: worktree.id(),
path: entry.path.clone(),
};
Some(Arc::from(
project.read(cx).absolute_path(&path, cx)?.as_path(),
))
}
}
impl SlashCommand for CargoWorkspaceSlashCommand {
fn name(&self) -> String {
"cargo-workspace".into()
}
fn description(&self) -> String {
"insert project workspace metadata".into()
}
fn menu_text(&self) -> String {
"Insert Project Workspace Metadata".into()
}
fn complete_argument(
self: Arc<Self>,
_arguments: &[String],
_cancel: Arc<AtomicBool>,
_workspace: Option<WeakView<Workspace>>,
_cx: &mut WindowContext,
) -> Task<Result<Vec<ArgumentCompletion>>> {
Task::ready(Err(anyhow!("this command does not require argument")))
}
fn requires_argument(&self) -> bool {
false
}
fn run(
self: Arc<Self>,
_arguments: &[String],
_context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
_context_buffer: BufferSnapshot,
workspace: WeakView<Workspace>,
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
cx: &mut WindowContext,
) -> Task<Result<SlashCommandOutput>> {
let output = workspace.update(cx, |workspace, cx| {
let project = workspace.project().clone();
let fs = workspace.project().read(cx).fs().clone();
let path = Self::path_to_cargo_toml(project, cx);
let output = cx.background_executor().spawn(async move {
let path = path.with_context(|| "Cargo.toml not found")?;
Self::build_message(fs, &path).await
});
cx.foreground_executor().spawn(async move {
let text = output.await?;
let range = 0..text.len();
Ok(SlashCommandOutput {
text,
sections: vec![SlashCommandOutputSection {
range,
icon: IconName::FileTree,
label: "Project".into(),
metadata: None,
}],
run_commands_in_text: false,
})
})
});
output.unwrap_or_else(|error| Task::ready(Err(error)))
}
}

View file

@ -1,90 +1,39 @@
use super::{SlashCommand, SlashCommandOutput};
use anyhow::{anyhow, Context, Result};
use super::{
create_label_for_command, search_command::add_search_result_section, SlashCommand,
SlashCommandOutput,
};
use crate::PromptBuilder;
use anyhow::{anyhow, Result};
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
use fs::Fs;
use gpui::{AppContext, Model, Task, WeakView};
use language::{BufferSnapshot, LspAdapterDelegate};
use project::{Project, ProjectPath};
use feature_flags::FeatureFlag;
use gpui::{AppContext, Task, WeakView, WindowContext};
use language::{Anchor, CodeLabel, LspAdapterDelegate};
use language_model::{LanguageModelRegistry, LanguageModelTool};
use schemars::JsonSchema;
use semantic_index::SemanticDb;
use serde::Deserialize;
pub struct ProjectSlashCommandFeatureFlag;
impl FeatureFlag for ProjectSlashCommandFeatureFlag {
const NAME: &'static str = "project-slash-command";
}
use std::{
fmt::Write,
path::Path,
fmt::Write as _,
ops::DerefMut,
sync::{atomic::AtomicBool, Arc},
};
use ui::prelude::*;
use ui::{BorrowAppContext as _, IconName};
use workspace::Workspace;
pub(crate) struct ProjectSlashCommand;
pub struct ProjectSlashCommand {
prompt_builder: Arc<PromptBuilder>,
}
impl ProjectSlashCommand {
async fn build_message(fs: Arc<dyn Fs>, path_to_cargo_toml: &Path) -> Result<String> {
let buffer = fs.load(path_to_cargo_toml).await?;
let cargo_toml: cargo_toml::Manifest = toml::from_str(&buffer)?;
let mut message = String::new();
writeln!(message, "You are in a Rust project.")?;
if let Some(workspace) = cargo_toml.workspace {
writeln!(
message,
"The project is a Cargo workspace with the following members:"
)?;
for member in workspace.members {
writeln!(message, "- {member}")?;
}
if !workspace.default_members.is_empty() {
writeln!(message, "The default members are:")?;
for member in workspace.default_members {
writeln!(message, "- {member}")?;
}
}
if !workspace.dependencies.is_empty() {
writeln!(
message,
"The following workspace dependencies are installed:"
)?;
for dependency in workspace.dependencies.keys() {
writeln!(message, "- {dependency}")?;
}
}
} else if let Some(package) = cargo_toml.package {
writeln!(
message,
"The project name is \"{name}\".",
name = package.name
)?;
let description = package
.description
.as_ref()
.and_then(|description| description.get().ok().cloned());
if let Some(description) = description.as_ref() {
writeln!(message, "It describes itself as \"{description}\".")?;
}
if !cargo_toml.dependencies.is_empty() {
writeln!(message, "The following dependencies are installed:")?;
for dependency in cargo_toml.dependencies.keys() {
writeln!(message, "- {dependency}")?;
}
}
}
Ok(message)
}
fn path_to_cargo_toml(project: Model<Project>, cx: &mut AppContext) -> Option<Arc<Path>> {
let worktree = project.read(cx).worktrees(cx).next()?;
let worktree = worktree.read(cx);
let entry = worktree.entry_for_path("Cargo.toml")?;
let path = ProjectPath {
worktree_id: worktree.id(),
path: entry.path.clone(),
};
Some(Arc::from(
project.read(cx).absolute_path(&path, cx)?.as_path(),
))
pub fn new(prompt_builder: Arc<PromptBuilder>) -> Self {
Self { prompt_builder }
}
}
@ -93,12 +42,20 @@ impl SlashCommand for ProjectSlashCommand {
"project".into()
}
fn label(&self, cx: &AppContext) -> CodeLabel {
create_label_for_command("project", &[], cx)
}
fn description(&self) -> String {
"insert project metadata".into()
"Generate semantic searches based on the current context".into()
}
fn menu_text(&self) -> String {
"Insert Project Metadata".into()
"Project Context".into()
}
fn requires_argument(&self) -> bool {
false
}
fn complete_argument(
@ -108,46 +65,126 @@ impl SlashCommand for ProjectSlashCommand {
_workspace: Option<WeakView<Workspace>>,
_cx: &mut WindowContext,
) -> Task<Result<Vec<ArgumentCompletion>>> {
Task::ready(Err(anyhow!("this command does not require argument")))
}
fn requires_argument(&self) -> bool {
false
Task::ready(Ok(Vec::new()))
}
fn run(
self: Arc<Self>,
_arguments: &[String],
_context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
_context_buffer: BufferSnapshot,
_context_slash_command_output_sections: &[SlashCommandOutputSection<Anchor>],
context_buffer: language::BufferSnapshot,
workspace: WeakView<Workspace>,
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
cx: &mut WindowContext,
) -> Task<Result<SlashCommandOutput>> {
let output = workspace.update(cx, |workspace, cx| {
let project = workspace.project().clone();
let fs = workspace.project().read(cx).fs().clone();
let path = Self::path_to_cargo_toml(project, cx);
let output = cx.background_executor().spawn(async move {
let path = path.with_context(|| "Cargo.toml not found")?;
Self::build_message(fs, &path).await
});
let model_registry = LanguageModelRegistry::read_global(cx);
let current_model = model_registry.active_model();
let prompt_builder = self.prompt_builder.clone();
cx.foreground_executor().spawn(async move {
let text = output.await?;
let range = 0..text.len();
Ok(SlashCommandOutput {
text,
sections: vec![SlashCommandOutputSection {
range,
icon: IconName::FileTree,
label: "Project".into(),
let Some(workspace) = workspace.upgrade() else {
return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
};
let project = workspace.read(cx).project().clone();
let fs = project.read(cx).fs().clone();
let Some(project_index) =
cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
else {
return Task::ready(Err(anyhow::anyhow!("no project indexer")));
};
cx.spawn(|mut cx| async move {
let current_model = current_model.ok_or_else(|| anyhow!("no model selected"))?;
let prompt =
prompt_builder.generate_project_slash_command_prompt(context_buffer.text())?;
let search_queries = current_model
.use_tool::<SearchQueries>(
language_model::LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![language_model::MessageContent::Text(prompt)],
cache: false,
}],
tools: vec![],
stop: vec![],
temperature: None,
},
cx.deref_mut(),
)
.await?
.search_queries;
let results = project_index
.read_with(&cx, |project_index, cx| {
project_index.search(search_queries.clone(), 25, cx)
})?
.await?;
let results = SemanticDb::load_results(results, &fs, &cx).await?;
cx.background_executor()
.spawn(async move {
let mut output = "Project context:\n".to_string();
let mut sections = Vec::new();
for (ix, query) in search_queries.into_iter().enumerate() {
let start_ix = output.len();
writeln!(&mut output, "Results for {query}:").unwrap();
let mut has_results = false;
for result in &results {
if result.query_index == ix {
add_search_result_section(result, &mut output, &mut sections);
has_results = true;
}
}
if has_results {
sections.push(SlashCommandOutputSection {
range: start_ix..output.len(),
icon: IconName::MagnifyingGlass,
label: query.into(),
metadata: None,
});
output.push('\n');
} else {
output.truncate(start_ix);
}
}
sections.push(SlashCommandOutputSection {
range: 0..output.len(),
icon: IconName::Book,
label: "Project context".into(),
metadata: None,
}],
run_commands_in_text: false,
});
Ok(SlashCommandOutput {
text: output,
sections,
run_commands_in_text: true,
})
})
})
});
output.unwrap_or_else(|error| Task::ready(Err(error)))
.await
})
}
}
#[derive(JsonSchema, Deserialize)]
struct SearchQueries {
/// An array of semantic search queries.
///
/// These queries will be used to search the user's codebase.
/// The function can only accept 4 queries, otherwise it will error.
/// As such, it's important that you limit the length of the search_queries array to 5 queries or less.
search_queries: Vec<String>,
}
impl LanguageModelTool for SearchQueries {
fn name() -> String {
"search_queries".to_string()
}
fn description() -> String {
"Generate semantic search queries based on context".to_string()
}
}

View file

@ -7,7 +7,7 @@ use anyhow::Result;
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
use feature_flags::FeatureFlag;
use gpui::{AppContext, Task, WeakView};
use language::{CodeLabel, LineEnding, LspAdapterDelegate};
use language::{CodeLabel, LspAdapterDelegate};
use semantic_index::{LoadedSearchResult, SemanticDb};
use std::{
fmt::Write,
@ -101,7 +101,7 @@ impl SlashCommand for SearchSlashCommand {
cx.spawn(|cx| async move {
let results = project_index
.read_with(&cx, |project_index, cx| {
project_index.search(query.clone(), limit.unwrap_or(5), cx)
project_index.search(vec![query.clone()], limit.unwrap_or(5), cx)
})?
.await?;
@ -112,31 +112,8 @@ impl SlashCommand for SearchSlashCommand {
.spawn(async move {
let mut text = format!("Search results for {query}:\n");
let mut sections = Vec::new();
for LoadedSearchResult {
path,
range,
full_path,
file_content,
row_range,
} in loaded_results
{
let section_start_ix = text.len();
text.push_str(&codeblock_fence_for_path(
Some(&path),
Some(row_range.clone()),
));
let mut excerpt = file_content[range].to_string();
LineEnding::normalize(&mut excerpt);
text.push_str(&excerpt);
writeln!(text, "\n```\n").unwrap();
let section_end_ix = text.len() - 1;
sections.push(build_entry_output_section(
section_start_ix..section_end_ix,
Some(&full_path),
false,
Some(row_range.start() + 1..row_range.end() + 1),
));
for loaded_result in &loaded_results {
add_search_result_section(loaded_result, &mut text, &mut sections);
}
let query = SharedString::from(query);
@ -159,3 +136,35 @@ impl SlashCommand for SearchSlashCommand {
})
}
}
pub fn add_search_result_section(
loaded_result: &LoadedSearchResult,
text: &mut String,
sections: &mut Vec<SlashCommandOutputSection<usize>>,
) {
let LoadedSearchResult {
path,
full_path,
excerpt_content,
row_range,
..
} = loaded_result;
let section_start_ix = text.len();
text.push_str(&codeblock_fence_for_path(
Some(&path),
Some(row_range.clone()),
));
text.push_str(&excerpt_content);
if !text.ends_with('\n') {
text.push('\n');
}
writeln!(text, "```\n").unwrap();
let section_end_ix = text.len() - 1;
sections.push(build_entry_output_section(
section_start_ix..section_end_ix,
Some(&full_path),
false,
Some(row_range.start() + 1..row_range.end() + 1),
));
}

View file

@ -10,9 +10,9 @@ pub struct SlashCommandSettings {
/// Settings for the `/docs` slash command.
#[serde(default)]
pub docs: DocsCommandSettings,
/// Settings for the `/project` slash command.
/// Settings for the `/cargo-workspace` slash command.
#[serde(default)]
pub project: ProjectCommandSettings,
pub cargo_workspace: CargoWorkspaceCommandSettings,
}
/// Settings for the `/docs` slash command.
@ -23,10 +23,10 @@ pub struct DocsCommandSettings {
pub enabled: bool,
}
/// Settings for the `/project` slash command.
/// Settings for the `/cargo-workspace` slash command.
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
pub struct ProjectCommandSettings {
/// Whether `/project` is enabled.
pub struct CargoWorkspaceCommandSettings {
/// Whether `/cargo-workspace` is enabled.
#[serde(default)]
pub enabled: bool,
}

View file

@ -438,7 +438,7 @@ async fn run_eval_project(
loop {
match cx.update(|cx| {
let project_index = project_index.read(cx);
project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx)
project_index.search(vec![query.query.clone()], SEARCH_RESULT_LIMIT, cx)
}) {
Ok(task) => match task.await {
Ok(answer) => {

View file

@ -98,7 +98,7 @@ fn main() {
.update(|cx| {
let project_index = project_index.read(cx);
let query = "converting an anchor to a point";
project_index.search(query.into(), 4, cx)
project_index.search(vec![query.into()], 4, cx)
})
.unwrap()
.await

View file

@ -42,14 +42,23 @@ impl Embedding {
self.0.len()
}
pub fn similarity(self, other: &Embedding) -> f32 {
debug_assert_eq!(self.0.len(), other.0.len());
self.0
pub fn similarity(&self, others: &[Embedding]) -> (f32, usize) {
debug_assert!(others.iter().all(|other| self.0.len() == other.0.len()));
others
.iter()
.copied()
.zip(other.0.iter().copied())
.map(|(a, b)| a * b)
.sum()
.enumerate()
.map(|(index, other)| {
let dot_product: f32 = self
.0
.iter()
.copied()
.zip(other.0.iter().copied())
.map(|(a, b)| a * b)
.sum();
(dot_product, index)
})
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0.0, 0))
}
}

View file

@ -31,20 +31,23 @@ pub struct SearchResult {
pub path: Arc<Path>,
pub range: Range<usize>,
pub score: f32,
pub query_index: usize,
}
#[derive(Debug, PartialEq, Eq)]
pub struct LoadedSearchResult {
pub path: Arc<Path>,
pub range: Range<usize>,
pub full_path: PathBuf,
pub file_content: String,
pub excerpt_content: String,
pub row_range: RangeInclusive<u32>,
pub query_index: usize,
}
pub struct WorktreeSearchResult {
pub worktree_id: WorktreeId,
pub path: Arc<Path>,
pub range: Range<usize>,
pub query_index: usize,
pub score: f32,
}
@ -227,7 +230,7 @@ impl ProjectIndex {
pub fn search(
&self,
query: String,
queries: Vec<String>,
limit: usize,
cx: &AppContext,
) -> Task<Result<Vec<SearchResult>>> {
@ -275,15 +278,18 @@ impl ProjectIndex {
cx.spawn(|cx| async move {
#[cfg(debug_assertions)]
let embedding_query_start = std::time::Instant::now();
log::info!("Searching for {query}");
log::info!("Searching for {queries:?}");
let queries: Vec<TextToEmbed> = queries
.iter()
.map(|s| TextToEmbed::new(s.as_str()))
.collect();
let query_embeddings = embedding_provider
.embed(&[TextToEmbed::new(&query)])
.await?;
let query_embedding = query_embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("no embedding for query"))?;
let query_embeddings = embedding_provider.embed(&queries[..]).await?;
if query_embeddings.len() != queries.len() {
return Err(anyhow!(
"The number of query embeddings does not match the number of queries"
));
}
let mut results_by_worker = Vec::new();
for _ in 0..cx.background_executor().num_cpus() {
@ -292,28 +298,34 @@ impl ProjectIndex {
#[cfg(debug_assertions)]
let search_start = std::time::Instant::now();
cx.background_executor()
.scoped(|cx| {
for results in results_by_worker.iter_mut() {
cx.spawn(async {
while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
let score = chunk.embedding.similarity(&query_embedding);
let (score, query_index) =
chunk.embedding.similarity(&query_embeddings);
let ix = match results.binary_search_by(|probe| {
score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
}) {
Ok(ix) | Err(ix) => ix,
};
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
score,
},
);
results.truncate(limit);
if ix < limit {
results.insert(
ix,
WorktreeSearchResult {
worktree_id,
path: path.clone(),
range: chunk.chunk.range.clone(),
query_index,
score,
},
);
if results.len() > limit {
results.pop();
}
}
}
});
}
@ -333,6 +345,7 @@ impl ProjectIndex {
path: result.path,
range: result.range,
score: result.score,
query_index: result.query_index,
})
}));
}

View file

@ -12,8 +12,13 @@ use anyhow::{Context as _, Result};
use collections::HashMap;
use fs::Fs;
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
use project::Project;
use std::{path::PathBuf, sync::Arc};
use language::LineEnding;
use project::{Project, Worktree};
use std::{
cmp::Ordering,
path::{Path, PathBuf},
sync::Arc,
};
use ui::ViewContext;
use util::ResultExt as _;
use workspace::Workspace;
@ -77,46 +82,127 @@ impl SemanticDb {
}
pub async fn load_results(
results: Vec<SearchResult>,
mut results: Vec<SearchResult>,
fs: &Arc<dyn Fs>,
cx: &AsyncAppContext,
) -> Result<Vec<LoadedSearchResult>> {
let mut loaded_results = Vec::new();
for result in results {
let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| {
let entry_abs_path = worktree.abs_path().join(&result.path);
let mut entry_full_path = PathBuf::from(worktree.root_name());
entry_full_path.push(&result.path);
let file_content = async {
let entry_abs_path = entry_abs_path;
fs.load(&entry_abs_path).await
};
(entry_full_path, file_content)
})?;
if let Some(file_content) = file_content.await.log_err() {
let range_start = result.range.start.min(file_content.len());
let range_end = result.range.end.min(file_content.len());
let start_row = file_content[0..range_start].matches('\n').count() as u32;
let end_row = file_content[0..range_end].matches('\n').count() as u32;
let start_line_byte_offset = file_content[0..range_start]
.rfind('\n')
.map(|pos| pos + 1)
.unwrap_or_default();
let end_line_byte_offset = file_content[range_end..]
.find('\n')
.map(|pos| range_end + pos)
.unwrap_or_else(|| file_content.len());
loaded_results.push(LoadedSearchResult {
path: result.path,
range: start_line_byte_offset..end_line_byte_offset,
full_path,
file_content,
row_range: start_row..=end_row,
});
let mut max_scores_by_path = HashMap::<_, (f32, usize)>::default();
for result in &results {
let (score, query_index) = max_scores_by_path
.entry((result.worktree.clone(), result.path.clone()))
.or_default();
if result.score > *score {
*score = result.score;
*query_index = result.query_index;
}
}
results.sort_by(|a, b| {
let max_score_a = max_scores_by_path[&(a.worktree.clone(), a.path.clone())].0;
let max_score_b = max_scores_by_path[&(b.worktree.clone(), b.path.clone())].0;
max_score_b
.partial_cmp(&max_score_a)
.unwrap_or(Ordering::Equal)
.then_with(|| a.worktree.entity_id().cmp(&b.worktree.entity_id()))
.then_with(|| a.path.cmp(&b.path))
.then_with(|| a.range.start.cmp(&b.range.start))
});
let mut last_loaded_file: Option<(Model<Worktree>, Arc<Path>, PathBuf, String)> = None;
let mut loaded_results = Vec::<LoadedSearchResult>::new();
for result in results {
let full_path;
let file_content;
if let Some(last_loaded_file) =
last_loaded_file
.as_ref()
.filter(|(last_worktree, last_path, _, _)| {
last_worktree == &result.worktree && last_path == &result.path
})
{
full_path = last_loaded_file.2.clone();
file_content = &last_loaded_file.3;
} else {
let output = result.worktree.read_with(cx, |worktree, _cx| {
let entry_abs_path = worktree.abs_path().join(&result.path);
let mut entry_full_path = PathBuf::from(worktree.root_name());
entry_full_path.push(&result.path);
let file_content = async {
let entry_abs_path = entry_abs_path;
fs.load(&entry_abs_path).await
};
(entry_full_path, file_content)
})?;
full_path = output.0;
let Some(content) = output.1.await.log_err() else {
continue;
};
last_loaded_file = Some((
result.worktree.clone(),
result.path.clone(),
full_path.clone(),
content,
));
file_content = &last_loaded_file.as_ref().unwrap().3;
};
let query_index = max_scores_by_path[&(result.worktree.clone(), result.path.clone())].1;
let mut range_start = result.range.start.min(file_content.len());
let mut range_end = result.range.end.min(file_content.len());
while !file_content.is_char_boundary(range_start) {
range_start += 1;
}
while !file_content.is_char_boundary(range_end) {
range_end += 1;
}
let start_row = file_content[0..range_start].matches('\n').count() as u32;
let mut end_row = file_content[0..range_end].matches('\n').count() as u32;
let start_line_byte_offset = file_content[0..range_start]
.rfind('\n')
.map(|pos| pos + 1)
.unwrap_or_default();
let mut end_line_byte_offset = range_end;
if file_content[..end_line_byte_offset].ends_with('\n') {
end_row -= 1;
} else {
end_line_byte_offset = file_content[range_end..]
.find('\n')
.map(|pos| range_end + pos + 1)
.unwrap_or_else(|| file_content.len());
}
let mut excerpt_content =
file_content[start_line_byte_offset..end_line_byte_offset].to_string();
LineEnding::normalize(&mut excerpt_content);
if let Some(prev_result) = loaded_results.last_mut() {
if prev_result.full_path == full_path {
if *prev_result.row_range.end() + 1 == start_row {
prev_result.row_range = *prev_result.row_range.start()..=end_row;
prev_result.excerpt_content.push_str(&excerpt_content);
continue;
}
}
}
loaded_results.push(LoadedSearchResult {
path: result.path,
full_path,
excerpt_content,
row_range: start_row..=end_row,
query_index,
});
}
for result in &mut loaded_results {
while result.excerpt_content.ends_with("\n\n") {
result.excerpt_content.pop();
result.row_range =
*result.row_range.start()..=result.row_range.end().saturating_sub(1)
}
}
Ok(loaded_results)
}
@ -312,7 +398,7 @@ mod tests {
.update(|cx| {
let project_index = project_index.read(cx);
let query = "garbage in, garbage out";
project_index.search(query.into(), 4, cx)
project_index.search(vec![query.into()], 4, cx)
})
.await
.unwrap();
@ -426,4 +512,117 @@ mod tests {
],
);
}
#[gpui::test]
async fn test_load_search_results(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project_path = Path::new("/fake_project");
let file1_content = "one\ntwo\nthree\nfour\nfive\n";
let file2_content = "aaa\nbbb\nccc\nddd\neee\n";
fs.insert_tree(
project_path,
json!({
"file1.txt": file1_content,
"file2.txt": file2_content,
}),
)
.await;
let fs = fs as Arc<dyn Fs>;
let project = Project::test(fs.clone(), [project_path], cx).await;
let worktree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap());
// chunk that is already newline-aligned
let search_results = vec![SearchResult {
worktree: worktree.clone(),
path: Path::new("file1.txt").into(),
range: 0..file1_content.find("four").unwrap(),
score: 0.5,
query_index: 0,
}];
assert_eq!(
SemanticDb::load_results(search_results, &fs, &cx.to_async())
.await
.unwrap(),
&[LoadedSearchResult {
path: Path::new("file1.txt").into(),
full_path: "fake_project/file1.txt".into(),
excerpt_content: "one\ntwo\nthree\n".into(),
row_range: 0..=2,
query_index: 0,
}]
);
// chunk that is *not* newline-aligned
let search_results = vec![SearchResult {
worktree: worktree.clone(),
path: Path::new("file1.txt").into(),
range: file1_content.find("two").unwrap() + 1..file1_content.find("four").unwrap() + 2,
score: 0.5,
query_index: 0,
}];
assert_eq!(
SemanticDb::load_results(search_results, &fs, &cx.to_async())
.await
.unwrap(),
&[LoadedSearchResult {
path: Path::new("file1.txt").into(),
full_path: "fake_project/file1.txt".into(),
excerpt_content: "two\nthree\nfour\n".into(),
row_range: 1..=3,
query_index: 0,
}]
);
// chunks that are adjacent
let search_results = vec![
SearchResult {
worktree: worktree.clone(),
path: Path::new("file1.txt").into(),
range: file1_content.find("two").unwrap()..file1_content.len(),
score: 0.6,
query_index: 0,
},
SearchResult {
worktree: worktree.clone(),
path: Path::new("file1.txt").into(),
range: 0..file1_content.find("two").unwrap(),
score: 0.5,
query_index: 1,
},
SearchResult {
worktree: worktree.clone(),
path: Path::new("file2.txt").into(),
range: 0..file2_content.len(),
score: 0.8,
query_index: 1,
},
];
assert_eq!(
SemanticDb::load_results(search_results, &fs, &cx.to_async())
.await
.unwrap(),
&[
LoadedSearchResult {
path: Path::new("file2.txt").into(),
full_path: "fake_project/file2.txt".into(),
excerpt_content: file2_content.into(),
row_range: 0..=4,
query_index: 1,
},
LoadedSearchResult {
path: Path::new("file1.txt").into(),
full_path: "fake_project/file1.txt".into(),
excerpt_content: file1_content.into(),
row_range: 0..=4,
query_index: 0,
}
]
);
}
}