From a916bbf00c2757df465357b2b15d8191ca76f7db Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Fri, 28 Mar 2025 17:56:14 +0100 Subject: [PATCH] assistant2: Add support for referencing symbols as context (#27513) TODO Release Notes: - N/A --- Cargo.lock | 1 + crates/assistant2/Cargo.toml | 1 + crates/assistant2/src/context.rs | 69 ++- crates/assistant2/src/context_picker.rs | 28 +- .../src/context_picker/completion_provider.rs | 100 +++- .../context_picker/symbol_context_picker.rs | 438 ++++++++++++++++++ crates/assistant2/src/context_store.rs | 210 ++++++++- crates/assistant2/src/ui/context_pill.rs | 7 +- crates/editor/src/editor.rs | 4 +- 9 files changed, 838 insertions(+), 20 deletions(-) create mode 100644 crates/assistant2/src/context_picker/symbol_context_picker.rs diff --git a/Cargo.lock b/Cargo.lock index 24972889bc..f001aaa2d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -482,6 +482,7 @@ dependencies = [ "markdown", "menu", "multi_buffer", + "ordered-float 2.10.1", "parking_lot", "paths", "picker", diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 411fe508c0..8d10c21fc4 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -55,6 +55,7 @@ lsp.workspace = true markdown.workspace = true menu.workspace = true multi_buffer.workspace = true +ordered-float.workspace = true parking_lot.workspace = true paths.workspace = true picker.workspace = true diff --git a/crates/assistant2/src/context.rs b/crates/assistant2/src/context.rs index 44c81c82ec..d86542210e 100644 --- a/crates/assistant2/src/context.rs +++ b/crates/assistant2/src/context.rs @@ -1,12 +1,13 @@ -use std::path::Path; use std::rc::Rc; +use std::{ops::Range, path::Path}; use file_icons::FileIcons; use gpui::{App, Entity, SharedString}; use language::Buffer; use language_model::{LanguageModelRequestMessage, MessageContent}; +use project::ProjectPath; use serde::{Deserialize, Serialize}; -use text::BufferId; +use text::{Anchor, BufferId}; use ui::IconName; use util::post_inc; @@ -38,6 +39,7 @@ pub struct ContextSnapshot { pub enum ContextKind { File, Directory, + Symbol, FetchedUrl, Thread, } @@ -47,6 +49,7 @@ impl ContextKind { match self { ContextKind::File => IconName::File, ContextKind::Directory => IconName::Folder, + ContextKind::Symbol => IconName::Code, ContextKind::FetchedUrl => IconName::Globe, ContextKind::Thread => IconName::MessageCircle, } @@ -57,6 +60,7 @@ impl ContextKind { pub enum AssistantContext { File(FileContext), Directory(DirectoryContext), + Symbol(SymbolContext), FetchedUrl(FetchedUrlContext), Thread(ThreadContext), } @@ -66,6 +70,7 @@ impl AssistantContext { match self { Self::File(file) => file.id, Self::Directory(directory) => directory.snapshot.id, + Self::Symbol(symbol) => symbol.id, Self::FetchedUrl(url) => url.id, Self::Thread(thread) => thread.id, } @@ -85,6 +90,12 @@ pub struct DirectoryContext { pub snapshot: ContextSnapshot, } +#[derive(Debug)] +pub struct SymbolContext { + pub id: ContextId, + pub context_symbol: ContextSymbol, +} + #[derive(Debug)] pub struct FetchedUrlContext { pub id: ContextId, @@ -113,11 +124,30 @@ pub struct ContextBuffer { pub text: SharedString, } +#[derive(Debug, Clone)] +pub struct ContextSymbol { + pub id: ContextSymbolId, + pub buffer: Entity, + pub buffer_version: clock::Global, + /// The range that the symbol encloses, e.g. for function symbol, this will + /// include not only the signature, but also the body + pub enclosing_range: Range, + pub text: SharedString, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ContextSymbolId { + pub path: ProjectPath, + pub name: SharedString, + pub range: Range, +} + impl AssistantContext { pub fn snapshot(&self, cx: &App) -> Option { match &self { Self::File(file_context) => file_context.snapshot(cx), Self::Directory(directory_context) => Some(directory_context.snapshot()), + Self::Symbol(symbol_context) => symbol_context.snapshot(cx), Self::FetchedUrl(fetched_url_context) => Some(fetched_url_context.snapshot()), Self::Thread(thread_context) => Some(thread_context.snapshot(cx)), } @@ -197,6 +227,27 @@ impl DirectoryContext { } } +impl SymbolContext { + pub fn snapshot(&self, cx: &App) -> Option { + let buffer = self.context_symbol.buffer.read(cx); + let name = self.context_symbol.id.name.clone(); + let path = buffer_path_log_err(buffer)? + .to_string_lossy() + .into_owned() + .into(); + + Some(ContextSnapshot { + id: self.id, + name, + parent: Some(path), + tooltip: None, + icon_path: None, + kind: ContextKind::Symbol, + text: Box::new([self.context_symbol.text.clone()]), + }) + } +} + impl FetchedUrlContext { pub fn snapshot(&self) -> ContextSnapshot { ContextSnapshot { @@ -232,6 +283,7 @@ pub fn attach_context_to_message( ) { let mut file_context = Vec::new(); let mut directory_context = Vec::new(); + let mut symbol_context = Vec::new(); let mut fetch_context = Vec::new(); let mut thread_context = Vec::new(); @@ -241,6 +293,7 @@ pub fn attach_context_to_message( match context.kind { ContextKind::File => file_context.push(context), ContextKind::Directory => directory_context.push(context), + ContextKind::Symbol => symbol_context.push(context), ContextKind::FetchedUrl => fetch_context.push(context), ContextKind::Thread => thread_context.push(context), } @@ -251,6 +304,9 @@ pub fn attach_context_to_message( if !directory_context.is_empty() { capacity += 1; } + if !symbol_context.is_empty() { + capacity += 1; + } if !fetch_context.is_empty() { capacity += 1 + fetch_context.len(); } @@ -281,6 +337,15 @@ pub fn attach_context_to_message( } } + if !symbol_context.is_empty() { + context_chunks.push("The following symbols are available:\n"); + for context in &symbol_context { + for chunk in &context.text { + context_chunks.push(&chunk); + } + } + } + if !fetch_context.is_empty() { context_chunks.push("The following fetched results are available:\n"); for context in &fetch_context { diff --git a/crates/assistant2/src/context_picker.rs b/crates/assistant2/src/context_picker.rs index d9dba7d0c8..7db267edaa 100644 --- a/crates/assistant2/src/context_picker.rs +++ b/crates/assistant2/src/context_picker.rs @@ -1,6 +1,7 @@ mod completion_provider; mod fetch_context_picker; mod file_context_picker; +mod symbol_context_picker; mod thread_context_picker; use std::ops::Range; @@ -16,6 +17,7 @@ use gpui::{ }; use multi_buffer::MultiBufferRow; use project::ProjectPath; +use symbol_context_picker::SymbolContextPicker; use thread_context_picker::{render_thread_context_entry, ThreadContextEntry}; use ui::{ prelude::*, ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, @@ -39,6 +41,7 @@ pub enum ConfirmBehavior { #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ContextPickerMode { File, + Symbol, Fetch, Thread, } @@ -49,6 +52,7 @@ impl TryFrom<&str> for ContextPickerMode { fn try_from(value: &str) -> Result { match value { "file" => Ok(Self::File), + "symbol" => Ok(Self::Symbol), "fetch" => Ok(Self::Fetch), "thread" => Ok(Self::Thread), _ => Err(format!("Invalid context picker mode: {}", value)), @@ -60,6 +64,7 @@ impl ContextPickerMode { pub fn mention_prefix(&self) -> &'static str { match self { Self::File => "file", + Self::Symbol => "symbol", Self::Fetch => "fetch", Self::Thread => "thread", } @@ -68,6 +73,7 @@ impl ContextPickerMode { pub fn label(&self) -> &'static str { match self { Self::File => "Files & Directories", + Self::Symbol => "Symbols", Self::Fetch => "Fetch", Self::Thread => "Thread", } @@ -76,6 +82,7 @@ impl ContextPickerMode { pub fn icon(&self) -> IconName { match self { Self::File => IconName::File, + Self::Symbol => IconName::Code, Self::Fetch => IconName::Globe, Self::Thread => IconName::MessageCircle, } @@ -86,6 +93,7 @@ impl ContextPickerMode { enum ContextPickerState { Default(Entity), File(Entity), + Symbol(Entity), Fetch(Entity), Thread(Entity), } @@ -205,6 +213,18 @@ impl ContextPicker { ) })); } + ContextPickerMode::Symbol => { + self.mode = ContextPickerState::Symbol(cx.new(|cx| { + SymbolContextPicker::new( + context_picker.clone(), + self.workspace.clone(), + self.context_store.clone(), + self.confirm_behavior, + window, + cx, + ) + })); + } ContextPickerMode::Fetch => { self.mode = ContextPickerState::Fetch(cx.new(|cx| { FetchContextPicker::new( @@ -416,6 +436,7 @@ impl Focusable for ContextPicker { match &self.mode { ContextPickerState::Default(menu) => menu.focus_handle(cx), ContextPickerState::File(file_picker) => file_picker.focus_handle(cx), + ContextPickerState::Symbol(symbol_picker) => symbol_picker.focus_handle(cx), ContextPickerState::Fetch(fetch_picker) => fetch_picker.focus_handle(cx), ContextPickerState::Thread(thread_picker) => thread_picker.focus_handle(cx), } @@ -430,6 +451,7 @@ impl Render for ContextPicker { .map(|parent| match &self.mode { ContextPickerState::Default(menu) => parent.child(menu.clone()), ContextPickerState::File(file_picker) => parent.child(file_picker.clone()), + ContextPickerState::Symbol(symbol_picker) => parent.child(symbol_picker.clone()), ContextPickerState::Fetch(fetch_picker) => parent.child(fetch_picker.clone()), ContextPickerState::Thread(thread_picker) => parent.child(thread_picker.clone()), }) @@ -446,7 +468,11 @@ enum RecentEntry { fn supported_context_picker_modes( thread_store: &Option>, ) -> Vec { - let mut modes = vec![ContextPickerMode::File, ContextPickerMode::Fetch]; + let mut modes = vec![ + ContextPickerMode::File, + ContextPickerMode::Symbol, + ContextPickerMode::Fetch, + ]; if thread_store.is_some() { modes.push(ContextPickerMode::Thread); } diff --git a/crates/assistant2/src/context_picker/completion_provider.rs b/crates/assistant2/src/context_picker/completion_provider.rs index 8f59a0d34f..30f0f0074e 100644 --- a/crates/assistant2/src/context_picker/completion_provider.rs +++ b/crates/assistant2/src/context_picker/completion_provider.rs @@ -12,7 +12,7 @@ use gpui::{App, Entity, Task, WeakEntity}; use http_client::HttpClientWithUrl; use language::{Buffer, CodeLabel, HighlightId}; use lsp::CompletionContext; -use project::{Completion, CompletionIntent, ProjectPath, WorktreeId}; +use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId}; use rope::Point; use text::{Anchor, ToPoint}; use ui::prelude::*; @@ -308,6 +308,66 @@ impl ContextPickerCompletionProvider { )), } } + + fn completion_for_symbol( + symbol: Symbol, + excerpt_id: ExcerptId, + source_range: Range, + editor: Entity, + context_store: Entity, + workspace: Entity, + cx: &mut App, + ) -> Option { + let path_prefix = workspace + .read(cx) + .project() + .read(cx) + .worktree_for_id(symbol.path.worktree_id, cx)? + .read(cx) + .root_name(); + + let (file_name, _) = super::file_context_picker::extract_file_name_and_directory( + &symbol.path.path, + path_prefix, + ); + + let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId); + let mut label = CodeLabel::plain(symbol.name.clone(), None); + label.push_str(" ", None); + label.push_str(&file_name, comment_id); + + let new_text = format!("@symbol {}:{}", file_name, symbol.name); + let new_text_len = new_text.len(); + Some(Completion { + old_range: source_range.clone(), + new_text, + label, + documentation: None, + source: project::CompletionSource::Custom, + icon_path: Some(IconName::Code.path().into()), + confirm: Some(confirm_completion_callback( + IconName::Code.path().into(), + symbol.name.clone().into(), + excerpt_id, + source_range.start, + new_text_len, + editor.clone(), + move |cx| { + let symbol = symbol.clone(); + let context_store = context_store.clone(); + let workspace = workspace.clone(); + super::symbol_context_picker::add_symbol( + symbol.clone(), + false, + workspace.clone(), + context_store.downgrade(), + cx, + ) + .detach_and_log_err(cx); + }, + )), + }) + } } impl CompletionProvider for ContextPickerCompletionProvider { @@ -350,14 +410,10 @@ impl CompletionProvider for ContextPickerCompletionProvider { cx.spawn(async move |_, cx| { let mut completions = Vec::new(); - let MentionCompletion { - mode: category, - argument, - .. - } = state; + let MentionCompletion { mode, argument, .. } = state; let query = argument.unwrap_or_else(|| "".to_string()); - match category { + match mode { Some(ContextPickerMode::File) => { let path_matches = cx .update(|cx| { @@ -392,6 +448,35 @@ impl CompletionProvider for ContextPickerCompletionProvider { })?; } } + Some(ContextPickerMode::Symbol) => { + if let Some(editor) = editor.upgrade() { + let symbol_matches = cx + .update(|cx| { + super::symbol_context_picker::search_symbols( + query, + Arc::new(AtomicBool::default()), + &workspace, + cx, + ) + })? + .await?; + cx.update(|cx| { + completions.extend(symbol_matches.into_iter().filter_map( + |(_, symbol)| { + Self::completion_for_symbol( + symbol, + excerpt_id, + source_range.clone(), + editor.clone(), + context_store.clone(), + workspace.clone(), + cx, + ) + }, + )); + })?; + } + } Some(ContextPickerMode::Fetch) => { if let Some(editor) = editor.upgrade() { if !query.is_empty() { @@ -792,6 +877,7 @@ mod tests { "five.txt dir/b/", "four.txt dir/a/", "Files & Directories", + "Symbols", "Fetch" ] ); diff --git a/crates/assistant2/src/context_picker/symbol_context_picker.rs b/crates/assistant2/src/context_picker/symbol_context_picker.rs new file mode 100644 index 0000000000..d792467611 --- /dev/null +++ b/crates/assistant2/src/context_picker/symbol_context_picker.rs @@ -0,0 +1,438 @@ +use std::cmp::Reverse; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use anyhow::{Context as _, Result}; +use fuzzy::{StringMatch, StringMatchCandidate}; +use gpui::{ + App, AppContext, DismissEvent, Entity, FocusHandle, Focusable, Stateful, Task, WeakEntity, +}; +use ordered_float::OrderedFloat; +use picker::{Picker, PickerDelegate}; +use project::{DocumentSymbol, Symbol}; +use text::OffsetRangeExt; +use ui::{prelude::*, ListItem}; +use util::ResultExt as _; +use workspace::Workspace; + +use crate::context_picker::{ConfirmBehavior, ContextPicker}; +use crate::context_store::ContextStore; + +pub struct SymbolContextPicker { + picker: Entity>, +} + +impl SymbolContextPicker { + pub fn new( + context_picker: WeakEntity, + workspace: WeakEntity, + context_store: WeakEntity, + confirm_behavior: ConfirmBehavior, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let delegate = SymbolContextPickerDelegate::new( + context_picker, + workspace, + context_store, + confirm_behavior, + ); + let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); + + Self { picker } + } +} + +impl Focusable for SymbolContextPicker { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.picker.focus_handle(cx) + } +} + +impl Render for SymbolContextPicker { + fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + self.picker.clone() + } +} + +pub struct SymbolContextPickerDelegate { + context_picker: WeakEntity, + workspace: WeakEntity, + context_store: WeakEntity, + confirm_behavior: ConfirmBehavior, + matches: Vec, + selected_index: usize, +} + +impl SymbolContextPickerDelegate { + pub fn new( + context_picker: WeakEntity, + workspace: WeakEntity, + context_store: WeakEntity, + confirm_behavior: ConfirmBehavior, + ) -> Self { + Self { + context_picker, + workspace, + context_store, + confirm_behavior, + matches: Vec::new(), + selected_index: 0, + } + } +} + +impl PickerDelegate for SymbolContextPickerDelegate { + type ListItem = ListItem; + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) { + self.selected_index = ix; + } + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Search symbols…".into() + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + let Some(workspace) = self.workspace.upgrade() else { + return Task::ready(()); + }; + + let search_task = search_symbols(query, Arc::::default(), &workspace, cx); + let context_store = self.context_store.clone(); + cx.spawn_in(window, async move |this, cx| { + let symbols = search_task + .await + .context("Failed to load symbols") + .log_err() + .unwrap_or_default(); + + let symbol_entries = context_store + .read_with(cx, |context_store, cx| { + compute_symbol_entries(symbols, context_store, cx) + }) + .log_err() + .unwrap_or_default(); + + this.update(cx, |this, _cx| { + this.delegate.matches = symbol_entries; + }) + .log_err(); + }) + } + + fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + let Some(mat) = self.matches.get(self.selected_index) else { + return; + }; + let Some(workspace) = self.workspace.upgrade() else { + return; + }; + + let confirm_behavior = self.confirm_behavior; + let add_symbol_task = add_symbol( + mat.symbol.clone(), + true, + workspace, + self.context_store.clone(), + cx, + ); + + let selected_index = self.selected_index; + cx.spawn_in(window, async move |this, cx| { + let included = add_symbol_task.await?; + this.update_in(cx, |this, window, cx| { + if let Some(mat) = this.delegate.matches.get_mut(selected_index) { + mat.is_included = included; + } + match confirm_behavior { + ConfirmBehavior::KeepOpen => {} + ConfirmBehavior::Close => this.delegate.dismissed(window, cx), + } + }) + }) + .detach_and_log_err(cx); + } + + fn dismissed(&mut self, _: &mut Window, cx: &mut Context>) { + self.context_picker + .update(cx, |_, cx| { + cx.emit(DismissEvent); + }) + .ok(); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _window: &mut Window, + _: &mut Context>, + ) -> Option { + let mat = &self.matches[ix]; + + Some(ListItem::new(ix).inset(true).toggle_state(selected).child( + render_symbol_context_entry( + ElementId::NamedInteger("symbol-ctx-picker".into(), ix), + mat, + ), + )) + } +} + +pub(crate) struct SymbolEntry { + pub symbol: Symbol, + pub is_included: bool, +} + +pub(crate) fn add_symbol( + symbol: Symbol, + remove_if_exists: bool, + workspace: Entity, + context_store: WeakEntity, + cx: &mut App, +) -> Task> { + let project = workspace.read(cx).project().clone(); + let open_buffer_task = project.update(cx, |project, cx| { + project.open_buffer(symbol.path.clone(), cx) + }); + cx.spawn(async move |cx| { + let buffer = open_buffer_task.await?; + let document_symbols = project + .update(cx, |project, cx| project.document_symbols(&buffer, cx))? + .await?; + + // Try to find a matching document symbol. Document symbols include + // not only the symbol itself (e.g. function name), but they also + // include the context that they contain (e.g. function body). + let (name, range, enclosing_range) = if let Some(DocumentSymbol { + name, + range, + selection_range, + .. + }) = + find_matching_symbol(&symbol, document_symbols.as_slice()) + { + (name, selection_range, range) + } else { + // If we do not find a matching document symbol, fall back to + // just the symbol itself + (symbol.name, symbol.range.clone(), symbol.range) + }; + + let (range, enclosing_range) = buffer.read_with(cx, |buffer, _| { + ( + buffer.anchor_after(range.start)..buffer.anchor_before(range.end), + buffer.anchor_after(enclosing_range.start) + ..buffer.anchor_before(enclosing_range.end), + ) + })?; + + context_store + .update(cx, move |context_store, cx| { + context_store.add_symbol( + buffer, + name.into(), + range, + enclosing_range, + remove_if_exists, + cx, + ) + })? + .await + }) +} + +fn find_matching_symbol(symbol: &Symbol, candidates: &[DocumentSymbol]) -> Option { + let mut candidates = candidates.iter(); + let mut candidate = candidates.next()?; + + loop { + if candidate.range.start > symbol.range.end { + return None; + } + if candidate.range.end < symbol.range.start { + candidate = candidates.next()?; + continue; + } + if candidate.selection_range == symbol.range { + return Some(candidate.clone()); + } + if candidate.range.start <= symbol.range.start && symbol.range.end <= candidate.range.end { + candidates = candidate.children.iter(); + candidate = candidates.next()?; + continue; + } + return None; + } +} + +pub(crate) fn search_symbols( + query: String, + cancellation_flag: Arc, + workspace: &Entity, + cx: &mut App, +) -> Task>> { + let symbols_task = workspace.update(cx, |workspace, cx| { + workspace + .project() + .update(cx, |project, cx| project.symbols(&query, cx)) + }); + let project = workspace.read(cx).project().clone(); + cx.spawn(async move |cx| { + let symbols = symbols_task.await?; + let (visible_match_candidates, external_match_candidates): (Vec<_>, Vec<_>) = project + .update(cx, |project, cx| { + symbols + .iter() + .enumerate() + .map(|(id, symbol)| StringMatchCandidate::new(id, &symbol.label.filter_text())) + .partition(|candidate| { + project + .entry_for_path(&symbols[candidate.id].path, cx) + .map_or(false, |e| !e.is_ignored) + }) + })?; + + const MAX_MATCHES: usize = 100; + let mut visible_matches = cx.background_executor().block(fuzzy::match_strings( + &visible_match_candidates, + &query, + false, + MAX_MATCHES, + &cancellation_flag, + cx.background_executor().clone(), + )); + let mut external_matches = cx.background_executor().block(fuzzy::match_strings( + &external_match_candidates, + &query, + false, + MAX_MATCHES - visible_matches.len().min(MAX_MATCHES), + &cancellation_flag, + cx.background_executor().clone(), + )); + let sort_key_for_match = |mat: &StringMatch| { + let symbol = &symbols[mat.candidate_id]; + (Reverse(OrderedFloat(mat.score)), symbol.label.filter_text()) + }; + + visible_matches.sort_unstable_by_key(sort_key_for_match); + external_matches.sort_unstable_by_key(sort_key_for_match); + let mut matches = visible_matches; + matches.append(&mut external_matches); + + Ok(matches + .into_iter() + .map(|mut mat| { + let symbol = symbols[mat.candidate_id].clone(); + let filter_start = symbol.label.filter_range.start; + for position in &mut mat.positions { + *position += filter_start; + } + (mat, symbol) + }) + .collect()) + }) +} + +fn compute_symbol_entries( + symbols: Vec<(StringMatch, Symbol)>, + context_store: &ContextStore, + cx: &App, +) -> Vec { + let mut symbol_entries = Vec::with_capacity(symbols.len()); + for (_, symbol) in symbols { + let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path); + let is_included = if let Some(symbols_for_path) = symbols_for_path { + let mut is_included = false; + for included_symbol_id in symbols_for_path { + if included_symbol_id.name.as_ref() == symbol.name.as_str() { + if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) { + let snapshot = buffer.read(cx).snapshot(); + let included_symbol_range = + included_symbol_id.range.to_point_utf16(&snapshot); + + if included_symbol_range.start == symbol.range.start.0 + && included_symbol_range.end == symbol.range.end.0 + { + is_included = true; + break; + } + } + } + } + is_included + } else { + false + }; + + symbol_entries.push(SymbolEntry { + symbol, + is_included, + }) + } + symbol_entries +} + +pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful
{ + let path = entry + .symbol + .path + .path + .file_name() + .map(|s| s.to_string_lossy()) + .unwrap_or_default(); + let symbol_location = format!("{} L{}", path, entry.symbol.range.start.0.row + 1); + + h_flex() + .id(id) + .gap_1p5() + .w_full() + .child( + Icon::new(IconName::Code) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child( + h_flex() + .gap_1() + .child(Label::new(&entry.symbol.name)) + .child( + Label::new(symbol_location) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .when(entry.is_included, |el| { + el.child( + h_flex() + .w_full() + .justify_end() + .gap_0p5() + .child( + Icon::new(IconName::Check) + .size(IconSize::Small) + .color(Color::Success), + ) + .child(Label::new("Added").size(LabelSize::Small)), + ) + }) +} diff --git a/crates/assistant2/src/context_store.rs b/crates/assistant2/src/context_store.rs index 3dfaaca3c8..7177f2114a 100644 --- a/crates/assistant2/src/context_store.rs +++ b/crates/assistant2/src/context_store.rs @@ -1,3 +1,4 @@ +use std::ops::Range; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -6,15 +7,15 @@ use collections::{BTreeMap, HashMap, HashSet}; use futures::{self, future, Future, FutureExt}; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task, WeakEntity}; use language::Buffer; -use project::{ProjectPath, Worktree}; +use project::{ProjectItem, ProjectPath, Worktree}; use rope::Rope; -use text::BufferId; +use text::{Anchor, BufferId, OffsetRangeExt}; use util::maybe; use workspace::Workspace; use crate::context::{ - AssistantContext, ContextBuffer, ContextId, ContextSnapshot, DirectoryContext, - FetchedUrlContext, FileContext, ThreadContext, + AssistantContext, ContextBuffer, ContextId, ContextSnapshot, ContextSymbol, ContextSymbolId, + DirectoryContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext, }; use crate::context_strip::SuggestedContext; use crate::thread::{Thread, ThreadId}; @@ -26,6 +27,9 @@ pub struct ContextStore { next_context_id: ContextId, files: BTreeMap, directories: HashMap, + symbols: HashMap, + symbol_buffers: HashMap>, + symbols_by_path: HashMap>, threads: HashMap, fetched_urls: HashMap, } @@ -38,6 +42,9 @@ impl ContextStore { next_context_id: ContextId(0), files: BTreeMap::default(), directories: HashMap::default(), + symbols: HashMap::default(), + symbol_buffers: HashMap::default(), + symbols_by_path: HashMap::default(), threads: HashMap::default(), fetched_urls: HashMap::default(), } @@ -107,6 +114,7 @@ impl ContextStore { project_path.path.clone(), buffer_entity, buffer, + None, cx.to_async(), ) })?; @@ -136,6 +144,7 @@ impl ContextStore { file.path().clone(), buffer_entity, buffer, + None, cx.to_async(), )) })??; @@ -222,6 +231,7 @@ impl ContextStore { path, buffer_entity, buffer, + None, cx.to_async(), ); buffer_infos.push(buffer_info); @@ -262,6 +272,84 @@ impl ContextStore { ))); } + pub fn add_symbol( + &mut self, + buffer: Entity, + symbol_name: SharedString, + symbol_range: Range, + symbol_enclosing_range: Range, + remove_if_exists: bool, + cx: &mut Context, + ) -> Task> { + let buffer_ref = buffer.read(cx); + let Some(file) = buffer_ref.file() else { + return Task::ready(Err(anyhow!("Buffer has no path."))); + }; + + let Some(project_path) = buffer_ref.project_path(cx) else { + return Task::ready(Err(anyhow!("Buffer has no project path."))); + }; + + if let Some(symbols_for_path) = self.symbols_by_path.get(&project_path) { + let mut matching_symbol_id = None; + for symbol in symbols_for_path { + if &symbol.name == &symbol_name { + let snapshot = buffer_ref.snapshot(); + if symbol.range.to_offset(&snapshot) == symbol_range.to_offset(&snapshot) { + matching_symbol_id = self.symbols.get(symbol).cloned(); + break; + } + } + } + + if let Some(id) = matching_symbol_id { + if remove_if_exists { + self.remove_context(id); + } + return Task::ready(Ok(false)); + } + } + + let (buffer_info, collect_content_task) = collect_buffer_info_and_text( + file.path().clone(), + buffer, + buffer_ref, + Some(symbol_enclosing_range.clone()), + cx.to_async(), + ); + + cx.spawn(async move |this, cx| { + let content = collect_content_task.await; + + this.update(cx, |this, _cx| { + this.insert_symbol(make_context_symbol( + buffer_info, + project_path, + symbol_name, + symbol_range, + symbol_enclosing_range, + content, + )) + })?; + anyhow::Ok(true) + }) + } + + fn insert_symbol(&mut self, context_symbol: ContextSymbol) { + let id = self.next_context_id.post_inc(); + self.symbols.insert(context_symbol.id.clone(), id); + self.symbols_by_path + .entry(context_symbol.id.path.clone()) + .or_insert_with(Vec::new) + .push(context_symbol.id.clone()); + self.symbol_buffers + .insert(context_symbol.id.clone(), context_symbol.buffer.clone()); + self.context.push(AssistantContext::Symbol(SymbolContext { + id, + context_symbol, + })); + } + pub fn add_thread( &mut self, thread: Entity, @@ -340,6 +428,19 @@ impl ContextStore { AssistantContext::Directory(_) => { self.directories.retain(|_, context_id| *context_id != id); } + AssistantContext::Symbol(symbol) => { + if let Some(symbols_in_path) = + self.symbols_by_path.get_mut(&symbol.context_symbol.id.path) + { + symbols_in_path.retain(|s| { + self.symbols + .get(s) + .map_or(false, |context_id| *context_id != id) + }); + } + self.symbol_buffers.remove(&symbol.context_symbol.id); + self.symbols.retain(|_, context_id| *context_id != id); + } AssistantContext::FetchedUrl(_) => { self.fetched_urls.retain(|_, context_id| *context_id != id); } @@ -403,6 +504,18 @@ impl ContextStore { self.directories.get(path).copied() } + pub fn included_symbol(&self, symbol_id: &ContextSymbolId) -> Option { + self.symbols.get(symbol_id).copied() + } + + pub fn included_symbols_by_path(&self) -> &HashMap> { + &self.symbols_by_path + } + + pub fn buffer_for_symbol(&self, symbol_id: &ContextSymbolId) -> Option> { + self.symbol_buffers.get(symbol_id).cloned() + } + pub fn includes_thread(&self, thread_id: &ThreadId) -> Option { self.threads.get(thread_id).copied() } @@ -431,6 +544,7 @@ impl ContextStore { buffer_path_log_err(buffer).map(|p| p.to_path_buf()) } AssistantContext::Directory(_) + | AssistantContext::Symbol(_) | AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => None, }) @@ -463,10 +577,28 @@ fn make_context_buffer(info: BufferInfo, text: SharedString) -> ContextBuffer { } } +fn make_context_symbol( + info: BufferInfo, + path: ProjectPath, + name: SharedString, + range: Range, + enclosing_range: Range, + text: SharedString, +) -> ContextSymbol { + ContextSymbol { + id: ContextSymbolId { name, range, path }, + buffer_version: info.version, + enclosing_range, + buffer: info.buffer_entity, + text, + } +} + fn collect_buffer_info_and_text( path: Arc, buffer_entity: Entity, buffer: &Buffer, + range: Option>, cx: AsyncApp, ) -> (BufferInfo, Task) { let buffer_info = BufferInfo { @@ -475,7 +607,11 @@ fn collect_buffer_info_and_text( version: buffer.version(), }; // Important to collect version at the same time as content so that staleness logic is correct. - let content = buffer.as_rope().clone(); + let content = if let Some(range) = range { + buffer.text_for_range(range).collect::() + } else { + buffer.as_rope().clone() + }; let text_task = cx.background_spawn(async move { to_fenced_codeblock(&path, content) }); (buffer_info, text_task) } @@ -577,6 +713,14 @@ pub fn refresh_context_store_text( return refresh_directory_text(context_store, directory_context, cx); } } + AssistantContext::Symbol(symbol_context) => { + if changed_buffers.is_empty() + || changed_buffers.contains(&symbol_context.context_symbol.buffer) + { + let context_store = context_store.clone(); + return refresh_symbol_text(context_store, symbol_context, cx); + } + } AssistantContext::Thread(thread_context) => { if changed_buffers.is_empty() { let context_store = context_store.clone(); @@ -660,6 +804,28 @@ fn refresh_directory_text( })) } +fn refresh_symbol_text( + context_store: Entity, + symbol_context: &SymbolContext, + cx: &App, +) -> Option> { + let id = symbol_context.id; + let task = refresh_context_symbol(&symbol_context.context_symbol, cx); + if let Some(task) = task { + Some(cx.spawn(async move |cx| { + let context_symbol = task.await; + context_store + .update(cx, |context_store, _| { + let new_symbol_context = SymbolContext { id, context_symbol }; + context_store.replace_context(AssistantContext::Symbol(new_symbol_context)); + }) + .ok(); + })) + } else { + None + } +} + fn refresh_thread_text( context_store: Entity, thread_context: &ThreadContext, @@ -692,6 +858,7 @@ fn refresh_context_buffer( path, context_buffer.buffer.clone(), buffer, + None, cx.to_async(), ); Some(text_task.map(move |text| make_context_buffer(buffer_info, text))) @@ -699,3 +866,36 @@ fn refresh_context_buffer( None } } + +fn refresh_context_symbol( + context_symbol: &ContextSymbol, + cx: &App, +) -> Option> { + let buffer = context_symbol.buffer.read(cx); + let path = buffer_path_log_err(buffer)?; + let project_path = buffer.project_path(cx)?; + if buffer.version.changed_since(&context_symbol.buffer_version) { + let (buffer_info, text_task) = collect_buffer_info_and_text( + path, + context_symbol.buffer.clone(), + buffer, + Some(context_symbol.enclosing_range.clone()), + cx.to_async(), + ); + let name = context_symbol.id.name.clone(); + let range = context_symbol.id.range.clone(); + let enclosing_range = context_symbol.enclosing_range.clone(); + Some(text_task.map(move |text| { + make_context_symbol( + buffer_info, + project_path, + name, + range, + enclosing_range, + text, + ) + })) + } else { + None + } +} diff --git a/crates/assistant2/src/ui/context_pill.rs b/crates/assistant2/src/ui/context_pill.rs index 00b3dba867..5fac7c3758 100644 --- a/crates/assistant2/src/ui/context_pill.rs +++ b/crates/assistant2/src/ui/context_pill.rs @@ -190,9 +190,10 @@ impl RenderOnce for ContextPill { .child( Label::new(match kind { ContextKind::File => "Active Tab", - ContextKind::Thread | ContextKind::Directory | ContextKind::FetchedUrl => { - "Active" - } + ContextKind::Thread + | ContextKind::Directory + | ContextKind::FetchedUrl + | ContextKind::Symbol => "Active", }) .size(LabelSize::XSmall) .color(Color::Muted), diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 82d61fedf0..4a78d7c9a9 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -15494,9 +15494,9 @@ impl Editor { } } - pub fn project_path(&self, cx: &mut Context) -> Option { + pub fn project_path(&self, cx: &App) -> Option { if let Some(buffer) = self.buffer.read(cx).as_singleton() { - buffer.read_with(cx, |buffer, cx| buffer.project_path(cx)) + buffer.read(cx).project_path(cx) } else { None }