Unify agent server settings and extract e2e tests out (#34642)
Release Notes: - N/A
This commit is contained in:
parent
0f72d7ed52
commit
dab0b3509d
7 changed files with 547 additions and 500 deletions
|
@ -7,7 +7,7 @@ license = "GPL-3.0-or-later"
|
|||
|
||||
[features]
|
||||
test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"]
|
||||
gemini = []
|
||||
e2e = []
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
|
|
@ -3,6 +3,9 @@ mod gemini;
|
|||
mod settings;
|
||||
mod stdio_agent_server;
|
||||
|
||||
#[cfg(test)]
|
||||
mod e2e_tests;
|
||||
|
||||
pub use claude::*;
|
||||
pub use gemini::*;
|
||||
pub use settings::*;
|
||||
|
@ -11,34 +14,20 @@ pub use stdio_agent_server::*;
|
|||
use acp_thread::AcpThread;
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use gpui::{App, AsyncApp, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
settings::init(cx);
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
|
||||
pub struct AgentServerCommand {
|
||||
#[serde(rename = "command")]
|
||||
pub path: PathBuf,
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
pub enum AgentServerVersion {
|
||||
Supported,
|
||||
Unsupported {
|
||||
error_message: SharedString,
|
||||
upgrade_message: SharedString,
|
||||
upgrade_command: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait AgentServer: Send {
|
||||
fn logo(&self) -> ui::IconName;
|
||||
fn name(&self) -> &'static str;
|
||||
|
@ -78,3 +67,99 @@ impl std::fmt::Debug for AgentServerCommand {
|
|||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub enum AgentServerVersion {
|
||||
Supported,
|
||||
Unsupported {
|
||||
error_message: SharedString,
|
||||
upgrade_message: SharedString,
|
||||
upgrade_command: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
|
||||
pub struct AgentServerCommand {
|
||||
#[serde(rename = "command")]
|
||||
pub path: PathBuf,
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl AgentServerCommand {
|
||||
pub(crate) async fn resolve(
|
||||
path_bin_name: &'static str,
|
||||
extra_args: &[&'static str],
|
||||
settings: Option<AgentServerSettings>,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Option<Self> {
|
||||
if let Some(agent_settings) = settings {
|
||||
return Some(Self {
|
||||
path: agent_settings.command.path,
|
||||
args: agent_settings
|
||||
.command
|
||||
.args
|
||||
.into_iter()
|
||||
.chain(extra_args.iter().map(|arg| arg.to_string()))
|
||||
.collect(),
|
||||
env: agent_settings.command.env,
|
||||
});
|
||||
} else {
|
||||
find_bin_in_path(path_bin_name, project, cx)
|
||||
.await
|
||||
.map(|path| Self {
|
||||
path,
|
||||
args: extra_args.iter().map(|arg| arg.to_string()).collect(),
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn find_bin_in_path(
|
||||
bin_name: &'static str,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Option<PathBuf> {
|
||||
let (env_task, root_dir) = project
|
||||
.update(cx, |project, cx| {
|
||||
let worktree = project.visible_worktrees(cx).next();
|
||||
match worktree {
|
||||
Some(worktree) => {
|
||||
let env_task = project.environment().update(cx, |env, cx| {
|
||||
env.get_worktree_environment(worktree.clone(), cx)
|
||||
});
|
||||
|
||||
let path = worktree.read(cx).abs_path();
|
||||
(env_task, path)
|
||||
}
|
||||
None => {
|
||||
let path: Arc<Path> = paths::home_dir().as_path().into();
|
||||
let env_task = project.environment().update(cx, |env, cx| {
|
||||
env.get_directory_environment(path.clone(), cx)
|
||||
});
|
||||
(env_task, path)
|
||||
}
|
||||
}
|
||||
})
|
||||
.log_err()?;
|
||||
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let which_result = if cfg!(windows) {
|
||||
which::which(bin_name)
|
||||
} else {
|
||||
let env = env_task.await.unwrap_or_default();
|
||||
let shell_path = env.get("PATH").cloned();
|
||||
which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref())
|
||||
};
|
||||
|
||||
if let Err(which::Error::CannotFindBinaryPath) = which_result {
|
||||
return None;
|
||||
}
|
||||
|
||||
which_result.log_err()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ mod tools;
|
|||
|
||||
use collections::HashMap;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
|
@ -12,7 +13,7 @@ use agentic_coding_protocol::{
|
|||
self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion,
|
||||
StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use futures::channel::oneshot;
|
||||
use futures::future::LocalBoxFuture;
|
||||
use futures::{AsyncBufReadExt, AsyncWriteExt};
|
||||
|
@ -28,7 +29,7 @@ use util::ResultExt;
|
|||
|
||||
use crate::claude::mcp_server::ClaudeMcpServer;
|
||||
use crate::claude::tools::ClaudeTool;
|
||||
use crate::{AgentServer, find_bin_in_path};
|
||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
|
||||
use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -87,12 +88,19 @@ impl AgentServer for ClaudeCode {
|
|||
.await?;
|
||||
mcp_config_file.flush().await?;
|
||||
|
||||
let command = find_bin_in_path("claude", &project, cx)
|
||||
.await
|
||||
.context("Failed to find claude binary")?;
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).claude.clone()
|
||||
})?;
|
||||
|
||||
let mut child = util::command::new_smol_command(&command)
|
||||
.args([
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("claude", &[], settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find claude binary");
|
||||
};
|
||||
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(
|
||||
[
|
||||
"--input-format",
|
||||
"stream-json",
|
||||
"--output-format",
|
||||
|
@ -111,7 +119,10 @@ impl AgentServer for ClaudeCode {
|
|||
"mcp__zed__Read,mcp__zed__Edit",
|
||||
"--disallowedTools",
|
||||
"Read,Edit",
|
||||
])
|
||||
]
|
||||
.into_iter()
|
||||
.chain(command.args.iter().map(|arg| arg.as_str())),
|
||||
)
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
|
@ -562,10 +573,20 @@ struct McpServerConfig {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// crate::common_e2e_tests!(ClaudeCode);
|
||||
|
||||
pub fn local_command() -> AgentServerCommand {
|
||||
AgentServerCommand {
|
||||
path: "claude".into(),
|
||||
args: vec![],
|
||||
env: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_content_untagged_text() {
|
||||
let json = json!("Hello, world!");
|
||||
|
|
368
crates/agent_servers/src/e2e_tests.rs
Normal file
368
crates/agent_servers/src/e2e_tests.rs
Normal file
|
@ -0,0 +1,368 @@
|
|||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
|
||||
use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
|
||||
use acp_thread::{
|
||||
AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
|
||||
};
|
||||
use agentic_coding_protocol as acp;
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||
use gpui::{Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use util::path;
|
||||
|
||||
pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx).await;
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.entries().len(), 2);
|
||||
assert!(matches!(
|
||||
thread.entries()[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
thread.entries()[1],
|
||||
AgentThreadEntry::AssistantMessage(_)
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
|
||||
let _fs = init_test(cx).await;
|
||||
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
std::fs::write(
|
||||
tempdir.path().join("foo.rs"),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
println!(\"Hello, world!\");
|
||||
}
|
||||
"},
|
||||
)
|
||||
.expect("failed to write file");
|
||||
let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
|
||||
let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
acp::SendUserMessageParams {
|
||||
chunks: vec![
|
||||
acp::UserMessageChunk::Text {
|
||||
text: "Read the file ".into(),
|
||||
},
|
||||
acp::UserMessageChunk::Path {
|
||||
path: Path::new("foo.rs").into(),
|
||||
},
|
||||
acp::UserMessageChunk::Text {
|
||||
text: " and tell me what the content of the println! is".into(),
|
||||
},
|
||||
],
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries().len(), 3);
|
||||
assert!(matches!(
|
||||
thread.entries()[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
));
|
||||
assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
|
||||
let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
|
||||
panic!("Expected AssistantMessage")
|
||||
};
|
||||
assert!(
|
||||
assistant_message.to_markdown(cx).contains("Hello, world!"),
|
||||
"unexpected assistant message: {:?}",
|
||||
assistant_message.to_markdown(cx)
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx).await;
|
||||
fs.insert_tree(
|
||||
path!("/private/tmp"),
|
||||
json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Read the '/private/tmp/foo' file and tell me what you see.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert!(matches!(
|
||||
&thread.entries()[2],
|
||||
AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
|
||||
assert!(matches!(
|
||||
thread.entries()[3],
|
||||
AgentThreadEntry::AssistantMessage(_)
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn test_tool_call_with_confirmation(
|
||||
server: impl AgentServer + 'static,
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
let fs = init_test(cx).await;
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
|
||||
});
|
||||
|
||||
run_until_first_tool_call(&thread, cx).await;
|
||||
|
||||
let tool_call_id = thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[2]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
assert_eq!(root_command, "echo");
|
||||
|
||||
*id
|
||||
});
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
|
||||
|
||||
assert!(matches!(
|
||||
&thread.entries()[2],
|
||||
AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
});
|
||||
|
||||
full_turn.await.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
content: Some(ToolCallContent::Markdown { markdown }),
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
}) = &thread.entries()[2]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
markdown.read_with(cx, |md, _cx| {
|
||||
assert!(
|
||||
md.source().contains("Hello, world!"),
|
||||
r#"Expected '{}' to contain "Hello, world!""#,
|
||||
md.source()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx).await;
|
||||
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
|
||||
});
|
||||
|
||||
let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
|
||||
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
else {
|
||||
panic!("{:?}", thread.entries()[1]);
|
||||
};
|
||||
|
||||
assert_eq!(root_command, "echo");
|
||||
|
||||
*id
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.cancel(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
full_turn.await.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Canceled,
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(
|
||||
&thread.entries().last().unwrap(),
|
||||
AgentThreadEntry::AssistantMessage(..),
|
||||
))
|
||||
});
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! common_e2e_tests {
|
||||
($server:expr) => {
|
||||
mod common_e2e {
|
||||
use super::*;
|
||||
|
||||
#[::gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn basic(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_basic($server, cx).await;
|
||||
}
|
||||
|
||||
#[::gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_path_mentions($server, cx).await;
|
||||
}
|
||||
|
||||
#[::gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn tool_call(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_tool_call($server, cx).await;
|
||||
}
|
||||
|
||||
#[::gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
|
||||
}
|
||||
|
||||
#[::gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn cancel(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_cancel($server, cx).await;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Helpers
|
||||
|
||||
pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
||||
env_logger::try_init().ok();
|
||||
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
crate::settings::init(cx);
|
||||
|
||||
crate::AllAgentServersSettings::override_global(
|
||||
AllAgentServersSettings {
|
||||
claude: Some(AgentServerSettings {
|
||||
command: crate::claude::tests::local_command(),
|
||||
}),
|
||||
gemini: Some(AgentServerSettings {
|
||||
command: crate::gemini::tests::local_command(),
|
||||
}),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
FakeFs::new(cx.executor())
|
||||
}
|
||||
|
||||
pub async fn new_test_thread(
|
||||
server: impl AgentServer + 'static,
|
||||
project: Entity<Project>,
|
||||
current_dir: impl AsRef<Path>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Entity<AcpThread> {
|
||||
let thread = cx
|
||||
.update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _| thread.initialize())
|
||||
.await
|
||||
.unwrap();
|
||||
thread
|
||||
}
|
||||
|
||||
pub async fn run_until_first_tool_call(
|
||||
thread: &Entity<AcpThread>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> usize {
|
||||
let (mut tx, mut rx) = mpsc::channel::<usize>(1);
|
||||
|
||||
let subscription = cx.update(|cx| {
|
||||
cx.subscribe(thread, move |thread, _, cx| {
|
||||
for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
|
||||
if matches!(entry, AgentThreadEntry::ToolCall(_)) {
|
||||
return tx.try_send(ix).unwrap();
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
select! {
|
||||
// We have to use a smol timer here because
|
||||
// cx.background_executor().timer isn't real in the test context
|
||||
_ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
|
||||
panic!("Timeout waiting for tool call")
|
||||
}
|
||||
ix = rx.next().fuse() => {
|
||||
drop(subscription);
|
||||
ix.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::stdio_agent_server::{StdioAgentServer, find_bin_in_path};
|
||||
use crate::stdio_agent_server::StdioAgentServer;
|
||||
use crate::{AgentServerCommand, AgentServerVersion};
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
|
@ -38,35 +38,15 @@ impl StdioAgentServer for Gemini {
|
|||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<AgentServerCommand> {
|
||||
let custom_command = cx.read_global(|settings: &SettingsStore, _| {
|
||||
let settings = settings.get::<AllAgentServersSettings>(None);
|
||||
settings
|
||||
.gemini
|
||||
.as_ref()
|
||||
.map(|gemini_settings| AgentServerCommand {
|
||||
path: gemini_settings.command.path.clone(),
|
||||
args: gemini_settings
|
||||
.command
|
||||
.args
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(std::iter::once(ACP_ARG.into()))
|
||||
.collect(),
|
||||
env: gemini_settings.command.env.clone(),
|
||||
})
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||
})?;
|
||||
|
||||
if let Some(custom_command) = custom_command {
|
||||
return Ok(custom_command);
|
||||
}
|
||||
|
||||
if let Some(path) = find_bin_in_path("gemini", project, cx).await {
|
||||
return Ok(AgentServerCommand {
|
||||
path,
|
||||
args: vec![ACP_ARG.into()],
|
||||
env: None,
|
||||
});
|
||||
}
|
||||
if let Some(command) =
|
||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
||||
{
|
||||
return Ok(command);
|
||||
};
|
||||
|
||||
let (fs, node_runtime) = project.update(cx, |project, _| {
|
||||
(project.fs().clone(), project.node_runtime().cloned())
|
||||
|
@ -121,381 +101,23 @@ impl StdioAgentServer for Gemini {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::{path::Path, time::Duration};
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use crate::AgentServerCommand;
|
||||
use std::path::Path;
|
||||
|
||||
use acp_thread::{
|
||||
AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent,
|
||||
ToolCallStatus,
|
||||
};
|
||||
use agentic_coding_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||
use gpui::{AsyncApp, Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
crate::common_e2e_tests!(Gemini);
|
||||
|
||||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion, StdioAgentServer};
|
||||
|
||||
pub async fn gemini_acp_thread(
|
||||
project: Entity<Project>,
|
||||
current_dir: impl AsRef<Path>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Entity<AcpThread> {
|
||||
#[derive(Clone)]
|
||||
struct DevGemini;
|
||||
|
||||
impl StdioAgentServer for DevGemini {
|
||||
async fn command(
|
||||
&self,
|
||||
_project: &Entity<Project>,
|
||||
_cx: &mut AsyncApp,
|
||||
) -> Result<AgentServerCommand> {
|
||||
pub fn local_command() -> AgentServerCommand {
|
||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../../gemini-cli/packages/cli")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
Ok(AgentServerCommand {
|
||||
AgentServerCommand {
|
||||
path: "node".into(),
|
||||
args: vec![cli_path, "--experimental-acp".into()],
|
||||
args: vec![cli_path, ACP_ARG.into()],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn version(&self, _command: &AgentServerCommand) -> Result<AgentServerVersion> {
|
||||
Ok(AgentServerVersion::Supported)
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiGemini
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
fn supports_always_allow(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
let thread = cx
|
||||
.update(|cx| AgentServer::new_thread(&DevGemini, current_dir.as_ref(), &project, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _| thread.initialize())
|
||||
.await
|
||||
.unwrap();
|
||||
thread
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_basic(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.entries().len(), 2);
|
||||
assert!(matches!(
|
||||
thread.entries()[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
thread.entries()[1],
|
||||
AgentThreadEntry::AssistantMessage(_)
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
std::fs::write(
|
||||
tempdir.path().join("foo.rs"),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
println!(\"Hello, world!\");
|
||||
}
|
||||
"},
|
||||
)
|
||||
.expect("failed to write file");
|
||||
let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
|
||||
let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
acp::SendUserMessageParams {
|
||||
chunks: vec![
|
||||
acp::UserMessageChunk::Text {
|
||||
text: "Read the file ".into(),
|
||||
},
|
||||
acp::UserMessageChunk::Path {
|
||||
path: Path::new("foo.rs").into(),
|
||||
},
|
||||
acp::UserMessageChunk::Text {
|
||||
text: " and tell me what the content of the println! is".into(),
|
||||
},
|
||||
],
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries().len(), 3);
|
||||
assert!(matches!(
|
||||
thread.entries()[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
));
|
||||
assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
|
||||
let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
|
||||
panic!("Expected AssistantMessage")
|
||||
};
|
||||
assert!(
|
||||
assistant_message.to_markdown(cx).contains("Hello, world!"),
|
||||
"unexpected assistant message: {:?}",
|
||||
assistant_message.to_markdown(cx)
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_tool_call(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/private/tmp"),
|
||||
json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Read the '/private/tmp/foo' file and tell me what you see.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert!(matches!(
|
||||
&thread.entries()[2],
|
||||
AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
|
||||
assert!(matches!(
|
||||
thread.entries()[3],
|
||||
AgentThreadEntry::AssistantMessage(_)
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
|
||||
});
|
||||
|
||||
run_until_first_tool_call(&thread, cx).await;
|
||||
|
||||
let tool_call_id = thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[2]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
assert_eq!(root_command, "echo");
|
||||
|
||||
*id
|
||||
});
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
|
||||
|
||||
assert!(matches!(
|
||||
&thread.entries()[2],
|
||||
AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
});
|
||||
|
||||
full_turn.await.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
content: Some(ToolCallContent::Markdown { markdown }),
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
}) = &thread.entries()[2]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
markdown.read_with(cx, |md, _cx| {
|
||||
assert!(
|
||||
md.source().contains("Hello, world!"),
|
||||
r#"Expected '{}' to contain "Hello, world!""#,
|
||||
md.source()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_cancel(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
|
||||
});
|
||||
|
||||
let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
|
||||
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
else {
|
||||
panic!("{:?}", thread.entries()[1]);
|
||||
};
|
||||
|
||||
assert_eq!(root_command, "echo");
|
||||
|
||||
*id
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.cancel(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
full_turn.await.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Canceled,
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(
|
||||
&thread.entries().last().unwrap(),
|
||||
AgentThreadEntry::AssistantMessage(..),
|
||||
))
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_until_first_tool_call(
|
||||
thread: &Entity<AcpThread>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> usize {
|
||||
let (mut tx, mut rx) = mpsc::channel::<usize>(1);
|
||||
|
||||
let subscription = cx.update(|cx| {
|
||||
cx.subscribe(thread, move |thread, _, cx| {
|
||||
for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
|
||||
if matches!(entry, AgentThreadEntry::ToolCall(_)) {
|
||||
return tx.try_send(ix).unwrap();
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
select! {
|
||||
_ = cx.executor().timer(Duration::from_secs(10)).fuse() => {
|
||||
panic!("Timeout waiting for tool call")
|
||||
}
|
||||
ix = rx.next().fuse() => {
|
||||
drop(subscription);
|
||||
ix.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ pub fn init(cx: &mut App) {
|
|||
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
pub struct AllAgentServersSettings {
|
||||
pub gemini: Option<AgentServerSettings>,
|
||||
pub claude: Option<AgentServerSettings>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
|
|
|
@ -4,11 +4,8 @@ use agentic_coding_protocol as acp;
|
|||
use anyhow::{Result, anyhow};
|
||||
use gpui::{App, AsyncApp, Entity, Task, prelude::*};
|
||||
use project::Project;
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{ResultExt, paths};
|
||||
use std::path::Path;
|
||||
use util::ResultExt;
|
||||
|
||||
pub trait StdioAgentServer: Send + Clone {
|
||||
fn logo(&self) -> ui::IconName;
|
||||
|
@ -120,50 +117,3 @@ impl<T: StdioAgentServer + 'static> AgentServer for T {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn find_bin_in_path(
|
||||
bin_name: &'static str,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Option<PathBuf> {
|
||||
let (env_task, root_dir) = project
|
||||
.update(cx, |project, cx| {
|
||||
let worktree = project.visible_worktrees(cx).next();
|
||||
match worktree {
|
||||
Some(worktree) => {
|
||||
let env_task = project.environment().update(cx, |env, cx| {
|
||||
env.get_worktree_environment(worktree.clone(), cx)
|
||||
});
|
||||
|
||||
let path = worktree.read(cx).abs_path();
|
||||
(env_task, path)
|
||||
}
|
||||
None => {
|
||||
let path: Arc<Path> = paths::home_dir().as_path().into();
|
||||
let env_task = project.environment().update(cx, |env, cx| {
|
||||
env.get_directory_environment(path.clone(), cx)
|
||||
});
|
||||
(env_task, path)
|
||||
}
|
||||
}
|
||||
})
|
||||
.log_err()?;
|
||||
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let which_result = if cfg!(windows) {
|
||||
which::which(bin_name)
|
||||
} else {
|
||||
let env = env_task.await.unwrap_or_default();
|
||||
let shell_path = env.get("PATH").cloned();
|
||||
which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref())
|
||||
};
|
||||
|
||||
if let Err(which::Error::CannotFindBinaryPath) = which_result {
|
||||
return None;
|
||||
}
|
||||
|
||||
which_result.log_err()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue