Tool authorization

This commit is contained in:
Agus Zubiaga 2025-07-01 20:32:21 -03:00
parent 6f768aefa2
commit 3ceeefe460
6 changed files with 206 additions and 21 deletions

1
Cargo.lock generated
View file

@ -17,6 +17,7 @@ dependencies = [
"futures 0.3.31", "futures 0.3.31",
"gpui", "gpui",
"language", "language",
"log",
"markdown", "markdown",
"parking_lot", "parking_lot",
"project", "project",

View file

@ -26,6 +26,7 @@ editor.workspace = true
futures.workspace = true futures.workspace = true
gpui.workspace = true gpui.workspace = true
language.workspace = true language.workspace = true
log.workspace = true
markdown.workspace = true markdown.workspace = true
parking_lot.workspace = true parking_lot.workspace = true
project.workspace = true project.workspace = true

View file

@ -4,12 +4,14 @@ mod thread_view;
use agentic_coding_protocol::{self as acp, Role}; use agentic_coding_protocol::{self as acp, Role};
use anyhow::Result; use anyhow::Result;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use futures::channel::oneshot;
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
use language::LanguageRegistry; use language::LanguageRegistry;
use markdown::Markdown; use markdown::Markdown;
use project::Project; use project::Project;
use std::{ops::Range, path::PathBuf, sync::Arc}; use std::{mem, ops::Range, path::PathBuf, sync::Arc};
use ui::App; use ui::App;
use util::{ResultExt, debug_panic};
pub use server::AcpServer; pub use server::AcpServer;
pub use thread_view::AcpThreadView; pub use thread_view::AcpThreadView;
@ -112,14 +114,32 @@ impl MessageChunk {
} }
} }
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Debug)]
pub enum AgentThreadEntryContent { pub enum AgentThreadEntryContent {
Message(Message), Message(Message),
ReadFile { path: PathBuf, content: String }, ReadFile { path: PathBuf, content: String },
ToolCall(ToolCall),
} }
#[derive(Debug)]
pub enum ToolCall {
WaitingForConfirmation {
id: ToolCallId,
tool_name: Entity<Markdown>,
description: Entity<Markdown>,
respond_tx: oneshot::Sender<bool>,
},
// todo! Running?
Allowed,
Rejected,
}
/// A `ThreadEntryId` that is known to be a ToolCall
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ThreadEntryId(usize); pub struct ToolCallId(ThreadEntryId);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ThreadEntryId(pub u64);
impl ThreadEntryId { impl ThreadEntryId {
pub fn post_inc(&mut self) -> Self { pub fn post_inc(&mut self) -> Self {
@ -146,7 +166,7 @@ pub struct AcpThread {
enum AcpThreadEvent { enum AcpThreadEvent {
NewEntry, NewEntry,
LastEntryUpdated, EntryUpdated(usize),
} }
impl EventEmitter<AcpThreadEvent> for AcpThread {} impl EventEmitter<AcpThreadEvent> for AcpThread {}
@ -184,22 +204,26 @@ impl AcpThread {
&self.entries &self.entries
} }
pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) { pub fn push_entry(
self.entries.push(ThreadEntry { &mut self,
id: self.next_entry_id.post_inc(), entry: AgentThreadEntryContent,
content: entry, cx: &mut Context<Self>,
}); ) -> ThreadEntryId {
cx.emit(AcpThreadEvent::NewEntry) let id = self.next_entry_id.post_inc();
self.entries.push(ThreadEntry { id, content: entry });
cx.emit(AcpThreadEvent::NewEntry);
id
} }
pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) { pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
let entries_len = self.entries.len();
if let Some(last_entry) = self.entries.last_mut() if let Some(last_entry) = self.entries.last_mut()
&& let AgentThreadEntryContent::Message(Message { && let AgentThreadEntryContent::Message(Message {
ref mut chunks, ref mut chunks,
role: Role::Assistant, role: Role::Assistant,
}) = last_entry.content }) = last_entry.content
{ {
cx.emit(AcpThreadEvent::LastEntryUpdated); cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
if let ( if let (
Some(MessageChunk::Text { chunk: old_chunk }), Some(MessageChunk::Text { chunk: old_chunk }),
@ -231,6 +255,74 @@ impl AcpThread {
); );
} }
pub fn push_tool_call(
&mut self,
title: String,
description: String,
respond_tx: oneshot::Sender<bool>,
cx: &mut Context<Self>,
) -> ToolCallId {
let language_registry = self.project.read(cx).languages().clone();
let entry_id = self.push_entry(
AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
// todo! clean up id creation
id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
tool_name: cx.new(|cx| {
Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
}),
description: cx.new(|cx| {
Markdown::new(
description.into(),
Some(language_registry.clone()),
None,
cx,
)
}),
respond_tx,
}),
cx,
);
ToolCallId(entry_id)
}
pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
let Some(entry) = self.entry_mut(id.0) else {
return;
};
let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
debug_panic!("expected ToolCall");
return;
};
let new_state = if allowed {
ToolCall::Allowed
} else {
ToolCall::Rejected
};
let call = mem::replace(call, new_state);
if let ToolCall::WaitingForConfirmation { respond_tx, .. } = call {
respond_tx.send(allowed).log_err();
} else {
debug_panic!("tried to authorize an already authorized tool call");
}
cx.emit(AcpThreadEvent::EntryUpdated(id.0.0 as usize));
}
fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
let entry = self.entries.get_mut(id.0 as usize);
debug_assert!(
entry.is_some(),
"We shouldn't give out ids to entries that don't exist"
);
entry
}
pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> { pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
let agent = self.server.clone(); let agent = self.server.clone();
let id = self.id.clone(); let id = self.id.clone();
@ -303,11 +395,13 @@ mod tests {
)); ));
assert!( assert!(
thread.entries().iter().any(|entry| { thread.entries().iter().any(|entry| {
entry.content match &entry.content {
== AgentThreadEntryContent::ReadFile { AgentThreadEntryContent::ReadFile { path, content } => {
path: "/private/tmp/foo".into(), path.to_string_lossy().to_string() == "/private/tmp/foo"
content: "Lorem ipsum dolor".into(), && content == "Lorem ipsum dolor"
} }
_ => false,
}
}), }),
"Thread does not contain entry. Actual: {:?}", "Thread does not contain entry. Actual: {:?}",
thread.entries() thread.entries()

View file

@ -1,8 +1,9 @@
use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId}; use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId, ToolCallId};
use agentic_coding_protocol as acp; use agentic_coding_protocol as acp;
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use async_trait::async_trait; use async_trait::async_trait;
use collections::HashMap; use collections::HashMap;
use futures::channel::oneshot;
use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
use parking_lot::Mutex; use parking_lot::Mutex;
use project::Project; use project::Project;
@ -185,6 +186,31 @@ impl acp::Client for AcpClientDelegate {
) -> Result<acp::GlobSearchResponse> { ) -> Result<acp::GlobSearchResponse> {
todo!() todo!()
} }
async fn request_tool_call(
&self,
request: acp::RequestToolCallParams,
) -> Result<acp::RequestToolCallResponse> {
let (tx, rx) = oneshot::channel();
let cx = &mut self.cx.clone();
let entry_id = cx
.update(|cx| {
self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
// todo! tools that don't require confirmation
thread.push_tool_call(request.tool_name, request.description, tx, cx)
})
})?
.context("Failed to update thread")?;
if dbg!(rx.await)? {
Ok(acp::RequestToolCallResponse::Allowed {
id: entry_id.into(),
})
} else {
Ok(acp::RequestToolCallResponse::Rejected)
}
}
} }
impl AcpServer { impl AcpServer {
@ -258,3 +284,15 @@ impl From<ThreadId> for acp::ThreadId {
acp::ThreadId(thread_id.0.to_string()) acp::ThreadId(thread_id.0.to_string())
} }
} }
impl From<acp::ToolCallId> for ToolCallId {
fn from(tool_call_id: acp::ToolCallId) -> Self {
Self(ThreadEntryId(tool_call_id.0.into()))
}
}
impl From<ToolCallId> for acp::ToolCallId {
fn from(tool_call_id: ToolCallId) -> Self {
acp::ToolCallId(tool_call_id.0.0)
}
}

