From 993e0f55ec88f3209f00198e8a6872f43bd6340d Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 11 Jul 2025 09:38:42 -0600 Subject: [PATCH] ACP follow (#34235) Closes #ISSUE Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga Co-authored-by: Anthony Eid Co-authored-by: Ben Brandt --- Cargo.lock | 5 +- Cargo.toml | 2 +- assets/keymaps/default-linux.json | 3 +- assets/keymaps/default-macos.json | 3 +- crates/acp/Cargo.toml | 1 + crates/acp/src/acp.rs | 389 +++++++++++++--- crates/agent_ui/src/acp/thread_view.rs | 563 +++++++++++++++++++++--- crates/agent_ui/src/agent_diff.rs | 273 +++++++++--- crates/agent_ui/src/agent_panel.rs | 27 +- crates/agent_ui/src/message_editor.rs | 10 +- crates/assistant_tool/src/action_log.rs | 22 +- 11 files changed, 1090 insertions(+), 208 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bc6783ce92..624126c163 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9,6 +9,7 @@ dependencies = [ "agent_servers", "agentic-coding-protocol", "anyhow", + "assistant_tool", "async-pipe", "buffer_diff", "editor", @@ -263,9 +264,9 @@ dependencies = [ [[package]] name = "agentic-coding-protocol" -version = "0.0.6" +version = "0.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ac0351749af7bf53c65042ef69fefb9351aa8b7efa0a813d6281377605c37d" +checksum = "a75f520bcc049ebe40c8c99427aa61b48ad78a01bcc96a13b350b903dcfb9438" dependencies = [ "anyhow", "chrono", diff --git a/Cargo.toml b/Cargo.toml index fd5cbff545..7e3b43e58a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -404,7 +404,7 @@ zlog_settings = { path = "crates/zlog_settings" } # External crates # -agentic-coding-protocol = "0.0.6" +agentic-coding-protocol = "0.0.7" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 489e4e6d0c..c660383d10 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -320,7 +320,8 @@ "bindings": { "enter": "agent::Chat", "up": "agent::PreviousHistoryMessage", - "down": "agent::NextHistoryMessage" + "down": "agent::NextHistoryMessage", + "shift-ctrl-r": "agent::OpenAgentDiff" } }, { diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index c7ab7c9273..dc109d94aa 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -371,7 +371,8 @@ "bindings": { "enter": "agent::Chat", "up": "agent::PreviousHistoryMessage", - "down": "agent::NextHistoryMessage" + "down": "agent::NextHistoryMessage", + "shift-ctrl-r": "agent::OpenAgentDiff" } }, { diff --git a/crates/acp/Cargo.toml b/crates/acp/Cargo.toml index dae6292e28..1570aeaef0 100644 --- a/crates/acp/Cargo.toml +++ b/crates/acp/Cargo.toml @@ -20,6 +20,7 @@ gemini = [] agent_servers.workspace = true agentic-coding-protocol.workspace = true anyhow.workspace = true +assistant_tool.workspace = true buffer_diff.workspace = true editor.workspace = true futures.workspace = true diff --git a/crates/acp/src/acp.rs b/crates/acp/src/acp.rs index ddb7c50f7a..0aa57513a7 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp/src/acp.rs @@ -2,14 +2,19 @@ pub use acp::ToolCallId; use agent_servers::AgentServer; use agentic_coding_protocol::{self as acp, UserMessageChunk}; use anyhow::{Context as _, Result, anyhow}; +use assistant_tool::ActionLog; use buffer_diff::BufferDiff; use editor::{MultiBuffer, PathKey}; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use itertools::Itertools; -use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _}; +use language::{ + Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point, + text_diff, +}; use markdown::Markdown; -use project::Project; +use project::{AgentLocation, Project}; +use std::collections::HashMap; use std::error::Error; use std::fmt::{Formatter, Write}; use std::{ @@ -159,6 +164,18 @@ impl AgentThreadEntry { Self::ToolCall(too_call) => too_call.to_markdown(cx), } } + + pub fn diff(&self) -> Option<&Diff> { + if let AgentThreadEntry::ToolCall(ToolCall { + content: Some(ToolCallContent::Diff { diff }), + .. + }) = self + { + Some(&diff) + } else { + None + } + } } #[derive(Debug)] @@ -168,6 +185,7 @@ pub struct ToolCall { pub icon: IconName, pub content: Option, pub status: ToolCallStatus, + pub locations: Vec, } impl ToolCall { @@ -328,6 +346,8 @@ impl ToolCallContent { pub struct Diff { pub multibuffer: Entity, pub path: PathBuf, + pub new_buffer: Entity, + pub old_buffer: Entity, _task: Task>, } @@ -362,6 +382,7 @@ impl Diff { let task = cx.spawn({ let multibuffer = multibuffer.clone(); let path = path.clone(); + let new_buffer = new_buffer.clone(); async move |cx| { diff_task.await?; @@ -401,6 +422,8 @@ impl Diff { Self { multibuffer, path, + new_buffer, + old_buffer, _task: task, } } @@ -421,6 +444,8 @@ pub struct AcpThread { entries: Vec, title: SharedString, project: Entity, + action_log: Entity, + shared_buffers: HashMap, BufferSnapshot>, send_task: Option>, connection: Arc, child_status: Option>>, @@ -522,7 +547,11 @@ impl AcpThread { } }); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + Self { + action_log, + shared_buffers: Default::default(), entries: Default::default(), title: "ACP Thread".into(), project, @@ -534,6 +563,14 @@ impl AcpThread { }) } + pub fn action_log(&self) -> &Entity { + &self.action_log + } + + pub fn project(&self) -> &Entity { + &self.project + } + #[cfg(test)] pub fn fake( stdin: async_pipe::PipeWriter, @@ -558,7 +595,11 @@ impl AcpThread { } }); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + Self { + action_log, + shared_buffers: Default::default(), entries: Default::default(), title: "ACP Thread".into(), project, @@ -589,6 +630,26 @@ impl AcpThread { } } + pub fn has_pending_edit_tool_calls(&self) -> bool { + for entry in self.entries.iter().rev() { + match entry { + AgentThreadEntry::UserMessage(_) => return false, + AgentThreadEntry::ToolCall(ToolCall { + status: + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Running, + .. + }, + content: Some(ToolCallContent::Diff { .. }), + .. + }) => return true, + AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {} + } + } + + false + } + pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { self.entries.push(entry); cx.emit(AcpThreadEvent::NewEntry); @@ -644,65 +705,63 @@ impl AcpThread { pub fn request_tool_call( &mut self, - label: String, - icon: acp::Icon, - content: Option, - confirmation: acp::ToolCallConfirmation, + tool_call: acp::RequestToolCallConfirmationParams, cx: &mut Context, ) -> ToolCallRequest { let (tx, rx) = oneshot::channel(); let status = ToolCallStatus::WaitingForConfirmation { confirmation: ToolCallConfirmation::from_acp( - confirmation, + tool_call.confirmation, self.project.read(cx).languages().clone(), cx, ), respond_tx: tx, }; - let id = self.insert_tool_call(label, status, icon, content, cx); + let id = self.insert_tool_call(tool_call.tool_call, status, cx); ToolCallRequest { id, outcome: rx } } pub fn push_tool_call( &mut self, - label: String, - icon: acp::Icon, - content: Option, + request: acp::PushToolCallParams, cx: &mut Context, ) -> acp::ToolCallId { let status = ToolCallStatus::Allowed { status: acp::ToolCallStatus::Running, }; - self.insert_tool_call(label, status, icon, content, cx) + self.insert_tool_call(request, status, cx) } fn insert_tool_call( &mut self, - label: String, + tool_call: acp::PushToolCallParams, status: ToolCallStatus, - icon: acp::Icon, - content: Option, cx: &mut Context, ) -> acp::ToolCallId { let language_registry = self.project.read(cx).languages().clone(); let id = acp::ToolCallId(self.entries.len() as u64); - - self.push_entry( - AgentThreadEntry::ToolCall(ToolCall { - id, - label: cx.new(|cx| { - Markdown::new(label.into(), Some(language_registry.clone()), None, cx) - }), - icon: acp_icon_to_ui_icon(icon), - content: content - .map(|content| ToolCallContent::from_acp(content, language_registry, cx)), - status, + let call = ToolCall { + id, + label: cx.new(|cx| { + Markdown::new( + tool_call.label.into(), + Some(language_registry.clone()), + None, + cx, + ) }), - cx, - ); + icon: acp_icon_to_ui_icon(tool_call.icon), + content: tool_call + .content + .map(|content| ToolCallContent::from_acp(content, language_registry, cx)), + locations: tool_call.locations, + status, + }; + + self.push_entry(AgentThreadEntry::ToolCall(call), cx); id } @@ -804,14 +863,16 @@ impl AcpThread { false } - pub fn initialize(&self) -> impl use<> + Future> { + pub fn initialize( + &self, + ) -> impl use<> + Future> { let connection = self.connection.clone(); - async move { Ok(connection.request(acp::InitializeParams).await?) } + async move { connection.request(acp::InitializeParams).await } } - pub fn authenticate(&self) -> impl use<> + Future> { + pub fn authenticate(&self) -> impl use<> + Future> { let connection = self.connection.clone(); - async move { Ok(connection.request(acp::AuthenticateParams).await?) } + async move { connection.request(acp::AuthenticateParams).await } } #[cfg(test)] @@ -819,7 +880,7 @@ impl AcpThread { &mut self, message: &str, cx: &mut Context, - ) -> BoxFuture<'static, Result<()>> { + ) -> BoxFuture<'static, Result<(), acp::Error>> { self.send( acp::SendUserMessageParams { chunks: vec![acp::UserMessageChunk::Text { @@ -834,7 +895,7 @@ impl AcpThread { &mut self, message: acp::SendUserMessageParams, cx: &mut Context, - ) -> BoxFuture<'static, Result<()>> { + ) -> BoxFuture<'static, Result<(), acp::Error>> { let agent = self.connection.clone(); self.push_entry( AgentThreadEntry::UserMessage(UserMessage::from_acp( @@ -865,7 +926,7 @@ impl AcpThread { .boxed() } - pub fn cancel(&mut self, cx: &mut Context) -> Task> { + pub fn cancel(&mut self, cx: &mut Context) -> Task> { let agent = self.connection.clone(); if self.send_task.take().is_some() { @@ -898,13 +959,123 @@ impl AcpThread { } } } - }) + })?; + Ok(()) }) } else { Task::ready(Ok(())) } } + pub fn read_text_file( + &self, + request: acp::ReadTextFileParams, + cx: &mut Context, + ) -> Task> { + let project = self.project.clone(); + let action_log = self.action_log.clone(); + cx.spawn(async move |this, cx| { + let load = project.update(cx, |project, cx| { + let path = project + .project_path_for_absolute_path(&request.path, cx) + .context("invalid path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + let buffer = load??.await?; + + action_log.update(cx, |action_log, cx| { + action_log.buffer_read(buffer.clone(), cx); + })?; + project.update(cx, |project, cx| { + let position = buffer + .read(cx) + .snapshot() + .anchor_before(Point::new(request.line.unwrap_or_default(), 0)); + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position, + }), + cx, + ); + })?; + let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?; + this.update(cx, |this, _| { + let text = snapshot.text(); + this.shared_buffers.insert(buffer.clone(), snapshot); + text + }) + }) + } + + pub fn write_text_file( + &self, + path: PathBuf, + content: String, + cx: &mut Context, + ) -> Task> { + let project = self.project.clone(); + let action_log = self.action_log.clone(); + cx.spawn(async move |this, cx| { + let load = project.update(cx, |project, cx| { + let path = project + .project_path_for_absolute_path(&path, cx) + .context("invalid path")?; + anyhow::Ok(project.open_buffer(path, cx)) + }); + let buffer = load??.await?; + let snapshot = this.update(cx, |this, cx| { + this.shared_buffers + .get(&buffer) + .cloned() + .unwrap_or_else(|| buffer.read(cx).snapshot()) + })?; + let edits = cx + .background_executor() + .spawn(async move { + let old_text = snapshot.text(); + text_diff(old_text.as_str(), &content) + .into_iter() + .map(|(range, replacement)| { + ( + snapshot.anchor_after(range.start) + ..snapshot.anchor_before(range.end), + replacement, + ) + }) + .collect::>() + }) + .await; + cx.update(|cx| { + project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: edits + .last() + .map(|(range, _)| range.end) + .unwrap_or(Anchor::MIN), + }), + cx, + ); + }); + + action_log.update(cx, |action_log, cx| { + action_log.buffer_read(buffer.clone(), cx); + }); + buffer.update(cx, |buffer, cx| { + buffer.edit(edits, None, cx); + }); + action_log.update(cx, |action_log, cx| { + action_log.buffer_edited(buffer.clone(), cx); + }); + })?; + project + .update(cx, |project, cx| project.save_buffer(buffer, cx))? + .await + }) + } + pub fn child_status(&mut self) -> Option>> { self.child_status.take() } @@ -930,7 +1101,7 @@ impl acp::Client for AcpClientDelegate { async fn stream_assistant_message_chunk( &self, params: acp::StreamAssistantMessageChunkParams, - ) -> Result<()> { + ) -> Result<(), acp::Error> { let cx = &mut self.cx.clone(); cx.update(|cx| { @@ -947,45 +1118,37 @@ impl acp::Client for AcpClientDelegate { async fn request_tool_call_confirmation( &self, request: acp::RequestToolCallConfirmationParams, - ) -> Result { + ) -> Result { let cx = &mut self.cx.clone(); let ToolCallRequest { id, outcome } = cx .update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.request_tool_call( - request.label, - request.icon, - request.content, - request.confirmation, - cx, - ) - }) + self.thread + .update(cx, |thread, cx| thread.request_tool_call(request, cx)) })? .context("Failed to update thread")?; Ok(acp::RequestToolCallConfirmationResponse { id, - outcome: outcome.await?, + outcome: outcome.await.map_err(acp::Error::into_internal_error)?, }) } async fn push_tool_call( &self, request: acp::PushToolCallParams, - ) -> Result { + ) -> Result { let cx = &mut self.cx.clone(); let id = cx .update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.push_tool_call(request.label, request.icon, request.content, cx) - }) + self.thread + .update(cx, |thread, cx| thread.push_tool_call(request, cx)) })? .context("Failed to update thread")?; Ok(acp::PushToolCallResponse { id }) } - async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<()> { + async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> { let cx = &mut self.cx.clone(); cx.update(|cx| { @@ -997,6 +1160,34 @@ impl acp::Client for AcpClientDelegate { Ok(()) } + + async fn read_text_file( + &self, + request: acp::ReadTextFileParams, + ) -> Result { + let content = self + .cx + .update(|cx| { + self.thread + .update(cx, |thread, cx| thread.read_text_file(request, cx)) + })? + .context("Failed to update thread")? + .await?; + Ok(acp::ReadTextFileResponse { content }) + } + + async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> { + self.cx + .update(|cx| { + self.thread.update(cx, |thread, cx| { + thread.write_text_file(request.path, request.content, cx) + }) + })? + .context("Failed to update thread")? + .await?; + + Ok(()) + } } fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName { @@ -1100,6 +1291,80 @@ mod tests { ); } + #[gpui::test] + async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"})) + .await; + let project = Project::test(fs.clone(), [], cx).await; + let (thread, fake_server) = fake_acp_thread(project.clone(), cx); + let (worktree, pathbuf) = project + .update(cx, |project, cx| { + project.find_or_create_worktree(path!("/tmp/foo"), true, cx) + }) + .await + .unwrap(); + let buffer = project + .update(cx, |project, cx| { + project.open_buffer((worktree.read(cx).id(), pathbuf), cx) + }) + .await + .unwrap(); + + let (read_file_tx, read_file_rx) = oneshot::channel::<()>(); + let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx))); + + fake_server.update(cx, |fake_server, _| { + fake_server.on_user_message(move |_, server, mut cx| { + let read_file_tx = read_file_tx.clone(); + async move { + let content = server + .update(&mut cx, |server, _| { + server.send_to_zed(acp::ReadTextFileParams { + path: path!("/tmp/foo").into(), + line: None, + limit: None, + }) + })? + .await + .unwrap(); + assert_eq!(content.content, "one\ntwo\nthree\n"); + read_file_tx.take().unwrap().send(()).unwrap(); + server + .update(&mut cx, |server, _| { + server.send_to_zed(acp::WriteTextFileParams { + path: path!("/tmp/foo").into(), + content: "one\ntwo\nthree\nfour\nfive\n".to_string(), + }) + })? + .await + .unwrap(); + Ok(()) + } + }) + }); + + let request = thread.update(cx, |thread, cx| { + thread.send_raw("Extend the count in /tmp/foo", cx) + }); + read_file_rx.await.ok(); + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "zero\n".to_string())], None, cx); + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "zero\none\ntwo\nthree\nfour\nfive\n" + ); + assert_eq!( + String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(), + "zero\none\ntwo\nthree\nfour\nfive\n" + ); + request.await.unwrap(); + } + #[gpui::test] async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) { init_test(cx); @@ -1124,6 +1389,7 @@ mod tests { label: "Fetch".to_string(), icon: acp::Icon::Globe, content: None, + locations: vec![], }) })? .await @@ -1553,7 +1819,7 @@ mod tests { acp::SendUserMessageParams, Entity, AsyncApp, - ) -> LocalBoxFuture<'static, Result<()>>, + ) -> LocalBoxFuture<'static, Result<(), acp::Error>>, >, >, } @@ -1565,21 +1831,24 @@ mod tests { } impl acp::Agent for FakeAgent { - async fn initialize(&self) -> Result { + async fn initialize(&self) -> Result { Ok(acp::InitializeResponse { is_authenticated: true, }) } - async fn authenticate(&self) -> Result<()> { + async fn authenticate(&self) -> Result<(), acp::Error> { Ok(()) } - async fn cancel_send_message(&self) -> Result<()> { + async fn cancel_send_message(&self) -> Result<(), acp::Error> { Ok(()) } - async fn send_user_message(&self, request: acp::SendUserMessageParams) -> Result<()> { + async fn send_user_message( + &self, + request: acp::SendUserMessageParams, + ) -> Result<(), acp::Error> { let mut cx = self.cx.clone(); let handler = self .server @@ -1589,7 +1858,7 @@ mod tests { if let Some(handler) = handler { handler(request, self.server.clone(), self.cx.clone()).await } else { - anyhow::bail!("No handler for on_user_message") + Err(anyhow::anyhow!("No handler for on_user_message").into()) } } } @@ -1624,7 +1893,7 @@ mod tests { handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity, AsyncApp) -> F + 'static, ) where - F: Future> + 'static, + F: Future> + 'static, { self.on_user_message .replace(Rc::new(move |request, server, cx| { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 2e3bf54837..3db5e52a0a 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,33 +1,37 @@ +use std::collections::BTreeMap; use std::path::Path; use std::rc::Rc; use std::sync::Arc; use std::time::Duration; use agentic_coding_protocol::{self as acp}; +use assistant_tool::ActionLog; +use buffer_diff::BufferDiff; use collections::{HashMap, HashSet}; use editor::{ AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, - EditorStyle, MinimapVisibility, MultiBuffer, + EditorStyle, MinimapVisibility, MultiBuffer, PathKey, }; use file_icons::FileIcons; use futures::channel::oneshot; use gpui::{ - Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, Focusable, - Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, Subscription, TextStyle, - TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, div, list, percentage, - prelude::*, pulsating_between, + Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, + FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, + Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, + Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*, + pulsating_between, }; -use gpui::{FocusHandle, Task}; use language::language_settings::SoftWrap; use language::{Buffer, Language}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::Project; use settings::Settings as _; +use text::Anchor; use theme::ThemeSettings; -use ui::{Disclosure, Tooltip, prelude::*}; +use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*}; use util::ResultExt; -use workspace::Workspace; +use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; use ::acp::{ @@ -38,6 +42,8 @@ use ::acp::{ use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; use crate::acp::message_history::MessageHistory; +use crate::agent_diff::AgentDiff; +use crate::{AgentDiffPane, Follow, KeepAll, OpenAgentDiff, RejectAll}; const RESPONSE_PADDING_X: Pixels = px(19.); @@ -53,6 +59,7 @@ pub struct AcpThreadView { auth_task: Option>, expanded_tool_calls: HashSet, expanded_thinking_blocks: HashSet<(usize, usize)>, + edits_expanded: bool, message_history: MessageHistory, } @@ -62,7 +69,7 @@ enum ThreadState { }, Ready { thread: Entity, - _subscription: Subscription, + _subscription: [Subscription; 2], }, LoadError(LoadError), Unauthenticated { @@ -136,9 +143,9 @@ impl AcpThreadView { ); Self { - workspace, + workspace: workspace.clone(), project: project.clone(), - thread_state: Self::initial_state(project, window, cx), + thread_state: Self::initial_state(workspace, project, window, cx), message_editor, mention_set, diff_editors: Default::default(), @@ -147,11 +154,13 @@ impl AcpThreadView { auth_task: None, expanded_tool_calls: HashSet::default(), expanded_thinking_blocks: HashSet::default(), + edits_expanded: false, message_history: MessageHistory::new(), } } fn initial_state( + workspace: WeakEntity, project: Entity, window: &mut Window, cx: &mut Context, @@ -219,15 +228,23 @@ impl AcpThreadView { this.update_in(cx, |this, window, cx| { match result { Ok(()) => { - let subscription = + let thread_subscription = cx.subscribe_in(&thread, window, Self::handle_thread_event); + + let action_log = thread.read(cx).action_log().clone(); + let action_log_subscription = + cx.observe(&action_log, |_, _, cx| cx.notify()); + this.list_state .splice(0..0, thread.read(cx).entries().len()); + AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); + this.thread_state = ThreadState::Ready { thread, - _subscription: subscription, + _subscription: [thread_subscription, action_log_subscription], }; + cx.notify(); } Err(err) => { @@ -250,7 +267,7 @@ impl AcpThreadView { cx.notify(); } - fn thread(&self) -> Option<&Entity> { + pub fn thread(&self) -> Option<&Entity> { match &self.thread_state { ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { Some(thread) @@ -281,7 +298,6 @@ impl AcpThreadView { let mut ix = 0; let mut chunks: Vec = Vec::new(); - let project = self.project.clone(); self.message_editor.update(cx, |editor, cx| { let text = editor.text(cx); @@ -377,6 +393,33 @@ impl AcpThreadView { ); } + fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { + if let Some(thread) = self.thread() { + AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err(); + } + } + + fn open_edited_buffer( + &mut self, + buffer: &Entity, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread() else { + return; + }; + + let Some(diff) = + AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err() + else { + return; + }; + + diff.update(cx, |diff, cx| { + diff.move_to_path(PathKey::for_buffer(&buffer, cx), window, cx) + }) + } + fn set_draft_message( message_editor: Entity, mention_set: Arc>, @@ -464,7 +507,8 @@ impl AcpThreadView { let count = self.list_state.item_count(); match event { AcpThreadEvent::NewEntry => { - self.sync_thread_entry_view(thread.read(cx).entries().len() - 1, window, cx); + let index = thread.read(cx).entries().len() - 1; + self.sync_thread_entry_view(index, window, cx); self.list_state.splice(count..count, 1); } AcpThreadEvent::EntryUpdated(index) => { @@ -537,15 +581,7 @@ impl AcpThreadView { fn entry_diff_multibuffer(&self, entry_ix: usize, cx: &App) -> Option> { let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - if let AgentThreadEntry::ToolCall(ToolCall { - content: Some(ToolCallContent::Diff { diff }), - .. - }) = &entry - { - Some(diff.multibuffer.clone()) - } else { - None - } + entry.diff().map(|diff| diff.multibuffer.clone()) } fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { @@ -566,7 +602,8 @@ impl AcpThreadView { Markdown::new(format!("Error: {err}").into(), None, None, cx) })) } else { - this.thread_state = Self::initial_state(project.clone(), window, cx) + this.thread_state = + Self::initial_state(this.workspace.clone(), project.clone(), window, cx) } this.auth_task.take() }) @@ -1529,6 +1566,357 @@ impl AcpThreadView { container.into_any() } + fn render_edits_bar( + &self, + thread_entity: &Entity, + window: &mut Window, + cx: &Context, + ) -> Option { + let thread = thread_entity.read(cx); + let action_log = thread.action_log(); + let changed_buffers = action_log.read(cx).changed_buffers(cx); + + if changed_buffers.is_empty() { + return None; + } + + let editor_bg_color = cx.theme().colors().editor_background; + let active_color = cx.theme().colors().element_selected; + let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3)); + + let pending_edits = thread.has_pending_edit_tool_calls(); + let expanded = self.edits_expanded; + + v_flex() + .mt_1() + .mx_2() + .bg(bg_edit_files_disclosure) + .border_1() + .border_b_0() + .border_color(cx.theme().colors().border) + .rounded_t_md() + .shadow(vec![gpui::BoxShadow { + color: gpui::black().opacity(0.15), + offset: point(px(1.), px(-1.)), + blur_radius: px(3.), + spread_radius: px(0.), + }]) + .child(self.render_edits_bar_summary( + action_log, + &changed_buffers, + expanded, + pending_edits, + window, + cx, + )) + .when(expanded, |parent| { + parent.child(self.render_edits_bar_files( + action_log, + &changed_buffers, + pending_edits, + cx, + )) + }) + .into_any() + .into() + } + + fn render_edits_bar_summary( + &self, + action_log: &Entity, + changed_buffers: &BTreeMap, Entity>, + expanded: bool, + pending_edits: bool, + window: &mut Window, + cx: &Context, + ) -> Div { + const EDIT_NOT_READY_TOOLTIP_LABEL: &str = "Wait until file edits are complete."; + + let focus_handle = self.focus_handle(cx); + + h_flex() + .p_1() + .justify_between() + .when(expanded, |this| { + this.border_b_1().border_color(cx.theme().colors().border) + }) + .child( + h_flex() + .id("edits-container") + .cursor_pointer() + .w_full() + .gap_1() + .child(Disclosure::new("edits-disclosure", expanded)) + .map(|this| { + if pending_edits { + this.child( + Label::new(format!( + "Editing {} {}…", + changed_buffers.len(), + if changed_buffers.len() == 1 { + "file" + } else { + "files" + } + )) + .color(Color::Muted) + .size(LabelSize::Small) + .with_animation( + "edit-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.3, 0.7)), + |label, delta| label.alpha(delta), + ), + ) + } else { + this.child( + Label::new("Edits") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(Label::new("•").size(LabelSize::XSmall).color(Color::Muted)) + .child( + Label::new(format!( + "{} {}", + changed_buffers.len(), + if changed_buffers.len() == 1 { + "file" + } else { + "files" + } + )) + .size(LabelSize::Small) + .color(Color::Muted), + ) + } + }) + .on_click(cx.listener(|this, _, _, cx| { + this.edits_expanded = !this.edits_expanded; + cx.notify(); + })), + ) + .child( + h_flex() + .gap_1() + .child( + IconButton::new("review-changes", IconName::ListTodo) + .icon_size(IconSize::Small) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Review Changes", + &OpenAgentDiff, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener(|_, _, window, cx| { + window.dispatch_action(OpenAgentDiff.boxed_clone(), cx); + })), + ) + .child(Divider::vertical().color(DividerColor::Border)) + .child( + Button::new("reject-all-changes", "Reject All") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .when(pending_edits, |this| { + this.tooltip(Tooltip::text(EDIT_NOT_READY_TOOLTIP_LABEL)) + }) + .key_binding( + KeyBinding::for_action_in( + &RejectAll, + &focus_handle.clone(), + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .on_click({ + let action_log = action_log.clone(); + cx.listener(move |_, _, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log.reject_all_edits(cx).detach(); + }) + }) + }), + ) + .child( + Button::new("keep-all-changes", "Keep All") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .when(pending_edits, |this| { + this.tooltip(Tooltip::text(EDIT_NOT_READY_TOOLTIP_LABEL)) + }) + .key_binding( + KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(10.))), + ) + .on_click({ + let action_log = action_log.clone(); + cx.listener(move |_, _, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log.keep_all_edits(cx); + }) + }) + }), + ), + ) + } + + fn render_edits_bar_files( + &self, + action_log: &Entity, + changed_buffers: &BTreeMap, Entity>, + pending_edits: bool, + cx: &Context, + ) -> Div { + let editor_bg_color = cx.theme().colors().editor_background; + + v_flex().children(changed_buffers.into_iter().enumerate().flat_map( + |(index, (buffer, _diff))| { + let file = buffer.read(cx).file()?; + let path = file.path(); + + let file_path = path.parent().and_then(|parent| { + let parent_str = parent.to_string_lossy(); + + if parent_str.is_empty() { + None + } else { + Some( + Label::new(format!("/{}{}", parent_str, std::path::MAIN_SEPARATOR_STR)) + .color(Color::Muted) + .size(LabelSize::XSmall) + .buffer_font(cx), + ) + } + }); + + let file_name = path.file_name().map(|name| { + Label::new(name.to_string_lossy().to_string()) + .size(LabelSize::XSmall) + .buffer_font(cx) + }); + + let file_icon = FileIcons::get_icon(&path, cx) + .map(Icon::from_path) + .map(|icon| icon.color(Color::Muted).size(IconSize::Small)) + .unwrap_or_else(|| { + Icon::new(IconName::File) + .color(Color::Muted) + .size(IconSize::Small) + }); + + let overlay_gradient = linear_gradient( + 90., + linear_color_stop(editor_bg_color, 1.), + linear_color_stop(editor_bg_color.opacity(0.2), 0.), + ); + + let element = h_flex() + .group("edited-code") + .id(("file-container", index)) + .relative() + .py_1() + .pl_2() + .pr_1() + .gap_2() + .justify_between() + .bg(editor_bg_color) + .when(index < changed_buffers.len() - 1, |parent| { + parent.border_color(cx.theme().colors().border).border_b_1() + }) + .child( + h_flex() + .id(("file-name", index)) + .pr_8() + .gap_1p5() + .max_w_full() + .overflow_x_scroll() + .child(file_icon) + .child(h_flex().gap_0p5().children(file_name).children(file_path)) + .on_click({ + let buffer = buffer.clone(); + cx.listener(move |this, _, window, cx| { + this.open_edited_buffer(&buffer, window, cx); + }) + }), + ) + .child( + h_flex() + .gap_1() + .visible_on_hover("edited-code") + .child( + Button::new("review", "Review") + .label_size(LabelSize::Small) + .on_click({ + let buffer = buffer.clone(); + cx.listener(move |this, _, window, cx| { + this.open_edited_buffer(&buffer, window, cx); + }) + }), + ) + .child(Divider::vertical().color(DividerColor::BorderVariant)) + .child( + Button::new("reject-file", "Reject") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .on_click({ + let buffer = buffer.clone(); + let action_log = action_log.clone(); + move |_, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log + .reject_edits_in_ranges( + buffer.clone(), + vec![Anchor::MIN..Anchor::MAX], + cx, + ) + .detach_and_log_err(cx); + }) + } + }), + ) + .child( + Button::new("keep-file", "Keep") + .label_size(LabelSize::Small) + .disabled(pending_edits) + .on_click({ + let buffer = buffer.clone(); + let action_log = action_log.clone(); + move |_, _, cx| { + action_log.update(cx, |action_log, cx| { + action_log.keep_edits_in_range( + buffer.clone(), + Anchor::MIN..Anchor::MAX, + cx, + ); + }) + } + }), + ), + ) + .child( + div() + .id("gradient-overlay") + .absolute() + .h_full() + .w_12() + .top_0() + .bottom_0() + .right(px(152.)) + .bg(overlay_gradient), + ); + + Some(element) + }, + )) + } + fn render_message_editor(&mut self, cx: &mut Context) -> AnyElement { let settings = ThemeSettings::get_global(cx); let font_size = TextSize::Small @@ -1559,6 +1947,76 @@ impl AcpThreadView { .into_any() } + fn render_send_button(&self, cx: &mut Context) -> AnyElement { + if self.thread().map_or(true, |thread| { + thread.read(cx).status() == ThreadStatus::Idle + }) { + let is_editor_empty = self.message_editor.read(cx).is_empty(cx); + IconButton::new("send-message", IconName::Send) + .icon_color(Color::Accent) + .style(ButtonStyle::Filled) + .disabled(self.thread().is_none() || is_editor_empty) + .on_click(cx.listener(|this, _, window, cx| { + this.chat(&Chat, window, cx); + })) + .when(!is_editor_empty, |button| { + button.tooltip(move |window, cx| Tooltip::for_action("Send", &Chat, window, cx)) + }) + .when(is_editor_empty, |button| { + button.tooltip(Tooltip::text("Type a message to submit")) + }) + .into_any_element() + } else { + IconButton::new("stop-generation", IconName::StopFilled) + .icon_color(Color::Error) + .style(ButtonStyle::Tinted(ui::TintColor::Error)) + .tooltip(move |window, cx| { + Tooltip::for_action("Stop Generation", &editor::actions::Cancel, window, cx) + }) + .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx))) + .into_any_element() + } + } + + fn render_follow_toggle(&self, cx: &mut Context) -> impl IntoElement { + let following = self + .workspace + .read_with(cx, |workspace, _| { + workspace.is_being_followed(CollaboratorId::Agent) + }) + .unwrap_or(false); + + IconButton::new("follow-agent", IconName::Crosshair) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .toggle_state(following) + .selected_icon_color(Some(Color::Custom(cx.theme().players().agent().cursor))) + .tooltip(move |window, cx| { + if following { + Tooltip::for_action("Stop Following Agent", &Follow, window, cx) + } else { + Tooltip::with_meta( + "Follow Agent", + Some(&Follow), + "Track the agent's location as it reads and edits files.", + window, + cx, + ) + } + }) + .on_click(cx.listener(move |this, _, window, cx| { + this.workspace + .update(cx, |workspace, cx| { + if following { + workspace.unfollow(CollaboratorId::Agent, window, cx); + } else { + workspace.follow(CollaboratorId::Agent, window, cx); + } + }) + .ok(); + })) + } + fn render_markdown(&self, markdown: Entity, style: MarkdownStyle) -> MarkdownElement { let workspace = self.workspace.clone(); MarkdownElement::new(markdown, style).on_url_click(move |text, window, cx| { @@ -1673,10 +2131,6 @@ impl Focusable for AcpThreadView { impl Render for AcpThreadView { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let text = self.message_editor.read(cx).text(cx); - let is_editor_empty = text.is_empty(); - let focus_handle = self.message_editor.focus_handle(cx); - let open_as_markdown = IconButton::new("open-as-markdown", IconName::DocumentText) .icon_size(IconSize::XSmall) .icon_color(Color::Ignored) @@ -1702,6 +2156,7 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::chat)) .on_action(cx.listener(Self::previous_history_message)) .on_action(cx.listener(Self::next_history_message)) + .on_action(cx.listener(Self::open_agent_diff)) .child(match &self.thread_state { ThreadState::Unauthenticated { .. } => v_flex() .p_2() @@ -1755,6 +2210,7 @@ impl Render for AcpThreadView { .child(LoadingLabel::new("").size(LabelSize::Small)) .into(), }) + .children(self.render_edits_bar(&thread, window, cx)) } else { this.child(self.render_empty_state(false, cx)) } @@ -1782,47 +2238,12 @@ impl Render for AcpThreadView { .border_t_1() .border_color(cx.theme().colors().border) .child(self.render_message_editor(cx)) - .child({ - let thread = self.thread(); - - h_flex().justify_end().child( - if thread.map_or(true, |thread| { - thread.read(cx).status() == ThreadStatus::Idle - }) { - IconButton::new("send-message", IconName::Send) - .icon_color(Color::Accent) - .style(ButtonStyle::Filled) - .disabled(thread.is_none() || is_editor_empty) - .on_click({ - let focus_handle = focus_handle.clone(); - move |_event, window, cx| { - focus_handle.dispatch_action(&Chat, window, cx); - } - }) - .when(!is_editor_empty, |button| { - button.tooltip(move |window, cx| { - Tooltip::for_action("Send", &Chat, window, cx) - }) - }) - .when(is_editor_empty, |button| { - button.tooltip(Tooltip::text("Type a message to submit")) - }) - } else { - IconButton::new("stop-generation", IconName::StopFilled) - .icon_color(Color::Error) - .style(ButtonStyle::Tinted(ui::TintColor::Error)) - .tooltip(move |window, cx| { - Tooltip::for_action( - "Stop Generation", - &editor::actions::Cancel, - window, - cx, - ) - }) - .on_click(cx.listener(|this, _event, _, cx| this.cancel(cx))) - }, - ) - }), + .child( + h_flex() + .justify_between() + .child(self.render_follow_toggle(cx)) + .child(self.render_send_button(cx)), + ), ) } } diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 1a0f3ff27d..31fb0dd69f 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1,7 +1,9 @@ use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll}; -use agent::{Thread, ThreadEvent}; +use acp::{AcpThread, AcpThreadEvent}; +use agent::{Thread, ThreadEvent, ThreadSummary}; use agent_settings::AgentSettings; use anyhow::Result; +use assistant_tool::ActionLog; use buffer_diff::DiffHunkStatus; use collections::{HashMap, HashSet}; use editor::{ @@ -41,16 +43,108 @@ use zed_actions::assistant::ToggleFocus; pub struct AgentDiffPane { multibuffer: Entity, editor: Entity, - thread: Entity, + thread: AgentDiffThread, focus_handle: FocusHandle, workspace: WeakEntity, title: SharedString, _subscriptions: Vec, } +#[derive(PartialEq, Eq, Clone)] +pub enum AgentDiffThread { + Native(Entity), + AcpThread(Entity), +} + +impl AgentDiffThread { + fn project(&self, cx: &App) -> Entity { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).project().clone(), + AgentDiffThread::AcpThread(thread) => thread.read(cx).project().clone(), + } + } + fn action_log(&self, cx: &App) -> Entity { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).action_log().clone(), + AgentDiffThread::AcpThread(thread) => thread.read(cx).action_log().clone(), + } + } + + fn summary(&self, cx: &App) -> ThreadSummary { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).summary().clone(), + AgentDiffThread::AcpThread(thread) => ThreadSummary::Ready(thread.read(cx).title()), + } + } + + fn is_generating(&self, cx: &App) -> bool { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).is_generating(), + AgentDiffThread::AcpThread(thread) => { + thread.read(cx).status() == acp::ThreadStatus::Generating + } + } + } + + fn has_pending_edit_tool_uses(&self, cx: &App) -> bool { + match self { + AgentDiffThread::Native(thread) => thread.read(cx).has_pending_edit_tool_uses(), + AgentDiffThread::AcpThread(thread) => thread.read(cx).has_pending_edit_tool_calls(), + } + } + + fn downgrade(&self) -> WeakAgentDiffThread { + match self { + AgentDiffThread::Native(thread) => WeakAgentDiffThread::Native(thread.downgrade()), + AgentDiffThread::AcpThread(thread) => { + WeakAgentDiffThread::AcpThread(thread.downgrade()) + } + } + } +} + +impl From> for AgentDiffThread { + fn from(entity: Entity) -> Self { + AgentDiffThread::Native(entity) + } +} + +impl From> for AgentDiffThread { + fn from(entity: Entity) -> Self { + AgentDiffThread::AcpThread(entity) + } +} + +#[derive(PartialEq, Eq, Clone)] +pub enum WeakAgentDiffThread { + Native(WeakEntity), + AcpThread(WeakEntity), +} + +impl WeakAgentDiffThread { + pub fn upgrade(&self) -> Option { + match self { + WeakAgentDiffThread::Native(weak) => weak.upgrade().map(AgentDiffThread::Native), + WeakAgentDiffThread::AcpThread(weak) => weak.upgrade().map(AgentDiffThread::AcpThread), + } + } +} + +impl From> for WeakAgentDiffThread { + fn from(entity: WeakEntity) -> Self { + WeakAgentDiffThread::Native(entity) + } +} + +impl From> for WeakAgentDiffThread { + fn from(entity: WeakEntity) -> Self { + WeakAgentDiffThread::AcpThread(entity) + } +} + impl AgentDiffPane { pub fn deploy( - thread: Entity, + thread: impl Into, workspace: WeakEntity, window: &mut Window, cx: &mut App, @@ -61,14 +155,16 @@ impl AgentDiffPane { } pub fn deploy_in_workspace( - thread: Entity, + thread: impl Into, workspace: &mut Workspace, window: &mut Window, cx: &mut Context, ) -> Entity { + let thread = thread.into(); let existing_diff = workspace .items_of_type::(cx) .find(|diff| diff.read(cx).thread == thread); + if let Some(existing_diff) = existing_diff { workspace.activate_item(&existing_diff, true, true, window, cx); existing_diff @@ -81,7 +177,7 @@ impl AgentDiffPane { } pub fn new( - thread: Entity, + thread: AgentDiffThread, workspace: WeakEntity, window: &mut Window, cx: &mut Context, @@ -89,7 +185,7 @@ impl AgentDiffPane { let focus_handle = cx.focus_handle(); let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite)); - let project = thread.read(cx).project().clone(); + let project = thread.project(cx).clone(); let editor = cx.new(|cx| { let mut editor = Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx); @@ -100,16 +196,27 @@ impl AgentDiffPane { editor }); - let action_log = thread.read(cx).action_log().clone(); + let action_log = thread.action_log(cx).clone(); + let mut this = Self { - _subscriptions: vec![ - cx.observe_in(&action_log, window, |this, _action_log, window, cx| { - this.update_excerpts(window, cx) - }), - cx.subscribe(&thread, |this, _thread, event, cx| { - this.handle_thread_event(event, cx) - }), - ], + _subscriptions: [ + Some( + cx.observe_in(&action_log, window, |this, _action_log, window, cx| { + this.update_excerpts(window, cx) + }), + ), + match &thread { + AgentDiffThread::Native(thread) => { + Some(cx.subscribe(&thread, |this, _thread, event, cx| { + this.handle_thread_event(event, cx) + })) + } + AgentDiffThread::AcpThread(_) => None, + }, + ] + .into_iter() + .flatten() + .collect(), title: SharedString::default(), multibuffer, editor, @@ -123,8 +230,7 @@ impl AgentDiffPane { } fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context) { - let thread = self.thread.read(cx); - let changed_buffers = thread.action_log().read(cx).changed_buffers(cx); + let changed_buffers = self.thread.action_log(cx).read(cx).changed_buffers(cx); let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::>(); for (buffer, diff_handle) in changed_buffers { @@ -211,7 +317,7 @@ impl AgentDiffPane { } fn update_title(&mut self, cx: &mut Context) { - let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes"); + let new_title = self.thread.summary(cx).unwrap_or("Agent Changes"); if new_title != self.title { self.title = new_title; cx.emit(EditorEvent::TitleChanged); @@ -275,14 +381,15 @@ impl AgentDiffPane { fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context) { self.thread - .update(cx, |thread, cx| thread.keep_all_edits(cx)); + .action_log(cx) + .update(cx, |action_log, cx| action_log.keep_all_edits(cx)) } } fn keep_edits_in_selection( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut Context, ) { @@ -297,7 +404,7 @@ fn keep_edits_in_selection( fn reject_edits_in_selection( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut Context, ) { @@ -311,7 +418,7 @@ fn reject_edits_in_selection( fn keep_edits_in_ranges( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, ranges: Vec>, window: &mut Window, cx: &mut Context, @@ -326,8 +433,8 @@ fn keep_edits_in_ranges( for hunk in &diff_hunks_in_ranges { let buffer = multibuffer.read(cx).buffer(hunk.buffer_id); if let Some(buffer) = buffer { - thread.update(cx, |thread, cx| { - thread.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx) + thread.action_log(cx).update(cx, |action_log, cx| { + action_log.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx) }); } } @@ -336,7 +443,7 @@ fn keep_edits_in_ranges( fn reject_edits_in_ranges( editor: &mut Editor, buffer_snapshot: &MultiBufferSnapshot, - thread: &Entity, + thread: &AgentDiffThread, ranges: Vec>, window: &mut Window, cx: &mut Context, @@ -362,8 +469,9 @@ fn reject_edits_in_ranges( for (buffer, ranges) in ranges_by_buffer { thread - .update(cx, |thread, cx| { - thread.reject_edits_in_ranges(buffer, ranges, cx) + .action_log(cx) + .update(cx, |action_log, cx| { + action_log.reject_edits_in_ranges(buffer, ranges, cx) }) .detach_and_log_err(cx); } @@ -461,7 +569,7 @@ impl Item for AgentDiffPane { } fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { - let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes"); + let summary = self.thread.summary(cx).unwrap_or("Agent Changes"); Label::new(format!("Review: {}", summary)) .color(if params.selected { Color::Default @@ -641,7 +749,7 @@ impl Render for AgentDiffPane { } } -fn diff_hunk_controls(thread: &Entity) -> editor::RenderDiffHunkControlsFn { +fn diff_hunk_controls(thread: &AgentDiffThread) -> editor::RenderDiffHunkControlsFn { let thread = thread.clone(); Arc::new( @@ -676,7 +784,7 @@ fn render_diff_hunk_controls( hunk_range: Range, is_created_file: bool, line_height: Pixels, - thread: &Entity, + thread: &AgentDiffThread, editor: &Entity, window: &mut Window, cx: &mut App, @@ -1112,11 +1220,8 @@ impl Render for AgentDiffToolbar { return Empty.into_any(); }; - let has_pending_edit_tool_use = agent_diff - .read(cx) - .thread - .read(cx) - .has_pending_edit_tool_uses(); + let has_pending_edit_tool_use = + agent_diff.read(cx).thread.has_pending_edit_tool_uses(cx); if has_pending_edit_tool_use { return div().px_2().child(spinner_icon).into_any(); @@ -1187,8 +1292,8 @@ pub enum EditorState { } struct WorkspaceThread { - thread: WeakEntity, - _thread_subscriptions: [Subscription; 2], + thread: WeakAgentDiffThread, + _thread_subscriptions: (Subscription, Subscription), singleton_editors: HashMap, HashMap, Subscription>>, _settings_subscription: Subscription, _workspace_subscription: Option, @@ -1212,23 +1317,23 @@ impl AgentDiff { pub fn set_active_thread( workspace: &WeakEntity, - thread: &Entity, + thread: impl Into, window: &mut Window, cx: &mut App, ) { Self::global(cx).update(cx, |this, cx| { - this.register_active_thread_impl(workspace, thread, window, cx); + this.register_active_thread_impl(workspace, thread.into(), window, cx); }); } fn register_active_thread_impl( &mut self, workspace: &WeakEntity, - thread: &Entity, + thread: AgentDiffThread, window: &mut Window, cx: &mut Context, ) { - let action_log = thread.read(cx).action_log().clone(); + let action_log = thread.action_log(cx).clone(); let action_log_subscription = cx.observe_in(&action_log, window, { let workspace = workspace.clone(); @@ -1237,17 +1342,25 @@ impl AgentDiff { } }); - let thread_subscription = cx.subscribe_in(&thread, window, { - let workspace = workspace.clone(); - move |this, _thread, event, window, cx| { - this.handle_thread_event(&workspace, event, window, cx) - } - }); + let thread_subscription = match &thread { + AgentDiffThread::Native(thread) => cx.subscribe_in(&thread, window, { + let workspace = workspace.clone(); + move |this, _thread, event, window, cx| { + this.handle_native_thread_event(&workspace, event, window, cx) + } + }), + AgentDiffThread::AcpThread(thread) => cx.subscribe_in(&thread, window, { + let workspace = workspace.clone(); + move |this, thread, event, window, cx| { + this.handle_acp_thread_event(&workspace, thread, event, window, cx) + } + }), + }; if let Some(workspace_thread) = self.workspace_threads.get_mut(&workspace) { // replace thread and action log subscription, but keep editors workspace_thread.thread = thread.downgrade(); - workspace_thread._thread_subscriptions = [action_log_subscription, thread_subscription]; + workspace_thread._thread_subscriptions = (action_log_subscription, thread_subscription); self.update_reviewing_editors(&workspace, window, cx); return; } @@ -1272,7 +1385,7 @@ impl AgentDiff { workspace.clone(), WorkspaceThread { thread: thread.downgrade(), - _thread_subscriptions: [action_log_subscription, thread_subscription], + _thread_subscriptions: (action_log_subscription, thread_subscription), singleton_editors: HashMap::default(), _settings_subscription: settings_subscription, _workspace_subscription: workspace_subscription, @@ -1319,7 +1432,7 @@ impl AgentDiff { fn register_review_action( workspace: &mut Workspace, - review: impl Fn(&Entity, &Entity, &mut Window, &mut App) -> PostReviewState + review: impl Fn(&Entity, &AgentDiffThread, &mut Window, &mut App) -> PostReviewState + 'static, this: &Entity, ) { @@ -1338,7 +1451,7 @@ impl AgentDiff { }); } - fn handle_thread_event( + fn handle_native_thread_event( &mut self, workspace: &WeakEntity, event: &ThreadEvent, @@ -1380,6 +1493,40 @@ impl AgentDiff { } } + fn handle_acp_thread_event( + &mut self, + workspace: &WeakEntity, + thread: &Entity, + event: &AcpThreadEvent, + window: &mut Window, + cx: &mut Context, + ) { + match event { + AcpThreadEvent::NewEntry => { + if thread + .read(cx) + .entries() + .last() + .and_then(|entry| entry.diff()) + .is_some() + { + self.update_reviewing_editors(workspace, window, cx); + } + } + AcpThreadEvent::EntryUpdated(ix) => { + if thread + .read(cx) + .entries() + .get(*ix) + .and_then(|entry| entry.diff()) + .is_some() + { + self.update_reviewing_editors(workspace, window, cx); + } + } + } + } + fn handle_workspace_event( &mut self, workspace: &Entity, @@ -1485,7 +1632,7 @@ impl AgentDiff { return; }; - let action_log = thread.read(cx).action_log(); + let action_log = thread.action_log(cx); let changed_buffers = action_log.read(cx).changed_buffers(cx); let mut unaffected = self.reviewing_editors.clone(); @@ -1510,7 +1657,7 @@ impl AgentDiff { multibuffer.add_diff(diff_handle.clone(), cx); }); - let new_state = if thread.read(cx).is_generating() { + let new_state = if thread.is_generating(cx) { EditorState::Generating } else { EditorState::Reviewing @@ -1606,7 +1753,7 @@ impl AgentDiff { fn keep_all( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1626,7 +1773,7 @@ impl AgentDiff { fn reject_all( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1646,7 +1793,7 @@ impl AgentDiff { fn keep( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1659,7 +1806,7 @@ impl AgentDiff { fn reject( editor: &Entity, - thread: &Entity, + thread: &AgentDiffThread, window: &mut Window, cx: &mut App, ) -> PostReviewState { @@ -1682,7 +1829,7 @@ impl AgentDiff { fn review_in_active_editor( &mut self, workspace: &mut Workspace, - review: impl Fn(&Entity, &Entity, &mut Window, &mut App) -> PostReviewState, + review: impl Fn(&Entity, &AgentDiffThread, &mut Window, &mut App) -> PostReviewState, window: &mut Window, cx: &mut Context, ) -> Option>> { @@ -1703,7 +1850,7 @@ impl AgentDiff { if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) { if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() { - let changed_buffers = thread.read(cx).action_log().read(cx).changed_buffers(cx); + let changed_buffers = thread.action_log(cx).read(cx).changed_buffers(cx); let mut keys = changed_buffers.keys().cycle(); keys.find(|k| *k == &curr_buffer); @@ -1801,8 +1948,9 @@ mod tests { }) .await .unwrap(); - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + let thread = + AgentDiffThread::Native(thread_store.update(cx, |store, cx| store.create_thread(cx))); + let action_log = cx.read(|cx| thread.action_log(cx)); let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); @@ -1988,8 +2136,9 @@ mod tests { }); // Set the active thread + let thread = AgentDiffThread::Native(thread); cx.update(|window, cx| { - AgentDiff::set_active_thread(&workspace.downgrade(), &thread, window, cx) + AgentDiff::set_active_thread(&workspace.downgrade(), thread.clone(), window, cx) }); let buffer1 = project diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 8485c5f092..7f3addc1f4 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -8,6 +8,7 @@ use db::kvp::{Dismissable, KEY_VALUE_STORE}; use serde::{Deserialize, Serialize}; use crate::NewAcpThread; +use crate::agent_diff::AgentDiffThread; use crate::language_model_selector::ToggleModelSelector; use crate::{ AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode, @@ -624,7 +625,7 @@ impl AgentPanel { } }; - AgentDiff::set_active_thread(&workspace, &thread, window, cx); + AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); let weak_panel = weak_self.clone(); @@ -845,7 +846,7 @@ impl AgentPanel { let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); self.set_active_view(thread_view, window, cx); - AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); + AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); } fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context) { @@ -890,11 +891,20 @@ impl AgentPanel { cx.spawn_in(window, async move |this, cx| { let thread_view = cx.new_window_entity(|window, cx| { - crate::acp::AcpThreadView::new(workspace, project, window, cx) + crate::acp::AcpThreadView::new(workspace.clone(), project, window, cx) })?; this.update_in(cx, |this, window, cx| { - this.set_active_view(ActiveView::AcpThread { thread_view }, window, cx); + this.set_active_view( + ActiveView::AcpThread { + thread_view: thread_view.clone(), + }, + window, + cx, + ); }) + .log_err(); + + anyhow::Ok(()) }) .detach(); } @@ -1050,7 +1060,7 @@ impl AgentPanel { let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); self.set_active_view(thread_view, window, cx); - AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); + AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx); } pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context) { @@ -1181,7 +1191,12 @@ impl AgentPanel { let thread = thread.read(cx).thread().clone(); self.workspace .update(cx, |workspace, cx| { - AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx) + AgentDiffPane::deploy_in_workspace( + AgentDiffThread::Native(thread), + workspace, + window, + cx, + ) }) .log_err(); } diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 25c62c5fb3..d2b136f274 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; use std::rc::Rc; use std::sync::Arc; +use crate::agent_diff::AgentDiffThread; use crate::agent_model_selector::AgentModelSelector; use crate::language_model_selector::ToggleModelSelector; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; @@ -475,9 +476,12 @@ impl MessageEditor { window: &mut Window, cx: &mut Context, ) { - if let Ok(diff) = - AgentDiffPane::deploy(self.thread.clone(), self.workspace.clone(), window, cx) - { + if let Ok(diff) = AgentDiffPane::deploy( + AgentDiffThread::Native(self.thread.clone()), + self.workspace.clone(), + window, + cx, + ) { let path_key = multi_buffer::PathKey::for_buffer(&buffer, cx); diff.update(cx, |diff, cx| diff.move_to_path(path_key, window, cx)); } diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index 2071a1f444..e983075cd1 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -8,7 +8,7 @@ use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint}; use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle}; use std::{cmp, ops::Range, sync::Arc}; use text::{Edit, Patch, Rope}; -use util::RangeExt; +use util::{RangeExt, ResultExt as _}; /// Tracks actions performed by tools in a thread pub struct ActionLog { @@ -47,6 +47,10 @@ impl ActionLog { self.edited_since_project_diagnostics_check } + pub fn latest_snapshot(&self, buffer: &Entity) -> Option { + Some(self.tracked_buffers.get(buffer)?.snapshot.clone()) + } + fn track_buffer_internal( &mut self, buffer: Entity, @@ -715,6 +719,22 @@ impl ActionLog { cx.notify(); } + pub fn reject_all_edits(&mut self, cx: &mut Context) -> Task<()> { + let futures = self.changed_buffers(cx).into_keys().map(|buffer| { + let reject = self.reject_edits_in_ranges(buffer, vec![Anchor::MIN..Anchor::MAX], cx); + + async move { + reject.await.log_err(); + } + }); + + let task = futures::future::join_all(futures); + + cx.spawn(async move |_, _| { + task.await; + }) + } + /// Returns the set of buffers that contain edits that haven't been reviewed by the user. pub fn changed_buffers(&self, cx: &App) -> BTreeMap, Entity> { self.tracked_buffers