Merge branch 'acp' of github.com:zed-industries/zed into acp

This commit is contained in:
Agus Zubiaga 2025-07-02 11:33:22 -03:00
commit 348bc52a3f
2 changed files with 30 additions and 91 deletions

View file

@ -123,7 +123,7 @@ pub enum AgentThreadEntryContent {
#[derive(Debug)] #[derive(Debug)]
pub struct ToolCall { pub struct ToolCall {
id: ToolCallId, id: ToolCallId,
tool_name: Entity<Markdown>, display_name: Entity<Markdown>,
status: ToolCallStatus, status: ToolCallStatus,
} }
@ -271,7 +271,7 @@ impl AcpThread {
pub fn request_tool_call( pub fn request_tool_call(
&mut self, &mut self,
title: String, display_name: String,
confirmation: acp::ToolCallConfirmation, confirmation: acp::ToolCallConfirmation,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> ToolCallRequest { ) -> ToolCallRequest {
@ -282,22 +282,22 @@ impl AcpThread {
respond_tx: tx, respond_tx: tx,
}; };
let id = self.insert_tool_call(title, status, cx); let id = self.insert_tool_call(display_name, status, cx);
ToolCallRequest { id, outcome: rx } ToolCallRequest { id, outcome: rx }
} }
pub fn push_tool_call(&mut self, title: String, cx: &mut Context<Self>) -> ToolCallId { pub fn push_tool_call(&mut self, display_name: String, cx: &mut Context<Self>) -> ToolCallId {
let status = ToolCallStatus::Allowed { let status = ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Running, status: acp::ToolCallStatus::Running,
content: None, content: None,
}; };
self.insert_tool_call(title, status, cx) self.insert_tool_call(display_name, status, cx)
} }
fn insert_tool_call( fn insert_tool_call(
&mut self, &mut self,
title: String, display_name: String,
status: ToolCallStatus, status: ToolCallStatus,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> ToolCallId { ) -> ToolCallId {
@ -307,8 +307,13 @@ impl AcpThread {
AgentThreadEntryContent::ToolCall(ToolCall { AgentThreadEntryContent::ToolCall(ToolCall {
// todo! clean up id creation // todo! clean up id creation
id: ToolCallId(ThreadEntryId(self.entries.len() as u64)), id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
tool_name: cx.new(|cx| { display_name: cx.new(|cx| {
Markdown::new(title.into(), Some(language_registry.clone()), None, cx) Markdown::new(
display_name.into(),
Some(language_registry.clone()),
None,
cx,
)
}), }),
status, status,
}), }),
@ -441,13 +446,11 @@ pub struct ToolCallRequest {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures::{FutureExt as _, channel::mpsc, select};
use gpui::{AsyncApp, TestAppContext}; use gpui::{AsyncApp, TestAppContext};
use project::FakeFs; use project::FakeFs;
use serde_json::json; use serde_json::json;
use settings::SettingsStore; use settings::SettingsStore;
use smol::stream::StreamExt; use std::{env, path::Path, process::Stdio};
use std::{env, path::Path, process::Stdio, time::Duration};
use util::path; use util::path;
fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {
@ -509,27 +512,27 @@ mod tests {
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap(); let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
let thread = server.create_thread(&mut cx.to_async()).await.unwrap(); let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
let full_turn = thread.update(cx, |thread, cx| { thread
.update(cx, |thread, cx| {
thread.send( thread.send(
"Read the '/private/tmp/foo' file and tell me what you see.", "Read the '/private/tmp/foo' file and tell me what you see.",
cx, cx,
) )
}); })
.await
run_until_tool_call(&thread, cx).await; .unwrap();
thread.read_with(cx, |thread, cx| {
let tool_call_id = thread.read_with(cx, |thread, cx| {
let AgentThreadEntryContent::ToolCall(ToolCall { let AgentThreadEntryContent::ToolCall(ToolCall {
id, id,
tool_name, display_name,
status: ToolCallStatus::Allowed { .. }, status: ToolCallStatus::Allowed { content, .. },
}) = &thread.entries().last().unwrap().content }) = &thread.entries()[1].content
else { else {
panic!(); panic!();
}; };
tool_name.read_with(cx, |md, _cx| { display_name.read_with(cx, |md, _cx| {
assert_eq!(md.source(), "read_file"); assert_eq!(md.source(), "ReadFile");
}); });
// todo! // todo!
@ -542,70 +545,6 @@ mod tests {
// }); // });
*id *id
}); });
thread.update(cx, |thread, cx| {
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
assert!(matches!(
thread.entries().last().unwrap().content,
AgentThreadEntryContent::ToolCall(ToolCall {
status: ToolCallStatus::Allowed { .. },
..
})
));
});
full_turn.await.unwrap();
thread.read_with(cx, |thread, _| {
assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
assert!(matches!(
thread.entries[0].content,
AgentThreadEntryContent::Message(Message {
role: Role::User,
..
})
));
assert!(matches!(
thread.entries[1].content,
AgentThreadEntryContent::ToolCall(ToolCall {
status: ToolCallStatus::Allowed { .. },
..
})
));
assert!(matches!(
thread.entries[2].content,
AgentThreadEntryContent::Message(Message {
role: Role::Assistant,
..
})
));
});
}
async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
let (mut tx, mut rx) = mpsc::channel(1);
let subscription = cx.update(|cx| {
cx.subscribe(thread, move |thread, _, cx| {
if thread
.read(cx)
.entries
.iter()
.any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
{
tx.try_send(()).unwrap();
}
})
});
select! {
_ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
panic!("Timeout waiting for tool call")
}
_ = rx.next().fuse() => {
drop(subscription);
}
}
} }
pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> { pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {

View file

@ -329,7 +329,7 @@ impl AcpThreadView {
.color(Color::Muted), .color(Color::Muted),
) )
.child(MarkdownElement::new( .child(MarkdownElement::new(
tool_call.tool_name.clone(), tool_call.display_name.clone(),
default_markdown_style(window, cx), default_markdown_style(window, cx),
)) ))
.child(div().w_full()) .child(div().w_full())