Centralize always_allow logic when authorizing agent2 tools (#35988)

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-11 19:22:19 +02:00 committed by GitHub
parent 56c4992b9a
commit 365b5aa31d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 136 additions and 40 deletions

View file

@ -4,9 +4,11 @@ use action_log::ActionLog;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use client::{Client, UserStore};
use fs::FakeFs;
use fs::{FakeFs, Fs};
use futures::channel::mpsc::UnboundedReceiver;
use gpui::{AppContext, Entity, Task, TestAppContext, http_client::FakeHttpClient};
use gpui::{
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
};
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
@ -19,6 +21,7 @@ use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use settings::SettingsStore;
use smol::stream::StreamExt;
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
@ -282,6 +285,63 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
})
]
);
// Simulate yet another tool call.
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_3".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
// Respond by always allowing tools.
let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
tool_call_auth_3
.response
.send(tool_call_auth_3.options[0].id.clone())
.unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
})]
);
// Simulate a final tool call, ensuring we don't trigger authorization.
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_4".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: "tool_id_4".into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
})]
);
}
#[gpui::test]
@ -773,13 +833,17 @@ impl TestModel {
async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.executor().allow_parking();
let fs = FakeFs::new(cx.background_executor.clone());
cx.update(|cx| {
settings::init(cx);
watch_settings(fs.clone(), cx);
Project::init_settings(cx);
agent_settings::init(cx);
});
let templates = Templates::new();
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
@ -841,3 +905,26 @@ fn init_logger() {
env_logger::init();
}
}
fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
let fs = fs.clone();
cx.spawn({
async move |cx| {
let mut new_settings_content_rx = settings::watch_config_file(
cx.background_executor(),
fs,
paths::settings_file().clone(),
);
while let Some(new_settings_content) = new_settings_content_rx.next().await {
cx.update(|cx| {
SettingsStore::update_global(cx, |settings, cx| {
settings.set_user_settings(&new_settings_content, cx)
})
})
.ok();
}
}
})
.detach();
}

View file

@ -110,9 +110,9 @@ impl AgentTool for ToolRequiringPermission {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let auth_check = event_stream.authorize("Authorize?".into());
let authorize = event_stream.authorize("Authorize?", cx);
cx.foreground_executor().spawn(async move {
auth_check.await?;
authorize.await?;
Ok("Allowed".to_string())
})
}

View file

@ -1,10 +1,12 @@
use crate::{SystemPromptTemplate, Template, Templates};
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use cloud_llm_client::{CompletionIntent, CompletionMode};
use collections::HashMap;
use fs::Fs;
use futures::{
channel::{mpsc, oneshot},
stream::FuturesUnordered,
@ -21,8 +23,9 @@ use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc};
use util::{ResultExt, markdown::MarkdownCodeBlock};
#[derive(Debug, Clone)]
@ -506,8 +509,9 @@ impl Thread {
}));
};
let fs = self.project.read(cx).fs().clone();
let tool_event_stream =
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
tool_event_stream.update_fields(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
@ -884,6 +888,7 @@ pub struct ToolCallEventStream {
kind: acp::ToolKind,
input: serde_json::Value,
stream: AgentResponseEventStream,
fs: Option<Arc<dyn Fs>>,
}
impl ToolCallEventStream {
@ -902,6 +907,7 @@ impl ToolCallEventStream {
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
None,
);
(stream, ToolCallEventStreamReceiver(events_rx))
@ -911,12 +917,14 @@ impl ToolCallEventStream {
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
stream: AgentResponseEventStream,
fs: Option<Arc<dyn Fs>>,
) -> Self {
Self {
tool_use_id: tool_use.id.clone(),
kind,
input: tool_use.input.clone(),
stream,
fs,
}
}
@ -951,7 +959,11 @@ impl ToolCallEventStream {
.ok();
}
pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
return Task::ready(Ok(()));
}
let (response_tx, response_rx) = oneshot::channel();
self.stream
.0
@ -959,7 +971,7 @@ impl ToolCallEventStream {
ToolCallAuthorization {
tool_call: AgentResponseEventStream::initial_tool_call(
&self.tool_use_id,
title,
title.into(),
self.kind.clone(),
self.input.clone(),
),
@ -984,12 +996,22 @@ impl ToolCallEventStream {
},
)))
.ok();
async move {
match response_rx.await?.0.as_ref() {
"allow" | "always_allow" => Ok(()),
_ => Err(anyhow!("Permission to run tool denied by user")),
let fs = self.fs.clone();
cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
"always_allow" => {
if let Some(fs) = fs.clone() {
cx.update(|cx| {
update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
settings.set_always_allow_tool_actions(true);
});
})?;
}
Ok(())
}
}
"allow" => Ok(()),
_ => Err(anyhow!("Permission to run tool denied by user")),
})
}
}

View file

@ -133,7 +133,7 @@ impl EditFileTool {
&self,
input: &EditFileToolInput,
event_stream: &ToolCallEventStream,
cx: &App,
cx: &mut App,
) -> Task<Result<()>> {
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
return Task::ready(Ok(()));
@ -147,8 +147,9 @@ impl EditFileTool {
.components()
.any(|component| component.as_os_str() == local_settings_folder.as_os_str())
{
return cx.foreground_executor().spawn(
event_stream.authorize(format!("{} (local settings)", input.display_description)),
return event_stream.authorize(
format!("{} (local settings)", input.display_description),
cx,
);
}
@ -156,9 +157,9 @@ impl EditFileTool {
// so check for that edge case too.
if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
if canonical_path.starts_with(paths::config_dir()) {
return cx.foreground_executor().spawn(
event_stream
.authorize(format!("{} (global settings)", input.display_description)),
return event_stream.authorize(
format!("{} (global settings)", input.display_description),
cx,
);
}
}
@ -173,8 +174,7 @@ impl EditFileTool {
if project_path.is_some() {
Task::ready(Ok(()))
} else {
cx.foreground_executor()
.spawn(event_stream.authorize(input.display_description.clone()))
event_stream.authorize(&input.display_description, cx)
}
}
}

View file

@ -65,7 +65,7 @@ impl AgentTool for OpenTool {
) -> Task<Result<Self::Output>> {
// If path_or_url turns out to be a path in the project, make it absolute.
let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx);
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())).to_string());
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
cx.background_spawn(async move {
authorize.await?;

View file

@ -5,7 +5,6 @@ use gpui::{App, AppContext, Entity, SharedString, Task};
use project::{Project, terminals::TerminalKind};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{
path::{Path, PathBuf},
sync::Arc,
@ -61,21 +60,6 @@ impl TerminalTool {
determine_shell: determine_shell.shared(),
}
}
fn authorize(
&self,
input: &TerminalToolInput,
event_stream: &ToolCallEventStream,
cx: &App,
) -> Task<Result<()>> {
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
return Task::ready(Ok(()));
}
// TODO: do we want to have a special title here?
cx.foreground_executor()
.spawn(event_stream.authorize(self.initial_title(Ok(input.clone())).to_string()))
}
}
impl AgentTool for TerminalTool {
@ -152,7 +136,7 @@ impl AgentTool for TerminalTool {
env
});
let authorize = self.authorize(&input, &event_stream, cx);
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
cx.spawn({
async move |cx| {

View file

@ -2172,6 +2172,9 @@ impl Fs for FakeFs {
async fn atomic_write(&self, path: PathBuf, data: String) -> Result<()> {
self.simulate_random_delay().await;
let path = normalize_path(path.as_path());
if let Some(path) = path.parent() {
self.create_dir(path).await?;
}
self.write_file_internal(path, data.into_bytes(), true)?;
Ok(())
}