Refactor tool auth

This commit is contained in:
Antonio Scandurra 2025-08-08 10:50:53 +02:00
parent d52e0f47b5
commit 294109c6da
2 changed files with 41 additions and 43 deletions

View file

@ -99,11 +99,11 @@ impl AgentTool for ToolRequiringPermission {
fn run(
self: Arc<Self>,
input: Self::Input,
_input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let auth_check = self.authorize(input, event_stream);
let auth_check = event_stream.authorize("Authorize?");
cx.foreground_executor().spawn(async move {
auth_check.await?;
Ok("Allowed".to_string())

View file

@ -498,7 +498,13 @@ impl Thread {
}));
};
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
let tool_event_stream =
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
tool_event_stream.send_update(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
Some(cx.foreground_executor().spawn(async move {
match tool_result.await {
Ok(tool_output) => LanguageModelToolResult {
@ -519,25 +525,6 @@ impl Thread {
}))
}
fn run_tool(
&self,
tool: Arc<dyn AnyAgentTool>,
tool_use: LanguageModelToolUse,
event_stream: AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Task<Result<String>> {
cx.spawn(async move |_this, cx| {
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
tool_event_stream.send_update(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
.await
})
}
fn handle_tool_use_json_parse_error_event(
&mut self,
tool_use_id: LanguageModelToolUseId,
@ -694,16 +681,6 @@ where
schemars::schema_for!(Self::Input)
}
/// Allows the tool to authorize a given tool call with the user if necessary
fn authorize(
&self,
input: Self::Input,
event_stream: ToolCallEventStream,
) -> impl use<Self> + Future<Output = Result<()>> {
let json_input = serde_json::json!(&input);
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
@ -918,13 +895,21 @@ impl AgentResponseEventStream {
#[derive(Clone)]
pub struct ToolCallEventStream {
tool_use_id: LanguageModelToolUseId,
kind: acp::ToolKind,
input: serde_json::Value,
stream: AgentResponseEventStream,
}
impl ToolCallEventStream {
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
fn new(
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
stream: AgentResponseEventStream,
) -> Self {
Self {
tool_use_id,
tool_use_id: tool_use.id.clone(),
kind,
input: tool_use.input.clone(),
stream,
}
}
@ -940,14 +925,17 @@ impl ToolCallEventStream {
});
}
pub fn authorize(
&self,
title: String,
kind: acp::ToolKind,
input: serde_json::Value,
) -> impl use<> + Future<Output = Result<()>> {
self.stream
.authorize_tool_call(&self.tool_use_id, title, kind, input)
pub fn authorize<T>(&self, title: T) -> impl use<T> + Future<Output = Result<()>>
where
T: Into<String>,
{
let title = title.into();
self.stream.authorize_tool_call(
&self.tool_use_id,
title,
self.kind.clone(),
self.input.clone(),
)
}
}
@ -963,7 +951,17 @@ impl TestToolCallEventStream {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
Self {
stream,