ACP follow (#34235)

Closes #ISSUE

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Anthony Eid <hello@anthonyeid.me>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Conrad Irwin 2025-07-11 09:38:42 -06:00 committed by GitHub
parent 496bf0ec43
commit 993e0f55ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1090 additions and 208 deletions

View file

@ -2,14 +2,19 @@ pub use acp::ToolCallId;
use agent_servers::AgentServer;
use agentic_coding_protocol::{self as acp, UserMessageChunk};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::ActionLog;
use buffer_diff::BufferDiff;
use editor::{MultiBuffer, PathKey};
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use itertools::Itertools;
use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
use language::{
Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
text_diff,
};
use markdown::Markdown;
use project::Project;
use project::{AgentLocation, Project};
use std::collections::HashMap;
use std::error::Error;
use std::fmt::{Formatter, Write};
use std::{
@ -159,6 +164,18 @@ impl AgentThreadEntry {
Self::ToolCall(too_call) => too_call.to_markdown(cx),
}
}
pub fn diff(&self) -> Option<&Diff> {
if let AgentThreadEntry::ToolCall(ToolCall {
content: Some(ToolCallContent::Diff { diff }),
..
}) = self
{
Some(&diff)
} else {
None
}
}
}
#[derive(Debug)]
@ -168,6 +185,7 @@ pub struct ToolCall {
pub icon: IconName,
pub content: Option<ToolCallContent>,
pub status: ToolCallStatus,
pub locations: Vec<acp::ToolCallLocation>,
}
impl ToolCall {
@ -328,6 +346,8 @@ impl ToolCallContent {
pub struct Diff {
pub multibuffer: Entity<MultiBuffer>,
pub path: PathBuf,
pub new_buffer: Entity<Buffer>,
pub old_buffer: Entity<Buffer>,
_task: Task<Result<()>>,
}
@ -362,6 +382,7 @@ impl Diff {
let task = cx.spawn({
let multibuffer = multibuffer.clone();
let path = path.clone();
let new_buffer = new_buffer.clone();
async move |cx| {
diff_task.await?;
@ -401,6 +422,8 @@ impl Diff {
Self {
multibuffer,
path,
new_buffer,
old_buffer,
_task: task,
}
}
@ -421,6 +444,8 @@ pub struct AcpThread {
entries: Vec<AgentThreadEntry>,
title: SharedString,
project: Entity<Project>,
action_log: Entity<ActionLog>,
shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
send_task: Option<Task<()>>,
connection: Arc<acp::AgentConnection>,
child_status: Option<Task<Result<()>>>,
@ -522,7 +547,11 @@ impl AcpThread {
}
});
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self {
action_log,
shared_buffers: Default::default(),
entries: Default::default(),
title: "ACP Thread".into(),
project,
@ -534,6 +563,14 @@ impl AcpThread {
})
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn project(&self) -> &Entity<Project> {
&self.project
}
#[cfg(test)]
pub fn fake(
stdin: async_pipe::PipeWriter,
@ -558,7 +595,11 @@ impl AcpThread {
}
});
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self {
action_log,
shared_buffers: Default::default(),
entries: Default::default(),
title: "ACP Thread".into(),
project,
@ -589,6 +630,26 @@ impl AcpThread {
}
}
pub fn has_pending_edit_tool_calls(&self) -> bool {
for entry in self.entries.iter().rev() {
match entry {
AgentThreadEntry::UserMessage(_) => return false,
AgentThreadEntry::ToolCall(ToolCall {
status:
ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Running,
..
},
content: Some(ToolCallContent::Diff { .. }),
..
}) => return true,
AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
}
}
false
}
pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
self.entries.push(entry);
cx.emit(AcpThreadEvent::NewEntry);
@ -644,65 +705,63 @@ impl AcpThread {
pub fn request_tool_call(
&mut self,
label: String,
icon: acp::Icon,
content: Option<acp::ToolCallContent>,
confirmation: acp::ToolCallConfirmation,
tool_call: acp::RequestToolCallConfirmationParams,
cx: &mut Context<Self>,
) -> ToolCallRequest {
let (tx, rx) = oneshot::channel();
let status = ToolCallStatus::WaitingForConfirmation {
confirmation: ToolCallConfirmation::from_acp(
confirmation,
tool_call.confirmation,
self.project.read(cx).languages().clone(),
cx,
),
respond_tx: tx,
};
let id = self.insert_tool_call(label, status, icon, content, cx);
let id = self.insert_tool_call(tool_call.tool_call, status, cx);
ToolCallRequest { id, outcome: rx }
}
pub fn push_tool_call(
&mut self,
label: String,
icon: acp::Icon,
content: Option<acp::ToolCallContent>,
request: acp::PushToolCallParams,
cx: &mut Context<Self>,
) -> acp::ToolCallId {
let status = ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Running,
};
self.insert_tool_call(label, status, icon, content, cx)
self.insert_tool_call(request, status, cx)
}
fn insert_tool_call(
&mut self,
label: String,
tool_call: acp::PushToolCallParams,
status: ToolCallStatus,
icon: acp::Icon,
content: Option<acp::ToolCallContent>,
cx: &mut Context<Self>,
) -> acp::ToolCallId {
let language_registry = self.project.read(cx).languages().clone();
let id = acp::ToolCallId(self.entries.len() as u64);
self.push_entry(
AgentThreadEntry::ToolCall(ToolCall {
id,
label: cx.new(|cx| {
Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
}),
icon: acp_icon_to_ui_icon(icon),
content: content
.map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
status,
let call = ToolCall {
id,
label: cx.new(|cx| {
Markdown::new(
tool_call.label.into(),
Some(language_registry.clone()),
None,
cx,
)
}),
cx,
);
icon: acp_icon_to_ui_icon(tool_call.icon),
content: tool_call
.content
.map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
locations: tool_call.locations,
status,
};
self.push_entry(AgentThreadEntry::ToolCall(call), cx);
id
}
@ -804,14 +863,16 @@ impl AcpThread {
false
}
pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp::InitializeResponse>> {
pub fn initialize(
&self,
) -> impl use<> + Future<Output = Result<acp::InitializeResponse, acp::Error>> {
let connection = self.connection.clone();
async move { Ok(connection.request(acp::InitializeParams).await?) }
async move { connection.request(acp::InitializeParams).await }
}
pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
pub fn authenticate(&self) -> impl use<> + Future<Output = Result<(), acp::Error>> {
let connection = self.connection.clone();
async move { Ok(connection.request(acp::AuthenticateParams).await?) }
async move { connection.request(acp::AuthenticateParams).await }
}
#[cfg(test)]
@ -819,7 +880,7 @@ impl AcpThread {
&mut self,
message: &str,
cx: &mut Context<Self>,
) -> BoxFuture<'static, Result<()>> {
) -> BoxFuture<'static, Result<(), acp::Error>> {
self.send(
acp::SendUserMessageParams {
chunks: vec![acp::UserMessageChunk::Text {
@ -834,7 +895,7 @@ impl AcpThread {
&mut self,
message: acp::SendUserMessageParams,
cx: &mut Context<Self>,
) -> BoxFuture<'static, Result<()>> {
) -> BoxFuture<'static, Result<(), acp::Error>> {
let agent = self.connection.clone();
self.push_entry(
AgentThreadEntry::UserMessage(UserMessage::from_acp(
@ -865,7 +926,7 @@ impl AcpThread {
.boxed()
}
pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
let agent = self.connection.clone();
if self.send_task.take().is_some() {
@ -898,13 +959,123 @@ impl AcpThread {
}
}
}
})
})?;
Ok(())
})
} else {
Task::ready(Ok(()))
}
}
pub fn read_text_file(
&self,
request: acp::ReadTextFileParams,
cx: &mut Context<Self>,
) -> Task<Result<String>> {
let project = self.project.clone();
let action_log = self.action_log.clone();
cx.spawn(async move |this, cx| {
let load = project.update(cx, |project, cx| {
let path = project
.project_path_for_absolute_path(&request.path, cx)
.context("invalid path")?;
anyhow::Ok(project.open_buffer(path, cx))
});
let buffer = load??.await?;
action_log.update(cx, |action_log, cx| {
action_log.buffer_read(buffer.clone(), cx);
})?;
project.update(cx, |project, cx| {
let position = buffer
.read(cx)
.snapshot()
.anchor_before(Point::new(request.line.unwrap_or_default(), 0));
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position,
}),
cx,
);
})?;
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
this.update(cx, |this, _| {
let text = snapshot.text();
this.shared_buffers.insert(buffer.clone(), snapshot);
text
})
})
}
pub fn write_text_file(
&self,
path: PathBuf,
content: String,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let project = self.project.clone();
let action_log = self.action_log.clone();
cx.spawn(async move |this, cx| {
let load = project.update(cx, |project, cx| {
let path = project
.project_path_for_absolute_path(&path, cx)
.context("invalid path")?;
anyhow::Ok(project.open_buffer(path, cx))
});
let buffer = load??.await?;
let snapshot = this.update(cx, |this, cx| {
this.shared_buffers
.get(&buffer)
.cloned()
.unwrap_or_else(|| buffer.read(cx).snapshot())
})?;
let edits = cx
.background_executor()
.spawn(async move {
let old_text = snapshot.text();
text_diff(old_text.as_str(), &content)
.into_iter()
.map(|(range, replacement)| {
(
snapshot.anchor_after(range.start)
..snapshot.anchor_before(range.end),
replacement,
)
})
.collect::<Vec<_>>()
})
.await;
cx.update(|cx| {
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: edits
.last()
.map(|(range, _)| range.end)
.unwrap_or(Anchor::MIN),
}),
cx,
);
});
action_log.update(cx, |action_log, cx| {
action_log.buffer_read(buffer.clone(), cx);
});
buffer.update(cx, |buffer, cx| {
buffer.edit(edits, None, cx);
});
action_log.update(cx, |action_log, cx| {
action_log.buffer_edited(buffer.clone(), cx);
});
})?;
project
.update(cx, |project, cx| project.save_buffer(buffer, cx))?
.await
})
}
pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
self.child_status.take()
}
@ -930,7 +1101,7 @@ impl acp::Client for AcpClientDelegate {
async fn stream_assistant_message_chunk(
&self,
params: acp::StreamAssistantMessageChunkParams,
) -> Result<()> {
) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
@ -947,45 +1118,37 @@ impl acp::Client for AcpClientDelegate {
async fn request_tool_call_confirmation(
&self,
request: acp::RequestToolCallConfirmationParams,
) -> Result<acp::RequestToolCallConfirmationResponse> {
) -> Result<acp::RequestToolCallConfirmationResponse, acp::Error> {
let cx = &mut self.cx.clone();
let ToolCallRequest { id, outcome } = cx
.update(|cx| {
self.thread.update(cx, |thread, cx| {
thread.request_tool_call(
request.label,
request.icon,
request.content,
request.confirmation,
cx,
)
})
self.thread
.update(cx, |thread, cx| thread.request_tool_call(request, cx))
})?
.context("Failed to update thread")?;
Ok(acp::RequestToolCallConfirmationResponse {
id,
outcome: outcome.await?,
outcome: outcome.await.map_err(acp::Error::into_internal_error)?,
})
}
async fn push_tool_call(
&self,
request: acp::PushToolCallParams,
) -> Result<acp::PushToolCallResponse> {
) -> Result<acp::PushToolCallResponse, acp::Error> {
let cx = &mut self.cx.clone();
let id = cx
.update(|cx| {
self.thread.update(cx, |thread, cx| {
thread.push_tool_call(request.label, request.icon, request.content, cx)
})
self.thread
.update(cx, |thread, cx| thread.push_tool_call(request, cx))
})?
.context("Failed to update thread")?;
Ok(acp::PushToolCallResponse { id })
}
async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<()> {
async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
@ -997,6 +1160,34 @@ impl acp::Client for AcpClientDelegate {
Ok(())
}
async fn read_text_file(
&self,
request: acp::ReadTextFileParams,
) -> Result<acp::ReadTextFileResponse, acp::Error> {
let content = self
.cx
.update(|cx| {
self.thread
.update(cx, |thread, cx| thread.read_text_file(request, cx))
})?
.context("Failed to update thread")?
.await?;
Ok(acp::ReadTextFileResponse { content })
}
async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> {
self.cx
.update(|cx| {
self.thread.update(cx, |thread, cx| {
thread.write_text_file(request.path, request.content, cx)
})
})?
.context("Failed to update thread")?
.await?;
Ok(())
}
}
fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
@ -1100,6 +1291,80 @@ mod tests {
);
}
#[gpui::test]
async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
.await;
let project = Project::test(fs.clone(), [], cx).await;
let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
let (worktree, pathbuf) = project
.update(cx, |project, cx| {
project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
})
.await
.unwrap();
let buffer = project
.update(cx, |project, cx| {
project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
})
.await
.unwrap();
let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
fake_server.update(cx, |fake_server, _| {
fake_server.on_user_message(move |_, server, mut cx| {
let read_file_tx = read_file_tx.clone();
async move {
let content = server
.update(&mut cx, |server, _| {
server.send_to_zed(acp::ReadTextFileParams {
path: path!("/tmp/foo").into(),
line: None,
limit: None,
})
})?
.await
.unwrap();
assert_eq!(content.content, "one\ntwo\nthree\n");
read_file_tx.take().unwrap().send(()).unwrap();
server
.update(&mut cx, |server, _| {
server.send_to_zed(acp::WriteTextFileParams {
path: path!("/tmp/foo").into(),
content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
})
})?
.await
.unwrap();
Ok(())
}
})
});
let request = thread.update(cx, |thread, cx| {
thread.send_raw("Extend the count in /tmp/foo", cx)
});
read_file_rx.await.ok();
buffer.update(cx, |buffer, cx| {
buffer.edit([(0..0, "zero\n".to_string())], None, cx);
});
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"zero\none\ntwo\nthree\nfour\nfive\n"
);
assert_eq!(
String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
"zero\none\ntwo\nthree\nfour\nfive\n"
);
request.await.unwrap();
}
#[gpui::test]
async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
init_test(cx);
@ -1124,6 +1389,7 @@ mod tests {
label: "Fetch".to_string(),
icon: acp::Icon::Globe,
content: None,
locations: vec![],
})
})?
.await
@ -1553,7 +1819,7 @@ mod tests {
acp::SendUserMessageParams,
Entity<FakeAcpServer>,
AsyncApp,
) -> LocalBoxFuture<'static, Result<()>>,
) -> LocalBoxFuture<'static, Result<(), acp::Error>>,
>,
>,
}
@ -1565,21 +1831,24 @@ mod tests {
}
impl acp::Agent for FakeAgent {
async fn initialize(&self) -> Result<acp::InitializeResponse> {
async fn initialize(&self) -> Result<acp::InitializeResponse, acp::Error> {
Ok(acp::InitializeResponse {
is_authenticated: true,
})
}
async fn authenticate(&self) -> Result<()> {
async fn authenticate(&self) -> Result<(), acp::Error> {
Ok(())
}
async fn cancel_send_message(&self) -> Result<()> {
async fn cancel_send_message(&self) -> Result<(), acp::Error> {
Ok(())
}
async fn send_user_message(&self, request: acp::SendUserMessageParams) -> Result<()> {
async fn send_user_message(
&self,
request: acp::SendUserMessageParams,
) -> Result<(), acp::Error> {
let mut cx = self.cx.clone();
let handler = self
.server
@ -1589,7 +1858,7 @@ mod tests {
if let Some(handler) = handler {
handler(request, self.server.clone(), self.cx.clone()).await
} else {
anyhow::bail!("No handler for on_user_message")
Err(anyhow::anyhow!("No handler for on_user_message").into())
}
}
}
@ -1624,7 +1893,7 @@ mod tests {
handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
+ 'static,
) where
F: Future<Output = Result<()>> + 'static,
F: Future<Output = Result<(), acp::Error>> + 'static,
{
self.on_user_message
.replace(Rc::new(move |request, server, cx| {