Centralize always_allow logic when authorizing agent2 tools (#35988)

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-11 19:22:19 +02:00 committed by GitHub
parent 56c4992b9a
commit 365b5aa31d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 136 additions and 40 deletions

View file

@ -1,10 +1,12 @@
use crate::{SystemPromptTemplate, Template, Templates};
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use cloud_llm_client::{CompletionIntent, CompletionMode};
use collections::HashMap;
use fs::Fs;
use futures::{
channel::{mpsc, oneshot},
stream::FuturesUnordered,
@ -21,8 +23,9 @@ use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc};
use util::{ResultExt, markdown::MarkdownCodeBlock};
#[derive(Debug, Clone)]
@ -506,8 +509,9 @@ impl Thread {
}));
};
let fs = self.project.read(cx).fs().clone();
let tool_event_stream =
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
tool_event_stream.update_fields(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
@ -884,6 +888,7 @@ pub struct ToolCallEventStream {
kind: acp::ToolKind,
input: serde_json::Value,
stream: AgentResponseEventStream,
fs: Option<Arc<dyn Fs>>,
}
impl ToolCallEventStream {
@ -902,6 +907,7 @@ impl ToolCallEventStream {
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
None,
);
(stream, ToolCallEventStreamReceiver(events_rx))
@ -911,12 +917,14 @@ impl ToolCallEventStream {
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
stream: AgentResponseEventStream,
fs: Option<Arc<dyn Fs>>,
) -> Self {
Self {
tool_use_id: tool_use.id.clone(),
kind,
input: tool_use.input.clone(),
stream,
fs,
}
}
@ -951,7 +959,11 @@ impl ToolCallEventStream {
.ok();
}
pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
return Task::ready(Ok(()));
}
let (response_tx, response_rx) = oneshot::channel();
self.stream
.0
@ -959,7 +971,7 @@ impl ToolCallEventStream {
ToolCallAuthorization {
tool_call: AgentResponseEventStream::initial_tool_call(
&self.tool_use_id,
title,
title.into(),
self.kind.clone(),
self.input.clone(),
),
@ -984,12 +996,22 @@ impl ToolCallEventStream {
},
)))
.ok();
async move {
match response_rx.await?.0.as_ref() {
"allow" | "always_allow" => Ok(()),
_ => Err(anyhow!("Permission to run tool denied by user")),
let fs = self.fs.clone();
cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
"always_allow" => {
if let Some(fs) = fs.clone() {
cx.update(|cx| {
update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
settings.set_always_allow_tool_actions(true);
});
})?;
}
Ok(())
}
}
"allow" => Ok(()),
_ => Err(anyhow!("Permission to run tool denied by user")),
})
}
}