acp: Clean up entry views on rewind (#36197)

We were leaking diffs and terminals on rewind, we'll now clean them up.
This PR also introduces a refactor of how we mantain the entry view
state to use a `Vec` that's kept in sync with the thread entries.

Release Notes:

- N/A
This commit is contained in:
Agus Zubiaga 2025-08-14 15:30:18 -03:00 committed by GitHub
parent 2acfa5e948
commit 43ee604179
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 758 additions and 346 deletions

View file

@ -13,7 +13,7 @@ path = "src/acp_thread.rs"
doctest = false
[features]
test-support = ["gpui/test-support", "project/test-support"]
test-support = ["gpui/test-support", "project/test-support", "dep:parking_lot"]
[dependencies]
action_log.workspace = true
@ -29,6 +29,7 @@ gpui.workspace = true
itertools.workspace = true
language.workspace = true
markdown.workspace = true
parking_lot = { workspace = true, optional = true }
project.workspace = true
prompt_store.workspace = true
serde.workspace = true

View file

@ -1575,11 +1575,7 @@ mod tests {
let project = Project::test(fs, [], cx).await;
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
.spawn(async move |mut cx| {
connection
.new_thread(project, Path::new(path!("/test")), &mut cx)
.await
})
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
.await
.unwrap();
@ -1699,11 +1695,7 @@ mod tests {
));
let thread = cx
.spawn(async move |mut cx| {
connection
.new_thread(project, Path::new(path!("/test")), &mut cx)
.await
})
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
.await
.unwrap();
@ -1786,7 +1778,7 @@ mod tests {
.unwrap();
let thread = cx
.spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
.update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
.await
.unwrap();
@ -1849,11 +1841,7 @@ mod tests {
}));
let thread = cx
.spawn(async move |mut cx| {
connection
.new_thread(project, Path::new(path!("/test")), &mut cx)
.await
})
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
.await
.unwrap();
@ -1961,10 +1949,11 @@ mod tests {
}
}));
let thread = connection
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
let thread = cx
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
.await
.unwrap();
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
.await
.unwrap();
@ -2021,8 +2010,8 @@ mod tests {
.boxed_local()
}
}));
let thread = connection
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
let thread = cx
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
.await
.unwrap();
@ -2227,7 +2216,7 @@ mod tests {
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut gpui::AsyncApp,
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(
rand::thread_rng()
@ -2237,9 +2226,8 @@ mod tests {
.collect::<String>()
.into(),
);
let thread = cx
.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
.unwrap();
let thread =
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}

View file

@ -2,7 +2,7 @@ use crate::AcpThread;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use collections::IndexMap;
use gpui::{AsyncApp, Entity, SharedString, Task};
use gpui::{Entity, SharedString, Task};
use project::Project;
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
@ -22,7 +22,7 @@ pub trait AgentConnection {
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>>;
fn auth_methods(&self) -> &[acp::AuthMethod];
@ -160,3 +160,155 @@ impl AgentModelList {
}
}
}
#[cfg(feature = "test-support")]
mod test_support {
use std::sync::Arc;
use collections::HashMap;
use futures::future::try_join_all;
use gpui::{AppContext as _, WeakEntity};
use parking_lot::Mutex;
use super::*;
#[derive(Clone, Default)]
pub struct StubAgentConnection {
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
}
impl StubAgentConnection {
pub fn new() -> Self {
Self {
next_prompt_updates: Default::default(),
permission_requests: HashMap::default(),
sessions: Arc::default(),
}
}
pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
*self.next_prompt_updates.lock() = updates;
}
pub fn with_permission_requests(
mut self,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
) -> Self {
self.permission_requests = permission_requests;
self
}
pub fn send_update(
&self,
session_id: acp::SessionId,
update: acp::SessionUpdate,
cx: &mut App,
) {
self.sessions
.lock()
.get(&session_id)
.unwrap()
.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})
.unwrap();
}
}
impl AgentConnection for StubAgentConnection {
fn auth_methods(&self) -> &[acp::AuthMethod] {
&[]
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let thread =
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}
fn authenticate(
&self,
_method_id: acp::AuthMethodId,
_cx: &mut App,
) -> Task<gpui::Result<()>> {
unimplemented!()
}
fn prompt(
&self,
_id: Option<UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
let sessions = self.sessions.lock();
let thread = sessions.get(&params.session_id).unwrap();
let mut tasks = vec![];
for update in self.next_prompt_updates.lock().drain(..) {
let thread = thread.clone();
let update = update.clone();
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
&& let Some(options) = self.permission_requests.get(&tool_call.id)
{
Some((tool_call.clone(), options.clone()))
} else {
None
};
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
tool_call.clone(),
options.clone(),
cx,
)
})?;
permission.await?;
}
thread.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})?;
anyhow::Ok(())
});
tasks.push(task);
}
cx.spawn(async move |_| {
try_join_all(tasks).await?;
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
unimplemented!()
}
fn session_editor(
&self,
_session_id: &agent_client_protocol::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn AgentSessionEditor>> {
Some(Rc::new(StubAgentSessionEditor))
}
}
struct StubAgentSessionEditor;
impl AgentSessionEditor for StubAgentSessionEditor {
fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
}
#[cfg(feature = "test-support")]
pub use test_support::*;

View file

@ -522,7 +522,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let agent = self.0.clone();
log::info!("Creating new thread for project at: {:?}", cwd);
@ -940,11 +940,7 @@ mod tests {
// Create a thread/session
let acp_thread = cx
.update(|cx| {
Rc::new(connection.clone()).new_thread(
project.clone(),
Path::new("/a"),
&mut cx.to_async(),
)
Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
})
.await
.unwrap();

View file

@ -841,7 +841,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
// Create a thread using new_thread
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
.update(|cx| connection_rc.new_thread(project, cwd, cx))
.await
.expect("new_thread should succeed");

View file

@ -423,7 +423,7 @@ impl AgentConnection for AcpConnection {
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut AsyncApp,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let task = self.connection.request_any(
acp_old::InitializeParams {

View file

@ -111,7 +111,7 @@ impl AgentConnection for AcpConnection {
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let conn = self.connection.clone();
let sessions = self.sessions.clone();

View file

@ -74,7 +74,7 @@ impl AgentConnection for ClaudeAgentConnection {
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let cwd = cwd.to_owned();
cx.spawn(async move |cx| {

View file

@ -422,8 +422,8 @@ pub async fn new_test_thread(
.await
.unwrap();
let thread = connection
.new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
let thread = cx
.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
.await
.unwrap();

View file

@ -103,6 +103,7 @@ workspace.workspace = true
zed_actions.workspace = true
[dev-dependencies]
acp_thread = { workspace = true, features = ["test-support"] }
agent = { workspace = true, features = ["test-support"] }
assistant_context = { workspace = true, features = ["test-support"] }
assistant_tools.workspace = true

View file

@ -1,4 +1,5 @@
mod completion_provider;
mod entry_view_state;
mod message_editor;
mod model_selector;
mod model_selector_popover;

View file

@ -0,0 +1,351 @@
use std::{collections::HashMap, ops::Range};
use acp_thread::AcpThread;
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
use gpui::{
AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window,
};
use language::language_settings::SoftWrap;
use settings::Settings as _;
use terminal_view::TerminalView;
use theme::ThemeSettings;
use ui::TextSize;
use workspace::Workspace;
#[derive(Default)]
pub struct EntryViewState {
entries: Vec<Entry>,
}
impl EntryViewState {
pub fn entry(&self, index: usize) -> Option<&Entry> {
self.entries.get(index)
}
pub fn sync_entry(
&mut self,
workspace: WeakEntity<Workspace>,
thread: Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
debug_assert!(index <= self.entries.len());
let entry = if let Some(entry) = self.entries.get_mut(index) {
entry
} else {
self.entries.push(Entry::default());
self.entries.last_mut().unwrap()
};
entry.sync_diff_multibuffers(&thread, index, window, cx);
entry.sync_terminals(&workspace, &thread, index, window, cx);
}
pub fn remove(&mut self, range: Range<usize>) {
self.entries.drain(range);
}
pub fn settings_changed(&mut self, cx: &mut App) {
for entry in self.entries.iter() {
for view in entry.views.values() {
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
diff_editor.update(cx, |diff_editor, cx| {
diff_editor
.set_text_style_refinement(diff_editor_text_style_refinement(cx));
cx.notify();
})
}
}
}
}
}
pub struct Entry {
views: HashMap<EntityId, AnyEntity>,
}
impl Entry {
pub fn editor_for_diff(&self, diff: &Entity<MultiBuffer>) -> Option<Entity<Editor>> {
self.views
.get(&diff.entity_id())
.cloned()
.map(|entity| entity.downcast::<Editor>().unwrap())
}
pub fn terminal(
&self,
terminal: &Entity<acp_thread::Terminal>,
) -> Option<Entity<TerminalView>> {
self.views
.get(&terminal.entity_id())
.cloned()
.map(|entity| entity.downcast::<TerminalView>().unwrap())
}
fn sync_diff_multibuffers(
&mut self,
thread: &Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
let Some(entry) = thread.read(cx).entries().get(index) else {
return;
};
let multibuffers = entry
.diffs()
.map(|diff| diff.read(cx).multibuffer().clone());
let multibuffers = multibuffers.collect::<Vec<_>>();
for multibuffer in multibuffers {
if self.views.contains_key(&multibuffer.entity_id()) {
return;
}
let editor = cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
multibuffer.clone(),
None,
window,
cx,
);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_expand_excerpt_buttons(cx);
editor.set_show_vertical_scrollbar(false, cx);
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_indent_guides(false, cx);
editor.set_read_only(true);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
editor
});
let entity_id = multibuffer.entity_id();
self.views.insert(entity_id, editor.into_any());
}
}
fn sync_terminals(
&mut self,
workspace: &WeakEntity<Workspace>,
thread: &Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
let Some(entry) = thread.read(cx).entries().get(index) else {
return;
};
let terminals = entry
.terminals()
.map(|terminal| terminal.clone())
.collect::<Vec<_>>();
for terminal in terminals {
if self.views.contains_key(&terminal.entity_id()) {
return;
}
let Some(strong_workspace) = workspace.upgrade() else {
return;
};
let terminal_view = cx.new(|cx| {
let mut view = TerminalView::new(
terminal.read(cx).inner().clone(),
workspace.clone(),
None,
strong_workspace.read(cx).project().downgrade(),
window,
cx,
);
view.set_embedded_mode(Some(1000), cx);
view
});
let entity_id = terminal.entity_id();
self.views.insert(entity_id, terminal_view.into_any());
}
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.views.len()
}
}
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
TextStyleRefinement {
font_size: Some(
TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
}
}
impl Default for Entry {
fn default() -> Self {
Self {
// Avoid allocating in the heap by default
views: HashMap::with_capacity(0),
}
}
}
#[cfg(test)]
mod tests {
use std::{path::Path, rc::Rc};
use acp_thread::{AgentConnection, StubAgentConnection};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
use editor::{EditorSettings, RowInfo};
use fs::FakeFs;
use gpui::{SemanticVersion, TestAppContext};
use multi_buffer::MultiBufferRow;
use pretty_assertions::assert_matches;
use project::Project;
use serde_json::json;
use settings::{Settings as _, SettingsStore};
use theme::ThemeSettings;
use util::path;
use workspace::Workspace;
use crate::acp::entry_view_state::EntryViewState;
#[gpui::test]
async fn test_diff_sync(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/project",
json!({
"hello.txt": "hi world"
}),
)
.await;
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let tool_call = acp::ToolCall {
id: acp::ToolCallId("tool".into()),
title: "Tool call".into(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::InProgress,
content: vec![acp::ToolCallContent::Diff {
diff: acp::Diff {
path: "/project/hello.txt".into(),
old_text: Some("hi world".into()),
new_text: "hello world".into(),
},
}],
locations: vec![],
raw_input: None,
raw_output: None,
};
let connection = Rc::new(StubAgentConnection::new());
let thread = cx
.update(|_, cx| {
connection
.clone()
.new_thread(project, Path::new(path!("/project")), cx)
})
.await
.unwrap();
let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
cx.update(|_, cx| {
connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
});
let mut view_state = EntryViewState::default();
cx.update(|window, cx| {
view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx);
});
let multibuffer = thread.read_with(cx, |thread, cx| {
thread
.entries()
.get(0)
.unwrap()
.diffs()
.next()
.unwrap()
.read(cx)
.multibuffer()
.clone()
});
cx.run_until_parked();
let entry = view_state.entry(0).unwrap();
let diff_editor = entry.editor_for_diff(&multibuffer).unwrap();
assert_eq!(
diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
"hi world\nhello world"
);
let row_infos = diff_editor.read_with(cx, |editor, cx| {
let multibuffer = editor.buffer().read(cx);
multibuffer
.snapshot(cx)
.row_infos(MultiBufferRow(0))
.collect::<Vec<_>>()
});
assert_matches!(
row_infos.as_slice(),
[
RowInfo {
multibuffer_row: Some(MultiBufferRow(0)),
diff_status: Some(DiffHunkStatus {
kind: DiffHunkStatusKind::Deleted,
..
}),
..
},
RowInfo {
multibuffer_row: Some(MultiBufferRow(1)),
diff_status: Some(DiffHunkStatus {
kind: DiffHunkStatusKind::Added,
..
}),
..
}
]
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
AgentSettings::register(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
release_channel::init(SemanticVersion::default(), cx);
EditorSettings::register(cx);
});
}
}

View file

@ -12,24 +12,22 @@ use audio::{Audio, Sound};
use buffer_diff::BufferDiff;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey, SelectionEffects};
use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
use file_icons::FileIcons;
use gpui::{
Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, EdgesRefinement, Empty, Entity,
EntityId, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle,
TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between,
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
linear_gradient, list, percentage, point, prelude::*, pulsating_between,
};
use language::Buffer;
use language::language_settings::SoftWrap;
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
use project::Project;
use prompt_store::PromptId;
use rope::Point;
use settings::{Settings as _, SettingsStore};
use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
use terminal_view::TerminalView;
use text::Anchor;
use theme::ThemeSettings;
use ui::{
@ -41,6 +39,7 @@ use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::{Chat, ToggleModelSelector};
use zed_actions::assistant::OpenRulesLibrary;
use super::entry_view_state::EntryViewState;
use crate::acp::AcpModelSelectorPopover;
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
use crate::agent_diff::AgentDiff;
@ -61,8 +60,7 @@ pub struct AcpThreadView {
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
thread_state: ThreadState,
diff_editors: HashMap<EntityId, Entity<Editor>>,
terminal_views: HashMap<EntityId, Entity<TerminalView>>,
entry_view_state: EntryViewState,
message_editor: Entity<MessageEditor>,
model_selector: Option<Entity<AcpModelSelectorPopover>>,
notifications: Vec<WindowHandle<AgentNotification>>,
@ -149,8 +147,7 @@ impl AcpThreadView {
model_selector: None,
notifications: Vec::new(),
notification_subscriptions: HashMap::default(),
diff_editors: Default::default(),
terminal_views: Default::default(),
entry_view_state: EntryViewState::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
last_error: None,
@ -209,11 +206,18 @@ impl AcpThreadView {
// })
// .ok();
let result = match connection
.clone()
.new_thread(project.clone(), &root_dir, cx)
.await
{
let Some(result) = cx
.update(|_, cx| {
connection
.clone()
.new_thread(project.clone(), &root_dir, cx)
})
.log_err()
else {
return;
};
let result = match result.await {
Err(e) => {
let mut cx = cx.clone();
if e.is::<acp_thread::AuthRequired>() {
@ -480,16 +484,29 @@ impl AcpThreadView {
) {
match event {
AcpThreadEvent::NewEntry => {
let index = thread.read(cx).entries().len() - 1;
self.sync_thread_entry_view(index, window, cx);
let len = thread.read(cx).entries().len();
let index = len - 1;
self.entry_view_state.sync_entry(
self.workspace.clone(),
thread.clone(),
index,
window,
cx,
);
self.list_state.splice(index..index, 1);
}
AcpThreadEvent::EntryUpdated(index) => {
self.sync_thread_entry_view(*index, window, cx);
self.entry_view_state.sync_entry(
self.workspace.clone(),
thread.clone(),
*index,
window,
cx,
);
self.list_state.splice(*index..index + 1, 1);
}
AcpThreadEvent::EntriesRemoved(range) => {
// TODO: Clean up unused diff editors and terminal views
self.entry_view_state.remove(range.clone());
self.list_state.splice(range.clone(), 0);
}
AcpThreadEvent::ToolAuthorizationRequired => {
@ -523,128 +540,6 @@ impl AcpThreadView {
cx.notify();
}
fn sync_thread_entry_view(
&mut self,
entry_ix: usize,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.sync_diff_multibuffers(entry_ix, window, cx);
self.sync_terminals(entry_ix, window, cx);
}
fn sync_diff_multibuffers(
&mut self,
entry_ix: usize,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else {
return;
};
let multibuffers = multibuffers.collect::<Vec<_>>();
for multibuffer in multibuffers {
if self.diff_editors.contains_key(&multibuffer.entity_id()) {
return;
}
let editor = cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
multibuffer.clone(),
None,
window,
cx,
);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_expand_excerpt_buttons(cx);
editor.set_show_vertical_scrollbar(false, cx);
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_indent_guides(false, cx);
editor.set_read_only(true);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
editor
});
let entity_id = multibuffer.entity_id();
cx.observe_release(&multibuffer, move |this, _, _| {
this.diff_editors.remove(&entity_id);
})
.detach();
self.diff_editors.insert(entity_id, editor);
}
}
fn entry_diff_multibuffers(
&self,
entry_ix: usize,
cx: &App,
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
Some(
entry
.diffs()
.map(|diff| diff.read(cx).multibuffer().clone()),
)
}
fn sync_terminals(&mut self, entry_ix: usize, window: &mut Window, cx: &mut Context<Self>) {
let Some(terminals) = self.entry_terminals(entry_ix, cx) else {
return;
};
let terminals = terminals.collect::<Vec<_>>();
for terminal in terminals {
if self.terminal_views.contains_key(&terminal.entity_id()) {
return;
}
let terminal_view = cx.new(|cx| {
let mut view = TerminalView::new(
terminal.read(cx).inner().clone(),
self.workspace.clone(),
None,
self.project.downgrade(),
window,
cx,
);
view.set_embedded_mode(Some(1000), cx);
view
});
let entity_id = terminal.entity_id();
cx.observe_release(&terminal, move |this, _, _| {
this.terminal_views.remove(&entity_id);
})
.detach();
self.terminal_views.insert(entity_id, terminal_view);
}
}
fn entry_terminals(
&self,
entry_ix: usize,
cx: &App,
) -> Option<impl Iterator<Item = Entity<acp_thread::Terminal>>> {
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
Some(entry.terminals().map(|terminal| terminal.clone()))
}
fn authenticate(
&mut self,
method: acp::AuthMethodId,
@ -712,7 +607,7 @@ impl AcpThreadView {
fn render_entry(
&self,
index: usize,
entry_ix: usize,
total_entries: usize,
entry: &AgentThreadEntry,
window: &mut Window,
@ -720,7 +615,7 @@ impl AcpThreadView {
) -> AnyElement {
let primary = match &entry {
AgentThreadEntry::UserMessage(message) => div()
.id(("user_message", index))
.id(("user_message", entry_ix))
.py_4()
.px_2()
.children(message.id.clone().and_then(|message_id| {
@ -749,7 +644,9 @@ impl AcpThreadView {
.text_xs()
.id("message")
.on_click(cx.listener({
move |this, _, window, cx| this.start_editing_message(index, window, cx)
move |this, _, window, cx| {
this.start_editing_message(entry_ix, window, cx)
}
}))
.children(
if let Some(editing) = self.editing_message.as_ref()
@ -787,7 +684,7 @@ impl AcpThreadView {
AssistantMessageChunk::Thought { block } => {
block.markdown().map(|md| {
self.render_thinking_block(
index,
entry_ix,
chunk_ix,
md.clone(),
window,
@ -803,7 +700,7 @@ impl AcpThreadView {
v_flex()
.px_5()
.py_1()
.when(index + 1 == total_entries, |this| this.pb_4())
.when(entry_ix + 1 == total_entries, |this| this.pb_4())
.w_full()
.text_ui(cx)
.child(message_body)
@ -815,10 +712,12 @@ impl AcpThreadView {
div().w_full().py_1p5().px_5().map(|this| {
if has_terminals {
this.children(tool_call.terminals().map(|terminal| {
self.render_terminal_tool_call(terminal, tool_call, window, cx)
self.render_terminal_tool_call(
entry_ix, terminal, tool_call, window, cx,
)
}))
} else {
this.child(self.render_tool_call(index, tool_call, window, cx))
this.child(self.render_tool_call(entry_ix, tool_call, window, cx))
}
})
}
@ -830,7 +729,7 @@ impl AcpThreadView {
};
let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating);
let primary = if index == total_entries - 1 && !is_generating {
let primary = if entry_ix == total_entries - 1 && !is_generating {
v_flex()
.w_full()
.child(primary)
@ -841,10 +740,10 @@ impl AcpThreadView {
};
if let Some(editing) = self.editing_message.as_ref()
&& editing.index < index
&& editing.index < entry_ix
{
let backdrop = div()
.id(("backdrop", index))
.id(("backdrop", entry_ix))
.size_full()
.absolute()
.inset_0()
@ -1125,7 +1024,9 @@ impl AcpThreadView {
.w_full()
.children(tool_call.content.iter().map(|content| {
div()
.child(self.render_tool_call_content(content, tool_call, window, cx))
.child(
self.render_tool_call_content(entry_ix, content, tool_call, window, cx),
)
.into_any_element()
}))
.child(self.render_permission_buttons(
@ -1139,7 +1040,9 @@ impl AcpThreadView {
.w_full()
.children(tool_call.content.iter().map(|content| {
div()
.child(self.render_tool_call_content(content, tool_call, window, cx))
.child(
self.render_tool_call_content(entry_ix, content, tool_call, window, cx),
)
.into_any_element()
})),
ToolCallStatus::Rejected => v_flex().size_0(),
@ -1257,6 +1160,7 @@ impl AcpThreadView {
fn render_tool_call_content(
&self,
entry_ix: usize,
content: &ToolCallContent,
tool_call: &ToolCall,
window: &Window,
@ -1273,10 +1177,10 @@ impl AcpThreadView {
}
}
ToolCallContent::Diff(diff) => {
self.render_diff_editor(&diff.read(cx).multibuffer(), cx)
self.render_diff_editor(entry_ix, &diff.read(cx).multibuffer(), cx)
}
ToolCallContent::Terminal(terminal) => {
self.render_terminal_tool_call(terminal, tool_call, window, cx)
self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx)
}
}
}
@ -1420,6 +1324,7 @@ impl AcpThreadView {
fn render_diff_editor(
&self,
entry_ix: usize,
multibuffer: &Entity<MultiBuffer>,
cx: &Context<Self>,
) -> AnyElement {
@ -1428,7 +1333,9 @@ impl AcpThreadView {
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.child(
if let Some(editor) = self.diff_editors.get(&multibuffer.entity_id()) {
if let Some(entry) = self.entry_view_state.entry(entry_ix)
&& let Some(editor) = entry.editor_for_diff(&multibuffer)
{
editor.clone().into_any_element()
} else {
Empty.into_any()
@ -1439,6 +1346,7 @@ impl AcpThreadView {
fn render_terminal_tool_call(
&self,
entry_ix: usize,
terminal: &Entity<acp_thread::Terminal>,
tool_call: &ToolCall,
window: &Window,
@ -1627,8 +1535,11 @@ impl AcpThreadView {
})),
);
let show_output =
self.terminal_expanded && self.terminal_views.contains_key(&terminal.entity_id());
let terminal_view = self
.entry_view_state
.entry(entry_ix)
.and_then(|entry| entry.terminal(&terminal));
let show_output = self.terminal_expanded && terminal_view.is_some();
v_flex()
.mb_2()
@ -1661,8 +1572,6 @@ impl AcpThreadView {
),
)
.when(show_output, |this| {
let terminal_view = self.terminal_views.get(&terminal.entity_id()).unwrap();
this.child(
div()
.pt_2()
@ -1672,7 +1581,7 @@ impl AcpThreadView {
.bg(cx.theme().colors().editor_background)
.rounded_b_md()
.text_ui_sm(cx)
.child(terminal_view.clone()),
.children(terminal_view.clone()),
)
})
.into_any()
@ -3075,12 +2984,7 @@ impl AcpThreadView {
}
fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
for diff_editor in self.diff_editors.values() {
diff_editor.update(cx, |diff_editor, cx| {
diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
cx.notify();
})
}
self.entry_view_state.settings_changed(cx);
}
pub(crate) fn insert_dragged_files(
@ -3379,18 +3283,6 @@ fn plan_label_markdown_style(
}
}
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
TextStyleRefinement {
font_size: Some(
TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
}
}
fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let default_md_style = default_markdown_style(true, window, cx);
@ -3405,16 +3297,16 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
#[cfg(test)]
pub(crate) mod tests {
use std::{path::Path, sync::Arc};
use std::path::Path;
use acp_thread::StubAgentConnection;
use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol::SessionId;
use editor::EditorSettings;
use fs::FakeFs;
use futures::future::try_join_all;
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
use parking_lot::Mutex;
use rand::Rng;
use project::Project;
use serde_json::json;
use settings::SettingsStore;
use super::*;
@ -3497,8 +3389,8 @@ pub(crate) mod tests {
raw_input: None,
raw_output: None,
};
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
.with_permission_requests(HashMap::from_iter([(
let connection =
StubAgentConnection::new().with_permission_requests(HashMap::from_iter([(
tool_call_id,
vec![acp::PermissionOption {
id: acp::PermissionOptionId("1".into()),
@ -3506,6 +3398,9 @@ pub(crate) mod tests {
kind: acp::PermissionOptionKind::AllowOnce,
}],
)]));
connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(tool_call)]);
let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
@ -3605,115 +3500,6 @@ pub(crate) mod tests {
}
}
#[derive(Clone, Default)]
struct StubAgentConnection {
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
updates: Vec<acp::SessionUpdate>,
}
impl StubAgentConnection {
fn new(updates: Vec<acp::SessionUpdate>) -> Self {
Self {
updates,
permission_requests: HashMap::default(),
sessions: Arc::default(),
}
}
fn with_permission_requests(
mut self,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
) -> Self {
self.permission_requests = permission_requests;
self
}
}
impl AgentConnection for StubAgentConnection {
fn auth_methods(&self) -> &[acp::AuthMethod] {
&[]
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut gpui::AsyncApp,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = SessionId(
rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(7)
.map(char::from)
.collect::<String>()
.into(),
);
let thread = cx
.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
.unwrap();
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}
fn authenticate(
&self,
_method_id: acp::AuthMethodId,
_cx: &mut App,
) -> Task<gpui::Result<()>> {
unimplemented!()
}
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
let sessions = self.sessions.lock();
let thread = sessions.get(&params.session_id).unwrap();
let mut tasks = vec![];
for update in &self.updates {
let thread = thread.clone();
let update = update.clone();
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
&& let Some(options) = self.permission_requests.get(&tool_call.id)
{
Some((tool_call.clone(), options.clone()))
} else {
None
};
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
tool_call.clone(),
options.clone(),
cx,
)
})?;
permission.await?;
}
thread.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})?;
anyhow::Ok(())
});
tasks.push(task);
}
cx.spawn(async move |_| {
try_join_all(tasks).await?;
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
unimplemented!()
}
}
#[derive(Clone)]
struct SaboteurAgentConnection;
@ -3722,19 +3508,17 @@ pub(crate) mod tests {
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut gpui::AsyncApp,
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
Task::ready(Ok(cx
.new(|cx| {
AcpThread::new(
"SaboteurAgentConnection",
self,
project,
SessionId("test".into()),
cx,
)
})
.unwrap()))
Task::ready(Ok(cx.new(|cx| {
AcpThread::new(
"SaboteurAgentConnection",
self,
project,
SessionId("test".into()),
cx,
)
})))
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
@ -3776,4 +3560,142 @@ pub(crate) mod tests {
EditorSettings::register(cx);
});
}
#[gpui::test]
async fn test_rewind_views(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/project",
json!({
"test1.txt": "old content 1",
"test2.txt": "old content 2"
}),
)
.await;
let project = Project::test(fs, [Path::new("/project")], cx).await;
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store =
cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx)));
let text_thread_store =
cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx)));
let connection = Rc::new(StubAgentConnection::new());
let thread_view = cx.update(|window, cx| {
cx.new(|cx| {
AcpThreadView::new(
Rc::new(StubAgentServer::new(connection.as_ref().clone())),
workspace.downgrade(),
project.clone(),
thread_store.clone(),
text_thread_store.clone(),
window,
cx,
)
})
});
cx.run_until_parked();
let thread = thread_view
.read_with(cx, |view, _| view.thread().cloned())
.unwrap();
// First user message
connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall {
id: acp::ToolCallId("tool1".into()),
title: "Edit file 1".into(),
kind: acp::ToolKind::Edit,
status: acp::ToolCallStatus::Completed,
content: vec![acp::ToolCallContent::Diff {
diff: acp::Diff {
path: "/project/test1.txt".into(),
old_text: Some("old content 1".into()),
new_text: "new content 1".into(),
},
}],
locations: vec![],
raw_input: None,
raw_output: None,
})]);
thread
.update(cx, |thread, cx| thread.send_raw("Give me a diff", cx))
.await
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(thread.entries().len(), 2);
});
thread_view.read_with(cx, |view, _| {
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
});
// Second user message
connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall {
id: acp::ToolCallId("tool2".into()),
title: "Edit file 2".into(),
kind: acp::ToolKind::Edit,
status: acp::ToolCallStatus::Completed,
content: vec![acp::ToolCallContent::Diff {
diff: acp::Diff {
path: "/project/test2.txt".into(),
old_text: Some("old content 2".into()),
new_text: "new content 2".into(),
},
}],
locations: vec![],
raw_input: None,
raw_output: None,
})]);
thread
.update(cx, |thread, cx| thread.send_raw("Another one", cx))
.await
.unwrap();
cx.run_until_parked();
let second_user_message_id = thread.read_with(cx, |thread, _| {
assert_eq!(thread.entries().len(), 4);
let AgentThreadEntry::UserMessage(user_message) = thread.entries().get(2).unwrap()
else {
panic!();
};
user_message.id.clone().unwrap()
});
thread_view.read_with(cx, |view, _| {
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
assert_eq!(view.entry_view_state.entry(2).unwrap().len(), 0);
assert_eq!(view.entry_view_state.entry(3).unwrap().len(), 1);
});
// Rewind to first message
thread
.update(cx, |thread, cx| thread.rewind(second_user_message_id, cx))
.await
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(thread.entries().len(), 2);
});
thread_view.read_with(cx, |view, _| {
assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
// Old views should be dropped
assert!(view.entry_view_state.entry(2).is_none());
assert!(view.entry_view_state.entry(3).is_none());
});
}
}