Refactor tool auth
This commit is contained in:
parent
d52e0f47b5
commit
294109c6da
2 changed files with 41 additions and 43 deletions
|
@ -99,11 +99,11 @@ impl AgentTool for ToolRequiringPermission {
|
|||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let auth_check = self.authorize(input, event_stream);
|
||||
let auth_check = event_stream.authorize("Authorize?");
|
||||
cx.foreground_executor().spawn(async move {
|
||||
auth_check.await?;
|
||||
Ok("Allowed".to_string())
|
||||
|
|
|
@ -498,7 +498,13 @@ impl Thread {
|
|||
}));
|
||||
};
|
||||
|
||||
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
|
||||
let tool_event_stream =
|
||||
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
|
@ -519,25 +525,6 @@ impl Thread {
|
|||
}))
|
||||
}
|
||||
|
||||
fn run_tool(
|
||||
&self,
|
||||
tool: Arc<dyn AnyAgentTool>,
|
||||
tool_use: LanguageModelToolUse,
|
||||
event_stream: AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<String>> {
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_tool_use_json_parse_error_event(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
|
@ -694,16 +681,6 @@ where
|
|||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Allows the tool to authorize a given tool call with the user if necessary
|
||||
fn authorize(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
) -> impl use<Self> + Future<Output = Result<()>> {
|
||||
let json_input = serde_json::json!(&input);
|
||||
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
|
@ -918,13 +895,21 @@ impl AgentResponseEventStream {
|
|||
#[derive(Clone)]
|
||||
pub struct ToolCallEventStream {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
stream: AgentResponseEventStream,
|
||||
}
|
||||
|
||||
impl ToolCallEventStream {
|
||||
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
|
||||
fn new(
|
||||
tool_use: &LanguageModelToolUse,
|
||||
kind: acp::ToolKind,
|
||||
stream: AgentResponseEventStream,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_use_id,
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
kind,
|
||||
input: tool_use.input.clone(),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
|
@ -940,14 +925,17 @@ impl ToolCallEventStream {
|
|||
});
|
||||
}
|
||||
|
||||
pub fn authorize(
|
||||
&self,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.stream
|
||||
.authorize_tool_call(&self.tool_use_id, title, kind, input)
|
||||
pub fn authorize<T>(&self, title: T) -> impl use<T> + Future<Output = Result<()>>
|
||||
where
|
||||
T: Into<String>,
|
||||
{
|
||||
let title = title.into();
|
||||
self.stream.authorize_tool_call(
|
||||
&self.tool_use_id,
|
||||
title,
|
||||
self.kind.clone(),
|
||||
self.input.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -963,7 +951,17 @@ impl TestToolCallEventStream {
|
|||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
|
||||
let stream = ToolCallEventStream::new(
|
||||
&LanguageModelToolUse {
|
||||
id: "test_id".into(),
|
||||
name: "test_tool".into(),
|
||||
raw_input: String::new(),
|
||||
input: serde_json::Value::Null,
|
||||
is_input_complete: true,
|
||||
},
|
||||
acp::ToolKind::Other,
|
||||
AgentResponseEventStream(events_tx),
|
||||
);
|
||||
|
||||
Self {
|
||||
stream,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue