Allow Permission Request to only require a ToolCallUpdate instead of a full tool call (#36319)
Release Notes: - N/A
This commit is contained in:
parent
7b3fe0a474
commit
332626e582
11 changed files with 63 additions and 61 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -172,9 +172,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "agent-client-protocol"
|
||||
version = "0.0.24"
|
||||
version = "0.0.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fd68bbbef8e424fb8a605c5f0b00c360f682c4528b0a5feb5ec928aaf5ce28e"
|
||||
checksum = "2ab66add8be8d6a963f5bf4070045c1bbf36472837654c73e2298dd16bda5bf7"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
|
|
|
@ -426,7 +426,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
|||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = "0.0.24"
|
||||
agent-client-protocol = "0.0.25"
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
|
|
|
@ -792,7 +792,7 @@ impl AcpThread {
|
|||
&mut self,
|
||||
update: acp::SessionUpdate,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
) -> Result<(), acp::Error> {
|
||||
match update {
|
||||
acp::SessionUpdate::UserMessageChunk { content } => {
|
||||
self.push_user_content_block(None, content, cx);
|
||||
|
@ -804,7 +804,7 @@ impl AcpThread {
|
|||
self.push_assistant_content_block(content, true, cx);
|
||||
}
|
||||
acp::SessionUpdate::ToolCall(tool_call) => {
|
||||
self.upsert_tool_call(tool_call, cx);
|
||||
self.upsert_tool_call(tool_call, cx)?;
|
||||
}
|
||||
acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
|
||||
self.update_tool_call(tool_call_update, cx)?;
|
||||
|
@ -940,32 +940,40 @@ impl AcpThread {
|
|||
}
|
||||
|
||||
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
|
||||
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
|
||||
pub fn upsert_tool_call(
|
||||
&mut self,
|
||||
tool_call: acp::ToolCall,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<(), acp::Error> {
|
||||
let status = ToolCallStatus::Allowed {
|
||||
status: tool_call.status,
|
||||
};
|
||||
self.upsert_tool_call_inner(tool_call, status, cx)
|
||||
self.upsert_tool_call_inner(tool_call.into(), status, cx)
|
||||
}
|
||||
|
||||
/// Fails if id does not match an existing entry.
|
||||
pub fn upsert_tool_call_inner(
|
||||
&mut self,
|
||||
tool_call: acp::ToolCall,
|
||||
tool_call_update: acp::ToolCallUpdate,
|
||||
status: ToolCallStatus,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
) -> Result<(), acp::Error> {
|
||||
let language_registry = self.project.read(cx).languages().clone();
|
||||
let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
|
||||
let id = call.id.clone();
|
||||
let id = tool_call_update.id.clone();
|
||||
|
||||
if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
|
||||
*current_call = call;
|
||||
if let Some((ix, current_call)) = self.tool_call_mut(&id) {
|
||||
current_call.update_fields(tool_call_update.fields, language_registry, cx);
|
||||
current_call.status = status;
|
||||
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
} else {
|
||||
let call =
|
||||
ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
|
||||
self.push_entry(AgentThreadEntry::ToolCall(call), cx);
|
||||
};
|
||||
|
||||
self.resolve_locations(id, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
|
||||
|
@ -1034,10 +1042,10 @@ impl AcpThread {
|
|||
|
||||
pub fn request_tool_call_authorization(
|
||||
&mut self,
|
||||
tool_call: acp::ToolCall,
|
||||
tool_call: acp::ToolCallUpdate,
|
||||
options: Vec<acp::PermissionOption>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> oneshot::Receiver<acp::PermissionOptionId> {
|
||||
) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let status = ToolCallStatus::WaitingForConfirmation {
|
||||
|
@ -1045,9 +1053,9 @@ impl AcpThread {
|
|||
respond_tx: tx,
|
||||
};
|
||||
|
||||
self.upsert_tool_call_inner(tool_call, status, cx);
|
||||
self.upsert_tool_call_inner(tool_call, status, cx)?;
|
||||
cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
|
||||
rx
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
pub fn authorize_tool_call(
|
||||
|
|
|
@ -286,12 +286,12 @@ mod test_support {
|
|||
if let Some((tool_call, options)) = permission_request {
|
||||
let permission = thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(
|
||||
tool_call.clone(),
|
||||
tool_call.clone().into(),
|
||||
options.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
permission.await?;
|
||||
permission?.await?;
|
||||
}
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(update.clone(), cx).unwrap();
|
||||
|
|
|
@ -514,7 +514,8 @@ impl NativeAgentConnection {
|
|||
thread.request_tool_call_authorization(tool_call, options, cx)
|
||||
})?;
|
||||
cx.background_spawn(async move {
|
||||
if let Some(option) = recv
|
||||
if let Some(recv) = recv.log_err()
|
||||
&& let Some(option) = recv
|
||||
.await
|
||||
.context("authorization sender was dropped")
|
||||
.log_err()
|
||||
|
@ -530,7 +531,7 @@ impl NativeAgentConnection {
|
|||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.upsert_tool_call(tool_call, cx)
|
||||
})?;
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
|
|
|
@ -448,7 +448,7 @@ pub enum AgentResponseEvent {
|
|||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCallAuthorization {
|
||||
pub tool_call: acp::ToolCall,
|
||||
pub tool_call: acp::ToolCallUpdate,
|
||||
pub options: Vec<acp::PermissionOption>,
|
||||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
@ -901,7 +901,7 @@ impl Thread {
|
|||
|
||||
let fs = self.project.read(cx).fs().clone();
|
||||
let tool_event_stream =
|
||||
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
|
||||
ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
|
||||
tool_event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
|
@ -1344,8 +1344,6 @@ impl AgentResponseEventStream {
|
|||
#[derive(Clone)]
|
||||
pub struct ToolCallEventStream {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
stream: AgentResponseEventStream,
|
||||
fs: Option<Arc<dyn Fs>>,
|
||||
}
|
||||
|
@ -1355,32 +1353,19 @@ impl ToolCallEventStream {
|
|||
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||
|
||||
let stream = ToolCallEventStream::new(
|
||||
&LanguageModelToolUse {
|
||||
id: "test_id".into(),
|
||||
name: "test_tool".into(),
|
||||
raw_input: String::new(),
|
||||
input: serde_json::Value::Null,
|
||||
is_input_complete: true,
|
||||
},
|
||||
acp::ToolKind::Other,
|
||||
AgentResponseEventStream(events_tx),
|
||||
None,
|
||||
);
|
||||
let stream =
|
||||
ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None);
|
||||
|
||||
(stream, ToolCallEventStreamReceiver(events_rx))
|
||||
}
|
||||
|
||||
fn new(
|
||||
tool_use: &LanguageModelToolUse,
|
||||
kind: acp::ToolKind,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
stream: AgentResponseEventStream,
|
||||
fs: Option<Arc<dyn Fs>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
kind,
|
||||
input: tool_use.input.clone(),
|
||||
tool_use_id,
|
||||
stream,
|
||||
fs,
|
||||
}
|
||||
|
@ -1427,12 +1412,13 @@ impl ToolCallEventStream {
|
|||
.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
|
||||
ToolCallAuthorization {
|
||||
tool_call: AgentResponseEventStream::initial_tool_call(
|
||||
&self.tool_use_id,
|
||||
title.into(),
|
||||
self.kind.clone(),
|
||||
self.input.clone(),
|
||||
),
|
||||
tool_call: acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
title: Some(title.into()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
options: vec![
|
||||
acp::PermissionOption {
|
||||
id: acp::PermissionOptionId("always_allow".into()),
|
||||
|
|
|
@ -1001,7 +1001,10 @@ mod tests {
|
|||
});
|
||||
|
||||
let event = stream_rx.expect_authorization().await;
|
||||
assert_eq!(event.tool_call.title, "test 1 (local settings)");
|
||||
assert_eq!(
|
||||
event.tool_call.fields.title,
|
||||
Some("test 1 (local settings)".into())
|
||||
);
|
||||
|
||||
// Test 2: Path outside project should require confirmation
|
||||
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
|
||||
|
@ -1018,7 +1021,7 @@ mod tests {
|
|||
});
|
||||
|
||||
let event = stream_rx.expect_authorization().await;
|
||||
assert_eq!(event.tool_call.title, "test 2");
|
||||
assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
|
||||
|
||||
// Test 3: Relative path without .zed should not require confirmation
|
||||
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
|
||||
|
@ -1051,7 +1054,10 @@ mod tests {
|
|||
)
|
||||
});
|
||||
let event = stream_rx.expect_authorization().await;
|
||||
assert_eq!(event.tool_call.title, "test 4 (local settings)");
|
||||
assert_eq!(
|
||||
event.tool_call.fields.title,
|
||||
Some("test 4 (local settings)".into())
|
||||
);
|
||||
|
||||
// Test 5: When always_allow_tool_actions is enabled, no confirmation needed
|
||||
cx.update(|cx| {
|
||||
|
|
|
@ -135,9 +135,9 @@ impl acp_old::Client for OldAcpClientDelegate {
|
|||
let response = cx
|
||||
.update(|cx| {
|
||||
self.thread.borrow().update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(tool_call, acp_options, cx)
|
||||
thread.request_tool_call_authorization(tool_call.into(), acp_options, cx)
|
||||
})
|
||||
})?
|
||||
})??
|
||||
.context("Failed to update thread")?
|
||||
.await;
|
||||
|
||||
|
@ -168,7 +168,7 @@ impl acp_old::Client for OldAcpClientDelegate {
|
|||
cx,
|
||||
)
|
||||
})
|
||||
})?
|
||||
})??
|
||||
.context("Failed to update thread")?;
|
||||
|
||||
Ok(acp_old::PushToolCallResponse {
|
||||
|
|
|
@ -233,7 +233,7 @@ impl acp::Client for ClientDelegate {
|
|||
thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
|
||||
})?;
|
||||
|
||||
let result = rx.await;
|
||||
let result = rx?.await;
|
||||
|
||||
let outcome = match result {
|
||||
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
|
||||
|
|
|
@ -560,8 +560,9 @@ impl ClaudeAgentSession {
|
|||
thread.upsert_tool_call(
|
||||
claude_tool.as_acp(acp::ToolCallId(id.into())),
|
||||
cx,
|
||||
);
|
||||
)?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
|
|
@ -154,7 +154,7 @@ impl McpServerTool for PermissionTool {
|
|||
let chosen_option = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(
|
||||
claude_tool.as_acp(tool_call_id),
|
||||
claude_tool.as_acp(tool_call_id).into(),
|
||||
vec![
|
||||
acp::PermissionOption {
|
||||
id: allow_option_id.clone(),
|
||||
|
@ -169,7 +169,7 @@ impl McpServerTool for PermissionTool {
|
|||
],
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
})??
|
||||
.await?;
|
||||
|
||||
let response = if chosen_option == allow_option_id {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue