From 8e4555455c6469b33803c8c653e38b2755ad9aad Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 17 Jul 2025 11:25:55 -0300 Subject: [PATCH] Claude experiment (#34577) Release Notes: - N/A --------- Co-authored-by: Conrad Irwin Co-authored-by: Anthony Eid Co-authored-by: Ben Brandt Co-authored-by: Nathan Sobo Co-authored-by: Oleksiy Syvokon --- Cargo.lock | 33 +- Cargo.toml | 8 +- assets/icons/ai_claude.svg | 3 + assets/keymaps/default-linux.json | 4 +- assets/keymaps/default-macos.json | 4 +- crates/{acp => acp_thread}/Cargo.toml | 9 +- crates/{acp => acp_thread}/LICENSE-GPL | 0 .../acp.rs => acp_thread/src/acp_thread.rs} | 670 +++++------------ crates/acp_thread/src/connection.rs | 20 + crates/agent_servers/Cargo.toml | 21 + crates/agent_servers/src/agent_servers.rs | 213 +----- crates/agent_servers/src/claude.rs | 680 +++++++++++++++++ crates/agent_servers/src/claude/mcp_server.rs | 303 ++++++++ crates/agent_servers/src/claude/tools.rs | 670 +++++++++++++++++ crates/agent_servers/src/gemini.rs | 501 +++++++++++++ crates/agent_servers/src/settings.rs | 41 + .../agent_servers/src/stdio_agent_server.rs | 169 +++++ crates/agent_ui/Cargo.toml | 2 +- crates/agent_ui/src/acp/thread_view.rs | 704 +++++++----------- crates/agent_ui/src/agent_diff.rs | 4 +- crates/agent_ui/src/agent_panel.rs | 152 ++-- crates/agent_ui/src/agent_ui.rs | 31 +- crates/context_server/Cargo.toml | 2 + crates/context_server/src/client.rs | 29 +- crates/context_server/src/context_server.rs | 1 + crates/context_server/src/listener.rs | 236 ++++++ crates/context_server/src/types.rs | 2 +- crates/icons/src/icons.rs | 1 + crates/nc/Cargo.toml | 20 + crates/nc/LICENSE-GPL | 1 + crates/nc/src/nc.rs | 56 ++ crates/zed/Cargo.toml | 1 + crates/zed/src/main.rs | 16 + 33 files changed, 3437 insertions(+), 1170 deletions(-) create mode 100644 assets/icons/ai_claude.svg rename crates/{acp => acp_thread}/Cargo.toml (92%) rename crates/{acp => acp_thread}/LICENSE-GPL (100%) rename crates/{acp/src/acp.rs => acp_thread/src/acp_thread.rs} (75%) create mode 100644 crates/acp_thread/src/connection.rs create mode 100644 crates/agent_servers/src/claude.rs create mode 100644 crates/agent_servers/src/claude/mcp_server.rs create mode 100644 crates/agent_servers/src/claude/tools.rs create mode 100644 crates/agent_servers/src/gemini.rs create mode 100644 crates/agent_servers/src/settings.rs create mode 100644 crates/agent_servers/src/stdio_agent_server.rs create mode 100644 crates/context_server/src/listener.rs create mode 100644 crates/nc/Cargo.toml create mode 120000 crates/nc/LICENSE-GPL create mode 100644 crates/nc/src/nc.rs diff --git a/Cargo.lock b/Cargo.lock index 59e444f1f8..540e3039ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3,10 +3,9 @@ version = 4 [[package]] -name = "acp" +name = "acp_thread" version = "0.1.0" dependencies = [ - "agent_servers", "agentic-coding-protocol", "anyhow", "assistant_tool", @@ -21,6 +20,7 @@ dependencies = [ "language", "markdown", "project", + "serde", "serde_json", "settings", "smol", @@ -139,16 +139,29 @@ dependencies = [ name = "agent_servers" version = "0.1.0" dependencies = [ + "acp_thread", + "agentic-coding-protocol", "anyhow", "collections", + "context_server", + "env_logger 0.11.8", "futures 0.3.31", "gpui", + "indoc", + "itertools 0.14.0", + "language", + "log", "paths", "project", "schemars", "serde", + "serde_json", "settings", + "smol", + "tempfile", + "ui", "util", + "watch", "which 6.0.3", "workspace-hack", ] @@ -176,7 +189,7 @@ dependencies = [ name = "agent_ui" version = "0.1.0" dependencies = [ - "acp", + "acp_thread", "agent", "agent_servers", "agent_settings", @@ -3411,12 +3424,14 @@ dependencies = [ "futures 0.3.31", "gpui", "log", + "net", "parking_lot", "postage", "schemars", "serde", "serde_json", "smol", + "tempfile", "url", "util", "workspace-hack", @@ -10288,6 +10303,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "nc" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "net", + "smol", + "workspace-hack", +] + [[package]] name = "ndk" version = "0.8.0" @@ -20171,6 +20197,7 @@ dependencies = [ "menu", "migrator", "mimalloc", + "nc", "nix 0.29.0", "node_runtime", "notifications", diff --git a/Cargo.toml b/Cargo.toml index afb47c006e..1c79f4c1c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ resolver = "2" members = [ "crates/activity_indicator", - "crates/acp", + "crates/acp_thread", "crates/agent_ui", "crates/agent", "crates/agent_settings", @@ -102,6 +102,7 @@ members = [ "crates/migrator", "crates/mistral", "crates/multi_buffer", + "crates/nc", "crates/net", "crates/node_runtime", "crates/notifications", @@ -219,7 +220,7 @@ edition = "2024" # Workspace member crates # -acp = { path = "crates/acp" } +acp_thread = { path = "crates/acp_thread" } agent = { path = "crates/agent" } activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } @@ -317,6 +318,7 @@ menu = { path = "crates/menu" } migrator = { path = "crates/migrator" } mistral = { path = "crates/mistral" } multi_buffer = { path = "crates/multi_buffer" } +nc = { path = "crates/nc" } net = { path = "crates/net" } node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } @@ -406,7 +408,7 @@ zlog_settings = { path = "crates/zlog_settings" } # External crates # -agentic-coding-protocol = { version = "0.0.9" } +agentic-coding-protocol = "0.0.9" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/assets/icons/ai_claude.svg b/assets/icons/ai_claude.svg new file mode 100644 index 0000000000..423a963eba --- /dev/null +++ b/assets/icons/ai_claude.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index da4d79eca1..b52b6c614d 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -269,10 +269,10 @@ } }, { - "context": "AgentPanel && acp_thread", + "context": "AgentPanel && external_agent_thread", "use_key_equivalents": true, "bindings": { - "ctrl-n": "agent::NewAcpThread", + "ctrl-n": "agent::NewExternalAgentThread", "ctrl-alt-t": "agent::NewThread" } }, diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 962760098b..240b42fd1f 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -310,10 +310,10 @@ } }, { - "context": "AgentPanel && acp_thread", + "context": "AgentPanel && external_agent_thread", "use_key_equivalents": true, "bindings": { - "cmd-n": "agent::NewAcpThread", + "cmd-n": "agent::NewExternalAgentThread", "cmd-alt-t": "agent::NewThread" } }, diff --git a/crates/acp/Cargo.toml b/crates/acp_thread/Cargo.toml similarity index 92% rename from crates/acp/Cargo.toml rename to crates/acp_thread/Cargo.toml index 1570aeaef0..b44c25ccc9 100644 --- a/crates/acp/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "acp" +name = "acp_thread" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,15 +9,13 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/acp.rs" +path = "src/acp_thread.rs" doctest = false [features] test-support = ["gpui/test-support", "project/test-support"] -gemini = [] [dependencies] -agent_servers.workspace = true agentic-coding-protocol.workspace = true anyhow.workspace = true assistant_tool.workspace = true @@ -29,6 +27,8 @@ itertools.workspace = true language.workspace = true markdown.workspace = true project.workspace = true +serde.workspace = true +serde_json.workspace = true settings.workspace = true smol.workspace = true ui.workspace = true @@ -41,7 +41,6 @@ env_logger.workspace = true gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true project = { workspace = true, "features" = ["test-support"] } -serde_json.workspace = true tempfile.workspace = true util.workspace = true settings.workspace = true diff --git a/crates/acp/LICENSE-GPL b/crates/acp_thread/LICENSE-GPL similarity index 100% rename from crates/acp/LICENSE-GPL rename to crates/acp_thread/LICENSE-GPL diff --git a/crates/acp/src/acp.rs b/crates/acp_thread/src/acp_thread.rs similarity index 75% rename from crates/acp/src/acp.rs rename to crates/acp_thread/src/acp_thread.rs index a7e72b0c2d..1e3947351a 100644 --- a/crates/acp/src/acp.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,7 +1,12 @@ +mod connection; +pub use connection::*; + pub use acp::ToolCallId; -use agent_servers::AgentServer; -use agentic_coding_protocol::{self as acp, ToolCallLocation, UserMessageChunk}; -use anyhow::{Context as _, Result, anyhow}; +use agentic_coding_protocol::{ + self as acp, AgentRequest, ProtocolVersion, ToolCallConfirmationOutcome, ToolCallLocation, + UserMessageChunk, +}; +use anyhow::{Context as _, Result}; use assistant_tool::ActionLog; use buffer_diff::BufferDiff; use editor::{Bias, MultiBuffer, PathKey}; @@ -97,7 +102,7 @@ pub struct AssistantMessage { } impl AssistantMessage { - fn to_markdown(&self, cx: &App) -> String { + pub fn to_markdown(&self, cx: &App) -> String { format!( "## Assistant\n\n{}\n\n", self.chunks @@ -455,9 +460,8 @@ pub struct AcpThread { action_log: Entity, shared_buffers: HashMap, BufferSnapshot>, send_task: Option>, - connection: Arc, + connection: Arc, child_status: Option>>, - _io_task: Task<()>, } pub enum AcpThreadEvent { @@ -476,7 +480,11 @@ pub enum ThreadStatus { #[derive(Debug, Clone)] pub enum LoadError { - Unsupported { current_version: SharedString }, + Unsupported { + error_message: SharedString, + upgrade_message: SharedString, + upgrade_command: String, + }, Exited(i32), Other(SharedString), } @@ -484,13 +492,7 @@ pub enum LoadError { impl Display for LoadError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - LoadError::Unsupported { current_version } => { - write!( - f, - "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", - current_version - ) - } + LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message), LoadError::Exited(status) => write!(f, "Server exited with status {}", status), LoadError::Other(msg) => write!(f, "{}", msg), } @@ -500,75 +502,38 @@ impl Display for LoadError { impl Error for LoadError {} impl AcpThread { - pub async fn spawn( - server: impl AgentServer + 'static, - root_dir: &Path, + pub fn new( + connection: impl AgentConnection + 'static, + title: SharedString, + child_status: Option>>, project: Entity, - cx: &mut AsyncApp, - ) -> Result> { - let command = match server.command(&project, cx).await { - Ok(command) => command, - Err(e) => return Err(anyhow!(LoadError::Other(format!("{e}").into()))), - }; + cx: &mut Context, + ) -> Self { + let action_log = cx.new(|_| ActionLog::new(project.clone())); - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; + Self { + action_log, + shared_buffers: Default::default(), + entries: Default::default(), + title, + project, + send_task: None, + connection: Arc::new(connection), + child_status, + } + } - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - cx.new(|cx| { - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(version) = server.version(&command).await.log_err() - && !version.supported - { - Err(anyhow!(LoadError::Unsupported { - current_version: version.current_version - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } - } - } - }); - - let action_log = cx.new(|_| ActionLog::new(project.clone())); - - Self { - action_log, - shared_buffers: Default::default(), - entries: Default::default(), - title: "ACP Thread".into(), - project, - send_task: None, - connection: Arc::new(connection), - child_status: Some(child_status), - _io_task: io_task, - } - }) + /// Send a request to the agent and wait for a response. + pub fn request( + &self, + params: R, + ) -> impl use + Future> { + let params = params.into_any(); + let result = self.connection.request_any(params); + async move { + let result = result.await?; + Ok(R::response_from_any(result)?) + } } pub fn action_log(&self) -> &Entity { @@ -579,45 +544,6 @@ impl AcpThread { &self.project } - #[cfg(test)] - pub fn fake( - stdin: async_pipe::PipeWriter, - stdout: async_pipe::PipeReader, - project: Entity, - cx: &mut Context, - ) -> Self { - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), - stdin, - stdout, - move |fut| { - foreground_executor.spawn(fut).detach(); - }, - ); - - let io_task = cx.background_spawn({ - async move { - io_fut.await.log_err(); - } - }); - - let action_log = cx.new(|_| ActionLog::new(project.clone())); - - Self { - action_log, - shared_buffers: Default::default(), - entries: Default::default(), - title: "ACP Thread".into(), - project, - send_task: None, - connection: Arc::new(connection), - child_status: None, - _io_task: io_task, - } - } - pub fn title(&self) -> SharedString { self.title.clone() } @@ -711,7 +637,7 @@ impl AcpThread { } } - pub fn request_tool_call( + pub fn request_new_tool_call( &mut self, tool_call: acp::RequestToolCallConfirmationParams, cx: &mut Context, @@ -731,6 +657,30 @@ impl AcpThread { ToolCallRequest { id, outcome: rx } } + pub fn request_tool_call_confirmation( + &mut self, + tool_call_id: ToolCallId, + confirmation: acp::ToolCallConfirmation, + cx: &mut Context, + ) -> Result { + let project = self.project.read(cx).languages().clone(); + let Some((_, call)) = self.tool_call_mut(tool_call_id) else { + anyhow::bail!("Tool call not found"); + }; + + let (tx, rx) = oneshot::channel(); + + call.status = ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx), + respond_tx: tx, + }; + + Ok(ToolCallRequest { + id: tool_call_id, + outcome: rx, + }) + } + pub fn push_tool_call( &mut self, request: acp::PushToolCallParams, @@ -912,19 +862,17 @@ impl AcpThread { false } - pub fn initialize( - &self, - ) -> impl use<> + Future> { - let connection = self.connection.clone(); - async move { connection.initialize().await } + pub fn initialize(&self) -> impl use<> + Future> { + self.request(acp::InitializeParams { + protocol_version: ProtocolVersion::latest(), + }) } - pub fn authenticate(&self) -> impl use<> + Future> { - let connection = self.connection.clone(); - async move { connection.request(acp::AuthenticateParams).await } + pub fn authenticate(&self) -> impl use<> + Future> { + self.request(acp::AuthenticateParams) } - #[cfg(test)] + #[cfg(any(test, feature = "test-support"))] pub fn send_raw( &mut self, message: &str, @@ -945,7 +893,6 @@ impl AcpThread { message: acp::SendUserMessageParams, cx: &mut Context, ) -> BoxFuture<'static, Result<(), acp::Error>> { - let agent = self.connection.clone(); self.push_entry( AgentThreadEntry::UserMessage(UserMessage::from_acp( &message, @@ -959,11 +906,16 @@ impl AcpThread { let cancel = self.cancel(cx); self.send_task = Some(cx.spawn(async move |this, cx| { - cancel.await.log_err(); + async { + cancel.await.log_err(); - let result = agent.request(message).await; - tx.send(result).log_err(); - this.update(cx, |this, _cx| this.send_task.take()).log_err(); + let result = this.update(cx, |this, _| this.request(message))?.await; + tx.send(result).log_err(); + this.update(cx, |this, _cx| this.send_task.take())?; + anyhow::Ok(()) + } + .await + .log_err(); })); async move { @@ -976,12 +928,10 @@ impl AcpThread { } pub fn cancel(&mut self, cx: &mut Context) -> Task> { - let agent = self.connection.clone(); - if self.send_task.take().is_some() { + let request = self.request(acp::CancelSendMessageParams); cx.spawn(async move |this, cx| { - agent.request(acp::CancelSendMessageParams).await?; - + request.await?; this.update(cx, |this, _cx| { for entry in this.entries.iter_mut() { if let AgentThreadEntry::ToolCall(call) = entry { @@ -1019,6 +969,7 @@ impl AcpThread { pub fn read_text_file( &self, request: acp::ReadTextFileParams, + reuse_shared_snapshot: bool, cx: &mut Context, ) -> Task> { let project = self.project.clone(); @@ -1032,28 +983,60 @@ impl AcpThread { }); let buffer = load??.await?; - action_log.update(cx, |action_log, cx| { - action_log.buffer_read(buffer.clone(), cx); - })?; - project.update(cx, |project, cx| { - let position = buffer - .read(cx) - .snapshot() - .anchor_before(Point::new(request.line.unwrap_or_default(), 0)); - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position, - }), - cx, - ); - })?; - let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?; + let snapshot = if reuse_shared_snapshot { + this.read_with(cx, |this, _| { + this.shared_buffers.get(&buffer.clone()).cloned() + }) + .log_err() + .flatten() + } else { + None + }; + + let snapshot = if let Some(snapshot) = snapshot { + snapshot + } else { + action_log.update(cx, |action_log, cx| { + action_log.buffer_read(buffer.clone(), cx); + })?; + project.update(cx, |project, cx| { + let position = buffer + .read(cx) + .snapshot() + .anchor_before(Point::new(request.line.unwrap_or_default(), 0)); + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position, + }), + cx, + ); + })?; + + buffer.update(cx, |buffer, _| buffer.snapshot())? + }; + this.update(cx, |this, _| { let text = snapshot.text(); this.shared_buffers.insert(buffer.clone(), snapshot); - text - }) + if request.line.is_none() && request.limit.is_none() { + return Ok(text); + } + let limit = request.limit.unwrap_or(u32::MAX) as usize; + let Some(line) = request.line else { + return Ok(text.lines().take(limit).collect::()); + }; + + let count = text.lines().count(); + if count < line as usize { + anyhow::bail!("There are only {} lines", count); + } + Ok(text + .lines() + .skip(line as usize + 1) + .take(limit) + .collect::()) + })? }) } @@ -1134,16 +1117,49 @@ impl AcpThread { } } -struct AcpClientDelegate { +#[derive(Clone)] +pub struct AcpClientDelegate { thread: WeakEntity, cx: AsyncApp, // sent_buffer_versions: HashMap, HashMap>, } impl AcpClientDelegate { - fn new(thread: WeakEntity, cx: AsyncApp) -> Self { + pub fn new(thread: WeakEntity, cx: AsyncApp) -> Self { Self { thread, cx } } + + pub async fn request_existing_tool_call_confirmation( + &self, + tool_call_id: ToolCallId, + confirmation: acp::ToolCallConfirmation, + ) -> Result { + let cx = &mut self.cx.clone(); + let ToolCallRequest { outcome, .. } = cx + .update(|cx| { + self.thread.update(cx, |thread, cx| { + thread.request_tool_call_confirmation(tool_call_id, confirmation, cx) + }) + })? + .context("Failed to update thread")??; + + Ok(outcome.await?) + } + + pub async fn read_text_file_reusing_snapshot( + &self, + request: acp::ReadTextFileParams, + ) -> Result { + let content = self + .cx + .update(|cx| { + self.thread + .update(cx, |thread, cx| thread.read_text_file(request, true, cx)) + })? + .context("Failed to update thread")? + .await?; + Ok(acp::ReadTextFileResponse { content }) + } } impl acp::Client for AcpClientDelegate { @@ -1172,7 +1188,7 @@ impl acp::Client for AcpClientDelegate { let ToolCallRequest { id, outcome } = cx .update(|cx| { self.thread - .update(cx, |thread, cx| thread.request_tool_call(request, cx)) + .update(cx, |thread, cx| thread.request_new_tool_call(request, cx)) })? .context("Failed to update thread")?; @@ -1218,7 +1234,7 @@ impl acp::Client for AcpClientDelegate { .cx .update(|cx| { self.thread - .update(cx, |thread, cx| thread.read_text_file(request, cx)) + .update(cx, |thread, cx| thread.read_text_file(request, false, cx)) })? .context("Failed to update thread")? .await?; @@ -1260,7 +1276,7 @@ pub struct ToolCallRequest { #[cfg(test)] mod tests { use super::*; - use agent_servers::{AgentServerCommand, AgentServerVersion}; + use anyhow::anyhow; use async_pipe::{PipeReader, PipeWriter}; use futures::{channel::mpsc, future::LocalBoxFuture, select}; use gpui::{AsyncApp, TestAppContext}; @@ -1269,7 +1285,7 @@ mod tests { use serde_json::json; use settings::SettingsStore; use smol::{future::BoxedLocal, stream::StreamExt as _}; - use std::{cell::RefCell, env, path::Path, rc::Rc, time::Duration}; + use std::{cell::RefCell, rc::Rc, time::Duration}; use util::path; fn init_test(cx: &mut TestAppContext) { @@ -1515,265 +1531,6 @@ mod tests { }); } - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_basic(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - thread - .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) - .await - .unwrap(); - - thread.read_with(cx, |thread, _| { - assert_eq!(thread.entries.len(), 2); - assert!(matches!( - thread.entries[0], - AgentThreadEntry::UserMessage(_) - )); - assert!(matches!( - thread.entries[1], - AgentThreadEntry::AssistantMessage(_) - )); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_path_mentions(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - let tempdir = tempfile::tempdir().unwrap(); - std::fs::write( - tempdir.path().join("foo.rs"), - indoc! {" - fn main() { - println!(\"Hello, world!\"); - } - "}, - ) - .expect("failed to write file"); - let project = Project::example([tempdir.path()], &mut cx.to_async()).await; - let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await; - thread - .update(cx, |thread, cx| { - thread.send( - acp::SendUserMessageParams { - chunks: vec![ - acp::UserMessageChunk::Text { - text: "Read the file ".into(), - }, - acp::UserMessageChunk::Path { - path: Path::new("foo.rs").into(), - }, - acp::UserMessageChunk::Text { - text: " and tell me what the content of the println! is".into(), - }, - ], - }, - cx, - ) - }) - .await - .unwrap(); - - thread.read_with(cx, |thread, cx| { - assert_eq!(thread.entries.len(), 3); - assert!(matches!( - thread.entries[0], - AgentThreadEntry::UserMessage(_) - )); - assert!(matches!(thread.entries[1], AgentThreadEntry::ToolCall(_))); - let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries[2] else { - panic!("Expected AssistantMessage") - }; - assert!( - assistant_message.to_markdown(cx).contains("Hello, world!"), - "unexpected assistant message: {:?}", - assistant_message.to_markdown(cx) - ); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_tool_call(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/private/tmp"), - json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), - ) - .await; - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - thread - .update(cx, |thread, cx| { - thread.send_raw( - "Read the '/private/tmp/foo' file and tell me what you see.", - cx, - ) - }) - .await - .unwrap(); - thread.read_with(cx, |thread, _cx| { - assert!(matches!( - &thread.entries()[2], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { .. }, - .. - }) - )); - - assert!(matches!( - thread.entries[3], - AgentThreadEntry::AssistantMessage(_) - )); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - let full_turn = thread.update(cx, |thread, cx| { - thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) - }); - - run_until_first_tool_call(&thread, cx).await; - - let tool_call_id = thread.read_with(cx, |thread, _cx| { - let AgentThreadEntry::ToolCall(ToolCall { - id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, - .. - }) = &thread.entries()[2] - else { - panic!(); - }; - - assert_eq!(root_command, "echo"); - - *id - }); - - thread.update(cx, |thread, cx| { - thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); - - assert!(matches!( - &thread.entries()[2], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { .. }, - .. - }) - )); - }); - - full_turn.await.unwrap(); - - thread.read_with(cx, |thread, cx| { - let AgentThreadEntry::ToolCall(ToolCall { - content: Some(ToolCallContent::Markdown { markdown }), - status: ToolCallStatus::Allowed { .. }, - .. - }) = &thread.entries()[2] - else { - panic!(); - }; - - markdown.read_with(cx, |md, _cx| { - assert!( - md.source().contains("Hello, world!"), - r#"Expected '{}' to contain "Hello, world!""#, - md.source() - ); - }); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_cancel(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - let full_turn = thread.update(cx, |thread, cx| { - thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) - }); - - let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await; - - thread.read_with(cx, |thread, _cx| { - let AgentThreadEntry::ToolCall(ToolCall { - id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, - .. - }) = &thread.entries()[first_tool_call_ix] - else { - panic!("{:?}", thread.entries()[1]); - }; - - assert_eq!(root_command, "echo"); - - *id - }); - - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .await - .unwrap(); - full_turn.await.unwrap(); - thread.read_with(cx, |thread, _| { - let AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Canceled, - .. - }) = &thread.entries()[first_tool_call_ix] - else { - panic!(); - }; - }); - - thread - .update(cx, |thread, cx| { - thread.send_raw(r#"Stop running and say goodbye to me."#, cx) - }) - .await - .unwrap(); - thread.read_with(cx, |thread, _| { - assert!(matches!( - &thread.entries().last().unwrap(), - AgentThreadEntry::AssistantMessage(..), - )) - }); - } - async fn run_until_first_tool_call( thread: &Entity, cx: &mut TestAppContext, @@ -1801,66 +1558,39 @@ mod tests { } } - pub async fn gemini_acp_thread( - project: Entity, - current_dir: impl AsRef, - cx: &mut TestAppContext, - ) -> Entity { - struct DevGemini; - - impl agent_servers::AgentServer for DevGemini { - async fn command( - &self, - _project: &Entity, - _cx: &mut AsyncApp, - ) -> Result { - let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) - .join("../../../gemini-cli/packages/cli") - .to_string_lossy() - .to_string(); - - Ok(AgentServerCommand { - path: "node".into(), - args: vec![cli_path, "--experimental-acp".into()], - env: None, - }) - } - - async fn version( - &self, - _command: &agent_servers::AgentServerCommand, - ) -> Result { - Ok(AgentServerVersion { - current_version: "0.1.0".into(), - supported: true, - }) - } - } - - let thread = AcpThread::spawn(DevGemini, current_dir.as_ref(), project, &mut cx.to_async()) - .await - .unwrap(); - - thread - .update(cx, |thread, _| thread.initialize()) - .await - .unwrap(); - thread - } - pub fn fake_acp_thread( project: Entity, cx: &mut TestAppContext, ) -> (Entity, Entity) { let (stdin_tx, stdin_rx) = async_pipe::pipe(); let (stdout_tx, stdout_rx) = async_pipe::pipe(); - let thread = cx.update(|cx| cx.new(|cx| AcpThread::fake(stdin_tx, stdout_rx, project, cx))); + + let thread = cx.new(|cx| { + let foreground_executor = cx.foreground_executor().clone(); + let (connection, io_fut) = acp::AgentConnection::connect_to_agent( + AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), + stdin_tx, + stdout_rx, + move |fut| { + foreground_executor.spawn(fut).detach(); + }, + ); + + let io_task = cx.background_spawn({ + async move { + io_fut.await.log_err(); + Ok(()) + } + }); + AcpThread::new(connection, "Test".into(), Some(io_task), project, cx) + }); let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx))); (thread, agent) } pub struct FakeAcpServer { connection: acp::ClientConnection, + _io_task: Task<()>, on_user_message: Option< Rc< diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs new file mode 100644 index 0000000000..7c0ba4f41c --- /dev/null +++ b/crates/acp_thread/src/connection.rs @@ -0,0 +1,20 @@ +use agentic_coding_protocol as acp; +use anyhow::Result; +use futures::future::{FutureExt as _, LocalBoxFuture}; + +pub trait AgentConnection { + fn request_any( + &self, + params: acp::AnyAgentRequest, + ) -> LocalBoxFuture<'static, Result>; +} + +impl AgentConnection for acp::AgentConnection { + fn request_any( + &self, + params: acp::AnyAgentRequest, + ) -> LocalBoxFuture<'static, Result> { + let task = self.request_any(params); + async move { Ok(task.await?) }.boxed_local() + } +} diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 549162c5dd..d65235aee3 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -5,6 +5,10 @@ edition.workspace = true publish.workspace = true license = "GPL-3.0-or-later" +[features] +test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"] +gemini = [] + [lints] workspace = true @@ -13,15 +17,32 @@ path = "src/agent_servers.rs" doctest = false [dependencies] +acp_thread.workspace = true +agentic-coding-protocol.workspace = true anyhow.workspace = true collections.workspace = true +context_server.workspace = true futures.workspace = true gpui.workspace = true +itertools.workspace = true +log.workspace = true paths.workspace = true project.workspace = true schemars.workspace = true serde.workspace = true +serde_json.workspace = true settings.workspace = true +smol.workspace = true +tempfile.workspace = true +ui.workspace = true util.workspace = true +watch.workspace = true which.workspace = true workspace-hack.workspace = true + +[dev-dependencies] +env_logger.workspace = true +language.workspace = true +indoc.workspace = true +acp_thread = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index ba43122570..ebebeca511 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,30 +1,24 @@ -use std::{ - path::{Path, PathBuf}, - sync::Arc, -}; +mod claude; +mod gemini; +mod settings; +mod stdio_agent_server; -use anyhow::{Context as _, Result}; +pub use claude::*; +pub use gemini::*; +pub use settings::*; +pub use stdio_agent_server::*; + +use acp_thread::AcpThread; +use anyhow::Result; use collections::HashMap; -use gpui::{App, AsyncApp, Entity, SharedString}; +use gpui::{App, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources, SettingsStore}; -use util::{ResultExt, paths}; +use std::path::{Path, PathBuf}; pub fn init(cx: &mut App) { - AllAgentServersSettings::register(cx); -} - -#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)] -pub struct AllAgentServersSettings { - gemini: Option, -} - -#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] -pub struct AgentServerSettings { - #[serde(flatten)] - command: AgentServerCommand, + settings::init(cx); } #[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] @@ -36,153 +30,28 @@ pub struct AgentServerCommand { pub env: Option>, } -pub struct Gemini; - -pub struct AgentServerVersion { - pub current_version: SharedString, - pub supported: bool, +pub enum AgentServerVersion { + Supported, + Unsupported { + error_message: SharedString, + upgrade_message: SharedString, + upgrade_command: String, + }, } pub trait AgentServer: Send { - fn command( + fn logo(&self) -> ui::IconName; + fn name(&self) -> &'static str; + fn empty_state_headline(&self) -> &'static str; + fn empty_state_message(&self) -> &'static str; + fn supports_always_allow(&self) -> bool; + + fn new_thread( &self, + root_dir: &Path, project: &Entity, - cx: &mut AsyncApp, - ) -> impl Future>; - - fn version( - &self, - command: &AgentServerCommand, - ) -> impl Future> + Send; -} - -const GEMINI_ACP_ARG: &str = "--experimental-acp"; - -impl AgentServer for Gemini { - async fn command( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result { - let custom_command = cx.read_global(|settings: &SettingsStore, _| { - let settings = settings.get::(None); - settings - .gemini - .as_ref() - .map(|gemini_settings| AgentServerCommand { - path: gemini_settings.command.path.clone(), - args: gemini_settings - .command - .args - .iter() - .cloned() - .chain(std::iter::once(GEMINI_ACP_ARG.into())) - .collect(), - env: gemini_settings.command.env.clone(), - }) - })?; - - if let Some(custom_command) = custom_command { - return Ok(custom_command); - } - - if let Some(path) = find_bin_in_path("gemini", project, cx).await { - return Ok(AgentServerCommand { - path, - args: vec![GEMINI_ACP_ARG.into()], - env: None, - }); - } - - let (fs, node_runtime) = project.update(cx, |project, _| { - (project.fs().clone(), project.node_runtime().cloned()) - })?; - let node_runtime = node_runtime.context("gemini not found on path")?; - - let directory = ::paths::agent_servers_dir().join("gemini"); - fs.create_dir(&directory).await?; - node_runtime - .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) - .await?; - let path = directory.join("node_modules/.bin/gemini"); - - Ok(AgentServerCommand { - path, - args: vec![GEMINI_ACP_ARG.into()], - env: None, - }) - } - - async fn version(&self, command: &AgentServerCommand) -> Result { - let version_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--version") - .kill_on_drop(true) - .output(); - - let help_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--help") - .kill_on_drop(true) - .output(); - - let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; - - let current_version = String::from_utf8(version_output?.stdout)?.into(); - let supported = String::from_utf8(help_output?.stdout)?.contains(GEMINI_ACP_ARG); - - Ok(AgentServerVersion { - current_version, - supported, - }) - } -} - -async fn find_bin_in_path( - bin_name: &'static str, - project: &Entity, - cx: &mut AsyncApp, -) -> Option { - let (env_task, root_dir) = project - .update(cx, |project, cx| { - let worktree = project.visible_worktrees(cx).next(); - match worktree { - Some(worktree) => { - let env_task = project.environment().update(cx, |env, cx| { - env.get_worktree_environment(worktree.clone(), cx) - }); - - let path = worktree.read(cx).abs_path(); - (env_task, path) - } - None => { - let path: Arc = paths::home_dir().as_path().into(); - let env_task = project.environment().update(cx, |env, cx| { - env.get_directory_environment(path.clone(), cx) - }); - (env_task, path) - } - } - }) - .log_err()?; - - cx.background_executor() - .spawn(async move { - let which_result = if cfg!(windows) { - which::which(bin_name) - } else { - let env = env_task.await.unwrap_or_default(); - let shell_path = env.get("PATH").cloned(); - which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref()) - }; - - if let Err(which::Error::CannotFindBinaryPath) = which_result { - return None; - } - - which_result.log_err() - }) - .await + cx: &mut App, + ) -> Task>>; } impl std::fmt::Debug for AgentServerCommand { @@ -209,23 +78,3 @@ impl std::fmt::Debug for AgentServerCommand { .finish() } } - -impl settings::Settings for AllAgentServersSettings { - const KEY: Option<&'static str> = Some("agent_servers"); - - type FileContent = Self; - - fn load(sources: SettingsSources, _: &mut App) -> Result { - let mut settings = AllAgentServersSettings::default(); - - for value in sources.defaults_and_customizations() { - if value.gemini.is_some() { - settings.gemini = value.gemini.clone(); - } - } - - Ok(settings) - } - - fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} -} diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs new file mode 100644 index 0000000000..897158dc57 --- /dev/null +++ b/crates/agent_servers/src/claude.rs @@ -0,0 +1,680 @@ +mod mcp_server; +mod tools; + +use collections::HashMap; +use project::Project; +use std::cell::RefCell; +use std::fmt::Display; +use std::path::Path; +use std::rc::Rc; + +use agentic_coding_protocol::{ + self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion, + StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams, +}; +use anyhow::{Context as _, Result, anyhow}; +use futures::channel::oneshot; +use futures::future::LocalBoxFuture; +use futures::{AsyncBufReadExt, AsyncWriteExt}; +use futures::{ + AsyncRead, AsyncWrite, FutureExt, StreamExt, + channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, + io::BufReader, + select_biased, +}; +use gpui::{App, AppContext, Entity, Task}; +use serde::{Deserialize, Serialize}; +use util::ResultExt; + +use crate::claude::mcp_server::ClaudeMcpServer; +use crate::claude::tools::ClaudeTool; +use crate::{AgentServer, find_bin_in_path}; +use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection}; + +#[derive(Clone)] +pub struct ClaudeCode; + +impl AgentServer for ClaudeCode { + fn name(&self) -> &'static str { + "Claude Code" + } + + fn empty_state_headline(&self) -> &'static str { + self.name() + } + + fn empty_state_message(&self) -> &'static str { + "" + } + + fn logo(&self) -> ui::IconName { + ui::IconName::AiClaude + } + + fn supports_always_allow(&self) -> bool { + false + } + + fn new_thread( + &self, + root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + let project = project.clone(); + let root_dir = root_dir.to_path_buf(); + let title = self.name().into(); + cx.spawn(async move |cx| { + let (mut delegate_tx, delegate_rx) = watch::channel(None); + let tool_id_map = Rc::new(RefCell::new(HashMap::default())); + + let permission_mcp_server = + ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?; + + let mut mcp_servers = HashMap::default(); + mcp_servers.insert( + mcp_server::SERVER_NAME.to_string(), + permission_mcp_server.server_config()?, + ); + let mcp_config = McpConfig { mcp_servers }; + + let mcp_config_file = tempfile::NamedTempFile::new()?; + let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts(); + + let mut mcp_config_file = smol::fs::File::from(mcp_config_file); + mcp_config_file + .write_all(serde_json::to_string(&mcp_config)?.as_bytes()) + .await?; + mcp_config_file.flush().await?; + + let command = find_bin_in_path("claude", &project, cx) + .await + .context("Failed to find claude binary")?; + + let mut child = util::command::new_smol_command(&command) + .args([ + "--input-format", + "stream-json", + "--output-format", + "stream-json", + "--print", + "--verbose", + "--mcp-config", + mcp_config_path.to_string_lossy().as_ref(), + "--permission-prompt-tool", + &format!( + "mcp__{}__{}", + mcp_server::SERVER_NAME, + mcp_server::PERMISSION_TOOL + ), + "--allowedTools", + "mcp__zed__Read,mcp__zed__Edit", + "--disallowedTools", + "Read,Edit", + ]) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); + let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + + let io_task = + ClaudeAgentConnection::handle_io(outgoing_rx, incoming_message_tx, stdin, stdout); + cx.background_spawn(async move { + io_task.await.log_err(); + drop(mcp_config_path); + drop(child); + }) + .detach(); + + cx.new(|cx| { + let end_turn_tx = Rc::new(RefCell::new(None)); + let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()); + delegate_tx.send(Some(delegate.clone())).log_err(); + + let handler_task = cx.foreground_executor().spawn({ + let end_turn_tx = end_turn_tx.clone(); + let tool_id_map = tool_id_map.clone(); + async move { + while let Some(message) = incoming_message_rx.next().await { + ClaudeAgentConnection::handle_message( + delegate.clone(), + message, + end_turn_tx.clone(), + tool_id_map.clone(), + ) + .await + } + } + }); + + let mut connection = ClaudeAgentConnection { + outgoing_tx, + end_turn_tx, + _handler_task: handler_task, + _mcp_server: None, + }; + + connection._mcp_server = Some(permission_mcp_server); + acp_thread::AcpThread::new(connection, title, None, project.clone(), cx) + }) + }) + } +} + +impl AgentConnection for ClaudeAgentConnection { + /// Send a request to the agent and wait for a response. + fn request_any( + &self, + params: AnyAgentRequest, + ) -> LocalBoxFuture<'static, Result> { + let end_turn_tx = self.end_turn_tx.clone(); + let outgoing_tx = self.outgoing_tx.clone(); + async move { + match params { + // todo: consider sending an empty request so we get the init response? + AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse( + acp::InitializeResponse { + is_authenticated: true, + protocol_version: ProtocolVersion::latest(), + }, + )), + AnyAgentRequest::AuthenticateParams(_) => { + Err(anyhow!("Authentication not supported")) + } + AnyAgentRequest::SendUserMessageParams(message) => { + let (tx, rx) = oneshot::channel(); + end_turn_tx.borrow_mut().replace(tx); + let mut content = String::new(); + for chunk in message.chunks { + match chunk { + agentic_coding_protocol::UserMessageChunk::Text { text } => { + content.push_str(&text) + } + agentic_coding_protocol::UserMessageChunk::Path { path } => { + content.push_str(&format!("@{path:?}")) + } + } + } + outgoing_tx.unbounded_send(SdkMessage::User { + message: Message { + role: Role::User, + content: Content::UntaggedText(content), + id: None, + model: None, + stop_reason: None, + stop_sequence: None, + usage: None, + }, + session_id: None, + })?; + rx.await??; + Ok(AnyAgentResult::SendUserMessageResponse( + acp::SendUserMessageResponse, + )) + } + AnyAgentRequest::CancelSendMessageParams(_) => Ok( + AnyAgentResult::CancelSendMessageResponse(acp::CancelSendMessageResponse), + ), + } + } + .boxed_local() + } +} + +struct ClaudeAgentConnection { + outgoing_tx: UnboundedSender, + end_turn_tx: Rc>>>>, + _mcp_server: Option, + _handler_task: Task<()>, +} + +impl ClaudeAgentConnection { + async fn handle_message( + delegate: AcpClientDelegate, + message: SdkMessage, + end_turn_tx: Rc>>>>, + tool_id_map: Rc>>, + ) { + match message { + SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => { + for chunk in message.content.chunks() { + match chunk { + ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { + delegate + .stream_assistant_message_chunk(StreamAssistantMessageChunkParams { + chunk: acp::AssistantMessageChunk::Text { text }, + }) + .await + .log_err(); + } + ContentChunk::ToolUse { id, name, input } => { + if let Some(resp) = delegate + .push_tool_call(ClaudeTool::infer(&name, input).as_acp()) + .await + .log_err() + { + tool_id_map.borrow_mut().insert(id, resp.id); + } + } + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + let id = tool_id_map.borrow_mut().remove(&tool_use_id); + if let Some(id) = id { + delegate + .update_tool_call(UpdateToolCallParams { + tool_call_id: id, + status: acp::ToolCallStatus::Finished, + content: Some(ToolCallContent::Markdown { + // For now we only include text content + markdown: content.to_string(), + }), + }) + .await + .log_err(); + } + } + ContentChunk::Image + | ContentChunk::Document + | ContentChunk::Thinking + | ContentChunk::RedactedThinking + | ContentChunk::WebSearchToolResult => { + delegate + .stream_assistant_message_chunk(StreamAssistantMessageChunkParams { + chunk: acp::AssistantMessageChunk::Text { + text: format!("Unsupported content: {:?}", chunk), + }, + }) + .await + .log_err(); + } + } + } + } + SdkMessage::Result { + is_error, subtype, .. + } => { + if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { + if is_error { + end_turn_tx.send(Err(anyhow!("Error: {subtype}"))).ok(); + } else { + end_turn_tx.send(Ok(())).ok(); + } + } + } + SdkMessage::System { .. } => {} + } + } + + async fn handle_io( + mut outgoing_rx: UnboundedReceiver, + incoming_tx: UnboundedSender, + mut outgoing_bytes: impl Unpin + AsyncWrite, + incoming_bytes: impl Unpin + AsyncRead, + ) -> Result<()> { + let mut output_reader = BufReader::new(incoming_bytes); + let mut outgoing_line = Vec::new(); + let mut incoming_line = String::new(); + loop { + select_biased! { + message = outgoing_rx.next() => { + if let Some(message) = message { + outgoing_line.clear(); + serde_json::to_writer(&mut outgoing_line, &message)?; + log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line)); + outgoing_line.push(b'\n'); + outgoing_bytes.write_all(&outgoing_line).await.ok(); + } else { + break; + } + } + bytes_read = output_reader.read_line(&mut incoming_line).fuse() => { + if bytes_read? == 0 { + break + } + log::trace!("recv: {}", &incoming_line); + match serde_json::from_str::(&incoming_line) { + Ok(message) => { + incoming_tx.unbounded_send(message).log_err(); + } + Err(error) => { + log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}"); + } + } + incoming_line.clear(); + } + } + } + Ok(()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Message { + role: Role, + content: Content, + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stop_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stop_sequence: Option, + #[serde(skip_serializing_if = "Option::is_none")] + usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +enum Content { + UntaggedText(String), + Chunks(Vec), +} + +impl Content { + pub fn chunks(self) -> impl Iterator { + match self { + Self::Chunks(chunks) => chunks.into_iter(), + Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(), + } + } +} + +impl Display for Content { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Content::UntaggedText(txt) => write!(f, "{}", txt), + Content::Chunks(chunks) => { + for chunk in chunks { + write!(f, "{}", chunk)?; + } + Ok(()) + } + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ContentChunk { + Text { + text: String, + }, + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + ToolResult { + content: Content, + tool_use_id: String, + }, + // TODO + Image, + Document, + Thinking, + RedactedThinking, + WebSearchToolResult, + #[serde(untagged)] + UntaggedText(String), +} + +impl Display for ContentChunk { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ContentChunk::Text { text } => write!(f, "{}", text), + ContentChunk::UntaggedText(text) => write!(f, "{}", text), + ContentChunk::ToolResult { content, .. } => write!(f, "{}", content), + ContentChunk::Image + | ContentChunk::Document + | ContentChunk::Thinking + | ContentChunk::RedactedThinking + | ContentChunk::ToolUse { .. } + | ContentChunk::WebSearchToolResult => { + write!(f, "\n{:?}\n", &self) + } + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Usage { + input_tokens: u32, + cache_creation_input_tokens: u32, + cache_read_input_tokens: u32, + output_tokens: u32, + service_tier: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +enum Role { + System, + Assistant, + User, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MessageParam { + role: Role, + content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum SdkMessage { + // An assistant message + Assistant { + message: Message, // from Anthropic SDK + #[serde(skip_serializing_if = "Option::is_none")] + session_id: Option, + }, + + // A user message + User { + message: Message, // from Anthropic SDK + #[serde(skip_serializing_if = "Option::is_none")] + session_id: Option, + }, + + // Emitted as the last message in a conversation + Result { + subtype: ResultErrorType, + duration_ms: f64, + duration_api_ms: f64, + is_error: bool, + num_turns: i32, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + session_id: String, + total_cost_usd: f64, + }, + // Emitted as the first message at the start of a conversation + System { + cwd: String, + session_id: String, + tools: Vec, + model: String, + mcp_servers: Vec, + #[serde(rename = "apiKeySource")] + api_key_source: String, + #[serde(rename = "permissionMode")] + permission_mode: PermissionMode, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +enum ResultErrorType { + Success, + ErrorMaxTurns, + ErrorDuringExecution, +} + +impl Display for ResultErrorType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ResultErrorType::Success => write!(f, "success"), + ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"), + ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct McpServer { + name: String, + status: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +enum PermissionMode { + Default, + AcceptEdits, + BypassPermissions, + Plan, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct McpConfig { + mcp_servers: HashMap, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct McpServerConfig { + command: String, + args: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + env: Option>, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_deserialize_content_untagged_text() { + let json = json!("Hello, world!"); + let content: Content = serde_json::from_value(json).unwrap(); + match content { + Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"), + _ => panic!("Expected UntaggedText variant"), + } + } + + #[test] + fn test_deserialize_content_chunks() { + let json = json!([ + { + "type": "text", + "text": "Hello" + }, + { + "type": "tool_use", + "id": "tool_123", + "name": "calculator", + "input": {"operation": "add", "a": 1, "b": 2} + } + ]); + let content: Content = serde_json::from_value(json).unwrap(); + match content { + Content::Chunks(chunks) => { + assert_eq!(chunks.len(), 2); + match &chunks[0] { + ContentChunk::Text { text } => assert_eq!(text, "Hello"), + _ => panic!("Expected Text chunk"), + } + match &chunks[1] { + ContentChunk::ToolUse { id, name, input } => { + assert_eq!(id, "tool_123"); + assert_eq!(name, "calculator"); + assert_eq!(input["operation"], "add"); + assert_eq!(input["a"], 1); + assert_eq!(input["b"], 2); + } + _ => panic!("Expected ToolUse chunk"), + } + } + _ => panic!("Expected Chunks variant"), + } + } + + #[test] + fn test_deserialize_tool_result_untagged_text() { + let json = json!({ + "type": "tool_result", + "content": "Result content", + "tool_use_id": "tool_456" + }); + let chunk: ContentChunk = serde_json::from_value(json).unwrap(); + match chunk { + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + match content { + Content::UntaggedText(text) => assert_eq!(text, "Result content"), + _ => panic!("Expected UntaggedText content"), + } + assert_eq!(tool_use_id, "tool_456"); + } + _ => panic!("Expected ToolResult variant"), + } + } + + #[test] + fn test_deserialize_tool_result_chunks() { + let json = json!({ + "type": "tool_result", + "content": [ + { + "type": "text", + "text": "Processing complete" + }, + { + "type": "text", + "text": "Result: 42" + } + ], + "tool_use_id": "tool_789" + }); + let chunk: ContentChunk = serde_json::from_value(json).unwrap(); + match chunk { + ContentChunk::ToolResult { + content, + tool_use_id, + } => { + match content { + Content::Chunks(chunks) => { + assert_eq!(chunks.len(), 2); + match &chunks[0] { + ContentChunk::Text { text } => assert_eq!(text, "Processing complete"), + _ => panic!("Expected Text chunk"), + } + match &chunks[1] { + ContentChunk::Text { text } => assert_eq!(text, "Result: 42"), + _ => panic!("Expected Text chunk"), + } + } + _ => panic!("Expected Chunks content"), + } + assert_eq!(tool_use_id, "tool_789"); + } + _ => panic!("Expected ToolResult variant"), + } + } +} diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs new file mode 100644 index 0000000000..fa61e67112 --- /dev/null +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -0,0 +1,303 @@ +use std::{cell::RefCell, rc::Rc}; + +use acp_thread::AcpClientDelegate; +use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams}; +use anyhow::{Context, Result}; +use collections::HashMap; +use context_server::{ + listener::McpServer, + types::{ + CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse, + ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, + ToolResponseContent, ToolsCapabilities, requests, + }, +}; +use gpui::{App, AsyncApp, Task}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use util::debug_panic; + +use crate::claude::{ + McpServerConfig, + tools::{ClaudeTool, EditToolParams, EditToolResponse, ReadToolParams, ReadToolResponse}, +}; + +pub struct ClaudeMcpServer { + server: McpServer, +} + +pub const SERVER_NAME: &str = "zed"; +pub const READ_TOOL: &str = "Read"; +pub const EDIT_TOOL: &str = "Edit"; +pub const PERMISSION_TOOL: &str = "Confirmation"; + +#[derive(Deserialize, JsonSchema, Debug)] +struct PermissionToolParams { + tool_name: String, + input: serde_json::Value, + tool_use_id: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct PermissionToolResponse { + behavior: PermissionToolBehavior, + updated_input: serde_json::Value, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +enum PermissionToolBehavior { + Allow, + Deny, +} + +impl ClaudeMcpServer { + pub async fn new( + delegate: watch::Receiver>, + tool_id_map: Rc>>, + cx: &AsyncApp, + ) -> Result { + let mut mcp_server = McpServer::new(cx).await?; + mcp_server.handle_request::(Self::handle_initialize); + mcp_server.handle_request::(Self::handle_list_tools); + mcp_server.handle_request::(move |request, cx| { + Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx) + }); + + Ok(Self { server: mcp_server }) + } + + pub fn server_config(&self) -> Result { + #[cfg(not(target_os = "windows"))] + let zed_path = util::get_shell_safe_zed_path()?; + #[cfg(target_os = "windows")] + let zed_path = std::env::current_exe() + .context("finding current executable path for use in mcp_server")? + .to_string_lossy() + .to_string(); + + Ok(McpServerConfig { + command: zed_path, + args: vec![ + "--nc".into(), + self.server.socket_path().display().to_string(), + ], + env: None, + }) + } + + fn handle_initialize(_: InitializeParams, cx: &App) -> Task> { + cx.foreground_executor().spawn(async move { + Ok(InitializeResponse { + protocol_version: ProtocolVersion("2025-06-18".into()), + capabilities: ServerCapabilities { + experimental: None, + logging: None, + completions: None, + prompts: None, + resources: None, + tools: Some(ToolsCapabilities { + list_changed: Some(false), + }), + }, + server_info: Implementation { + name: SERVER_NAME.into(), + version: "0.1.0".into(), + }, + meta: None, + }) + }) + } + + fn handle_list_tools(_: (), cx: &App) -> Task> { + cx.foreground_executor().spawn(async move { + Ok(ListToolsResponse { + tools: vec![ + Tool { + name: PERMISSION_TOOL.into(), + input_schema: schemars::schema_for!(PermissionToolParams).into(), + description: None, + annotations: None, + }, + Tool { + name: READ_TOOL.into(), + input_schema: schemars::schema_for!(ReadToolParams).into(), + description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()), + annotations: Some(ToolAnnotations { + title: Some("Read file".to_string()), + read_only_hint: Some(true), + destructive_hint: Some(false), + open_world_hint: Some(false), + // if time passes the contents might change, but it's not going to do anything different + // true or false seem too strong, let's try a none. + idempotent_hint: None, + }), + }, + Tool { + name: EDIT_TOOL.into(), + input_schema: schemars::schema_for!(EditToolParams).into(), + description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()), + annotations: Some(ToolAnnotations { + title: Some("Edit file".to_string()), + read_only_hint: Some(false), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: Some(false), + }), + }, + ], + next_cursor: None, + meta: None, + }) + }) + } + + fn handle_call_tool( + request: CallToolParams, + mut delegate_watch: watch::Receiver>, + tool_id_map: Rc>>, + cx: &App, + ) -> Task> { + cx.spawn(async move |cx| { + let Some(delegate) = delegate_watch.recv().await? else { + debug_panic!("Sent None delegate"); + anyhow::bail!("Server not available"); + }; + + if request.name.as_str() == PERMISSION_TOOL { + let input = + serde_json::from_value(request.arguments.context("Arguments required")?)?; + + let result = + Self::handle_permissions_tool_call(input, delegate, tool_id_map, cx).await?; + Ok(CallToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&result)?, + }], + is_error: None, + meta: None, + }) + } else if request.name.as_str() == READ_TOOL { + let input = + serde_json::from_value(request.arguments.context("Arguments required")?)?; + + let result = Self::handle_read_tool_call(input, delegate, cx).await?; + Ok(CallToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&result)?, + }], + is_error: None, + meta: None, + }) + } else if request.name.as_str() == EDIT_TOOL { + let input = + serde_json::from_value(request.arguments.context("Arguments required")?)?; + + let result = Self::handle_edit_tool_call(input, delegate, cx).await?; + Ok(CallToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&result)?, + }], + is_error: None, + meta: None, + }) + } else { + anyhow::bail!("Unsupported tool"); + } + }) + } + + fn handle_read_tool_call( + params: ReadToolParams, + delegate: AcpClientDelegate, + cx: &AsyncApp, + ) -> Task> { + cx.foreground_executor().spawn(async move { + let response = delegate + .read_text_file(ReadTextFileParams { + path: params.abs_path, + line: params.offset, + limit: params.limit, + }) + .await?; + + Ok(ReadToolResponse { + content: response.content, + }) + }) + } + + fn handle_edit_tool_call( + params: EditToolParams, + delegate: AcpClientDelegate, + cx: &AsyncApp, + ) -> Task> { + cx.foreground_executor().spawn(async move { + let response = delegate + .read_text_file_reusing_snapshot(ReadTextFileParams { + path: params.abs_path.clone(), + line: None, + limit: None, + }) + .await?; + + let new_content = response.content.replace(¶ms.old_text, ¶ms.new_text); + if new_content == response.content { + return Err(anyhow::anyhow!("The old_text was not found in the content")); + } + + delegate + .write_text_file(WriteTextFileParams { + path: params.abs_path, + content: new_content, + }) + .await?; + + Ok(EditToolResponse) + }) + } + + fn handle_permissions_tool_call( + params: PermissionToolParams, + delegate: AcpClientDelegate, + tool_id_map: Rc>>, + cx: &AsyncApp, + ) -> Task> { + cx.foreground_executor().spawn(async move { + let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone()); + + let tool_call_id = match params.tool_use_id { + Some(tool_use_id) => tool_id_map + .borrow() + .get(&tool_use_id) + .cloned() + .context("Tool call ID not found")?, + + None => delegate.push_tool_call(claude_tool.as_acp()).await?.id, + }; + + let outcome = delegate + .request_existing_tool_call_confirmation( + tool_call_id, + claude_tool.confirmation(None), + ) + .await?; + + match outcome { + acp::ToolCallConfirmationOutcome::Allow + | acp::ToolCallConfirmationOutcome::AlwaysAllow + | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer + | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse { + behavior: PermissionToolBehavior::Allow, + updated_input: params.input, + }), + acp::ToolCallConfirmationOutcome::Reject + | acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse { + behavior: PermissionToolBehavior::Deny, + updated_input: params.input, + }), + } + }) + } +} diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs new file mode 100644 index 0000000000..89d42c0daa --- /dev/null +++ b/crates/agent_servers/src/claude/tools.rs @@ -0,0 +1,670 @@ +use std::path::PathBuf; + +use agentic_coding_protocol::{self as acp, PushToolCallParams, ToolCallLocation}; +use itertools::Itertools; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use util::ResultExt; + +pub enum ClaudeTool { + Task(Option), + NotebookRead(Option), + NotebookEdit(Option), + Edit(Option), + MultiEdit(Option), + ReadFile(Option), + Write(Option), + Ls(Option), + Glob(Option), + Grep(Option), + Terminal(Option), + WebFetch(Option), + WebSearch(Option), + TodoWrite(Option), + ExitPlanMode(Option), + Other { + name: String, + input: serde_json::Value, + }, +} + +impl ClaudeTool { + pub fn infer(tool_name: &str, input: serde_json::Value) -> Self { + match tool_name { + // Known tools + "mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()), + "mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()), + "MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()), + "Write" => Self::Write(serde_json::from_value(input).log_err()), + "LS" => Self::Ls(serde_json::from_value(input).log_err()), + "Glob" => Self::Glob(serde_json::from_value(input).log_err()), + "Grep" => Self::Grep(serde_json::from_value(input).log_err()), + "Bash" => Self::Terminal(serde_json::from_value(input).log_err()), + "WebFetch" => Self::WebFetch(serde_json::from_value(input).log_err()), + "WebSearch" => Self::WebSearch(serde_json::from_value(input).log_err()), + "TodoWrite" => Self::TodoWrite(serde_json::from_value(input).log_err()), + "exit_plan_mode" => Self::ExitPlanMode(serde_json::from_value(input).log_err()), + "Task" => Self::Task(serde_json::from_value(input).log_err()), + "NotebookRead" => Self::NotebookRead(serde_json::from_value(input).log_err()), + "NotebookEdit" => Self::NotebookEdit(serde_json::from_value(input).log_err()), + // Inferred from name + _ => { + let tool_name = tool_name.to_lowercase(); + + if tool_name.contains("edit") || tool_name.contains("write") { + Self::Edit(None) + } else if tool_name.contains("terminal") { + Self::Terminal(None) + } else { + Self::Other { + name: tool_name.to_string(), + input, + } + } + } + } + } + + pub fn label(&self) -> String { + match &self { + Self::Task(Some(params)) => params.description.clone(), + Self::Task(None) => "Task".into(), + Self::NotebookRead(Some(params)) => { + format!("Read Notebook {}", params.notebook_path.display()) + } + Self::NotebookRead(None) => "Read Notebook".into(), + Self::NotebookEdit(Some(params)) => { + format!("Edit Notebook {}", params.notebook_path.display()) + } + Self::NotebookEdit(None) => "Edit Notebook".into(), + Self::Terminal(Some(params)) => format!("`{}`", params.command), + Self::Terminal(None) => "Terminal".into(), + Self::ReadFile(_) => "Read File".into(), + Self::Ls(Some(params)) => { + format!("List Directory {}", params.path.display()) + } + Self::Ls(None) => "List Directory".into(), + Self::Edit(Some(params)) => { + format!("Edit {}", params.abs_path.display()) + } + Self::Edit(None) => "Edit".into(), + Self::MultiEdit(Some(params)) => { + format!("Multi Edit {}", params.file_path.display()) + } + Self::MultiEdit(None) => "Multi Edit".into(), + Self::Write(Some(params)) => { + format!("Write {}", params.file_path.display()) + } + Self::Write(None) => "Write".into(), + Self::Glob(Some(params)) => { + format!("Glob {params}") + } + Self::Glob(None) => "Glob".into(), + Self::Grep(Some(params)) => params.to_string(), + Self::Grep(None) => "Grep".into(), + Self::WebFetch(Some(params)) => format!("Fetch {}", params.url), + Self::WebFetch(None) => "Fetch".into(), + Self::WebSearch(Some(params)) => format!("Web Search: {}", params), + Self::WebSearch(None) => "Web Search".into(), + Self::TodoWrite(Some(params)) => format!( + "Update TODOs: {}", + params.todos.iter().map(|todo| &todo.content).join(", ") + ), + Self::TodoWrite(None) => "Update TODOs".into(), + Self::ExitPlanMode(_) => "Exit Plan Mode".into(), + Self::Other { name, .. } => name.clone(), + } + } + + pub fn content(&self) -> Option { + match &self { + ClaudeTool::Other { input, .. } => Some(acp::ToolCallContent::Markdown { + markdown: format!( + "```json\n{}```", + serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) + ), + }), + _ => None, + } + } + + pub fn icon(&self) -> acp::Icon { + match self { + Self::Task(_) => acp::Icon::Hammer, + Self::NotebookRead(_) => acp::Icon::FileSearch, + Self::NotebookEdit(_) => acp::Icon::Pencil, + Self::Edit(_) => acp::Icon::Pencil, + Self::MultiEdit(_) => acp::Icon::Pencil, + Self::Write(_) => acp::Icon::Pencil, + Self::ReadFile(_) => acp::Icon::FileSearch, + Self::Ls(_) => acp::Icon::Folder, + Self::Glob(_) => acp::Icon::FileSearch, + Self::Grep(_) => acp::Icon::Regex, + Self::Terminal(_) => acp::Icon::Terminal, + Self::WebSearch(_) => acp::Icon::Globe, + Self::WebFetch(_) => acp::Icon::Globe, + Self::TodoWrite(_) => acp::Icon::LightBulb, + Self::ExitPlanMode(_) => acp::Icon::Hammer, + Self::Other { .. } => acp::Icon::Hammer, + } + } + + pub fn confirmation(&self, description: Option) -> acp::ToolCallConfirmation { + match &self { + Self::Edit(_) | Self::Write(_) | Self::NotebookEdit(_) | Self::MultiEdit(_) => { + acp::ToolCallConfirmation::Edit { description } + } + Self::WebFetch(params) => acp::ToolCallConfirmation::Fetch { + urls: params + .as_ref() + .map(|p| vec![p.url.clone()]) + .unwrap_or_default(), + description, + }, + Self::Terminal(Some(BashToolParams { + description, + command, + .. + })) => acp::ToolCallConfirmation::Execute { + command: command.clone(), + root_command: command.clone(), + description: description.clone(), + }, + Self::ExitPlanMode(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {}", params.plan) + } else { + params.plan.clone() + }, + }, + Self::Task(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {}", params.description) + } else { + params.description.clone() + }, + }, + Self::Ls(Some(LsToolParams { path, .. })) + | Self::ReadFile(Some(ReadToolParams { abs_path: path, .. })) => { + let path = path.display(); + acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {path}") + } else { + path.to_string() + }, + } + } + Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { + let path = notebook_path.display(); + acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {path}") + } else { + path.to_string() + }, + } + } + Self::Glob(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {params}") + } else { + params.to_string() + }, + }, + Self::Grep(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {params}") + } else { + params.to_string() + }, + }, + Self::WebSearch(Some(params)) => acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {params}") + } else { + params.to_string() + }, + }, + Self::TodoWrite(Some(params)) => { + let params = params.todos.iter().map(|todo| &todo.content).join(", "); + acp::ToolCallConfirmation::Other { + description: if let Some(description) = description { + format!("{description} {params}") + } else { + params + }, + } + } + Self::Terminal(None) + | Self::Task(None) + | Self::NotebookRead(None) + | Self::ExitPlanMode(None) + | Self::Ls(None) + | Self::Glob(None) + | Self::Grep(None) + | Self::ReadFile(None) + | Self::WebSearch(None) + | Self::TodoWrite(None) + | Self::Other { .. } => acp::ToolCallConfirmation::Other { + description: description.unwrap_or("".to_string()), + }, + } + } + + pub fn locations(&self) -> Vec { + match &self { + Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![ToolCallLocation { + path: abs_path.clone(), + line: None, + }], + Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { + vec![ToolCallLocation { + path: file_path.clone(), + line: None, + }] + } + Self::Write(Some(WriteToolParams { file_path, .. })) => vec![ToolCallLocation { + path: file_path.clone(), + line: None, + }], + Self::ReadFile(Some(ReadToolParams { + abs_path, offset, .. + })) => vec![ToolCallLocation { + path: abs_path.clone(), + line: *offset, + }], + Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { + vec![ToolCallLocation { + path: notebook_path.clone(), + line: None, + }] + } + Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { + vec![ToolCallLocation { + path: notebook_path.clone(), + line: None, + }] + } + Self::Glob(Some(GlobToolParams { + path: Some(path), .. + })) => vec![ToolCallLocation { + path: path.clone(), + line: None, + }], + Self::Ls(Some(LsToolParams { path, .. })) => vec![ToolCallLocation { + path: path.clone(), + line: None, + }], + Self::Grep(Some(GrepToolParams { + path: Some(path), .. + })) => vec![ToolCallLocation { + path: PathBuf::from(path), + line: None, + }], + Self::Task(_) + | Self::NotebookRead(None) + | Self::NotebookEdit(None) + | Self::Edit(None) + | Self::MultiEdit(None) + | Self::Write(None) + | Self::ReadFile(None) + | Self::Ls(None) + | Self::Glob(_) + | Self::Grep(_) + | Self::Terminal(_) + | Self::WebFetch(_) + | Self::WebSearch(_) + | Self::TodoWrite(_) + | Self::ExitPlanMode(_) + | Self::Other { .. } => vec![], + } + } + + pub fn as_acp(&self) -> PushToolCallParams { + PushToolCallParams { + label: self.label(), + content: self.content(), + icon: self.icon(), + locations: self.locations(), + } + } +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct EditToolParams { + /// The absolute path to the file to read. + pub abs_path: PathBuf, + /// The old text to replace (must be unique in the file) + pub old_text: String, + /// The new text. + pub new_text: String, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct EditToolResponse; + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct ReadToolParams { + /// The absolute path to the file to read. + pub abs_path: PathBuf, + /// Which line to start reading from. Omit to start from the beginning. + #[serde(skip_serializing_if = "Option::is_none")] + pub offset: Option, + /// How many lines to read. Omit for the whole file. + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ReadToolResponse { + pub content: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct WriteToolParams { + /// Absolute path for new file + pub file_path: PathBuf, + /// File content + pub content: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct BashToolParams { + /// Shell command to execute + pub command: String, + /// 5-10 word description of what command does + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Timeout in ms (max 600000ms/10min, default 120000ms) + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct GlobToolParams { + /// Glob pattern like **/*.js or src/**/*.ts + pub pattern: String, + /// Directory to search in (omit for current directory) + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, +} + +impl std::fmt::Display for GlobToolParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(path) = &self.path { + write!(f, "{}", path.display())?; + } + write!(f, "{}", self.pattern) + } +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct LsToolParams { + /// Absolute path to directory + pub path: PathBuf, + /// Array of glob patterns to ignore + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub ignore: Vec, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct GrepToolParams { + /// Regex pattern to search for + pub pattern: String, + /// File/directory to search (defaults to current directory) + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// "content" (shows lines), "files_with_matches" (default), "count" + #[serde(skip_serializing_if = "Option::is_none")] + pub output_mode: Option, + /// Filter files with glob pattern like "*.js" + #[serde(skip_serializing_if = "Option::is_none")] + pub glob: Option, + /// File type filter like "js", "py", "rust" + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub file_type: Option, + /// Case insensitive search + #[serde(rename = "-i", default, skip_serializing_if = "is_false")] + pub case_insensitive: bool, + /// Show line numbers (content mode only) + #[serde(rename = "-n", default, skip_serializing_if = "is_false")] + pub line_numbers: bool, + /// Lines after match (content mode only) + #[serde(rename = "-A", skip_serializing_if = "Option::is_none")] + pub after_context: Option, + /// Lines before match (content mode only) + #[serde(rename = "-B", skip_serializing_if = "Option::is_none")] + pub before_context: Option, + /// Lines before and after match (content mode only) + #[serde(rename = "-C", skip_serializing_if = "Option::is_none")] + pub context: Option, + /// Enable multiline/cross-line matching + #[serde(default, skip_serializing_if = "is_false")] + pub multiline: bool, + /// Limit output to first N results + #[serde(skip_serializing_if = "Option::is_none")] + pub head_limit: Option, +} + +impl std::fmt::Display for GrepToolParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "grep")?; + + // Boolean flags + if self.case_insensitive { + write!(f, " -i")?; + } + if self.line_numbers { + write!(f, " -n")?; + } + + // Context options + if let Some(after) = self.after_context { + write!(f, " -A {}", after)?; + } + if let Some(before) = self.before_context { + write!(f, " -B {}", before)?; + } + if let Some(context) = self.context { + write!(f, " -C {}", context)?; + } + + // Output mode + if let Some(mode) = &self.output_mode { + match mode { + GrepOutputMode::FilesWithMatches => write!(f, " -l")?, + GrepOutputMode::Count => write!(f, " -c")?, + GrepOutputMode::Content => {} // Default mode + } + } + + // Head limit + if let Some(limit) = self.head_limit { + write!(f, " | head -{}", limit)?; + } + + // Glob pattern + if let Some(glob) = &self.glob { + write!(f, " --include=\"{}\"", glob)?; + } + + // File type + if let Some(file_type) = &self.file_type { + write!(f, " --type={}", file_type)?; + } + + // Multiline + if self.multiline { + write!(f, " -P")?; // Perl-compatible regex for multiline + } + + // Pattern (escaped if contains special characters) + write!(f, " \"{}\"", self.pattern)?; + + // Path + if let Some(path) = &self.path { + write!(f, " {}", path)?; + } + + Ok(()) + } +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum TodoPriority { + High, + Medium, + Low, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum TodoStatus { + Pending, + InProgress, + Completed, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +pub struct Todo { + /// Unique identifier + pub id: String, + /// Task description + pub content: String, + /// Priority level of the todo + pub priority: TodoPriority, + /// Current status of the todo + pub status: TodoStatus, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct TodoWriteToolParams { + pub todos: Vec, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct ExitPlanModeToolParams { + /// Implementation plan in markdown format + pub plan: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct TaskToolParams { + /// Short 3-5 word description of task + pub description: String, + /// Detailed task for agent to perform + pub prompt: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct NotebookReadToolParams { + /// Absolute path to .ipynb file + pub notebook_path: PathBuf, + /// Specific cell ID to read + #[serde(skip_serializing_if = "Option::is_none")] + pub cell_id: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum CellType { + Code, + Markdown, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum EditMode { + Replace, + Insert, + Delete, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct NotebookEditToolParams { + /// Absolute path to .ipynb file + pub notebook_path: PathBuf, + /// New cell content + pub new_source: String, + /// Cell ID to edit + #[serde(skip_serializing_if = "Option::is_none")] + pub cell_id: Option, + /// Type of cell (code or markdown) + #[serde(skip_serializing_if = "Option::is_none")] + pub cell_type: Option, + /// Edit operation mode + #[serde(skip_serializing_if = "Option::is_none")] + pub edit_mode: Option, +} + +#[derive(Deserialize, Serialize, JsonSchema, Debug)] +pub struct MultiEditItem { + /// The text to search for and replace + pub old_string: String, + /// The replacement text + pub new_string: String, + /// Whether to replace all occurrences or just the first + #[serde(default, skip_serializing_if = "is_false")] + pub replace_all: bool, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct MultiEditToolParams { + /// Absolute path to file + pub file_path: PathBuf, + /// List of edits to apply + pub edits: Vec, +} + +fn is_false(v: &bool) -> bool { + !*v +} + +#[derive(Deserialize, JsonSchema, Debug)] +#[serde(rename_all = "snake_case")] +pub enum GrepOutputMode { + Content, + FilesWithMatches, + Count, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct WebFetchToolParams { + /// Valid URL to fetch + #[serde(rename = "url")] + pub url: String, + /// What to extract from content + pub prompt: String, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct WebSearchToolParams { + /// Search query (min 2 chars) + pub query: String, + /// Only include these domains + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub allowed_domains: Vec, + /// Exclude these domains + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub blocked_domains: Vec, +} + +impl std::fmt::Display for WebSearchToolParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "\"{}\"", self.query)?; + + if !self.allowed_domains.is_empty() { + write!(f, " (allowed: {})", self.allowed_domains.join(", "))?; + } + + if !self.blocked_domains.is_empty() { + write!(f, " (blocked: {})", self.blocked_domains.join(", "))?; + } + + Ok(()) + } +} diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs new file mode 100644 index 0000000000..bf1d13429e --- /dev/null +++ b/crates/agent_servers/src/gemini.rs @@ -0,0 +1,501 @@ +use crate::stdio_agent_server::{StdioAgentServer, find_bin_in_path}; +use crate::{AgentServerCommand, AgentServerVersion}; +use anyhow::{Context as _, Result}; +use gpui::{AsyncApp, Entity}; +use project::Project; +use settings::SettingsStore; + +use crate::AllAgentServersSettings; + +#[derive(Clone)] +pub struct Gemini; + +const ACP_ARG: &str = "--experimental-acp"; + +impl StdioAgentServer for Gemini { + fn name(&self) -> &'static str { + "Gemini" + } + + fn empty_state_headline(&self) -> &'static str { + "Welcome to Gemini" + } + + fn empty_state_message(&self) -> &'static str { + "Ask questions, edit files, run commands.\nBe specific for the best results." + } + + fn supports_always_allow(&self) -> bool { + true + } + + fn logo(&self) -> ui::IconName { + ui::IconName::AiGemini + } + + async fn command( + &self, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result { + let custom_command = cx.read_global(|settings: &SettingsStore, _| { + let settings = settings.get::(None); + settings + .gemini + .as_ref() + .map(|gemini_settings| AgentServerCommand { + path: gemini_settings.command.path.clone(), + args: gemini_settings + .command + .args + .iter() + .cloned() + .chain(std::iter::once(ACP_ARG.into())) + .collect(), + env: gemini_settings.command.env.clone(), + }) + })?; + + if let Some(custom_command) = custom_command { + return Ok(custom_command); + } + + if let Some(path) = find_bin_in_path("gemini", project, cx).await { + return Ok(AgentServerCommand { + path, + args: vec![ACP_ARG.into()], + env: None, + }); + } + + let (fs, node_runtime) = project.update(cx, |project, _| { + (project.fs().clone(), project.node_runtime().cloned()) + })?; + let node_runtime = node_runtime.context("gemini not found on path")?; + + let directory = ::paths::agent_servers_dir().join("gemini"); + fs.create_dir(&directory).await?; + node_runtime + .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) + .await?; + let path = directory.join("node_modules/.bin/gemini"); + + Ok(AgentServerCommand { + path, + args: vec![ACP_ARG.into()], + env: None, + }) + } + + async fn version(&self, command: &AgentServerCommand) -> Result { + let version_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--version") + .kill_on_drop(true) + .output(); + + let help_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--help") + .kill_on_drop(true) + .output(); + + let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; + + let current_version = String::from_utf8(version_output?.stdout)?; + let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); + + if supported { + Ok(AgentServerVersion::Supported) + } else { + Ok(AgentServerVersion::Unsupported { + error_message: format!( + "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", + current_version + ).into(), + upgrade_message: "Upgrade Gemini to Latest".into(), + upgrade_command: "npm install -g @google/gemini-cli@latest".into(), + }) + } + } +} + +#[cfg(test)] +mod test { + use std::{path::Path, time::Duration}; + + use acp_thread::{ + AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, + ToolCallStatus, + }; + use agentic_coding_protocol as acp; + use anyhow::Result; + use futures::{FutureExt, StreamExt, channel::mpsc, select}; + use gpui::{AsyncApp, Entity, TestAppContext}; + use indoc::indoc; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + use crate::{AgentServer, AgentServerCommand, AgentServerVersion, StdioAgentServer}; + + pub async fn gemini_acp_thread( + project: Entity, + current_dir: impl AsRef, + cx: &mut TestAppContext, + ) -> Entity { + #[derive(Clone)] + struct DevGemini; + + impl StdioAgentServer for DevGemini { + async fn command( + &self, + _project: &Entity, + _cx: &mut AsyncApp, + ) -> Result { + let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../../../gemini-cli/packages/cli") + .to_string_lossy() + .to_string(); + + Ok(AgentServerCommand { + path: "node".into(), + args: vec![cli_path, "--experimental-acp".into()], + env: None, + }) + } + + async fn version(&self, _command: &AgentServerCommand) -> Result { + Ok(AgentServerVersion::Supported) + } + + fn logo(&self) -> ui::IconName { + ui::IconName::AiGemini + } + + fn name(&self) -> &'static str { + "test" + } + + fn empty_state_headline(&self) -> &'static str { + "test" + } + + fn empty_state_message(&self) -> &'static str { + "test" + } + + fn supports_always_allow(&self) -> bool { + true + } + } + + let thread = cx + .update(|cx| AgentServer::new_thread(&DevGemini, current_dir.as_ref(), &project, cx)) + .await + .unwrap(); + + thread + .update(cx, |thread, _| thread.initialize()) + .await + .unwrap(); + thread + } + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + } + + #[gpui::test] + #[cfg_attr(not(feature = "gemini"), ignore)] + async fn test_gemini_basic(cx: &mut TestAppContext) { + init_test(cx); + + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; + thread + .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 2); + assert!(matches!( + thread.entries()[0], + AgentThreadEntry::UserMessage(_) + )); + assert!(matches!( + thread.entries()[1], + AgentThreadEntry::AssistantMessage(_) + )); + }); + } + + #[gpui::test] + #[cfg_attr(not(feature = "gemini"), ignore)] + async fn test_gemini_path_mentions(cx: &mut TestAppContext) { + init_test(cx); + + cx.executor().allow_parking(); + let tempdir = tempfile::tempdir().unwrap(); + std::fs::write( + tempdir.path().join("foo.rs"), + indoc! {" + fn main() { + println!(\"Hello, world!\"); + } + "}, + ) + .expect("failed to write file"); + let project = Project::example([tempdir.path()], &mut cx.to_async()).await; + let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await; + thread + .update(cx, |thread, cx| { + thread.send( + acp::SendUserMessageParams { + chunks: vec![ + acp::UserMessageChunk::Text { + text: "Read the file ".into(), + }, + acp::UserMessageChunk::Path { + path: Path::new("foo.rs").into(), + }, + acp::UserMessageChunk::Text { + text: " and tell me what the content of the println! is".into(), + }, + ], + }, + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, cx| { + assert_eq!(thread.entries().len(), 3); + assert!(matches!( + thread.entries()[0], + AgentThreadEntry::UserMessage(_) + )); + assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_))); + let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else { + panic!("Expected AssistantMessage") + }; + assert!( + assistant_message.to_markdown(cx).contains("Hello, world!"), + "unexpected assistant message: {:?}", + assistant_message.to_markdown(cx) + ); + }); + } + + #[gpui::test] + #[cfg_attr(not(feature = "gemini"), ignore)] + async fn test_gemini_tool_call(cx: &mut TestAppContext) { + init_test(cx); + + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/private/tmp"), + json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), + ) + .await; + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Read the '/private/tmp/foo' file and tell me what you see.", + cx, + ) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, _cx| { + assert!(matches!( + &thread.entries()[2], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + )); + + assert!(matches!( + thread.entries()[3], + AgentThreadEntry::AssistantMessage(_) + )); + }); + } + + #[gpui::test] + #[cfg_attr(not(feature = "gemini"), ignore)] + async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) { + init_test(cx); + + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; + let full_turn = thread.update(cx, |thread, cx| { + thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) + }); + + run_until_first_tool_call(&thread, cx).await; + + let tool_call_id = thread.read_with(cx, |thread, _cx| { + let AgentThreadEntry::ToolCall(ToolCall { + id, + status: + ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::Execute { root_command, .. }, + .. + }, + .. + }) = &thread.entries()[2] + else { + panic!(); + }; + + assert_eq!(root_command, "echo"); + + *id + }); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); + + assert!(matches!( + &thread.entries()[2], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + )); + }); + + full_turn.await.unwrap(); + + thread.read_with(cx, |thread, cx| { + let AgentThreadEntry::ToolCall(ToolCall { + content: Some(ToolCallContent::Markdown { markdown }), + status: ToolCallStatus::Allowed { .. }, + .. + }) = &thread.entries()[2] + else { + panic!(); + }; + + markdown.read_with(cx, |md, _cx| { + assert!( + md.source().contains("Hello, world!"), + r#"Expected '{}' to contain "Hello, world!""#, + md.source() + ); + }); + }); + } + + #[gpui::test] + #[cfg_attr(not(feature = "gemini"), ignore)] + async fn test_gemini_cancel(cx: &mut TestAppContext) { + init_test(cx); + + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; + let full_turn = thread.update(cx, |thread, cx| { + thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) + }); + + let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await; + + thread.read_with(cx, |thread, _cx| { + let AgentThreadEntry::ToolCall(ToolCall { + id, + status: + ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::Execute { root_command, .. }, + .. + }, + .. + }) = &thread.entries()[first_tool_call_ix] + else { + panic!("{:?}", thread.entries()[1]); + }; + + assert_eq!(root_command, "echo"); + + *id + }); + + thread + .update(cx, |thread, cx| thread.cancel(cx)) + .await + .unwrap(); + full_turn.await.unwrap(); + thread.read_with(cx, |thread, _| { + let AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Canceled, + .. + }) = &thread.entries()[first_tool_call_ix] + else { + panic!(); + }; + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw(r#"Stop running and say goodbye to me."#, cx) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, _| { + assert!(matches!( + &thread.entries().last().unwrap(), + AgentThreadEntry::AssistantMessage(..), + )) + }); + } + + async fn run_until_first_tool_call( + thread: &Entity, + cx: &mut TestAppContext, + ) -> usize { + let (mut tx, mut rx) = mpsc::channel::(1); + + let subscription = cx.update(|cx| { + cx.subscribe(thread, move |thread, _, cx| { + for (ix, entry) in thread.read(cx).entries().iter().enumerate() { + if matches!(entry, AgentThreadEntry::ToolCall(_)) { + return tx.try_send(ix).unwrap(); + } + } + }) + }); + + select! { + _ = cx.executor().timer(Duration::from_secs(10)).fuse() => { + panic!("Timeout waiting for tool call") + } + ix = rx.next().fuse() => { + drop(subscription); + ix.unwrap() + } + } + } +} diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs new file mode 100644 index 0000000000..8e6914352b --- /dev/null +++ b/crates/agent_servers/src/settings.rs @@ -0,0 +1,41 @@ +use crate::AgentServerCommand; +use anyhow::Result; +use gpui::App; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsSources}; + +pub fn init(cx: &mut App) { + AllAgentServersSettings::register(cx); +} + +#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)] +pub struct AllAgentServersSettings { + pub gemini: Option, +} + +#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] +pub struct AgentServerSettings { + #[serde(flatten)] + pub command: AgentServerCommand, +} + +impl settings::Settings for AllAgentServersSettings { + const KEY: Option<&'static str> = Some("agent_servers"); + + type FileContent = Self; + + fn load(sources: SettingsSources, _: &mut App) -> Result { + let mut settings = AllAgentServersSettings::default(); + + for value in sources.defaults_and_customizations() { + if value.gemini.is_some() { + settings.gemini = value.gemini.clone(); + } + } + + Ok(settings) + } + + fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} +} diff --git a/crates/agent_servers/src/stdio_agent_server.rs b/crates/agent_servers/src/stdio_agent_server.rs new file mode 100644 index 0000000000..d78506022d --- /dev/null +++ b/crates/agent_servers/src/stdio_agent_server.rs @@ -0,0 +1,169 @@ +use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; +use acp_thread::{AcpClientDelegate, AcpThread, LoadError}; +use agentic_coding_protocol as acp; +use anyhow::{Result, anyhow}; +use gpui::{App, AsyncApp, Entity, Task, prelude::*}; +use project::Project; +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; +use util::{ResultExt, paths}; + +pub trait StdioAgentServer: Send + Clone { + fn logo(&self) -> ui::IconName; + fn name(&self) -> &'static str; + fn empty_state_headline(&self) -> &'static str; + fn empty_state_message(&self) -> &'static str; + fn supports_always_allow(&self) -> bool; + + fn command( + &self, + project: &Entity, + cx: &mut AsyncApp, + ) -> impl Future>; + + fn version( + &self, + command: &AgentServerCommand, + ) -> impl Future> + Send; +} + +impl AgentServer for T { + fn name(&self) -> &'static str { + self.name() + } + + fn empty_state_headline(&self) -> &'static str { + self.empty_state_headline() + } + + fn empty_state_message(&self) -> &'static str { + self.empty_state_message() + } + + fn logo(&self) -> ui::IconName { + self.logo() + } + + fn supports_always_allow(&self) -> bool { + self.supports_always_allow() + } + + fn new_thread( + &self, + root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + let root_dir = root_dir.to_path_buf(); + let project = project.clone(); + let this = self.clone(); + let title = self.name().into(); + + cx.spawn(async move |cx| { + let command = this.command(&project, cx).await?; + + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + cx.new(|cx| { + let foreground_executor = cx.foreground_executor().clone(); + + let (connection, io_fut) = acp::AgentConnection::connect_to_agent( + AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => { + if let Some(AgentServerVersion::Unsupported { + error_message, + upgrade_message, + upgrade_command, + }) = this.version(&command).await.log_err() + { + Err(anyhow!(LoadError::Unsupported { + error_message, + upgrade_message, + upgrade_command + })) + } else { + Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) + } + } + }; + drop(io_task); + result + }); + + AcpThread::new(connection, title, Some(child_status), project.clone(), cx) + }) + }) + } +} + +pub async fn find_bin_in_path( + bin_name: &'static str, + project: &Entity, + cx: &mut AsyncApp, +) -> Option { + let (env_task, root_dir) = project + .update(cx, |project, cx| { + let worktree = project.visible_worktrees(cx).next(); + match worktree { + Some(worktree) => { + let env_task = project.environment().update(cx, |env, cx| { + env.get_worktree_environment(worktree.clone(), cx) + }); + + let path = worktree.read(cx).abs_path(); + (env_task, path) + } + None => { + let path: Arc = paths::home_dir().as_path().into(); + let env_task = project.environment().update(cx, |env, cx| { + env.get_directory_environment(path.clone(), cx) + }); + (env_task, path) + } + } + }) + .log_err()?; + + cx.background_executor() + .spawn(async move { + let which_result = if cfg!(windows) { + which::which(bin_name) + } else { + let env = env_task.await.unwrap_or_default(); + let shell_path = env.get("PATH").cloned(); + which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref()) + }; + + if let Err(which::Error::CannotFindBinaryPath) = which_result { + return None; + } + + which_result.log_err() + }) + .await +} diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 72466fe8e7..d4feceb0b6 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -16,7 +16,7 @@ doctest = false test-support = ["gpui/test-support", "language/test-support"] [dependencies] -acp.workspace = true +acp_thread.workspace = true agent.workspace = true agentic-coding-protocol.workspace = true agent_settings.workspace = true diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 7ab395815f..765f4fe6c0 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,3 +1,4 @@ +use agent_servers::AgentServer; use std::cell::RefCell; use std::collections::BTreeMap; use std::path::Path; @@ -35,7 +36,7 @@ use util::ResultExt; use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; -use ::acp::{ +use ::acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId, ToolCallStatus, @@ -49,6 +50,7 @@ use crate::{AgentDiffPane, Follow, KeepAll, OpenAgentDiff, RejectAll}; const RESPONSE_PADDING_X: Pixels = px(19.); pub struct AcpThreadView { + agent: Rc, workspace: WeakEntity, project: Entity, thread_state: ThreadState, @@ -80,8 +82,15 @@ enum ThreadState { }, } +struct AlwaysAllowOption { + id: &'static str, + label: SharedString, + outcome: acp::ToolCallConfirmationOutcome, +} + impl AcpThreadView { pub fn new( + agent: Rc, workspace: WeakEntity, project: Entity, message_history: Rc>>, @@ -158,9 +167,10 @@ impl AcpThreadView { ); Self { + agent: agent.clone(), workspace: workspace.clone(), project: project.clone(), - thread_state: Self::initial_state(workspace, project, window, cx), + thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, message_set_from_history: false, _message_editor_subscription: message_editor_subscription, @@ -177,6 +187,7 @@ impl AcpThreadView { } fn initial_state( + agent: Rc, workspace: WeakEntity, project: Entity, window: &mut Window, @@ -189,9 +200,9 @@ impl AcpThreadView { .map(|worktree| worktree.read(cx).abs_path()) .unwrap_or_else(|| paths::home_dir().as_path().into()); + let task = agent.new_thread(&root_dir, &project, cx); let load_task = cx.spawn_in(window, async move |this, cx| { - let thread = match AcpThread::spawn(agent_servers::Gemini, &root_dir, project, cx).await - { + let thread = match task.await { Ok(thread) => thread, Err(err) => { this.update(cx, |this, cx| { @@ -410,6 +421,33 @@ impl AcpThreadView { ); } + fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { + if let Some(thread) = self.thread() { + AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err(); + } + } + + fn open_edited_buffer( + &mut self, + buffer: &Entity, + window: &mut Window, + cx: &mut Context, + ) { + let Some(thread) = self.thread() else { + return; + }; + + let Some(diff) = + AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err() + else { + return; + }; + + diff.update(cx, |diff, cx| { + diff.move_to_path(PathKey::for_buffer(&buffer, cx), window, cx) + }) + } + fn set_draft_message( message_editor: Entity, mention_set: Arc>, @@ -485,33 +523,6 @@ impl AcpThreadView { true } - fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context) { - if let Some(thread) = self.thread() { - AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err(); - } - } - - fn open_edited_buffer( - &mut self, - buffer: &Entity, - window: &mut Window, - cx: &mut Context, - ) { - let Some(thread) = self.thread() else { - return; - }; - - let Some(diff) = - AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err() - else { - return; - }; - - diff.update(cx, |diff, cx| { - diff.move_to_path(PathKey::for_buffer(&buffer, cx), window, cx) - }) - } - fn handle_thread_event( &mut self, thread: &Entity, @@ -608,6 +619,7 @@ impl AcpThreadView { let authenticate = thread.read(cx).authenticate(); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); + let agent = self.agent.clone(); async move |this, cx| { let result = authenticate.await; @@ -617,8 +629,13 @@ impl AcpThreadView { Markdown::new(format!("Error: {err}").into(), None, None, cx) })) } else { - this.thread_state = - Self::initial_state(this.workspace.clone(), project.clone(), window, cx) + this.thread_state = Self::initial_state( + agent, + this.workspace.clone(), + project.clone(), + window, + cx, + ) } this.auth_task.take() }) @@ -1047,14 +1064,6 @@ impl AcpThreadView { ) -> AnyElement { let confirmation_container = v_flex().mt_1().py_1p5(); - let button_container = h_flex() - .pt_1p5() - .px_1p5() - .gap_1() - .justify_end() - .border_t_1() - .border_color(self.tool_card_border_color(cx)); - match confirmation { ToolCallConfirmation::Edit { description } => confirmation_container .child( @@ -1068,60 +1077,15 @@ impl AcpThreadView { })), ) .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new(("always_allow", tool_call_id.0), "Always Allow Edits") - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllow, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) + .child(self.render_confirmation_buttons( + &[AlwaysAllowOption { + id: "always_allow", + label: "Always Allow Edits".into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, + }], + tool_call_id, + cx, + )) .into_any(), ToolCallConfirmation::Execute { command, @@ -1140,66 +1104,15 @@ impl AcpThreadView { }), )) .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new( - ("always_allow", tool_call_id.0), - format!("Always Allow {root_command}"), - ) - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllow, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) + .child(self.render_confirmation_buttons( + &[AlwaysAllowOption { + id: "always_allow", + label: format!("Always Allow {root_command}").into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, + }], + tool_call_id, + cx, + )) .into_any(), ToolCallConfirmation::Mcp { server_name, @@ -1220,87 +1133,22 @@ impl AcpThreadView { })), ) .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new( - ("always_allow_server", tool_call_id.0), - format!("Always Allow {server_name}"), - ) - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, - cx, - ); - } - })), - ) - .child( - Button::new( - ("always_allow_tool", tool_call_id.0), - format!("Always Allow {tool_display_name}"), - ) - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllowTool, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) + .child(self.render_confirmation_buttons( + &[ + AlwaysAllowOption { + id: "always_allow_server", + label: format!("Always Allow {server_name}").into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer, + }, + AlwaysAllowOption { + id: "always_allow_tool", + label: format!("Always Allow {tool_display_name}").into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowTool, + }, + ], + tool_call_id, + cx, + )) .into_any(), ToolCallConfirmation::Fetch { description, urls } => confirmation_container .child( @@ -1328,63 +1176,15 @@ impl AcpThreadView { })), ) .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new(("always_allow", tool_call_id.0), "Always Allow") - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllow, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) + .child(self.render_confirmation_buttons( + &[AlwaysAllowOption { + id: "always_allow", + label: "Always Allow".into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, + }], + tool_call_id, + cx, + )) .into_any(), ToolCallConfirmation::Other { description } => confirmation_container .child(v_flex().px_2().pb_1p5().child(self.render_markdown( @@ -1392,67 +1192,87 @@ impl AcpThreadView { default_markdown_style(false, window, cx), ))) .children(content.map(|content| self.render_tool_call_content(content, window, cx))) - .child( - button_container - .child( - Button::new(("always_allow", tool_call_id.0), "Always Allow") - .icon(IconName::CheckDouble) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::AlwaysAllow, - cx, - ); - } - })), - ) - .child( - Button::new(("allow", tool_call_id.0), "Allow") - .icon(IconName::Check) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Success) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Allow, - cx, - ); - } - })), - ) - .child( - Button::new(("reject", tool_call_id.0), "Reject") - .icon(IconName::X) - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon_color(Color::Error) - .label_size(LabelSize::Small) - .on_click(cx.listener({ - let id = tool_call_id; - move |this, _, _, cx| { - this.authorize_tool_call( - id, - acp::ToolCallConfirmationOutcome::Reject, - cx, - ); - } - })), - ), - ) + .child(self.render_confirmation_buttons( + &[AlwaysAllowOption { + id: "always_allow", + label: "Always Allow".into(), + outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow, + }], + tool_call_id, + cx, + )) .into_any(), } } + fn render_confirmation_buttons( + &self, + always_allow_options: &[AlwaysAllowOption], + tool_call_id: ToolCallId, + cx: &Context, + ) -> Div { + h_flex() + .pt_1p5() + .px_1p5() + .gap_1() + .justify_end() + .border_t_1() + .border_color(self.tool_card_border_color(cx)) + .when(self.agent.supports_always_allow(), |this| { + this.children(always_allow_options.into_iter().map(|always_allow_option| { + let outcome = always_allow_option.outcome; + Button::new( + (always_allow_option.id, tool_call_id.0), + always_allow_option.label.clone(), + ) + .icon(IconName::CheckDouble) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Success) + .on_click(cx.listener({ + let id = tool_call_id; + move |this, _, _, cx| { + this.authorize_tool_call(id, outcome, cx); + } + })) + })) + }) + .child( + Button::new(("allow", tool_call_id.0), "Allow") + .icon(IconName::Check) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Success) + .on_click(cx.listener({ + let id = tool_call_id; + move |this, _, _, cx| { + this.authorize_tool_call( + id, + acp::ToolCallConfirmationOutcome::Allow, + cx, + ); + } + })), + ) + .child( + Button::new(("reject", tool_call_id.0), "Reject") + .icon(IconName::X) + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon_color(Color::Error) + .on_click(cx.listener({ + let id = tool_call_id; + move |this, _, _, cx| { + this.authorize_tool_call( + id, + acp::ToolCallConfirmationOutcome::Reject, + cx, + ); + } + })), + ) + } + fn render_diff_editor(&self, multibuffer: &Entity) -> AnyElement { v_flex() .h_full() @@ -1466,15 +1286,15 @@ impl AcpThreadView { .into_any() } - fn render_gemini_logo(&self) -> AnyElement { - Icon::new(IconName::AiGemini) + fn render_agent_logo(&self) -> AnyElement { + Icon::new(self.agent.logo()) .color(Color::Muted) .size(IconSize::XLarge) .into_any_element() } - fn render_error_gemini_logo(&self) -> AnyElement { - let logo = Icon::new(IconName::AiGemini) + fn render_error_agent_logo(&self) -> AnyElement { + let logo = Icon::new(self.agent.logo()) .color(Color::Muted) .size(IconSize::XLarge) .into_any_element(); @@ -1493,49 +1313,50 @@ impl AcpThreadView { .into_any_element() } - fn render_empty_state(&self, loading: bool, cx: &App) -> AnyElement { + fn render_empty_state(&self, cx: &App) -> AnyElement { + let loading = matches!(&self.thread_state, ThreadState::Loading { .. }); + v_flex() .size_full() .items_center() .justify_center() - .child( - if loading { - h_flex() - .justify_center() - .child(self.render_gemini_logo()) - .with_animation( - "pulsating_icon", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between(0.4, 1.0)), - |icon, delta| icon.opacity(delta), - ).into_any() - } else { - self.render_gemini_logo().into_any_element() - } - ) - .child( + .child(if loading { h_flex() - .mt_4() - .mb_1() .justify_center() - .child(Headline::new(if loading { - "Connecting to Gemini…" - } else { - "Welcome to Gemini" - }).size(HeadlineSize::Medium)), - ) + .child(self.render_agent_logo()) + .with_animation( + "pulsating_icon", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 1.0)), + |icon, delta| icon.opacity(delta), + ) + .into_any() + } else { + self.render_agent_logo().into_any_element() + }) + .child(h_flex().mt_4().mb_1().justify_center().child(if loading { + div() + .child(LoadingLabel::new("").size(LabelSize::Large)) + .into_any_element() + } else { + Headline::new(self.agent.empty_state_headline()) + .size(HeadlineSize::Medium) + .into_any_element() + })) .child( div() .max_w_1_2() .text_sm() .text_center() - .map(|this| if loading { - this.invisible() - } else { - this.text_color(cx.theme().colors().text_muted) + .map(|this| { + if loading { + this.invisible() + } else { + this.text_color(cx.theme().colors().text_muted) + } }) - .child("Ask questions, edit files, run commands.\nBe specific for the best results.") + .child(self.agent.empty_state_message()), ) .into_any() } @@ -1544,7 +1365,7 @@ impl AcpThreadView { v_flex() .items_center() .justify_center() - .child(self.render_error_gemini_logo()) + .child(self.render_error_agent_logo()) .child( h_flex() .mt_4() @@ -1559,7 +1380,7 @@ impl AcpThreadView { let mut container = v_flex() .items_center() .justify_center() - .child(self.render_error_gemini_logo()) + .child(self.render_error_agent_logo()) .child( v_flex() .mt_4() @@ -1575,43 +1396,47 @@ impl AcpThreadView { ), ); - if matches!(e, LoadError::Unsupported { .. }) { - container = - container.child(Button::new("upgrade", "Upgrade Gemini to Latest").on_click( - cx.listener(|this, _, window, cx| { - this.workspace - .update(cx, |workspace, cx| { - let project = workspace.project().read(cx); - let cwd = project.first_project_directory(cx); - let shell = project.terminal_settings(&cwd, cx).shell.clone(); - let command = - "npm install -g @google/gemini-cli@latest".to_string(); - let spawn_in_terminal = task::SpawnInTerminal { - id: task::TaskId("install".to_string()), - full_label: command.clone(), - label: command.clone(), - command: Some(command.clone()), - args: Vec::new(), - command_label: command.clone(), - cwd, - env: Default::default(), - use_new_terminal: true, - allow_concurrent_runs: true, - reveal: Default::default(), - reveal_target: Default::default(), - hide: Default::default(), - shell, - show_summary: true, - show_command: true, - show_rerun: false, - }; - workspace - .spawn_in_terminal(spawn_in_terminal, window, cx) - .detach(); - }) - .ok(); - }), - )); + if let LoadError::Unsupported { + upgrade_message, + upgrade_command, + .. + } = &e + { + let upgrade_message = upgrade_message.clone(); + let upgrade_command = upgrade_command.clone(); + container = container.child(Button::new("upgrade", upgrade_message).on_click( + cx.listener(move |this, _, window, cx| { + this.workspace + .update(cx, |workspace, cx| { + let project = workspace.project().read(cx); + let cwd = project.first_project_directory(cx); + let shell = project.terminal_settings(&cwd, cx).shell.clone(); + let spawn_in_terminal = task::SpawnInTerminal { + id: task::TaskId("install".to_string()), + full_label: upgrade_command.clone(), + label: upgrade_command.clone(), + command: Some(upgrade_command.clone()), + args: Vec::new(), + command_label: upgrade_command.clone(), + cwd, + env: Default::default(), + use_new_terminal: true, + allow_concurrent_runs: true, + reveal: Default::default(), + reveal_target: Default::default(), + hide: Default::default(), + shell, + show_summary: true, + show_command: true, + show_rerun: false, + }; + workspace + .spawn_in_terminal(spawn_in_terminal, window, cx) + .detach(); + }) + .ok(); + }), + )); } container.into_any() @@ -2267,20 +2092,23 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::open_agent_diff)) .child(match &self.thread_state { - ThreadState::Unauthenticated { .. } => v_flex() - .p_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .child(h_flex().mt_1p5().justify_center().child( - Button::new("sign-in", "Sign in to Gemini").on_click( - cx.listener(|this, _, window, cx| this.authenticate(window, cx)), - ), - )), - ThreadState::Loading { .. } => { - v_flex().flex_1().child(self.render_empty_state(true, cx)) + ThreadState::Unauthenticated { .. } => { + v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_pending_auth_state()) + .child( + h_flex().mt_1p5().justify_center().child( + Button::new("sign-in", format!("Sign in to {}", self.agent.name())) + .on_click(cx.listener(|this, _, window, cx| { + this.authenticate(window, cx) + })), + ), + ) } + ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2() .flex_1() @@ -2321,7 +2149,7 @@ impl Render for AcpThreadView { }) .children(self.render_edits_bar(&thread, window, cx)) } else { - this.child(self.render_empty_state(false, cx)) + this.child(self.render_empty_state(cx)) } }), }) diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 000e270322..e69664ce88 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1,5 +1,5 @@ use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll}; -use acp::{AcpThread, AcpThreadEvent}; +use acp_thread::{AcpThread, AcpThreadEvent}; use agent::{Thread, ThreadEvent, ThreadSummary}; use agent_settings::AgentSettings; use anyhow::Result; @@ -81,7 +81,7 @@ impl AgentDiffThread { match self { AgentDiffThread::Native(thread) => thread.read(cx).is_generating(), AgentDiffThread::AcpThread(thread) => { - thread.read(cx).status() == acp::ThreadStatus::Generating + thread.read(cx).status() == acp_thread::ThreadStatus::Generating } } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 2caa9dab42..895a499502 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -5,10 +5,11 @@ use std::rc::Rc; use std::sync::Arc; use std::time::Duration; +use agent_servers::AgentServer; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use serde::{Deserialize, Serialize}; -use crate::NewAcpThread; +use crate::NewExternalAgentThread; use crate::agent_diff::AgentDiffThread; use crate::language_model_selector::ToggleModelSelector; use crate::{ @@ -114,10 +115,12 @@ pub fn init(cx: &mut App) { panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx)); } }) - .register_action(|workspace, _: &NewAcpThread, window, cx| { + .register_action(|workspace, action: &NewExternalAgentThread, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); - panel.update(cx, |panel, cx| panel.new_gemini_thread(window, cx)); + panel.update(cx, |panel, cx| { + panel.new_external_thread(action.agent, window, cx) + }); } }) .register_action(|workspace, action: &OpenRulesLibrary, window, cx| { @@ -136,7 +139,7 @@ pub fn init(cx: &mut App) { let thread = thread.read(cx).thread().clone(); AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx); } - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -200,7 +203,7 @@ enum ActiveView { message_editor: Entity, _subscriptions: Vec, }, - AcpThread { + ExternalAgentThread { thread_view: Entity, }, TextThread { @@ -222,9 +225,9 @@ enum WhichFontSize { impl ActiveView { pub fn which_font_size_used(&self) -> WhichFontSize { match self { - ActiveView::Thread { .. } | ActiveView::AcpThread { .. } | ActiveView::History => { - WhichFontSize::AgentFont - } + ActiveView::Thread { .. } + | ActiveView::ExternalAgentThread { .. } + | ActiveView::History => WhichFontSize::AgentFont, ActiveView::TextThread { .. } => WhichFontSize::BufferFont, ActiveView::Configuration => WhichFontSize::None, } @@ -255,7 +258,7 @@ impl ActiveView { thread.scroll_to_bottom(cx); }); } - ActiveView::AcpThread { .. } => {} + ActiveView::ExternalAgentThread { .. } => {} ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -674,7 +677,7 @@ impl AgentPanel { .clone() .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)); } - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -757,7 +760,7 @@ impl AgentPanel { ActiveView::Thread { thread, .. } => { thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx)); } - ActiveView::AcpThread { thread_view, .. } => { + ActiveView::ExternalAgentThread { thread_view, .. } => { thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx)); } ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -767,7 +770,7 @@ impl AgentPanel { fn active_message_editor(&self) -> Option<&Entity> { match &self.active_view { ActiveView::Thread { message_editor, .. } => Some(message_editor), - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None, @@ -889,35 +892,77 @@ impl AgentPanel { context_editor.focus_handle(cx).focus(window); } - fn new_gemini_thread(&mut self, window: &mut Window, cx: &mut Context) { + fn new_external_thread( + &mut self, + agent_choice: Option, + window: &mut Window, + cx: &mut Context, + ) { let workspace = self.workspace.clone(); let project = self.project.clone(); let message_history = self.acp_message_history.clone(); + const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent"; + + #[derive(Default, Serialize, Deserialize)] + struct LastUsedExternalAgent { + agent: crate::ExternalAgent, + } + cx.spawn_in(window, async move |this, cx| { - let thread_view = cx.new_window_entity(|window, cx| { - crate::acp::AcpThreadView::new( - workspace.clone(), - project, - message_history, - window, - cx, - ) - })?; + let server: Rc = match agent_choice { + Some(agent) => { + cx.background_spawn(async move { + if let Some(serialized) = + serde_json::to_string(&LastUsedExternalAgent { agent }).log_err() + { + KEY_VALUE_STORE + .write_kvp(LAST_USED_EXTERNAL_AGENT_KEY.to_string(), serialized) + .await + .log_err(); + } + }) + .detach(); + + agent.server() + } + None => cx + .background_spawn(async move { + KEY_VALUE_STORE.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY) + }) + .await + .log_err() + .flatten() + .and_then(|value| { + serde_json::from_str::(&value).log_err() + }) + .unwrap_or_default() + .agent + .server(), + }; + this.update_in(cx, |this, window, cx| { + let thread_view = cx.new(|cx| { + crate::acp::AcpThreadView::new( + server, + workspace.clone(), + project, + message_history, + window, + cx, + ) + }); + this.set_active_view( - ActiveView::AcpThread { + ActiveView::ExternalAgentThread { thread_view: thread_view.clone(), }, window, cx, ); }) - .log_err(); - - anyhow::Ok(()) }) - .detach(); + .detach_and_log_err(cx); } fn deploy_rules_library( @@ -1084,7 +1129,7 @@ impl AgentPanel { ActiveView::Thread { message_editor, .. } => { message_editor.focus_handle(cx).focus(window); } - ActiveView::AcpThread { thread_view } => { + ActiveView::ExternalAgentThread { thread_view } => { thread_view.focus_handle(cx).focus(window); } ActiveView::TextThread { context_editor, .. } => { @@ -1211,7 +1256,7 @@ impl AgentPanel { }) .log_err(); } - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -1267,7 +1312,7 @@ impl AgentPanel { ) .detach_and_log_err(cx); } - ActiveView::AcpThread { thread_view } => { + ActiveView::ExternalAgentThread { thread_view } => { thread_view .update(cx, |thread_view, cx| { thread_view.open_thread_as_markdown(workspace, window, cx) @@ -1428,7 +1473,7 @@ impl AgentPanel { } }) } - ActiveView::AcpThread { .. } => {} + ActiveView::ExternalAgentThread { .. } => {} ActiveView::History | ActiveView::Configuration => {} } @@ -1517,7 +1562,7 @@ impl Focusable for AgentPanel { fn focus_handle(&self, cx: &App) -> FocusHandle { match &self.active_view { ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx), - ActiveView::AcpThread { thread_view, .. } => thread_view.focus_handle(cx), + ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx), ActiveView::History => self.history.focus_handle(cx), ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx), ActiveView::Configuration => { @@ -1674,9 +1719,11 @@ impl AgentPanel { .into_any_element(), } } - ActiveView::AcpThread { thread_view } => Label::new(thread_view.read(cx).title(cx)) - .truncate() - .into_any_element(), + ActiveView::ExternalAgentThread { thread_view } => { + Label::new(thread_view.read(cx).title(cx)) + .truncate() + .into_any_element() + } ActiveView::TextThread { title_editor, context_editor, @@ -1811,7 +1858,7 @@ impl AgentPanel { let active_thread = match &self.active_view { ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()), - ActiveView::AcpThread { .. } + ActiveView::ExternalAgentThread { .. } | ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None, @@ -1849,7 +1896,20 @@ impl AgentPanel { .when(cx.has_flag::(), |this| { this.separator() .header("External Agents") - .action("New Gemini Thread", NewAcpThread.boxed_clone()) + .action( + "New Gemini Thread", + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::Gemini), + } + .boxed_clone(), + ) + .action( + "New Claude Code Thread", + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::ClaudeCode), + } + .boxed_clone(), + ) }); menu })) @@ -2090,7 +2150,11 @@ impl AgentPanel { Some(element.into_any_element()) } - _ => None, + ActiveView::ExternalAgentThread { .. } + | ActiveView::History + | ActiveView::Configuration => { + return None; + } } } @@ -2119,7 +2183,7 @@ impl AgentPanel { return false; } } - ActiveView::AcpThread { .. } => { + ActiveView::ExternalAgentThread { .. } => { return false; } ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => { @@ -2706,7 +2770,7 @@ impl AgentPanel { ) -> Option { let active_thread = match &self.active_view { ActiveView::Thread { thread, .. } => thread, - ActiveView::AcpThread { .. } => { + ActiveView::ExternalAgentThread { .. } => { return None; } ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => { @@ -3055,7 +3119,7 @@ impl AgentPanel { .detach(); }); } - ActiveView::AcpThread { .. } => { + ActiveView::ExternalAgentThread { .. } => { unimplemented!() } ActiveView::TextThread { context_editor, .. } => { @@ -3077,7 +3141,7 @@ impl AgentPanel { let mut key_context = KeyContext::new_with_defaults(); key_context.add("AgentPanel"); match &self.active_view { - ActiveView::AcpThread { .. } => key_context.add("acp_thread"), + ActiveView::ExternalAgentThread { .. } => key_context.add("external_agent_thread"), ActiveView::TextThread { .. } => key_context.add("prompt_editor"), ActiveView::Thread { .. } | ActiveView::History | ActiveView::Configuration => {} } @@ -3133,7 +3197,7 @@ impl Render for AgentPanel { }); this.continue_conversation(window, cx); } - ActiveView::AcpThread { .. } => {} + ActiveView::ExternalAgentThread { .. } => {} ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {} @@ -3175,7 +3239,7 @@ impl Render for AgentPanel { }) .child(h_flex().child(message_editor.clone())) .child(self.render_drag_target(cx)), - ActiveView::AcpThread { thread_view, .. } => parent + ActiveView::ExternalAgentThread { thread_view, .. } => parent .relative() .child(thread_view.clone()) .child(self.render_drag_target(cx)), diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 3170ec4a26..7f69e8f66e 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -25,6 +25,7 @@ mod thread_history; mod tool_compatibility; mod ui; +use std::rc::Rc; use std::sync::Arc; use agent::{Thread, ThreadId}; @@ -40,7 +41,7 @@ use language_model::{ }; use prompt_store::PromptBuilder; use schemars::JsonSchema; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; pub use crate::active_thread::ActiveThread; @@ -57,8 +58,6 @@ actions!( [ /// Creates a new text-based conversation thread. NewTextThread, - /// Creates a new external agent conversation thread. - NewAcpThread, /// Toggles the context picker interface for adding files, symbols, or other context. ToggleContextPicker, /// Toggles the navigation menu for switching between threads and views. @@ -133,6 +132,32 @@ pub struct NewThread { from_thread_id: Option, } +/// Creates a new external agent conversation thread. +#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)] +#[action(namespace = agent)] +#[serde(deny_unknown_fields)] +pub struct NewExternalAgentThread { + /// Which agent to use for the conversation. + agent: Option, +} + +#[derive(Default, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +enum ExternalAgent { + #[default] + Gemini, + ClaudeCode, +} + +impl ExternalAgent { + pub fn server(&self) -> Rc { + match self { + ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), + ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), + } + } +} + /// Opens the profile management interface for configuring agent tools and settings. #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)] #[action(namespace = agent)] diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index 96bb9e071f..5e4f8369c4 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -21,12 +21,14 @@ collections.workspace = true futures.workspace = true gpui.workspace = true log.workspace = true +net.workspace = true parking_lot.workspace = true postage.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true smol.workspace = true +tempfile.workspace = true url = { workspace = true, features = ["serde"] } util.workspace = true workspace-hack.workspace = true diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 83d815432d..6b24d9b136 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -70,12 +70,12 @@ fn is_null_value(value: &T) -> bool { } #[derive(Serialize, Deserialize)] -struct Request<'a, T> { - jsonrpc: &'static str, - id: RequestId, - method: &'a str, +pub struct Request<'a, T> { + pub jsonrpc: &'static str, + pub id: RequestId, + pub method: &'a str, #[serde(skip_serializing_if = "is_null_value")] - params: T, + pub params: T, } #[derive(Serialize, Deserialize)] @@ -88,18 +88,18 @@ struct AnyResponse<'a> { result: Option<&'a RawValue>, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[allow(dead_code)] -struct Response { - jsonrpc: &'static str, - id: RequestId, +pub(crate) struct Response { + pub jsonrpc: &'static str, + pub id: RequestId, #[serde(flatten)] - value: CspResult, + pub value: CspResult, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -enum CspResult { +pub(crate) enum CspResult { #[serde(rename = "result")] Ok(Option), #[allow(dead_code)] @@ -123,8 +123,9 @@ struct AnyNotification<'a> { } #[derive(Debug, Serialize, Deserialize)] -struct Error { - message: String, +pub(crate) struct Error { + pub message: String, + pub code: i32, } #[derive(Debug, Clone, Deserialize)] diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 905435fcce..807b17f1ca 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -1,4 +1,5 @@ pub mod client; +pub mod listener; pub mod protocol; #[cfg(any(test, feature = "test-support"))] pub mod test; diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs new file mode 100644 index 0000000000..9295ad979c --- /dev/null +++ b/crates/context_server/src/listener.rs @@ -0,0 +1,236 @@ +use ::serde::{Deserialize, Serialize}; +use anyhow::{Context as _, Result}; +use collections::HashMap; +use futures::{ + AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, + channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}, + io::BufReader, + select_biased, +}; +use gpui::{App, AppContext, AsyncApp, Task}; +use net::async_net::{UnixListener, UnixStream}; +use serde_json::{json, value::RawValue}; +use smol::stream::StreamExt; +use std::{ + cell::RefCell, + path::{Path, PathBuf}, + rc::Rc, +}; +use util::ResultExt; + +use crate::{ + client::{CspResult, RequestId, Response}, + types::Request, +}; + +pub struct McpServer { + socket_path: PathBuf, + handlers: Rc>>, + _server_task: Task<()>, +} + +type McpHandler = Box>, &App) -> Task>; + +impl McpServer { + pub fn new(cx: &AsyncApp) -> Task> { + let task = cx.background_spawn(async move { + let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?; + let socket_path = temp_dir.path().join("mcp.sock"); + let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?; + + anyhow::Ok((temp_dir, socket_path, listener)) + }); + + cx.spawn(async move |cx| { + let (temp_dir, socket_path, listener) = task.await?; + let handlers = Rc::new(RefCell::new(HashMap::default())); + let server_task = cx.spawn({ + let handlers = handlers.clone(); + async move |cx| { + while let Ok((stream, _)) = listener.accept().await { + Self::serve_connection(stream, handlers.clone(), cx); + } + drop(temp_dir) + } + }); + Ok(Self { + socket_path, + _server_task: server_task, + handlers: handlers.clone(), + }) + }) + } + + pub fn handle_request( + &mut self, + f: impl Fn(R::Params, &App) -> Task> + 'static, + ) { + let f = Box::new(f); + self.handlers.borrow_mut().insert( + R::METHOD, + Box::new(move |req_id, opt_params, cx| { + let result = match opt_params { + Some(params) => serde_json::from_str(params.get()), + None => serde_json::from_value(serde_json::Value::Null), + }; + + let params: R::Params = match result { + Ok(params) => params, + Err(e) => { + return Task::ready( + serde_json::to_string(&Response:: { + jsonrpc: "2.0", + id: req_id, + value: CspResult::Error(Some(crate::client::Error { + message: format!("{e}"), + code: -32700, + })), + }) + .unwrap(), + ); + } + }; + let task = f(params, cx); + cx.background_spawn(async move { + match task.await { + Ok(result) => serde_json::to_string(&Response { + jsonrpc: "2.0", + id: req_id, + value: CspResult::Ok(Some(result)), + }) + .unwrap(), + Err(e) => serde_json::to_string(&Response { + jsonrpc: "2.0", + id: req_id, + value: CspResult::Error::(Some(crate::client::Error { + message: format!("{e}"), + code: -32603, + })), + }) + .unwrap(), + } + }) + }), + ); + } + + pub fn socket_path(&self) -> &Path { + &self.socket_path + } + + fn serve_connection( + stream: UnixStream, + handlers: Rc>>, + cx: &mut AsyncApp, + ) { + let (read, write) = smol::io::split(stream); + let (incoming_tx, mut incoming_rx) = unbounded(); + let (outgoing_tx, outgoing_rx) = unbounded(); + + cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read)) + .detach(); + + cx.spawn(async move |cx| { + while let Some(request) = incoming_rx.next().await { + let Some(request_id) = request.id.clone() else { + continue; + }; + if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) { + let outgoing_tx = outgoing_tx.clone(); + + if let Some(task) = cx + .update(|cx| handler(request_id, request.params, cx)) + .log_err() + { + cx.spawn(async move |_| { + let response = task.await; + outgoing_tx.unbounded_send(response).ok(); + }) + .detach(); + } + } else { + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response::<()> { + jsonrpc: "2.0", + id: request.id.unwrap(), + value: CspResult::Error(Some(crate::client::Error { + message: format!("unhandled method {}", request.method), + code: -32601, + })), + }) + .unwrap(), + ) + .ok(); + } + } + }) + .detach(); + } + + async fn handle_io( + mut outgoing_rx: UnboundedReceiver, + incoming_tx: UnboundedSender, + mut outgoing_bytes: impl Unpin + AsyncWrite, + incoming_bytes: impl Unpin + AsyncRead, + ) -> Result<()> { + let mut output_reader = BufReader::new(incoming_bytes); + let mut incoming_line = String::new(); + loop { + select_biased! { + message = outgoing_rx.next().fuse() => { + if let Some(message) = message { + log::trace!("send: {}", &message); + outgoing_bytes.write_all(message.as_bytes()).await?; + outgoing_bytes.write_all(&[b'\n']).await?; + } else { + break; + } + } + bytes_read = output_reader.read_line(&mut incoming_line).fuse() => { + if bytes_read? == 0 { + break + } + log::trace!("recv: {}", &incoming_line); + match serde_json::from_str(&incoming_line) { + Ok(message) => { + incoming_tx.unbounded_send(message).log_err(); + } + Err(error) => { + outgoing_bytes.write_all(serde_json::to_string(&json!({ + "jsonrpc": "2.0", + "error": json!({ + "code": -32603, + "message": format!("Failed to parse: {error}"), + }), + }))?.as_bytes()).await?; + outgoing_bytes.write_all(&[b'\n']).await?; + log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}"); + } + } + incoming_line.clear(); + } + } + } + Ok(()) + } +} + +#[derive(Serialize, Deserialize)] +struct RawRequest { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + method: String, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option>, +} + +#[derive(Serialize, Deserialize)] +struct RawResponse { + jsonrpc: &'static str, + id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option>, +} diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 8e3daf9e22..4a6fdcabd3 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -153,7 +153,7 @@ pub struct InitializeParams { pub struct CallToolParams { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option>, + pub arguments: Option, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index b29a8b78e6..6834d56215 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -11,6 +11,7 @@ pub enum IconName { Ai, AiAnthropic, AiBedrock, + AiClaude, AiDeepSeek, AiEdit, AiGemini, diff --git a/crates/nc/Cargo.toml b/crates/nc/Cargo.toml new file mode 100644 index 0000000000..46ef2d3c62 --- /dev/null +++ b/crates/nc/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "nc" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/nc.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +futures.workspace = true +net.workspace = true +smol.workspace = true +workspace-hack.workspace = true diff --git a/crates/nc/LICENSE-GPL b/crates/nc/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/nc/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/nc/src/nc.rs b/crates/nc/src/nc.rs new file mode 100644 index 0000000000..fccb4d726c --- /dev/null +++ b/crates/nc/src/nc.rs @@ -0,0 +1,56 @@ +use anyhow::Result; + +#[cfg(windows)] +pub fn main(_socket: &str) -> Result<()> { + // It looks like we can't get an async stdio stream on Windows from smol. + // + // We decided to merge this with a panic on Windows since this is only used + // by the experimental Claude Code Agent Server. + // + // We're tracking this internally, and we will address it before shipping the integration. + panic!("--nc isn't yet supported on Windows"); +} + +/// The main function for when Zed is running in netcat mode +#[cfg(not(windows))] +pub fn main(socket: &str) -> Result<()> { + use futures::{AsyncReadExt as _, AsyncWriteExt as _, FutureExt as _, io::BufReader, select}; + use net::async_net::UnixStream; + use smol::{Async, io::AsyncBufReadExt}; + + smol::block_on(async { + let socket_stream = UnixStream::connect(socket).await?; + let (socket_read, mut socket_write) = socket_stream.split(); + let mut socket_reader = BufReader::new(socket_read); + + let mut stdout = Async::new(std::io::stdout())?; + let stdin = Async::new(std::io::stdin())?; + let mut stdin_reader = BufReader::new(stdin); + + let mut socket_line = Vec::new(); + let mut stdin_line = Vec::new(); + + loop { + select! { + bytes_read = socket_reader.read_until(b'\n', &mut socket_line).fuse() => { + if bytes_read? == 0 { + break + } + stdout.write_all(&socket_line).await?; + stdout.flush().await?; + socket_line.clear(); + } + bytes_read = stdin_reader.read_until(b'\n', &mut stdin_line).fuse() => { + if bytes_read? == 0 { + break + } + socket_write.write_all(&stdin_line).await?; + socket_write.flush().await?; + stdin_line.clear(); + } + } + } + + anyhow::Ok(()) + }) +} diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index ae96a48b53..bbceb3f101 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -95,6 +95,7 @@ svg_preview.workspace = true menu.workspace = true migrator.workspace = true mimalloc = { version = "0.1", optional = true } +nc.workspace = true nix = { workspace = true, features = ["pthread", "signal"] } node_runtime.workspace = true notifications.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 5eb96f21a4..89b9fad6bf 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -175,6 +175,17 @@ pub fn main() { return; } + // `zed --nc` Makes zed operate in nc/netcat mode for use with MCP + if let Some(socket) = &args.nc { + match nc::main(socket) { + Ok(()) => return, + Err(err) => { + eprintln!("Error: {}", err); + process::exit(1); + } + } + } + // `zed --printenv` Outputs environment variables as JSON to stdout if args.printenv { util::shell_env::print_env(); @@ -1168,6 +1179,11 @@ struct Args { #[arg(long, hide = true)] askpass: Option, + /// Used for the MCP Server, to remove the need for netcat as a dependency, + /// by having Zed act like netcat communicating over a Unix socket. + #[arg(long, hide = true)] + nc: Option, + /// Run zed in the foreground, only used on Windows, to match the behavior on macOS. #[arg(long)] #[cfg(target_os = "windows")]