From 365b5aa31d606f8ecac440de98a81f405f751d67 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 11 Aug 2025 19:22:19 +0200 Subject: [PATCH] Centralize `always_allow` logic when authorizing agent2 tools (#35988) Release Notes: - N/A --------- Co-authored-by: Cole Miller Co-authored-by: Bennet Bo Fenner Co-authored-by: Agus Zubiaga Co-authored-by: Ben Brandt --- crates/agent2/src/tests/mod.rs | 93 ++++++++++++++++++++++- crates/agent2/src/tests/test_tools.rs | 4 +- crates/agent2/src/thread.rs | 40 +++++++--- crates/agent2/src/tools/edit_file_tool.rs | 16 ++-- crates/agent2/src/tools/open_tool.rs | 2 +- crates/agent2/src/tools/terminal_tool.rs | 18 +---- crates/fs/src/fs.rs | 3 + 7 files changed, 136 insertions(+), 40 deletions(-) diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index b47816f35c..d6aaddf2c2 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -4,9 +4,11 @@ use action_log::ActionLog; use agent_client_protocol::{self as acp}; use anyhow::Result; use client::{Client, UserStore}; -use fs::FakeFs; +use fs::{FakeFs, Fs}; use futures::channel::mpsc::UnboundedReceiver; -use gpui::{AppContext, Entity, Task, TestAppContext, http_client::FakeHttpClient}; +use gpui::{ + App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient, +}; use indoc::indoc; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, @@ -19,6 +21,7 @@ use reqwest_client::ReqwestClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; +use settings::SettingsStore; use smol::stream::StreamExt; use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration}; use util::path; @@ -282,6 +285,63 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { }) ] ); + + // Simulate yet another tool call. + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_3".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + + // Respond by always allowing tools. + let tool_call_auth_3 = next_tool_call_authorization(&mut events).await; + tool_call_auth_3 + .response + .send(tool_call_auth_3.options[0].id.clone()) + .unwrap(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + let message = completion.messages.last().unwrap(); + assert_eq!( + message.content, + vec![MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: false, + content: "Allowed".into(), + output: Some("Allowed".into()) + })] + ); + + // Simulate a final tool call, ensuring we don't trigger authorization. + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_4".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + let completion = fake_model.pending_completions().pop().unwrap(); + let message = completion.messages.last().unwrap(); + assert_eq!( + message.content, + vec![MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: "tool_id_4".into(), + tool_name: ToolRequiringPermission.name().into(), + is_error: false, + content: "Allowed".into(), + output: Some("Allowed".into()) + })] + ); } #[gpui::test] @@ -773,13 +833,17 @@ impl TestModel { async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.background_executor.clone()); + cx.update(|cx| { settings::init(cx); + watch_settings(fs.clone(), cx); Project::init_settings(cx); + agent_settings::init(cx); }); let templates = Templates::new(); - let fs = FakeFs::new(cx.background_executor.clone()); fs.insert_tree(path!("/test"), json!({})).await; let project = Project::test(fs, [path!("/test").as_ref()], cx).await; @@ -841,3 +905,26 @@ fn init_logger() { env_logger::init(); } } + +fn watch_settings(fs: Arc, cx: &mut App) { + let fs = fs.clone(); + cx.spawn({ + async move |cx| { + let mut new_settings_content_rx = settings::watch_config_file( + cx.background_executor(), + fs, + paths::settings_file().clone(), + ); + + while let Some(new_settings_content) = new_settings_content_rx.next().await { + cx.update(|cx| { + SettingsStore::update_global(cx, |settings, cx| { + settings.set_user_settings(&new_settings_content, cx) + }) + }) + .ok(); + } + } + }) + .detach(); +} diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs index d06614f3fe..7c7b81f52f 100644 --- a/crates/agent2/src/tests/test_tools.rs +++ b/crates/agent2/src/tests/test_tools.rs @@ -110,9 +110,9 @@ impl AgentTool for ToolRequiringPermission { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let auth_check = event_stream.authorize("Authorize?".into()); + let authorize = event_stream.authorize("Authorize?", cx); cx.foreground_executor().spawn(async move { - auth_check.await?; + authorize.await?; Ok("Allowed".to_string()) }) } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index dd8e5476ab..23a0f7972d 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,10 +1,12 @@ use crate::{SystemPromptTemplate, Template, Templates}; use action_log::ActionLog; use agent_client_protocol as acp; +use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use cloud_llm_client::{CompletionIntent, CompletionMode}; use collections::HashMap; +use fs::Fs; use futures::{ channel::{mpsc, oneshot}, stream::FuturesUnordered, @@ -21,8 +23,9 @@ use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; +use settings::{Settings, update_settings_file}; use smol::stream::StreamExt; -use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc}; +use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc}; use util::{ResultExt, markdown::MarkdownCodeBlock}; #[derive(Debug, Clone)] @@ -506,8 +509,9 @@ impl Thread { })); }; + let fs = self.project.read(cx).fs().clone(); let tool_event_stream = - ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone()); + ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs)); tool_event_stream.update_fields(acp::ToolCallUpdateFields { status: Some(acp::ToolCallStatus::InProgress), ..Default::default() @@ -884,6 +888,7 @@ pub struct ToolCallEventStream { kind: acp::ToolKind, input: serde_json::Value, stream: AgentResponseEventStream, + fs: Option>, } impl ToolCallEventStream { @@ -902,6 +907,7 @@ impl ToolCallEventStream { }, acp::ToolKind::Other, AgentResponseEventStream(events_tx), + None, ); (stream, ToolCallEventStreamReceiver(events_rx)) @@ -911,12 +917,14 @@ impl ToolCallEventStream { tool_use: &LanguageModelToolUse, kind: acp::ToolKind, stream: AgentResponseEventStream, + fs: Option>, ) -> Self { Self { tool_use_id: tool_use.id.clone(), kind, input: tool_use.input.clone(), stream, + fs, } } @@ -951,7 +959,11 @@ impl ToolCallEventStream { .ok(); } - pub fn authorize(&self, title: String) -> impl use<> + Future> { + pub fn authorize(&self, title: impl Into, cx: &mut App) -> Task> { + if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { + return Task::ready(Ok(())); + } + let (response_tx, response_rx) = oneshot::channel(); self.stream .0 @@ -959,7 +971,7 @@ impl ToolCallEventStream { ToolCallAuthorization { tool_call: AgentResponseEventStream::initial_tool_call( &self.tool_use_id, - title, + title.into(), self.kind.clone(), self.input.clone(), ), @@ -984,12 +996,22 @@ impl ToolCallEventStream { }, ))) .ok(); - async move { - match response_rx.await?.0.as_ref() { - "allow" | "always_allow" => Ok(()), - _ => Err(anyhow!("Permission to run tool denied by user")), + let fs = self.fs.clone(); + cx.spawn(async move |cx| match response_rx.await?.0.as_ref() { + "always_allow" => { + if let Some(fs) = fs.clone() { + cx.update(|cx| { + update_settings_file::(fs, cx, |settings, _| { + settings.set_always_allow_tool_actions(true); + }); + })?; + } + + Ok(()) } - } + "allow" => Ok(()), + _ => Err(anyhow!("Permission to run tool denied by user")), + }) } } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index d9a4cdf8ba..88764d1953 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -133,7 +133,7 @@ impl EditFileTool { &self, input: &EditFileToolInput, event_stream: &ToolCallEventStream, - cx: &App, + cx: &mut App, ) -> Task> { if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { return Task::ready(Ok(())); @@ -147,8 +147,9 @@ impl EditFileTool { .components() .any(|component| component.as_os_str() == local_settings_folder.as_os_str()) { - return cx.foreground_executor().spawn( - event_stream.authorize(format!("{} (local settings)", input.display_description)), + return event_stream.authorize( + format!("{} (local settings)", input.display_description), + cx, ); } @@ -156,9 +157,9 @@ impl EditFileTool { // so check for that edge case too. if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { if canonical_path.starts_with(paths::config_dir()) { - return cx.foreground_executor().spawn( - event_stream - .authorize(format!("{} (global settings)", input.display_description)), + return event_stream.authorize( + format!("{} (global settings)", input.display_description), + cx, ); } } @@ -173,8 +174,7 @@ impl EditFileTool { if project_path.is_some() { Task::ready(Ok(())) } else { - cx.foreground_executor() - .spawn(event_stream.authorize(input.display_description.clone())) + event_stream.authorize(&input.display_description, cx) } } } diff --git a/crates/agent2/src/tools/open_tool.rs b/crates/agent2/src/tools/open_tool.rs index 0860b62a51..36420560c1 100644 --- a/crates/agent2/src/tools/open_tool.rs +++ b/crates/agent2/src/tools/open_tool.rs @@ -65,7 +65,7 @@ impl AgentTool for OpenTool { ) -> Task> { // If path_or_url turns out to be a path in the project, make it absolute. let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx); - let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())).to_string()); + let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx); cx.background_spawn(async move { authorize.await?; diff --git a/crates/agent2/src/tools/terminal_tool.rs b/crates/agent2/src/tools/terminal_tool.rs index c0b34444dd..ecb855ac34 100644 --- a/crates/agent2/src/tools/terminal_tool.rs +++ b/crates/agent2/src/tools/terminal_tool.rs @@ -5,7 +5,6 @@ use gpui::{App, AppContext, Entity, SharedString, Task}; use project::{Project, terminals::TerminalKind}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::Settings; use std::{ path::{Path, PathBuf}, sync::Arc, @@ -61,21 +60,6 @@ impl TerminalTool { determine_shell: determine_shell.shared(), } } - - fn authorize( - &self, - input: &TerminalToolInput, - event_stream: &ToolCallEventStream, - cx: &App, - ) -> Task> { - if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { - return Task::ready(Ok(())); - } - - // TODO: do we want to have a special title here? - cx.foreground_executor() - .spawn(event_stream.authorize(self.initial_title(Ok(input.clone())).to_string())) - } } impl AgentTool for TerminalTool { @@ -152,7 +136,7 @@ impl AgentTool for TerminalTool { env }); - let authorize = self.authorize(&input, &event_stream, cx); + let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx); cx.spawn({ async move |cx| { diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index af8fe129ab..a2b75ac6a7 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -2172,6 +2172,9 @@ impl Fs for FakeFs { async fn atomic_write(&self, path: PathBuf, data: String) -> Result<()> { self.simulate_random_delay().await; let path = normalize_path(path.as_path()); + if let Some(path) = path.parent() { + self.create_dir(path).await?; + } self.write_file_internal(path, data.into_bytes(), true)?; Ok(()) }