Centralize always_allow
logic when authorizing agent2 tools (#35988)
Release Notes: - N/A --------- Co-authored-by: Cole Miller <cole@zed.dev> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de> Co-authored-by: Agus Zubiaga <agus@zed.dev> Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
56c4992b9a
commit
365b5aa31d
7 changed files with 136 additions and 40 deletions
|
@ -4,9 +4,11 @@ use action_log::ActionLog;
|
||||||
use agent_client_protocol::{self as acp};
|
use agent_client_protocol::{self as acp};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use client::{Client, UserStore};
|
use client::{Client, UserStore};
|
||||||
use fs::FakeFs;
|
use fs::{FakeFs, Fs};
|
||||||
use futures::channel::mpsc::UnboundedReceiver;
|
use futures::channel::mpsc::UnboundedReceiver;
|
||||||
use gpui::{AppContext, Entity, Task, TestAppContext, http_client::FakeHttpClient};
|
use gpui::{
|
||||||
|
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
|
||||||
|
};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||||
|
@ -19,6 +21,7 @@ use reqwest_client::ReqwestClient;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use settings::SettingsStore;
|
||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
|
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||||
use util::path;
|
use util::path;
|
||||||
|
@ -282,6 +285,63 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||||
})
|
})
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Simulate yet another tool call.
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: "tool_id_3".into(),
|
||||||
|
name: ToolRequiringPermission.name().into(),
|
||||||
|
raw_input: "{}".into(),
|
||||||
|
input: json!({}),
|
||||||
|
is_input_complete: true,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
// Respond by always allowing tools.
|
||||||
|
let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
|
||||||
|
tool_call_auth_3
|
||||||
|
.response
|
||||||
|
.send(tool_call_auth_3.options[0].id.clone())
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
let message = completion.messages.last().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
message.content,
|
||||||
|
vec![MessageContent::ToolResult(LanguageModelToolResult {
|
||||||
|
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
|
||||||
|
tool_name: ToolRequiringPermission.name().into(),
|
||||||
|
is_error: false,
|
||||||
|
content: "Allowed".into(),
|
||||||
|
output: Some("Allowed".into())
|
||||||
|
})]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Simulate a final tool call, ensuring we don't trigger authorization.
|
||||||
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
LanguageModelToolUse {
|
||||||
|
id: "tool_id_4".into(),
|
||||||
|
name: ToolRequiringPermission.name().into(),
|
||||||
|
raw_input: "{}".into(),
|
||||||
|
input: json!({}),
|
||||||
|
is_input_complete: true,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
cx.run_until_parked();
|
||||||
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
let message = completion.messages.last().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
message.content,
|
||||||
|
vec![MessageContent::ToolResult(LanguageModelToolResult {
|
||||||
|
tool_use_id: "tool_id_4".into(),
|
||||||
|
tool_name: ToolRequiringPermission.name().into(),
|
||||||
|
is_error: false,
|
||||||
|
content: "Allowed".into(),
|
||||||
|
output: Some("Allowed".into())
|
||||||
|
})]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
@ -773,13 +833,17 @@ impl TestModel {
|
||||||
|
|
||||||
async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||||
cx.executor().allow_parking();
|
cx.executor().allow_parking();
|
||||||
|
|
||||||
|
let fs = FakeFs::new(cx.background_executor.clone());
|
||||||
|
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
settings::init(cx);
|
settings::init(cx);
|
||||||
|
watch_settings(fs.clone(), cx);
|
||||||
Project::init_settings(cx);
|
Project::init_settings(cx);
|
||||||
|
agent_settings::init(cx);
|
||||||
});
|
});
|
||||||
let templates = Templates::new();
|
let templates = Templates::new();
|
||||||
|
|
||||||
let fs = FakeFs::new(cx.background_executor.clone());
|
|
||||||
fs.insert_tree(path!("/test"), json!({})).await;
|
fs.insert_tree(path!("/test"), json!({})).await;
|
||||||
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
||||||
|
|
||||||
|
@ -841,3 +905,26 @@ fn init_logger() {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
|
||||||
|
let fs = fs.clone();
|
||||||
|
cx.spawn({
|
||||||
|
async move |cx| {
|
||||||
|
let mut new_settings_content_rx = settings::watch_config_file(
|
||||||
|
cx.background_executor(),
|
||||||
|
fs,
|
||||||
|
paths::settings_file().clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
while let Some(new_settings_content) = new_settings_content_rx.next().await {
|
||||||
|
cx.update(|cx| {
|
||||||
|
SettingsStore::update_global(cx, |settings, cx| {
|
||||||
|
settings.set_user_settings(&new_settings_content, cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
|
|
@ -110,9 +110,9 @@ impl AgentTool for ToolRequiringPermission {
|
||||||
event_stream: ToolCallEventStream,
|
event_stream: ToolCallEventStream,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<String>> {
|
) -> Task<Result<String>> {
|
||||||
let auth_check = event_stream.authorize("Authorize?".into());
|
let authorize = event_stream.authorize("Authorize?", cx);
|
||||||
cx.foreground_executor().spawn(async move {
|
cx.foreground_executor().spawn(async move {
|
||||||
auth_check.await?;
|
authorize.await?;
|
||||||
Ok("Allowed".to_string())
|
Ok("Allowed".to_string())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
use crate::{SystemPromptTemplate, Template, Templates};
|
use crate::{SystemPromptTemplate, Template, Templates};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
|
use agent_settings::AgentSettings;
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tool::adapt_schema_to_format;
|
use assistant_tool::adapt_schema_to_format;
|
||||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
|
use fs::Fs;
|
||||||
use futures::{
|
use futures::{
|
||||||
channel::{mpsc, oneshot},
|
channel::{mpsc, oneshot},
|
||||||
stream::FuturesUnordered,
|
stream::FuturesUnordered,
|
||||||
|
@ -21,8 +23,9 @@ use project::Project;
|
||||||
use prompt_store::ProjectContext;
|
use prompt_store::ProjectContext;
|
||||||
use schemars::{JsonSchema, Schema};
|
use schemars::{JsonSchema, Schema};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use settings::{Settings, update_settings_file};
|
||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
|
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc};
|
||||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -506,8 +509,9 @@ impl Thread {
|
||||||
}));
|
}));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let fs = self.project.read(cx).fs().clone();
|
||||||
let tool_event_stream =
|
let tool_event_stream =
|
||||||
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
|
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
|
||||||
tool_event_stream.update_fields(acp::ToolCallUpdateFields {
|
tool_event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||||
status: Some(acp::ToolCallStatus::InProgress),
|
status: Some(acp::ToolCallStatus::InProgress),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
@ -884,6 +888,7 @@ pub struct ToolCallEventStream {
|
||||||
kind: acp::ToolKind,
|
kind: acp::ToolKind,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
stream: AgentResponseEventStream,
|
stream: AgentResponseEventStream,
|
||||||
|
fs: Option<Arc<dyn Fs>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolCallEventStream {
|
impl ToolCallEventStream {
|
||||||
|
@ -902,6 +907,7 @@ impl ToolCallEventStream {
|
||||||
},
|
},
|
||||||
acp::ToolKind::Other,
|
acp::ToolKind::Other,
|
||||||
AgentResponseEventStream(events_tx),
|
AgentResponseEventStream(events_tx),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
(stream, ToolCallEventStreamReceiver(events_rx))
|
(stream, ToolCallEventStreamReceiver(events_rx))
|
||||||
|
@ -911,12 +917,14 @@ impl ToolCallEventStream {
|
||||||
tool_use: &LanguageModelToolUse,
|
tool_use: &LanguageModelToolUse,
|
||||||
kind: acp::ToolKind,
|
kind: acp::ToolKind,
|
||||||
stream: AgentResponseEventStream,
|
stream: AgentResponseEventStream,
|
||||||
|
fs: Option<Arc<dyn Fs>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
tool_use_id: tool_use.id.clone(),
|
tool_use_id: tool_use.id.clone(),
|
||||||
kind,
|
kind,
|
||||||
input: tool_use.input.clone(),
|
input: tool_use.input.clone(),
|
||||||
stream,
|
stream,
|
||||||
|
fs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -951,7 +959,11 @@ impl ToolCallEventStream {
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
|
pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
|
||||||
|
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
|
||||||
|
return Task::ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
let (response_tx, response_rx) = oneshot::channel();
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
self.stream
|
self.stream
|
||||||
.0
|
.0
|
||||||
|
@ -959,7 +971,7 @@ impl ToolCallEventStream {
|
||||||
ToolCallAuthorization {
|
ToolCallAuthorization {
|
||||||
tool_call: AgentResponseEventStream::initial_tool_call(
|
tool_call: AgentResponseEventStream::initial_tool_call(
|
||||||
&self.tool_use_id,
|
&self.tool_use_id,
|
||||||
title,
|
title.into(),
|
||||||
self.kind.clone(),
|
self.kind.clone(),
|
||||||
self.input.clone(),
|
self.input.clone(),
|
||||||
),
|
),
|
||||||
|
@ -984,12 +996,22 @@ impl ToolCallEventStream {
|
||||||
},
|
},
|
||||||
)))
|
)))
|
||||||
.ok();
|
.ok();
|
||||||
async move {
|
let fs = self.fs.clone();
|
||||||
match response_rx.await?.0.as_ref() {
|
cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
|
||||||
"allow" | "always_allow" => Ok(()),
|
"always_allow" => {
|
||||||
_ => Err(anyhow!("Permission to run tool denied by user")),
|
if let Some(fs) = fs.clone() {
|
||||||
|
cx.update(|cx| {
|
||||||
|
update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
|
||||||
|
settings.set_always_allow_tool_actions(true);
|
||||||
|
});
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
"allow" => Ok(()),
|
||||||
|
_ => Err(anyhow!("Permission to run tool denied by user")),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -133,7 +133,7 @@ impl EditFileTool {
|
||||||
&self,
|
&self,
|
||||||
input: &EditFileToolInput,
|
input: &EditFileToolInput,
|
||||||
event_stream: &ToolCallEventStream,
|
event_stream: &ToolCallEventStream,
|
||||||
cx: &App,
|
cx: &mut App,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
|
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
|
||||||
return Task::ready(Ok(()));
|
return Task::ready(Ok(()));
|
||||||
|
@ -147,8 +147,9 @@ impl EditFileTool {
|
||||||
.components()
|
.components()
|
||||||
.any(|component| component.as_os_str() == local_settings_folder.as_os_str())
|
.any(|component| component.as_os_str() == local_settings_folder.as_os_str())
|
||||||
{
|
{
|
||||||
return cx.foreground_executor().spawn(
|
return event_stream.authorize(
|
||||||
event_stream.authorize(format!("{} (local settings)", input.display_description)),
|
format!("{} (local settings)", input.display_description),
|
||||||
|
cx,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,9 +157,9 @@ impl EditFileTool {
|
||||||
// so check for that edge case too.
|
// so check for that edge case too.
|
||||||
if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
|
if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
|
||||||
if canonical_path.starts_with(paths::config_dir()) {
|
if canonical_path.starts_with(paths::config_dir()) {
|
||||||
return cx.foreground_executor().spawn(
|
return event_stream.authorize(
|
||||||
event_stream
|
format!("{} (global settings)", input.display_description),
|
||||||
.authorize(format!("{} (global settings)", input.display_description)),
|
cx,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -173,8 +174,7 @@ impl EditFileTool {
|
||||||
if project_path.is_some() {
|
if project_path.is_some() {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
cx.foreground_executor()
|
event_stream.authorize(&input.display_description, cx)
|
||||||
.spawn(event_stream.authorize(input.display_description.clone()))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,7 +65,7 @@ impl AgentTool for OpenTool {
|
||||||
) -> Task<Result<Self::Output>> {
|
) -> Task<Result<Self::Output>> {
|
||||||
// If path_or_url turns out to be a path in the project, make it absolute.
|
// If path_or_url turns out to be a path in the project, make it absolute.
|
||||||
let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx);
|
let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx);
|
||||||
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())).to_string());
|
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
|
||||||
cx.background_spawn(async move {
|
cx.background_spawn(async move {
|
||||||
authorize.await?;
|
authorize.await?;
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||||
use project::{Project, terminals::TerminalKind};
|
use project::{Project, terminals::TerminalKind};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::Settings;
|
|
||||||
use std::{
|
use std::{
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
|
@ -61,21 +60,6 @@ impl TerminalTool {
|
||||||
determine_shell: determine_shell.shared(),
|
determine_shell: determine_shell.shared(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authorize(
|
|
||||||
&self,
|
|
||||||
input: &TerminalToolInput,
|
|
||||||
event_stream: &ToolCallEventStream,
|
|
||||||
cx: &App,
|
|
||||||
) -> Task<Result<()>> {
|
|
||||||
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
|
|
||||||
return Task::ready(Ok(()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: do we want to have a special title here?
|
|
||||||
cx.foreground_executor()
|
|
||||||
.spawn(event_stream.authorize(self.initial_title(Ok(input.clone())).to_string()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentTool for TerminalTool {
|
impl AgentTool for TerminalTool {
|
||||||
|
@ -152,7 +136,7 @@ impl AgentTool for TerminalTool {
|
||||||
env
|
env
|
||||||
});
|
});
|
||||||
|
|
||||||
let authorize = self.authorize(&input, &event_stream, cx);
|
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
|
||||||
|
|
||||||
cx.spawn({
|
cx.spawn({
|
||||||
async move |cx| {
|
async move |cx| {
|
||||||
|
|
|
@ -2172,6 +2172,9 @@ impl Fs for FakeFs {
|
||||||
async fn atomic_write(&self, path: PathBuf, data: String) -> Result<()> {
|
async fn atomic_write(&self, path: PathBuf, data: String) -> Result<()> {
|
||||||
self.simulate_random_delay().await;
|
self.simulate_random_delay().await;
|
||||||
let path = normalize_path(path.as_path());
|
let path = normalize_path(path.as_path());
|
||||||
|
if let Some(path) = path.parent() {
|
||||||
|
self.create_dir(path).await?;
|
||||||
|
}
|
||||||
self.write_file_internal(path, data.into_bytes(), true)?;
|
self.write_file_internal(path, data.into_bytes(), true)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue