diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 71827d6948..7a00bd2320 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,19 +1,18 @@ mod connection; +mod diff; + pub use connection::*; +pub use diff::*; use agent_client_protocol as acp; use anyhow::{Context as _, Result}; use assistant_tool::ActionLog; -use buffer_diff::BufferDiff; -use editor::{Bias, MultiBuffer, PathKey}; +use editor::Bias; use futures::future::{Fuse, FusedFuture}; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use itertools::Itertools; -use language::{ - Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point, - text_diff, -}; +use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff}; use markdown::Markdown; use project::{AgentLocation, Project}; use std::collections::HashMap; @@ -141,7 +140,7 @@ impl AgentThreadEntry { } } - pub fn diffs(&self) -> impl Iterator { + pub fn diffs(&self) -> impl Iterator> { if let AgentThreadEntry::ToolCall(call) = self { itertools::Either::Left(call.diffs()) } else { @@ -250,7 +249,7 @@ impl ToolCall { } } - pub fn diffs(&self) -> impl Iterator { + pub fn diffs(&self) -> impl Iterator> { self.content.iter().filter_map(|content| match content { ToolCallContent::ContentBlock { .. } => None, ToolCallContent::Diff { diff } => Some(diff), @@ -390,7 +389,7 @@ impl ContentBlock { #[derive(Debug)] pub enum ToolCallContent { ContentBlock { content: ContentBlock }, - Diff { diff: Diff }, + Diff { diff: Entity }, } impl ToolCallContent { @@ -404,7 +403,7 @@ impl ToolCallContent { content: ContentBlock::new(content, &language_registry, cx), }, acp::ToolCallContent::Diff { diff } => Self::Diff { - diff: Diff::from_acp(diff, language_registry, cx), + diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)), }, } } @@ -412,108 +411,11 @@ impl ToolCallContent { pub fn to_markdown(&self, cx: &App) -> String { match self { Self::ContentBlock { content } => content.to_markdown(cx).to_string(), - Self::Diff { diff } => diff.to_markdown(cx), + Self::Diff { diff } => diff.read(cx).to_markdown(cx), } } } -#[derive(Debug)] -pub struct Diff { - pub multibuffer: Entity, - pub path: PathBuf, - _task: Task>, -} - -impl Diff { - pub fn from_acp( - diff: acp::Diff, - language_registry: Arc, - cx: &mut App, - ) -> Self { - let acp::Diff { - path, - old_text, - new_text, - } = diff; - - let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); - - let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); - let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); - let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); - let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); - - let task = cx.spawn({ - let multibuffer = multibuffer.clone(); - let path = path.clone(); - async move |cx| { - let language = language_registry - .language_for_file_path(&path) - .await - .log_err(); - - new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?; - - let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| { - buffer.set_language(language, cx); - buffer.snapshot() - })?; - - buffer_diff - .update(cx, |diff, cx| { - diff.set_base_text( - old_buffer_snapshot, - Some(language_registry), - new_buffer_snapshot, - cx, - ) - })? - .await?; - - multibuffer - .update(cx, |multibuffer, cx| { - let hunk_ranges = { - let buffer = new_buffer.read(cx); - let diff = buffer_diff.read(cx); - diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) - .collect::>() - }; - - multibuffer.set_excerpts_for_path( - PathKey::for_buffer(&new_buffer, cx), - new_buffer.clone(), - hunk_ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, - cx, - ); - multibuffer.add_diff(buffer_diff, cx); - }) - .log_err(); - - anyhow::Ok(()) - } - }); - - Self { - multibuffer, - path, - _task: task, - } - } - - fn to_markdown(&self, cx: &App) -> String { - let buffer_text = self - .multibuffer - .read(cx) - .all_buffers() - .iter() - .map(|buffer| buffer.read(cx).text()) - .join("\n"); - format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text) - } -} - #[derive(Debug, Default)] pub struct Plan { pub entries: Vec, @@ -828,6 +730,21 @@ impl AcpThread { Ok(()) } + pub fn set_tool_call_diff( + &mut self, + tool_call_id: &acp::ToolCallId, + diff: Entity, + cx: &mut Context, + ) -> Result<()> { + let (ix, current_call) = self + .tool_call_mut(tool_call_id) + .context("Tool call not found")?; + current_call.content.clear(); + current_call.content.push(ToolCallContent::Diff { diff }); + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + Ok(()) + } + /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { let status = ToolCallStatus::Allowed { diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs new file mode 100644 index 0000000000..01649eaebb --- /dev/null +++ b/crates/acp_thread/src/diff.rs @@ -0,0 +1,161 @@ +use agent_client_protocol as acp; +use anyhow::Result; +use buffer_diff::BufferDiff; +use editor::{MultiBuffer, PathKey}; +use gpui::{App, AppContext, Context, Entity, Task}; +use itertools::Itertools; +use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _}; +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; +use util::ResultExt; + +pub enum Diff { + Pending { + multibuffer: Entity, + base_text: Arc, + buffer: Entity, + buffer_diff: Entity, + }, + Ready { + path: PathBuf, + multibuffer: Entity, + _task: Task>, + }, +} + +impl Diff { + pub fn from_acp( + diff: acp::Diff, + language_registry: Arc, + cx: &mut Context, + ) -> Self { + let acp::Diff { + path, + old_text, + new_text, + } = diff; + + let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); + + let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); + let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); + let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); + let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); + + let task = cx.spawn({ + let multibuffer = multibuffer.clone(); + let path = path.clone(); + async move |_, cx| { + let language = language_registry + .language_for_file_path(&path) + .await + .log_err(); + + new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?; + + let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| { + buffer.set_language(language, cx); + buffer.snapshot() + })?; + + buffer_diff + .update(cx, |diff, cx| { + diff.set_base_text( + old_buffer_snapshot, + Some(language_registry), + new_buffer_snapshot, + cx, + ) + })? + .await?; + + multibuffer + .update(cx, |multibuffer, cx| { + let hunk_ranges = { + let buffer = new_buffer.read(cx); + let diff = buffer_diff.read(cx); + diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .collect::>() + }; + + multibuffer.set_excerpts_for_path( + PathKey::for_buffer(&new_buffer, cx), + new_buffer.clone(), + hunk_ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + multibuffer.add_diff(buffer_diff, cx); + }) + .log_err(); + + anyhow::Ok(()) + } + }); + + Self::Ready { + multibuffer, + path, + _task: task, + } + } + + pub fn new(buffer: Entity, cx: &mut Context) -> Self { + let buffer_snapshot = buffer.read(cx).snapshot(); + let base_text = buffer_snapshot.text(); + let language_registry = buffer.read(cx).language_registry(); + let text_snapshot = buffer.read(cx).text_snapshot(); + let buffer_diff = cx.new(|cx| { + let mut diff = BufferDiff::new(&text_snapshot, cx); + let _ = diff.set_base_text( + buffer_snapshot.clone(), + language_registry, + text_snapshot, + cx, + ); + diff + }); + + let multibuffer = cx.new(|cx| { + let mut multibuffer = MultiBuffer::without_headers(Capability::ReadOnly); + multibuffer.add_diff(buffer_diff.clone(), cx); + multibuffer + }); + + Self::Pending { + multibuffer, + base_text: Arc::new(base_text), + buffer, + buffer_diff, + } + } + + pub fn multibuffer(&self) -> &Entity { + match self { + Self::Pending { multibuffer, .. } => multibuffer, + Self::Ready { multibuffer, .. } => multibuffer, + } + } + + pub fn to_markdown(&self, cx: &App) -> String { + let buffer_text = self + .multibuffer() + .read(cx) + .all_buffers() + .iter() + .map(|buffer| buffer.read(cx).text()) + .join("\n"); + let path = match self { + Diff::Pending { buffer, .. } => buffer.read(cx).file().map(|file| file.path().as_ref()), + Diff::Ready { path, .. } => Some(path.as_path()), + }; + format!( + "Diff: {}\n```\n{}\n```\n", + path.unwrap_or(Path::new("untitled")).display(), + buffer_text + ) + } +} diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 8a670b2478..4583688ad3 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -562,6 +562,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection { ) })??; } + AgentResponseEvent::ToolCallDiff(tool_call_diff) => { + acp_thread.update(cx, |thread, cx| { + thread.set_tool_call_diff( + &tool_call_diff.tool_call_id, + tool_call_diff.diff, + cx, + ) + })??; + } AgentResponseEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse { stop_reason }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 70cbde1449..9be2860eec 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -103,6 +103,7 @@ pub enum AgentResponseEvent { ToolCall(acp::ToolCall), ToolCallUpdate(acp::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + ToolCallDiff(ToolCallDiff), Stop(acp::StopReason), } @@ -113,6 +114,12 @@ pub struct ToolCallAuthorization { pub response: oneshot::Sender, } +#[derive(Debug)] +pub struct ToolCallDiff { + pub tool_call_id: acp::ToolCallId, + pub diff: Entity, +} + pub struct Thread { messages: Vec, completion_mode: CompletionMode, diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 1bd495d1aa..cf0cf43e33 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -42,7 +42,7 @@ use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; use ::acp_thread::{ - AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, + AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, }; @@ -731,7 +731,11 @@ impl AcpThreadView { cx: &App, ) -> Option>> { let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - Some(entry.diffs().map(|diff| diff.multibuffer.clone())) + Some( + entry + .diffs() + .map(|diff| diff.read(cx).multibuffer().clone()), + ) } fn authenticate( @@ -1313,10 +1317,9 @@ impl AcpThreadView { Empty.into_any_element() } } - ToolCallContent::Diff { - diff: Diff { multibuffer, .. }, - .. - } => self.render_diff_editor(multibuffer), + ToolCallContent::Diff { diff, .. } => { + self.render_diff_editor(&diff.read(cx).multibuffer()) + } } }