From 332626e5825564e97afc969292c90d9b0fb40b6d Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Sat, 16 Aug 2025 17:04:09 +0200 Subject: [PATCH] Allow Permission Request to only require a ToolCallUpdate instead of a full tool call (#36319) Release Notes: - N/A --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/acp_thread/src/acp_thread.rs | 36 ++++++++++------- crates/acp_thread/src/connection.rs | 4 +- crates/agent2/src/agent.rs | 11 ++--- crates/agent2/src/thread.rs | 40 ++++++------------- crates/agent2/src/tools/edit_file_tool.rs | 12 ++++-- crates/agent_servers/src/acp/v0.rs | 6 +-- crates/agent_servers/src/acp/v1.rs | 2 +- crates/agent_servers/src/claude.rs | 3 +- crates/agent_servers/src/claude/mcp_server.rs | 4 +- 11 files changed, 63 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1bce72b3a1..f59d92739b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -172,9 +172,9 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.24" +version = "0.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fd68bbbef8e424fb8a605c5f0b00c360f682c4528b0a5feb5ec928aaf5ce28e" +checksum = "2ab66add8be8d6a963f5bf4070045c1bbf36472837654c73e2298dd16bda5bf7" dependencies = [ "anyhow", "futures 0.3.31", diff --git a/Cargo.toml b/Cargo.toml index 644b6c0f40..b467e8743e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -426,7 +426,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.24" +agent-client-protocol = "0.0.25" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 2ef94a3cbe..3bb1b99ba1 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -792,7 +792,7 @@ impl AcpThread { &mut self, update: acp::SessionUpdate, cx: &mut Context, - ) -> Result<()> { + ) -> Result<(), acp::Error> { match update { acp::SessionUpdate::UserMessageChunk { content } => { self.push_user_content_block(None, content, cx); @@ -804,7 +804,7 @@ impl AcpThread { self.push_assistant_content_block(content, true, cx); } acp::SessionUpdate::ToolCall(tool_call) => { - self.upsert_tool_call(tool_call, cx); + self.upsert_tool_call(tool_call, cx)?; } acp::SessionUpdate::ToolCallUpdate(tool_call_update) => { self.update_tool_call(tool_call_update, cx)?; @@ -940,32 +940,40 @@ impl AcpThread { } /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. - pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { + pub fn upsert_tool_call( + &mut self, + tool_call: acp::ToolCall, + cx: &mut Context, + ) -> Result<(), acp::Error> { let status = ToolCallStatus::Allowed { status: tool_call.status, }; - self.upsert_tool_call_inner(tool_call, status, cx) + self.upsert_tool_call_inner(tool_call.into(), status, cx) } + /// Fails if id does not match an existing entry. pub fn upsert_tool_call_inner( &mut self, - tool_call: acp::ToolCall, + tool_call_update: acp::ToolCallUpdate, status: ToolCallStatus, cx: &mut Context, - ) { + ) -> Result<(), acp::Error> { let language_registry = self.project.read(cx).languages().clone(); - let call = ToolCall::from_acp(tool_call, status, language_registry, cx); - let id = call.id.clone(); + let id = tool_call_update.id.clone(); - if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { - *current_call = call; + if let Some((ix, current_call)) = self.tool_call_mut(&id) { + current_call.update_fields(tool_call_update.fields, language_registry, cx); + current_call.status = status; cx.emit(AcpThreadEvent::EntryUpdated(ix)); } else { + let call = + ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx); self.push_entry(AgentThreadEntry::ToolCall(call), cx); }; self.resolve_locations(id, cx); + Ok(()) } fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { @@ -1034,10 +1042,10 @@ impl AcpThread { pub fn request_tool_call_authorization( &mut self, - tool_call: acp::ToolCall, + tool_call: acp::ToolCallUpdate, options: Vec, cx: &mut Context, - ) -> oneshot::Receiver { + ) -> Result, acp::Error> { let (tx, rx) = oneshot::channel(); let status = ToolCallStatus::WaitingForConfirmation { @@ -1045,9 +1053,9 @@ impl AcpThread { respond_tx: tx, }; - self.upsert_tool_call_inner(tool_call, status, cx); + self.upsert_tool_call_inner(tool_call, status, cx)?; cx.emit(AcpThreadEvent::ToolAuthorizationRequired); - rx + Ok(rx) } pub fn authorize_tool_call( diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index b2116020fb..7497d2309f 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -286,12 +286,12 @@ mod test_support { if let Some((tool_call, options)) = permission_request { let permission = thread.update(cx, |thread, cx| { thread.request_tool_call_authorization( - tool_call.clone(), + tool_call.clone().into(), options.clone(), cx, ) })?; - permission.await?; + permission?.await?; } thread.update(cx, |thread, cx| { thread.handle_session_update(update.clone(), cx).unwrap(); diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 358365d11f..d63e3f8134 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -514,10 +514,11 @@ impl NativeAgentConnection { thread.request_tool_call_authorization(tool_call, options, cx) })?; cx.background_spawn(async move { - if let Some(option) = recv - .await - .context("authorization sender was dropped") - .log_err() + if let Some(recv) = recv.log_err() + && let Some(option) = recv + .await + .context("authorization sender was dropped") + .log_err() { response .send(option) @@ -530,7 +531,7 @@ impl NativeAgentConnection { AgentResponseEvent::ToolCall(tool_call) => { acp_thread.update(cx, |thread, cx| { thread.upsert_tool_call(tool_call, cx) - })?; + })??; } AgentResponseEvent::ToolCallUpdate(update) => { acp_thread.update(cx, |thread, cx| { diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index cfd67f4b05..0741bb9e08 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -448,7 +448,7 @@ pub enum AgentResponseEvent { #[derive(Debug)] pub struct ToolCallAuthorization { - pub tool_call: acp::ToolCall, + pub tool_call: acp::ToolCallUpdate, pub options: Vec, pub response: oneshot::Sender, } @@ -901,7 +901,7 @@ impl Thread { let fs = self.project.read(cx).fs().clone(); let tool_event_stream = - ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs)); + ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs)); tool_event_stream.update_fields(acp::ToolCallUpdateFields { status: Some(acp::ToolCallStatus::InProgress), ..Default::default() @@ -1344,8 +1344,6 @@ impl AgentResponseEventStream { #[derive(Clone)] pub struct ToolCallEventStream { tool_use_id: LanguageModelToolUseId, - kind: acp::ToolKind, - input: serde_json::Value, stream: AgentResponseEventStream, fs: Option>, } @@ -1355,32 +1353,19 @@ impl ToolCallEventStream { pub fn test() -> (Self, ToolCallEventStreamReceiver) { let (events_tx, events_rx) = mpsc::unbounded::>(); - let stream = ToolCallEventStream::new( - &LanguageModelToolUse { - id: "test_id".into(), - name: "test_tool".into(), - raw_input: String::new(), - input: serde_json::Value::Null, - is_input_complete: true, - }, - acp::ToolKind::Other, - AgentResponseEventStream(events_tx), - None, - ); + let stream = + ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None); (stream, ToolCallEventStreamReceiver(events_rx)) } fn new( - tool_use: &LanguageModelToolUse, - kind: acp::ToolKind, + tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream, fs: Option>, ) -> Self { Self { - tool_use_id: tool_use.id.clone(), - kind, - input: tool_use.input.clone(), + tool_use_id, stream, fs, } @@ -1427,12 +1412,13 @@ impl ToolCallEventStream { .0 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( ToolCallAuthorization { - tool_call: AgentResponseEventStream::initial_tool_call( - &self.tool_use_id, - title.into(), - self.kind.clone(), - self.input.clone(), - ), + tool_call: acp::ToolCallUpdate { + id: acp::ToolCallId(self.tool_use_id.to_string().into()), + fields: acp::ToolCallUpdateFields { + title: Some(title.into()), + ..Default::default() + }, + }, options: vec![ acp::PermissionOption { id: acp::PermissionOptionId("always_allow".into()), diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index c77b9f6a69..4b4f98daec 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -1001,7 +1001,10 @@ mod tests { }); let event = stream_rx.expect_authorization().await; - assert_eq!(event.tool_call.title, "test 1 (local settings)"); + assert_eq!( + event.tool_call.fields.title, + Some("test 1 (local settings)".into()) + ); // Test 2: Path outside project should require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -1018,7 +1021,7 @@ mod tests { }); let event = stream_rx.expect_authorization().await; - assert_eq!(event.tool_call.title, "test 2"); + assert_eq!(event.tool_call.fields.title, Some("test 2".into())); // Test 3: Relative path without .zed should not require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -1051,7 +1054,10 @@ mod tests { ) }); let event = stream_rx.expect_authorization().await; - assert_eq!(event.tool_call.title, "test 4 (local settings)"); + assert_eq!( + event.tool_call.fields.title, + Some("test 4 (local settings)".into()) + ); // Test 5: When always_allow_tool_actions is enabled, no confirmation needed cx.update(|cx| { diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index e936c87643..74647f7313 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -135,9 +135,9 @@ impl acp_old::Client for OldAcpClientDelegate { let response = cx .update(|cx| { self.thread.borrow().update(cx, |thread, cx| { - thread.request_tool_call_authorization(tool_call, acp_options, cx) + thread.request_tool_call_authorization(tool_call.into(), acp_options, cx) }) - })? + })?? .context("Failed to update thread")? .await; @@ -168,7 +168,7 @@ impl acp_old::Client for OldAcpClientDelegate { cx, ) }) - })? + })?? .context("Failed to update thread")?; Ok(acp_old::PushToolCallResponse { diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index 6cf9801d06..506ae80886 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -233,7 +233,7 @@ impl acp::Client for ClientDelegate { thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx) })?; - let result = rx.await; + let result = rx?.await; let outcome = match result { Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 14a179ba3d..4b3a173349 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -560,8 +560,9 @@ impl ClaudeAgentSession { thread.upsert_tool_call( claude_tool.as_acp(acp::ToolCallId(id.into())), cx, - ); + )?; } + anyhow::Ok(()) }) .log_err(); } diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index 53a8556e74..22cb2f8f8d 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -154,7 +154,7 @@ impl McpServerTool for PermissionTool { let chosen_option = thread .update(cx, |thread, cx| { thread.request_tool_call_authorization( - claude_tool.as_acp(tool_call_id), + claude_tool.as_acp(tool_call_id).into(), vec![ acp::PermissionOption { id: allow_option_id.clone(), @@ -169,7 +169,7 @@ impl McpServerTool for PermissionTool { ], cx, ) - })? + })?? .await?; let response = if chosen_option == allow_option_id {