View file

@ -13,13 +13,14 @@ use markdown::{HeadingLevelStyles, MarkdownElement, MarkdownStyle};
use project::Project; use project::Project;
use settings::Settings as _; use settings::Settings as _;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::Tooltip;
use ui::prelude::*; use ui::prelude::*;
use ui::{Button, Tooltip};
use util::ResultExt; use util::ResultExt;
use zed_actions::agent::Chat; use zed_actions::agent::Chat;
use crate::{ use crate::{
AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry, AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry,
ToolCall, ToolCallId,
}; };
pub struct AcpThreadView { pub struct AcpThreadView {
@ -100,8 +101,8 @@ impl AcpThreadView {
AcpThreadEvent::NewEntry => { AcpThreadEvent::NewEntry => {
this.list_state.splice(count..count, 1); this.list_state.splice(count..count, 1);
} }
AcpThreadEvent::LastEntryUpdated => { AcpThreadEvent::EntryUpdated(index) => {
this.list_state.splice(count - 1..count, 1); this.list_state.splice(*index..*index + 1, 1);
} }
} }
cx.notify(); cx.notify();
@ -149,7 +150,7 @@ impl AcpThreadView {
fn thread(&self) -> Option<&Entity<AcpThread>> { fn thread(&self) -> Option<&Entity<AcpThread>> {
match &self.thread_state { match &self.thread_state {
ThreadState::Ready { thread, .. } => Some(thread), ThreadState::Ready { thread, .. } => Some(thread),
_ => None, ThreadState::Loading { .. } | ThreadState::LoadError(..) => None,
} }
} }
@ -187,6 +188,16 @@ impl AcpThreadView {
}); });
} }
fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
let Some(thread) = self.thread() else {
return;
};
thread.update(cx, |thread, cx| {
thread.authorize_tool_call(id, allowed, cx);
});
cx.notify();
}
fn render_entry( fn render_entry(
&self, &self,
entry: &ThreadEntry, entry: &ThreadEntry,
@ -236,6 +247,46 @@ impl AcpThreadView {
.child(format!("<Reading file {}>", path.display())) .child(format!("<Reading file {}>", path.display()))
.into_any() .into_any()
} }
AgentThreadEntryContent::ToolCall(tool_call) => match tool_call {
ToolCall::WaitingForConfirmation {
id,
tool_name,
description,
..
} => {
let id = *id;
v_flex()
.elevation_1(cx)
.child(MarkdownElement::new(
tool_name.clone(),
default_markdown_style(window, cx),
))
.child(MarkdownElement::new(
description.clone(),
default_markdown_style(window, cx),
))
.child(
h_flex()
.child(Button::new(("allow", id.0.0), "Allow").on_click(
cx.listener({
move |this, _, _, cx| {
this.authorize_tool_call(id, true, cx);
}
}),
))
.child(Button::new(("reject", id.0.0), "Reject").on_click(
cx.listener({
move |this, _, _, cx| {
this.authorize_tool_call(id, false, cx);
}
}),
)),
)
.into_any()
}
ToolCall::Allowed => div().child("Allowed!").into_any(),
ToolCall::Rejected => div().child("Rejected!").into_any(),
},
} }
} }
} }

View file

@ -39,7 +39,7 @@ pub trait StyledExt: Styled + Sized {
/// Sets `bg()`, `rounded_lg()`, `border()`, `border_color()`, `shadow()` /// Sets `bg()`, `rounded_lg()`, `border()`, `border_color()`, `shadow()`
/// ///
/// Example Elements: Title Bar, Panel, Tab Bar, Editor /// Example Elements: Title Bar, Panel, Tab Bar, Editor
fn elevation_1(self, cx: &mut App) -> Self { fn elevation_1(self, cx: &App) -> Self {
elevated(self, cx, ElevationIndex::Surface) elevated(self, cx, ElevationIndex::Surface)
} }