Tool authorization
This commit is contained in:
parent
6f768aefa2
commit
3ceeefe460
6 changed files with 206 additions and 21 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -17,6 +17,7 @@ dependencies = [
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
"language",
|
"language",
|
||||||
|
"log",
|
||||||
"markdown",
|
"markdown",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"project",
|
"project",
|
||||||
|
|
|
@ -26,6 +26,7 @@ editor.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
language.workspace = true
|
language.workspace = true
|
||||||
|
log.workspace = true
|
||||||
markdown.workspace = true
|
markdown.workspace = true
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
|
|
|
@ -4,12 +4,14 @@ mod thread_view;
|
||||||
use agentic_coding_protocol::{self as acp, Role};
|
use agentic_coding_protocol::{self as acp, Role};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
|
use futures::channel::oneshot;
|
||||||
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
|
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
|
||||||
use language::LanguageRegistry;
|
use language::LanguageRegistry;
|
||||||
use markdown::Markdown;
|
use markdown::Markdown;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::{ops::Range, path::PathBuf, sync::Arc};
|
use std::{mem, ops::Range, path::PathBuf, sync::Arc};
|
||||||
use ui::App;
|
use ui::App;
|
||||||
|
use util::{ResultExt, debug_panic};
|
||||||
|
|
||||||
pub use server::AcpServer;
|
pub use server::AcpServer;
|
||||||
pub use thread_view::AcpThreadView;
|
pub use thread_view::AcpThreadView;
|
||||||
|
@ -112,14 +114,32 @@ impl MessageChunk {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
#[derive(Debug)]
|
||||||
pub enum AgentThreadEntryContent {
|
pub enum AgentThreadEntryContent {
|
||||||
Message(Message),
|
Message(Message),
|
||||||
ReadFile { path: PathBuf, content: String },
|
ReadFile { path: PathBuf, content: String },
|
||||||
|
ToolCall(ToolCall),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ToolCall {
|
||||||
|
WaitingForConfirmation {
|
||||||
|
id: ToolCallId,
|
||||||
|
tool_name: Entity<Markdown>,
|
||||||
|
description: Entity<Markdown>,
|
||||||
|
respond_tx: oneshot::Sender<bool>,
|
||||||
|
},
|
||||||
|
// todo! Running?
|
||||||
|
Allowed,
|
||||||
|
Rejected,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A `ThreadEntryId` that is known to be a ToolCall
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||||
pub struct ThreadEntryId(usize);
|
pub struct ToolCallId(ThreadEntryId);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||||
|
pub struct ThreadEntryId(pub u64);
|
||||||
|
|
||||||
impl ThreadEntryId {
|
impl ThreadEntryId {
|
||||||
pub fn post_inc(&mut self) -> Self {
|
pub fn post_inc(&mut self) -> Self {
|
||||||
|
@ -146,7 +166,7 @@ pub struct AcpThread {
|
||||||
|
|
||||||
enum AcpThreadEvent {
|
enum AcpThreadEvent {
|
||||||
NewEntry,
|
NewEntry,
|
||||||
LastEntryUpdated,
|
EntryUpdated(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
||||||
|
@ -184,22 +204,26 @@ impl AcpThread {
|
||||||
&self.entries
|
&self.entries
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push_entry(&mut self, entry: AgentThreadEntryContent, cx: &mut Context<Self>) {
|
pub fn push_entry(
|
||||||
self.entries.push(ThreadEntry {
|
&mut self,
|
||||||
id: self.next_entry_id.post_inc(),
|
entry: AgentThreadEntryContent,
|
||||||
content: entry,
|
cx: &mut Context<Self>,
|
||||||
});
|
) -> ThreadEntryId {
|
||||||
cx.emit(AcpThreadEvent::NewEntry)
|
let id = self.next_entry_id.post_inc();
|
||||||
|
self.entries.push(ThreadEntry { id, content: entry });
|
||||||
|
cx.emit(AcpThreadEvent::NewEntry);
|
||||||
|
id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
|
pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
|
||||||
|
let entries_len = self.entries.len();
|
||||||
if let Some(last_entry) = self.entries.last_mut()
|
if let Some(last_entry) = self.entries.last_mut()
|
||||||
&& let AgentThreadEntryContent::Message(Message {
|
&& let AgentThreadEntryContent::Message(Message {
|
||||||
ref mut chunks,
|
ref mut chunks,
|
||||||
role: Role::Assistant,
|
role: Role::Assistant,
|
||||||
}) = last_entry.content
|
}) = last_entry.content
|
||||||
{
|
{
|
||||||
cx.emit(AcpThreadEvent::LastEntryUpdated);
|
cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
|
||||||
|
|
||||||
if let (
|
if let (
|
||||||
Some(MessageChunk::Text { chunk: old_chunk }),
|
Some(MessageChunk::Text { chunk: old_chunk }),
|
||||||
|
@ -231,6 +255,74 @@ impl AcpThread {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn push_tool_call(
|
||||||
|
&mut self,
|
||||||
|
title: String,
|
||||||
|
description: String,
|
||||||
|
respond_tx: oneshot::Sender<bool>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> ToolCallId {
|
||||||
|
let language_registry = self.project.read(cx).languages().clone();
|
||||||
|
|
||||||
|
let entry_id = self.push_entry(
|
||||||
|
AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
|
||||||
|
// todo! clean up id creation
|
||||||
|
id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
|
||||||
|
tool_name: cx.new(|cx| {
|
||||||
|
Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
|
||||||
|
}),
|
||||||
|
description: cx.new(|cx| {
|
||||||
|
Markdown::new(
|
||||||
|
description.into(),
|
||||||
|
Some(language_registry.clone()),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
respond_tx,
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
ToolCallId(entry_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
|
||||||
|
let Some(entry) = self.entry_mut(id.0) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
|
||||||
|
debug_panic!("expected ToolCall");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_state = if allowed {
|
||||||
|
ToolCall::Allowed
|
||||||
|
} else {
|
||||||
|
ToolCall::Rejected
|
||||||
|
};
|
||||||
|
|
||||||
|
let call = mem::replace(call, new_state);
|
||||||
|
|
||||||
|
if let ToolCall::WaitingForConfirmation { respond_tx, .. } = call {
|
||||||
|
respond_tx.send(allowed).log_err();
|
||||||
|
} else {
|
||||||
|
debug_panic!("tried to authorize an already authorized tool call");
|
||||||
|
}
|
||||||
|
|
||||||
|
cx.emit(AcpThreadEvent::EntryUpdated(id.0.0 as usize));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
|
||||||
|
let entry = self.entries.get_mut(id.0 as usize);
|
||||||
|
debug_assert!(
|
||||||
|
entry.is_some(),
|
||||||
|
"We shouldn't give out ids to entries that don't exist"
|
||||||
|
);
|
||||||
|
entry
|
||||||
|
}
|
||||||
|
|
||||||
pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
|
pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||||
let agent = self.server.clone();
|
let agent = self.server.clone();
|
||||||
let id = self.id.clone();
|
let id = self.id.clone();
|
||||||
|
@ -303,11 +395,13 @@ mod tests {
|
||||||
));
|
));
|
||||||
assert!(
|
assert!(
|
||||||
thread.entries().iter().any(|entry| {
|
thread.entries().iter().any(|entry| {
|
||||||
entry.content
|
match &entry.content {
|
||||||
== AgentThreadEntryContent::ReadFile {
|
AgentThreadEntryContent::ReadFile { path, content } => {
|
||||||
path: "/private/tmp/foo".into(),
|
path.to_string_lossy().to_string() == "/private/tmp/foo"
|
||||||
content: "Lorem ipsum dolor".into(),
|
&& content == "Lorem ipsum dolor"
|
||||||
}
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
"Thread does not contain entry. Actual: {:?}",
|
"Thread does not contain entry. Actual: {:?}",
|
||||||
thread.entries()
|
thread.entries()
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId};
|
use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId, ToolCallId};
|
||||||
use agentic_coding_protocol as acp;
|
use agentic_coding_protocol as acp;
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
|
use futures::channel::oneshot;
|
||||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
|
use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
@ -185,6 +186,31 @@ impl acp::Client for AcpClientDelegate {
|
||||||
) -> Result<acp::GlobSearchResponse> {
|
) -> Result<acp::GlobSearchResponse> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn request_tool_call(
|
||||||
|
&self,
|
||||||
|
request: acp::RequestToolCallParams,
|
||||||
|
) -> Result<acp::RequestToolCallResponse> {
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
|
||||||
|
let cx = &mut self.cx.clone();
|
||||||
|
let entry_id = cx
|
||||||
|
.update(|cx| {
|
||||||
|
self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
|
||||||
|
// todo! tools that don't require confirmation
|
||||||
|
thread.push_tool_call(request.tool_name, request.description, tx, cx)
|
||||||
|
})
|
||||||
|
})?
|
||||||
|
.context("Failed to update thread")?;
|
||||||
|
|
||||||
|
if dbg!(rx.await)? {
|
||||||
|
Ok(acp::RequestToolCallResponse::Allowed {
|
||||||
|
id: entry_id.into(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(acp::RequestToolCallResponse::Rejected)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AcpServer {
|
impl AcpServer {
|
||||||
|
@ -258,3 +284,15 @@ impl From<ThreadId> for acp::ThreadId {
|
||||||
acp::ThreadId(thread_id.0.to_string())
|
acp::ThreadId(thread_id.0.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<acp::ToolCallId> for ToolCallId {
|
||||||
|
fn from(tool_call_id: acp::ToolCallId) -> Self {
|
||||||
|
Self(ThreadEntryId(tool_call_id.0.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ToolCallId> for acp::ToolCallId {
|
||||||
|
fn from(tool_call_id: ToolCallId) -> Self {
|
||||||
|
acp::ToolCallId(tool_call_id.0.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -13,13 +13,14 @@ use markdown::{HeadingLevelStyles, MarkdownElement, MarkdownStyle};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use settings::Settings as _;
|
use settings::Settings as _;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::Tooltip;
|
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
use ui::{Button, Tooltip};
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
use zed_actions::agent::Chat;
|
use zed_actions::agent::Chat;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry,
|
AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, MessageChunk, Role, ThreadEntry,
|
||||||
|
ToolCall, ToolCallId,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct AcpThreadView {
|
pub struct AcpThreadView {
|
||||||
|
@ -100,8 +101,8 @@ impl AcpThreadView {
|
||||||
AcpThreadEvent::NewEntry => {
|
AcpThreadEvent::NewEntry => {
|
||||||
this.list_state.splice(count..count, 1);
|
this.list_state.splice(count..count, 1);
|
||||||
}
|
}
|
||||||
AcpThreadEvent::LastEntryUpdated => {
|
AcpThreadEvent::EntryUpdated(index) => {
|
||||||
this.list_state.splice(count - 1..count, 1);
|
this.list_state.splice(*index..*index + 1, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -149,7 +150,7 @@ impl AcpThreadView {
|
||||||
fn thread(&self) -> Option<&Entity<AcpThread>> {
|
fn thread(&self) -> Option<&Entity<AcpThread>> {
|
||||||
match &self.thread_state {
|
match &self.thread_state {
|
||||||
ThreadState::Ready { thread, .. } => Some(thread),
|
ThreadState::Ready { thread, .. } => Some(thread),
|
||||||
_ => None,
|
ThreadState::Loading { .. } | ThreadState::LoadError(..) => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,6 +188,16 @@ impl AcpThreadView {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
|
||||||
|
let Some(thread) = self.thread() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.authorize_tool_call(id, allowed, cx);
|
||||||
|
});
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
fn render_entry(
|
fn render_entry(
|
||||||
&self,
|
&self,
|
||||||
entry: &ThreadEntry,
|
entry: &ThreadEntry,
|
||||||
|
@ -236,6 +247,46 @@ impl AcpThreadView {
|
||||||
.child(format!("<Reading file {}>", path.display()))
|
.child(format!("<Reading file {}>", path.display()))
|
||||||
.into_any()
|
.into_any()
|
||||||
}
|
}
|
||||||
|
AgentThreadEntryContent::ToolCall(tool_call) => match tool_call {
|
||||||
|
ToolCall::WaitingForConfirmation {
|
||||||
|
id,
|
||||||
|
tool_name,
|
||||||
|
description,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
let id = *id;
|
||||||
|
v_flex()
|
||||||
|
.elevation_1(cx)
|
||||||
|
.child(MarkdownElement::new(
|
||||||
|
tool_name.clone(),
|
||||||
|
default_markdown_style(window, cx),
|
||||||
|
))
|
||||||
|
.child(MarkdownElement::new(
|
||||||
|
description.clone(),
|
||||||
|
default_markdown_style(window, cx),
|
||||||
|
))
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.child(Button::new(("allow", id.0.0), "Allow").on_click(
|
||||||
|
cx.listener({
|
||||||
|
move |this, _, _, cx| {
|
||||||
|
this.authorize_tool_call(id, true, cx);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
.child(Button::new(("reject", id.0.0), "Reject").on_click(
|
||||||
|
cx.listener({
|
||||||
|
move |this, _, _, cx| {
|
||||||
|
this.authorize_tool_call(id, false, cx);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_any()
|
||||||
|
}
|
||||||
|
ToolCall::Allowed => div().child("Allowed!").into_any(),
|
||||||
|
ToolCall::Rejected => div().child("Rejected!").into_any(),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ pub trait StyledExt: Styled + Sized {
|
||||||
/// Sets `bg()`, `rounded_lg()`, `border()`, `border_color()`, `shadow()`
|
/// Sets `bg()`, `rounded_lg()`, `border()`, `border_color()`, `shadow()`
|
||||||
///
|
///
|
||||||
/// Example Elements: Title Bar, Panel, Tab Bar, Editor
|
/// Example Elements: Title Bar, Panel, Tab Bar, Editor
|
||||||
fn elevation_1(self, cx: &mut App) -> Self {
|
fn elevation_1(self, cx: &App) -> Self {
|
||||||
elevated(self, cx, ElevationIndex::Surface)
|
elevated(self, cx, ElevationIndex::Surface)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue