Update to new agent schema (#35578)

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
Ben Brandt 2025-08-04 15:49:41 +02:00 committed by GitHub
parent dea64d3373
commit f17943e4a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 741 additions and 1168 deletions

View file

@ -1,7 +1,5 @@
mod connection;
mod old_acp_support;
pub use connection::*;
pub use old_acp_support::*;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
@ -391,7 +389,7 @@ impl ToolCallContent {
cx: &mut App,
) -> Self {
match content {
acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock {
acp::ToolCallContent::Content { content } => Self::ContentBlock {
content: ContentBlock::new(content, &language_registry, cx),
},
acp::ToolCallContent::Diff { diff } => Self::Diff {
@ -619,6 +617,7 @@ impl Error for LoadError {}
impl AcpThread {
pub fn new(
title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>,
project: Entity<Project>,
session_id: acp::SessionId,
@ -631,7 +630,7 @@ impl AcpThread {
shared_buffers: Default::default(),
entries: Default::default(),
plan: Default::default(),
title: connection.name().into(),
title: title.into(),
project,
send_task: None,
connection,
@ -708,14 +707,14 @@ impl AcpThread {
cx: &mut Context<Self>,
) -> Result<()> {
match update {
acp::SessionUpdate::UserMessage(content_block) => {
self.push_user_content_block(content_block, cx);
acp::SessionUpdate::UserMessageChunk { content } => {
self.push_user_content_block(content, cx);
}
acp::SessionUpdate::AgentMessageChunk(content_block) => {
self.push_assistant_content_block(content_block, false, cx);
acp::SessionUpdate::AgentMessageChunk { content } => {
self.push_assistant_content_block(content, false, cx);
}
acp::SessionUpdate::AgentThoughtChunk(content_block) => {
self.push_assistant_content_block(content_block, true, cx);
acp::SessionUpdate::AgentThoughtChunk { content } => {
self.push_assistant_content_block(content, true, cx);
}
acp::SessionUpdate::ToolCall(tool_call) => {
self.upsert_tool_call(tool_call, cx);
@ -984,10 +983,6 @@ impl AcpThread {
cx.notify();
}
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
self.connection.authenticate(cx)
}
#[cfg(any(test, feature = "test-support"))]
pub fn send_raw(
&mut self,
@ -1029,7 +1024,7 @@ impl AcpThread {
let result = this
.update(cx, |this, cx| {
this.connection.prompt(
acp::PromptArguments {
acp::PromptRequest {
prompt: message,
session_id: this.session_id.clone(),
},
@ -1239,21 +1234,15 @@ impl AcpThread {
#[cfg(test)]
mod tests {
use super::*;
use agentic_coding_protocol as acp_old;
use anyhow::anyhow;
use async_pipe::{PipeReader, PipeWriter};
use futures::{
channel::mpsc,
future::{LocalBoxFuture, try_join_all},
select,
};
use futures::{channel::mpsc, future::LocalBoxFuture, select};
use gpui::{AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
use project::FakeFs;
use rand::Rng as _;
use serde_json::json;
use settings::SettingsStore;
use smol::{future::BoxedLocal, stream::StreamExt as _};
use smol::stream::StreamExt as _;
use std::{cell::RefCell, rc::Rc, time::Duration};
use util::path;
@ -1274,7 +1263,15 @@ mod tests {
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
let (thread, _fake_server) = fake_acp_thread(project, cx);
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
.spawn(async move |mut cx| {
connection
.new_thread(project, Path::new(path!("/test")), &mut cx)
.await
})
.await
.unwrap();
// Test creating a new user message
thread.update(cx, |thread, cx| {
@ -1354,34 +1351,40 @@ mod tests {
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
let (thread, fake_server) = fake_acp_thread(project, cx);
let connection = Rc::new(FakeAgentConnection::new().on_user_message(
|_, thread, mut cx| {
async move {
thread.update(&mut cx, |thread, cx| {
thread
.handle_session_update(
acp::SessionUpdate::AgentThoughtChunk {
content: "Thinking ".into(),
},
cx,
)
.unwrap();
thread
.handle_session_update(
acp::SessionUpdate::AgentThoughtChunk {
content: "hard!".into(),
},
cx,
)
.unwrap();
})
}
.boxed_local()
},
));
fake_server.update(cx, |fake_server, _| {
fake_server.on_user_message(move |_, server, mut cx| async move {
server
.update(&mut cx, |server, _| {
server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
chunk: acp_old::AssistantMessageChunk::Thought {
thought: "Thinking ".into(),
},
})
})?
let thread = cx
.spawn(async move |mut cx| {
connection
.new_thread(project, Path::new(path!("/test")), &mut cx)
.await
.unwrap();
server
.update(&mut cx, |server, _| {
server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
chunk: acp_old::AssistantMessageChunk::Thought {
thought: "hard!".into(),
},
})
})?
.await
.unwrap();
Ok(())
})
});
.await
.unwrap();
thread
.update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
@ -1414,7 +1417,38 @@ mod tests {
fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
.await;
let project = Project::test(fs.clone(), [], cx).await;
let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
let connection = Rc::new(FakeAgentConnection::new().on_user_message(
move |_, thread, mut cx| {
let read_file_tx = read_file_tx.clone();
async move {
let content = thread
.update(&mut cx, |thread, cx| {
thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
})
.unwrap()
.await
.unwrap();
assert_eq!(content, "one\ntwo\nthree\n");
read_file_tx.take().unwrap().send(()).unwrap();
thread
.update(&mut cx, |thread, cx| {
thread.write_text_file(
path!("/tmp/foo").into(),
"one\ntwo\nthree\nfour\nfive\n".to_string(),
cx,
)
})
.unwrap()
.await
.unwrap();
Ok(())
}
.boxed_local()
},
));
let (worktree, pathbuf) = project
.update(cx, |project, cx| {
project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
@ -1428,38 +1462,10 @@ mod tests {
.await
.unwrap();
let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
fake_server.update(cx, |fake_server, _| {
fake_server.on_user_message(move |_, server, mut cx| {
let read_file_tx = read_file_tx.clone();
async move {
let content = server
.update(&mut cx, |server, _| {
server.send_to_zed(acp_old::ReadTextFileParams {
path: path!("/tmp/foo").into(),
line: None,
limit: None,
})
})?
.await
.unwrap();
assert_eq!(content.content, "one\ntwo\nthree\n");
read_file_tx.take().unwrap().send(()).unwrap();
server
.update(&mut cx, |server, _| {
server.send_to_zed(acp_old::WriteTextFileParams {
path: path!("/tmp/foo").into(),
content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
})
})?
.await
.unwrap();
Ok(())
}
})
});
let thread = cx
.spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
.await
.unwrap();
let request = thread.update(cx, |thread, cx| {
thread.send_raw("Extend the count in /tmp/foo", cx)
@ -1486,36 +1492,44 @@ mod tests {
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
let (thread, fake_server) = fake_acp_thread(project, cx);
let id = acp::ToolCallId("test".into());
let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
let tool_call_id = Rc::new(RefCell::new(None));
let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
fake_server.update(cx, |fake_server, _| {
let tool_call_id = tool_call_id.clone();
fake_server.on_user_message(move |_, server, mut cx| {
let end_turn_rx = end_turn_rx.clone();
let tool_call_id = tool_call_id.clone();
let connection = Rc::new(FakeAgentConnection::new().on_user_message({
let id = id.clone();
move |_, thread, mut cx| {
let id = id.clone();
async move {
let tool_call_result = server
.update(&mut cx, |server, _| {
server.send_to_zed(acp_old::PushToolCallParams {
label: "Fetch".to_string(),
icon: acp_old::Icon::Globe,
content: None,
locations: vec![],
})
})?
.await
thread
.update(&mut cx, |thread, cx| {
thread.handle_session_update(
acp::SessionUpdate::ToolCall(acp::ToolCall {
id: id.clone(),
label: "Label".into(),
kind: acp::ToolKind::Fetch,
status: acp::ToolCallStatus::InProgress,
content: vec![],
locations: vec![],
raw_input: None,
}),
cx,
)
})
.unwrap()
.unwrap();
*tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
end_turn_rx.take().unwrap().await.ok();
Ok(())
}
.boxed_local()
}
}));
let thread = cx
.spawn(async move |mut cx| {
connection
.new_thread(project, Path::new(path!("/test")), &mut cx)
.await
})
});
.await
.unwrap();
let request = thread.update(cx, |thread, cx| {
thread.send_raw("Fetch https://example.com", cx)
@ -1536,8 +1550,6 @@ mod tests {
));
});
cx.run_until_parked();
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
thread.read_with(cx, |thread, _| {
@ -1550,19 +1562,22 @@ mod tests {
));
});
fake_server
.update(cx, |fake_server, _| {
fake_server.send_to_zed(acp_old::UpdateToolCallParams {
tool_call_id: tool_call_id.borrow().unwrap(),
status: acp_old::ToolCallStatus::Finished,
content: None,
})
thread
.update(cx, |thread, cx| {
thread.handle_session_update(
acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
id,
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
..Default::default()
},
}),
cx,
)
})
.await
.unwrap();
drop(end_turn_tx);
assert!(request.await.unwrap_err().to_string().contains("canceled"));
request.await.unwrap();
thread.read_with(cx, |thread, _| {
assert!(matches!(
@ -1585,23 +1600,37 @@ mod tests {
fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
let connection = Rc::new(StubAgentConnection::new(vec![
acp::SessionUpdate::ToolCall(acp::ToolCall {
id: acp::ToolCallId("test".into()),
label: "Label".into(),
kind: acp::ToolKind::Edit,
status: acp::ToolCallStatus::Completed,
content: vec![acp::ToolCallContent::Diff {
diff: acp::Diff {
path: "/test/test.txt".into(),
old_text: None,
new_text: "foo".into(),
},
}],
locations: vec![],
raw_input: None,
}),
]));
let connection = Rc::new(FakeAgentConnection::new().on_user_message({
move |_, thread, mut cx| {
async move {
thread
.update(&mut cx, |thread, cx| {
thread.handle_session_update(
acp::SessionUpdate::ToolCall(acp::ToolCall {
id: acp::ToolCallId("test".into()),
label: "Label".into(),
kind: acp::ToolKind::Edit,
status: acp::ToolCallStatus::Completed,
content: vec![acp::ToolCallContent::Diff {
diff: acp::Diff {
path: "/test/test.txt".into(),
old_text: None,
new_text: "foo".into(),
},
}],
locations: vec![],
raw_input: None,
}),
cx,
)
})
.unwrap()
.unwrap();
Ok(())
}
.boxed_local()
}
}));
let thread = connection
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
@ -1642,25 +1671,53 @@ mod tests {
}
#[derive(Clone, Default)]
struct StubAgentConnection {
struct FakeAgentConnection {
auth_methods: Vec<acp::AuthMethod>,
sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
updates: Vec<acp::SessionUpdate>,
on_user_message: Option<
Rc<
dyn Fn(
acp::PromptRequest,
WeakEntity<AcpThread>,
AsyncApp,
) -> LocalBoxFuture<'static, Result<()>>
+ 'static,
>,
>,
}
impl StubAgentConnection {
fn new(updates: Vec<acp::SessionUpdate>) -> Self {
impl FakeAgentConnection {
fn new() -> Self {
Self {
updates,
permission_requests: HashMap::default(),
auth_methods: Vec::new(),
on_user_message: None,
sessions: Arc::default(),
}
}
#[expect(unused)]
fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
self.auth_methods = auth_methods;
self
}
fn on_user_message(
mut self,
handler: impl Fn(
acp::PromptRequest,
WeakEntity<AcpThread>,
AsyncApp,
) -> LocalBoxFuture<'static, Result<()>>
+ 'static,
) -> Self {
self.on_user_message.replace(Rc::new(handler));
self
}
}
impl AgentConnection for StubAgentConnection {
fn name(&self) -> &'static str {
"StubAgentConnection"
impl AgentConnection for FakeAgentConnection {
fn auth_methods(&self) -> &[acp::AuthMethod] {
&self.auth_methods
}
fn new_thread(
@ -1678,222 +1735,43 @@ mod tests {
.into(),
);
let thread = cx
.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))
.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
.unwrap();
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}
fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
unimplemented!()
fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
if self.auth_methods().iter().any(|m| m.id == method) {
Task::ready(Ok(()))
} else {
Task::ready(Err(anyhow!("Invalid Auth Method")))
}
}
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> {
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
let sessions = self.sessions.lock();
let thread = sessions.get(&params.session_id).unwrap();
let mut tasks = vec![];
for update in &self.updates {
if let Some(handler) = &self.on_user_message {
let handler = handler.clone();
let thread = thread.clone();
let update = update.clone();
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
&& let Some(options) = self.permission_requests.get(&tool_call.id)
{
Some((tool_call.clone(), options.clone()))
} else {
None
};
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_permission(
tool_call.clone(),
options.clone(),
cx,
)
})?;
permission.await?;
}
thread.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})?;
anyhow::Ok(())
});
tasks.push(task);
}
cx.spawn(async move |_| {
try_join_all(tasks).await?;
Ok(())
})
}
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
unimplemented!()
}
}
pub fn fake_acp_thread(
project: Entity<Project>,
cx: &mut TestAppContext,
) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
let (stdin_tx, stdin_rx) = async_pipe::pipe();
let (stdout_tx, stdout_rx) = async_pipe::pipe();
let thread = cx.new(|cx| {
let foreground_executor = cx.foreground_executor().clone();
let thread_rc = Rc::new(RefCell::new(cx.entity().downgrade()));
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
OldAcpClientDelegate::new(thread_rc.clone(), cx.to_async()),
stdin_tx,
stdout_rx,
move |fut| {
foreground_executor.spawn(fut).detach();
},
);
let io_task = cx.background_spawn({
async move {
io_fut.await.log_err();
Ok(())
}
});
let connection = OldAcpAgentConnection {
name: "test",
connection,
child_status: io_task,
current_thread: thread_rc,
};
AcpThread::new(
Rc::new(connection),
project,
acp::SessionId("test".into()),
cx,
)
});
let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
(thread, agent)
}
pub struct FakeAcpServer {
connection: acp_old::ClientConnection,
_io_task: Task<()>,
on_user_message: Option<
Rc<
dyn Fn(
acp_old::SendUserMessageParams,
Entity<FakeAcpServer>,
AsyncApp,
) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>,
>,
>,
}
#[derive(Clone)]
struct FakeAgent {
server: Entity<FakeAcpServer>,
cx: AsyncApp,
cancel_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
}
impl acp_old::Agent for FakeAgent {
async fn initialize(
&self,
params: acp_old::InitializeParams,
) -> Result<acp_old::InitializeResponse, acp_old::Error> {
Ok(acp_old::InitializeResponse {
protocol_version: params.protocol_version,
is_authenticated: true,
})
}
async fn authenticate(&self) -> Result<(), acp_old::Error> {
Ok(())
}
async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
if let Some(cancel_tx) = self.cancel_tx.take() {
cancel_tx.send(()).log_err();
}
Ok(())
}
async fn send_user_message(
&self,
request: acp_old::SendUserMessageParams,
) -> Result<(), acp_old::Error> {
let (cancel_tx, cancel_rx) = oneshot::channel();
self.cancel_tx.replace(Some(cancel_tx));
let mut cx = self.cx.clone();
let handler = self
.server
.update(&mut cx, |server, _| server.on_user_message.clone())
.ok()
.flatten();
if let Some(handler) = handler {
select! {
_ = cancel_rx.fuse() => Err(anyhow::anyhow!("Message sending canceled").into()),
_ = handler(request, self.server.clone(), self.cx.clone()).fuse() => Ok(()),
}
cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
} else {
Err(anyhow::anyhow!("No handler for on_user_message").into())
}
}
}
impl FakeAcpServer {
fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
let agent = FakeAgent {
server: cx.entity(),
cx: cx.to_async(),
cancel_tx: Default::default(),
};
let foreground_executor = cx.foreground_executor().clone();
let (connection, io_fut) = acp_old::ClientConnection::connect_to_client(
agent.clone(),
stdout,
stdin,
move |fut| {
foreground_executor.spawn(fut).detach();
},
);
FakeAcpServer {
connection: connection,
on_user_message: None,
_io_task: cx.background_spawn(async move {
io_fut.await.log_err();
}),
Task::ready(Ok(()))
}
}
fn on_user_message<F>(
&mut self,
handler: impl for<'a> Fn(
acp_old::SendUserMessageParams,
Entity<FakeAcpServer>,
AsyncApp,
) -> F
+ 'static,
) where
F: Future<Output = Result<(), acp_old::Error>> + 'static,
{
self.on_user_message
.replace(Rc::new(move |request, server, cx| {
handler(request, server, cx).boxed_local()
}));
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
let sessions = self.sessions.lock();
let thread = sessions.get(&session_id).unwrap().clone();
fn send_to_zed<T: acp_old::ClientRequest + 'static>(
&self,
message: T,
) -> BoxedLocal<Result<T::Response>> {
self.connection
.request(message)
.map(|f| f.map_err(|err| anyhow!(err)))
.boxed_local()
cx.spawn(async move |cx| {
thread
.update(cx, |thread, cx| thread.cancel(cx))
.unwrap()
.await
})
.detach();
}
}
}