agent: Allow adding/removing context when editing existing messages (#29698)

Release Notes:

- agent: Support adding/removing context when editing existing message

---------

Co-authored-by: Cole Miller <m@cole-miller.net>
Co-authored-by: Cole Miller <cole@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2025-05-01 03:39:34 +02:00 committed by GitHub
parent f046d70625
commit 1bf9e15f26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 418 additions and 189 deletions

View file

@ -1,5 +1,7 @@
use crate::context::{AgentContextHandle, RULES_ICON}; use crate::context::{AgentContextHandle, RULES_ICON};
use crate::context_picker::MentionLink; use crate::context_picker::{ContextPicker, MentionLink};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::thread::{ use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent, LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
ThreadFeedback, ThreadFeedback,
@ -14,14 +16,16 @@ use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting}; use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use assistant_tool::ToolUseStatus; use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
use editor::scroll::Autoscroll; use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer}; use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer};
use gpui::{ use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem, AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardEntry,
DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment, ClipboardItem, DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla,
ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, ListAlignment, ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful,
Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
linear_color_stop, linear_gradient, list, percentage, pulsating_between, UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage,
pulsating_between,
}; };
use language::{Buffer, Language, LanguageRegistry}; use language::{Buffer, Language, LanguageRegistry};
use language_model::{ use language_model::{
@ -41,7 +45,8 @@ use std::time::Duration;
use text::ToPoint; use text::ToPoint;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{ use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*, Disclosure, IconButton, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize,
Tooltip, prelude::*,
}; };
use util::ResultExt as _; use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock; use util::markdown::MarkdownCodeBlock;
@ -49,6 +54,7 @@ use workspace::Workspace;
use zed_actions::assistant::OpenRulesLibrary; use zed_actions::assistant::OpenRulesLibrary;
pub struct ActiveThread { pub struct ActiveThread {
context_store: Entity<ContextStore>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
thread: Entity<Thread>, thread: Entity<Thread>,
@ -61,7 +67,7 @@ pub struct ActiveThread {
hide_scrollbar_task: Option<Task<()>>, hide_scrollbar_task: Option<Task<()>>,
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>, rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
rendered_tool_uses: HashMap<LanguageModelToolUseId, RenderedToolUse>, rendered_tool_uses: HashMap<LanguageModelToolUseId, RenderedToolUse>,
editing_message: Option<(MessageId, EditMessageState)>, editing_message: Option<(MessageId, EditingMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>, expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
expanded_thinking_segments: HashMap<(MessageId, usize), bool>, expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
expanded_code_blocks: HashMap<(MessageId, usize), bool>, expanded_code_blocks: HashMap<(MessageId, usize), bool>,
@ -72,6 +78,7 @@ pub struct ActiveThread {
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>, notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
open_feedback_editors: HashMap<MessageId, Entity<Editor>>, open_feedback_editors: HashMap<MessageId, Entity<Editor>>,
_load_edited_message_context_task: Option<Task<()>>,
} }
struct RenderedMessage { struct RenderedMessage {
@ -725,10 +732,12 @@ fn open_markdown_link(
} }
} }
struct EditMessageState { struct EditingMessageState {
editor: Entity<Editor>, editor: Entity<Editor>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
last_estimated_token_count: Option<usize>, last_estimated_token_count: Option<usize>,
_subscription: Subscription, _subscriptions: [Subscription; 2],
_update_token_count_task: Option<Task<()>>, _update_token_count_task: Option<Task<()>>,
} }
@ -736,6 +745,7 @@ impl ActiveThread {
pub fn new( pub fn new(
thread: Entity<Thread>, thread: Entity<Thread>,
thread_store: Entity<ThreadStore>, thread_store: Entity<ThreadStore>,
context_store: Entity<ContextStore>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
window: &mut Window, window: &mut Window,
@ -758,6 +768,7 @@ impl ActiveThread {
let mut this = Self { let mut this = Self {
language_registry, language_registry,
thread_store, thread_store,
context_store,
thread: thread.clone(), thread: thread.clone(),
workspace, workspace,
save_thread_task: None, save_thread_task: None,
@ -779,6 +790,7 @@ impl ActiveThread {
_subscriptions: subscriptions, _subscriptions: subscriptions,
notification_subscriptions: HashMap::default(), notification_subscriptions: HashMap::default(),
open_feedback_editors: HashMap::default(), open_feedback_editors: HashMap::default(),
_load_edited_message_context_task: None,
}; };
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() { for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
@ -1237,33 +1249,49 @@ impl ActiveThread {
return; return;
}; };
let buffer = cx.new(|cx| { let editor = crate::message_editor::create_editor(
MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx) self.workspace.clone(),
}); self.context_store.downgrade(),
let editor = cx.new(|cx| { self.thread_store.downgrade(),
let mut editor = Editor::new( window,
editor::EditorMode::AutoHeight { max_lines: 8 }, cx,
buffer, );
None, editor.update(cx, |editor, cx| {
window, editor.set_text(message_text.clone(), window, cx);
cx,
);
editor.focus_handle(cx).focus(window); editor.focus_handle(cx).focus(window);
editor.move_to_end(&editor::actions::MoveToEnd, window, cx); editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
editor
}); });
let subscription = cx.subscribe(&editor, |this, _, event, cx| match event { let buffer_edited_subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => { EditorEvent::BufferEdited => {
this.update_editing_message_token_count(true, cx); this.update_editing_message_token_count(true, cx);
} }
_ => {} _ => {}
}); });
let context_picker_menu_handle = PopoverMenuHandle::default();
let context_strip = cx.new(|cx| {
ContextStrip::new(
self.context_store.clone(),
self.workspace.clone(),
Some(self.thread_store.downgrade()),
context_picker_menu_handle.clone(),
SuggestContextKind::File,
window,
cx,
)
});
let context_strip_subscription =
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event);
self.editing_message = Some(( self.editing_message = Some((
message_id, message_id,
EditMessageState { EditingMessageState {
editor: editor.clone(), editor: editor.clone(),
context_strip,
context_picker_menu_handle,
last_estimated_token_count: None, last_estimated_token_count: None,
_subscription: subscription, _subscriptions: [buffer_edited_subscription, context_strip_subscription],
_update_token_count_task: None, _update_token_count_task: None,
}, },
)); ));
@ -1271,6 +1299,26 @@ impl ActiveThread {
cx.notify(); cx.notify();
} }
fn handle_context_strip_event(
&mut self,
_context_strip: &Entity<ContextStrip>,
event: &ContextStripEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some((_, state)) = self.editing_message.as_ref() {
match event {
ContextStripEvent::PickerDismissed
| ContextStripEvent::BlurredEmpty
| ContextStripEvent::BlurredDown => {
let editor_focus_handle = state.editor.focus_handle(cx);
window.focus(&editor_focus_handle);
}
ContextStripEvent::BlurredUp => {}
}
}
}
fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context<Self>) { fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context<Self>) {
let Some((message_id, state)) = self.editing_message.as_mut() else { let Some((message_id, state)) = self.editing_message.as_mut() else {
return; return;
@ -1357,6 +1405,68 @@ impl ActiveThread {
})); }));
} }
fn toggle_context_picker(
&mut self,
_: &crate::ToggleContextPicker,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some((_, state)) = self.editing_message.as_mut() {
let handle = state.context_picker_menu_handle.clone();
window.defer(cx, move |window, cx| {
handle.toggle(window, cx);
});
}
}
fn remove_all_context(
&mut self,
_: &crate::RemoveAllContext,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.context_store.update(cx, |store, _cx| store.clear());
cx.notify();
}
fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
if let Some((_, state)) = self.editing_message.as_mut() {
if state.context_picker_menu_handle.is_deployed() {
cx.propagate();
} else {
state.context_strip.focus_handle(cx).focus(window);
}
}
}
fn paste(&mut self, _: &Paste, _window: &mut Window, cx: &mut Context<Self>) {
let images = cx
.read_from_clipboard()
.map(|item| {
item.into_entries()
.filter_map(|entry| {
if let ClipboardEntry::Image(image) = entry {
Some(image)
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if images.is_empty() {
return;
}
cx.stop_propagation();
self.context_store.update(cx, |store, cx| {
for image in images {
store.add_image_instance(Arc::new(image), cx);
}
});
}
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) { fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
self.editing_message.take(); self.editing_message.take();
cx.notify(); cx.notify();
@ -1371,21 +1481,11 @@ impl ActiveThread {
let Some((message_id, state)) = self.editing_message.take() else { let Some((message_id, state)) = self.editing_message.take() else {
return; return;
}; };
let edited_text = state.editor.read(cx).text(cx);
let thread_model = self.thread.update(cx, |thread, cx| {
thread.edit_message(
message_id,
Role::User,
vec![MessageSegment::Text(edited_text)],
cx,
);
for message_id in self.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
thread.get_or_init_configured_model(cx)
});
let Some(model) = thread_model else { let Some(model) = self
.thread
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
else {
return; return;
}; };
@ -1394,11 +1494,45 @@ impl ActiveThread {
return; return;
} }
self.thread.update(cx, |thread, cx| { let edited_text = state.editor.read(cx).text(cx);
thread.advance_prompt_id();
thread.send_to_model(model.model, Some(window.window_handle()), cx); let new_context = self
}); .context_store
cx.notify(); .read(cx)
.new_context_for_thread(self.thread.read(cx), Some(message_id));
let project = self.thread.read(cx).project().clone();
let prompt_store = self.thread_store.read(cx).prompt_store().clone();
let load_context_task =
crate::context::load_context(new_context, &project, &prompt_store, cx);
self._load_edited_message_context_task =
Some(cx.spawn_in(window, async move |this, cx| {
let context = load_context_task.await;
let _ = this
.update_in(cx, |this, window, cx| {
this.thread.update(cx, |thread, cx| {
thread.edit_message(
message_id,
Role::User,
vec![MessageSegment::Text(edited_text)],
Some(context.loaded_context),
cx,
);
for message_id in this.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
});
this.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model.model, Some(window.window_handle()), cx);
});
this._load_edited_message_context_task = None;
cx.notify();
})
.log_err();
}));
} }
fn messages_after(&self, message_id: MessageId) -> &[MessageId] { fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
@ -1519,6 +1653,53 @@ impl ActiveThread {
} }
} }
fn render_edit_message_editor(
&self,
state: &EditingMessageState,
window: &mut Window,
cx: &Context<Self>,
) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.75;
let colors = cx.theme().colors();
let text_style = TextStyle {
color: colors.text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
v_flex()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::toggle_context_picker))
.on_action(cx.listener(Self::remove_all_context))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.capture_action(cx.listener(Self::paste))
.min_h_6()
.flex_grow()
.w_full()
.gap_2()
.child(EditorElement::new(
&state.editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.child(state.context_strip.clone())
}
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement { fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
let message_id = self.messages[ix]; let message_id = self.messages[ix];
let Some(message) = self.thread.read(cx).message(message_id) else { let Some(message) = self.thread.read(cx).message(message_id) else {
@ -1551,11 +1732,11 @@ impl ActiveThread {
let generating_label = (is_generating && is_last_message) let generating_label = (is_generating && is_last_message)
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small)); .then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
let edit_message_editor = self let editing_message_state = self
.editing_message .editing_message
.as_ref() .as_ref()
.filter(|(id, _)| *id == message_id) .filter(|(id, _)| *id == message_id)
.map(|(_, state)| state.editor.clone()); .map(|(_, state)| state);
let colors = cx.theme().colors(); let colors = cx.theme().colors();
let editor_bg_color = colors.editor_background; let editor_bg_color = colors.editor_background;
@ -1690,77 +1871,43 @@ impl ActiveThread {
let has_content = !message_is_empty || !added_context.is_empty(); let has_content = !message_is_empty || !added_context.is_empty();
let message_content = has_content.then(|| { let message_content = has_content.then(|| {
v_flex() if let Some(state) = editing_message_state.as_ref() {
.w_full() self.render_edit_message_editor(state, window, cx)
.gap_1() .into_any_element()
.when(!message_is_empty, |parent| { } else {
parent.child( v_flex()
if let Some(edit_message_editor) = edit_message_editor.clone() { .w_full()
let settings = ThemeSettings::get_global(cx); .gap_1()
let font_size = TextSize::Small.rems(cx); .when(!message_is_empty, |parent| {
let line_height = font_size.to_pixels(window.rem_size()) * 1.75; parent.child(div().min_h_6().child(self.render_message_content(
message_id,
let text_style = TextStyle { rendered_message,
color: cx.theme().colors().text, has_tool_uses,
font_family: settings.buffer_font.family.clone(), workspace.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(), window,
font_features: settings.buffer_font.features.clone(), cx,
font_size: font_size.into(), )))
line_height: line_height.into(), })
..Default::default() .when(!added_context.is_empty(), |parent| {
}; parent.child(h_flex().flex_wrap().gap_1().children(
added_context.into_iter().map(|added_context| {
div() let context = added_context.handle.clone();
.key_context("EditMessageEditor") ContextPill::added(added_context, false, false, None).on_click(
.on_action(cx.listener(Self::cancel_editing_message)) Rc::new(cx.listener({
.on_action(cx.listener(Self::confirm_editing_message)) let workspace = workspace.clone();
.min_h_6() move |_, _, window, cx| {
.flex_grow() if let Some(workspace) = workspace.upgrade() {
.w_full() open_context(&context, workspace, window, cx);
.child(EditorElement::new( cx.notify();
&edit_message_editor, }
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.into_any()
} else {
div()
.min_h_6()
.child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
))
.into_any()
},
)
})
.when(!added_context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
added_context.into_iter().map(|added_context| {
let context = added_context.handle.clone();
ContextPill::added(added_context, false, false, None).on_click(Rc::new(
cx.listener({
let workspace = workspace.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_context(&context, workspace, window, cx);
cx.notify();
} }
} })),
}), )
)) }),
}), ))
)) })
}) .into_any_element()
}
}); });
let styled_message = match message.role { let styled_message = match message.role {
@ -1785,8 +1932,8 @@ impl ActiveThread {
.p_2p5() .p_2p5()
.gap_1() .gap_1()
.children(message_content) .children(message_content)
.when_some(edit_message_editor.clone(), |this, edit_editor| { .when_some(editing_message_state, |this, state| {
let edit_editor_clone = edit_editor.clone(); let focus_handle = state.editor.focus_handle(cx).clone();
this.w_full().justify_between().child( this.w_full().justify_between().child(
h_flex() h_flex()
.gap_0p5() .gap_0p5()
@ -1797,16 +1944,17 @@ impl ActiveThread {
) )
.shape(ui::IconButtonShape::Square) .shape(ui::IconButtonShape::Square)
.icon_color(Color::Error) .icon_color(Color::Error)
.tooltip(move |window, cx| { .tooltip({
let focus_handle = let focus_handle = focus_handle.clone();
edit_editor_clone.focus_handle(cx); move |window, cx| {
Tooltip::for_action_in( Tooltip::for_action_in(
"Cancel Edit", "Cancel Edit",
&menu::Cancel, &menu::Cancel,
&focus_handle, &focus_handle,
window, window,
cx, cx,
) )
}
}) })
.on_click(cx.listener(Self::handle_cancel_click)), .on_click(cx.listener(Self::handle_cancel_click)),
) )
@ -1815,18 +1963,20 @@ impl ActiveThread {
"confirm-edit-message", "confirm-edit-message",
IconName::Check, IconName::Check,
) )
.disabled(edit_editor.read(cx).is_empty(cx)) .disabled(state.editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square) .shape(ui::IconButtonShape::Square)
.icon_color(Color::Success) .icon_color(Color::Success)
.tooltip(move |window, cx| { .tooltip({
let focus_handle = edit_editor.focus_handle(cx); let focus_handle = focus_handle.clone();
Tooltip::for_action_in( move |window, cx| {
"Regenerate", Tooltip::for_action_in(
&menu::Confirm, "Regenerate",
&focus_handle, &menu::Confirm,
window, &focus_handle,
cx, window,
) cx,
)
}
}) })
.on_click( .on_click(
cx.listener(Self::handle_regenerate_click), cx.listener(Self::handle_regenerate_click),
@ -1835,7 +1985,7 @@ impl ActiveThread {
) )
}), }),
) )
.when(edit_message_editor.is_none(), |this| { .when(editing_message_state.is_none(), |this| {
this.tooltip(Tooltip::text("Click To Edit")) this.tooltip(Tooltip::text("Click To Edit"))
}) })
.on_click(cx.listener({ .on_click(cx.listener({

View file

@ -424,6 +424,7 @@ impl AssistantPanel {
ActiveThread::new( ActiveThread::new(
thread.clone(), thread.clone(),
thread_store.clone(), thread_store.clone(),
message_editor_context_store.clone(),
language_registry.clone(), language_registry.clone(),
workspace.clone(), workspace.clone(),
window, window,
@ -626,7 +627,7 @@ impl AssistantPanel {
let thread_view = ActiveView::thread(thread.clone(), window, cx); let thread_view = ActiveView::thread(thread.clone(), window, cx);
self.set_active_view(thread_view, window, cx); self.set_active_view(thread_view, window, cx);
let message_editor_context_store = cx.new(|_cx| { let context_store = cx.new(|_cx| {
crate::context_store::ContextStore::new( crate::context_store::ContextStore::new(
self.project.downgrade(), self.project.downgrade(),
Some(self.thread_store.downgrade()), Some(self.thread_store.downgrade()),
@ -639,7 +640,7 @@ impl AssistantPanel {
.update(cx, |this, cx| this.open_thread(&other_thread_id, cx)); .update(cx, |this, cx| this.open_thread(&other_thread_id, cx));
cx.spawn({ cx.spawn({
let context_store = message_editor_context_store.clone(); let context_store = context_store.clone();
async move |_panel, cx| { async move |_panel, cx| {
let other_thread = other_thread_task.await?; let other_thread = other_thread_task.await?;
@ -664,6 +665,7 @@ impl AssistantPanel {
ActiveThread::new( ActiveThread::new(
thread.clone(), thread.clone(),
self.thread_store.clone(), self.thread_store.clone(),
context_store.clone(),
self.language_registry.clone(), self.language_registry.clone(),
self.workspace.clone(), self.workspace.clone(),
window, window,
@ -682,7 +684,7 @@ impl AssistantPanel {
MessageEditor::new( MessageEditor::new(
self.fs.clone(), self.fs.clone(),
self.workspace.clone(), self.workspace.clone(),
message_editor_context_store, context_store,
self.prompt_store.clone(), self.prompt_store.clone(),
self.thread_store.downgrade(), self.thread_store.downgrade(),
thread, thread,
@ -843,7 +845,7 @@ impl AssistantPanel {
) { ) {
let thread_view = ActiveView::thread(thread.clone(), window, cx); let thread_view = ActiveView::thread(thread.clone(), window, cx);
self.set_active_view(thread_view, window, cx); self.set_active_view(thread_view, window, cx);
let message_editor_context_store = cx.new(|_cx| { let context_store = cx.new(|_cx| {
crate::context_store::ContextStore::new( crate::context_store::ContextStore::new(
self.project.downgrade(), self.project.downgrade(),
Some(self.thread_store.downgrade()), Some(self.thread_store.downgrade()),
@ -860,6 +862,7 @@ impl AssistantPanel {
ActiveThread::new( ActiveThread::new(
thread.clone(), thread.clone(),
self.thread_store.clone(), self.thread_store.clone(),
context_store.clone(),
self.language_registry.clone(), self.language_registry.clone(),
self.workspace.clone(), self.workspace.clone(),
window, window,
@ -878,7 +881,7 @@ impl AssistantPanel {
MessageEditor::new( MessageEditor::new(
self.fs.clone(), self.fs.clone(),
self.workspace.clone(), self.workspace.clone(),
message_editor_context_store, context_store,
self.prompt_store.clone(), self.prompt_store.clone(),
self.thread_store.downgrade(), self.thread_store.downgrade(),
thread, thread,

View file

@ -21,7 +21,7 @@ use crate::context::{
SymbolContextHandle, ThreadContextHandle, SymbolContextHandle, ThreadContextHandle,
}; };
use crate::context_strip::SuggestedContext; use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId}; use crate::thread::{MessageId, Thread, ThreadId};
pub struct ContextStore { pub struct ContextStore {
project: WeakEntity<Project>, project: WeakEntity<Project>,
@ -54,9 +54,14 @@ impl ContextStore {
self.context_thread_ids.clear(); self.context_thread_ids.clear();
} }
pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> { pub fn new_context_for_thread(
&self,
thread: &Thread,
exclude_messages_from_id: Option<MessageId>,
) -> Vec<AgentContextHandle> {
let existing_context = thread let existing_context = thread
.messages() .messages()
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
.flat_map(|message| { .flat_map(|message| {
message message
.loaded_context .loaded_context

View file

@ -69,6 +69,56 @@ pub struct MessageEditor {
const MAX_EDITOR_LINES: usize = 8; const MAX_EDITOR_LINES: usize = 8;
pub(crate) fn create_editor(
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Entity<Editor> {
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
);
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
max_lines: MAX_EDITOR_LINES,
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
let editor_entity = editor.downgrade();
editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace,
context_store,
Some(thread_store),
editor_entity,
))));
});
editor
}
impl MessageEditor { impl MessageEditor {
pub fn new( pub fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
@ -83,47 +133,14 @@ impl MessageEditor {
let context_picker_menu_handle = PopoverMenuHandle::default(); let context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default(); let model_selector_menu_handle = PopoverMenuHandle::default();
let language = Language::new( let editor = create_editor(
language::LanguageConfig { workspace.clone(),
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']), context_store.downgrade(),
..Default::default() thread_store.clone(),
}, window,
None, cx,
); );
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
max_lines: MAX_EDITOR_LINES,
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
let editor_entity = editor.downgrade();
editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace.clone(),
context_store.downgrade(),
Some(thread_store.clone()),
editor_entity,
))));
});
let context_strip = cx.new(|cx| { let context_strip = cx.new(|cx| {
ContextStrip::new( ContextStrip::new(
context_store.clone(), context_store.clone(),
@ -1041,7 +1058,7 @@ impl MessageEditor {
let load_task = cx.spawn(async move |this, cx| { let load_task = cx.spawn(async move |this, cx| {
let Ok(load_task) = this.update(cx, |this, cx| { let Ok(load_task) = this.update(cx, |this, cx| {
let new_context = this.context_store.read_with(cx, |context_store, cx| { let new_context = this.context_store.read_with(cx, |context_store, cx| {
context_store.new_context_for_thread(this.thread.read(cx)) context_store.new_context_for_thread(this.thread.read(cx), None)
}); });
load_context(new_context, &this.project, &this.prompt_store, cx) load_context(new_context, &this.project, &this.prompt_store, cx)
}) else { }) else {

View file

@ -879,6 +879,7 @@ impl Thread {
id: MessageId, id: MessageId,
new_role: Role, new_role: Role,
new_segments: Vec<MessageSegment>, new_segments: Vec<MessageSegment>,
loaded_context: Option<LoadedContext>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> bool { ) -> bool {
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
@ -886,6 +887,9 @@ impl Thread {
}; };
message.role = new_role; message.role = new_role;
message.segments = new_segments; message.segments = new_segments;
if let Some(context) = loaded_context {
message.loaded_context = context;
}
self.touch_updated_at(); self.touch_updated_at();
cx.emit(ThreadEvent::MessageEdited(id)); cx.emit(ThreadEvent::MessageEdited(id));
true true
@ -2546,6 +2550,7 @@ fn main() {{
"file1.rs": "fn function1() {}\n", "file1.rs": "fn function1() {}\n",
"file2.rs": "fn function2() {}\n", "file2.rs": "fn function2() {}\n",
"file3.rs": "fn function3() {}\n", "file3.rs": "fn function3() {}\n",
"file4.rs": "fn function4() {}\n",
}), }),
) )
.await; .await;
@ -2558,7 +2563,7 @@ fn main() {{
.await .await
.unwrap(); .unwrap();
let new_contexts = context_store.update(cx, |store, cx| { let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx)) store.new_context_for_thread(thread.read(cx), None)
}); });
assert_eq!(new_contexts.len(), 1); assert_eq!(new_contexts.len(), 1);
let loaded_context = cx let loaded_context = cx
@ -2573,7 +2578,7 @@ fn main() {{
.await .await
.unwrap(); .unwrap();
let new_contexts = context_store.update(cx, |store, cx| { let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx)) store.new_context_for_thread(thread.read(cx), None)
}); });
assert_eq!(new_contexts.len(), 1); assert_eq!(new_contexts.len(), 1);
let loaded_context = cx let loaded_context = cx
@ -2589,7 +2594,7 @@ fn main() {{
.await .await
.unwrap(); .unwrap();
let new_contexts = context_store.update(cx, |store, cx| { let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx)) store.new_context_for_thread(thread.read(cx), None)
}); });
assert_eq!(new_contexts.len(), 1); assert_eq!(new_contexts.len(), 1);
let loaded_context = cx let loaded_context = cx
@ -2640,6 +2645,55 @@ fn main() {{
assert!(!request.messages[3].string_contents().contains("file1.rs")); assert!(!request.messages[3].string_contents().contains("file1.rs"));
assert!(!request.messages[3].string_contents().contains("file2.rs")); assert!(!request.messages[3].string_contents().contains("file2.rs"));
assert!(request.messages[3].string_contents().contains("file3.rs")); assert!(request.messages[3].string_contents().contains("file3.rs"));
add_file_to_context(&project, &context_store, "test/file4.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 3);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file4.rs
store.remove_context(&loaded_context.contexts[2].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 2);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file3.rs
store.remove_context(&loaded_context.contexts[1].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(!loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
} }
#[gpui::test] #[gpui::test]