diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index ee99a96e84..6790649d83 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1,5 +1,7 @@ use crate::context::{AgentContextHandle, RULES_ICON}; -use crate::context_picker::MentionLink; +use crate::context_picker::{ContextPicker, MentionLink}; +use crate::context_store::ContextStore; +use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::thread::{ LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent, ThreadFeedback, @@ -14,14 +16,16 @@ use anyhow::Context as _; use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting}; use assistant_tool::ToolUseStatus; use collections::{HashMap, HashSet}; +use editor::actions::{MoveUp, Paste}; use editor::scroll::Autoscroll; use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer}; use gpui::{ - AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem, - DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment, - ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, - Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle, - linear_color_stop, linear_gradient, list, percentage, pulsating_between, + AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardEntry, + ClipboardItem, DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, + ListAlignment, ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, + StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation, + UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage, + pulsating_between, }; use language::{Buffer, Language, LanguageRegistry}; use language_model::{ @@ -41,7 +45,8 @@ use std::time::Duration; use text::ToPoint; use theme::ThemeSettings; use ui::{ - Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*, + Disclosure, IconButton, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize, + Tooltip, prelude::*, }; use util::ResultExt as _; use util::markdown::MarkdownCodeBlock; @@ -49,6 +54,7 @@ use workspace::Workspace; use zed_actions::assistant::OpenRulesLibrary; pub struct ActiveThread { + context_store: Entity, language_registry: Arc, thread_store: Entity, thread: Entity, @@ -61,7 +67,7 @@ pub struct ActiveThread { hide_scrollbar_task: Option>, rendered_messages_by_id: HashMap, rendered_tool_uses: HashMap, - editing_message: Option<(MessageId, EditMessageState)>, + editing_message: Option<(MessageId, EditingMessageState)>, expanded_tool_uses: HashMap, expanded_thinking_segments: HashMap<(MessageId, usize), bool>, expanded_code_blocks: HashMap<(MessageId, usize), bool>, @@ -72,6 +78,7 @@ pub struct ActiveThread { _subscriptions: Vec, notification_subscriptions: HashMap, Vec>, open_feedback_editors: HashMap>, + _load_edited_message_context_task: Option>, } struct RenderedMessage { @@ -725,10 +732,12 @@ fn open_markdown_link( } } -struct EditMessageState { +struct EditingMessageState { editor: Entity, + context_strip: Entity, + context_picker_menu_handle: PopoverMenuHandle, last_estimated_token_count: Option, - _subscription: Subscription, + _subscriptions: [Subscription; 2], _update_token_count_task: Option>, } @@ -736,6 +745,7 @@ impl ActiveThread { pub fn new( thread: Entity, thread_store: Entity, + context_store: Entity, language_registry: Arc, workspace: WeakEntity, window: &mut Window, @@ -758,6 +768,7 @@ impl ActiveThread { let mut this = Self { language_registry, thread_store, + context_store, thread: thread.clone(), workspace, save_thread_task: None, @@ -779,6 +790,7 @@ impl ActiveThread { _subscriptions: subscriptions, notification_subscriptions: HashMap::default(), open_feedback_editors: HashMap::default(), + _load_edited_message_context_task: None, }; for message in thread.read(cx).messages().cloned().collect::>() { @@ -1237,33 +1249,49 @@ impl ActiveThread { return; }; - let buffer = cx.new(|cx| { - MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx) - }); - let editor = cx.new(|cx| { - let mut editor = Editor::new( - editor::EditorMode::AutoHeight { max_lines: 8 }, - buffer, - None, - window, - cx, - ); + let editor = crate::message_editor::create_editor( + self.workspace.clone(), + self.context_store.downgrade(), + self.thread_store.downgrade(), + window, + cx, + ); + editor.update(cx, |editor, cx| { + editor.set_text(message_text.clone(), window, cx); editor.focus_handle(cx).focus(window); editor.move_to_end(&editor::actions::MoveToEnd, window, cx); - editor }); - let subscription = cx.subscribe(&editor, |this, _, event, cx| match event { + let buffer_edited_subscription = cx.subscribe(&editor, |this, _, event, cx| match event { EditorEvent::BufferEdited => { this.update_editing_message_token_count(true, cx); } _ => {} }); + + let context_picker_menu_handle = PopoverMenuHandle::default(); + let context_strip = cx.new(|cx| { + ContextStrip::new( + self.context_store.clone(), + self.workspace.clone(), + Some(self.thread_store.downgrade()), + context_picker_menu_handle.clone(), + SuggestContextKind::File, + window, + cx, + ) + }); + + let context_strip_subscription = + cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event); + self.editing_message = Some(( message_id, - EditMessageState { + EditingMessageState { editor: editor.clone(), + context_strip, + context_picker_menu_handle, last_estimated_token_count: None, - _subscription: subscription, + _subscriptions: [buffer_edited_subscription, context_strip_subscription], _update_token_count_task: None, }, )); @@ -1271,6 +1299,26 @@ impl ActiveThread { cx.notify(); } + fn handle_context_strip_event( + &mut self, + _context_strip: &Entity, + event: &ContextStripEvent, + window: &mut Window, + cx: &mut Context, + ) { + if let Some((_, state)) = self.editing_message.as_ref() { + match event { + ContextStripEvent::PickerDismissed + | ContextStripEvent::BlurredEmpty + | ContextStripEvent::BlurredDown => { + let editor_focus_handle = state.editor.focus_handle(cx); + window.focus(&editor_focus_handle); + } + ContextStripEvent::BlurredUp => {} + } + } + } + fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context) { let Some((message_id, state)) = self.editing_message.as_mut() else { return; @@ -1357,6 +1405,68 @@ impl ActiveThread { })); } + fn toggle_context_picker( + &mut self, + _: &crate::ToggleContextPicker, + window: &mut Window, + cx: &mut Context, + ) { + if let Some((_, state)) = self.editing_message.as_mut() { + let handle = state.context_picker_menu_handle.clone(); + window.defer(cx, move |window, cx| { + handle.toggle(window, cx); + }); + } + } + + fn remove_all_context( + &mut self, + _: &crate::RemoveAllContext, + _window: &mut Window, + cx: &mut Context, + ) { + self.context_store.update(cx, |store, _cx| store.clear()); + cx.notify(); + } + + fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context) { + if let Some((_, state)) = self.editing_message.as_mut() { + if state.context_picker_menu_handle.is_deployed() { + cx.propagate(); + } else { + state.context_strip.focus_handle(cx).focus(window); + } + } + } + + fn paste(&mut self, _: &Paste, _window: &mut Window, cx: &mut Context) { + let images = cx + .read_from_clipboard() + .map(|item| { + item.into_entries() + .filter_map(|entry| { + if let ClipboardEntry::Image(image) = entry { + Some(image) + } else { + None + } + }) + .collect::>() + }) + .unwrap_or_default(); + + if images.is_empty() { + return; + } + cx.stop_propagation(); + + self.context_store.update(cx, |store, cx| { + for image in images { + store.add_image_instance(Arc::new(image), cx); + } + }); + } + fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { self.editing_message.take(); cx.notify(); @@ -1371,21 +1481,11 @@ impl ActiveThread { let Some((message_id, state)) = self.editing_message.take() else { return; }; - let edited_text = state.editor.read(cx).text(cx); - let thread_model = self.thread.update(cx, |thread, cx| { - thread.edit_message( - message_id, - Role::User, - vec![MessageSegment::Text(edited_text)], - cx, - ); - for message_id in self.messages_after(message_id) { - thread.delete_message(*message_id, cx); - } - thread.get_or_init_configured_model(cx) - }); - let Some(model) = thread_model else { + let Some(model) = self + .thread + .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)) + else { return; }; @@ -1394,11 +1494,45 @@ impl ActiveThread { return; } - self.thread.update(cx, |thread, cx| { - thread.advance_prompt_id(); - thread.send_to_model(model.model, Some(window.window_handle()), cx); - }); - cx.notify(); + let edited_text = state.editor.read(cx).text(cx); + + let new_context = self + .context_store + .read(cx) + .new_context_for_thread(self.thread.read(cx), Some(message_id)); + + let project = self.thread.read(cx).project().clone(); + let prompt_store = self.thread_store.read(cx).prompt_store().clone(); + + let load_context_task = + crate::context::load_context(new_context, &project, &prompt_store, cx); + self._load_edited_message_context_task = + Some(cx.spawn_in(window, async move |this, cx| { + let context = load_context_task.await; + let _ = this + .update_in(cx, |this, window, cx| { + this.thread.update(cx, |thread, cx| { + thread.edit_message( + message_id, + Role::User, + vec![MessageSegment::Text(edited_text)], + Some(context.loaded_context), + cx, + ); + for message_id in this.messages_after(message_id) { + thread.delete_message(*message_id, cx); + } + }); + + this.thread.update(cx, |thread, cx| { + thread.advance_prompt_id(); + thread.send_to_model(model.model, Some(window.window_handle()), cx); + }); + this._load_edited_message_context_task = None; + cx.notify(); + }) + .log_err(); + })); } fn messages_after(&self, message_id: MessageId) -> &[MessageId] { @@ -1519,6 +1653,53 @@ impl ActiveThread { } } + fn render_edit_message_editor( + &self, + state: &EditingMessageState, + window: &mut Window, + cx: &Context, + ) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let font_size = TextSize::Small.rems(cx); + let line_height = font_size.to_pixels(window.rem_size()) * 1.75; + + let colors = cx.theme().colors(); + + let text_style = TextStyle { + color: colors.text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: font_size.into(), + line_height: line_height.into(), + ..Default::default() + }; + + v_flex() + .key_context("EditMessageEditor") + .on_action(cx.listener(Self::toggle_context_picker)) + .on_action(cx.listener(Self::remove_all_context)) + .on_action(cx.listener(Self::move_up)) + .on_action(cx.listener(Self::cancel_editing_message)) + .on_action(cx.listener(Self::confirm_editing_message)) + .capture_action(cx.listener(Self::paste)) + .min_h_6() + .flex_grow() + .w_full() + .gap_2() + .child(EditorElement::new( + &state.editor, + EditorStyle { + background: colors.editor_background, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + )) + .child(state.context_strip.clone()) + } + fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context) -> AnyElement { let message_id = self.messages[ix]; let Some(message) = self.thread.read(cx).message(message_id) else { @@ -1551,11 +1732,11 @@ impl ActiveThread { let generating_label = (is_generating && is_last_message) .then(|| AnimatedLabel::new("Generating").size(LabelSize::Small)); - let edit_message_editor = self + let editing_message_state = self .editing_message .as_ref() .filter(|(id, _)| *id == message_id) - .map(|(_, state)| state.editor.clone()); + .map(|(_, state)| state); let colors = cx.theme().colors(); let editor_bg_color = colors.editor_background; @@ -1690,77 +1871,43 @@ impl ActiveThread { let has_content = !message_is_empty || !added_context.is_empty(); let message_content = has_content.then(|| { - v_flex() - .w_full() - .gap_1() - .when(!message_is_empty, |parent| { - parent.child( - if let Some(edit_message_editor) = edit_message_editor.clone() { - let settings = ThemeSettings::get_global(cx); - let font_size = TextSize::Small.rems(cx); - let line_height = font_size.to_pixels(window.rem_size()) * 1.75; - - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.buffer_font.family.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_features: settings.buffer_font.features.clone(), - font_size: font_size.into(), - line_height: line_height.into(), - ..Default::default() - }; - - div() - .key_context("EditMessageEditor") - .on_action(cx.listener(Self::cancel_editing_message)) - .on_action(cx.listener(Self::confirm_editing_message)) - .min_h_6() - .flex_grow() - .w_full() - .child(EditorElement::new( - &edit_message_editor, - EditorStyle { - background: colors.editor_background, - local_player: cx.theme().players().local(), - text: text_style, - syntax: cx.theme().syntax().clone(), - ..Default::default() - }, - )) - .into_any() - } else { - div() - .min_h_6() - .child(self.render_message_content( - message_id, - rendered_message, - has_tool_uses, - workspace.clone(), - window, - cx, - )) - .into_any() - }, - ) - }) - .when(!added_context.is_empty(), |parent| { - parent.child(h_flex().flex_wrap().gap_1().children( - added_context.into_iter().map(|added_context| { - let context = added_context.handle.clone(); - ContextPill::added(added_context, false, false, None).on_click(Rc::new( - cx.listener({ - let workspace = workspace.clone(); - move |_, _, window, cx| { - if let Some(workspace) = workspace.upgrade() { - open_context(&context, workspace, window, cx); - cx.notify(); + if let Some(state) = editing_message_state.as_ref() { + self.render_edit_message_editor(state, window, cx) + .into_any_element() + } else { + v_flex() + .w_full() + .gap_1() + .when(!message_is_empty, |parent| { + parent.child(div().min_h_6().child(self.render_message_content( + message_id, + rendered_message, + has_tool_uses, + workspace.clone(), + window, + cx, + ))) + }) + .when(!added_context.is_empty(), |parent| { + parent.child(h_flex().flex_wrap().gap_1().children( + added_context.into_iter().map(|added_context| { + let context = added_context.handle.clone(); + ContextPill::added(added_context, false, false, None).on_click( + Rc::new(cx.listener({ + let workspace = workspace.clone(); + move |_, _, window, cx| { + if let Some(workspace) = workspace.upgrade() { + open_context(&context, workspace, window, cx); + cx.notify(); + } } - } - }), - )) - }), - )) - }) + })), + ) + }), + )) + }) + .into_any_element() + } }); let styled_message = match message.role { @@ -1785,8 +1932,8 @@ impl ActiveThread { .p_2p5() .gap_1() .children(message_content) - .when_some(edit_message_editor.clone(), |this, edit_editor| { - let edit_editor_clone = edit_editor.clone(); + .when_some(editing_message_state, |this, state| { + let focus_handle = state.editor.focus_handle(cx).clone(); this.w_full().justify_between().child( h_flex() .gap_0p5() @@ -1797,16 +1944,17 @@ impl ActiveThread { ) .shape(ui::IconButtonShape::Square) .icon_color(Color::Error) - .tooltip(move |window, cx| { - let focus_handle = - edit_editor_clone.focus_handle(cx); - Tooltip::for_action_in( - "Cancel Edit", - &menu::Cancel, - &focus_handle, - window, - cx, - ) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Cancel Edit", + &menu::Cancel, + &focus_handle, + window, + cx, + ) + } }) .on_click(cx.listener(Self::handle_cancel_click)), ) @@ -1815,18 +1963,20 @@ impl ActiveThread { "confirm-edit-message", IconName::Check, ) - .disabled(edit_editor.read(cx).is_empty(cx)) + .disabled(state.editor.read(cx).is_empty(cx)) .shape(ui::IconButtonShape::Square) .icon_color(Color::Success) - .tooltip(move |window, cx| { - let focus_handle = edit_editor.focus_handle(cx); - Tooltip::for_action_in( - "Regenerate", - &menu::Confirm, - &focus_handle, - window, - cx, - ) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Regenerate", + &menu::Confirm, + &focus_handle, + window, + cx, + ) + } }) .on_click( cx.listener(Self::handle_regenerate_click), @@ -1835,7 +1985,7 @@ impl ActiveThread { ) }), ) - .when(edit_message_editor.is_none(), |this| { + .when(editing_message_state.is_none(), |this| { this.tooltip(Tooltip::text("Click To Edit")) }) .on_click(cx.listener({ diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 18e5b859a6..2f788d18f4 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -424,6 +424,7 @@ impl AssistantPanel { ActiveThread::new( thread.clone(), thread_store.clone(), + message_editor_context_store.clone(), language_registry.clone(), workspace.clone(), window, @@ -626,7 +627,7 @@ impl AssistantPanel { let thread_view = ActiveView::thread(thread.clone(), window, cx); self.set_active_view(thread_view, window, cx); - let message_editor_context_store = cx.new(|_cx| { + let context_store = cx.new(|_cx| { crate::context_store::ContextStore::new( self.project.downgrade(), Some(self.thread_store.downgrade()), @@ -639,7 +640,7 @@ impl AssistantPanel { .update(cx, |this, cx| this.open_thread(&other_thread_id, cx)); cx.spawn({ - let context_store = message_editor_context_store.clone(); + let context_store = context_store.clone(); async move |_panel, cx| { let other_thread = other_thread_task.await?; @@ -664,6 +665,7 @@ impl AssistantPanel { ActiveThread::new( thread.clone(), self.thread_store.clone(), + context_store.clone(), self.language_registry.clone(), self.workspace.clone(), window, @@ -682,7 +684,7 @@ impl AssistantPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - message_editor_context_store, + context_store, self.prompt_store.clone(), self.thread_store.downgrade(), thread, @@ -843,7 +845,7 @@ impl AssistantPanel { ) { let thread_view = ActiveView::thread(thread.clone(), window, cx); self.set_active_view(thread_view, window, cx); - let message_editor_context_store = cx.new(|_cx| { + let context_store = cx.new(|_cx| { crate::context_store::ContextStore::new( self.project.downgrade(), Some(self.thread_store.downgrade()), @@ -860,6 +862,7 @@ impl AssistantPanel { ActiveThread::new( thread.clone(), self.thread_store.clone(), + context_store.clone(), self.language_registry.clone(), self.workspace.clone(), window, @@ -878,7 +881,7 @@ impl AssistantPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - message_editor_context_store, + context_store, self.prompt_store.clone(), self.thread_store.downgrade(), thread, diff --git a/crates/agent/src/context_store.rs b/crates/agent/src/context_store.rs index 50ab1fc239..f572235d08 100644 --- a/crates/agent/src/context_store.rs +++ b/crates/agent/src/context_store.rs @@ -21,7 +21,7 @@ use crate::context::{ SymbolContextHandle, ThreadContextHandle, }; use crate::context_strip::SuggestedContext; -use crate::thread::{Thread, ThreadId}; +use crate::thread::{MessageId, Thread, ThreadId}; pub struct ContextStore { project: WeakEntity, @@ -54,9 +54,14 @@ impl ContextStore { self.context_thread_ids.clear(); } - pub fn new_context_for_thread(&self, thread: &Thread) -> Vec { + pub fn new_context_for_thread( + &self, + thread: &Thread, + exclude_messages_from_id: Option, + ) -> Vec { let existing_context = thread .messages() + .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id)) .flat_map(|message| { message .loaded_context diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 3116bec5a3..3df2ecefc7 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -69,6 +69,56 @@ pub struct MessageEditor { const MAX_EDITOR_LINES: usize = 8; +pub(crate) fn create_editor( + workspace: WeakEntity, + context_store: WeakEntity, + thread_store: WeakEntity, + window: &mut Window, + cx: &mut App, +) -> Entity { + let language = Language::new( + language::LanguageConfig { + completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']), + ..Default::default() + }, + None, + ); + + let editor = cx.new(|cx| { + let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx)); + let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let mut editor = Editor::new( + editor::EditorMode::AutoHeight { + max_lines: MAX_EDITOR_LINES, + }, + buffer, + None, + window, + cx, + ); + editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx); + editor.set_show_indent_guides(false, cx); + editor.set_soft_wrap(); + editor.set_context_menu_options(ContextMenuOptions { + min_entries_visible: 12, + max_entries_visible: 12, + placement: Some(ContextMenuPlacement::Above), + }); + editor + }); + + let editor_entity = editor.downgrade(); + editor.update(cx, |editor, _| { + editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new( + workspace, + context_store, + Some(thread_store), + editor_entity, + )))); + }); + editor +} + impl MessageEditor { pub fn new( fs: Arc, @@ -83,47 +133,14 @@ impl MessageEditor { let context_picker_menu_handle = PopoverMenuHandle::default(); let model_selector_menu_handle = PopoverMenuHandle::default(); - let language = Language::new( - language::LanguageConfig { - completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']), - ..Default::default() - }, - None, + let editor = create_editor( + workspace.clone(), + context_store.downgrade(), + thread_store.clone(), + window, + cx, ); - let editor = cx.new(|cx| { - let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx)); - let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); - let mut editor = Editor::new( - editor::EditorMode::AutoHeight { - max_lines: MAX_EDITOR_LINES, - }, - buffer, - None, - window, - cx, - ); - editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx); - editor.set_show_indent_guides(false, cx); - editor.set_soft_wrap(); - editor.set_context_menu_options(ContextMenuOptions { - min_entries_visible: 12, - max_entries_visible: 12, - placement: Some(ContextMenuPlacement::Above), - }); - editor - }); - - let editor_entity = editor.downgrade(); - editor.update(cx, |editor, _| { - editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new( - workspace.clone(), - context_store.downgrade(), - Some(thread_store.clone()), - editor_entity, - )))); - }); - let context_strip = cx.new(|cx| { ContextStrip::new( context_store.clone(), @@ -1041,7 +1058,7 @@ impl MessageEditor { let load_task = cx.spawn(async move |this, cx| { let Ok(load_task) = this.update(cx, |this, cx| { let new_context = this.context_store.read_with(cx, |context_store, cx| { - context_store.new_context_for_thread(this.thread.read(cx)) + context_store.new_context_for_thread(this.thread.read(cx), None) }); load_context(new_context, &this.project, &this.prompt_store, cx) }) else { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 9d4e32ee3b..8e96ff5513 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -879,6 +879,7 @@ impl Thread { id: MessageId, new_role: Role, new_segments: Vec, + loaded_context: Option, cx: &mut Context, ) -> bool { let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { @@ -886,6 +887,9 @@ impl Thread { }; message.role = new_role; message.segments = new_segments; + if let Some(context) = loaded_context { + message.loaded_context = context; + } self.touch_updated_at(); cx.emit(ThreadEvent::MessageEdited(id)); true @@ -2546,6 +2550,7 @@ fn main() {{ "file1.rs": "fn function1() {}\n", "file2.rs": "fn function2() {}\n", "file3.rs": "fn function3() {}\n", + "file4.rs": "fn function4() {}\n", }), ) .await; @@ -2558,7 +2563,7 @@ fn main() {{ .await .unwrap(); let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx)) + store.new_context_for_thread(thread.read(cx), None) }); assert_eq!(new_contexts.len(), 1); let loaded_context = cx @@ -2573,7 +2578,7 @@ fn main() {{ .await .unwrap(); let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx)) + store.new_context_for_thread(thread.read(cx), None) }); assert_eq!(new_contexts.len(), 1); let loaded_context = cx @@ -2589,7 +2594,7 @@ fn main() {{ .await .unwrap(); let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx)) + store.new_context_for_thread(thread.read(cx), None) }); assert_eq!(new_contexts.len(), 1); let loaded_context = cx @@ -2640,6 +2645,55 @@ fn main() {{ assert!(!request.messages[3].string_contents().contains("file1.rs")); assert!(!request.messages[3].string_contents().contains("file2.rs")); assert!(request.messages[3].string_contents().contains("file3.rs")); + + add_file_to_context(&project, &context_store, "test/file4.rs", cx) + .await + .unwrap(); + let new_contexts = context_store.update(cx, |store, cx| { + store.new_context_for_thread(thread.read(cx), Some(message2_id)) + }); + assert_eq!(new_contexts.len(), 3); + let loaded_context = cx + .update(|cx| load_context(new_contexts, &project, &None, cx)) + .await + .loaded_context; + + assert!(!loaded_context.text.contains("file1.rs")); + assert!(loaded_context.text.contains("file2.rs")); + assert!(loaded_context.text.contains("file3.rs")); + assert!(loaded_context.text.contains("file4.rs")); + + let new_contexts = context_store.update(cx, |store, cx| { + // Remove file4.rs + store.remove_context(&loaded_context.contexts[2].handle(), cx); + store.new_context_for_thread(thread.read(cx), Some(message2_id)) + }); + assert_eq!(new_contexts.len(), 2); + let loaded_context = cx + .update(|cx| load_context(new_contexts, &project, &None, cx)) + .await + .loaded_context; + + assert!(!loaded_context.text.contains("file1.rs")); + assert!(loaded_context.text.contains("file2.rs")); + assert!(loaded_context.text.contains("file3.rs")); + assert!(!loaded_context.text.contains("file4.rs")); + + let new_contexts = context_store.update(cx, |store, cx| { + // Remove file3.rs + store.remove_context(&loaded_context.contexts[1].handle(), cx); + store.new_context_for_thread(thread.read(cx), Some(message2_id)) + }); + assert_eq!(new_contexts.len(), 1); + let loaded_context = cx + .update(|cx| load_context(new_contexts, &project, &None, cx)) + .await + .loaded_context; + + assert!(!loaded_context.text.contains("file1.rs")); + assert!(loaded_context.text.contains("file2.rs")); + assert!(!loaded_context.text.contains("file3.rs")); + assert!(!loaded_context.text.contains("file4.rs")); } #[gpui::test]