Remove assistant ContextSnapshot (#27822)

Motivation for this is to simplify the context types and make it cleaner
to add image context.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Michael Sloan 2025-03-31 21:57:09 -06:00 committed by GitHub
parent c729842804
commit d0276e6666
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 340 additions and 376 deletions

View file

@ -6,7 +6,7 @@ use crate::thread::{
};
use crate::thread_store::ThreadStore;
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
use crate::ui::{AgentNotification, AgentNotificationEvent, ContextPill};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use collections::HashMap;
use editor::{Editor, MultiBuffer};
@ -487,14 +487,14 @@ impl ActiveThread {
let updated_context_ids = refresh_task.await;
this.update(cx, |this, cx| {
this.context_store.read_with(cx, |context_store, cx| {
this.context_store.read_with(cx, |context_store, _cx| {
context_store
.context()
.iter()
.filter(|context| {
updated_context_ids.contains(&context.id())
})
.flat_map(|context| context.snapshot(cx))
.cloned()
.collect()
})
})
@ -806,7 +806,7 @@ impl ActiveThread {
let thread = self.thread.read(cx);
// Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id);
let context = thread.context_for_message(message_id);
let context = thread.context_for_message(message_id).collect::<Vec<_>>();
let tool_uses = thread.tool_uses_for_message(message_id, cx);
// Don't render user messages that are just there for returning tool results.
@ -926,53 +926,50 @@ impl ActiveThread {
.into_any_element(),
};
let message_content =
v_flex()
.gap_1p5()
.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(edit_message_editor)
} else {
div()
.min_h_6()
.text_ui(cx)
.child(self.render_message_content(message_id, rendered_message, cx))
},
)
.when_some(context, |parent, context| {
if !context.is_empty() {
parent.child(h_flex().flex_wrap().gap_1().children(
context.into_iter().map(|context| {
let context_id = context.id;
ContextPill::added(context, false, false, None).on_click(Rc::new(
cx.listener({
let workspace = workspace.clone();
let context_store = context_store.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_context(
context_id,
context_store.clone(),
workspace,
window,
cx,
);
cx.notify();
}
let message_content = v_flex()
.gap_1p5()
.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(edit_message_editor)
} else {
div()
.min_h_6()
.text_ui(cx)
.child(self.render_message_content(message_id, rendered_message, cx))
},
)
.when(!context.is_empty(), |parent| {
parent.child(
h_flex()
.flex_wrap()
.gap_1()
.children(context.into_iter().map(|context| {
let context_id = context.id();
ContextPill::added(AddedContext::new(context, cx), false, false, None)
.on_click(Rc::new(cx.listener({
let workspace = workspace.clone();
let context_store = context_store.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_context(
context_id,
context_store.clone(),
workspace,
window,
cx,
);
cx.notify();
}
}),
))
}),
))
} else {
parent
}
});
}
})))
})),
)
});
let styled_message = match message.role {
Role::User => v_flex()
@ -1974,7 +1971,7 @@ pub(crate) fn open_context(
}
}
AssistantContext::Directory(directory_context) => {
let path = directory_context.path.clone();
let path = directory_context.project_path.clone();
workspace.update(cx, |workspace, cx| {
workspace.project().update(cx, |project, cx| {
if let Some(entry) = project.entry_for_path(&path, cx) {

View file

@ -414,7 +414,11 @@ impl CodegenAlternative {
};
if let Some(context_store) = &self.context_store {
attach_context_to_message(&mut request_message, context_store.read(cx).snapshot(cx));
attach_context_to_message(
&mut request_message,
context_store.read(cx).context().iter(),
cx,
);
}
request_message.content.push(prompt.into());

View file

@ -1,8 +1,7 @@
use std::ops::Range;
use std::{ops::Range, sync::Arc};
use file_icons::FileIcons;
use gpui::{App, Entity, SharedString};
use language::Buffer;
use language::{Buffer, File};
use language_model::{LanguageModelRequestMessage, MessageContent};
use project::ProjectPath;
use serde::{Deserialize, Serialize};
@ -10,7 +9,7 @@ use text::{Anchor, BufferId};
use ui::IconName;
use util::post_inc;
use crate::{context_store::buffer_path_log_err, thread::Thread};
use crate::thread::Thread;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct ContextId(pub(crate) usize);
@ -21,19 +20,6 @@ impl ContextId {
}
}
/// Some context attached to a message in a thread.
#[derive(Debug, Clone)]
pub struct ContextSnapshot {
pub id: ContextId,
pub name: SharedString,
pub parent: Option<SharedString>,
pub tooltip: Option<SharedString>,
pub icon_path: Option<SharedString>,
pub kind: ContextKind,
/// Joining these strings separated by \n yields text for model. Not refreshed by `snapshot`.
pub text: Box<[SharedString]>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContextKind {
File,
@ -55,7 +41,7 @@ impl ContextKind {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum AssistantContext {
File(FileContext),
Directory(DirectoryContext),
@ -68,7 +54,7 @@ impl AssistantContext {
pub fn id(&self) -> ContextId {
match self {
Self::File(file) => file.id,
Self::Directory(directory) => directory.snapshot.id,
Self::Directory(directory) => directory.id,
Self::Symbol(symbol) => symbol.id,
Self::FetchedUrl(url) => url.id,
Self::Thread(thread) => thread.id,
@ -76,26 +62,26 @@ impl AssistantContext {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct FileContext {
pub id: ContextId,
pub context_buffer: ContextBuffer,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct DirectoryContext {
pub path: ProjectPath,
pub id: ContextId,
pub project_path: ProjectPath,
pub context_buffers: Vec<ContextBuffer>,
pub snapshot: ContextSnapshot,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct SymbolContext {
pub id: ContextId,
pub context_symbol: ContextSymbol,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct FetchedUrlContext {
pub id: ContextId,
pub url: SharedString,
@ -105,24 +91,45 @@ pub struct FetchedUrlContext {
// TODO: Model<Thread> holds onto the thread even if the thread is deleted. Can either handle this
// explicitly or have a WeakModel<Thread> and remove during snapshot.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ThreadContext {
pub id: ContextId,
pub thread: Entity<Thread>,
pub text: SharedString,
}
impl ThreadContext {
pub fn summary(&self, cx: &App) -> SharedString {
self.thread
.read(cx)
.summary()
.unwrap_or("New thread".into())
}
}
// TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
// the context from the message editor in this case.
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct ContextBuffer {
pub id: BufferId,
pub buffer: Entity<Buffer>,
pub file: Arc<dyn File>,
pub version: clock::Global,
pub text: SharedString,
}
impl std::fmt::Debug for ContextBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContextBuffer")
.field("id", &self.id)
.field("buffer", &self.buffer)
.field("version", &self.version)
.field("text", &self.text)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ContextSymbol {
pub id: ContextSymbolId,
@ -141,145 +148,10 @@ pub struct ContextSymbolId {
pub range: Range<Anchor>,
}
impl AssistantContext {
pub fn snapshot(&self, cx: &App) -> Option<ContextSnapshot> {
match &self {
Self::File(file_context) => file_context.snapshot(cx),
Self::Directory(directory_context) => Some(directory_context.snapshot()),
Self::Symbol(symbol_context) => symbol_context.snapshot(cx),
Self::FetchedUrl(fetched_url_context) => Some(fetched_url_context.snapshot()),
Self::Thread(thread_context) => Some(thread_context.snapshot(cx)),
}
}
}
impl FileContext {
pub fn snapshot(&self, cx: &App) -> Option<ContextSnapshot> {
let buffer = self.context_buffer.buffer.read(cx);
let path = buffer_path_log_err(buffer, cx)?;
let full_path: SharedString = path.to_string_lossy().into_owned().into();
let name = match path.file_name() {
Some(name) => name.to_string_lossy().into_owned().into(),
None => full_path.clone(),
};
let parent = path
.parent()
.and_then(|p| p.file_name())
.map(|p| p.to_string_lossy().into_owned().into());
let icon_path = FileIcons::get_icon(&path, cx);
Some(ContextSnapshot {
id: self.id,
name,
parent,
tooltip: Some(full_path),
icon_path,
kind: ContextKind::File,
text: Box::new([self.context_buffer.text.clone()]),
})
}
}
impl DirectoryContext {
pub fn new(
id: ContextId,
project_path: ProjectPath,
context_buffers: Vec<ContextBuffer>,
) -> DirectoryContext {
let full_path: SharedString = project_path.path.to_string_lossy().into_owned().into();
let name = match project_path.path.file_name() {
Some(name) => name.to_string_lossy().into_owned().into(),
None => full_path.clone(),
};
let parent = project_path
.path
.parent()
.and_then(|p| p.file_name())
.map(|p| p.to_string_lossy().into_owned().into());
// TODO: include directory path in text?
let text = context_buffers
.iter()
.map(|b| b.text.clone())
.collect::<Vec<_>>()
.into();
DirectoryContext {
path: project_path,
context_buffers,
snapshot: ContextSnapshot {
id,
name,
parent,
tooltip: Some(full_path),
icon_path: None,
kind: ContextKind::Directory,
text,
},
}
}
pub fn snapshot(&self) -> ContextSnapshot {
self.snapshot.clone()
}
}
impl SymbolContext {
pub fn snapshot(&self, cx: &App) -> Option<ContextSnapshot> {
let buffer = self.context_symbol.buffer.read(cx);
let name = self.context_symbol.id.name.clone();
let path = buffer_path_log_err(buffer, cx)?
.to_string_lossy()
.into_owned()
.into();
Some(ContextSnapshot {
id: self.id,
name,
parent: Some(path),
tooltip: None,
icon_path: None,
kind: ContextKind::Symbol,
text: Box::new([self.context_symbol.text.clone()]),
})
}
}
impl FetchedUrlContext {
pub fn snapshot(&self) -> ContextSnapshot {
ContextSnapshot {
id: self.id,
name: self.url.clone(),
parent: None,
tooltip: None,
icon_path: None,
kind: ContextKind::FetchedUrl,
text: Box::new([self.text.clone()]),
}
}
}
impl ThreadContext {
pub fn snapshot(&self, cx: &App) -> ContextSnapshot {
let thread = self.thread.read(cx);
ContextSnapshot {
id: self.id,
name: thread.summary().unwrap_or("New thread".into()),
parent: None,
tooltip: None,
icon_path: None,
kind: ContextKind::Thread,
text: Box::new([self.text.clone()]),
}
}
}
pub fn attach_context_to_message(
pub fn attach_context_to_message<'a>(
message: &mut LanguageModelRequestMessage,
contexts: impl Iterator<Item = ContextSnapshot>,
contexts: impl Iterator<Item = &'a AssistantContext>,
cx: &App,
) {
let mut file_context = Vec::new();
let mut directory_context = Vec::new();
@ -287,91 +159,61 @@ pub fn attach_context_to_message(
let mut fetch_context = Vec::new();
let mut thread_context = Vec::new();
let mut capacity = 0;
for context in contexts {
capacity += context.text.len();
match context.kind {
ContextKind::File => file_context.push(context),
ContextKind::Directory => directory_context.push(context),
ContextKind::Symbol => symbol_context.push(context),
ContextKind::FetchedUrl => fetch_context.push(context),
ContextKind::Thread => thread_context.push(context),
match context {
AssistantContext::File(context) => file_context.push(context),
AssistantContext::Directory(context) => directory_context.push(context),
AssistantContext::Symbol(context) => symbol_context.push(context),
AssistantContext::FetchedUrl(context) => fetch_context.push(context),
AssistantContext::Thread(context) => thread_context.push(context),
}
}
if !file_context.is_empty() {
capacity += 1;
}
if !directory_context.is_empty() {
capacity += 1;
}
if !symbol_context.is_empty() {
capacity += 1;
}
if !fetch_context.is_empty() {
capacity += 1 + fetch_context.len();
}
if !thread_context.is_empty() {
capacity += 1 + thread_context.len();
}
if capacity == 0 {
return;
}
let mut context_chunks = Vec::with_capacity(capacity);
let mut context_chunks = Vec::new();
if !file_context.is_empty() {
context_chunks.push("The following files are available:\n");
for context in &file_context {
for chunk in &context.text {
context_chunks.push(&chunk);
}
for context in file_context {
context_chunks.push(&context.context_buffer.text);
}
}
if !directory_context.is_empty() {
context_chunks.push("The following directories are available:\n");
for context in &directory_context {
for chunk in &context.text {
context_chunks.push(&chunk);
for context in directory_context {
for context_buffer in &context.context_buffers {
context_chunks.push(&context_buffer.text);
}
}
}
if !symbol_context.is_empty() {
context_chunks.push("The following symbols are available:\n");
for context in &symbol_context {
for chunk in &context.text {
context_chunks.push(&chunk);
}
for context in symbol_context {
context_chunks.push(&context.context_symbol.text);
}
}
if !fetch_context.is_empty() {
context_chunks.push("The following fetched results are available:\n");
for context in &fetch_context {
context_chunks.push(&context.name);
for chunk in &context.text {
context_chunks.push(&chunk);
}
context_chunks.push(&context.url);
context_chunks.push(&context.text);
}
}
// Need to own the SharedString for summary so that it can be referenced.
let mut thread_context_chunks = Vec::new();
if !thread_context.is_empty() {
context_chunks.push("The following previous conversation threads are available:\n");
for context in &thread_context {
context_chunks.push(&context.name);
for chunk in &context.text {
context_chunks.push(&chunk);
}
thread_context_chunks.push(context.summary(cx));
thread_context_chunks.push(context.text.clone());
}
}
debug_assert!(
context_chunks.len() == capacity,
"attach_context_message calculated capacity of {}, but length was {}",
capacity,
context_chunks.len()
);
for chunk in &thread_context_chunks {
context_chunks.push(chunk);
}
if !context_chunks.is_empty() {
message

View file

@ -2,20 +2,20 @@ use std::ops::Range;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use anyhow::{Context as _, Result, anyhow};
use collections::{BTreeMap, HashMap, HashSet};
use futures::{self, Future, FutureExt, future};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language::Buffer;
use language::{Buffer, File};
use project::{ProjectItem, ProjectPath, Worktree};
use rope::Rope;
use text::{Anchor, BufferId, OffsetRangeExt};
use util::maybe;
use util::{ResultExt, maybe};
use workspace::Workspace;
use crate::context::{
AssistantContext, ContextBuffer, ContextId, ContextSnapshot, ContextSymbol, ContextSymbolId,
DirectoryContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext,
AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext,
FetchedUrlContext, FileContext, SymbolContext, ThreadContext,
};
use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId};
@ -50,12 +50,6 @@ impl ContextStore {
}
}
pub fn snapshot<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = ContextSnapshot> + 'a {
self.context()
.iter()
.flat_map(|context| context.snapshot(cx))
}
pub fn context(&self) -> &Vec<AssistantContext> {
&self.context
}
@ -121,7 +115,7 @@ impl ContextStore {
None,
cx.to_async(),
)
})?;
})??;
let text = text_task.await;
@ -144,13 +138,13 @@ impl ContextStore {
let Some(file) = buffer.file() else {
return Err(anyhow!("Buffer has no path."));
};
Ok(collect_buffer_info_and_text(
collect_buffer_info_and_text(
file.path().clone(),
buffer_entity,
buffer,
None,
cx.to_async(),
))
)
})??;
let text = text_task.await;
@ -166,8 +160,10 @@ impl ContextStore {
fn insert_file(&mut self, context_buffer: ContextBuffer) {
let id = self.next_context_id.post_inc();
self.files.insert(context_buffer.id, id);
self.context
.push(AssistantContext::File(FileContext { id, context_buffer }));
self.context.push(AssistantContext::File(FileContext {
id,
context_buffer: context_buffer,
}));
}
pub fn add_directory(
@ -231,15 +227,18 @@ impl ContextStore {
// Skip all binary files and other non-UTF8 files
if let Ok(buffer_entity) = buffer_entity {
let buffer = buffer_entity.read(cx);
let (buffer_info, text_task) = collect_buffer_info_and_text(
if let Some((buffer_info, text_task)) = collect_buffer_info_and_text(
path,
buffer_entity,
buffer,
None,
cx.to_async(),
);
buffer_infos.push(buffer_info);
text_tasks.push(text_task);
)
.log_err()
{
buffer_infos.push(buffer_info);
text_tasks.push(text_task);
}
}
}
anyhow::Ok(())
@ -253,7 +252,10 @@ impl ContextStore {
.collect::<Vec<_>>();
if context_buffers.is_empty() {
bail!("No text files found in {}", &project_path.path.display());
return Err(anyhow!(
"No text files found in {}",
&project_path.path.display()
));
}
this.update(cx, |this, _| {
@ -269,11 +271,11 @@ impl ContextStore {
self.directories.insert(project_path.path.to_path_buf(), id);
self.context
.push(AssistantContext::Directory(DirectoryContext::new(
.push(AssistantContext::Directory(DirectoryContext {
id,
project_path,
context_buffers,
)));
}));
}
pub fn add_symbol(
@ -314,13 +316,16 @@ impl ContextStore {
}
}
let (buffer_info, collect_content_task) = collect_buffer_info_and_text(
let (buffer_info, collect_content_task) = match collect_buffer_info_and_text(
file.path().clone(),
buffer,
buffer_ref,
Some(symbol_enclosing_range.clone()),
cx.to_async(),
);
) {
Ok((buffer_info, collect_context_task)) => (buffer_info, collect_context_task),
Err(err) => return Task::ready(Err(err)),
};
cx.spawn(async move |this, cx| {
let content = collect_content_task.await;
@ -568,6 +573,7 @@ pub enum FileInclusion {
// ContextBuffer without text.
struct BufferInfo {
buffer_entity: Entity<Buffer>,
file: Arc<dyn File>,
id: BufferId,
version: clock::Global,
}
@ -576,6 +582,7 @@ fn make_context_buffer(info: BufferInfo, text: SharedString) -> ContextBuffer {
ContextBuffer {
id: info.id,
buffer: info.buffer_entity,
file: info.file,
version: info.version,
text,
}
@ -604,10 +611,14 @@ fn collect_buffer_info_and_text(
buffer: &Buffer,
range: Option<Range<Anchor>>,
cx: AsyncApp,
) -> (BufferInfo, Task<SharedString>) {
) -> Result<(BufferInfo, Task<SharedString>)> {
let buffer_info = BufferInfo {
id: buffer.remote_id(),
buffer_entity,
file: buffer
.file()
.context("buffer context must have a file")?
.clone(),
version: buffer.version(),
};
// Important to collect version at the same time as content so that staleness logic is correct.
@ -617,23 +628,26 @@ fn collect_buffer_info_and_text(
buffer.as_rope().clone()
};
let text_task = cx.background_spawn(async move { to_fenced_codeblock(&path, content) });
(buffer_info, text_task)
Ok((buffer_info, text_task))
}
pub fn buffer_path_log_err(buffer: &Buffer, cx: &App) -> Option<Arc<Path>> {
if let Some(file) = buffer.file() {
let mut path = file.path().clone();
if path.as_os_str().is_empty() {
path = file.full_path(cx).into();
}
Some(path)
Some(file_path(file, cx))
} else {
log::error!("Buffer that had a path unexpectedly no longer has a path.");
None
}
}
pub fn file_path(file: &Arc<dyn File>, cx: &App) -> Arc<Path> {
let mut path = file.path().clone();
if path.as_os_str().is_empty() {
path = file.full_path(cx).into();
}
return path;
}
fn to_fenced_codeblock(path: &Path, content: Rope) -> SharedString {
let path_extension = path.extension().and_then(|ext| ext.to_str());
let path_string = path.to_string_lossy();
@ -714,7 +728,7 @@ pub fn refresh_context_store_text(
let buffer = buffer.read(cx);
buffer_path_log_err(&buffer, cx).map_or(false, |path| {
path.starts_with(&directory_context.path.path)
path.starts_with(&directory_context.project_path.path)
})
});
@ -801,13 +815,17 @@ fn refresh_directory_text(
let context_buffers = future::join_all(futures);
let id = directory_context.snapshot.id;
let path = directory_context.path.clone();
let id = directory_context.id;
let project_path = directory_context.project_path.clone();
Some(cx.spawn(async move |cx| {
let context_buffers = context_buffers.await;
context_store
.update(cx, |context_store, _| {
let new_directory_context = DirectoryContext::new(id, path, context_buffers);
let new_directory_context = DirectoryContext {
id,
project_path,
context_buffers,
};
context_store.replace_context(AssistantContext::Directory(new_directory_context));
})
.ok();
@ -870,7 +888,8 @@ fn refresh_context_buffer(
buffer,
None,
cx.to_async(),
);
)
.log_err()?;
Some(text_task.map(move |text| make_context_buffer(buffer_info, text)))
} else {
None
@ -891,7 +910,8 @@ fn refresh_context_symbol(
buffer,
Some(context_symbol.enclosing_range.clone()),
cx.to_async(),
);
)
.log_err()?;
let name = context_symbol.id.name.clone();
let range = context_symbol.id.range.clone();
let enclosing_range = context_symbol.enclosing_range.clone();

View file

@ -17,7 +17,7 @@ use crate::context_picker::{ConfirmBehavior, ContextPicker};
use crate::context_store::ContextStore;
use crate::thread::Thread;
use crate::thread_store::ThreadStore;
use crate::ui::ContextPill;
use crate::ui::{AddedContext, ContextPill};
use crate::{
AcceptSuggestedContext, AssistantPanel, FocusDown, FocusLeft, FocusRight, FocusUp,
RemoveAllContext, RemoveFocusedContext, ToggleContextPicker,
@ -363,19 +363,19 @@ impl Focusable for ContextStrip {
impl Render for ContextStrip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let context_store = self.context_store.read(cx);
let context = context_store
.context()
.iter()
.flat_map(|context| context.snapshot(cx))
.collect::<Vec<_>>();
let context = context_store.context();
let context_picker = self.context_picker.clone();
let focus_handle = self.focus_handle.clone();
let suggested_context = self.suggested_context(cx);
let dupe_names = context
let added_contexts = context
.iter()
.map(|context| context.name.clone())
.map(|c| AddedContext::new(c, cx))
.collect::<Vec<_>>();
let dupe_names = added_contexts
.iter()
.map(|c| c.name.clone())
.sorted()
.tuple_windows()
.filter(|(a, b)| a == b)
@ -461,34 +461,39 @@ impl Render for ContextStrip {
)
}
})
.children(context.iter().enumerate().map(|(i, context)| {
let id = context.id;
ContextPill::added(
context.clone(),
dupe_names.contains(&context.name),
self.focused_index == Some(i),
Some({
let id = context.id;
let context_store = self.context_store.clone();
Rc::new(cx.listener(move |_this, _event, _window, cx| {
context_store.update(cx, |this, _cx| {
this.remove_context(id);
});
cx.notify();
}))
.children(
added_contexts
.into_iter()
.enumerate()
.map(|(i, added_context)| {
let name = added_context.name.clone();
let id = added_context.id;
ContextPill::added(
added_context,
dupe_names.contains(&name),
self.focused_index == Some(i),
Some({
let context_store = self.context_store.clone();
Rc::new(cx.listener(move |_this, _event, _window, cx| {
context_store.update(cx, |this, _cx| {
this.remove_context(id);
});
cx.notify();
}))
}),
)
.on_click({
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
if event.down.click_count > 1 {
this.open_context(id, window, cx);
} else {
this.focused_index = Some(i);
}
cx.notify();
}))
})
}),
)
.on_click(Rc::new(cx.listener(
move |this, event: &ClickEvent, window, cx| {
if event.down.click_count > 1 {
this.open_context(id, window, cx);
} else {
this.focused_index = Some(i);
}
cx.notify();
},
)))
}))
)
.when_some(suggested_context, |el, suggested| {
el.child(
ContextPill::suggested(

View file

@ -239,7 +239,7 @@ impl MessageEditor {
.ok();
thread
.update(cx, |thread, cx| {
let context = context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
let context = context_store.read(cx).context().clone();
thread.action_log().update(cx, |action_log, cx| {
action_log.clear_reviewed_changes(cx);
});

View file

@ -252,7 +252,8 @@ impl TerminalInlineAssistant {
attach_context_to_message(
&mut request_message,
assist.context_store.read(cx).snapshot(cx),
assist.context_store.read(cx).context().iter(),
cx,
);
request_message.content.push(prompt.into());

View file

@ -29,7 +29,7 @@ use settings::Settings;
use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
use uuid::Uuid;
use crate::context::{ContextId, ContextSnapshot, attach_context_to_message};
use crate::context::{AssistantContext, ContextId, attach_context_to_message};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse,
@ -175,7 +175,7 @@ pub struct Thread {
pending_summary: Task<Option<()>>,
messages: Vec<Message>,
next_message_id: MessageId,
context: BTreeMap<ContextId, ContextSnapshot>,
context: BTreeMap<ContextId, AssistantContext>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
system_prompt_context: Option<AssistantSystemPromptContext>,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
@ -473,15 +473,15 @@ impl Thread {
cx.notify();
}
pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
let context = self.context_by_message.get(&id)?;
Some(
context
.into_iter()
.filter_map(|context_id| self.context.get(&context_id))
.cloned()
.collect::<Vec<_>>(),
)
pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
self.context_by_message
.get(&id)
.into_iter()
.flat_map(|context| {
context
.iter()
.filter_map(|context_id| self.context.get(&context_id))
})
}
/// Returns whether all of the tool uses have finished running.
@ -513,15 +513,18 @@ impl Thread {
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
context: Vec<ContextSnapshot>,
context: Vec<AssistantContext>,
git_checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
let message_id =
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
let context_ids = context
.iter()
.map(|context| context.id())
.collect::<Vec<_>>();
self.context
.extend(context.into_iter().map(|context| (context.id, context)));
.extend(context.into_iter().map(|context| (context.id(), context)));
self.context_by_message.insert(message_id, context_ids);
if let Some(git_checkpoint) = git_checkpoint {
self.pending_checkpoint = Some(ThreadCheckpoint {
@ -889,9 +892,8 @@ impl Thread {
let referenced_context = referenced_context_ids
.into_iter()
.filter_map(|context_id| self.context.get(context_id))
.cloned();
attach_context_to_message(&mut context_message, referenced_context);
.filter_map(|context_id| self.context.get(context_id));
attach_context_to_message(&mut context_message, referenced_context, cx);
request.messages.push(context_message);
}
@ -1300,13 +1302,13 @@ impl Thread {
pub fn attach_tool_results(
&mut self,
updated_context: Vec<ContextSnapshot>,
updated_context: Vec<AssistantContext>,
cx: &mut Context<Self>,
) {
self.context.extend(
updated_context
.into_iter()
.map(|context| (context.id, context)),
.map(|context| (context.id(), context)),
);
// Insert a user message to contain the tool results.

View file

@ -1,14 +1,15 @@
use std::rc::Rc;
use file_icons::FileIcons;
use gpui::ClickEvent;
use ui::{IconButtonShape, Tooltip, prelude::*};
use crate::context::{ContextKind, ContextSnapshot};
use crate::context::{AssistantContext, ContextId, ContextKind};
#[derive(IntoElement)]
pub enum ContextPill {
Added {
context: ContextSnapshot,
context: AddedContext,
dupe_name: bool,
focused: bool,
on_click: Option<Rc<dyn Fn(&ClickEvent, &mut Window, &mut App)>>,
@ -25,7 +26,7 @@ pub enum ContextPill {
impl ContextPill {
pub fn added(
context: ContextSnapshot,
context: AddedContext,
dupe_name: bool,
focused: bool,
on_remove: Option<Rc<dyn Fn(&ClickEvent, &mut Window, &mut App)>>,
@ -77,17 +78,21 @@ impl ContextPill {
pub fn icon(&self) -> Icon {
match self {
Self::Added { context, .. } => match &context.icon_path {
Some(icon_path) => Icon::from_path(icon_path),
None => Icon::new(context.kind.icon()),
},
Self::Suggested {
icon_path: Some(icon_path),
..
}
| Self::Added {
context:
AddedContext {
icon_path: Some(icon_path),
..
},
..
} => Icon::from_path(icon_path),
Self::Suggested {
kind,
icon_path: None,
Self::Suggested { kind, .. }
| Self::Added {
context: AddedContext { kind, .. },
..
} => Icon::new(kind.icon()),
}
@ -144,7 +149,7 @@ impl RenderOnce for ContextPill {
element
}
})
.when_some(context.tooltip.clone(), |element, tooltip| {
.when_some(context.tooltip.as_ref(), |element, tooltip| {
element.tooltip(Tooltip::text(tooltip.clone()))
}),
)
@ -219,3 +224,91 @@ impl RenderOnce for ContextPill {
}
}
}
pub struct AddedContext {
pub id: ContextId,
pub kind: ContextKind,
pub name: SharedString,
pub parent: Option<SharedString>,
pub tooltip: Option<SharedString>,
pub icon_path: Option<SharedString>,
}
impl AddedContext {
pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
match context {
AssistantContext::File(file_context) => {
let full_path = file_context.context_buffer.file.full_path(cx);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
.file_name()
.map(|n| n.to_string_lossy().into_owned().into())
.unwrap_or_else(|| full_path_string.clone());
let parent = full_path
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: file_context.id,
kind: ContextKind::File,
name,
parent,
tooltip: Some(full_path_string),
icon_path: FileIcons::get_icon(&full_path, cx),
}
}
AssistantContext::Directory(directory_context) => {
// TODO: handle worktree disambiguation. Maybe by storing an `Arc<dyn File>` to also
// handle renames?
let full_path = &directory_context.project_path.path;
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
.file_name()
.map(|n| n.to_string_lossy().into_owned().into())
.unwrap_or_else(|| full_path_string.clone());
let parent = full_path
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: directory_context.id,
kind: ContextKind::Directory,
name,
parent,
tooltip: Some(full_path_string),
icon_path: None,
}
}
AssistantContext::Symbol(symbol_context) => AddedContext {
id: symbol_context.id,
kind: ContextKind::Symbol,
name: symbol_context.context_symbol.id.name.clone(),
parent: None,
tooltip: None,
icon_path: None,
},
AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
id: fetched_url_context.id,
kind: ContextKind::FetchedUrl,
name: fetched_url_context.url.clone(),
parent: None,
tooltip: None,
icon_path: None,
},
AssistantContext::Thread(thread_context) => AddedContext {
id: thread_context.id,
kind: ContextKind::Thread,
name: thread_context.summary(cx),
parent: None,
tooltip: None,
icon_path: None,
},
}
}
}