assistant2: Add support for referencing symbols as context (#27513)

TODO

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-03-28 17:56:14 +01:00 committed by GitHub
parent da47013e56
commit a916bbf00c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 838 additions and 20 deletions

View file

@ -1,3 +1,4 @@
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::sync::Arc;
@ -6,15 +7,15 @@ use collections::{BTreeMap, HashMap, HashSet};
use futures::{self, future, Future, FutureExt};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language::Buffer;
use project::{ProjectPath, Worktree};
use project::{ProjectItem, ProjectPath, Worktree};
use rope::Rope;
use text::BufferId;
use text::{Anchor, BufferId, OffsetRangeExt};
use util::maybe;
use workspace::Workspace;
use crate::context::{
AssistantContext, ContextBuffer, ContextId, ContextSnapshot, DirectoryContext,
FetchedUrlContext, FileContext, ThreadContext,
AssistantContext, ContextBuffer, ContextId, ContextSnapshot, ContextSymbol, ContextSymbolId,
DirectoryContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext,
};
use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId};
@ -26,6 +27,9 @@ pub struct ContextStore {
next_context_id: ContextId,
files: BTreeMap<BufferId, ContextId>,
directories: HashMap<PathBuf, ContextId>,
symbols: HashMap<ContextSymbolId, ContextId>,
symbol_buffers: HashMap<ContextSymbolId, Entity<Buffer>>,
symbols_by_path: HashMap<ProjectPath, Vec<ContextSymbolId>>,
threads: HashMap<ThreadId, ContextId>,
fetched_urls: HashMap<String, ContextId>,
}
@ -38,6 +42,9 @@ impl ContextStore {
next_context_id: ContextId(0),
files: BTreeMap::default(),
directories: HashMap::default(),
symbols: HashMap::default(),
symbol_buffers: HashMap::default(),
symbols_by_path: HashMap::default(),
threads: HashMap::default(),
fetched_urls: HashMap::default(),
}
@ -107,6 +114,7 @@ impl ContextStore {
project_path.path.clone(),
buffer_entity,
buffer,
None,
cx.to_async(),
)
})?;
@ -136,6 +144,7 @@ impl ContextStore {
file.path().clone(),
buffer_entity,
buffer,
None,
cx.to_async(),
))
})??;
@ -222,6 +231,7 @@ impl ContextStore {
path,
buffer_entity,
buffer,
None,
cx.to_async(),
);
buffer_infos.push(buffer_info);
@ -262,6 +272,84 @@ impl ContextStore {
)));
}
pub fn add_symbol(
&mut self,
buffer: Entity<Buffer>,
symbol_name: SharedString,
symbol_range: Range<Anchor>,
symbol_enclosing_range: Range<Anchor>,
remove_if_exists: bool,
cx: &mut Context<Self>,
) -> Task<Result<bool>> {
let buffer_ref = buffer.read(cx);
let Some(file) = buffer_ref.file() else {
return Task::ready(Err(anyhow!("Buffer has no path.")));
};
let Some(project_path) = buffer_ref.project_path(cx) else {
return Task::ready(Err(anyhow!("Buffer has no project path.")));
};
if let Some(symbols_for_path) = self.symbols_by_path.get(&project_path) {
let mut matching_symbol_id = None;
for symbol in symbols_for_path {
if &symbol.name == &symbol_name {
let snapshot = buffer_ref.snapshot();
if symbol.range.to_offset(&snapshot) == symbol_range.to_offset(&snapshot) {
matching_symbol_id = self.symbols.get(symbol).cloned();
break;
}
}
}
if let Some(id) = matching_symbol_id {
if remove_if_exists {
self.remove_context(id);
}
return Task::ready(Ok(false));
}
}
let (buffer_info, collect_content_task) = collect_buffer_info_and_text(
file.path().clone(),
buffer,
buffer_ref,
Some(symbol_enclosing_range.clone()),
cx.to_async(),
);
cx.spawn(async move |this, cx| {
let content = collect_content_task.await;
this.update(cx, |this, _cx| {
this.insert_symbol(make_context_symbol(
buffer_info,
project_path,
symbol_name,
symbol_range,
symbol_enclosing_range,
content,
))
})?;
anyhow::Ok(true)
})
}
fn insert_symbol(&mut self, context_symbol: ContextSymbol) {
let id = self.next_context_id.post_inc();
self.symbols.insert(context_symbol.id.clone(), id);
self.symbols_by_path
.entry(context_symbol.id.path.clone())
.or_insert_with(Vec::new)
.push(context_symbol.id.clone());
self.symbol_buffers
.insert(context_symbol.id.clone(), context_symbol.buffer.clone());
self.context.push(AssistantContext::Symbol(SymbolContext {
id,
context_symbol,
}));
}
pub fn add_thread(
&mut self,
thread: Entity<Thread>,
@ -340,6 +428,19 @@ impl ContextStore {
AssistantContext::Directory(_) => {
self.directories.retain(|_, context_id| *context_id != id);
}
AssistantContext::Symbol(symbol) => {
if let Some(symbols_in_path) =
self.symbols_by_path.get_mut(&symbol.context_symbol.id.path)
{
symbols_in_path.retain(|s| {
self.symbols
.get(s)
.map_or(false, |context_id| *context_id != id)
});
}
self.symbol_buffers.remove(&symbol.context_symbol.id);
self.symbols.retain(|_, context_id| *context_id != id);
}
AssistantContext::FetchedUrl(_) => {
self.fetched_urls.retain(|_, context_id| *context_id != id);
}
@ -403,6 +504,18 @@ impl ContextStore {
self.directories.get(path).copied()
}
pub fn included_symbol(&self, symbol_id: &ContextSymbolId) -> Option<ContextId> {
self.symbols.get(symbol_id).copied()
}
pub fn included_symbols_by_path(&self) -> &HashMap<ProjectPath, Vec<ContextSymbolId>> {
&self.symbols_by_path
}
pub fn buffer_for_symbol(&self, symbol_id: &ContextSymbolId) -> Option<Entity<Buffer>> {
self.symbol_buffers.get(symbol_id).cloned()
}
pub fn includes_thread(&self, thread_id: &ThreadId) -> Option<ContextId> {
self.threads.get(thread_id).copied()
}
@ -431,6 +544,7 @@ impl ContextStore {
buffer_path_log_err(buffer).map(|p| p.to_path_buf())
}
AssistantContext::Directory(_)
| AssistantContext::Symbol(_)
| AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_) => None,
})
@ -463,10 +577,28 @@ fn make_context_buffer(info: BufferInfo, text: SharedString) -> ContextBuffer {
}
}
fn make_context_symbol(
info: BufferInfo,
path: ProjectPath,
name: SharedString,
range: Range<Anchor>,
enclosing_range: Range<Anchor>,
text: SharedString,
) -> ContextSymbol {
ContextSymbol {
id: ContextSymbolId { name, range, path },
buffer_version: info.version,
enclosing_range,
buffer: info.buffer_entity,
text,
}
}
fn collect_buffer_info_and_text(
path: Arc<Path>,
buffer_entity: Entity<Buffer>,
buffer: &Buffer,
range: Option<Range<Anchor>>,
cx: AsyncApp,
) -> (BufferInfo, Task<SharedString>) {
let buffer_info = BufferInfo {
@ -475,7 +607,11 @@ fn collect_buffer_info_and_text(
version: buffer.version(),
};
// Important to collect version at the same time as content so that staleness logic is correct.
let content = buffer.as_rope().clone();
let content = if let Some(range) = range {
buffer.text_for_range(range).collect::<Rope>()
} else {
buffer.as_rope().clone()
};
let text_task = cx.background_spawn(async move { to_fenced_codeblock(&path, content) });
(buffer_info, text_task)
}
@ -577,6 +713,14 @@ pub fn refresh_context_store_text(
return refresh_directory_text(context_store, directory_context, cx);
}
}
AssistantContext::Symbol(symbol_context) => {
if changed_buffers.is_empty()
|| changed_buffers.contains(&symbol_context.context_symbol.buffer)
{
let context_store = context_store.clone();
return refresh_symbol_text(context_store, symbol_context, cx);
}
}
AssistantContext::Thread(thread_context) => {
if changed_buffers.is_empty() {
let context_store = context_store.clone();
@ -660,6 +804,28 @@ fn refresh_directory_text(
}))
}
fn refresh_symbol_text(
context_store: Entity<ContextStore>,
symbol_context: &SymbolContext,
cx: &App,
) -> Option<Task<()>> {
let id = symbol_context.id;
let task = refresh_context_symbol(&symbol_context.context_symbol, cx);
if let Some(task) = task {
Some(cx.spawn(async move |cx| {
let context_symbol = task.await;
context_store
.update(cx, |context_store, _| {
let new_symbol_context = SymbolContext { id, context_symbol };
context_store.replace_context(AssistantContext::Symbol(new_symbol_context));
})
.ok();
}))
} else {
None
}
}
fn refresh_thread_text(
context_store: Entity<ContextStore>,
thread_context: &ThreadContext,
@ -692,6 +858,7 @@ fn refresh_context_buffer(
path,
context_buffer.buffer.clone(),
buffer,
None,
cx.to_async(),
);
Some(text_task.map(move |text| make_context_buffer(buffer_info, text)))
@ -699,3 +866,36 @@ fn refresh_context_buffer(
None
}
}
fn refresh_context_symbol(
context_symbol: &ContextSymbol,
cx: &App,
) -> Option<impl Future<Output = ContextSymbol>> {
let buffer = context_symbol.buffer.read(cx);
let path = buffer_path_log_err(buffer)?;
let project_path = buffer.project_path(cx)?;
if buffer.version.changed_since(&context_symbol.buffer_version) {
let (buffer_info, text_task) = collect_buffer_info_and_text(
path,
context_symbol.buffer.clone(),
buffer,
Some(context_symbol.enclosing_range.clone()),
cx.to_async(),
);
let name = context_symbol.id.name.clone();
let range = context_symbol.id.range.clone();
let enclosing_range = context_symbol.enclosing_range.clone();
Some(text_task.map(move |text| {
make_context_symbol(
buffer_info,
project_path,
name,
range,
enclosing_range,
text,
)
}))
} else {
None
}
}