agent: Allow adding/removing context when editing existing messages (#29698)

Release Notes:

- agent: Support adding/removing context when editing existing message

---------

Co-authored-by: Cole Miller <m@cole-miller.net>
Co-authored-by: Cole Miller <cole@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2025-05-01 03:39:34 +02:00 committed by GitHub
parent f046d70625
commit 1bf9e15f26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 418 additions and 189 deletions

View file

@ -1,5 +1,7 @@
use crate::context::{AgentContextHandle, RULES_ICON};
use crate::context_picker::MentionLink;
use crate::context_picker::{ContextPicker, MentionLink};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
ThreadFeedback,
@ -14,14 +16,16 @@ use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer};
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem,
DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment,
ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription,
Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardEntry,
ClipboardItem, DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla,
ListAlignment, ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful,
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage,
pulsating_between,
};
use language::{Buffer, Language, LanguageRegistry};
use language_model::{
@ -41,7 +45,8 @@ use std::time::Duration;
use text::ToPoint;
use theme::ThemeSettings;
use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
Disclosure, IconButton, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize,
Tooltip, prelude::*,
};
use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
@ -49,6 +54,7 @@ use workspace::Workspace;
use zed_actions::assistant::OpenRulesLibrary;
pub struct ActiveThread {
context_store: Entity<ContextStore>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
@ -61,7 +67,7 @@ pub struct ActiveThread {
hide_scrollbar_task: Option<Task<()>>,
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
rendered_tool_uses: HashMap<LanguageModelToolUseId, RenderedToolUse>,
editing_message: Option<(MessageId, EditMessageState)>,
editing_message: Option<(MessageId, EditingMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
expanded_code_blocks: HashMap<(MessageId, usize), bool>,
@ -72,6 +78,7 @@ pub struct ActiveThread {
_subscriptions: Vec<Subscription>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
open_feedback_editors: HashMap<MessageId, Entity<Editor>>,
_load_edited_message_context_task: Option<Task<()>>,
}
struct RenderedMessage {
@ -725,10 +732,12 @@ fn open_markdown_link(
}
}
struct EditMessageState {
struct EditingMessageState {
editor: Entity<Editor>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
last_estimated_token_count: Option<usize>,
_subscription: Subscription,
_subscriptions: [Subscription; 2],
_update_token_count_task: Option<Task<()>>,
}
@ -736,6 +745,7 @@ impl ActiveThread {
pub fn new(
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
context_store: Entity<ContextStore>,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
@ -758,6 +768,7 @@ impl ActiveThread {
let mut this = Self {
language_registry,
thread_store,
context_store,
thread: thread.clone(),
workspace,
save_thread_task: None,
@ -779,6 +790,7 @@ impl ActiveThread {
_subscriptions: subscriptions,
notification_subscriptions: HashMap::default(),
open_feedback_editors: HashMap::default(),
_load_edited_message_context_task: None,
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
@ -1237,33 +1249,49 @@ impl ActiveThread {
return;
};
let buffer = cx.new(|cx| {
MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
});
let editor = cx.new(|cx| {
let mut editor = Editor::new(
editor::EditorMode::AutoHeight { max_lines: 8 },
buffer,
None,
window,
cx,
);
let editor = crate::message_editor::create_editor(
self.workspace.clone(),
self.context_store.downgrade(),
self.thread_store.downgrade(),
window,
cx,
);
editor.update(cx, |editor, cx| {
editor.set_text(message_text.clone(), window, cx);
editor.focus_handle(cx).focus(window);
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
editor
});
let subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
let buffer_edited_subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => {
this.update_editing_message_token_count(true, cx);
}
_ => {}
});
let context_picker_menu_handle = PopoverMenuHandle::default();
let context_strip = cx.new(|cx| {
ContextStrip::new(
self.context_store.clone(),
self.workspace.clone(),
Some(self.thread_store.downgrade()),
context_picker_menu_handle.clone(),
SuggestContextKind::File,
window,
cx,
)
});
let context_strip_subscription =
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event);
self.editing_message = Some((
message_id,
EditMessageState {
EditingMessageState {
editor: editor.clone(),
context_strip,
context_picker_menu_handle,
last_estimated_token_count: None,
_subscription: subscription,
_subscriptions: [buffer_edited_subscription, context_strip_subscription],
_update_token_count_task: None,
},
));
@ -1271,6 +1299,26 @@ impl ActiveThread {
cx.notify();
}
fn handle_context_strip_event(
&mut self,
_context_strip: &Entity<ContextStrip>,
event: &ContextStripEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some((_, state)) = self.editing_message.as_ref() {
match event {
ContextStripEvent::PickerDismissed
| ContextStripEvent::BlurredEmpty
| ContextStripEvent::BlurredDown => {
let editor_focus_handle = state.editor.focus_handle(cx);
window.focus(&editor_focus_handle);
}
ContextStripEvent::BlurredUp => {}
}
}
}
fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context<Self>) {
let Some((message_id, state)) = self.editing_message.as_mut() else {
return;
@ -1357,6 +1405,68 @@ impl ActiveThread {
}));
}
fn toggle_context_picker(
&mut self,
_: &crate::ToggleContextPicker,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some((_, state)) = self.editing_message.as_mut() {
let handle = state.context_picker_menu_handle.clone();
window.defer(cx, move |window, cx| {
handle.toggle(window, cx);
});
}
}
fn remove_all_context(
&mut self,
_: &crate::RemoveAllContext,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.context_store.update(cx, |store, _cx| store.clear());
cx.notify();
}
fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
if let Some((_, state)) = self.editing_message.as_mut() {
if state.context_picker_menu_handle.is_deployed() {
cx.propagate();
} else {
state.context_strip.focus_handle(cx).focus(window);
}
}
}
fn paste(&mut self, _: &Paste, _window: &mut Window, cx: &mut Context<Self>) {
let images = cx
.read_from_clipboard()
.map(|item| {
item.into_entries()
.filter_map(|entry| {
if let ClipboardEntry::Image(image) = entry {
Some(image)
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if images.is_empty() {
return;
}
cx.stop_propagation();
self.context_store.update(cx, |store, cx| {
for image in images {
store.add_image_instance(Arc::new(image), cx);
}
});
}
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
self.editing_message.take();
cx.notify();
@ -1371,21 +1481,11 @@ impl ActiveThread {
let Some((message_id, state)) = self.editing_message.take() else {
return;
};
let edited_text = state.editor.read(cx).text(cx);
let thread_model = self.thread.update(cx, |thread, cx| {
thread.edit_message(
message_id,
Role::User,
vec![MessageSegment::Text(edited_text)],
cx,
);
for message_id in self.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
thread.get_or_init_configured_model(cx)
});
let Some(model) = thread_model else {
let Some(model) = self
.thread
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
else {
return;
};
@ -1394,11 +1494,45 @@ impl ActiveThread {
return;
}
self.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model.model, Some(window.window_handle()), cx);
});
cx.notify();
let edited_text = state.editor.read(cx).text(cx);
let new_context = self
.context_store
.read(cx)
.new_context_for_thread(self.thread.read(cx), Some(message_id));
let project = self.thread.read(cx).project().clone();
let prompt_store = self.thread_store.read(cx).prompt_store().clone();
let load_context_task =
crate::context::load_context(new_context, &project, &prompt_store, cx);
self._load_edited_message_context_task =
Some(cx.spawn_in(window, async move |this, cx| {
let context = load_context_task.await;
let _ = this
.update_in(cx, |this, window, cx| {
this.thread.update(cx, |thread, cx| {
thread.edit_message(
message_id,
Role::User,
vec![MessageSegment::Text(edited_text)],
Some(context.loaded_context),
cx,
);
for message_id in this.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
});
this.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model.model, Some(window.window_handle()), cx);
});
this._load_edited_message_context_task = None;
cx.notify();
})
.log_err();
}));
}
fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
@ -1519,6 +1653,53 @@ impl ActiveThread {
}
}
fn render_edit_message_editor(
&self,
state: &EditingMessageState,
window: &mut Window,
cx: &Context<Self>,
) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.75;
let colors = cx.theme().colors();
let text_style = TextStyle {
color: colors.text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
v_flex()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::toggle_context_picker))
.on_action(cx.listener(Self::remove_all_context))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.capture_action(cx.listener(Self::paste))
.min_h_6()
.flex_grow()
.w_full()
.gap_2()
.child(EditorElement::new(
&state.editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.child(state.context_strip.clone())
}
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
let message_id = self.messages[ix];
let Some(message) = self.thread.read(cx).message(message_id) else {
@ -1551,11 +1732,11 @@ impl ActiveThread {
let generating_label = (is_generating && is_last_message)
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
let edit_message_editor = self
let editing_message_state = self
.editing_message
.as_ref()
.filter(|(id, _)| *id == message_id)
.map(|(_, state)| state.editor.clone());
.map(|(_, state)| state);
let colors = cx.theme().colors();
let editor_bg_color = colors.editor_background;
@ -1690,77 +1871,43 @@ impl ActiveThread {
let has_content = !message_is_empty || !added_context.is_empty();
let message_content = has_content.then(|| {
v_flex()
.w_full()
.gap_1()
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.75;
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.flex_grow()
.w_full()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.into_any()
} else {
div()
.min_h_6()
.child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
))
.into_any()
},
)
})
.when(!added_context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
added_context.into_iter().map(|added_context| {
let context = added_context.handle.clone();
ContextPill::added(added_context, false, false, None).on_click(Rc::new(
cx.listener({
let workspace = workspace.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_context(&context, workspace, window, cx);
cx.notify();
if let Some(state) = editing_message_state.as_ref() {
self.render_edit_message_editor(state, window, cx)
.into_any_element()
} else {
v_flex()
.w_full()
.gap_1()
.when(!message_is_empty, |parent| {
parent.child(div().min_h_6().child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
)))
})
.when(!added_context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
added_context.into_iter().map(|added_context| {
let context = added_context.handle.clone();
ContextPill::added(added_context, false, false, None).on_click(
Rc::new(cx.listener({
let workspace = workspace.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_context(&context, workspace, window, cx);
cx.notify();
}
}
}
}),
))
}),
))
})
})),
)
}),
))
})
.into_any_element()
}
});
let styled_message = match message.role {
@ -1785,8 +1932,8 @@ impl ActiveThread {
.p_2p5()
.gap_1()
.children(message_content)
.when_some(edit_message_editor.clone(), |this, edit_editor| {
let edit_editor_clone = edit_editor.clone();
.when_some(editing_message_state, |this, state| {
let focus_handle = state.editor.focus_handle(cx).clone();
this.w_full().justify_between().child(
h_flex()
.gap_0p5()
@ -1797,16 +1944,17 @@ impl ActiveThread {
)
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Error)
.tooltip(move |window, cx| {
let focus_handle =
edit_editor_clone.focus_handle(cx);
Tooltip::for_action_in(
"Cancel Edit",
&menu::Cancel,
&focus_handle,
window,
cx,
)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Cancel Edit",
&menu::Cancel,
&focus_handle,
window,
cx,
)
}
})
.on_click(cx.listener(Self::handle_cancel_click)),
)
@ -1815,18 +1963,20 @@ impl ActiveThread {
"confirm-edit-message",
IconName::Check,
)
.disabled(edit_editor.read(cx).is_empty(cx))
.disabled(state.editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Success)
.tooltip(move |window, cx| {
let focus_handle = edit_editor.focus_handle(cx);
Tooltip::for_action_in(
"Regenerate",
&menu::Confirm,
&focus_handle,
window,
cx,
)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Regenerate",
&menu::Confirm,
&focus_handle,
window,
cx,
)
}
})
.on_click(
cx.listener(Self::handle_regenerate_click),
@ -1835,7 +1985,7 @@ impl ActiveThread {
)
}),
)
.when(edit_message_editor.is_none(), |this| {
.when(editing_message_state.is_none(), |this| {
this.tooltip(Tooltip::text("Click To Edit"))
})
.on_click(cx.listener({

View file

@ -424,6 +424,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
thread_store.clone(),
message_editor_context_store.clone(),
language_registry.clone(),
workspace.clone(),
window,
@ -626,7 +627,7 @@ impl AssistantPanel {
let thread_view = ActiveView::thread(thread.clone(), window, cx);
self.set_active_view(thread_view, window, cx);
let message_editor_context_store = cx.new(|_cx| {
let context_store = cx.new(|_cx| {
crate::context_store::ContextStore::new(
self.project.downgrade(),
Some(self.thread_store.downgrade()),
@ -639,7 +640,7 @@ impl AssistantPanel {
.update(cx, |this, cx| this.open_thread(&other_thread_id, cx));
cx.spawn({
let context_store = message_editor_context_store.clone();
let context_store = context_store.clone();
async move |_panel, cx| {
let other_thread = other_thread_task.await?;
@ -664,6 +665,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
context_store.clone(),
self.language_registry.clone(),
self.workspace.clone(),
window,
@ -682,7 +684,7 @@ impl AssistantPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,
@ -843,7 +845,7 @@ impl AssistantPanel {
) {
let thread_view = ActiveView::thread(thread.clone(), window, cx);
self.set_active_view(thread_view, window, cx);
let message_editor_context_store = cx.new(|_cx| {
let context_store = cx.new(|_cx| {
crate::context_store::ContextStore::new(
self.project.downgrade(),
Some(self.thread_store.downgrade()),
@ -860,6 +862,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
context_store.clone(),
self.language_registry.clone(),
self.workspace.clone(),
window,
@ -878,7 +881,7 @@ impl AssistantPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,

View file

@ -21,7 +21,7 @@ use crate::context::{
SymbolContextHandle, ThreadContextHandle,
};
use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId};
use crate::thread::{MessageId, Thread, ThreadId};
pub struct ContextStore {
project: WeakEntity<Project>,
@ -54,9 +54,14 @@ impl ContextStore {
self.context_thread_ids.clear();
}
pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
pub fn new_context_for_thread(
&self,
thread: &Thread,
exclude_messages_from_id: Option<MessageId>,
) -> Vec<AgentContextHandle> {
let existing_context = thread
.messages()
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
.flat_map(|message| {
message
.loaded_context

View file

@ -69,6 +69,56 @@ pub struct MessageEditor {
const MAX_EDITOR_LINES: usize = 8;
pub(crate) fn create_editor(
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Entity<Editor> {
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
);
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
max_lines: MAX_EDITOR_LINES,
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
let editor_entity = editor.downgrade();
editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace,
context_store,
Some(thread_store),
editor_entity,
))));
});
editor
}
impl MessageEditor {
pub fn new(
fs: Arc<dyn Fs>,
@ -83,47 +133,14 @@ impl MessageEditor {
let context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default();
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
let editor = create_editor(
workspace.clone(),
context_store.downgrade(),
thread_store.clone(),
window,
cx,
);
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
max_lines: MAX_EDITOR_LINES,
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
let editor_entity = editor.downgrade();
editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace.clone(),
context_store.downgrade(),
Some(thread_store.clone()),
editor_entity,
))));
});
let context_strip = cx.new(|cx| {
ContextStrip::new(
context_store.clone(),
@ -1041,7 +1058,7 @@ impl MessageEditor {
let load_task = cx.spawn(async move |this, cx| {
let Ok(load_task) = this.update(cx, |this, cx| {
let new_context = this.context_store.read_with(cx, |context_store, cx| {
context_store.new_context_for_thread(this.thread.read(cx))
context_store.new_context_for_thread(this.thread.read(cx), None)
});
load_context(new_context, &this.project, &this.prompt_store, cx)
}) else {

View file

@ -879,6 +879,7 @@ impl Thread {
id: MessageId,
new_role: Role,
new_segments: Vec<MessageSegment>,
loaded_context: Option<LoadedContext>,
cx: &mut Context<Self>,
) -> bool {
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
@ -886,6 +887,9 @@ impl Thread {
};
message.role = new_role;
message.segments = new_segments;
if let Some(context) = loaded_context {
message.loaded_context = context;
}
self.touch_updated_at();
cx.emit(ThreadEvent::MessageEdited(id));
true
@ -2546,6 +2550,7 @@ fn main() {{
"file1.rs": "fn function1() {}\n",
"file2.rs": "fn function2() {}\n",
"file3.rs": "fn function3() {}\n",
"file4.rs": "fn function4() {}\n",
}),
)
.await;
@ -2558,7 +2563,7 @@ fn main() {{
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
store.new_context_for_thread(thread.read(cx), None)
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
@ -2573,7 +2578,7 @@ fn main() {{
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
store.new_context_for_thread(thread.read(cx), None)
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
@ -2589,7 +2594,7 @@ fn main() {{
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
store.new_context_for_thread(thread.read(cx), None)
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
@ -2640,6 +2645,55 @@ fn main() {{
assert!(!request.messages[3].string_contents().contains("file1.rs"));
assert!(!request.messages[3].string_contents().contains("file2.rs"));
assert!(request.messages[3].string_contents().contains("file3.rs"));
add_file_to_context(&project, &context_store, "test/file4.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 3);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file4.rs
store.remove_context(&loaded_context.contexts[2].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 2);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file3.rs
store.remove_context(&loaded_context.contexts[1].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(!loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
}
#[gpui::test]