From 43ee604179ccda222eed29a173ac19e0514e8679 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 14 Aug 2025 15:30:18 -0300 Subject: [PATCH] acp: Clean up entry views on rewind (#36197) We were leaking diffs and terminals on rewind, we'll now clean them up. This PR also introduces a refactor of how we mantain the entry view state to use a `Vec` that's kept in sync with the thread entries. Release Notes: - N/A --- crates/acp_thread/Cargo.toml | 3 +- crates/acp_thread/src/acp_thread.rs | 36 +- crates/acp_thread/src/connection.rs | 156 +++++- crates/agent2/src/agent.rs | 8 +- crates/agent2/src/tests/mod.rs | 2 +- crates/agent_servers/src/acp/v0.rs | 2 +- crates/agent_servers/src/acp/v1.rs | 2 +- crates/agent_servers/src/claude.rs | 2 +- crates/agent_servers/src/e2e_tests.rs | 4 +- crates/agent_ui/Cargo.toml | 1 + crates/agent_ui/src/acp.rs | 1 + crates/agent_ui/src/acp/entry_view_state.rs | 351 +++++++++++++ crates/agent_ui/src/acp/thread_view.rs | 536 +++++++++----------- 13 files changed, 758 insertions(+), 346 deletions(-) create mode 100644 crates/agent_ui/src/acp/entry_view_state.rs diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 2d0fe2d264..2b9a6513c8 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -13,7 +13,7 @@ path = "src/acp_thread.rs" doctest = false [features] -test-support = ["gpui/test-support", "project/test-support"] +test-support = ["gpui/test-support", "project/test-support", "dep:parking_lot"] [dependencies] action_log.workspace = true @@ -29,6 +29,7 @@ gpui.workspace = true itertools.workspace = true language.workspace = true markdown.workspace = true +parking_lot = { workspace = true, optional = true } project.workspace = true prompt_store.workspace = true serde.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index da4d82712a..4bdc42ea2e 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1575,11 +1575,7 @@ mod tests { let project = Project::test(fs, [], cx).await; let connection = Rc::new(FakeAgentConnection::new()); let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await - }) + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -1699,11 +1695,7 @@ mod tests { )); let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await - }) + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -1786,7 +1778,7 @@ mod tests { .unwrap(); let thread = cx - .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx)) + .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx)) .await .unwrap(); @@ -1849,11 +1841,7 @@ mod tests { })); let thread = cx - .spawn(async move |mut cx| { - connection - .new_thread(project, Path::new(path!("/test")), &mut cx) - .await - }) + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -1961,10 +1949,11 @@ mod tests { } })); - let thread = connection - .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) .await .unwrap(); @@ -2021,8 +2010,8 @@ mod tests { .boxed_local() } })); - let thread = connection - .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) .await .unwrap(); @@ -2227,7 +2216,7 @@ mod tests { self: Rc, project: Entity, _cwd: &Path, - cx: &mut gpui::AsyncApp, + cx: &mut gpui::App, ) -> Task>> { let session_id = acp::SessionId( rand::thread_rng() @@ -2237,9 +2226,8 @@ mod tests { .collect::() .into(), ); - let thread = cx - .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) - .unwrap(); + let thread = + cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); self.sessions.lock().insert(session_id, thread.downgrade()); Task::ready(Ok(thread)) } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index c3167eb2d4..0f531acbde 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -2,7 +2,7 @@ use crate::AcpThread; use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; -use gpui::{AsyncApp, Entity, SharedString, Task}; +use gpui::{Entity, SharedString, Task}; use project::Project; use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; @@ -22,7 +22,7 @@ pub trait AgentConnection { self: Rc, project: Entity, cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>>; fn auth_methods(&self) -> &[acp::AuthMethod]; @@ -160,3 +160,155 @@ impl AgentModelList { } } } + +#[cfg(feature = "test-support")] +mod test_support { + use std::sync::Arc; + + use collections::HashMap; + use futures::future::try_join_all; + use gpui::{AppContext as _, WeakEntity}; + use parking_lot::Mutex; + + use super::*; + + #[derive(Clone, Default)] + pub struct StubAgentConnection { + sessions: Arc>>>, + permission_requests: HashMap>, + next_prompt_updates: Arc>>, + } + + impl StubAgentConnection { + pub fn new() -> Self { + Self { + next_prompt_updates: Default::default(), + permission_requests: HashMap::default(), + sessions: Arc::default(), + } + } + + pub fn set_next_prompt_updates(&self, updates: Vec) { + *self.next_prompt_updates.lock() = updates; + } + + pub fn with_permission_requests( + mut self, + permission_requests: HashMap>, + ) -> Self { + self.permission_requests = permission_requests; + self + } + + pub fn send_update( + &self, + session_id: acp::SessionId, + update: acp::SessionUpdate, + cx: &mut App, + ) { + self.sessions + .lock() + .get(&session_id) + .unwrap() + .update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + }) + .unwrap(); + } + } + + impl AgentConnection for StubAgentConnection { + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::App, + ) -> Task>> { + let session_id = acp::SessionId(self.sessions.lock().len().to_string().into()); + let thread = + cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + _id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + let mut tasks = vec![]; + for update in self.next_prompt_updates.lock().drain(..) { + let thread = thread.clone(); + let update = update.clone(); + let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update + && let Some(options) = self.permission_requests.get(&tool_call.id) + { + Some((tool_call.clone(), options.clone())) + } else { + None + }; + let task = cx.spawn(async move |cx| { + if let Some((tool_call, options)) = permission_request { + let permission = thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization( + tool_call.clone(), + options.clone(), + cx, + ) + })?; + permission.await?; + } + thread.update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + })?; + anyhow::Ok(()) + }); + tasks.push(task); + } + cx.spawn(async move |_| { + try_join_all(tasks).await?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + + fn session_editor( + &self, + _session_id: &agent_client_protocol::SessionId, + _cx: &mut App, + ) -> Option> { + Some(Rc::new(StubAgentSessionEditor)) + } + } + + struct StubAgentSessionEditor; + + impl AgentSessionEditor for StubAgentSessionEditor { + fn truncate(&self, _: UserMessageId, _: &mut App) -> Task> { + Task::ready(Ok(())) + } + } +} + +#[cfg(feature = "test-support")] +pub use test_support::*; diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 6ebcece2b5..9ac3c2d0e5 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -522,7 +522,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { self: Rc, project: Entity, cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>> { let agent = self.0.clone(); log::info!("Creating new thread for project at: {:?}", cwd); @@ -940,11 +940,7 @@ mod tests { // Create a thread/session let acp_thread = cx .update(|cx| { - Rc::new(connection.clone()).new_thread( - project.clone(), - Path::new("/a"), - &mut cx.to_async(), - ) + Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) }) .await .unwrap(); diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 637af73d1a..1df664c029 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -841,7 +841,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Create a thread using new_thread let connection_rc = Rc::new(connection.clone()); let acp_thread = cx - .update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async())) + .update(|cx| connection_rc.new_thread(project, cwd, cx)) .await .expect("new_thread should succeed"); diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 327613de67..15f8635cde 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -423,7 +423,7 @@ impl AgentConnection for AcpConnection { self: Rc, project: Entity, _cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>> { let task = self.connection.request_any( acp_old::InitializeParams { diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index de397fddf0..d93e3d023e 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -111,7 +111,7 @@ impl AgentConnection for AcpConnection { self: Rc, project: Entity, cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>> { let conn = self.connection.clone(); let sessions = self.sessions.clone(); diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index c394ec4a9c..dbcda00e48 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -74,7 +74,7 @@ impl AgentConnection for ClaudeAgentConnection { self: Rc, project: Entity, cwd: &Path, - cx: &mut AsyncApp, + cx: &mut App, ) -> Task>> { let cwd = cwd.to_owned(); cx.spawn(async move |cx| { diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index ec6ca29b9d..5af7010f26 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -422,8 +422,8 @@ pub async fn new_test_thread( .await .unwrap(); - let thread = connection - .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async()) + let thread = cx + .update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx)) .await .unwrap(); diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index b6a5710aa4..13fd9d13c5 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -103,6 +103,7 @@ workspace.workspace = true zed_actions.workspace = true [dev-dependencies] +acp_thread = { workspace = true, features = ["test-support"] } agent = { workspace = true, features = ["test-support"] } assistant_context = { workspace = true, features = ["test-support"] } assistant_tools.workspace = true diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs index 630aa730a6..831d296eeb 100644 --- a/crates/agent_ui/src/acp.rs +++ b/crates/agent_ui/src/acp.rs @@ -1,4 +1,5 @@ mod completion_provider; +mod entry_view_state; mod message_editor; mod model_selector; mod model_selector_popover; diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs new file mode 100644 index 0000000000..2f5f855e90 --- /dev/null +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -0,0 +1,351 @@ +use std::{collections::HashMap, ops::Range}; + +use acp_thread::AcpThread; +use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer}; +use gpui::{ + AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window, +}; +use language::language_settings::SoftWrap; +use settings::Settings as _; +use terminal_view::TerminalView; +use theme::ThemeSettings; +use ui::TextSize; +use workspace::Workspace; + +#[derive(Default)] +pub struct EntryViewState { + entries: Vec, +} + +impl EntryViewState { + pub fn entry(&self, index: usize) -> Option<&Entry> { + self.entries.get(index) + } + + pub fn sync_entry( + &mut self, + workspace: WeakEntity, + thread: Entity, + index: usize, + window: &mut Window, + cx: &mut App, + ) { + debug_assert!(index <= self.entries.len()); + let entry = if let Some(entry) = self.entries.get_mut(index) { + entry + } else { + self.entries.push(Entry::default()); + self.entries.last_mut().unwrap() + }; + + entry.sync_diff_multibuffers(&thread, index, window, cx); + entry.sync_terminals(&workspace, &thread, index, window, cx); + } + + pub fn remove(&mut self, range: Range) { + self.entries.drain(range); + } + + pub fn settings_changed(&mut self, cx: &mut App) { + for entry in self.entries.iter() { + for view in entry.views.values() { + if let Ok(diff_editor) = view.clone().downcast::() { + diff_editor.update(cx, |diff_editor, cx| { + diff_editor + .set_text_style_refinement(diff_editor_text_style_refinement(cx)); + cx.notify(); + }) + } + } + } + } +} + +pub struct Entry { + views: HashMap, +} + +impl Entry { + pub fn editor_for_diff(&self, diff: &Entity) -> Option> { + self.views + .get(&diff.entity_id()) + .cloned() + .map(|entity| entity.downcast::().unwrap()) + } + + pub fn terminal( + &self, + terminal: &Entity, + ) -> Option> { + self.views + .get(&terminal.entity_id()) + .cloned() + .map(|entity| entity.downcast::().unwrap()) + } + + fn sync_diff_multibuffers( + &mut self, + thread: &Entity, + index: usize, + window: &mut Window, + cx: &mut App, + ) { + let Some(entry) = thread.read(cx).entries().get(index) else { + return; + }; + + let multibuffers = entry + .diffs() + .map(|diff| diff.read(cx).multibuffer().clone()); + + let multibuffers = multibuffers.collect::>(); + + for multibuffer in multibuffers { + if self.views.contains_key(&multibuffer.entity_id()) { + return; + } + + let editor = cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + multibuffer.clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); + editor + }); + + let entity_id = multibuffer.entity_id(); + self.views.insert(entity_id, editor.into_any()); + } + } + + fn sync_terminals( + &mut self, + workspace: &WeakEntity, + thread: &Entity, + index: usize, + window: &mut Window, + cx: &mut App, + ) { + let Some(entry) = thread.read(cx).entries().get(index) else { + return; + }; + + let terminals = entry + .terminals() + .map(|terminal| terminal.clone()) + .collect::>(); + + for terminal in terminals { + if self.views.contains_key(&terminal.entity_id()) { + return; + } + + let Some(strong_workspace) = workspace.upgrade() else { + return; + }; + + let terminal_view = cx.new(|cx| { + let mut view = TerminalView::new( + terminal.read(cx).inner().clone(), + workspace.clone(), + None, + strong_workspace.read(cx).project().downgrade(), + window, + cx, + ); + view.set_embedded_mode(Some(1000), cx); + view + }); + + let entity_id = terminal.entity_id(); + self.views.insert(entity_id, terminal_view.into_any()); + } + } + + #[cfg(test)] + pub fn len(&self) -> usize { + self.views.len() + } +} + +fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { + TextStyleRefinement { + font_size: Some( + TextSize::Small + .rems(cx) + .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) + .into(), + ), + ..Default::default() + } +} + +impl Default for Entry { + fn default() -> Self { + Self { + // Avoid allocating in the heap by default + views: HashMap::with_capacity(0), + } + } +} + +#[cfg(test)] +mod tests { + use std::{path::Path, rc::Rc}; + + use acp_thread::{AgentConnection, StubAgentConnection}; + use agent_client_protocol as acp; + use agent_settings::AgentSettings; + use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind}; + use editor::{EditorSettings, RowInfo}; + use fs::FakeFs; + use gpui::{SemanticVersion, TestAppContext}; + use multi_buffer::MultiBufferRow; + use pretty_assertions::assert_matches; + use project::Project; + use serde_json::json; + use settings::{Settings as _, SettingsStore}; + use theme::ThemeSettings; + use util::path; + use workspace::Workspace; + + use crate::acp::entry_view_state::EntryViewState; + + #[gpui::test] + async fn test_diff_sync(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "hello.txt": "hi world" + }), + ) + .await; + let project = Project::test(fs, [Path::new(path!("/project"))], cx).await; + + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let tool_call = acp::ToolCall { + id: acp::ToolCallId("tool".into()), + title: "Tool call".into(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::InProgress, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/project/hello.txt".into(), + old_text: Some("hi world".into()), + new_text: "hello world".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + }; + let connection = Rc::new(StubAgentConnection::new()); + let thread = cx + .update(|_, cx| { + connection + .clone() + .new_thread(project, Path::new(path!("/project")), cx) + }) + .await + .unwrap(); + let session_id = thread.update(cx, |thread, _| thread.session_id().clone()); + + cx.update(|_, cx| { + connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx) + }); + + let mut view_state = EntryViewState::default(); + cx.update(|window, cx| { + view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx); + }); + + let multibuffer = thread.read_with(cx, |thread, cx| { + thread + .entries() + .get(0) + .unwrap() + .diffs() + .next() + .unwrap() + .read(cx) + .multibuffer() + .clone() + }); + + cx.run_until_parked(); + + let entry = view_state.entry(0).unwrap(); + let diff_editor = entry.editor_for_diff(&multibuffer).unwrap(); + assert_eq!( + diff_editor.read_with(cx, |editor, cx| editor.text(cx)), + "hi world\nhello world" + ); + let row_infos = diff_editor.read_with(cx, |editor, cx| { + let multibuffer = editor.buffer().read(cx); + multibuffer + .snapshot(cx) + .row_infos(MultiBufferRow(0)) + .collect::>() + }); + assert_matches!( + row_infos.as_slice(), + [ + RowInfo { + multibuffer_row: Some(MultiBufferRow(0)), + diff_status: Some(DiffHunkStatus { + kind: DiffHunkStatusKind::Deleted, + .. + }), + .. + }, + RowInfo { + multibuffer_row: Some(MultiBufferRow(1)), + diff_status: Some(DiffHunkStatus { + kind: DiffHunkStatusKind::Added, + .. + }), + .. + } + ] + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AgentSettings::register(cx); + workspace::init_settings(cx); + ThemeSettings::register(cx); + release_channel::init(SemanticVersion::default(), cx); + EditorSettings::register(cx); + }); + } +} diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 2a72cc6f48..0e90b93f4d 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -12,24 +12,22 @@ use audio::{Audio, Sound}; use buffer_diff::BufferDiff; use collections::{HashMap, HashSet}; use editor::scroll::Autoscroll; -use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey, SelectionEffects}; +use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; use gpui::{ Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, EdgesRefinement, Empty, Entity, - EntityId, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, - PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, - TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, - linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between, + FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay, + SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, + Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, + linear_gradient, list, percentage, point, prelude::*, pulsating_between, }; use language::Buffer; -use language::language_settings::SoftWrap; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use project::Project; use prompt_store::PromptId; use rope::Point; use settings::{Settings as _, SettingsStore}; use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration}; -use terminal_view::TerminalView; use text::Anchor; use theme::ThemeSettings; use ui::{ @@ -41,6 +39,7 @@ use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::{Chat, ToggleModelSelector}; use zed_actions::assistant::OpenRulesLibrary; +use super::entry_view_state::EntryViewState; use crate::acp::AcpModelSelectorPopover; use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; use crate::agent_diff::AgentDiff; @@ -61,8 +60,7 @@ pub struct AcpThreadView { thread_store: Entity, text_thread_store: Entity, thread_state: ThreadState, - diff_editors: HashMap>, - terminal_views: HashMap>, + entry_view_state: EntryViewState, message_editor: Entity, model_selector: Option>, notifications: Vec>, @@ -149,8 +147,7 @@ impl AcpThreadView { model_selector: None, notifications: Vec::new(), notification_subscriptions: HashMap::default(), - diff_editors: Default::default(), - terminal_views: Default::default(), + entry_view_state: EntryViewState::default(), list_state: list_state.clone(), scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()), last_error: None, @@ -209,11 +206,18 @@ impl AcpThreadView { // }) // .ok(); - let result = match connection - .clone() - .new_thread(project.clone(), &root_dir, cx) - .await - { + let Some(result) = cx + .update(|_, cx| { + connection + .clone() + .new_thread(project.clone(), &root_dir, cx) + }) + .log_err() + else { + return; + }; + + let result = match result.await { Err(e) => { let mut cx = cx.clone(); if e.is::() { @@ -480,16 +484,29 @@ impl AcpThreadView { ) { match event { AcpThreadEvent::NewEntry => { - let index = thread.read(cx).entries().len() - 1; - self.sync_thread_entry_view(index, window, cx); + let len = thread.read(cx).entries().len(); + let index = len - 1; + self.entry_view_state.sync_entry( + self.workspace.clone(), + thread.clone(), + index, + window, + cx, + ); self.list_state.splice(index..index, 1); } AcpThreadEvent::EntryUpdated(index) => { - self.sync_thread_entry_view(*index, window, cx); + self.entry_view_state.sync_entry( + self.workspace.clone(), + thread.clone(), + *index, + window, + cx, + ); self.list_state.splice(*index..index + 1, 1); } AcpThreadEvent::EntriesRemoved(range) => { - // TODO: Clean up unused diff editors and terminal views + self.entry_view_state.remove(range.clone()); self.list_state.splice(range.clone(), 0); } AcpThreadEvent::ToolAuthorizationRequired => { @@ -523,128 +540,6 @@ impl AcpThreadView { cx.notify(); } - fn sync_thread_entry_view( - &mut self, - entry_ix: usize, - window: &mut Window, - cx: &mut Context, - ) { - self.sync_diff_multibuffers(entry_ix, window, cx); - self.sync_terminals(entry_ix, window, cx); - } - - fn sync_diff_multibuffers( - &mut self, - entry_ix: usize, - window: &mut Window, - cx: &mut Context, - ) { - let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else { - return; - }; - - let multibuffers = multibuffers.collect::>(); - - for multibuffer in multibuffers { - if self.diff_editors.contains_key(&multibuffer.entity_id()) { - return; - } - - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - None, - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); - editor - }); - let entity_id = multibuffer.entity_id(); - cx.observe_release(&multibuffer, move |this, _, _| { - this.diff_editors.remove(&entity_id); - }) - .detach(); - - self.diff_editors.insert(entity_id, editor); - } - } - - fn entry_diff_multibuffers( - &self, - entry_ix: usize, - cx: &App, - ) -> Option>> { - let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - Some( - entry - .diffs() - .map(|diff| diff.read(cx).multibuffer().clone()), - ) - } - - fn sync_terminals(&mut self, entry_ix: usize, window: &mut Window, cx: &mut Context) { - let Some(terminals) = self.entry_terminals(entry_ix, cx) else { - return; - }; - - let terminals = terminals.collect::>(); - - for terminal in terminals { - if self.terminal_views.contains_key(&terminal.entity_id()) { - return; - } - - let terminal_view = cx.new(|cx| { - let mut view = TerminalView::new( - terminal.read(cx).inner().clone(), - self.workspace.clone(), - None, - self.project.downgrade(), - window, - cx, - ); - view.set_embedded_mode(Some(1000), cx); - view - }); - - let entity_id = terminal.entity_id(); - cx.observe_release(&terminal, move |this, _, _| { - this.terminal_views.remove(&entity_id); - }) - .detach(); - - self.terminal_views.insert(entity_id, terminal_view); - } - } - - fn entry_terminals( - &self, - entry_ix: usize, - cx: &App, - ) -> Option>> { - let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - Some(entry.terminals().map(|terminal| terminal.clone())) - } - fn authenticate( &mut self, method: acp::AuthMethodId, @@ -712,7 +607,7 @@ impl AcpThreadView { fn render_entry( &self, - index: usize, + entry_ix: usize, total_entries: usize, entry: &AgentThreadEntry, window: &mut Window, @@ -720,7 +615,7 @@ impl AcpThreadView { ) -> AnyElement { let primary = match &entry { AgentThreadEntry::UserMessage(message) => div() - .id(("user_message", index)) + .id(("user_message", entry_ix)) .py_4() .px_2() .children(message.id.clone().and_then(|message_id| { @@ -749,7 +644,9 @@ impl AcpThreadView { .text_xs() .id("message") .on_click(cx.listener({ - move |this, _, window, cx| this.start_editing_message(index, window, cx) + move |this, _, window, cx| { + this.start_editing_message(entry_ix, window, cx) + } })) .children( if let Some(editing) = self.editing_message.as_ref() @@ -787,7 +684,7 @@ impl AcpThreadView { AssistantMessageChunk::Thought { block } => { block.markdown().map(|md| { self.render_thinking_block( - index, + entry_ix, chunk_ix, md.clone(), window, @@ -803,7 +700,7 @@ impl AcpThreadView { v_flex() .px_5() .py_1() - .when(index + 1 == total_entries, |this| this.pb_4()) + .when(entry_ix + 1 == total_entries, |this| this.pb_4()) .w_full() .text_ui(cx) .child(message_body) @@ -815,10 +712,12 @@ impl AcpThreadView { div().w_full().py_1p5().px_5().map(|this| { if has_terminals { this.children(tool_call.terminals().map(|terminal| { - self.render_terminal_tool_call(terminal, tool_call, window, cx) + self.render_terminal_tool_call( + entry_ix, terminal, tool_call, window, cx, + ) })) } else { - this.child(self.render_tool_call(index, tool_call, window, cx)) + this.child(self.render_tool_call(entry_ix, tool_call, window, cx)) } }) } @@ -830,7 +729,7 @@ impl AcpThreadView { }; let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating); - let primary = if index == total_entries - 1 && !is_generating { + let primary = if entry_ix == total_entries - 1 && !is_generating { v_flex() .w_full() .child(primary) @@ -841,10 +740,10 @@ impl AcpThreadView { }; if let Some(editing) = self.editing_message.as_ref() - && editing.index < index + && editing.index < entry_ix { let backdrop = div() - .id(("backdrop", index)) + .id(("backdrop", entry_ix)) .size_full() .absolute() .inset_0() @@ -1125,7 +1024,9 @@ impl AcpThreadView { .w_full() .children(tool_call.content.iter().map(|content| { div() - .child(self.render_tool_call_content(content, tool_call, window, cx)) + .child( + self.render_tool_call_content(entry_ix, content, tool_call, window, cx), + ) .into_any_element() })) .child(self.render_permission_buttons( @@ -1139,7 +1040,9 @@ impl AcpThreadView { .w_full() .children(tool_call.content.iter().map(|content| { div() - .child(self.render_tool_call_content(content, tool_call, window, cx)) + .child( + self.render_tool_call_content(entry_ix, content, tool_call, window, cx), + ) .into_any_element() })), ToolCallStatus::Rejected => v_flex().size_0(), @@ -1257,6 +1160,7 @@ impl AcpThreadView { fn render_tool_call_content( &self, + entry_ix: usize, content: &ToolCallContent, tool_call: &ToolCall, window: &Window, @@ -1273,10 +1177,10 @@ impl AcpThreadView { } } ToolCallContent::Diff(diff) => { - self.render_diff_editor(&diff.read(cx).multibuffer(), cx) + self.render_diff_editor(entry_ix, &diff.read(cx).multibuffer(), cx) } ToolCallContent::Terminal(terminal) => { - self.render_terminal_tool_call(terminal, tool_call, window, cx) + self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx) } } } @@ -1420,6 +1324,7 @@ impl AcpThreadView { fn render_diff_editor( &self, + entry_ix: usize, multibuffer: &Entity, cx: &Context, ) -> AnyElement { @@ -1428,7 +1333,9 @@ impl AcpThreadView { .border_t_1() .border_color(self.tool_card_border_color(cx)) .child( - if let Some(editor) = self.diff_editors.get(&multibuffer.entity_id()) { + if let Some(entry) = self.entry_view_state.entry(entry_ix) + && let Some(editor) = entry.editor_for_diff(&multibuffer) + { editor.clone().into_any_element() } else { Empty.into_any() @@ -1439,6 +1346,7 @@ impl AcpThreadView { fn render_terminal_tool_call( &self, + entry_ix: usize, terminal: &Entity, tool_call: &ToolCall, window: &Window, @@ -1627,8 +1535,11 @@ impl AcpThreadView { })), ); - let show_output = - self.terminal_expanded && self.terminal_views.contains_key(&terminal.entity_id()); + let terminal_view = self + .entry_view_state + .entry(entry_ix) + .and_then(|entry| entry.terminal(&terminal)); + let show_output = self.terminal_expanded && terminal_view.is_some(); v_flex() .mb_2() @@ -1661,8 +1572,6 @@ impl AcpThreadView { ), ) .when(show_output, |this| { - let terminal_view = self.terminal_views.get(&terminal.entity_id()).unwrap(); - this.child( div() .pt_2() @@ -1672,7 +1581,7 @@ impl AcpThreadView { .bg(cx.theme().colors().editor_background) .rounded_b_md() .text_ui_sm(cx) - .child(terminal_view.clone()), + .children(terminal_view.clone()), ) }) .into_any() @@ -3075,12 +2984,7 @@ impl AcpThreadView { } fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context) { - for diff_editor in self.diff_editors.values() { - diff_editor.update(cx, |diff_editor, cx| { - diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx)); - cx.notify(); - }) - } + self.entry_view_state.settings_changed(cx); } pub(crate) fn insert_dragged_files( @@ -3379,18 +3283,6 @@ fn plan_label_markdown_style( } } -fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement { - TextStyleRefinement { - font_size: Some( - TextSize::Small - .rems(cx) - .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) - .into(), - ), - ..Default::default() - } -} - fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { let default_md_style = default_markdown_style(true, window, cx); @@ -3405,16 +3297,16 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { #[cfg(test)] pub(crate) mod tests { - use std::{path::Path, sync::Arc}; + use std::path::Path; + use acp_thread::StubAgentConnection; use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol::SessionId; use editor::EditorSettings; use fs::FakeFs; - use futures::future::try_join_all; use gpui::{SemanticVersion, TestAppContext, VisualTestContext}; - use parking_lot::Mutex; - use rand::Rng; + use project::Project; + use serde_json::json; use settings::SettingsStore; use super::*; @@ -3497,8 +3389,8 @@ pub(crate) mod tests { raw_input: None, raw_output: None, }; - let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) - .with_permission_requests(HashMap::from_iter([( + let connection = + StubAgentConnection::new().with_permission_requests(HashMap::from_iter([( tool_call_id, vec![acp::PermissionOption { id: acp::PermissionOptionId("1".into()), @@ -3506,6 +3398,9 @@ pub(crate) mod tests { kind: acp::PermissionOptionKind::AllowOnce, }], )])); + + connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(tool_call)]); + let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await; let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); @@ -3605,115 +3500,6 @@ pub(crate) mod tests { } } - #[derive(Clone, Default)] - struct StubAgentConnection { - sessions: Arc>>>, - permission_requests: HashMap>, - updates: Vec, - } - - impl StubAgentConnection { - fn new(updates: Vec) -> Self { - Self { - updates, - permission_requests: HashMap::default(), - sessions: Arc::default(), - } - } - - fn with_permission_requests( - mut self, - permission_requests: HashMap>, - ) -> Self { - self.permission_requests = permission_requests; - self - } - } - - impl AgentConnection for StubAgentConnection { - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } - - fn new_thread( - self: Rc, - project: Entity, - _cwd: &Path, - cx: &mut gpui::AsyncApp, - ) -> Task>> { - let session_id = SessionId( - rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(7) - .map(char::from) - .collect::() - .into(), - ); - let thread = cx - .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)) - .unwrap(); - self.sessions.lock().insert(session_id, thread.downgrade()); - Task::ready(Ok(thread)) - } - - fn authenticate( - &self, - _method_id: acp::AuthMethodId, - _cx: &mut App, - ) -> Task> { - unimplemented!() - } - - fn prompt( - &self, - _id: Option, - params: acp::PromptRequest, - cx: &mut App, - ) -> Task> { - let sessions = self.sessions.lock(); - let thread = sessions.get(¶ms.session_id).unwrap(); - let mut tasks = vec![]; - for update in &self.updates { - let thread = thread.clone(); - let update = update.clone(); - let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update - && let Some(options) = self.permission_requests.get(&tool_call.id) - { - Some((tool_call.clone(), options.clone())) - } else { - None - }; - let task = cx.spawn(async move |cx| { - if let Some((tool_call, options)) = permission_request { - let permission = thread.update(cx, |thread, cx| { - thread.request_tool_call_authorization( - tool_call.clone(), - options.clone(), - cx, - ) - })?; - permission.await?; - } - thread.update(cx, |thread, cx| { - thread.handle_session_update(update.clone(), cx).unwrap(); - })?; - anyhow::Ok(()) - }); - tasks.push(task); - } - cx.spawn(async move |_| { - try_join_all(tasks).await?; - Ok(acp::PromptResponse { - stop_reason: acp::StopReason::EndTurn, - }) - }) - } - - fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { - unimplemented!() - } - } - #[derive(Clone)] struct SaboteurAgentConnection; @@ -3722,19 +3508,17 @@ pub(crate) mod tests { self: Rc, project: Entity, _cwd: &Path, - cx: &mut gpui::AsyncApp, + cx: &mut gpui::App, ) -> Task>> { - Task::ready(Ok(cx - .new(|cx| { - AcpThread::new( - "SaboteurAgentConnection", - self, - project, - SessionId("test".into()), - cx, - ) - }) - .unwrap())) + Task::ready(Ok(cx.new(|cx| { + AcpThread::new( + "SaboteurAgentConnection", + self, + project, + SessionId("test".into()), + cx, + ) + }))) } fn auth_methods(&self) -> &[acp::AuthMethod] { @@ -3776,4 +3560,142 @@ pub(crate) mod tests { EditorSettings::register(cx); }); } + + #[gpui::test] + async fn test_rewind_views(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "test1.txt": "old content 1", + "test2.txt": "old content 2" + }), + ) + .await; + let project = Project::test(fs, [Path::new("/project")], cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_store = + cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx))); + let text_thread_store = + cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); + + let connection = Rc::new(StubAgentConnection::new()); + let thread_view = cx.update(|window, cx| { + cx.new(|cx| { + AcpThreadView::new( + Rc::new(StubAgentServer::new(connection.as_ref().clone())), + workspace.downgrade(), + project.clone(), + thread_store.clone(), + text_thread_store.clone(), + window, + cx, + ) + }) + }); + + cx.run_until_parked(); + + let thread = thread_view + .read_with(cx, |view, _| view.thread().cloned()) + .unwrap(); + + // First user message + connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("tool1".into()), + title: "Edit file 1".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/project/test1.txt".into(), + old_text: Some("old content 1".into()), + new_text: "new content 1".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + })]); + + thread + .update(cx, |thread, cx| thread.send_raw("Give me a diff", cx)) + .await + .unwrap(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 2); + }); + + thread_view.read_with(cx, |view, _| { + assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0); + assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1); + }); + + // Second user message + connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("tool2".into()), + title: "Edit file 2".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/project/test2.txt".into(), + old_text: Some("old content 2".into()), + new_text: "new content 2".into(), + }, + }], + locations: vec![], + raw_input: None, + raw_output: None, + })]); + + thread + .update(cx, |thread, cx| thread.send_raw("Another one", cx)) + .await + .unwrap(); + cx.run_until_parked(); + + let second_user_message_id = thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 4); + let AgentThreadEntry::UserMessage(user_message) = thread.entries().get(2).unwrap() + else { + panic!(); + }; + user_message.id.clone().unwrap() + }); + + thread_view.read_with(cx, |view, _| { + assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0); + assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1); + assert_eq!(view.entry_view_state.entry(2).unwrap().len(), 0); + assert_eq!(view.entry_view_state.entry(3).unwrap().len(), 1); + }); + + // Rewind to first message + thread + .update(cx, |thread, cx| thread.rewind(second_user_message_id, cx)) + .await + .unwrap(); + + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 2); + }); + + thread_view.read_with(cx, |view, _| { + assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0); + assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1); + + // Old views should be dropped + assert!(view.entry_view_state.entry(2).is_none()); + assert!(view.entry_view_state.entry(3).is_none()); + }); + } }