diff --git a/Cargo.lock b/Cargo.lock index a8a8d12e37..4d84249b00 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -158,6 +158,7 @@ dependencies = [ "serde_json", "settings", "smol", + "strum 0.27.1", "tempfile", "ui", "util", diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 1e3947351a..ae22725d5e 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -664,7 +664,7 @@ impl AcpThread { cx: &mut Context, ) -> Result { let project = self.project.read(cx).languages().clone(); - let Some((_, call)) = self.tool_call_mut(tool_call_id) else { + let Some((idx, call)) = self.tool_call_mut(tool_call_id) else { anyhow::bail!("Tool call not found"); }; @@ -675,6 +675,8 @@ impl AcpThread { respond_tx: tx, }; + cx.emit(AcpThreadEvent::EntryUpdated(idx)); + Ok(ToolCallRequest { id: tool_call_id, outcome: rx, @@ -768,8 +770,13 @@ impl AcpThread { let language_registry = self.project.read(cx).languages().clone(); let (ix, call) = self.tool_call_mut(id).context("Entry not found")?; - call.content = new_content - .map(|new_content| ToolCallContent::from_acp(new_content, language_registry, cx)); + if let Some(new_content) = new_content { + call.content = Some(ToolCallContent::from_acp( + new_content, + language_registry, + cx, + )); + } match &mut call.status { ToolCallStatus::Allowed { status } => { diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 2d68148264..f3df25f709 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -33,6 +33,7 @@ serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true +strum.workspace = true tempfile.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 5760a96d8c..52c6012267 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -281,14 +281,18 @@ impl ClaudeAgentConnection { } => { let id = tool_id_map.borrow_mut().remove(&tool_use_id); if let Some(id) = id { + let content = content.to_string(); delegate .update_tool_call(UpdateToolCallParams { tool_call_id: id, status: acp::ToolCallStatus::Finished, - content: Some(ToolCallContent::Markdown { - // For now we only include text content - markdown: content.to_string(), - }), + // Don't unset existing content + content: (!content.is_empty()).then_some( + ToolCallContent::Markdown { + // For now we only include text content + markdown: content, + }, + ), }) .await .log_err(); @@ -577,7 +581,7 @@ pub(crate) mod tests { use super::*; use serde_json::json; - // crate::common_e2e_tests!(ClaudeCode); + crate::common_e2e_tests!(ClaudeCode); pub fn local_command() -> AgentServerCommand { AgentServerCommand { diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index e3ac6c14e2..a2d6b487b2 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -118,13 +118,106 @@ impl ClaudeTool { pub fn content(&self) -> Option { match &self { - ClaudeTool::Other { input, .. } => Some(acp::ToolCallContent::Markdown { + Self::Other { input, .. } => Some(acp::ToolCallContent::Markdown { markdown: format!( "```json\n{}```", serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) ), }), - _ => None, + Self::Task(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.prompt.clone(), + }), + Self::NotebookRead(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.notebook_path.display().to_string(), + }), + Self::NotebookEdit(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.new_source.clone(), + }), + Self::Terminal(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: format!( + "`{}`\n\n{}", + params.command, + params.description.as_deref().unwrap_or_default() + ), + }), + Self::ReadFile(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.abs_path.display().to_string(), + }), + Self::Ls(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.path.display().to_string(), + }), + Self::Glob(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.to_string(), + }), + Self::Grep(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: format!("`{params}`"), + }), + Self::WebFetch(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.prompt.clone(), + }), + Self::WebSearch(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.to_string(), + }), + Self::TodoWrite(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params + .todos + .iter() + .map(|todo| { + format!( + "- {} {}: {}", + match todo.status { + TodoStatus::Completed => "✅", + TodoStatus::InProgress => "🚧", + TodoStatus::Pending => "⬜", + }, + todo.priority, + todo.content + ) + }) + .join("\n"), + }), + Self::ExitPlanMode(Some(params)) => Some(acp::ToolCallContent::Markdown { + markdown: params.plan.clone(), + }), + Self::Edit(Some(params)) => Some(acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.abs_path.clone(), + old_text: Some(params.old_text.clone()), + new_text: params.new_text.clone(), + }, + }), + Self::Write(Some(params)) => Some(acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: None, + new_text: params.content.clone(), + }, + }), + Self::MultiEdit(Some(params)) => { + // todo: show multiple edits in a multibuffer? + params.edits.first().map(|edit| acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: Some(edit.old_string.clone()), + new_text: edit.new_string.clone(), + }, + }) + } + Self::Task(None) + | Self::NotebookRead(None) + | Self::NotebookEdit(None) + | Self::Terminal(None) + | Self::ReadFile(None) + | Self::Ls(None) + | Self::Glob(None) + | Self::Grep(None) + | Self::WebFetch(None) + | Self::WebSearch(None) + | Self::TodoWrite(None) + | Self::ExitPlanMode(None) + | Self::Edit(None) + | Self::Write(None) + | Self::MultiEdit(None) => None, } } @@ -513,7 +606,7 @@ impl std::fmt::Display for GrepToolParams { } } -#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)] #[serde(rename_all = "snake_case")] pub enum TodoPriority { High, diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 923c6cdd6f..12f74cb13e 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -111,18 +111,21 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp .await .unwrap(); thread.read_with(cx, |thread, _cx| { - assert!(matches!( - &thread.entries()[2], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { .. }, - .. - }) - )); - - assert!(matches!( - thread.entries()[3], - AgentThreadEntry::AssistantMessage(_) - )); + assert!(thread.entries().iter().any(|entry| { + matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + ) + })); + assert!( + thread + .entries() + .iter() + .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) }) + ); }); } @@ -134,10 +137,26 @@ pub async fn test_tool_call_with_confirmation( let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let full_turn = thread.update(cx, |thread, cx| { - thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) + thread.send_raw( + r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#, + cx, + ) }); - run_until_first_tool_call(&thread, cx).await; + run_until_first_tool_call( + &thread, + |entry| { + matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) + ) + }, + cx, + ) + .await; let tool_call_id = thread.read_with(cx, |thread, _cx| { let AgentThreadEntry::ToolCall(ToolCall { @@ -148,12 +167,16 @@ pub async fn test_tool_call_with_confirmation( .. }, .. - }) = &thread.entries()[2] + }) = &thread + .entries() + .iter() + .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_))) + .unwrap() else { panic!(); }; - assert_eq!(root_command, "echo"); + assert!(root_command.contains("touch")); *id }); @@ -161,13 +184,13 @@ pub async fn test_tool_call_with_confirmation( thread.update(cx, |thread, cx| { thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); - assert!(matches!( - &thread.entries()[2], + assert!(thread.entries().iter().any(|entry| matches!( + entry, AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Allowed { .. }, .. }) - )); + ))); }); full_turn.await.unwrap(); @@ -177,15 +200,19 @@ pub async fn test_tool_call_with_confirmation( content: Some(ToolCallContent::Markdown { markdown }), status: ToolCallStatus::Allowed { .. }, .. - }) = &thread.entries()[2] + }) = thread + .entries() + .iter() + .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_))) + .unwrap() else { panic!(); }; markdown.read_with(cx, |md, _cx| { assert!( - md.source().contains("Hello, world!"), - r#"Expected '{}' to contain "Hello, world!""#, + md.source().contains("Hello"), + r#"Expected '{}' to contain "Hello""#, md.source() ); }); @@ -198,10 +225,26 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let full_turn = thread.update(cx, |thread, cx| { - thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) + thread.send_raw( + r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#, + cx, + ) }); - let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await; + let first_tool_call_ix = run_until_first_tool_call( + &thread, + |entry| { + matches!( + entry, + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::WaitingForConfirmation { .. }, + .. + }) + ) + }, + cx, + ) + .await; thread.read_with(cx, |thread, _cx| { let AgentThreadEntry::ToolCall(ToolCall { @@ -217,7 +260,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon panic!("{:?}", thread.entries()[1]); }; - assert_eq!(root_command, "echo"); + assert!(root_command.contains("touch")); *id }); @@ -340,6 +383,7 @@ pub async fn new_test_thread( pub async fn run_until_first_tool_call( thread: &Entity, + wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static, cx: &mut TestAppContext, ) -> usize { let (mut tx, mut rx) = mpsc::channel::(1); @@ -347,7 +391,7 @@ pub async fn run_until_first_tool_call( let subscription = cx.update(|cx| { cx.subscribe(thread, move |thread, _, cx| { for (ix, entry) in thread.read(cx).entries().iter().enumerate() { - if matches!(entry, AgentThreadEntry::ToolCall(_)) { + if wait_until(entry) { return tx.try_send(ix).unwrap(); } } @@ -357,7 +401,7 @@ pub async fn run_until_first_tool_call( select! { // We have to use a smol timer here because // cx.background_executor().timer isn't real in the test context - _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => { + _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => { panic!("Timeout waiting for tool call") } ix = rx.next().fuse() => {