From 2d0f10c48a3308bde3079fdb5301006256f8d50b Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 24 Jul 2025 14:39:29 -0300 Subject: [PATCH] Refactor to use new ACP crate (#35043) This will prepare us for running the protocol over MCP Release Notes: - N/A --------- Co-authored-by: Ben Brandt Co-authored-by: Conrad Irwin Co-authored-by: Richard Feldman --- Cargo.lock | 15 +- Cargo.toml | 1 + crates/acp_thread/Cargo.toml | 1 + crates/acp_thread/src/acp_thread.rs | 1002 +++++++---------- crates/acp_thread/src/connection.rs | 36 +- crates/acp_thread/src/old_acp_support.rs | 461 ++++++++ crates/agent_servers/Cargo.toml | 1 + crates/agent_servers/src/agent_servers.rs | 11 +- crates/agent_servers/src/claude.rs | 497 ++++---- crates/agent_servers/src/claude/mcp_server.rs | 179 +-- crates/agent_servers/src/claude/tools.rs | 275 ++--- crates/agent_servers/src/e2e_tests.rs | 95 +- crates/agent_servers/src/gemini.rs | 98 +- .../agent_servers/src/stdio_agent_server.rs | 119 -- crates/agent_ui/Cargo.toml | 4 +- crates/agent_ui/src/acp/thread_view.rs | 652 +++++------ crates/agent_ui/src/agent_diff.rs | 6 +- crates/agent_ui/src/agent_panel.rs | 2 +- crates/context_server/src/client.rs | 87 +- crates/context_server/src/protocol.rs | 20 + crates/context_server/src/types.rs | 16 +- 21 files changed, 1830 insertions(+), 1748 deletions(-) create mode 100644 crates/acp_thread/src/old_acp_support.rs delete mode 100644 crates/agent_servers/src/stdio_agent_server.rs diff --git a/Cargo.lock b/Cargo.lock index 8f791d395a..2c65131db0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6,6 +6,7 @@ version = 4 name = "acp_thread" version = "0.1.0" dependencies = [ + "agent-client-protocol", "agentic-coding-protocol", "anyhow", "assistant_tool", @@ -135,11 +136,23 @@ dependencies = [ "zstd", ] +[[package]] +name = "agent-client-protocol" +version = "0.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fb7f39671e02f8a1aeb625652feae40b6fc2597baaa97e028a98863477aecbd" +dependencies = [ + "schemars", + "serde", + "serde_json", +] + [[package]] name = "agent_servers" version = "0.1.0" dependencies = [ "acp_thread", + "agent-client-protocol", "agentic-coding-protocol", "anyhow", "collections", @@ -195,9 +208,9 @@ version = "0.1.0" dependencies = [ "acp_thread", "agent", + "agent-client-protocol", "agent_servers", "agent_settings", - "agentic-coding-protocol", "ai_onboarding", "anyhow", "assistant_context", diff --git a/Cargo.toml b/Cargo.toml index ec793a7429..9062950127 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -413,6 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" +agent-client-protocol = "0.0.10" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index b44c25ccc9..011f26f364 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -16,6 +16,7 @@ doctest = false test-support = ["gpui/test-support", "project/test-support"] [dependencies] +agent-client-protocol.workspace = true agentic-coding-protocol.workspace = true anyhow.workspace = true assistant_tool.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 9af1eeb187..3c6c21205f 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,17 +1,15 @@ mod connection; +mod old_acp_support; pub use connection::*; +pub use old_acp_support::*; -pub use acp::ToolCallId; -use agentic_coding_protocol::{ - self as acp, AgentRequest, ProtocolVersion, ToolCallConfirmationOutcome, ToolCallLocation, - UserMessageChunk, -}; +use agent_client_protocol as acp; use anyhow::{Context as _, Result}; use assistant_tool::ActionLog; use buffer_diff::BufferDiff; use editor::{Bias, MultiBuffer, PathKey}; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; -use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; +use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use itertools::Itertools; use language::{ Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point, @@ -21,46 +19,37 @@ use markdown::Markdown; use project::{AgentLocation, Project}; use std::collections::HashMap; use std::error::Error; -use std::fmt::{Formatter, Write}; +use std::fmt::Formatter; +use std::rc::Rc; use std::{ fmt::Display, mem, path::{Path, PathBuf}, sync::Arc, }; -use ui::{App, IconName}; +use ui::App; use util::ResultExt; -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug)] pub struct UserMessage { - pub content: Entity, + pub content: ContentBlock, } impl UserMessage { pub fn from_acp( - message: &acp::SendUserMessageParams, + message: impl IntoIterator, language_registry: Arc, cx: &mut App, ) -> Self { - let mut md_source = String::new(); - - for chunk in &message.chunks { - match chunk { - UserMessageChunk::Text { text } => md_source.push_str(&text), - UserMessageChunk::Path { path } => { - write!(&mut md_source, "{}", MentionPath(&path)).unwrap() - } - } - } - - Self { - content: cx - .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)), + let mut content = ContentBlock::Empty; + for chunk in message { + content.append(chunk, &language_registry, cx) } + Self { content: content } } fn to_markdown(&self, cx: &App) -> String { - format!("## User\n\n{}\n\n", self.content.read(cx).source()) + format!("## User\n\n{}\n\n", self.content.to_markdown(cx)) } } @@ -96,7 +85,7 @@ impl Display for MentionPath<'_> { } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug, PartialEq)] pub struct AssistantMessage { pub chunks: Vec, } @@ -113,42 +102,24 @@ impl AssistantMessage { } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug, PartialEq)] pub enum AssistantMessageChunk { - Text { chunk: Entity }, - Thought { chunk: Entity }, + Message { block: ContentBlock }, + Thought { block: ContentBlock }, } impl AssistantMessageChunk { - pub fn from_acp( - chunk: acp::AssistantMessageChunk, - language_registry: Arc, - cx: &mut App, - ) -> Self { - match chunk { - acp::AssistantMessageChunk::Text { text } => Self::Text { - chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)), - }, - acp::AssistantMessageChunk::Thought { thought } => Self::Thought { - chunk: cx - .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)), - }, - } - } - - pub fn from_str(chunk: &str, language_registry: Arc, cx: &mut App) -> Self { - Self::Text { - chunk: cx.new(|cx| { - Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx) - }), + pub fn from_str(chunk: &str, language_registry: &Arc, cx: &mut App) -> Self { + Self::Message { + block: ContentBlock::new(chunk.into(), language_registry, cx), } } fn to_markdown(&self, cx: &App) -> String { match self { - Self::Text { chunk } => chunk.read(cx).source().to_string(), - Self::Thought { chunk } => { - format!("\n{}\n", chunk.read(cx).source()) + Self::Message { block } => block.to_markdown(cx).to_string(), + Self::Thought { block } => { + format!("\n{}\n", block.to_markdown(cx)) } } } @@ -166,19 +137,15 @@ impl AgentThreadEntry { match self { Self::UserMessage(message) => message.to_markdown(cx), Self::AssistantMessage(message) => message.to_markdown(cx), - Self::ToolCall(too_call) => too_call.to_markdown(cx), + Self::ToolCall(tool_call) => tool_call.to_markdown(cx), } } - pub fn diff(&self) -> Option<&Diff> { - if let AgentThreadEntry::ToolCall(ToolCall { - content: Some(ToolCallContent::Diff { diff }), - .. - }) = self - { - Some(&diff) + pub fn diffs(&self) -> impl Iterator { + if let AgentThreadEntry::ToolCall(call) = self { + itertools::Either::Left(call.diffs()) } else { - None + itertools::Either::Right(std::iter::empty()) } } @@ -195,20 +162,54 @@ impl AgentThreadEntry { pub struct ToolCall { pub id: acp::ToolCallId, pub label: Entity, - pub icon: IconName, - pub content: Option, + pub kind: acp::ToolKind, + pub content: Vec, pub status: ToolCallStatus, pub locations: Vec, } impl ToolCall { + fn from_acp( + tool_call: acp::ToolCall, + status: ToolCallStatus, + language_registry: Arc, + cx: &mut App, + ) -> Self { + Self { + id: tool_call.id, + label: cx.new(|cx| { + Markdown::new( + tool_call.label.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + kind: tool_call.kind, + content: tool_call + .content + .into_iter() + .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx)) + .collect(), + locations: tool_call.locations, + status, + } + } + + pub fn diffs(&self) -> impl Iterator { + self.content.iter().filter_map(|content| match content { + ToolCallContent::ContentBlock { .. } => None, + ToolCallContent::Diff { diff } => Some(diff), + }) + } + fn to_markdown(&self, cx: &App) -> String { let mut markdown = format!( "**Tool Call: {}**\nStatus: {}\n\n", self.label.read(cx).source(), self.status ); - if let Some(content) = &self.content { + for content in &self.content { markdown.push_str(content.to_markdown(cx).as_str()); markdown.push_str("\n\n"); } @@ -219,8 +220,8 @@ impl ToolCall { #[derive(Debug)] pub enum ToolCallStatus { WaitingForConfirmation { - confirmation: ToolCallConfirmation, - respond_tx: oneshot::Sender, + options: Vec, + respond_tx: oneshot::Sender, }, Allowed { status: acp::ToolCallStatus, @@ -237,9 +238,9 @@ impl Display for ToolCallStatus { match self { ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation", ToolCallStatus::Allowed { status } => match status { - acp::ToolCallStatus::Running => "Running", - acp::ToolCallStatus::Finished => "Finished", - acp::ToolCallStatus::Error => "Error", + acp::ToolCallStatus::InProgress => "In Progress", + acp::ToolCallStatus::Completed => "Completed", + acp::ToolCallStatus::Failed => "Failed", }, ToolCallStatus::Rejected => "Rejected", ToolCallStatus::Canceled => "Canceled", @@ -248,86 +249,92 @@ impl Display for ToolCallStatus { } } -#[derive(Debug)] -pub enum ToolCallConfirmation { - Edit { - description: Option>, - }, - Execute { - command: String, - root_command: String, - description: Option>, - }, - Mcp { - server_name: String, - tool_name: String, - tool_display_name: String, - description: Option>, - }, - Fetch { - urls: Vec, - description: Option>, - }, - Other { - description: Entity, - }, +#[derive(Debug, PartialEq, Clone)] +pub enum ContentBlock { + Empty, + Markdown { markdown: Entity }, } -impl ToolCallConfirmation { - pub fn from_acp( - confirmation: acp::ToolCallConfirmation, +impl ContentBlock { + pub fn new( + block: acp::ContentBlock, + language_registry: &Arc, + cx: &mut App, + ) -> Self { + let mut this = Self::Empty; + this.append(block, language_registry, cx); + this + } + + pub fn new_combined( + blocks: impl IntoIterator, language_registry: Arc, cx: &mut App, ) -> Self { - let to_md = |description: String, cx: &mut App| -> Entity { - cx.new(|cx| { - Markdown::new( - description.into(), - Some(language_registry.clone()), - None, - cx, - ) - }) + let mut this = Self::Empty; + for block in blocks { + this.append(block, &language_registry, cx); + } + this + } + + pub fn append( + &mut self, + block: acp::ContentBlock, + language_registry: &Arc, + cx: &mut App, + ) { + let new_content = match block { + acp::ContentBlock::Text(text_content) => text_content.text.clone(), + acp::ContentBlock::ResourceLink(resource_link) => { + if let Some(path) = resource_link.uri.strip_prefix("file://") { + format!("{}", MentionPath(path.as_ref())) + } else { + resource_link.uri.clone() + } + } + acp::ContentBlock::Image(_) + | acp::ContentBlock::Audio(_) + | acp::ContentBlock::Resource(_) => String::new(), }; - match confirmation { - acp::ToolCallConfirmation::Edit { description } => Self::Edit { - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Execute { - command, - root_command, - description, - } => Self::Execute { - command, - root_command, - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Mcp { - server_name, - tool_name, - tool_display_name, - description, - } => Self::Mcp { - server_name, - tool_name, - tool_display_name, - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch { - urls: urls.iter().map(|url| url.into()).collect(), - description: description.map(|description| to_md(description, cx)), - }, - acp::ToolCallConfirmation::Other { description } => Self::Other { - description: to_md(description, cx), - }, + match self { + ContentBlock::Empty => { + *self = ContentBlock::Markdown { + markdown: cx.new(|cx| { + Markdown::new( + new_content.into(), + Some(language_registry.clone()), + None, + cx, + ) + }), + }; + } + ContentBlock::Markdown { markdown } => { + markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx)); + } + } + } + + fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str { + match self { + ContentBlock::Empty => "", + ContentBlock::Markdown { markdown } => markdown.read(cx).source(), + } + } + + pub fn markdown(&self) -> Option<&Entity> { + match self { + ContentBlock::Empty => None, + ContentBlock::Markdown { markdown } => Some(markdown), } } } #[derive(Debug)] pub enum ToolCallContent { - Markdown { markdown: Entity }, + ContentBlock { content: ContentBlock }, Diff { diff: Diff }, } @@ -338,8 +345,8 @@ impl ToolCallContent { cx: &mut App, ) -> Self { match content { - acp::ToolCallContent::Markdown { markdown } => Self::Markdown { - markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)), + acp::ToolCallContent::ContentBlock { content } => Self::ContentBlock { + content: ContentBlock::new(content, &language_registry, cx), }, acp::ToolCallContent::Diff { diff } => Self::Diff { diff: Diff::from_acp(diff, language_registry, cx), @@ -347,9 +354,9 @@ impl ToolCallContent { } } - fn to_markdown(&self, cx: &App) -> String { + pub fn to_markdown(&self, cx: &App) -> String { match self { - Self::Markdown { markdown } => markdown.read(cx).source().to_string(), + Self::ContentBlock { content } => content.to_markdown(cx).to_string(), Self::Diff { diff } => diff.to_markdown(cx), } } @@ -520,8 +527,8 @@ pub struct AcpThread { action_log: Entity, shared_buffers: HashMap, BufferSnapshot>, send_task: Option>, - connection: Arc, - child_status: Option>>, + connection: Rc, + session_id: acp::SessionId, } pub enum AcpThreadEvent { @@ -563,10 +570,9 @@ impl Error for LoadError {} impl AcpThread { pub fn new( - connection: impl AgentConnection + 'static, - title: SharedString, - child_status: Option>>, + connection: Rc, project: Entity, + session_id: acp::SessionId, cx: &mut Context, ) -> Self { let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -576,24 +582,11 @@ impl AcpThread { shared_buffers: Default::default(), entries: Default::default(), plan: Default::default(), - title, + title: connection.name().into(), project, send_task: None, - connection: Arc::new(connection), - child_status, - } - } - - /// Send a request to the agent and wait for a response. - pub fn request( - &self, - params: R, - ) -> impl use + Future> { - let params = params.into_any(); - let result = self.connection.request_any(params); - async move { - let result = result.await?; - Ok(R::response_from_any(result)?) + connection, + session_id, } } @@ -629,15 +622,7 @@ impl AcpThread { for entry in self.entries.iter().rev() { match entry { AgentThreadEntry::UserMessage(_) => return false, - AgentThreadEntry::ToolCall(ToolCall { - status: - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, - .. - }, - content: Some(ToolCallContent::Diff { .. }), - .. - }) => return true, + AgentThreadEntry::ToolCall(call) if call.diffs().next().is_some() => return true, AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {} } } @@ -652,42 +637,37 @@ impl AcpThread { pub fn push_assistant_chunk( &mut self, - chunk: acp::AssistantMessageChunk, + chunk: acp::ContentBlock, + is_thought: bool, cx: &mut Context, ) { + let language_registry = self.project.read(cx).languages().clone(); let entries_len = self.entries.len(); if let Some(last_entry) = self.entries.last_mut() && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry { cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); - - match (chunks.last_mut(), &chunk) { - ( - Some(AssistantMessageChunk::Text { chunk: old_chunk }), - acp::AssistantMessageChunk::Text { text: new_chunk }, - ) - | ( - Some(AssistantMessageChunk::Thought { chunk: old_chunk }), - acp::AssistantMessageChunk::Thought { thought: new_chunk }, - ) => { - old_chunk.update(cx, |old_chunk, cx| { - old_chunk.append(&new_chunk, cx); - }); + match (chunks.last_mut(), is_thought) { + (Some(AssistantMessageChunk::Message { block }), false) + | (Some(AssistantMessageChunk::Thought { block }), true) => { + block.append(chunk, &language_registry, cx) } _ => { - chunks.push(AssistantMessageChunk::from_acp( - chunk, - self.project.read(cx).languages().clone(), - cx, - )); + let block = ContentBlock::new(chunk, &language_registry, cx); + if is_thought { + chunks.push(AssistantMessageChunk::Thought { block }) + } else { + chunks.push(AssistantMessageChunk::Message { block }) + } } } } else { - let chunk = AssistantMessageChunk::from_acp( - chunk, - self.project.read(cx).languages().clone(), - cx, - ); + let block = ContentBlock::new(chunk, &language_registry, cx); + let chunk = if is_thought { + AssistantMessageChunk::Thought { block } + } else { + AssistantMessageChunk::Message { block } + }; self.push_entry( AgentThreadEntry::AssistantMessage(AssistantMessage { @@ -698,122 +678,122 @@ impl AcpThread { } } - pub fn request_new_tool_call( + pub fn update_tool_call( &mut self, - tool_call: acp::RequestToolCallConfirmationParams, + id: acp::ToolCallId, + status: acp::ToolCallStatus, + content: Option>, cx: &mut Context, - ) -> ToolCallRequest { - let (tx, rx) = oneshot::channel(); + ) -> Result<()> { + let languages = self.project.read(cx).languages().clone(); + let (ix, current_call) = self.tool_call_mut(&id).context("Tool call not found")?; - let status = ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::from_acp( - tool_call.confirmation, - self.project.read(cx).languages().clone(), - cx, - ), - respond_tx: tx, - }; + if let Some(content) = content { + current_call.content = content + .into_iter() + .map(|chunk| ToolCallContent::from_acp(chunk, languages.clone(), cx)) + .collect(); + } + current_call.status = ToolCallStatus::Allowed { status }; - let id = self.insert_tool_call(tool_call.tool_call, status, cx); - ToolCallRequest { id, outcome: rx } + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + + Ok(()) } - pub fn request_tool_call_confirmation( - &mut self, - tool_call_id: ToolCallId, - confirmation: acp::ToolCallConfirmation, - cx: &mut Context, - ) -> Result { - let project = self.project.read(cx).languages().clone(); - let Some((idx, call)) = self.tool_call_mut(tool_call_id) else { - anyhow::bail!("Tool call not found"); - }; - - let (tx, rx) = oneshot::channel(); - - call.status = ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx), - respond_tx: tx, - }; - - cx.emit(AcpThreadEvent::EntryUpdated(idx)); - - Ok(ToolCallRequest { - id: tool_call_id, - outcome: rx, - }) - } - - pub fn push_tool_call( - &mut self, - request: acp::PushToolCallParams, - cx: &mut Context, - ) -> acp::ToolCallId { + /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. + pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { let status = ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, + status: tool_call.status, }; - - self.insert_tool_call(request, status, cx) + self.upsert_tool_call_inner(tool_call, status, cx) } - fn insert_tool_call( + pub fn upsert_tool_call_inner( &mut self, - tool_call: acp::PushToolCallParams, + tool_call: acp::ToolCall, status: ToolCallStatus, cx: &mut Context, - ) -> acp::ToolCallId { + ) { let language_registry = self.project.read(cx).languages().clone(); - let id = acp::ToolCallId(self.entries.len() as u64); - let call = ToolCall { - id, - label: cx.new(|cx| { - Markdown::new( - tool_call.label.into(), - Some(language_registry.clone()), - None, - cx, - ) - }), - icon: acp_icon_to_ui_icon(tool_call.icon), - content: tool_call - .content - .map(|content| ToolCallContent::from_acp(content, language_registry, cx)), - locations: tool_call.locations, - status, - }; + let call = ToolCall::from_acp(tool_call, status, language_registry, cx); let location = call.locations.last().cloned(); + + if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { + *current_call = call; + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } else { + self.push_entry(AgentThreadEntry::ToolCall(call), cx); + } + if let Some(location) = location { self.set_project_location(location, cx) } + } - self.push_entry(AgentThreadEntry::ToolCall(call), cx); + fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { + // The tool call we are looking for is typically the last one, or very close to the end. + // At the moment, it doesn't seem like a hashmap would be a good fit for this use case. + self.entries + .iter_mut() + .enumerate() + .rev() + .find_map(|(index, tool_call)| { + if let AgentThreadEntry::ToolCall(tool_call) = tool_call + && &tool_call.id == id + { + Some((index, tool_call)) + } else { + None + } + }) + } - id + pub fn request_tool_call_permission( + &mut self, + tool_call: acp::ToolCall, + options: Vec, + cx: &mut Context, + ) -> oneshot::Receiver { + let (tx, rx) = oneshot::channel(); + + let status = ToolCallStatus::WaitingForConfirmation { + options, + respond_tx: tx, + }; + + self.upsert_tool_call_inner(tool_call, status, cx); + rx } pub fn authorize_tool_call( &mut self, id: acp::ToolCallId, - outcome: acp::ToolCallConfirmationOutcome, + option_id: acp::PermissionOptionId, + option_kind: acp::PermissionOptionKind, cx: &mut Context, ) { - let Some((ix, call)) = self.tool_call_mut(id) else { + let Some((ix, call)) = self.tool_call_mut(&id) else { return; }; - let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject { - ToolCallStatus::Rejected - } else { - ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, + let new_status = match option_kind { + acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => { + ToolCallStatus::Rejected + } + acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => { + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress, + } } }; let curr_status = mem::replace(&mut call.status, new_status); if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { - respond_tx.send(outcome).log_err(); + respond_tx.send(option_id).log_err(); } else if cfg!(debug_assertions) { panic!("tried to authorize an already authorized tool call"); } @@ -821,70 +801,11 @@ impl AcpThread { cx.emit(AcpThreadEvent::EntryUpdated(ix)); } - pub fn update_tool_call( - &mut self, - id: acp::ToolCallId, - new_status: acp::ToolCallStatus, - new_content: Option, - cx: &mut Context, - ) -> Result<()> { - let language_registry = self.project.read(cx).languages().clone(); - let (ix, call) = self.tool_call_mut(id).context("Entry not found")?; - - if let Some(new_content) = new_content { - call.content = Some(ToolCallContent::from_acp( - new_content, - language_registry, - cx, - )); - } - - match &mut call.status { - ToolCallStatus::Allowed { status } => { - *status = new_status; - } - ToolCallStatus::WaitingForConfirmation { .. } => { - anyhow::bail!("Tool call hasn't been authorized yet") - } - ToolCallStatus::Rejected => { - anyhow::bail!("Tool call was rejected and therefore can't be updated") - } - ToolCallStatus::Canceled => { - call.status = ToolCallStatus::Allowed { status: new_status }; - } - } - - let location = call.locations.last().cloned(); - if let Some(location) = location { - self.set_project_location(location, cx) - } - - cx.emit(AcpThreadEvent::EntryUpdated(ix)); - Ok(()) - } - - fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { - let entry = self.entries.get_mut(id.0 as usize); - debug_assert!( - entry.is_some(), - "We shouldn't give out ids to entries that don't exist" - ); - match entry { - Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)), - _ => { - if cfg!(debug_assertions) { - panic!("entry is not a tool call"); - } - None - } - } - } - pub fn plan(&self) -> &Plan { &self.plan } - pub fn update_plan(&mut self, request: acp::UpdatePlanParams, cx: &mut Context) { + pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context) { self.plan = Plan { entries: request .entries @@ -896,14 +817,14 @@ impl AcpThread { cx.notify(); } - pub fn clear_completed_plan_entries(&mut self, cx: &mut Context) { + fn clear_completed_plan_entries(&mut self, cx: &mut Context) { self.plan .entries .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed)); cx.notify(); } - pub fn set_project_location(&self, location: ToolCallLocation, cx: &mut Context) { + pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context) { self.project.update(cx, |project, cx| { let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else { return; @@ -953,14 +874,8 @@ impl AcpThread { false } - pub fn initialize(&self) -> impl use<> + Future> { - self.request(acp::InitializeParams { - protocol_version: ProtocolVersion::latest(), - }) - } - - pub fn authenticate(&self) -> impl use<> + Future> { - self.request(acp::AuthenticateParams) + pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future> { + self.connection.authenticate(cx) } #[cfg(any(test, feature = "test-support"))] @@ -968,39 +883,50 @@ impl AcpThread { &mut self, message: &str, cx: &mut Context, - ) -> BoxFuture<'static, Result<(), acp::Error>> { + ) -> BoxFuture<'static, Result<()>> { self.send( - acp::SendUserMessageParams { - chunks: vec![acp::UserMessageChunk::Text { - text: message.to_string(), - }], - }, + vec![acp::ContentBlock::Text(acp::TextContent { + text: message.to_string(), + annotations: None, + })], cx, ) } pub fn send( &mut self, - message: acp::SendUserMessageParams, + message: Vec, cx: &mut Context, - ) -> BoxFuture<'static, Result<(), acp::Error>> { - self.push_entry( - AgentThreadEntry::UserMessage(UserMessage::from_acp( - &message, - self.project.read(cx).languages().clone(), - cx, - )), + ) -> BoxFuture<'static, Result<()>> { + let block = ContentBlock::new_combined( + message.clone(), + self.project.read(cx).languages().clone(), cx, ); + self.push_entry( + AgentThreadEntry::UserMessage(UserMessage { content: block }), + cx, + ); + self.clear_completed_plan_entries(cx); let (tx, rx) = oneshot::channel(); - let cancel = self.cancel(cx); + let cancel_task = self.cancel(cx); self.send_task = Some(cx.spawn(async move |this, cx| { async { - cancel.await.log_err(); + cancel_task.await; - let result = this.update(cx, |this, _| this.request(message))?.await; + let result = this + .update(cx, |this, cx| { + this.connection.prompt( + acp::PromptToolArguments { + prompt: message, + session_id: this.session_id.clone(), + }, + cx, + ) + })? + .await; tx.send(result).log_err(); this.update(cx, |this, _cx| this.send_task.take())?; anyhow::Ok(()) @@ -1018,48 +944,38 @@ impl AcpThread { .boxed() } - pub fn cancel(&mut self, cx: &mut Context) -> Task> { - if self.send_task.take().is_some() { - let request = self.request(acp::CancelSendMessageParams); - cx.spawn(async move |this, cx| { - request.await?; - this.update(cx, |this, _cx| { - for entry in this.entries.iter_mut() { - if let AgentThreadEntry::ToolCall(call) = entry { - let cancel = matches!( - call.status, - ToolCallStatus::WaitingForConfirmation { .. } - | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running - } - ); + pub fn cancel(&mut self, cx: &mut Context) -> Task<()> { + let Some(send_task) = self.send_task.take() else { + return Task::ready(()); + }; - if cancel { - let curr_status = - mem::replace(&mut call.status, ToolCallStatus::Canceled); - - if let ToolCallStatus::WaitingForConfirmation { - respond_tx, .. - } = curr_status - { - respond_tx - .send(acp::ToolCallConfirmationOutcome::Cancel) - .ok(); - } - } + for entry in self.entries.iter_mut() { + if let AgentThreadEntry::ToolCall(call) = entry { + let cancel = matches!( + call.status, + ToolCallStatus::WaitingForConfirmation { .. } + | ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress } - } - })?; - Ok(()) - }) - } else { - Task::ready(Ok(())) + ); + + if cancel { + call.status = ToolCallStatus::Canceled; + } + } } + + self.connection.cancel(&self.session_id, cx); + + // Wait for the send task to complete + cx.foreground_executor().spawn(send_task) } pub fn read_text_file( &self, - request: acp::ReadTextFileParams, + path: PathBuf, + line: Option, + limit: Option, reuse_shared_snapshot: bool, cx: &mut Context, ) -> Task> { @@ -1068,7 +984,7 @@ impl AcpThread { cx.spawn(async move |this, cx| { let load = project.update(cx, |project, cx| { let path = project - .project_path_for_absolute_path(&request.path, cx) + .project_path_for_absolute_path(&path, cx) .context("invalid path")?; anyhow::Ok(project.open_buffer(path, cx)) }); @@ -1094,7 +1010,7 @@ impl AcpThread { let position = buffer .read(cx) .snapshot() - .anchor_before(Point::new(request.line.unwrap_or_default(), 0)); + .anchor_before(Point::new(line.unwrap_or_default(), 0)); project.set_agent_location( Some(AgentLocation { buffer: buffer.downgrade(), @@ -1110,11 +1026,11 @@ impl AcpThread { this.update(cx, |this, _| { let text = snapshot.text(); this.shared_buffers.insert(buffer.clone(), snapshot); - if request.line.is_none() && request.limit.is_none() { + if line.is_none() && limit.is_none() { return Ok(text); } - let limit = request.limit.unwrap_or(u32::MAX) as usize; - let Some(line) = request.line else { + let limit = limit.unwrap_or(u32::MAX) as usize; + let Some(line) = line else { return Ok(text.lines().take(limit).collect::()); }; @@ -1199,197 +1115,15 @@ impl AcpThread { }) } - pub fn child_status(&mut self) -> Option>> { - self.child_status.take() - } - pub fn to_markdown(&self, cx: &App) -> String { self.entries.iter().map(|e| e.to_markdown(cx)).collect() } } -#[derive(Clone)] -pub struct AcpClientDelegate { - thread: WeakEntity, - cx: AsyncApp, - // sent_buffer_versions: HashMap, HashMap>, -} - -impl AcpClientDelegate { - pub fn new(thread: WeakEntity, cx: AsyncApp) -> Self { - Self { thread, cx } - } - - pub async fn clear_completed_plan_entries(&self) -> Result<()> { - let cx = &mut self.cx.clone(); - cx.update(|cx| { - self.thread - .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx)) - })? - .context("Failed to update thread")?; - - Ok(()) - } - - pub async fn request_existing_tool_call_confirmation( - &self, - tool_call_id: ToolCallId, - confirmation: acp::ToolCallConfirmation, - ) -> Result { - let cx = &mut self.cx.clone(); - let ToolCallRequest { outcome, .. } = cx - .update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.request_tool_call_confirmation(tool_call_id, confirmation, cx) - }) - })? - .context("Failed to update thread")??; - - Ok(outcome.await?) - } - - pub async fn read_text_file_reusing_snapshot( - &self, - request: acp::ReadTextFileParams, - ) -> Result { - let content = self - .cx - .update(|cx| { - self.thread - .update(cx, |thread, cx| thread.read_text_file(request, true, cx)) - })? - .context("Failed to update thread")? - .await?; - Ok(acp::ReadTextFileResponse { content }) - } -} - -impl acp::Client for AcpClientDelegate { - async fn stream_assistant_message_chunk( - &self, - params: acp::StreamAssistantMessageChunkParams, - ) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread - .update(cx, |thread, cx| { - thread.push_assistant_chunk(params.chunk, cx) - }) - .ok(); - })?; - - Ok(()) - } - - async fn request_tool_call_confirmation( - &self, - request: acp::RequestToolCallConfirmationParams, - ) -> Result { - let cx = &mut self.cx.clone(); - let ToolCallRequest { id, outcome } = cx - .update(|cx| { - self.thread - .update(cx, |thread, cx| thread.request_new_tool_call(request, cx)) - })? - .context("Failed to update thread")?; - - Ok(acp::RequestToolCallConfirmationResponse { - id, - outcome: outcome.await.map_err(acp::Error::into_internal_error)?, - }) - } - - async fn push_tool_call( - &self, - request: acp::PushToolCallParams, - ) -> Result { - let cx = &mut self.cx.clone(); - let id = cx - .update(|cx| { - self.thread - .update(cx, |thread, cx| thread.push_tool_call(request, cx)) - })? - .context("Failed to update thread")?; - - Ok(acp::PushToolCallResponse { id }) - } - - async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.update_tool_call(request.tool_call_id, request.status, request.content, cx) - }) - })? - .context("Failed to update thread")??; - - Ok(()) - } - - async fn update_plan(&self, request: acp::UpdatePlanParams) -> Result<(), acp::Error> { - let cx = &mut self.cx.clone(); - - cx.update(|cx| { - self.thread - .update(cx, |thread, cx| thread.update_plan(request, cx)) - })? - .context("Failed to update thread")?; - - Ok(()) - } - - async fn read_text_file( - &self, - request: acp::ReadTextFileParams, - ) -> Result { - let content = self - .cx - .update(|cx| { - self.thread - .update(cx, |thread, cx| thread.read_text_file(request, false, cx)) - })? - .context("Failed to update thread")? - .await?; - Ok(acp::ReadTextFileResponse { content }) - } - - async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> { - self.cx - .update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.write_text_file(request.path, request.content, cx) - }) - })? - .context("Failed to update thread")? - .await?; - - Ok(()) - } -} - -fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName { - match icon { - acp::Icon::FileSearch => IconName::ToolSearch, - acp::Icon::Folder => IconName::ToolFolder, - acp::Icon::Globe => IconName::ToolWeb, - acp::Icon::Hammer => IconName::ToolHammer, - acp::Icon::LightBulb => IconName::ToolBulb, - acp::Icon::Pencil => IconName::ToolPencil, - acp::Icon::Regex => IconName::ToolRegex, - acp::Icon::Terminal => IconName::ToolTerminal, - } -} - -pub struct ToolCallRequest { - pub id: acp::ToolCallId, - pub outcome: oneshot::Receiver, -} - #[cfg(test)] mod tests { use super::*; + use agentic_coding_protocol as acp_old; use anyhow::anyhow; use async_pipe::{PipeReader, PipeWriter}; use futures::{channel::mpsc, future::LocalBoxFuture, select}; @@ -1400,6 +1134,7 @@ mod tests { use settings::SettingsStore; use smol::{future::BoxedLocal, stream::StreamExt as _}; use std::{cell::RefCell, rc::Rc, time::Duration}; + use util::path; fn init_test(cx: &mut TestAppContext) { @@ -1424,8 +1159,8 @@ mod tests { fake_server.on_user_message(move |_, server, mut cx| async move { server .update(&mut cx, |server, _| { - server.send_to_zed(acp::StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Thought { + server.send_to_zed(acp_old::StreamAssistantMessageChunkParams { + chunk: acp_old::AssistantMessageChunk::Thought { thought: "Thinking ".into(), }, }) @@ -1434,8 +1169,8 @@ mod tests { .unwrap(); server .update(&mut cx, |server, _| { - server.send_to_zed(acp::StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Thought { + server.send_to_zed(acp_old::StreamAssistantMessageChunkParams { + chunk: acp_old::AssistantMessageChunk::Thought { thought: "hard!".into(), }, }) @@ -1501,7 +1236,7 @@ mod tests { async move { let content = server .update(&mut cx, |server, _| { - server.send_to_zed(acp::ReadTextFileParams { + server.send_to_zed(acp_old::ReadTextFileParams { path: path!("/tmp/foo").into(), line: None, limit: None, @@ -1513,7 +1248,7 @@ mod tests { read_file_tx.take().unwrap().send(()).unwrap(); server .update(&mut cx, |server, _| { - server.send_to_zed(acp::WriteTextFileParams { + server.send_to_zed(acp_old::WriteTextFileParams { path: path!("/tmp/foo").into(), content: "one\ntwo\nthree\nfour\nfive\n".to_string(), }) @@ -1564,9 +1299,9 @@ mod tests { async move { let tool_call_result = server .update(&mut cx, |server, _| { - server.send_to_zed(acp::PushToolCallParams { + server.send_to_zed(acp_old::PushToolCallParams { label: "Fetch".to_string(), - icon: acp::Icon::Globe, + icon: acp_old::Icon::Globe, content: None, locations: vec![], }) @@ -1592,7 +1327,7 @@ mod tests { thread.entries[1], AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, + status: acp::ToolCallStatus::InProgress, .. }, .. @@ -1602,10 +1337,7 @@ mod tests { cx.run_until_parked(); - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .await - .unwrap(); + thread.update(cx, |thread, cx| thread.cancel(cx)).await; thread.read_with(cx, |thread, _| { assert!(matches!( @@ -1619,9 +1351,9 @@ mod tests { fake_server .update(cx, |fake_server, _| { - fake_server.send_to_zed(acp::UpdateToolCallParams { + fake_server.send_to_zed(acp_old::UpdateToolCallParams { tool_call_id: tool_call_id.borrow().unwrap(), - status: acp::ToolCallStatus::Finished, + status: acp_old::ToolCallStatus::Finished, content: None, }) }) @@ -1629,14 +1361,14 @@ mod tests { .unwrap(); drop(end_turn_tx); - request.await.unwrap(); + assert!(request.await.unwrap_err().to_string().contains("canceled")); thread.read_with(cx, |thread, _| { assert!(matches!( thread.entries[1], AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Finished, + status: acp::ToolCallStatus::Completed, .. }, .. @@ -1681,8 +1413,10 @@ mod tests { let thread = cx.new(|cx| { let foreground_executor = cx.foreground_executor().clone(); - let (connection, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), + let thread_rc = Rc::new(RefCell::new(cx.entity().downgrade())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.to_async()), stdin_tx, stdout_rx, move |fut| { @@ -1696,23 +1430,34 @@ mod tests { Ok(()) } }); - AcpThread::new(connection, "Test".into(), Some(io_task), project, cx) + let connection = OldAcpAgentConnection { + name: "test", + connection, + child_status: io_task, + }; + + AcpThread::new( + Rc::new(connection), + project, + acp::SessionId("test".into()), + cx, + ) }); let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx))); (thread, agent) } pub struct FakeAcpServer { - connection: acp::ClientConnection, + connection: acp_old::ClientConnection, _io_task: Task<()>, on_user_message: Option< Rc< dyn Fn( - acp::SendUserMessageParams, + acp_old::SendUserMessageParams, Entity, AsyncApp, - ) -> LocalBoxFuture<'static, Result<(), acp::Error>>, + ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>, >, >, } @@ -1721,31 +1466,38 @@ mod tests { struct FakeAgent { server: Entity, cx: AsyncApp, + cancel_tx: Rc>>>, } - impl acp::Agent for FakeAgent { + impl acp_old::Agent for FakeAgent { async fn initialize( &self, - params: acp::InitializeParams, - ) -> Result { - Ok(acp::InitializeResponse { + params: acp_old::InitializeParams, + ) -> Result { + Ok(acp_old::InitializeResponse { protocol_version: params.protocol_version, is_authenticated: true, }) } - async fn authenticate(&self) -> Result<(), acp::Error> { + async fn authenticate(&self) -> Result<(), acp_old::Error> { Ok(()) } - async fn cancel_send_message(&self) -> Result<(), acp::Error> { + async fn cancel_send_message(&self) -> Result<(), acp_old::Error> { + if let Some(cancel_tx) = self.cancel_tx.take() { + cancel_tx.send(()).log_err(); + } Ok(()) } async fn send_user_message( &self, - request: acp::SendUserMessageParams, - ) -> Result<(), acp::Error> { + request: acp_old::SendUserMessageParams, + ) -> Result<(), acp_old::Error> { + let (cancel_tx, cancel_rx) = oneshot::channel(); + self.cancel_tx.replace(Some(cancel_tx)); + let mut cx = self.cx.clone(); let handler = self .server @@ -1753,7 +1505,10 @@ mod tests { .ok() .flatten(); if let Some(handler) = handler { - handler(request, self.server.clone(), self.cx.clone()).await + select! { + _ = cancel_rx.fuse() => Err(anyhow::anyhow!("Message sending canceled").into()), + _ = handler(request, self.server.clone(), self.cx.clone()).fuse() => Ok(()), + } } else { Err(anyhow::anyhow!("No handler for on_user_message").into()) } @@ -1765,10 +1520,11 @@ mod tests { let agent = FakeAgent { server: cx.entity(), cx: cx.to_async(), + cancel_tx: Default::default(), }; let foreground_executor = cx.foreground_executor().clone(); - let (connection, io_fut) = acp::ClientConnection::connect_to_client( + let (connection, io_fut) = acp_old::ClientConnection::connect_to_client( agent.clone(), stdout, stdin, @@ -1787,10 +1543,14 @@ mod tests { fn on_user_message( &mut self, - handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity, AsyncApp) -> F + handler: impl for<'a> Fn( + acp_old::SendUserMessageParams, + Entity, + AsyncApp, + ) -> F + 'static, ) where - F: Future> + 'static, + F: Future> + 'static, { self.on_user_message .replace(Rc::new(move |request, server, cx| { @@ -1798,7 +1558,7 @@ mod tests { })); } - fn send_to_zed( + fn send_to_zed( &self, message: T, ) -> BoxedLocal> { diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 7c0ba4f41c..fde167da5f 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,20 +1,26 @@ -use agentic_coding_protocol as acp; +use std::{path::Path, rc::Rc}; + +use agent_client_protocol as acp; use anyhow::Result; -use futures::future::{FutureExt as _, LocalBoxFuture}; +use gpui::{AsyncApp, Entity, Task}; +use project::Project; +use ui::App; + +use crate::AcpThread; pub trait AgentConnection { - fn request_any( - &self, - params: acp::AnyAgentRequest, - ) -> LocalBoxFuture<'static, Result>; -} + fn name(&self) -> &'static str; -impl AgentConnection for acp::AgentConnection { - fn request_any( - &self, - params: acp::AnyAgentRequest, - ) -> LocalBoxFuture<'static, Result> { - let task = self.request_any(params); - async move { Ok(task.await?) }.boxed_local() - } + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>>; + + fn authenticate(&self, cx: &mut App) -> Task>; + + fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task>; + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); } diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs new file mode 100644 index 0000000000..316a5bcf25 --- /dev/null +++ b/crates/acp_thread/src/old_acp_support.rs @@ -0,0 +1,461 @@ +// Translates old acp agents into the new schema +use agent_client_protocol as acp; +use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; +use anyhow::{Context as _, Result}; +use futures::channel::oneshot; +use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use project::Project; +use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc}; +use ui::App; + +use crate::{AcpThread, AcpThreadEvent, AgentConnection, ToolCallContent, ToolCallStatus}; + +#[derive(Clone)] +pub struct OldAcpClientDelegate { + thread: Rc>>, + cx: AsyncApp, + next_tool_call_id: Rc>, + // sent_buffer_versions: HashMap, HashMap>, +} + +impl OldAcpClientDelegate { + pub fn new(thread: Rc>>, cx: AsyncApp) -> Self { + Self { + thread, + cx, + next_tool_call_id: Rc::new(RefCell::new(0)), + } + } +} + +impl acp_old::Client for OldAcpClientDelegate { + async fn stream_assistant_message_chunk( + &self, + params: acp_old::StreamAssistantMessageChunkParams, + ) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread + .borrow() + .update(cx, |thread, cx| match params.chunk { + acp_old::AssistantMessageChunk::Text { text } => { + thread.push_assistant_chunk(text.into(), false, cx) + } + acp_old::AssistantMessageChunk::Thought { thought } => { + thread.push_assistant_chunk(thought.into(), true, cx) + } + }) + .ok(); + })?; + + Ok(()) + } + + async fn request_tool_call_confirmation( + &self, + request: acp_old::RequestToolCallConfirmationParams, + ) -> Result { + let cx = &mut self.cx.clone(); + + let old_acp_id = *self.next_tool_call_id.borrow() + 1; + self.next_tool_call_id.replace(old_acp_id); + + let tool_call = into_new_tool_call( + acp::ToolCallId(old_acp_id.to_string().into()), + request.tool_call, + ); + + let mut options = match request.confirmation { + acp_old::ToolCallConfirmation::Edit { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow Edits".to_string(), + )], + acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", root_command), + )], + acp_old::ToolCallConfirmation::Mcp { + server_name, + tool_name, + .. + } => vec![ + ( + acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", server_name), + ), + ( + acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool, + acp::PermissionOptionKind::AllowAlways, + format!("Always Allow {}", tool_name), + ), + ], + acp_old::ToolCallConfirmation::Fetch { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow".to_string(), + )], + acp_old::ToolCallConfirmation::Other { .. } => vec![( + acp_old::ToolCallConfirmationOutcome::AlwaysAllow, + acp::PermissionOptionKind::AllowAlways, + "Always Allow".to_string(), + )], + }; + + options.extend([ + ( + acp_old::ToolCallConfirmationOutcome::Allow, + acp::PermissionOptionKind::AllowOnce, + "Allow".to_string(), + ), + ( + acp_old::ToolCallConfirmationOutcome::Reject, + acp::PermissionOptionKind::RejectOnce, + "Reject".to_string(), + ), + ]); + + let mut outcomes = Vec::with_capacity(options.len()); + let mut acp_options = Vec::with_capacity(options.len()); + + for (index, (outcome, kind, label)) in options.into_iter().enumerate() { + outcomes.push(outcome); + acp_options.push(acp::PermissionOption { + id: acp::PermissionOptionId(index.to_string().into()), + label, + kind, + }) + } + + let response = cx + .update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.request_tool_call_permission(tool_call, acp_options, cx) + }) + })? + .context("Failed to update thread")? + .await; + + let outcome = match response { + Ok(option_id) => outcomes[option_id.0.parse::().unwrap_or(0)], + Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel, + }; + + Ok(acp_old::RequestToolCallConfirmationResponse { + id: acp_old::ToolCallId(old_acp_id), + outcome: outcome, + }) + } + + async fn push_tool_call( + &self, + request: acp_old::PushToolCallParams, + ) -> Result { + let cx = &mut self.cx.clone(); + + let old_acp_id = *self.next_tool_call_id.borrow() + 1; + self.next_tool_call_id.replace(old_acp_id); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.upsert_tool_call( + into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request), + cx, + ) + }) + })? + .context("Failed to update thread")?; + + Ok(acp_old::PushToolCallResponse { + id: acp_old::ToolCallId(old_acp_id), + }) + } + + async fn update_tool_call( + &self, + request: acp_old::UpdateToolCallParams, + ) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + let languages = thread.project.read(cx).languages().clone(); + + if let Some((ix, tool_call)) = thread + .tool_call_mut(&acp::ToolCallId(request.tool_call_id.0.to_string().into())) + { + tool_call.status = ToolCallStatus::Allowed { + status: into_new_tool_call_status(request.status), + }; + tool_call.content = request + .content + .into_iter() + .map(|content| { + ToolCallContent::from_acp( + into_new_tool_call_content(content), + languages.clone(), + cx, + ) + }) + .collect(); + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + anyhow::Ok(()) + } else { + anyhow::bail!("Tool call not found") + } + }) + })? + .context("Failed to update thread")??; + + Ok(()) + } + + async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> { + let cx = &mut self.cx.clone(); + + cx.update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.update_plan( + acp::Plan { + entries: request + .entries + .into_iter() + .map(into_new_plan_entry) + .collect(), + }, + cx, + ) + }) + })? + .context("Failed to update thread")?; + + Ok(()) + } + + async fn read_text_file( + &self, + acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams, + ) -> Result { + let content = self + .cx + .update(|cx| { + self.thread.borrow().update(cx, |thread, cx| { + thread.read_text_file(path, line, limit, false, cx) + }) + })? + .context("Failed to update thread")? + .await?; + Ok(acp_old::ReadTextFileResponse { content }) + } + + async fn write_text_file( + &self, + acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams, + ) -> Result<(), acp_old::Error> { + self.cx + .update(|cx| { + self.thread + .borrow() + .update(cx, |thread, cx| thread.write_text_file(path, content, cx)) + })? + .context("Failed to update thread")? + .await?; + + Ok(()) + } +} + +fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall { + acp::ToolCall { + id: id, + label: request.label, + kind: acp_kind_from_old_icon(request.icon), + status: acp::ToolCallStatus::InProgress, + content: request + .content + .into_iter() + .map(into_new_tool_call_content) + .collect(), + locations: request + .locations + .into_iter() + .map(into_new_tool_call_location) + .collect(), + } +} + +fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind { + match icon { + acp_old::Icon::FileSearch => acp::ToolKind::Search, + acp_old::Icon::Folder => acp::ToolKind::Search, + acp_old::Icon::Globe => acp::ToolKind::Search, + acp_old::Icon::Hammer => acp::ToolKind::Other, + acp_old::Icon::LightBulb => acp::ToolKind::Think, + acp_old::Icon::Pencil => acp::ToolKind::Edit, + acp_old::Icon::Regex => acp::ToolKind::Search, + acp_old::Icon::Terminal => acp::ToolKind::Execute, + } +} + +fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus { + match status { + acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress, + acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed, + acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed, + } +} + +fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent { + match content { + acp_old::ToolCallContent::Markdown { markdown } => acp::ToolCallContent::ContentBlock { + content: acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: markdown, + }), + }, + acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff { + diff: into_new_diff(diff), + }, + } +} + +fn into_new_diff(diff: acp_old::Diff) -> acp::Diff { + acp::Diff { + path: diff.path, + old_text: diff.old_text, + new_text: diff.new_text, + } +} + +fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation { + acp::ToolCallLocation { + path: location.path, + line: location.line, + } +} + +fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry { + acp::PlanEntry { + content: entry.content, + priority: into_new_plan_priority(entry.priority), + status: into_new_plan_status(entry.status), + } +} + +fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority { + match priority { + acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low, + acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium, + acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High, + } +} + +fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus { + match status { + acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending, + acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress, + acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed, + } +} + +#[derive(Debug)] +pub struct Unauthenticated; + +impl Error for Unauthenticated {} +impl fmt::Display for Unauthenticated { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Unauthenticated") + } +} + +pub struct OldAcpAgentConnection { + pub name: &'static str, + pub connection: acp_old::AgentConnection, + pub child_status: Task>, +} + +impl AgentConnection for OldAcpAgentConnection { + fn name(&self) -> &'static str { + self.name + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let task = self.connection.request_any( + acp_old::InitializeParams { + protocol_version: acp_old::ProtocolVersion::latest(), + } + .into_any(), + ); + cx.spawn(async move |cx| { + let result = task.await?; + let result = acp_old::InitializeParams::response_from_any(result)?; + + if !result.is_authenticated { + anyhow::bail!(Unauthenticated) + } + + cx.update(|cx| { + let thread = cx.new(|cx| { + let session_id = acp::SessionId("acp-old-no-id".into()); + AcpThread::new(self.clone(), project, session_id, cx) + }); + thread + }) + }) + } + + fn authenticate(&self, cx: &mut App) -> Task> { + let task = self + .connection + .request_any(acp_old::AuthenticateParams.into_any()); + cx.foreground_executor().spawn(async move { + task.await?; + Ok(()) + }) + } + + fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task> { + let chunks = params + .prompt + .into_iter() + .filter_map(|block| match block { + acp::ContentBlock::Text(text) => { + Some(acp_old::UserMessageChunk::Text { text: text.text }) + } + acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path { + path: link.uri.into(), + }), + _ => None, + }) + .collect(); + + let task = self + .connection + .request_any(acp_old::SendUserMessageParams { chunks }.into_any()); + cx.foreground_executor().spawn(async move { + task.await?; + anyhow::Ok(()) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) { + let task = self + .connection + .request_any(acp_old::CancelSendMessageParams.into_any()); + cx.foreground_executor() + .spawn(async move { + task.await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx) + } +} diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 4714245b94..4371f7684d 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -18,6 +18,7 @@ doctest = false [dependencies] acp_thread.workspace = true +agent-client-protocol.workspace = true agentic-coding-protocol.workspace = true anyhow.workspace = true collections.workspace = true diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 6d9c77f296..660f61f907 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,7 +1,6 @@ mod claude; mod gemini; mod settings; -mod stdio_agent_server; #[cfg(test)] mod e2e_tests; @@ -9,9 +8,8 @@ mod e2e_tests; pub use claude::*; pub use gemini::*; pub use settings::*; -pub use stdio_agent_server::*; -use acp_thread::AcpThread; +use acp_thread::AgentConnection; use anyhow::Result; use collections::HashMap; use gpui::{App, AsyncApp, Entity, SharedString, Task}; @@ -20,6 +18,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ path::{Path, PathBuf}, + rc::Rc, sync::Arc, }; use util::ResultExt as _; @@ -33,14 +32,14 @@ pub trait AgentServer: Send { fn name(&self) -> &'static str; fn empty_state_headline(&self) -> &'static str; fn empty_state_message(&self) -> &'static str; - fn supports_always_allow(&self) -> bool; - fn new_thread( + fn connect( &self, + // these will go away when old_acp is fully removed root_dir: &Path, project: &Entity, cx: &mut App, - ) -> Task>>; + ) -> Task>>; } impl std::fmt::Debug for AgentServerCommand { diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 835efbd655..5f35b4af73 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -1,5 +1,5 @@ mod mcp_server; -mod tools; +pub mod tools; use collections::HashMap; use project::Project; @@ -12,28 +12,24 @@ use std::pin::pin; use std::rc::Rc; use uuid::Uuid; -use agentic_coding_protocol::{ - self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion, - StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams, -}; +use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use futures::channel::oneshot; -use futures::future::LocalBoxFuture; -use futures::{AsyncBufReadExt, AsyncWriteExt, SinkExt}; +use futures::{AsyncBufReadExt, AsyncWriteExt}; use futures::{ AsyncRead, AsyncWrite, FutureExt, StreamExt, channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, io::BufReader, select_biased, }; -use gpui::{App, AppContext, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use serde::{Deserialize, Serialize}; use util::ResultExt; -use crate::claude::mcp_server::ClaudeMcpServer; +use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::tools::ClaudeTool; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection}; +use acp_thread::{AcpThread, AgentConnection}; #[derive(Clone)] pub struct ClaudeCode; @@ -55,29 +51,57 @@ impl AgentServer for ClaudeCode { ui::IconName::AiClaude } - fn supports_always_allow(&self) -> bool { - false + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + _cx: &mut App, + ) -> Task>> { + let connection = ClaudeAgentConnection { + sessions: Default::default(), + }; + + Task::ready(Ok(Rc::new(connection) as _)) + } +} + +#[cfg(unix)] +fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> { + let pid = nix::unistd::Pid::from_raw(pid); + + nix::sys::signal::kill(pid, nix::sys::signal::SIGINT) + .map_err(|e| anyhow!("Failed to interrupt process: {}", e)) +} + +#[cfg(windows)] +fn send_interrupt(_pid: i32) -> anyhow::Result<()> { + panic!("Cancel not implemented on Windows") +} + +struct ClaudeAgentConnection { + sessions: Rc>>, +} + +impl AgentConnection for ClaudeAgentConnection { + fn name(&self) -> &'static str { + ClaudeCode.name() } fn new_thread( - &self, - root_dir: &Path, - project: &Entity, - cx: &mut App, + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, ) -> Task>> { - let project = project.clone(); - let root_dir = root_dir.to_path_buf(); - let title = self.name().into(); + let cwd = cwd.to_owned(); cx.spawn(async move |cx| { - let (mut delegate_tx, delegate_rx) = watch::channel(None); - let tool_id_map = Rc::new(RefCell::new(HashMap::default())); - - let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?; + let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); + let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?; let mut mcp_servers = HashMap::default(); mcp_servers.insert( mcp_server::SERVER_NAME.to_string(), - mcp_server.server_config()?, + permission_mcp_server.server_config()?, ); let mcp_config = McpConfig { mcp_servers }; @@ -104,177 +128,180 @@ impl AgentServer for ClaudeCode { let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); let (cancel_tx, mut cancel_rx) = mpsc::unbounded::>>(); - let session_id = Uuid::new_v4(); + let session_id = acp::SessionId(Uuid::new_v4().to_string().into()); log::trace!("Starting session with id: {}", session_id); - cx.background_spawn(async move { - let mut outgoing_rx = Some(outgoing_rx); - let mut mode = ClaudeSessionMode::Start; + cx.background_spawn({ + let session_id = session_id.clone(); + async move { + let mut outgoing_rx = Some(outgoing_rx); + let mut mode = ClaudeSessionMode::Start; - loop { - let mut child = - spawn_claude(&command, mode, session_id, &mcp_config_path, &root_dir) - .await?; - mode = ClaudeSessionMode::Resume; - - let pid = child.id(); - log::trace!("Spawned (pid: {})", pid); - - let mut io_fut = pin!( - ClaudeAgentConnection::handle_io( - outgoing_rx.take().unwrap(), - incoming_message_tx.clone(), - child.stdin.take().unwrap(), - child.stdout.take().unwrap(), + loop { + let mut child = spawn_claude( + &command, + mode, + session_id.clone(), + &mcp_config_path, + &cwd, ) - .fuse() - ); + .await?; + mode = ClaudeSessionMode::Resume; - select_biased! { - done_tx = cancel_rx.next() => { - if let Some(done_tx) = done_tx { - log::trace!("Interrupted (pid: {})", pid); - let result = send_interrupt(pid as i32); - outgoing_rx.replace(io_fut.await?); - done_tx.send(result).log_err(); - continue; + let pid = child.id(); + log::trace!("Spawned (pid: {})", pid); + + let mut io_fut = pin!( + ClaudeAgentSession::handle_io( + outgoing_rx.take().unwrap(), + incoming_message_tx.clone(), + child.stdin.take().unwrap(), + child.stdout.take().unwrap(), + ) + .fuse() + ); + + select_biased! { + done_tx = cancel_rx.next() => { + if let Some(done_tx) = done_tx { + log::trace!("Interrupted (pid: {})", pid); + let result = send_interrupt(pid as i32); + outgoing_rx.replace(io_fut.await?); + done_tx.send(result).log_err(); + continue; + } + } + result = io_fut => { + result?; } } - result = io_fut => { - result?; - } + + log::trace!("Stopped (pid: {})", pid); + break; } - log::trace!("Stopped (pid: {})", pid); - break; + drop(mcp_config_path); + anyhow::Ok(()) } - - drop(mcp_config_path); - anyhow::Ok(()) }) .detach(); - cx.new(|cx| { - let end_turn_tx = Rc::new(RefCell::new(None)); - let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()); - delegate_tx.send(Some(delegate.clone())).log_err(); - - let handler_task = cx.foreground_executor().spawn({ - let end_turn_tx = end_turn_tx.clone(); - let tool_id_map = tool_id_map.clone(); - let delegate = delegate.clone(); - async move { - while let Some(message) = incoming_message_rx.next().await { - ClaudeAgentConnection::handle_message( - delegate.clone(), - message, - end_turn_tx.clone(), - tool_id_map.clone(), - ) - .await - } + let end_turn_tx = Rc::new(RefCell::new(None)); + let handler_task = cx.spawn({ + let end_turn_tx = end_turn_tx.clone(); + let thread_rx = thread_rx.clone(); + async move |cx| { + while let Some(message) = incoming_message_rx.next().await { + ClaudeAgentSession::handle_message( + thread_rx.clone(), + message, + end_turn_tx.clone(), + cx, + ) + .await } - }); + } + }); - let mut connection = ClaudeAgentConnection { - delegate, - outgoing_tx, - end_turn_tx, - cancel_tx, - session_id, - _handler_task: handler_task, - _mcp_server: None, - }; + let thread = + cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?; - connection._mcp_server = Some(mcp_server); - acp_thread::AcpThread::new(connection, title, None, project.clone(), cx) - }) + thread_tx.send(thread.downgrade())?; + + let session = ClaudeAgentSession { + outgoing_tx, + end_turn_tx, + cancel_tx, + _handler_task: handler_task, + _mcp_server: Some(permission_mcp_server), + }; + + self.sessions.borrow_mut().insert(session_id, session); + + Ok(thread) }) } -} -#[cfg(unix)] -fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> { - let pid = nix::unistd::Pid::from_raw(pid); + fn authenticate(&self, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow!("Authentication not supported"))) + } - nix::sys::signal::kill(pid, nix::sys::signal::SIGINT) - .map_err(|e| anyhow!("Failed to interrupt process: {}", e)) -} + fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task> { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(¶ms.session_id) else { + return Task::ready(Err(anyhow!( + "Attempted to send message to nonexistent session {}", + params.session_id + ))); + }; -#[cfg(windows)] -fn send_interrupt(_pid: i32) -> anyhow::Result<()> { - panic!("Cancel not implemented on Windows") -} + let (tx, rx) = oneshot::channel(); + session.end_turn_tx.borrow_mut().replace(tx); -impl AgentConnection for ClaudeAgentConnection { - /// Send a request to the agent and wait for a response. - fn request_any( - &self, - params: AnyAgentRequest, - ) -> LocalBoxFuture<'static, Result> { - let delegate = self.delegate.clone(); - let end_turn_tx = self.end_turn_tx.clone(); - let outgoing_tx = self.outgoing_tx.clone(); - let mut cancel_tx = self.cancel_tx.clone(); - let session_id = self.session_id; - async move { - match params { - // todo: consider sending an empty request so we get the init response? - AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse( - acp::InitializeResponse { - is_authenticated: true, - protocol_version: ProtocolVersion::latest(), - }, - )), - AnyAgentRequest::AuthenticateParams(_) => { - Err(anyhow!("Authentication not supported")) + let mut content = String::new(); + for chunk in params.prompt { + match chunk { + acp::ContentBlock::Text(text_content) => { + content.push_str(&text_content.text); } - AnyAgentRequest::SendUserMessageParams(message) => { - delegate.clear_completed_plan_entries().await?; - - let (tx, rx) = oneshot::channel(); - end_turn_tx.borrow_mut().replace(tx); - let mut content = String::new(); - for chunk in message.chunks { - match chunk { - agentic_coding_protocol::UserMessageChunk::Text { text } => { - content.push_str(&text) - } - agentic_coding_protocol::UserMessageChunk::Path { path } => { - content.push_str(&format!("@{path:?}")) - } - } - } - outgoing_tx.unbounded_send(SdkMessage::User { - message: Message { - role: Role::User, - content: Content::UntaggedText(content), - id: None, - model: None, - stop_reason: None, - stop_sequence: None, - usage: None, - }, - session_id: Some(session_id), - })?; - rx.await??; - Ok(AnyAgentResult::SendUserMessageResponse( - acp::SendUserMessageResponse, - )) + acp::ContentBlock::ResourceLink(resource_link) => { + content.push_str(&format!("@{}", resource_link.uri)); } - AnyAgentRequest::CancelSendMessageParams(_) => { - let (done_tx, done_rx) = oneshot::channel(); - cancel_tx.send(done_tx).await?; - done_rx.await??; - - Ok(AnyAgentResult::CancelSendMessageResponse( - acp::CancelSendMessageResponse, - )) + acp::ContentBlock::Audio(_) + | acp::ContentBlock::Image(_) + | acp::ContentBlock::Resource(_) => { + // TODO } } } - .boxed_local() + + if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User { + message: Message { + role: Role::User, + content: Content::UntaggedText(content), + id: None, + model: None, + stop_reason: None, + stop_sequence: None, + usage: None, + }, + session_id: Some(params.session_id.to_string()), + }) { + return Task::ready(Err(anyhow!(err))); + } + + cx.foreground_executor().spawn(async move { + rx.await??; + Ok(()) + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let sessions = self.sessions.borrow(); + let Some(session) = sessions.get(&session_id) else { + log::warn!("Attempted to cancel nonexistent session {}", session_id); + return; + }; + + let (done_tx, done_rx) = oneshot::channel(); + if session + .cancel_tx + .unbounded_send(done_tx) + .log_err() + .is_some() + { + let end_turn_tx = session.end_turn_tx.clone(); + cx.foreground_executor() + .spawn(async move { + done_rx.await??; + if let Some(end_turn_tx) = end_turn_tx.take() { + end_turn_tx.send(Ok(())).ok(); + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } } } @@ -287,7 +314,7 @@ enum ClaudeSessionMode { async fn spawn_claude( command: &AgentServerCommand, mode: ClaudeSessionMode, - session_id: Uuid, + session_id: acp::SessionId, mcp_config_path: &Path, root_dir: &Path, ) -> Result { @@ -327,88 +354,103 @@ async fn spawn_claude( Ok(child) } -struct ClaudeAgentConnection { - delegate: AcpClientDelegate, - session_id: Uuid, +struct ClaudeAgentSession { outgoing_tx: UnboundedSender, end_turn_tx: Rc>>>>, cancel_tx: UnboundedSender>>, - _mcp_server: Option, + _mcp_server: Option, _handler_task: Task<()>, } -impl ClaudeAgentConnection { +impl ClaudeAgentSession { async fn handle_message( - delegate: AcpClientDelegate, + mut thread_rx: watch::Receiver>, message: SdkMessage, end_turn_tx: Rc>>>>, - tool_id_map: Rc>>, + cx: &mut AsyncApp, ) { match message { - SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => { + SdkMessage::Assistant { + message, + session_id: _, + } + | SdkMessage::User { + message, + session_id: _, + } => { + let Some(thread) = thread_rx + .recv() + .await + .log_err() + .and_then(|entity| entity.upgrade()) + else { + log::error!("Received an SDK message but thread is gone"); + return; + }; + for chunk in message.content.chunks() { match chunk { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - delegate - .stream_assistant_message_chunk(StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Text { text }, + thread + .update(cx, |thread, cx| { + thread.push_assistant_chunk(text.into(), false, cx) }) - .await .log_err(); } ContentChunk::ToolUse { id, name, input } => { let claude_tool = ClaudeTool::infer(&name, input); - if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { - delegate - .update_plan(acp::UpdatePlanParams { - entries: params.todos.into_iter().map(Into::into).collect(), - }) - .await - .log_err(); - } else if let Some(resp) = delegate - .push_tool_call(claude_tool.as_acp()) - .await - .log_err() - { - tool_id_map.borrow_mut().insert(id, resp.id); - } + thread + .update(cx, |thread, cx| { + if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { + thread.update_plan( + acp::Plan { + entries: params + .todos + .into_iter() + .map(Into::into) + .collect(), + }, + cx, + ) + } else { + thread.upsert_tool_call( + claude_tool.as_acp(acp::ToolCallId(id.into())), + cx, + ); + } + }) + .log_err(); } ContentChunk::ToolResult { content, tool_use_id, } => { - let id = tool_id_map.borrow_mut().remove(&tool_use_id); - if let Some(id) = id { - let content = content.to_string(); - delegate - .update_tool_call(UpdateToolCallParams { - tool_call_id: id, - status: acp::ToolCallStatus::Finished, - // Don't unset existing content - content: (!content.is_empty()).then_some( - ToolCallContent::Markdown { - // For now we only include text content - markdown: content, - }, - ), - }) - .await - .log_err(); - } + let content = content.to_string(); + thread + .update(cx, |thread, cx| { + thread.update_tool_call( + acp::ToolCallId(tool_use_id.into()), + acp::ToolCallStatus::Completed, + (!content.is_empty()).then(|| vec![content.into()]), + cx, + ) + }) + .log_err(); } ContentChunk::Image | ContentChunk::Document | ContentChunk::Thinking | ContentChunk::RedactedThinking | ContentChunk::WebSearchToolResult => { - delegate - .stream_assistant_message_chunk(StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Text { - text: format!("Unsupported content: {:?}", chunk), - }, + thread + .update(cx, |thread, cx| { + thread.push_assistant_chunk( + format!("Unsupported content: {:?}", chunk).into(), + false, + cx, + ) }) - .await .log_err(); } } @@ -592,14 +634,14 @@ enum SdkMessage { Assistant { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, // A user message User { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, // Emitted as the last message in a conversation @@ -661,21 +703,6 @@ enum PermissionMode { Plan, } -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct McpConfig { - mcp_servers: HashMap, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct McpServerConfig { - command: String, - args: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - env: Option>, -} - #[cfg(test)] pub(crate) mod tests { use super::*; diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index 2405603550..0a39a02931 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -1,29 +1,22 @@ -use std::{cell::RefCell, rc::Rc}; +use std::path::PathBuf; -use acp_thread::AcpClientDelegate; -use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams}; +use acp_thread::AcpThread; +use agent_client_protocol as acp; use anyhow::{Context, Result}; use collections::HashMap; -use context_server::{ - listener::McpServer, - types::{ - CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse, - ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, - ToolResponseContent, ToolsCapabilities, requests, - }, +use context_server::types::{ + CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse, + ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, + ToolResponseContent, ToolsCapabilities, requests, }; -use gpui::{App, AsyncApp, Task}; +use gpui::{App, AsyncApp, Entity, Task, WeakEntity}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use util::debug_panic; -use crate::claude::{ - McpServerConfig, - tools::{ClaudeTool, EditToolParams, ReadToolParams}, -}; +use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; -pub struct ClaudeMcpServer { - server: McpServer, +pub struct ClaudeZedMcpServer { + server: context_server::listener::McpServer, } pub const SERVER_NAME: &str = "zed"; @@ -52,17 +45,16 @@ enum PermissionToolBehavior { Deny, } -impl ClaudeMcpServer { +impl ClaudeZedMcpServer { pub async fn new( - delegate: watch::Receiver>, - tool_id_map: Rc>>, + thread_rx: watch::Receiver>, cx: &AsyncApp, ) -> Result { - let mut mcp_server = McpServer::new(cx).await?; + let mut mcp_server = context_server::listener::McpServer::new(cx).await?; mcp_server.handle_request::(Self::handle_initialize); mcp_server.handle_request::(Self::handle_list_tools); mcp_server.handle_request::(move |request, cx| { - Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx) + Self::handle_call_tool(request, thread_rx.clone(), cx) }); Ok(Self { server: mcp_server }) @@ -70,9 +62,7 @@ impl ClaudeMcpServer { pub fn server_config(&self) -> Result { let zed_path = std::env::current_exe() - .context("finding current executable path for use in mcp_server")? - .to_string_lossy() - .to_string(); + .context("finding current executable path for use in mcp_server")?; Ok(McpServerConfig { command: zed_path, @@ -152,22 +142,19 @@ impl ClaudeMcpServer { fn handle_call_tool( request: CallToolParams, - mut delegate_watch: watch::Receiver>, - tool_id_map: Rc>>, + mut thread_rx: watch::Receiver>, cx: &App, ) -> Task> { cx.spawn(async move |cx| { - let Some(delegate) = delegate_watch.recv().await? else { - debug_panic!("Sent None delegate"); - anyhow::bail!("Server not available"); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); }; if request.name.as_str() == PERMISSION_TOOL { let input = serde_json::from_value(request.arguments.context("Arguments required")?)?; - let result = - Self::handle_permissions_tool_call(input, delegate, tool_id_map, cx).await?; + let result = Self::handle_permissions_tool_call(input, thread, cx).await?; Ok(CallToolResponse { content: vec![ToolResponseContent::Text { text: serde_json::to_string(&result)?, @@ -179,7 +166,7 @@ impl ClaudeMcpServer { let input = serde_json::from_value(request.arguments.context("Arguments required")?)?; - let content = Self::handle_read_tool_call(input, delegate, cx).await?; + let content = Self::handle_read_tool_call(input, thread, cx).await?; Ok(CallToolResponse { content, is_error: None, @@ -189,7 +176,7 @@ impl ClaudeMcpServer { let input = serde_json::from_value(request.arguments.context("Arguments required")?)?; - Self::handle_edit_tool_call(input, delegate, cx).await?; + Self::handle_edit_tool_call(input, thread, cx).await?; Ok(CallToolResponse { content: vec![], is_error: None, @@ -202,49 +189,46 @@ impl ClaudeMcpServer { } fn handle_read_tool_call( - params: ReadToolParams, - delegate: AcpClientDelegate, + ReadToolParams { + abs_path, + offset, + limit, + }: ReadToolParams, + thread: Entity, cx: &AsyncApp, ) -> Task>> { - cx.foreground_executor().spawn(async move { - let response = delegate - .read_text_file(ReadTextFileParams { - path: params.abs_path, - line: params.offset, - limit: params.limit, - }) + cx.spawn(async move |cx| { + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(abs_path, offset, limit, false, cx) + })? .await?; - Ok(vec![ToolResponseContent::Text { - text: response.content, - }]) + Ok(vec![ToolResponseContent::Text { text: content }]) }) } fn handle_edit_tool_call( params: EditToolParams, - delegate: AcpClientDelegate, + thread: Entity, cx: &AsyncApp, ) -> Task> { - cx.foreground_executor().spawn(async move { - let response = delegate - .read_text_file_reusing_snapshot(ReadTextFileParams { - path: params.abs_path.clone(), - line: None, - limit: None, - }) + cx.spawn(async move |cx| { + let content = thread + .update(cx, |threads, cx| { + threads.read_text_file(params.abs_path.clone(), None, None, true, cx) + })? .await?; - let new_content = response.content.replace(¶ms.old_text, ¶ms.new_text); - if new_content == response.content { + let new_content = content.replace(¶ms.old_text, ¶ms.new_text); + if new_content == content { return Err(anyhow::anyhow!("The old_text was not found in the content")); } - delegate - .write_text_file(WriteTextFileParams { - path: params.abs_path, - content: new_content, - }) + thread + .update(cx, |threads, cx| { + threads.write_text_file(params.abs_path, new_content, cx) + })? .await?; Ok(()) @@ -253,44 +237,65 @@ impl ClaudeMcpServer { fn handle_permissions_tool_call( params: PermissionToolParams, - delegate: AcpClientDelegate, - tool_id_map: Rc>>, + thread: Entity, cx: &AsyncApp, ) -> Task> { - cx.foreground_executor().spawn(async move { + cx.spawn(async move |cx| { let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone()); - let tool_call_id = match params.tool_use_id { - Some(tool_use_id) => tool_id_map - .borrow() - .get(&tool_use_id) - .cloned() - .context("Tool call ID not found")?, + let tool_call_id = + acp::ToolCallId(params.tool_use_id.context("Tool ID required")?.into()); - None => delegate.push_tool_call(claude_tool.as_acp()).await?.id, - }; + let allow_option_id = acp::PermissionOptionId("allow".into()); + let reject_option_id = acp::PermissionOptionId("reject".into()); - let outcome = delegate - .request_existing_tool_call_confirmation( - tool_call_id, - claude_tool.confirmation(None), - ) + let chosen_option = thread + .update(cx, |thread, cx| { + thread.request_tool_call_permission( + claude_tool.as_acp(tool_call_id), + vec![ + acp::PermissionOption { + id: allow_option_id.clone(), + label: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: reject_option_id, + label: "Reject".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + cx, + ) + })? .await?; - match outcome { - acp::ToolCallConfirmationOutcome::Allow - | acp::ToolCallConfirmationOutcome::AlwaysAllow - | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer - | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse { + if chosen_option == allow_option_id { + Ok(PermissionToolResponse { behavior: PermissionToolBehavior::Allow, updated_input: params.input, - }), - acp::ToolCallConfirmationOutcome::Reject - | acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse { + }) + } else { + Ok(PermissionToolResponse { behavior: PermissionToolBehavior::Deny, updated_input: params.input, - }), + }) } }) } } + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfig { + pub mcp_servers: HashMap, +} + +#[derive(Serialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfig { + pub command: PathBuf, + pub args: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, +} diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index 75a26ee230..ed25f9af7f 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use agentic_coding_protocol::{self as acp, PushToolCallParams, ToolCallLocation}; +use agent_client_protocol as acp; use itertools::Itertools; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -115,51 +115,36 @@ impl ClaudeTool { Self::Other { name, .. } => name.clone(), } } - - pub fn content(&self) -> Option { + pub fn content(&self) -> Vec { match &self { - Self::Other { input, .. } => Some(acp::ToolCallContent::Markdown { - markdown: format!( + Self::Other { input, .. } => vec![ + format!( "```json\n{}```", serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) - ), - }), - Self::Task(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.prompt.clone(), - }), - Self::NotebookRead(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.notebook_path.display().to_string(), - }), - Self::NotebookEdit(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.new_source.clone(), - }), - Self::Terminal(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: format!( + ) + .into(), + ], + Self::Task(Some(params)) => vec![params.prompt.clone().into()], + Self::NotebookRead(Some(params)) => { + vec![params.notebook_path.display().to_string().into()] + } + Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()], + Self::Terminal(Some(params)) => vec![ + format!( "`{}`\n\n{}", params.command, params.description.as_deref().unwrap_or_default() - ), - }), - Self::ReadFile(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.abs_path.display().to_string(), - }), - Self::Ls(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.path.display().to_string(), - }), - Self::Glob(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.to_string(), - }), - Self::Grep(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: format!("`{params}`"), - }), - Self::WebFetch(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.prompt.clone(), - }), - Self::WebSearch(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.to_string(), - }), - Self::TodoWrite(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params + ) + .into(), + ], + Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()], + Self::Ls(Some(params)) => vec![params.path.display().to_string().into()], + Self::Glob(Some(params)) => vec![params.to_string().into()], + Self::Grep(Some(params)) => vec![format!("`{params}`").into()], + Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()], + Self::WebSearch(Some(params)) => vec![params.to_string().into()], + Self::TodoWrite(Some(params)) => vec![ + params .todos .iter() .map(|todo| { @@ -174,34 +159,39 @@ impl ClaudeTool { todo.content ) }) - .join("\n"), - }), - Self::ExitPlanMode(Some(params)) => Some(acp::ToolCallContent::Markdown { - markdown: params.plan.clone(), - }), - Self::Edit(Some(params)) => Some(acp::ToolCallContent::Diff { + .join("\n") + .into(), + ], + Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()], + Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff { diff: acp::Diff { path: params.abs_path.clone(), old_text: Some(params.old_text.clone()), new_text: params.new_text.clone(), }, - }), - Self::Write(Some(params)) => Some(acp::ToolCallContent::Diff { + }], + Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { diff: acp::Diff { path: params.file_path.clone(), old_text: None, new_text: params.content.clone(), }, - }), + }], Self::MultiEdit(Some(params)) => { // todo: show multiple edits in a multibuffer? - params.edits.first().map(|edit| acp::ToolCallContent::Diff { - diff: acp::Diff { - path: params.file_path.clone(), - old_text: Some(edit.old_string.clone()), - new_text: edit.new_string.clone(), - }, - }) + params + .edits + .first() + .map(|edit| { + vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: Some(edit.old_string.clone()), + new_text: edit.new_string.clone(), + }, + }] + }) + .unwrap_or_default() } Self::Task(None) | Self::NotebookRead(None) @@ -217,181 +207,80 @@ impl ClaudeTool { | Self::ExitPlanMode(None) | Self::Edit(None) | Self::Write(None) - | Self::MultiEdit(None) => None, + | Self::MultiEdit(None) => vec![], } } - pub fn icon(&self) -> acp::Icon { + pub fn kind(&self) -> acp::ToolKind { match self { - Self::Task(_) => acp::Icon::Hammer, - Self::NotebookRead(_) => acp::Icon::FileSearch, - Self::NotebookEdit(_) => acp::Icon::Pencil, - Self::Edit(_) => acp::Icon::Pencil, - Self::MultiEdit(_) => acp::Icon::Pencil, - Self::Write(_) => acp::Icon::Pencil, - Self::ReadFile(_) => acp::Icon::FileSearch, - Self::Ls(_) => acp::Icon::Folder, - Self::Glob(_) => acp::Icon::FileSearch, - Self::Grep(_) => acp::Icon::Regex, - Self::Terminal(_) => acp::Icon::Terminal, - Self::WebSearch(_) => acp::Icon::Globe, - Self::WebFetch(_) => acp::Icon::Globe, - Self::TodoWrite(_) => acp::Icon::LightBulb, - Self::ExitPlanMode(_) => acp::Icon::Hammer, - Self::Other { .. } => acp::Icon::Hammer, - } - } - - pub fn confirmation(&self, description: Option) -> acp::ToolCallConfirmation { - match &self { - Self::Edit(_) | Self::Write(_) | Self::NotebookEdit(_) | Self::MultiEdit(_) => { - acp::ToolCallConfirmation::Edit { description } - } - Self::WebFetch(params) => acp::ToolCallConfirmation::Fetch { - urls: params - .as_ref() - .map(|p| vec![p.url.clone()]) - .unwrap_or_default(), - description, - }, - Self::Terminal(Some(BashToolParams { - description, - command, - .. - })) => acp::ToolCallConfirmation::Execute { - command: command.clone(), - root_command: command.clone(), - description: description.clone(), - }, - Self::ExitPlanMode(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {}", params.plan) - } else { - params.plan.clone() - }, - }, - Self::Task(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {}", params.description) - } else { - params.description.clone() - }, - }, - Self::Ls(Some(LsToolParams { path, .. })) - | Self::ReadFile(Some(ReadToolParams { abs_path: path, .. })) => { - let path = path.display(); - acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {path}") - } else { - path.to_string() - }, - } - } - Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { - let path = notebook_path.display(); - acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {path}") - } else { - path.to_string() - }, - } - } - Self::Glob(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {params}") - } else { - params.to_string() - }, - }, - Self::Grep(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {params}") - } else { - params.to_string() - }, - }, - Self::WebSearch(Some(params)) => acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {params}") - } else { - params.to_string() - }, - }, - Self::TodoWrite(Some(params)) => { - let params = params.todos.iter().map(|todo| &todo.content).join(", "); - acp::ToolCallConfirmation::Other { - description: if let Some(description) = description { - format!("{description} {params}") - } else { - params - }, - } - } - Self::Terminal(None) - | Self::Task(None) - | Self::NotebookRead(None) - | Self::ExitPlanMode(None) - | Self::Ls(None) - | Self::Glob(None) - | Self::Grep(None) - | Self::ReadFile(None) - | Self::WebSearch(None) - | Self::TodoWrite(None) - | Self::Other { .. } => acp::ToolCallConfirmation::Other { - description: description.unwrap_or("".to_string()), - }, + Self::Task(_) => acp::ToolKind::Think, + Self::NotebookRead(_) => acp::ToolKind::Read, + Self::NotebookEdit(_) => acp::ToolKind::Edit, + Self::Edit(_) => acp::ToolKind::Edit, + Self::MultiEdit(_) => acp::ToolKind::Edit, + Self::Write(_) => acp::ToolKind::Edit, + Self::ReadFile(_) => acp::ToolKind::Read, + Self::Ls(_) => acp::ToolKind::Search, + Self::Glob(_) => acp::ToolKind::Search, + Self::Grep(_) => acp::ToolKind::Search, + Self::Terminal(_) => acp::ToolKind::Execute, + Self::WebSearch(_) => acp::ToolKind::Search, + Self::WebFetch(_) => acp::ToolKind::Fetch, + Self::TodoWrite(_) => acp::ToolKind::Think, + Self::ExitPlanMode(_) => acp::ToolKind::Think, + Self::Other { .. } => acp::ToolKind::Other, } } pub fn locations(&self) -> Vec { match &self { - Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![ToolCallLocation { + Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation { path: abs_path.clone(), line: None, }], Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { - vec![ToolCallLocation { + vec![acp::ToolCallLocation { + path: file_path.clone(), + line: None, + }] + } + Self::Write(Some(WriteToolParams { file_path, .. })) => { + vec![acp::ToolCallLocation { path: file_path.clone(), line: None, }] } - Self::Write(Some(WriteToolParams { file_path, .. })) => vec![ToolCallLocation { - path: file_path.clone(), - line: None, - }], Self::ReadFile(Some(ReadToolParams { abs_path, offset, .. - })) => vec![ToolCallLocation { + })) => vec![acp::ToolCallLocation { path: abs_path.clone(), line: *offset, }], Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { - vec![ToolCallLocation { + vec![acp::ToolCallLocation { path: notebook_path.clone(), line: None, }] } Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { - vec![ToolCallLocation { + vec![acp::ToolCallLocation { path: notebook_path.clone(), line: None, }] } Self::Glob(Some(GlobToolParams { path: Some(path), .. - })) => vec![ToolCallLocation { + })) => vec![acp::ToolCallLocation { path: path.clone(), line: None, }], - Self::Ls(Some(LsToolParams { path, .. })) => vec![ToolCallLocation { + Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation { path: path.clone(), line: None, }], Self::Grep(Some(GrepToolParams { path: Some(path), .. - })) => vec![ToolCallLocation { + })) => vec![acp::ToolCallLocation { path: PathBuf::from(path), line: None, }], @@ -414,11 +303,13 @@ impl ClaudeTool { } } - pub fn as_acp(&self) -> PushToolCallParams { - PushToolCallParams { + pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall { + acp::ToolCall { + id, + kind: self.kind(), + status: acp::ToolCallStatus::InProgress, label: self.label(), content: self.content(), - icon: self.icon(), locations: self.locations(), } } diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 12f74cb13e..9bc6fd60fe 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -1,10 +1,9 @@ use std::{path::Path, sync::Arc, time::Duration}; use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; -use acp_thread::{ - AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus, -}; -use agentic_coding_protocol as acp; +use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; +use agent_client_protocol as acp; + use futures::{FutureExt, StreamExt, channel::mpsc, select}; use gpui::{Entity, TestAppContext}; use indoc::indoc; @@ -54,19 +53,25 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes thread .update(cx, |thread, cx| { thread.send( - acp::SendUserMessageParams { - chunks: vec![ - acp::UserMessageChunk::Text { - text: "Read the file ".into(), - }, - acp::UserMessageChunk::Path { - path: Path::new("foo.rs").into(), - }, - acp::UserMessageChunk::Text { - text: " and tell me what the content of the println! is".into(), - }, - ], - }, + vec![ + acp::ContentBlock::Text(acp::TextContent { + text: "Read the file ".into(), + annotations: None, + }), + acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: "foo.rs".into(), + name: "foo.rs".into(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }), + acp::ContentBlock::Text(acp::TextContent { + text: " and tell me what the content of the println! is".into(), + annotations: None, + }), + ], cx, ) }) @@ -161,11 +166,8 @@ pub async fn test_tool_call_with_confirmation( let tool_call_id = thread.read_with(cx, |thread, _cx| { let AgentThreadEntry::ToolCall(ToolCall { id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, + content, + status: ToolCallStatus::WaitingForConfirmation { .. }, .. }) = &thread .entries() @@ -176,13 +178,18 @@ pub async fn test_tool_call_with_confirmation( panic!(); }; - assert!(root_command.contains("touch")); + assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch"))); - *id + id.clone() }); thread.update(cx, |thread, cx| { - thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); + thread.authorize_tool_call( + tool_call_id, + acp::PermissionOptionId("0".into()), + acp::PermissionOptionKind::AllowOnce, + cx, + ); assert!(thread.entries().iter().any(|entry| matches!( entry, @@ -197,7 +204,7 @@ pub async fn test_tool_call_with_confirmation( thread.read_with(cx, |thread, cx| { let AgentThreadEntry::ToolCall(ToolCall { - content: Some(ToolCallContent::Markdown { markdown }), + content, status: ToolCallStatus::Allowed { .. }, .. }) = thread @@ -209,13 +216,10 @@ pub async fn test_tool_call_with_confirmation( panic!(); }; - markdown.read_with(cx, |md, _cx| { - assert!( - md.source().contains("Hello"), - r#"Expected '{}' to contain "Hello""#, - md.source() - ); - }); + assert!( + content.iter().any(|c| c.to_markdown(cx).contains("Hello")), + "Expected content to contain 'Hello'" + ); }); } @@ -249,26 +253,20 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon thread.read_with(cx, |thread, _cx| { let AgentThreadEntry::ToolCall(ToolCall { id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, + content, + status: ToolCallStatus::WaitingForConfirmation { .. }, .. }) = &thread.entries()[first_tool_call_ix] else { panic!("{:?}", thread.entries()[1]); }; - assert!(root_command.contains("touch")); + assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch"))); - *id + id.clone() }); - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .await - .unwrap(); + let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); full_turn.await.unwrap(); thread.read_with(cx, |thread, _| { let AgentThreadEntry::ToolCall(ToolCall { @@ -369,15 +367,16 @@ pub async fn new_test_thread( current_dir: impl AsRef, cx: &mut TestAppContext, ) -> Entity { - let thread = cx - .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx)) + let connection = cx + .update(|cx| server.connect(current_dir.as_ref(), &project, cx)) .await .unwrap(); - thread - .update(cx, |thread, _| thread.initialize()) + let thread = connection + .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async()) .await .unwrap(); + thread } diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 8ad147cbff..47b965cdad 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,9 +1,17 @@ -use crate::stdio_agent_server::StdioAgentServer; -use crate::{AgentServerCommand, AgentServerVersion}; +use anyhow::anyhow; +use std::cell::RefCell; +use std::path::Path; +use std::rc::Rc; +use util::ResultExt as _; + +use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; +use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; +use agentic_coding_protocol as acp_old; use anyhow::{Context as _, Result}; -use gpui::{AsyncApp, Entity}; +use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; use settings::SettingsStore; +use ui::App; use crate::AllAgentServersSettings; @@ -12,7 +20,7 @@ pub struct Gemini; const ACP_ARG: &str = "--experimental-acp"; -impl StdioAgentServer for Gemini { +impl AgentServer for Gemini { fn name(&self) -> &'static str { "Gemini" } @@ -25,14 +33,88 @@ impl StdioAgentServer for Gemini { "Ask questions, edit files, run commands.\nBe specific for the best results." } - fn supports_always_allow(&self) -> bool { - true - } - fn logo(&self) -> ui::IconName { ui::IconName::AiGemini } + fn connect( + &self, + root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + let root_dir = root_dir.to_path_buf(); + let project = project.clone(); + let this = self.clone(); + let name = self.name(); + + cx.spawn(async move |cx| { + let command = this.command(&project, cx).await?; + + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + let foreground_executor = cx.foreground_executor().clone(); + + let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => { + if let Some(AgentServerVersion::Unsupported { + error_message, + upgrade_message, + upgrade_command, + }) = this.version(&command).await.log_err() + { + Err(anyhow!(LoadError::Unsupported { + error_message, + upgrade_message, + upgrade_command + })) + } else { + Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) + } + } + }; + drop(io_task); + result + }); + + let connection: Rc = Rc::new(OldAcpAgentConnection { + name, + connection, + child_status, + }); + + Ok(connection) + }) + } +} + +impl Gemini { async fn command( &self, project: &Entity, diff --git a/crates/agent_servers/src/stdio_agent_server.rs b/crates/agent_servers/src/stdio_agent_server.rs deleted file mode 100644 index e60dd39de4..0000000000 --- a/crates/agent_servers/src/stdio_agent_server.rs +++ /dev/null @@ -1,119 +0,0 @@ -use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; -use acp_thread::{AcpClientDelegate, AcpThread, LoadError}; -use agentic_coding_protocol as acp; -use anyhow::{Result, anyhow}; -use gpui::{App, AsyncApp, Entity, Task, prelude::*}; -use project::Project; -use std::path::Path; -use util::ResultExt; - -pub trait StdioAgentServer: Send + Clone { - fn logo(&self) -> ui::IconName; - fn name(&self) -> &'static str; - fn empty_state_headline(&self) -> &'static str; - fn empty_state_message(&self) -> &'static str; - fn supports_always_allow(&self) -> bool; - - fn command( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> impl Future>; - - fn version( - &self, - command: &AgentServerCommand, - ) -> impl Future> + Send; -} - -impl AgentServer for T { - fn name(&self) -> &'static str { - self.name() - } - - fn empty_state_headline(&self) -> &'static str { - self.empty_state_headline() - } - - fn empty_state_message(&self) -> &'static str { - self.empty_state_message() - } - - fn logo(&self) -> ui::IconName { - self.logo() - } - - fn supports_always_allow(&self) -> bool { - self.supports_always_allow() - } - - fn new_thread( - &self, - root_dir: &Path, - project: &Entity, - cx: &mut App, - ) -> Task>> { - let root_dir = root_dir.to_path_buf(); - let project = project.clone(); - let this = self.clone(); - let title = self.name().into(); - - cx.spawn(async move |cx| { - let command = this.command(&project, cx).await?; - - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - cx.new(|cx| { - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(AgentServerVersion::Unsupported { - error_message, - upgrade_message, - upgrade_command, - }) = this.version(&command).await.log_err() - { - Err(anyhow!(LoadError::Unsupported { - error_message, - upgrade_message, - upgrade_command - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } - } - }; - drop(io_task); - result - }); - - AcpThread::new(connection, title, Some(child_status), project.clone(), cx) - }) - }) - } -} diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 7d3b84e42e..fbd53e8d09 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -17,10 +17,10 @@ test-support = ["gpui/test-support", "language/test-support"] [dependencies] acp_thread.workspace = true +agent-client-protocol.workspace = true agent.workspace = true -agentic-coding-protocol.workspace = true -agent_settings.workspace = true agent_servers.workspace = true +agent_settings.workspace = true ai_onboarding.workspace = true anyhow.workspace = true assistant_context.workspace = true diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 95f4f81205..7f5de9db5f 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,4 +1,4 @@ -use acp_thread::Plan; +use acp_thread::{AgentConnection, Plan}; use agent_servers::AgentServer; use std::cell::RefCell; use std::collections::BTreeMap; @@ -7,7 +7,7 @@ use std::rc::Rc; use std::sync::Arc; use std::time::Duration; -use agentic_coding_protocol::{self as acp}; +use agent_client_protocol as acp; use assistant_tool::ActionLog; use buffer_diff::BufferDiff; use collections::{HashMap, HashSet}; @@ -16,7 +16,6 @@ use editor::{ EditorStyle, MinimapVisibility, MultiBuffer, PathKey, }; use file_icons::FileIcons; -use futures::channel::oneshot; use gpui::{ Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, @@ -39,8 +38,7 @@ use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; use ::acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, - LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent, - ToolCallId, ToolCallStatus, + LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, }; use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; @@ -64,12 +62,13 @@ pub struct AcpThreadView { last_error: Option>, list_state: ListState, auth_task: Option>, - expanded_tool_calls: HashSet, + expanded_tool_calls: HashSet, expanded_thinking_blocks: HashSet<(usize, usize)>, edits_expanded: bool, plan_expanded: bool, editor_expanded: bool, - message_history: Rc>>, + message_history: Rc>>>, + _cancel_task: Option>, } enum ThreadState { @@ -82,22 +81,16 @@ enum ThreadState { }, LoadError(LoadError), Unauthenticated { - thread: Entity, + connection: Rc, }, } -struct AlwaysAllowOption { - id: &'static str, - label: SharedString, - outcome: acp::ToolCallConfirmationOutcome, -} - impl AcpThreadView { pub fn new( agent: Rc, workspace: WeakEntity, project: Entity, - message_history: Rc>>, + message_history: Rc>>>, min_lines: usize, max_lines: Option, window: &mut Window, @@ -191,6 +184,7 @@ impl AcpThreadView { plan_expanded: false, editor_expanded: false, message_history, + _cancel_task: None, } } @@ -208,9 +202,9 @@ impl AcpThreadView { .map(|worktree| worktree.read(cx).abs_path()) .unwrap_or_else(|| paths::home_dir().as_path().into()); - let task = agent.new_thread(&root_dir, &project, cx); + let connect_task = agent.connect(&root_dir, &project, cx); let load_task = cx.spawn_in(window, async move |this, cx| { - let thread = match task.await { + let connection = match connect_task.await { Ok(thread) => thread, Err(err) => { this.update(cx, |this, cx| { @@ -222,48 +216,30 @@ impl AcpThreadView { } }; - let init_response = async { - let resp = thread - .read_with(cx, |thread, _cx| thread.initialize())? - .await?; - anyhow::Ok(resp) - }; - - let result = match init_response.await { + let result = match connection + .clone() + .new_thread(project.clone(), &root_dir, cx) + .await + { Err(e) => { let mut cx = cx.clone(); - if e.downcast_ref::().is_some() { - let child_status = thread - .update(&mut cx, |thread, _| thread.child_status()) - .ok() - .flatten(); - if let Some(child_status) = child_status { - match child_status.await { - Ok(_) => Err(e), - Err(e) => Err(e), - } - } else { - Err(e) - } + if e.downcast_ref::().is_some() { + this.update(&mut cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { connection }; + cx.notify(); + }) + .ok(); + return; } else { Err(e) } } - Ok(response) => { - if !response.is_authenticated { - this.update(cx, |this, _| { - this.thread_state = ThreadState::Unauthenticated { thread }; - }) - .ok(); - return; - }; - Ok(()) - } + Ok(session_id) => Ok(session_id), }; this.update_in(cx, |this, window, cx| { match result { - Ok(()) => { + Ok(thread) => { let thread_subscription = cx.subscribe_in(&thread, window, Self::handle_thread_event); @@ -305,10 +281,10 @@ impl AcpThreadView { pub fn thread(&self) -> Option<&Entity> { match &self.thread_state { - ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { - Some(thread) - } - ThreadState::Loading { .. } | ThreadState::LoadError(..) => None, + ThreadState::Ready { thread, .. } => Some(thread), + ThreadState::Unauthenticated { .. } + | ThreadState::Loading { .. } + | ThreadState::LoadError(..) => None, } } @@ -325,7 +301,7 @@ impl AcpThreadView { self.last_error.take(); if let Some(thread) = self.thread() { - thread.update(cx, |thread, cx| thread.cancel(cx)).detach(); + self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); } } @@ -362,7 +338,7 @@ impl AcpThreadView { self.last_error.take(); let mut ix = 0; - let mut chunks: Vec = Vec::new(); + let mut chunks: Vec = Vec::new(); let project = self.project.clone(); self.message_editor.update(cx, |editor, cx| { let text = editor.text(cx); @@ -374,12 +350,19 @@ impl AcpThreadView { { let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); if crease_range.start > ix { - chunks.push(acp::UserMessageChunk::Text { - text: text[ix..crease_range.start].to_string(), - }); + chunks.push(text[ix..crease_range.start].into()); } if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { - chunks.push(acp::UserMessageChunk::Path { path: abs_path }); + let path_str = abs_path.display().to_string(); + chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: path_str.clone(), + name: path_str, + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + })); } ix = crease_range.end; } @@ -388,9 +371,7 @@ impl AcpThreadView { if ix < text.len() { let last_chunk = text[ix..].trim(); if !last_chunk.is_empty() { - chunks.push(acp::UserMessageChunk::Text { - text: last_chunk.into(), - }); + chunks.push(last_chunk.into()); } } }) @@ -401,8 +382,7 @@ impl AcpThreadView { } let Some(thread) = self.thread() else { return }; - let message = acp::SendUserMessageParams { chunks }; - let task = thread.update(cx, |thread, cx| thread.send(message.clone(), cx)); + let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx)); cx.spawn(async move |this, cx| { let result = task.await; @@ -424,7 +404,7 @@ impl AcpThreadView { editor.remove_creases(mention_set.lock().drain(), cx) }); - self.message_history.borrow_mut().push(message); + self.message_history.borrow_mut().push(chunks); } fn previous_history_message( @@ -490,7 +470,7 @@ impl AcpThreadView { message_editor: Entity, mention_set: Arc>, project: Entity, - message: Option<&acp::SendUserMessageParams>, + message: Option<&Vec>, window: &mut Window, cx: &mut Context, ) -> bool { @@ -503,18 +483,19 @@ impl AcpThreadView { let mut text = String::new(); let mut mentions = Vec::new(); - for chunk in &message.chunks { + for chunk in message { match chunk { - acp::UserMessageChunk::Text { text: chunk } => { - text.push_str(&chunk); + acp::ContentBlock::Text(text_content) => { + text.push_str(&text_content.text); } - acp::UserMessageChunk::Path { path } => { + acp::ContentBlock::ResourceLink(resource_link) => { + let path = Path::new(&resource_link.uri); let start = text.len(); - let content = MentionPath::new(path).to_string(); + let content = MentionPath::new(&path).to_string(); text.push_str(&content); let end = text.len(); if let Some(project_path) = - project.read(cx).project_path_for_absolute_path(path, cx) + project.read(cx).project_path_for_absolute_path(&path, cx) { let filename: SharedString = path .file_name() @@ -525,6 +506,9 @@ impl AcpThreadView { mentions.push((start..end, project_path, filename)); } } + acp::ContentBlock::Image(_) + | acp::ContentBlock::Audio(_) + | acp::ContentBlock::Resource(_) => {} } } @@ -590,71 +574,79 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let Some(multibuffer) = self.entry_diff_multibuffer(entry_ix, cx) else { + let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else { return; }; - if self.diff_editors.contains_key(&multibuffer.entity_id()) { - return; - } + let multibuffers = multibuffers.collect::>(); - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - None, - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor.set_text_style_refinement(TextStyleRefinement { - font_size: Some( - TextSize::Small - .rems(cx) - .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) - .into(), - ), - ..Default::default() + for multibuffer in multibuffers { + if self.diff_editors.contains_key(&multibuffer.entity_id()) { + return; + } + + let editor = cx.new(|cx| { + let mut editor = Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + sized_by_content: true, + }, + multibuffer.clone(), + None, + window, + cx, + ); + editor.set_show_gutter(false, cx); + editor.disable_inline_diagnostics(); + editor.disable_expand_excerpt_buttons(cx); + editor.set_show_vertical_scrollbar(false, cx); + editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); + editor.set_soft_wrap_mode(SoftWrap::None, cx); + editor.scroll_manager.set_forbid_vertical_scroll(true); + editor.set_show_indent_guides(false, cx); + editor.set_read_only(true); + editor.set_show_breakpoints(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_expand_all_diff_hunks(cx); + editor.set_text_style_refinement(TextStyleRefinement { + font_size: Some( + TextSize::Small + .rems(cx) + .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) + .into(), + ), + ..Default::default() + }); + editor }); - editor - }); - let entity_id = multibuffer.entity_id(); - cx.observe_release(&multibuffer, move |this, _, _| { - this.diff_editors.remove(&entity_id); - }) - .detach(); + let entity_id = multibuffer.entity_id(); + cx.observe_release(&multibuffer, move |this, _, _| { + this.diff_editors.remove(&entity_id); + }) + .detach(); - self.diff_editors.insert(entity_id, editor); + self.diff_editors.insert(entity_id, editor); + } } - fn entry_diff_multibuffer(&self, entry_ix: usize, cx: &App) -> Option> { + fn entry_diff_multibuffers( + &self, + entry_ix: usize, + cx: &App, + ) -> Option>> { let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - entry.diff().map(|diff| diff.multibuffer.clone()) + Some(entry.diffs().map(|diff| diff.multibuffer.clone())) } fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { - let Some(thread) = self.thread().cloned() else { + let ThreadState::Unauthenticated { ref connection } = self.thread_state else { return; }; self.last_error.take(); - let authenticate = thread.read(cx).authenticate(); + let authenticate = connection.authenticate(cx); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); let agent = self.agent.clone(); @@ -684,15 +676,16 @@ impl AcpThreadView { fn authorize_tool_call( &mut self, - id: ToolCallId, - outcome: acp::ToolCallConfirmationOutcome, + tool_call_id: acp::ToolCallId, + option_id: acp::PermissionOptionId, + option_kind: acp::PermissionOptionKind, cx: &mut Context, ) { let Some(thread) = self.thread() else { return; }; thread.update(cx, |thread, cx| { - thread.authorize_tool_call(id, outcome, cx); + thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx); }); cx.notify(); } @@ -719,10 +712,12 @@ impl AcpThreadView { .border_1() .border_color(cx.theme().colors().border) .text_xs() - .child(self.render_markdown( - message.content.clone(), - user_message_markdown_style(window, cx), - )), + .children(message.content.markdown().map(|md| { + self.render_markdown( + md.clone(), + user_message_markdown_style(window, cx), + ) + })), ) .into_any(), AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => { @@ -730,20 +725,28 @@ impl AcpThreadView { let message_body = v_flex() .w_full() .gap_2p5() - .children(chunks.iter().enumerate().map(|(chunk_ix, chunk)| { - match chunk { - AssistantMessageChunk::Text { chunk } => self - .render_markdown(chunk.clone(), style.clone()) - .into_any_element(), - AssistantMessageChunk::Thought { chunk } => self.render_thinking_block( - index, - chunk_ix, - chunk.clone(), - window, - cx, - ), - } - })) + .children(chunks.iter().enumerate().filter_map( + |(chunk_ix, chunk)| match chunk { + AssistantMessageChunk::Message { block } => { + block.markdown().map(|md| { + self.render_markdown(md.clone(), style.clone()) + .into_any_element() + }) + } + AssistantMessageChunk::Thought { block } => { + block.markdown().map(|md| { + self.render_thinking_block( + index, + chunk_ix, + md.clone(), + window, + cx, + ) + .into_any_element() + }) + } + }, + )) .into_any(); v_flex() @@ -871,7 +874,7 @@ impl AcpThreadView { let status_icon = match &tool_call.status { ToolCallStatus::WaitingForConfirmation { .. } => None, ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Running, + status: acp::ToolCallStatus::InProgress, .. } => Some( Icon::new(IconName::ArrowCircle) @@ -885,13 +888,13 @@ impl AcpThreadView { .into_any(), ), ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Finished, + status: acp::ToolCallStatus::Completed, .. } => None, ToolCallStatus::Rejected | ToolCallStatus::Canceled | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::Error, + status: acp::ToolCallStatus::Failed, .. } => Some( Icon::new(IconName::X) @@ -909,34 +912,9 @@ impl AcpThreadView { .any(|content| matches!(content, ToolCallContent::Diff { .. })), }; - let is_collapsible = tool_call.content.is_some() && !needs_confirmation; + let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation; let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id); - let content = if is_open { - match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { confirmation, .. } => { - Some(self.render_tool_call_confirmation( - tool_call.id, - confirmation, - tool_call.content.as_ref(), - window, - cx, - )) - } - ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { - tool_call.content.as_ref().map(|content| { - div() - .py_1p5() - .child(self.render_tool_call_content(content, window, cx)) - .into_any_element() - }) - } - ToolCallStatus::Rejected => None, - } - } else { - None - }; - v_flex() .when(needs_confirmation, |this| { this.rounded_lg() @@ -976,9 +954,17 @@ impl AcpThreadView { }) .gap_1p5() .child( - Icon::new(tool_call.icon) - .size(IconSize::Small) - .color(Color::Muted), + Icon::new(match tool_call.kind { + acp::ToolKind::Read => IconName::ToolRead, + acp::ToolKind::Edit => IconName::ToolPencil, + acp::ToolKind::Search => IconName::ToolSearch, + acp::ToolKind::Execute => IconName::ToolTerminal, + acp::ToolKind::Think => IconName::ToolBulb, + acp::ToolKind::Fetch => IconName::ToolWeb, + acp::ToolKind::Other => IconName::ToolHammer, + }) + .size(IconSize::Small) + .color(Color::Muted), ) .child(if tool_call.locations.len() == 1 { let name = tool_call.locations[0] @@ -1023,16 +1009,16 @@ impl AcpThreadView { .gap_0p5() .when(is_collapsible, |this| { this.child( - Disclosure::new(("expand", tool_call.id.0), is_open) + Disclosure::new(("expand", entry_ix), is_open) .opened_icon(IconName::ChevronUp) .closed_icon(IconName::ChevronDown) .on_click(cx.listener({ - let id = tool_call.id; + let id = tool_call.id.clone(); move |this: &mut Self, _, _, cx: &mut Context| { if is_open { this.expanded_tool_calls.remove(&id); } else { - this.expanded_tool_calls.insert(id); + this.expanded_tool_calls.insert(id.clone()); } cx.notify(); } @@ -1042,12 +1028,12 @@ impl AcpThreadView { .children(status_icon), ) .on_click(cx.listener({ - let id = tool_call.id; + let id = tool_call.id.clone(); move |this: &mut Self, _, _, cx: &mut Context| { if is_open { this.expanded_tool_calls.remove(&id); } else { - this.expanded_tool_calls.insert(id); + this.expanded_tool_calls.insert(id.clone()); } cx.notify(); } @@ -1055,7 +1041,7 @@ impl AcpThreadView { ) .when(is_open, |this| { this.child( - div() + v_flex() .text_xs() .when(is_collapsible, |this| { this.mt_1() @@ -1064,7 +1050,44 @@ impl AcpThreadView { .bg(cx.theme().colors().editor_background) .rounded_lg() }) - .children(content), + .map(|this| { + if is_open { + match &tool_call.status { + ToolCallStatus::WaitingForConfirmation { options, .. } => this + .children(tool_call.content.iter().map(|content| { + div() + .py_1p5() + .child( + self.render_tool_call_content( + content, window, cx, + ), + ) + .into_any_element() + })) + .child(self.render_permission_buttons( + options, + entry_ix, + tool_call.id.clone(), + cx, + )), + ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { + this.children(tool_call.content.iter().map(|content| { + div() + .py_1p5() + .child( + self.render_tool_call_content( + content, window, cx, + ), + ) + .into_any_element() + })) + } + ToolCallStatus::Rejected => this, + } + } else { + this + } + }), ) }) } @@ -1076,14 +1099,20 @@ impl AcpThreadView { cx: &Context, ) -> AnyElement { match content { - ToolCallContent::Markdown { markdown } => { - div() - .p_2() - .child(self.render_markdown( - markdown.clone(), - default_markdown_style(false, window, cx), - )) - .into_any_element() + ToolCallContent::ContentBlock { content } => { + if let Some(md) = content.markdown() { + div() + .p_2() + .child( + self.render_markdown( + md.clone(), + default_markdown_style(false, window, cx), + ), + ) + .into_any_element() + } else { + Empty.into_any_element() + } } ToolCallContent::Diff { diff: Diff { multibuffer, .. }, @@ -1092,223 +1121,53 @@ impl AcpThreadView { } } - fn render_tool_call_confirmation( + fn render_permission_buttons( &self, - tool_call_id: ToolCallId, - confirmation: &ToolCallConfirmation, - content: Option<&ToolCallContent>, - window: &Window, - cx: &Context, - ) -> AnyElement { - let confirmation_container = v_flex().mt_1().py_1p5(); - - match confirmation { - ToolCallConfirmation::Edit { description } => confirmation_container - .child( - div() - .px_2() - .children(description.clone().map(|description| { - self.render_markdown( - description, - default_markdown_style(false, window, cx), - ) - })), - ) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[AlwaysAllowOption { - id: "always_allow", - label: "Always Allow Edits".into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, - }], - tool_call_id, - cx, - )) - .into_any(), - ToolCallConfirmation::Execute { - command, - root_command, - description, - } => confirmation_container - .child(v_flex().px_2().pb_1p5().child(command.clone()).children( - description.clone().map(|description| { - self.render_markdown(description, default_markdown_style(false, window, cx)) - .on_url_click({ - let workspace = self.workspace.clone(); - move |text, window, cx| { - Self::open_link(text, &workspace, window, cx); - } - }) - }), - )) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[AlwaysAllowOption { - id: "always_allow", - label: format!("Always Allow {root_command}").into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, - }], - tool_call_id, - cx, - )) - .into_any(), - ToolCallConfirmation::Mcp { - server_name, - tool_name: _, - tool_display_name, - description, - } => confirmation_container - .child( - v_flex() - .px_2() - .pb_1p5() - .child(format!("{server_name} - {tool_display_name}")) - .children(description.clone().map(|description| { - self.render_markdown( - description, - default_markdown_style(false, window, cx), - ) - })), - ) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[ - AlwaysAllowOption { - id: "always_allow_server", - label: format!("Always Allow {server_name}").into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, - }, - AlwaysAllowOption { - id: "always_allow_tool", - label: format!("Always Allow {tool_display_name}").into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowTool, - }, - ], - tool_call_id, - cx, - )) - .into_any(), - ToolCallConfirmation::Fetch { description, urls } => confirmation_container - .child( - v_flex() - .px_2() - .pb_1p5() - .gap_1() - .children(urls.iter().map(|url| { - h_flex().child( - Button::new(url.clone(), url) - .icon(IconName::ArrowUpRight) - .icon_color(Color::Muted) - .icon_size(IconSize::XSmall) - .on_click({ - let url = url.clone(); - move |_, _, cx| cx.open_url(&url) - }), - ) - })) - .children(description.clone().map(|description| { - self.render_markdown( - description, - default_markdown_style(false, window, cx), - ) - })), - ) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[AlwaysAllowOption { - id: "always_allow", - label: "Always Allow".into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, - }], - tool_call_id, - cx, - )) - .into_any(), - ToolCallConfirmation::Other { description } => confirmation_container - .child(v_flex().px_2().pb_1p5().child(self.render_markdown( - description.clone(), - default_markdown_style(false, window, cx), - ))) - .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child(self.render_confirmation_buttons( - &[AlwaysAllowOption { - id: "always_allow", - label: "Always Allow".into(), - outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, - }], - tool_call_id, - cx, - )) - .into_any(), - } - } - - fn render_confirmation_buttons( - &self, - always_allow_options: &[AlwaysAllowOption], - tool_call_id: ToolCallId, + options: &[acp::PermissionOption], + entry_ix: usize, + tool_call_id: acp::ToolCallId, cx: &Context, ) -> Div { h_flex() - .pt_1p5() + .py_1p5() .px_1p5() .gap_1() .justify_end() .border_t_1() .border_color(self.tool_card_border_color(cx)) - .when(self.agent.supports_always_allow(), |this| { - this.children(always_allow_options.into_iter().map(|always_allow_option| { - let outcome = always_allow_option.outcome; - Button::new( - (always_allow_option.id, tool_call_id.0), - always_allow_option.label.clone(), - ) - .icon(IconName::CheckDouble) + .children(options.iter().map(|option| { + let option_id = SharedString::from(option.id.0.clone()); + Button::new((option_id, entry_ix), option.label.clone()) + .map(|this| match option.kind { + acp::PermissionOptionKind::AllowOnce => { + this.icon(IconName::Check).icon_color(Color::Success) + } + acp::PermissionOptionKind::AllowAlways => { + this.icon(IconName::CheckDouble).icon_color(Color::Success) + } + acp::PermissionOptionKind::RejectOnce => { + this.icon(IconName::X).icon_color(Color::Error) + } + acp::PermissionOptionKind::RejectAlways => { + this.icon(IconName::X).icon_color(Color::Error) + } + }) .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) - .icon_color(Color::Success) .on_click(cx.listener({ - let id = tool_call_id; + let tool_call_id = tool_call_id.clone(); + let option_id = option.id.clone(); + let option_kind = option.kind; move |this, _, _, cx| { - this.authorize_tool_call(id, outcome, cx); + this.authorize_tool_call( + tool_call_id.clone(), + option_id.clone(), + option_kind, + cx, + ); } })) - })) - }) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ) + })) } fn render_diff_editor(&self, multibuffer: &Entity) -> AnyElement { @@ -2245,12 +2104,11 @@ impl AcpThreadView { .languages .language_for_name("Markdown"); - let (thread_summary, markdown) = match &self.thread_state { - ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { - let thread = thread.read(cx); - (thread.title().to_string(), thread.to_markdown(cx)) - } - ThreadState::Loading { .. } | ThreadState::LoadError(..) => return Task::ready(Ok(())), + let (thread_summary, markdown) = if let Some(thread) = self.thread() { + let thread = thread.read(cx); + (thread.title().to_string(), thread.to_markdown(cx)) + } else { + return Task::ready(Ok(())); }; window.spawn(cx, async move |cx| { diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index e69664ce88..ec0a11f86b 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1506,8 +1506,7 @@ impl AgentDiff { .read(cx) .entries() .last() - .and_then(|entry| entry.diff()) - .is_some() + .map_or(false, |entry| entry.diffs().next().is_some()) { self.update_reviewing_editors(workspace, window, cx); } @@ -1517,8 +1516,7 @@ impl AgentDiff { .read(cx) .entries() .get(*ix) - .and_then(|entry| entry.diff()) - .is_some() + .map_or(false, |entry| entry.diffs().next().is_some()) { self.update_reviewing_editors(workspace, window, cx); } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index a0250816a0..4b3db4bc1d 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -440,7 +440,7 @@ pub struct AgentPanel { local_timezone: UtcOffset, active_view: ActiveView, acp_message_history: - Rc>>, + Rc>>>, previous_view: Option, history_store: Entity, history: Entity, diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 6b24d9b136..8c5e7da0f1 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result, anyhow}; use collections::HashMap; -use futures::{FutureExt, StreamExt, channel::oneshot, select}; +use futures::{FutureExt, StreamExt, channel::oneshot, future, select}; use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task}; use parking_lot::Mutex; use postage::barrier; @@ -10,15 +10,19 @@ use smol::channel; use std::{ fmt, path::PathBuf, + pin::pin, sync::{ Arc, atomic::{AtomicI32, Ordering::SeqCst}, }, time::{Duration, Instant}, }; -use util::TryFutureExt; +use util::{ResultExt, TryFutureExt}; -use crate::transport::{StdioTransport, Transport}; +use crate::{ + transport::{StdioTransport, Transport}, + types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled}, +}; const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); @@ -32,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603; type ResponseHandler = Box)>; type NotificationHandler = Box; +type RequestHandler = Box; #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[serde(untagged)] @@ -78,6 +83,15 @@ pub struct Request<'a, T> { pub params: T, } +#[derive(Serialize, Deserialize)] +pub struct AnyRequest<'a> { + pub jsonrpc: &'a str, + pub id: RequestId, + pub method: &'a str, + #[serde(skip_serializing_if = "is_null_value")] + pub params: Option<&'a RawValue>, +} + #[derive(Serialize, Deserialize)] struct AnyResponse<'a> { jsonrpc: &'a str, @@ -176,15 +190,23 @@ impl Client { Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); + let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default())); let receive_input_task = cx.spawn({ let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); + let request_handlers = request_handlers.clone(); let transport = transport.clone(); async move |cx| { - Self::handle_input(transport, notification_handlers, response_handlers, cx) - .log_err() - .await + Self::handle_input( + transport, + notification_handlers, + request_handlers, + response_handlers, + cx, + ) + .log_err() + .await } }); let receive_err_task = cx.spawn({ @@ -230,13 +252,24 @@ impl Client { async fn handle_input( transport: Arc, notification_handlers: Arc>>, + request_handlers: Arc>>, response_handlers: Arc>>>, cx: &mut AsyncApp, ) -> anyhow::Result<()> { let mut receiver = transport.receive(); while let Some(message) = receiver.next().await { - if let Ok(response) = serde_json::from_str::(&message) { + log::trace!("recv: {}", &message); + if let Ok(request) = serde_json::from_str::(&message) { + let mut request_handlers = request_handlers.lock(); + if let Some(handler) = request_handlers.get_mut(request.method) { + handler( + request.id, + request.params.unwrap_or(RawValue::NULL), + cx.clone(), + ); + } + } else if let Ok(response) = serde_json::from_str::(&message) { if let Some(handlers) = response_handlers.lock().as_mut() { if let Some(handler) = handlers.remove(&response.id) { handler(Ok(message.to_string())); @@ -247,6 +280,8 @@ impl Client { if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { handler(notification.params.unwrap_or(Value::Null), cx.clone()); } + } else { + log::error!("Unhandled JSON from context_server: {}", message); } } @@ -294,6 +329,24 @@ impl Client { &self, method: &str, params: impl Serialize, + ) -> Result { + self.request_impl(method, params, None).await + } + + pub async fn cancellable_request( + &self, + method: &str, + params: impl Serialize, + cancel_rx: oneshot::Receiver<()>, + ) -> Result { + self.request_impl(method, params, Some(cancel_rx)).await + } + + pub async fn request_impl( + &self, + method: &str, + params: impl Serialize, + cancel_rx: Option>, ) -> Result { let id = self.next_id.fetch_add(1, SeqCst); let request = serde_json::to_string(&Request { @@ -330,6 +383,16 @@ impl Client { send?; let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse(); + let mut cancel_fut = pin!( + match cancel_rx { + Some(rx) => future::Either::Left(async { + rx.await.log_err(); + }), + None => future::Either::Right(future::pending()), + } + .fuse() + ); + select! { response = rx.fuse() => { let elapsed = started.elapsed(); @@ -348,6 +411,16 @@ impl Client { Err(_) => anyhow::bail!("cancelled") } } + _ = cancel_fut => { + self.notify( + Cancelled::METHOD, + ClientNotification::Cancelled(CancelledParams { + request_id: RequestId::Int(id), + reason: None + }) + ).log_err(); + anyhow::bail!("Request cancelled") + } _ = timeout => { log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT); anyhow::bail!("Context server request timeout"); diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index d8bbac60d6..7263f502fa 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -6,6 +6,9 @@ //! of messages. use anyhow::Result; +use futures::channel::oneshot; +use gpui::AsyncApp; +use serde_json::Value; use crate::client::Client; use crate::types::{self, Notification, Request}; @@ -95,7 +98,24 @@ impl InitializedContextServerProtocol { self.inner.request(T::METHOD, params).await } + pub async fn cancellable_request( + &self, + params: T::Params, + cancel_rx: oneshot::Receiver<()>, + ) -> Result { + self.inner + .cancellable_request(T::METHOD, params, cancel_rx) + .await + } + pub fn notify(&self, params: T::Params) -> Result<()> { self.inner.notify(T::METHOD, params) } + + pub fn on_notification(&self, method: &'static str, f: F) + where + F: 'static + Send + FnMut(Value, AsyncApp), + { + self.inner.on_notification(method, f); + } } diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 4a6fdcabd3..f92c86aa3c 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -3,6 +3,8 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::Url; +use crate::client::RequestId; + pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const VERSION_2024_11_05: &str = "2024-11-05"; @@ -100,6 +102,7 @@ pub mod notifications { notification!("notifications/initialized", Initialized, ()); notification!("notifications/progress", Progress, ProgressParams); notification!("notifications/message", Message, MessageParams); + notification!("notifications/cancelled", Cancelled, CancelledParams); notification!( "notifications/resources/updated", ResourcesUpdated, @@ -617,11 +620,14 @@ pub enum ClientNotification { Initialized, Progress(ProgressParams), RootsListChanged, - Cancelled { - request_id: String, - #[serde(skip_serializing_if = "Option::is_none")] - reason: Option, - }, + Cancelled(CancelledParams), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CancelledParams { + pub request_id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, } #[derive(Debug, Serialize, Deserialize)]