From dab0b3509d2394cd073cbd8b4260454c1f1d5278 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 17 Jul 2025 14:34:38 -0300 Subject: [PATCH] Unify agent server settings and extract e2e tests out (#34642) Release Notes: - N/A --- crates/agent_servers/Cargo.toml | 2 +- crates/agent_servers/src/agent_servers.rs | 125 +++++- crates/agent_servers/src/claude.rs | 75 ++-- crates/agent_servers/src/e2e_tests.rs | 368 +++++++++++++++ crates/agent_servers/src/gemini.rs | 422 +----------------- crates/agent_servers/src/settings.rs | 1 + .../agent_servers/src/stdio_agent_server.rs | 54 +-- 7 files changed, 547 insertions(+), 500 deletions(-) create mode 100644 crates/agent_servers/src/e2e_tests.rs diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index d65235aee3..2d68148264 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -7,7 +7,7 @@ license = "GPL-3.0-or-later" [features] test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"] -gemini = [] +e2e = [] [lints] workspace = true diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index ebebeca511..6d9c77f296 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -3,6 +3,9 @@ mod gemini; mod settings; mod stdio_agent_server; +#[cfg(test)] +mod e2e_tests; + pub use claude::*; pub use gemini::*; pub use settings::*; @@ -11,34 +14,20 @@ pub use stdio_agent_server::*; use acp_thread::AcpThread; use anyhow::Result; use collections::HashMap; -use gpui::{App, Entity, SharedString, Task}; +use gpui::{App, AsyncApp, Entity, SharedString, Task}; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::path::{Path, PathBuf}; +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; +use util::ResultExt as _; pub fn init(cx: &mut App) { settings::init(cx); } -#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] -pub struct AgentServerCommand { - #[serde(rename = "command")] - pub path: PathBuf, - #[serde(default)] - pub args: Vec, - pub env: Option>, -} - -pub enum AgentServerVersion { - Supported, - Unsupported { - error_message: SharedString, - upgrade_message: SharedString, - upgrade_command: String, - }, -} - pub trait AgentServer: Send { fn logo(&self) -> ui::IconName; fn name(&self) -> &'static str; @@ -78,3 +67,99 @@ impl std::fmt::Debug for AgentServerCommand { .finish() } } + +pub enum AgentServerVersion { + Supported, + Unsupported { + error_message: SharedString, + upgrade_message: SharedString, + upgrade_command: String, + }, +} + +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] +pub struct AgentServerCommand { + #[serde(rename = "command")] + pub path: PathBuf, + #[serde(default)] + pub args: Vec, + pub env: Option>, +} + +impl AgentServerCommand { + pub(crate) async fn resolve( + path_bin_name: &'static str, + extra_args: &[&'static str], + settings: Option, + project: &Entity, + cx: &mut AsyncApp, + ) -> Option { + if let Some(agent_settings) = settings { + return Some(Self { + path: agent_settings.command.path, + args: agent_settings + .command + .args + .into_iter() + .chain(extra_args.iter().map(|arg| arg.to_string())) + .collect(), + env: agent_settings.command.env, + }); + } else { + find_bin_in_path(path_bin_name, project, cx) + .await + .map(|path| Self { + path, + args: extra_args.iter().map(|arg| arg.to_string()).collect(), + env: None, + }) + } + } +} + +async fn find_bin_in_path( + bin_name: &'static str, + project: &Entity, + cx: &mut AsyncApp, +) -> Option { + let (env_task, root_dir) = project + .update(cx, |project, cx| { + let worktree = project.visible_worktrees(cx).next(); + match worktree { + Some(worktree) => { + let env_task = project.environment().update(cx, |env, cx| { + env.get_worktree_environment(worktree.clone(), cx) + }); + + let path = worktree.read(cx).abs_path(); + (env_task, path) + } + None => { + let path: Arc = paths::home_dir().as_path().into(); + let env_task = project.environment().update(cx, |env, cx| { + env.get_directory_environment(path.clone(), cx) + }); + (env_task, path) + } + } + }) + .log_err()?; + + cx.background_executor() + .spawn(async move { + let which_result = if cfg!(windows) { + which::which(bin_name) + } else { + let env = env_task.await.unwrap_or_default(); + let shell_path = env.get("PATH").cloned(); + which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref()) + }; + + if let Err(which::Error::CannotFindBinaryPath) = which_result { + return None; + } + + which_result.log_err() + }) + .await +} diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 897158dc57..5760a96d8c 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -3,6 +3,7 @@ mod tools; use collections::HashMap; use project::Project; +use settings::SettingsStore; use std::cell::RefCell; use std::fmt::Display; use std::path::Path; @@ -12,7 +13,7 @@ use agentic_coding_protocol::{ self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion, StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams, }; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use futures::channel::oneshot; use futures::future::LocalBoxFuture; use futures::{AsyncBufReadExt, AsyncWriteExt}; @@ -28,7 +29,7 @@ use util::ResultExt; use crate::claude::mcp_server::ClaudeMcpServer; use crate::claude::tools::ClaudeTool; -use crate::{AgentServer, find_bin_in_path}; +use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection}; #[derive(Clone)] @@ -87,31 +88,41 @@ impl AgentServer for ClaudeCode { .await?; mcp_config_file.flush().await?; - let command = find_bin_in_path("claude", &project, cx) - .await - .context("Failed to find claude binary")?; + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).claude.clone() + })?; - let mut child = util::command::new_smol_command(&command) - .args([ - "--input-format", - "stream-json", - "--output-format", - "stream-json", - "--print", - "--verbose", - "--mcp-config", - mcp_config_path.to_string_lossy().as_ref(), - "--permission-prompt-tool", - &format!( - "mcp__{}__{}", - mcp_server::SERVER_NAME, - mcp_server::PERMISSION_TOOL - ), - "--allowedTools", - "mcp__zed__Read,mcp__zed__Edit", - "--disallowedTools", - "Read,Edit", - ]) + let Some(command) = + AgentServerCommand::resolve("claude", &[], settings, &project, cx).await + else { + anyhow::bail!("Failed to find claude binary"); + }; + + let mut child = util::command::new_smol_command(&command.path) + .args( + [ + "--input-format", + "stream-json", + "--output-format", + "stream-json", + "--print", + "--verbose", + "--mcp-config", + mcp_config_path.to_string_lossy().as_ref(), + "--permission-prompt-tool", + &format!( + "mcp__{}__{}", + mcp_server::SERVER_NAME, + mcp_server::PERMISSION_TOOL + ), + "--allowedTools", + "mcp__zed__Read,mcp__zed__Edit", + "--disallowedTools", + "Read,Edit", + ] + .into_iter() + .chain(command.args.iter().map(|arg| arg.as_str())), + ) .current_dir(root_dir) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) @@ -562,10 +573,20 @@ struct McpServerConfig { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use serde_json::json; + // crate::common_e2e_tests!(ClaudeCode); + + pub fn local_command() -> AgentServerCommand { + AgentServerCommand { + path: "claude".into(), + args: vec![], + env: None, + } + } + #[test] fn test_deserialize_content_untagged_text() { let json = json!("Hello, world!"); diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs new file mode 100644 index 0000000000..923c6cdd6f --- /dev/null +++ b/crates/agent_servers/src/e2e_tests.rs @@ -0,0 +1,368 @@ +use std::{path::Path, sync::Arc, time::Duration}; + +use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; +use acp_thread::{ + AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus, +}; +use agentic_coding_protocol as acp; +use futures::{FutureExt, StreamExt, channel::mpsc, select}; +use gpui::{Entity, TestAppContext}; +use indoc::indoc; +use project::{FakeFs, Project}; +use serde_json::json; +use settings::{Settings, SettingsStore}; +use util::path; + +pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + let project = Project::test(fs, [], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) + .await + .unwrap(); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.entries().len(), 2); + assert!(matches!( + thread.entries()[0], + AgentThreadEntry::UserMessage(_) + )); + assert!(matches!( + thread.entries()[1], + AgentThreadEntry::AssistantMessage(_) + )); + }); +} + +pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let _fs = init_test(cx).await; + + let tempdir = tempfile::tempdir().unwrap(); + std::fs::write( + tempdir.path().join("foo.rs"), + indoc! {" + fn main() { + println!(\"Hello, world!\"); + } + "}, + ) + .expect("failed to write file"); + let project = Project::example([tempdir.path()], &mut cx.to_async()).await; + let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await; + thread + .update(cx, |thread, cx| { + thread.send( + acp::SendUserMessageParams { + chunks: vec![ + acp::UserMessageChunk::Text { + text: "Read the file ".into(), + }, + acp::UserMessageChunk::Path { + path: Path::new("foo.rs").into(), + }, + acp::UserMessageChunk::Text { + text: " and tell me what the content of the println! is".into(), + }, + ], + }, + cx, + ) + }) + .await + .unwrap(); + + thread.read_with(cx, |thread, cx| { + assert_eq!(thread.entries().len(), 3); + assert!(matches!( + thread.entries()[0], + AgentThreadEntry::UserMessage(_) + )); + assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_))); + let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else { + panic!("Expected AssistantMessage") + }; + assert!( + assistant_message.to_markdown(cx).contains("Hello, world!"), + "unexpected assistant message: {:?}", + assistant_message.to_markdown(cx) + ); + }); +} + +pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + fs.insert_tree( + path!("/private/tmp"), + json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), + ) + .await; + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + + thread + .update(cx, |thread, cx| { + thread.send_raw( + "Read the '/private/tmp/foo' file and tell me what you see.", + cx, + ) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, _cx| { + assert!(matches!( + &thread.entries()[2], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + )); + + assert!(matches!( + thread.entries()[3], + AgentThreadEntry::AssistantMessage(_) + )); + }); +} + +pub async fn test_tool_call_with_confirmation( + server: impl AgentServer + 'static, + cx: &mut TestAppContext, +) { + let fs = init_test(cx).await; + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + let full_turn = thread.update(cx, |thread, cx| { + thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) + }); + + run_until_first_tool_call(&thread, cx).await; + + let tool_call_id = thread.read_with(cx, |thread, _cx| { + let AgentThreadEntry::ToolCall(ToolCall { + id, + status: + ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::Execute { root_command, .. }, + .. + }, + .. + }) = &thread.entries()[2] + else { + panic!(); + }; + + assert_eq!(root_command, "echo"); + + *id + }); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); + + assert!(matches!( + &thread.entries()[2], + AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Allowed { .. }, + .. + }) + )); + }); + + full_turn.await.unwrap(); + + thread.read_with(cx, |thread, cx| { + let AgentThreadEntry::ToolCall(ToolCall { + content: Some(ToolCallContent::Markdown { markdown }), + status: ToolCallStatus::Allowed { .. }, + .. + }) = &thread.entries()[2] + else { + panic!(); + }; + + markdown.read_with(cx, |md, _cx| { + assert!( + md.source().contains("Hello, world!"), + r#"Expected '{}' to contain "Hello, world!""#, + md.source() + ); + }); + }); +} + +pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) { + let fs = init_test(cx).await; + + let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; + let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; + let full_turn = thread.update(cx, |thread, cx| { + thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) + }); + + let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await; + + thread.read_with(cx, |thread, _cx| { + let AgentThreadEntry::ToolCall(ToolCall { + id, + status: + ToolCallStatus::WaitingForConfirmation { + confirmation: ToolCallConfirmation::Execute { root_command, .. }, + .. + }, + .. + }) = &thread.entries()[first_tool_call_ix] + else { + panic!("{:?}", thread.entries()[1]); + }; + + assert_eq!(root_command, "echo"); + + *id + }); + + thread + .update(cx, |thread, cx| thread.cancel(cx)) + .await + .unwrap(); + full_turn.await.unwrap(); + thread.read_with(cx, |thread, _| { + let AgentThreadEntry::ToolCall(ToolCall { + status: ToolCallStatus::Canceled, + .. + }) = &thread.entries()[first_tool_call_ix] + else { + panic!(); + }; + }); + + thread + .update(cx, |thread, cx| { + thread.send_raw(r#"Stop running and say goodbye to me."#, cx) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, _| { + assert!(matches!( + &thread.entries().last().unwrap(), + AgentThreadEntry::AssistantMessage(..), + )) + }); +} + +#[macro_export] +macro_rules! common_e2e_tests { + ($server:expr) => { + mod common_e2e { + use super::*; + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn basic(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_basic($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn path_mentions(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_path_mentions($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn tool_call(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_tool_call($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await; + } + + #[::gpui::test] + #[cfg_attr(not(feature = "e2e"), ignore)] + async fn cancel(cx: &mut ::gpui::TestAppContext) { + $crate::e2e_tests::test_cancel($server, cx).await; + } + } + }; +} + +// Helpers + +pub async fn init_test(cx: &mut TestAppContext) -> Arc { + env_logger::try_init().ok(); + + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + crate::settings::init(cx); + + crate::AllAgentServersSettings::override_global( + AllAgentServersSettings { + claude: Some(AgentServerSettings { + command: crate::claude::tests::local_command(), + }), + gemini: Some(AgentServerSettings { + command: crate::gemini::tests::local_command(), + }), + }, + cx, + ); + }); + + cx.executor().allow_parking(); + + FakeFs::new(cx.executor()) +} + +pub async fn new_test_thread( + server: impl AgentServer + 'static, + project: Entity, + current_dir: impl AsRef, + cx: &mut TestAppContext, +) -> Entity { + let thread = cx + .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx)) + .await + .unwrap(); + + thread + .update(cx, |thread, _| thread.initialize()) + .await + .unwrap(); + thread +} + +pub async fn run_until_first_tool_call( + thread: &Entity, + cx: &mut TestAppContext, +) -> usize { + let (mut tx, mut rx) = mpsc::channel::(1); + + let subscription = cx.update(|cx| { + cx.subscribe(thread, move |thread, _, cx| { + for (ix, entry) in thread.read(cx).entries().iter().enumerate() { + if matches!(entry, AgentThreadEntry::ToolCall(_)) { + return tx.try_send(ix).unwrap(); + } + } + }) + }); + + select! { + // We have to use a smol timer here because + // cx.background_executor().timer isn't real in the test context + _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => { + panic!("Timeout waiting for tool call") + } + ix = rx.next().fuse() => { + drop(subscription); + ix.unwrap() + } + } +} diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index bf1d13429e..8ad147cbff 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,4 +1,4 @@ -use crate::stdio_agent_server::{StdioAgentServer, find_bin_in_path}; +use crate::stdio_agent_server::StdioAgentServer; use crate::{AgentServerCommand, AgentServerVersion}; use anyhow::{Context as _, Result}; use gpui::{AsyncApp, Entity}; @@ -38,35 +38,15 @@ impl StdioAgentServer for Gemini { project: &Entity, cx: &mut AsyncApp, ) -> Result { - let custom_command = cx.read_global(|settings: &SettingsStore, _| { - let settings = settings.get::(None); - settings - .gemini - .as_ref() - .map(|gemini_settings| AgentServerCommand { - path: gemini_settings.command.path.clone(), - args: gemini_settings - .command - .args - .iter() - .cloned() - .chain(std::iter::once(ACP_ARG.into())) - .collect(), - env: gemini_settings.command.env.clone(), - }) + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() })?; - if let Some(custom_command) = custom_command { - return Ok(custom_command); - } - - if let Some(path) = find_bin_in_path("gemini", project, cx).await { - return Ok(AgentServerCommand { - path, - args: vec![ACP_ARG.into()], - env: None, - }); - } + if let Some(command) = + AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await + { + return Ok(command); + }; let (fs, node_runtime) = project.update(cx, |project, _| { (project.fs().clone(), project.node_runtime().cloned()) @@ -121,381 +101,23 @@ impl StdioAgentServer for Gemini { } #[cfg(test)] -mod test { - use std::{path::Path, time::Duration}; +pub(crate) mod tests { + use super::*; + use crate::AgentServerCommand; + use std::path::Path; - use acp_thread::{ - AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, - ToolCallStatus, - }; - use agentic_coding_protocol as acp; - use anyhow::Result; - use futures::{FutureExt, StreamExt, channel::mpsc, select}; - use gpui::{AsyncApp, Entity, TestAppContext}; - use indoc::indoc; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use util::path; + crate::common_e2e_tests!(Gemini); - use crate::{AgentServer, AgentServerCommand, AgentServerVersion, StdioAgentServer}; + pub fn local_command() -> AgentServerCommand { + let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../../../gemini-cli/packages/cli") + .to_string_lossy() + .to_string(); - pub async fn gemini_acp_thread( - project: Entity, - current_dir: impl AsRef, - cx: &mut TestAppContext, - ) -> Entity { - #[derive(Clone)] - struct DevGemini; - - impl StdioAgentServer for DevGemini { - async fn command( - &self, - _project: &Entity, - _cx: &mut AsyncApp, - ) -> Result { - let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) - .join("../../../gemini-cli/packages/cli") - .to_string_lossy() - .to_string(); - - Ok(AgentServerCommand { - path: "node".into(), - args: vec![cli_path, "--experimental-acp".into()], - env: None, - }) - } - - async fn version(&self, _command: &AgentServerCommand) -> Result { - Ok(AgentServerVersion::Supported) - } - - fn logo(&self) -> ui::IconName { - ui::IconName::AiGemini - } - - fn name(&self) -> &'static str { - "test" - } - - fn empty_state_headline(&self) -> &'static str { - "test" - } - - fn empty_state_message(&self) -> &'static str { - "test" - } - - fn supports_always_allow(&self) -> bool { - true - } - } - - let thread = cx - .update(|cx| AgentServer::new_thread(&DevGemini, current_dir.as_ref(), &project, cx)) - .await - .unwrap(); - - thread - .update(cx, |thread, _| thread.initialize()) - .await - .unwrap(); - thread - } - - fn init_test(cx: &mut TestAppContext) { - env_logger::try_init().ok(); - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - language::init(cx); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_basic(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - thread - .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx)) - .await - .unwrap(); - - thread.read_with(cx, |thread, _| { - assert_eq!(thread.entries().len(), 2); - assert!(matches!( - thread.entries()[0], - AgentThreadEntry::UserMessage(_) - )); - assert!(matches!( - thread.entries()[1], - AgentThreadEntry::AssistantMessage(_) - )); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_path_mentions(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - let tempdir = tempfile::tempdir().unwrap(); - std::fs::write( - tempdir.path().join("foo.rs"), - indoc! {" - fn main() { - println!(\"Hello, world!\"); - } - "}, - ) - .expect("failed to write file"); - let project = Project::example([tempdir.path()], &mut cx.to_async()).await; - let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await; - thread - .update(cx, |thread, cx| { - thread.send( - acp::SendUserMessageParams { - chunks: vec![ - acp::UserMessageChunk::Text { - text: "Read the file ".into(), - }, - acp::UserMessageChunk::Path { - path: Path::new("foo.rs").into(), - }, - acp::UserMessageChunk::Text { - text: " and tell me what the content of the println! is".into(), - }, - ], - }, - cx, - ) - }) - .await - .unwrap(); - - thread.read_with(cx, |thread, cx| { - assert_eq!(thread.entries().len(), 3); - assert!(matches!( - thread.entries()[0], - AgentThreadEntry::UserMessage(_) - )); - assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_))); - let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else { - panic!("Expected AssistantMessage") - }; - assert!( - assistant_message.to_markdown(cx).contains("Hello, world!"), - "unexpected assistant message: {:?}", - assistant_message.to_markdown(cx) - ); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_tool_call(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/private/tmp"), - json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}), - ) - .await; - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - thread - .update(cx, |thread, cx| { - thread.send_raw( - "Read the '/private/tmp/foo' file and tell me what you see.", - cx, - ) - }) - .await - .unwrap(); - thread.read_with(cx, |thread, _cx| { - assert!(matches!( - &thread.entries()[2], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { .. }, - .. - }) - )); - - assert!(matches!( - thread.entries()[3], - AgentThreadEntry::AssistantMessage(_) - )); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - let full_turn = thread.update(cx, |thread, cx| { - thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) - }); - - run_until_first_tool_call(&thread, cx).await; - - let tool_call_id = thread.read_with(cx, |thread, _cx| { - let AgentThreadEntry::ToolCall(ToolCall { - id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, - .. - }) = &thread.entries()[2] - else { - panic!(); - }; - - assert_eq!(root_command, "echo"); - - *id - }); - - thread.update(cx, |thread, cx| { - thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); - - assert!(matches!( - &thread.entries()[2], - AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Allowed { .. }, - .. - }) - )); - }); - - full_turn.await.unwrap(); - - thread.read_with(cx, |thread, cx| { - let AgentThreadEntry::ToolCall(ToolCall { - content: Some(ToolCallContent::Markdown { markdown }), - status: ToolCallStatus::Allowed { .. }, - .. - }) = &thread.entries()[2] - else { - panic!(); - }; - - markdown.read_with(cx, |md, _cx| { - assert!( - md.source().contains("Hello, world!"), - r#"Expected '{}' to contain "Hello, world!""#, - md.source() - ); - }); - }); - } - - #[gpui::test] - #[cfg_attr(not(feature = "gemini"), ignore)] - async fn test_gemini_cancel(cx: &mut TestAppContext) { - init_test(cx); - - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; - let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await; - let full_turn = thread.update(cx, |thread, cx| { - thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx) - }); - - let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await; - - thread.read_with(cx, |thread, _cx| { - let AgentThreadEntry::ToolCall(ToolCall { - id, - status: - ToolCallStatus::WaitingForConfirmation { - confirmation: ToolCallConfirmation::Execute { root_command, .. }, - .. - }, - .. - }) = &thread.entries()[first_tool_call_ix] - else { - panic!("{:?}", thread.entries()[1]); - }; - - assert_eq!(root_command, "echo"); - - *id - }); - - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .await - .unwrap(); - full_turn.await.unwrap(); - thread.read_with(cx, |thread, _| { - let AgentThreadEntry::ToolCall(ToolCall { - status: ToolCallStatus::Canceled, - .. - }) = &thread.entries()[first_tool_call_ix] - else { - panic!(); - }; - }); - - thread - .update(cx, |thread, cx| { - thread.send_raw(r#"Stop running and say goodbye to me."#, cx) - }) - .await - .unwrap(); - thread.read_with(cx, |thread, _| { - assert!(matches!( - &thread.entries().last().unwrap(), - AgentThreadEntry::AssistantMessage(..), - )) - }); - } - - async fn run_until_first_tool_call( - thread: &Entity, - cx: &mut TestAppContext, - ) -> usize { - let (mut tx, mut rx) = mpsc::channel::(1); - - let subscription = cx.update(|cx| { - cx.subscribe(thread, move |thread, _, cx| { - for (ix, entry) in thread.read(cx).entries().iter().enumerate() { - if matches!(entry, AgentThreadEntry::ToolCall(_)) { - return tx.try_send(ix).unwrap(); - } - } - }) - }); - - select! { - _ = cx.executor().timer(Duration::from_secs(10)).fuse() => { - panic!("Timeout waiting for tool call") - } - ix = rx.next().fuse() => { - drop(subscription); - ix.unwrap() - } + AgentServerCommand { + path: "node".into(), + args: vec![cli_path, ACP_ARG.into()], + env: None, } } } diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs index 8e6914352b..29dcf5eb8c 100644 --- a/crates/agent_servers/src/settings.rs +++ b/crates/agent_servers/src/settings.rs @@ -12,6 +12,7 @@ pub fn init(cx: &mut App) { #[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)] pub struct AllAgentServersSettings { pub gemini: Option, + pub claude: Option, } #[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] diff --git a/crates/agent_servers/src/stdio_agent_server.rs b/crates/agent_servers/src/stdio_agent_server.rs index d78506022d..e60dd39de4 100644 --- a/crates/agent_servers/src/stdio_agent_server.rs +++ b/crates/agent_servers/src/stdio_agent_server.rs @@ -4,11 +4,8 @@ use agentic_coding_protocol as acp; use anyhow::{Result, anyhow}; use gpui::{App, AsyncApp, Entity, Task, prelude::*}; use project::Project; -use std::{ - path::{Path, PathBuf}, - sync::Arc, -}; -use util::{ResultExt, paths}; +use std::path::Path; +use util::ResultExt; pub trait StdioAgentServer: Send + Clone { fn logo(&self) -> ui::IconName; @@ -120,50 +117,3 @@ impl AgentServer for T { }) } } - -pub async fn find_bin_in_path( - bin_name: &'static str, - project: &Entity, - cx: &mut AsyncApp, -) -> Option { - let (env_task, root_dir) = project - .update(cx, |project, cx| { - let worktree = project.visible_worktrees(cx).next(); - match worktree { - Some(worktree) => { - let env_task = project.environment().update(cx, |env, cx| { - env.get_worktree_environment(worktree.clone(), cx) - }); - - let path = worktree.read(cx).abs_path(); - (env_task, path) - } - None => { - let path: Arc = paths::home_dir().as_path().into(); - let env_task = project.environment().update(cx, |env, cx| { - env.get_directory_environment(path.clone(), cx) - }); - (env_task, path) - } - } - }) - .log_err()?; - - cx.background_executor() - .spawn(async move { - let which_result = if cfg!(windows) { - which::which(bin_name) - } else { - let env = env_task.await.unwrap_or_default(); - let shell_path = env.get("PATH").cloned(); - which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref()) - }; - - if let Err(which::Error::CannotFindBinaryPath) = which_result { - return None; - } - - which_result.log_err() - }) - .await -}