Update to new agent schema (#35578)
Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
parent
dea64d3373
commit
f17943e4a3
23 changed files with 741 additions and 1168 deletions
23
Cargo.lock
generated
23
Cargo.lock
generated
|
@ -7,10 +7,8 @@ name = "acp_thread"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"agent-client-protocol",
|
"agent-client-protocol",
|
||||||
"agentic-coding-protocol",
|
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assistant_tool",
|
"assistant_tool",
|
||||||
"async-pipe",
|
|
||||||
"buffer_diff",
|
"buffer_diff",
|
||||||
"editor",
|
"editor",
|
||||||
"env_logger 0.11.8",
|
"env_logger 0.11.8",
|
||||||
|
@ -139,10 +137,14 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "agent-client-protocol"
|
name = "agent-client-protocol"
|
||||||
version = "0.0.11"
|
version = "0.0.17"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b"
|
checksum = "22c5180e40d31a9998ffa5f8eb067667f0870908a4aeed65a6a299e2d1d95443"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"futures 0.3.31",
|
||||||
|
"log",
|
||||||
|
"parking_lot",
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
@ -177,6 +179,7 @@ dependencies = [
|
||||||
"smol",
|
"smol",
|
||||||
"strum 0.27.1",
|
"strum 0.27.1",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
|
"thiserror 2.0.12",
|
||||||
"ui",
|
"ui",
|
||||||
"util",
|
"util",
|
||||||
"uuid",
|
"uuid",
|
||||||
|
@ -9572,9 +9575,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lock_api"
|
name = "lock_api"
|
||||||
version = "0.4.12"
|
version = "0.4.13"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
|
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"scopeguard",
|
"scopeguard",
|
||||||
|
@ -11288,9 +11291,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking_lot"
|
name = "parking_lot"
|
||||||
version = "0.12.3"
|
version = "0.12.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
|
checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"lock_api",
|
"lock_api",
|
||||||
"parking_lot_core",
|
"parking_lot_core",
|
||||||
|
@ -11298,9 +11301,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking_lot_core"
|
name = "parking_lot_core"
|
||||||
version = "0.9.10"
|
version = "0.9.11"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
|
checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
|
|
|
@ -421,7 +421,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||||
#
|
#
|
||||||
|
|
||||||
agentic-coding-protocol = "0.0.10"
|
agentic-coding-protocol = "0.0.10"
|
||||||
agent-client-protocol = "0.0.11"
|
agent-client-protocol = "0.0.17"
|
||||||
aho-corasick = "1.1"
|
aho-corasick = "1.1"
|
||||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||||
any_vec = "0.14"
|
any_vec = "0.14"
|
||||||
|
|
|
@ -17,7 +17,6 @@ test-support = ["gpui/test-support", "project/test-support"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
agent-client-protocol.workspace = true
|
agent-client-protocol.workspace = true
|
||||||
agentic-coding-protocol.workspace = true
|
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
assistant_tool.workspace = true
|
assistant_tool.workspace = true
|
||||||
buffer_diff.workspace = true
|
buffer_diff.workspace = true
|
||||||
|
@ -37,7 +36,6 @@ util.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
async-pipe.workspace = true
|
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
gpui = { workspace = true, "features" = ["test-support"] }
|
gpui = { workspace = true, "features" = ["test-support"] }
|
||||||
indoc.workspace = true
|
indoc.workspace = true
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
mod connection;
|
mod connection;
|
||||||
mod old_acp_support;
|
|
||||||
pub use connection::*;
|
pub use connection::*;
|
||||||
pub use old_acp_support::*;
|
|
||||||
|
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
|
@ -391,7 +389,7 @@ impl ToolCallContent {
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
match content {
|
match content {
|
||||||
acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock {
|
acp::ToolCallContent::Content { content } => Self::ContentBlock {
|
||||||
content: ContentBlock::new(content, &language_registry, cx),
|
content: ContentBlock::new(content, &language_registry, cx),
|
||||||
},
|
},
|
||||||
acp::ToolCallContent::Diff { diff } => Self::Diff {
|
acp::ToolCallContent::Diff { diff } => Self::Diff {
|
||||||
|
@ -619,6 +617,7 @@ impl Error for LoadError {}
|
||||||
|
|
||||||
impl AcpThread {
|
impl AcpThread {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
|
title: impl Into<SharedString>,
|
||||||
connection: Rc<dyn AgentConnection>,
|
connection: Rc<dyn AgentConnection>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
|
@ -631,7 +630,7 @@ impl AcpThread {
|
||||||
shared_buffers: Default::default(),
|
shared_buffers: Default::default(),
|
||||||
entries: Default::default(),
|
entries: Default::default(),
|
||||||
plan: Default::default(),
|
plan: Default::default(),
|
||||||
title: connection.name().into(),
|
title: title.into(),
|
||||||
project,
|
project,
|
||||||
send_task: None,
|
send_task: None,
|
||||||
connection,
|
connection,
|
||||||
|
@ -708,14 +707,14 @@ impl AcpThread {
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
match update {
|
match update {
|
||||||
acp::SessionUpdate::UserMessage(content_block) => {
|
acp::SessionUpdate::UserMessageChunk { content } => {
|
||||||
self.push_user_content_block(content_block, cx);
|
self.push_user_content_block(content, cx);
|
||||||
}
|
}
|
||||||
acp::SessionUpdate::AgentMessageChunk(content_block) => {
|
acp::SessionUpdate::AgentMessageChunk { content } => {
|
||||||
self.push_assistant_content_block(content_block, false, cx);
|
self.push_assistant_content_block(content, false, cx);
|
||||||
}
|
}
|
||||||
acp::SessionUpdate::AgentThoughtChunk(content_block) => {
|
acp::SessionUpdate::AgentThoughtChunk { content } => {
|
||||||
self.push_assistant_content_block(content_block, true, cx);
|
self.push_assistant_content_block(content, true, cx);
|
||||||
}
|
}
|
||||||
acp::SessionUpdate::ToolCall(tool_call) => {
|
acp::SessionUpdate::ToolCall(tool_call) => {
|
||||||
self.upsert_tool_call(tool_call, cx);
|
self.upsert_tool_call(tool_call, cx);
|
||||||
|
@ -984,10 +983,6 @@ impl AcpThread {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
|
|
||||||
self.connection.authenticate(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
pub fn send_raw(
|
pub fn send_raw(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
@ -1029,7 +1024,7 @@ impl AcpThread {
|
||||||
let result = this
|
let result = this
|
||||||
.update(cx, |this, cx| {
|
.update(cx, |this, cx| {
|
||||||
this.connection.prompt(
|
this.connection.prompt(
|
||||||
acp::PromptArguments {
|
acp::PromptRequest {
|
||||||
prompt: message,
|
prompt: message,
|
||||||
session_id: this.session_id.clone(),
|
session_id: this.session_id.clone(),
|
||||||
},
|
},
|
||||||
|
@ -1239,21 +1234,15 @@ impl AcpThread {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use agentic_coding_protocol as acp_old;
|
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use async_pipe::{PipeReader, PipeWriter};
|
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
||||||
use futures::{
|
|
||||||
channel::mpsc,
|
|
||||||
future::{LocalBoxFuture, try_join_all},
|
|
||||||
select,
|
|
||||||
};
|
|
||||||
use gpui::{AsyncApp, TestAppContext, WeakEntity};
|
use gpui::{AsyncApp, TestAppContext, WeakEntity};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use project::FakeFs;
|
use project::FakeFs;
|
||||||
use rand::Rng as _;
|
use rand::Rng as _;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::{future::BoxedLocal, stream::StreamExt as _};
|
use smol::stream::StreamExt as _;
|
||||||
use std::{cell::RefCell, rc::Rc, time::Duration};
|
use std::{cell::RefCell, rc::Rc, time::Duration};
|
||||||
|
|
||||||
use util::path;
|
use util::path;
|
||||||
|
@ -1274,7 +1263,15 @@ mod tests {
|
||||||
|
|
||||||
let fs = FakeFs::new(cx.executor());
|
let fs = FakeFs::new(cx.executor());
|
||||||
let project = Project::test(fs, [], cx).await;
|
let project = Project::test(fs, [], cx).await;
|
||||||
let (thread, _fake_server) = fake_acp_thread(project, cx);
|
let connection = Rc::new(FakeAgentConnection::new());
|
||||||
|
let thread = cx
|
||||||
|
.spawn(async move |mut cx| {
|
||||||
|
connection
|
||||||
|
.new_thread(project, Path::new(path!("/test")), &mut cx)
|
||||||
|
.await
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Test creating a new user message
|
// Test creating a new user message
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
|
@ -1354,34 +1351,40 @@ mod tests {
|
||||||
|
|
||||||
let fs = FakeFs::new(cx.executor());
|
let fs = FakeFs::new(cx.executor());
|
||||||
let project = Project::test(fs, [], cx).await;
|
let project = Project::test(fs, [], cx).await;
|
||||||
let (thread, fake_server) = fake_acp_thread(project, cx);
|
let connection = Rc::new(FakeAgentConnection::new().on_user_message(
|
||||||
|
|_, thread, mut cx| {
|
||||||
|
async move {
|
||||||
|
thread.update(&mut cx, |thread, cx| {
|
||||||
|
thread
|
||||||
|
.handle_session_update(
|
||||||
|
acp::SessionUpdate::AgentThoughtChunk {
|
||||||
|
content: "Thinking ".into(),
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
thread
|
||||||
|
.handle_session_update(
|
||||||
|
acp::SessionUpdate::AgentThoughtChunk {
|
||||||
|
content: "hard!".into(),
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
})
|
||||||
|
}
|
||||||
|
.boxed_local()
|
||||||
|
},
|
||||||
|
));
|
||||||
|
|
||||||
fake_server.update(cx, |fake_server, _| {
|
let thread = cx
|
||||||
fake_server.on_user_message(move |_, server, mut cx| async move {
|
.spawn(async move |mut cx| {
|
||||||
server
|
connection
|
||||||
.update(&mut cx, |server, _| {
|
.new_thread(project, Path::new(path!("/test")), &mut cx)
|
||||||
server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
|
|
||||||
chunk: acp_old::AssistantMessageChunk::Thought {
|
|
||||||
thought: "Thinking ".into(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})?
|
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
|
||||||
server
|
|
||||||
.update(&mut cx, |server, _| {
|
|
||||||
server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
|
|
||||||
chunk: acp_old::AssistantMessageChunk::Thought {
|
|
||||||
thought: "hard!".into(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})?
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
})
|
})
|
||||||
});
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
|
.update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
|
||||||
|
@ -1414,7 +1417,38 @@ mod tests {
|
||||||
fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
|
fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
|
||||||
.await;
|
.await;
|
||||||
let project = Project::test(fs.clone(), [], cx).await;
|
let project = Project::test(fs.clone(), [], cx).await;
|
||||||
let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
|
let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
|
||||||
|
let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
|
||||||
|
let connection = Rc::new(FakeAgentConnection::new().on_user_message(
|
||||||
|
move |_, thread, mut cx| {
|
||||||
|
let read_file_tx = read_file_tx.clone();
|
||||||
|
async move {
|
||||||
|
let content = thread
|
||||||
|
.update(&mut cx, |thread, cx| {
|
||||||
|
thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(content, "one\ntwo\nthree\n");
|
||||||
|
read_file_tx.take().unwrap().send(()).unwrap();
|
||||||
|
thread
|
||||||
|
.update(&mut cx, |thread, cx| {
|
||||||
|
thread.write_text_file(
|
||||||
|
path!("/tmp/foo").into(),
|
||||||
|
"one\ntwo\nthree\nfour\nfive\n".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
.boxed_local()
|
||||||
|
},
|
||||||
|
));
|
||||||
|
|
||||||
let (worktree, pathbuf) = project
|
let (worktree, pathbuf) = project
|
||||||
.update(cx, |project, cx| {
|
.update(cx, |project, cx| {
|
||||||
project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
|
project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
|
||||||
|
@ -1428,38 +1462,10 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
|
let thread = cx
|
||||||
let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
|
.spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
|
||||||
|
.await
|
||||||
fake_server.update(cx, |fake_server, _| {
|
.unwrap();
|
||||||
fake_server.on_user_message(move |_, server, mut cx| {
|
|
||||||
let read_file_tx = read_file_tx.clone();
|
|
||||||
async move {
|
|
||||||
let content = server
|
|
||||||
.update(&mut cx, |server, _| {
|
|
||||||
server.send_to_zed(acp_old::ReadTextFileParams {
|
|
||||||
path: path!("/tmp/foo").into(),
|
|
||||||
line: None,
|
|
||||||
limit: None,
|
|
||||||
})
|
|
||||||
})?
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(content.content, "one\ntwo\nthree\n");
|
|
||||||
read_file_tx.take().unwrap().send(()).unwrap();
|
|
||||||
server
|
|
||||||
.update(&mut cx, |server, _| {
|
|
||||||
server.send_to_zed(acp_old::WriteTextFileParams {
|
|
||||||
path: path!("/tmp/foo").into(),
|
|
||||||
content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
|
|
||||||
})
|
|
||||||
})?
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.send_raw("Extend the count in /tmp/foo", cx)
|
thread.send_raw("Extend the count in /tmp/foo", cx)
|
||||||
|
@ -1486,36 +1492,44 @@ mod tests {
|
||||||
|
|
||||||
let fs = FakeFs::new(cx.executor());
|
let fs = FakeFs::new(cx.executor());
|
||||||
let project = Project::test(fs, [], cx).await;
|
let project = Project::test(fs, [], cx).await;
|
||||||
let (thread, fake_server) = fake_acp_thread(project, cx);
|
let id = acp::ToolCallId("test".into());
|
||||||
|
|
||||||
let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
|
let connection = Rc::new(FakeAgentConnection::new().on_user_message({
|
||||||
|
let id = id.clone();
|
||||||
let tool_call_id = Rc::new(RefCell::new(None));
|
move |_, thread, mut cx| {
|
||||||
let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
|
let id = id.clone();
|
||||||
fake_server.update(cx, |fake_server, _| {
|
|
||||||
let tool_call_id = tool_call_id.clone();
|
|
||||||
fake_server.on_user_message(move |_, server, mut cx| {
|
|
||||||
let end_turn_rx = end_turn_rx.clone();
|
|
||||||
let tool_call_id = tool_call_id.clone();
|
|
||||||
async move {
|
async move {
|
||||||
let tool_call_result = server
|
thread
|
||||||
.update(&mut cx, |server, _| {
|
.update(&mut cx, |thread, cx| {
|
||||||
server.send_to_zed(acp_old::PushToolCallParams {
|
thread.handle_session_update(
|
||||||
label: "Fetch".to_string(),
|
acp::SessionUpdate::ToolCall(acp::ToolCall {
|
||||||
icon: acp_old::Icon::Globe,
|
id: id.clone(),
|
||||||
content: None,
|
label: "Label".into(),
|
||||||
locations: vec![],
|
kind: acp::ToolKind::Fetch,
|
||||||
})
|
status: acp::ToolCallStatus::InProgress,
|
||||||
})?
|
content: vec![],
|
||||||
.await
|
locations: vec![],
|
||||||
|
raw_input: None,
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
*tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
|
|
||||||
end_turn_rx.take().unwrap().await.ok();
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
.boxed_local()
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
let thread = cx
|
||||||
|
.spawn(async move |mut cx| {
|
||||||
|
connection
|
||||||
|
.new_thread(project, Path::new(path!("/test")), &mut cx)
|
||||||
|
.await
|
||||||
})
|
})
|
||||||
});
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let request = thread.update(cx, |thread, cx| {
|
let request = thread.update(cx, |thread, cx| {
|
||||||
thread.send_raw("Fetch https://example.com", cx)
|
thread.send_raw("Fetch https://example.com", cx)
|
||||||
|
@ -1536,8 +1550,6 @@ mod tests {
|
||||||
));
|
));
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.run_until_parked();
|
|
||||||
|
|
||||||
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
||||||
|
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
|
@ -1550,19 +1562,22 @@ mod tests {
|
||||||
));
|
));
|
||||||
});
|
});
|
||||||
|
|
||||||
fake_server
|
thread
|
||||||
.update(cx, |fake_server, _| {
|
.update(cx, |thread, cx| {
|
||||||
fake_server.send_to_zed(acp_old::UpdateToolCallParams {
|
thread.handle_session_update(
|
||||||
tool_call_id: tool_call_id.borrow().unwrap(),
|
acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
|
||||||
status: acp_old::ToolCallStatus::Finished,
|
id,
|
||||||
content: None,
|
fields: acp::ToolCallUpdateFields {
|
||||||
})
|
status: Some(acp::ToolCallStatus::Completed),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.await
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
drop(end_turn_tx);
|
request.await.unwrap();
|
||||||
assert!(request.await.unwrap_err().to_string().contains("canceled"));
|
|
||||||
|
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
|
@ -1585,23 +1600,37 @@ mod tests {
|
||||||
fs.insert_tree(path!("/test"), json!({})).await;
|
fs.insert_tree(path!("/test"), json!({})).await;
|
||||||
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
||||||
|
|
||||||
let connection = Rc::new(StubAgentConnection::new(vec![
|
let connection = Rc::new(FakeAgentConnection::new().on_user_message({
|
||||||
acp::SessionUpdate::ToolCall(acp::ToolCall {
|
move |_, thread, mut cx| {
|
||||||
id: acp::ToolCallId("test".into()),
|
async move {
|
||||||
label: "Label".into(),
|
thread
|
||||||
kind: acp::ToolKind::Edit,
|
.update(&mut cx, |thread, cx| {
|
||||||
status: acp::ToolCallStatus::Completed,
|
thread.handle_session_update(
|
||||||
content: vec![acp::ToolCallContent::Diff {
|
acp::SessionUpdate::ToolCall(acp::ToolCall {
|
||||||
diff: acp::Diff {
|
id: acp::ToolCallId("test".into()),
|
||||||
path: "/test/test.txt".into(),
|
label: "Label".into(),
|
||||||
old_text: None,
|
kind: acp::ToolKind::Edit,
|
||||||
new_text: "foo".into(),
|
status: acp::ToolCallStatus::Completed,
|
||||||
},
|
content: vec![acp::ToolCallContent::Diff {
|
||||||
}],
|
diff: acp::Diff {
|
||||||
locations: vec![],
|
path: "/test/test.txt".into(),
|
||||||
raw_input: None,
|
old_text: None,
|
||||||
}),
|
new_text: "foo".into(),
|
||||||
]));
|
},
|
||||||
|
}],
|
||||||
|
locations: vec![],
|
||||||
|
raw_input: None,
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
.boxed_local()
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
let thread = connection
|
let thread = connection
|
||||||
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
|
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
|
||||||
|
@ -1642,25 +1671,53 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
struct StubAgentConnection {
|
struct FakeAgentConnection {
|
||||||
|
auth_methods: Vec<acp::AuthMethod>,
|
||||||
sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
|
sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
|
||||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
on_user_message: Option<
|
||||||
updates: Vec<acp::SessionUpdate>,
|
Rc<
|
||||||
|
dyn Fn(
|
||||||
|
acp::PromptRequest,
|
||||||
|
WeakEntity<AcpThread>,
|
||||||
|
AsyncApp,
|
||||||
|
) -> LocalBoxFuture<'static, Result<()>>
|
||||||
|
+ 'static,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StubAgentConnection {
|
impl FakeAgentConnection {
|
||||||
fn new(updates: Vec<acp::SessionUpdate>) -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
updates,
|
auth_methods: Vec::new(),
|
||||||
permission_requests: HashMap::default(),
|
on_user_message: None,
|
||||||
sessions: Arc::default(),
|
sessions: Arc::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[expect(unused)]
|
||||||
|
fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
|
||||||
|
self.auth_methods = auth_methods;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_user_message(
|
||||||
|
mut self,
|
||||||
|
handler: impl Fn(
|
||||||
|
acp::PromptRequest,
|
||||||
|
WeakEntity<AcpThread>,
|
||||||
|
AsyncApp,
|
||||||
|
) -> LocalBoxFuture<'static, Result<()>>
|
||||||
|
+ 'static,
|
||||||
|
) -> Self {
|
||||||
|
self.on_user_message.replace(Rc::new(handler));
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentConnection for StubAgentConnection {
|
impl AgentConnection for FakeAgentConnection {
|
||||||
fn name(&self) -> &'static str {
|
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||||
"StubAgentConnection"
|
&self.auth_methods
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
|
@ -1678,222 +1735,43 @@ mod tests {
|
||||||
.into(),
|
.into(),
|
||||||
);
|
);
|
||||||
let thread = cx
|
let thread = cx
|
||||||
.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))
|
.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||||
Task::ready(Ok(thread))
|
Task::ready(Ok(thread))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
|
fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
|
||||||
unimplemented!()
|
if self.auth_methods().iter().any(|m| m.id == method) {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
Task::ready(Err(anyhow!("Invalid Auth Method")))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> {
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
|
||||||
let sessions = self.sessions.lock();
|
let sessions = self.sessions.lock();
|
||||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||||
let mut tasks = vec![];
|
if let Some(handler) = &self.on_user_message {
|
||||||
for update in &self.updates {
|
let handler = handler.clone();
|
||||||
let thread = thread.clone();
|
let thread = thread.clone();
|
||||||
let update = update.clone();
|
cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
|
||||||
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
|
|
||||||
&& let Some(options) = self.permission_requests.get(&tool_call.id)
|
|
||||||
{
|
|
||||||
Some((tool_call.clone(), options.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let task = cx.spawn(async move |cx| {
|
|
||||||
if let Some((tool_call, options)) = permission_request {
|
|
||||||
let permission = thread.update(cx, |thread, cx| {
|
|
||||||
thread.request_tool_call_permission(
|
|
||||||
tool_call.clone(),
|
|
||||||
options.clone(),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
permission.await?;
|
|
||||||
}
|
|
||||||
thread.update(cx, |thread, cx| {
|
|
||||||
thread.handle_session_update(update.clone(), cx).unwrap();
|
|
||||||
})?;
|
|
||||||
anyhow::Ok(())
|
|
||||||
});
|
|
||||||
tasks.push(task);
|
|
||||||
}
|
|
||||||
cx.spawn(async move |_| {
|
|
||||||
try_join_all(tasks).await?;
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fake_acp_thread(
|
|
||||||
project: Entity<Project>,
|
|
||||||
cx: &mut TestAppContext,
|
|
||||||
) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
|
|
||||||
let (stdin_tx, stdin_rx) = async_pipe::pipe();
|
|
||||||
let (stdout_tx, stdout_rx) = async_pipe::pipe();
|
|
||||||
|
|
||||||
let thread = cx.new(|cx| {
|
|
||||||
let foreground_executor = cx.foreground_executor().clone();
|
|
||||||
let thread_rc = Rc::new(RefCell::new(cx.entity().downgrade()));
|
|
||||||
|
|
||||||
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
|
|
||||||
OldAcpClientDelegate::new(thread_rc.clone(), cx.to_async()),
|
|
||||||
stdin_tx,
|
|
||||||
stdout_rx,
|
|
||||||
move |fut| {
|
|
||||||
foreground_executor.spawn(fut).detach();
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let io_task = cx.background_spawn({
|
|
||||||
async move {
|
|
||||||
io_fut.await.log_err();
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
});
|
|
||||||
let connection = OldAcpAgentConnection {
|
|
||||||
name: "test",
|
|
||||||
connection,
|
|
||||||
child_status: io_task,
|
|
||||||
current_thread: thread_rc,
|
|
||||||
};
|
|
||||||
|
|
||||||
AcpThread::new(
|
|
||||||
Rc::new(connection),
|
|
||||||
project,
|
|
||||||
acp::SessionId("test".into()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
|
|
||||||
(thread, agent)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct FakeAcpServer {
|
|
||||||
connection: acp_old::ClientConnection,
|
|
||||||
|
|
||||||
_io_task: Task<()>,
|
|
||||||
on_user_message: Option<
|
|
||||||
Rc<
|
|
||||||
dyn Fn(
|
|
||||||
acp_old::SendUserMessageParams,
|
|
||||||
Entity<FakeAcpServer>,
|
|
||||||
AsyncApp,
|
|
||||||
) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct FakeAgent {
|
|
||||||
server: Entity<FakeAcpServer>,
|
|
||||||
cx: AsyncApp,
|
|
||||||
cancel_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl acp_old::Agent for FakeAgent {
|
|
||||||
async fn initialize(
|
|
||||||
&self,
|
|
||||||
params: acp_old::InitializeParams,
|
|
||||||
) -> Result<acp_old::InitializeResponse, acp_old::Error> {
|
|
||||||
Ok(acp_old::InitializeResponse {
|
|
||||||
protocol_version: params.protocol_version,
|
|
||||||
is_authenticated: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn authenticate(&self) -> Result<(), acp_old::Error> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
|
|
||||||
if let Some(cancel_tx) = self.cancel_tx.take() {
|
|
||||||
cancel_tx.send(()).log_err();
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_user_message(
|
|
||||||
&self,
|
|
||||||
request: acp_old::SendUserMessageParams,
|
|
||||||
) -> Result<(), acp_old::Error> {
|
|
||||||
let (cancel_tx, cancel_rx) = oneshot::channel();
|
|
||||||
self.cancel_tx.replace(Some(cancel_tx));
|
|
||||||
|
|
||||||
let mut cx = self.cx.clone();
|
|
||||||
let handler = self
|
|
||||||
.server
|
|
||||||
.update(&mut cx, |server, _| server.on_user_message.clone())
|
|
||||||
.ok()
|
|
||||||
.flatten();
|
|
||||||
if let Some(handler) = handler {
|
|
||||||
select! {
|
|
||||||
_ = cancel_rx.fuse() => Err(anyhow::anyhow!("Message sending canceled").into()),
|
|
||||||
_ = handler(request, self.server.clone(), self.cx.clone()).fuse() => Ok(()),
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
Err(anyhow::anyhow!("No handler for on_user_message").into())
|
Task::ready(Ok(()))
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FakeAcpServer {
|
|
||||||
fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
|
|
||||||
let agent = FakeAgent {
|
|
||||||
server: cx.entity(),
|
|
||||||
cx: cx.to_async(),
|
|
||||||
cancel_tx: Default::default(),
|
|
||||||
};
|
|
||||||
let foreground_executor = cx.foreground_executor().clone();
|
|
||||||
|
|
||||||
let (connection, io_fut) = acp_old::ClientConnection::connect_to_client(
|
|
||||||
agent.clone(),
|
|
||||||
stdout,
|
|
||||||
stdin,
|
|
||||||
move |fut| {
|
|
||||||
foreground_executor.spawn(fut).detach();
|
|
||||||
},
|
|
||||||
);
|
|
||||||
FakeAcpServer {
|
|
||||||
connection: connection,
|
|
||||||
on_user_message: None,
|
|
||||||
_io_task: cx.background_spawn(async move {
|
|
||||||
io_fut.await.log_err();
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_user_message<F>(
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||||
&mut self,
|
let sessions = self.sessions.lock();
|
||||||
handler: impl for<'a> Fn(
|
let thread = sessions.get(&session_id).unwrap().clone();
|
||||||
acp_old::SendUserMessageParams,
|
|
||||||
Entity<FakeAcpServer>,
|
|
||||||
AsyncApp,
|
|
||||||
) -> F
|
|
||||||
+ 'static,
|
|
||||||
) where
|
|
||||||
F: Future<Output = Result<(), acp_old::Error>> + 'static,
|
|
||||||
{
|
|
||||||
self.on_user_message
|
|
||||||
.replace(Rc::new(move |request, server, cx| {
|
|
||||||
handler(request, server, cx).boxed_local()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn send_to_zed<T: acp_old::ClientRequest + 'static>(
|
cx.spawn(async move |cx| {
|
||||||
&self,
|
thread
|
||||||
message: T,
|
.update(cx, |thread, cx| thread.cancel(cx))
|
||||||
) -> BoxedLocal<Result<T::Response>> {
|
.unwrap()
|
||||||
self.connection
|
.await
|
||||||
.request(message)
|
})
|
||||||
.map(|f| f.map_err(|err| anyhow!(err)))
|
.detach();
|
||||||
.boxed_local()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::{path::Path, rc::Rc};
|
use std::{error::Error, fmt, path::Path, rc::Rc};
|
||||||
|
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol::{self as acp};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use gpui::{AsyncApp, Entity, Task};
|
use gpui::{AsyncApp, Entity, Task};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
@ -9,8 +9,6 @@ use ui::App;
|
||||||
use crate::AcpThread;
|
use crate::AcpThread;
|
||||||
|
|
||||||
pub trait AgentConnection {
|
pub trait AgentConnection {
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
@ -18,9 +16,21 @@ pub trait AgentConnection {
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) -> Task<Result<Entity<AcpThread>>>;
|
) -> Task<Result<Entity<AcpThread>>>;
|
||||||
|
|
||||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
|
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||||
|
|
||||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
|
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||||
|
|
||||||
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AuthRequired;
|
||||||
|
|
||||||
|
impl Error for AuthRequired {}
|
||||||
|
impl fmt::Display for AuthRequired {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "AuthRequired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ collections.workspace = true
|
||||||
context_server.workspace = true
|
context_server.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
indoc.workspace = true
|
||||||
itertools.workspace = true
|
itertools.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
paths.workspace = true
|
paths.workspace = true
|
||||||
|
@ -37,11 +38,11 @@ settings.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
strum.workspace = true
|
strum.workspace = true
|
||||||
tempfile.workspace = true
|
tempfile.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
util.workspace = true
|
util.workspace = true
|
||||||
uuid.workspace = true
|
uuid.workspace = true
|
||||||
watch.workspace = true
|
watch.workspace = true
|
||||||
indoc.workspace = true
|
|
||||||
which.workspace = true
|
which.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
||||||
|
|
34
crates/agent_servers/src/acp.rs
Normal file
34
crates/agent_servers/src/acp.rs
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
use std::{path::Path, rc::Rc};
|
||||||
|
|
||||||
|
use crate::AgentServerCommand;
|
||||||
|
use acp_thread::AgentConnection;
|
||||||
|
use anyhow::Result;
|
||||||
|
use gpui::AsyncApp;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
mod v0;
|
||||||
|
mod v1;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[error("Unsupported version")]
|
||||||
|
pub struct UnsupportedVersion;
|
||||||
|
|
||||||
|
pub async fn connect(
|
||||||
|
server_name: &'static str,
|
||||||
|
command: AgentServerCommand,
|
||||||
|
root_dir: &Path,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Result<Rc<dyn AgentConnection>> {
|
||||||
|
let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await;
|
||||||
|
|
||||||
|
match conn {
|
||||||
|
Ok(conn) => Ok(Rc::new(conn) as _),
|
||||||
|
Err(err) if err.is::<UnsupportedVersion>() => {
|
||||||
|
// Consider re-using initialize response and subprocess when adding another version here
|
||||||
|
let conn: Rc<dyn AgentConnection> =
|
||||||
|
Rc::new(v0::AcpConnection::stdio(server_name, command, &root_dir, cx).await?);
|
||||||
|
Ok(conn)
|
||||||
|
}
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,18 +1,19 @@
|
||||||
// Translates old acp agents into the new schema
|
// Translates old acp agents into the new schema
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
|
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use futures::channel::oneshot;
|
use futures::channel::oneshot;
|
||||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
|
use std::{cell::RefCell, path::Path, rc::Rc};
|
||||||
use ui::App;
|
use ui::App;
|
||||||
use util::ResultExt as _;
|
use util::ResultExt as _;
|
||||||
|
|
||||||
use crate::{AcpThread, AgentConnection};
|
use crate::AgentServerCommand;
|
||||||
|
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OldAcpClientDelegate {
|
struct OldAcpClientDelegate {
|
||||||
thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
||||||
cx: AsyncApp,
|
cx: AsyncApp,
|
||||||
next_tool_call_id: Rc<RefCell<u64>>,
|
next_tool_call_id: Rc<RefCell<u64>>,
|
||||||
|
@ -20,7 +21,7 @@ pub struct OldAcpClientDelegate {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OldAcpClientDelegate {
|
impl OldAcpClientDelegate {
|
||||||
pub fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
|
fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
|
||||||
Self {
|
Self {
|
||||||
thread,
|
thread,
|
||||||
cx,
|
cx,
|
||||||
|
@ -351,28 +352,71 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
pub struct AcpConnection {
|
||||||
pub struct Unauthenticated;
|
|
||||||
|
|
||||||
impl Error for Unauthenticated {}
|
|
||||||
impl fmt::Display for Unauthenticated {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "Unauthenticated")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct OldAcpAgentConnection {
|
|
||||||
pub name: &'static str,
|
pub name: &'static str,
|
||||||
pub connection: acp_old::AgentConnection,
|
pub connection: acp_old::AgentConnection,
|
||||||
pub child_status: Task<Result<()>>,
|
pub _child_status: Task<Result<()>>,
|
||||||
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentConnection for OldAcpAgentConnection {
|
impl AcpConnection {
|
||||||
fn name(&self) -> &'static str {
|
pub fn stdio(
|
||||||
self.name
|
name: &'static str,
|
||||||
}
|
command: AgentServerCommand,
|
||||||
|
root_dir: &Path,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Task<Result<Self>> {
|
||||||
|
let root_dir = root_dir.to_path_buf();
|
||||||
|
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
let mut child = util::command::new_smol_command(&command.path)
|
||||||
|
.args(command.args.iter())
|
||||||
|
.current_dir(root_dir)
|
||||||
|
.stdin(std::process::Stdio::piped())
|
||||||
|
.stdout(std::process::Stdio::piped())
|
||||||
|
.stderr(std::process::Stdio::inherit())
|
||||||
|
.kill_on_drop(true)
|
||||||
|
.spawn()?;
|
||||||
|
|
||||||
|
let stdin = child.stdin.take().unwrap();
|
||||||
|
let stdout = child.stdout.take().unwrap();
|
||||||
|
|
||||||
|
let foreground_executor = cx.foreground_executor().clone();
|
||||||
|
|
||||||
|
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
|
||||||
|
|
||||||
|
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
|
||||||
|
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
|
||||||
|
stdin,
|
||||||
|
stdout,
|
||||||
|
move |fut| foreground_executor.spawn(fut).detach(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let io_task = cx.background_spawn(async move {
|
||||||
|
io_fut.await.log_err();
|
||||||
|
});
|
||||||
|
|
||||||
|
let child_status = cx.background_spawn(async move {
|
||||||
|
let result = match child.status().await {
|
||||||
|
Err(e) => Err(anyhow!(e)),
|
||||||
|
Ok(result) if result.success() => Ok(()),
|
||||||
|
Ok(result) => Err(anyhow!(result)),
|
||||||
|
};
|
||||||
|
drop(io_task);
|
||||||
|
result
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
name,
|
||||||
|
connection,
|
||||||
|
_child_status: child_status,
|
||||||
|
current_thread: thread_rc,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentConnection for AcpConnection {
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
@ -391,13 +435,13 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||||
let result = acp_old::InitializeParams::response_from_any(result)?;
|
let result = acp_old::InitializeParams::response_from_any(result)?;
|
||||||
|
|
||||||
if !result.is_authenticated {
|
if !result.is_authenticated {
|
||||||
anyhow::bail!(Unauthenticated)
|
anyhow::bail!(AuthRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let session_id = acp::SessionId("acp-old-no-id".into());
|
let session_id = acp::SessionId("acp-old-no-id".into());
|
||||||
AcpThread::new(self.clone(), project, session_id, cx)
|
AcpThread::new(self.name, self.clone(), project, session_id, cx)
|
||||||
});
|
});
|
||||||
current_thread.replace(thread.downgrade());
|
current_thread.replace(thread.downgrade());
|
||||||
thread
|
thread
|
||||||
|
@ -405,7 +449,11 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
|
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||||
|
&[]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||||
let task = self
|
let task = self
|
||||||
.connection
|
.connection
|
||||||
.request_any(acp_old::AuthenticateParams.into_any());
|
.request_any(acp_old::AuthenticateParams.into_any());
|
||||||
|
@ -415,7 +463,7 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||||
let chunks = params
|
let chunks = params
|
||||||
.prompt
|
.prompt
|
||||||
.into_iter()
|
.into_iter()
|
254
crates/agent_servers/src/acp/v1.rs
Normal file
254
crates/agent_servers/src/acp/v1.rs
Normal file
|
@ -0,0 +1,254 @@
|
||||||
|
use agent_client_protocol::{self as acp, Agent as _};
|
||||||
|
use collections::HashMap;
|
||||||
|
use futures::channel::oneshot;
|
||||||
|
use project::Project;
|
||||||
|
use std::cell::RefCell;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
use anyhow::{Context as _, Result};
|
||||||
|
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||||
|
|
||||||
|
use crate::{AgentServerCommand, acp::UnsupportedVersion};
|
||||||
|
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
|
||||||
|
|
||||||
|
pub struct AcpConnection {
|
||||||
|
server_name: &'static str,
|
||||||
|
connection: Rc<acp::ClientSideConnection>,
|
||||||
|
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||||
|
auth_methods: Vec<acp::AuthMethod>,
|
||||||
|
_io_task: Task<Result<()>>,
|
||||||
|
_child: smol::process::Child,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AcpSession {
|
||||||
|
thread: WeakEntity<AcpThread>,
|
||||||
|
}
|
||||||
|
|
||||||
|
const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
|
||||||
|
|
||||||
|
impl AcpConnection {
|
||||||
|
pub async fn stdio(
|
||||||
|
server_name: &'static str,
|
||||||
|
command: AgentServerCommand,
|
||||||
|
root_dir: &Path,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let mut child = util::command::new_smol_command(&command.path)
|
||||||
|
.args(command.args.iter().map(|arg| arg.as_str()))
|
||||||
|
.envs(command.env.iter().flatten())
|
||||||
|
.current_dir(root_dir)
|
||||||
|
.stdin(std::process::Stdio::piped())
|
||||||
|
.stdout(std::process::Stdio::piped())
|
||||||
|
.stderr(std::process::Stdio::inherit())
|
||||||
|
.kill_on_drop(true)
|
||||||
|
.spawn()?;
|
||||||
|
|
||||||
|
let stdout = child.stdout.take().expect("Failed to take stdout");
|
||||||
|
let stdin = child.stdin.take().expect("Failed to take stdin");
|
||||||
|
|
||||||
|
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||||
|
|
||||||
|
let client = ClientDelegate {
|
||||||
|
sessions: sessions.clone(),
|
||||||
|
cx: cx.clone(),
|
||||||
|
};
|
||||||
|
let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
|
||||||
|
let foreground_executor = cx.foreground_executor().clone();
|
||||||
|
move |fut| {
|
||||||
|
foreground_executor.spawn(fut).detach();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let io_task = cx.background_spawn(io_task);
|
||||||
|
|
||||||
|
let response = connection
|
||||||
|
.initialize(acp::InitializeRequest {
|
||||||
|
protocol_version: acp::VERSION,
|
||||||
|
client_capabilities: acp::ClientCapabilities {
|
||||||
|
fs: acp::FileSystemCapability {
|
||||||
|
read_text_file: true,
|
||||||
|
write_text_file: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
|
||||||
|
return Err(UnsupportedVersion.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
auth_methods: response.auth_methods,
|
||||||
|
connection: connection.into(),
|
||||||
|
server_name,
|
||||||
|
sessions,
|
||||||
|
_child: child,
|
||||||
|
_io_task: io_task,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentConnection for AcpConnection {
|
||||||
|
fn new_thread(
|
||||||
|
self: Rc<Self>,
|
||||||
|
project: Entity<Project>,
|
||||||
|
cwd: &Path,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Task<Result<Entity<AcpThread>>> {
|
||||||
|
let conn = self.connection.clone();
|
||||||
|
let sessions = self.sessions.clone();
|
||||||
|
let cwd = cwd.to_path_buf();
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
let response = conn
|
||||||
|
.new_session(acp::NewSessionRequest {
|
||||||
|
mcp_servers: vec![],
|
||||||
|
cwd,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let Some(session_id) = response.session_id else {
|
||||||
|
anyhow::bail!(AuthRequired);
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread = cx.new(|cx| {
|
||||||
|
AcpThread::new(
|
||||||
|
self.server_name,
|
||||||
|
self.clone(),
|
||||||
|
project,
|
||||||
|
session_id.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let session = AcpSession {
|
||||||
|
thread: thread.downgrade(),
|
||||||
|
};
|
||||||
|
sessions.borrow_mut().insert(session_id, session);
|
||||||
|
|
||||||
|
Ok(thread)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||||
|
&self.auth_methods
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||||
|
let conn = self.connection.clone();
|
||||||
|
cx.foreground_executor().spawn(async move {
|
||||||
|
let result = conn
|
||||||
|
.authenticate(acp::AuthenticateRequest {
|
||||||
|
method_id: method_id.clone(),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||||
|
let conn = self.connection.clone();
|
||||||
|
cx.foreground_executor()
|
||||||
|
.spawn(async move { Ok(conn.prompt(params).await?) })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||||
|
let conn = self.connection.clone();
|
||||||
|
let params = acp::CancelledNotification {
|
||||||
|
session_id: session_id.clone(),
|
||||||
|
};
|
||||||
|
cx.foreground_executor()
|
||||||
|
.spawn(async move { conn.cancelled(params).await })
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ClientDelegate {
|
||||||
|
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||||
|
cx: AsyncApp,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl acp::Client for ClientDelegate {
|
||||||
|
async fn request_permission(
|
||||||
|
&self,
|
||||||
|
arguments: acp::RequestPermissionRequest,
|
||||||
|
) -> Result<acp::RequestPermissionResponse, acp::Error> {
|
||||||
|
let cx = &mut self.cx.clone();
|
||||||
|
let rx = self
|
||||||
|
.sessions
|
||||||
|
.borrow()
|
||||||
|
.get(&arguments.session_id)
|
||||||
|
.context("Failed to get session")?
|
||||||
|
.thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let result = rx.await;
|
||||||
|
|
||||||
|
let outcome = match result {
|
||||||
|
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
|
||||||
|
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(acp::RequestPermissionResponse { outcome })
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_text_file(
|
||||||
|
&self,
|
||||||
|
arguments: acp::WriteTextFileRequest,
|
||||||
|
) -> Result<(), acp::Error> {
|
||||||
|
let cx = &mut self.cx.clone();
|
||||||
|
let task = self
|
||||||
|
.sessions
|
||||||
|
.borrow()
|
||||||
|
.get(&arguments.session_id)
|
||||||
|
.context("Failed to get session")?
|
||||||
|
.thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.write_text_file(arguments.path, arguments.content, cx)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
task.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_text_file(
|
||||||
|
&self,
|
||||||
|
arguments: acp::ReadTextFileRequest,
|
||||||
|
) -> Result<acp::ReadTextFileResponse, acp::Error> {
|
||||||
|
let cx = &mut self.cx.clone();
|
||||||
|
let task = self
|
||||||
|
.sessions
|
||||||
|
.borrow()
|
||||||
|
.get(&arguments.session_id)
|
||||||
|
.context("Failed to get session")?
|
||||||
|
.thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let content = task.await?;
|
||||||
|
|
||||||
|
Ok(acp::ReadTextFileResponse { content })
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn session_notification(
|
||||||
|
&self,
|
||||||
|
notification: acp::SessionNotification,
|
||||||
|
) -> Result<(), acp::Error> {
|
||||||
|
let cx = &mut self.cx.clone();
|
||||||
|
let sessions = self.sessions.borrow();
|
||||||
|
let session = sessions
|
||||||
|
.get(¬ification.session_id)
|
||||||
|
.context("Failed to get session")?;
|
||||||
|
|
||||||
|
session.thread.update(cx, |thread, cx| {
|
||||||
|
thread.handle_session_update(notification.update, cx)
|
||||||
|
})??;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,14 +1,12 @@
|
||||||
|
mod acp;
|
||||||
mod claude;
|
mod claude;
|
||||||
mod codex;
|
|
||||||
mod gemini;
|
mod gemini;
|
||||||
mod mcp_server;
|
|
||||||
mod settings;
|
mod settings;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod e2e_tests;
|
mod e2e_tests;
|
||||||
|
|
||||||
pub use claude::*;
|
pub use claude::*;
|
||||||
pub use codex::*;
|
|
||||||
pub use gemini::*;
|
pub use gemini::*;
|
||||||
pub use settings::*;
|
pub use settings::*;
|
||||||
|
|
||||||
|
@ -38,7 +36,6 @@ pub trait AgentServer: Send {
|
||||||
|
|
||||||
fn connect(
|
fn connect(
|
||||||
&self,
|
&self,
|
||||||
// these will go away when old_acp is fully removed
|
|
||||||
root_dir: &Path,
|
root_dir: &Path,
|
||||||
project: &Entity<Project>,
|
project: &Entity<Project>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
|
|
|
@ -70,10 +70,6 @@ struct ClaudeAgentConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentConnection for ClaudeAgentConnection {
|
impl AgentConnection for ClaudeAgentConnection {
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
ClaudeCode.name()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let thread =
|
let thread = cx.new(|cx| {
|
||||||
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
|
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
|
||||||
|
})?;
|
||||||
|
|
||||||
thread_tx.send(thread.downgrade())?;
|
thread_tx.send(thread.downgrade())?;
|
||||||
|
|
||||||
|
@ -186,11 +183,15 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||||
|
&[]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||||
let sessions = self.sessions.borrow();
|
let sessions = self.sessions.borrow();
|
||||||
let Some(session) = sessions.get(¶ms.session_id) else {
|
let Some(session) = sessions.get(¶ms.session_id) else {
|
||||||
return Task::ready(Err(anyhow!(
|
return Task::ready(Err(anyhow!(
|
||||||
|
|
|
@ -1,319 +0,0 @@
|
||||||
use agent_client_protocol as acp;
|
|
||||||
use anyhow::anyhow;
|
|
||||||
use collections::HashMap;
|
|
||||||
use context_server::listener::McpServerTool;
|
|
||||||
use context_server::types::requests;
|
|
||||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
|
||||||
use futures::channel::{mpsc, oneshot};
|
|
||||||
use project::Project;
|
|
||||||
use settings::SettingsStore;
|
|
||||||
use smol::stream::StreamExt as _;
|
|
||||||
use std::cell::RefCell;
|
|
||||||
use std::rc::Rc;
|
|
||||||
use std::{path::Path, sync::Arc};
|
|
||||||
use util::ResultExt;
|
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
|
||||||
|
|
||||||
use crate::mcp_server::ZedMcpServer;
|
|
||||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
|
|
||||||
use acp_thread::{AcpThread, AgentConnection};
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Codex;
|
|
||||||
|
|
||||||
impl AgentServer for Codex {
|
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
"Codex"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn empty_state_headline(&self) -> &'static str {
|
|
||||||
"Welcome to Codex"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn empty_state_message(&self) -> &'static str {
|
|
||||||
"What can I help with?"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn logo(&self) -> ui::IconName {
|
|
||||||
ui::IconName::AiOpenAi
|
|
||||||
}
|
|
||||||
|
|
||||||
fn connect(
|
|
||||||
&self,
|
|
||||||
_root_dir: &Path,
|
|
||||||
project: &Entity<Project>,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
|
||||||
let project = project.clone();
|
|
||||||
let working_directory = project.read(cx).active_project_directory(cx);
|
|
||||||
cx.spawn(async move |cx| {
|
|
||||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
|
||||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let Some(command) =
|
|
||||||
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
|
|
||||||
else {
|
|
||||||
anyhow::bail!("Failed to find codex binary");
|
|
||||||
};
|
|
||||||
|
|
||||||
let client: Arc<ContextServer> = ContextServer::stdio(
|
|
||||||
ContextServerId("codex-mcp-server".into()),
|
|
||||||
ContextServerCommand {
|
|
||||||
path: command.path,
|
|
||||||
args: command.args,
|
|
||||||
env: command.env,
|
|
||||||
},
|
|
||||||
working_directory,
|
|
||||||
)
|
|
||||||
.into();
|
|
||||||
ContextServer::start(client.clone(), cx).await?;
|
|
||||||
|
|
||||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
|
||||||
client
|
|
||||||
.client()
|
|
||||||
.context("Failed to subscribe")?
|
|
||||||
.on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
|
|
||||||
move |notification, _cx| {
|
|
||||||
let notification_tx = notification_tx.clone();
|
|
||||||
log::trace!(
|
|
||||||
"ACP Notification: {}",
|
|
||||||
serde_json::to_string_pretty(¬ification).unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(notification) =
|
|
||||||
serde_json::from_value::<acp::SessionNotification>(notification)
|
|
||||||
.log_err()
|
|
||||||
{
|
|
||||||
notification_tx.unbounded_send(notification).ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
|
||||||
|
|
||||||
let notification_handler_task = cx.spawn({
|
|
||||||
let sessions = sessions.clone();
|
|
||||||
async move |cx| {
|
|
||||||
while let Some(notification) = notification_rx.next().await {
|
|
||||||
CodexConnection::handle_session_notification(
|
|
||||||
notification,
|
|
||||||
sessions.clone(),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let connection = CodexConnection {
|
|
||||||
client,
|
|
||||||
sessions,
|
|
||||||
_notification_handler_task: notification_handler_task,
|
|
||||||
};
|
|
||||||
Ok(Rc::new(connection) as _)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CodexConnection {
|
|
||||||
client: Arc<context_server::ContextServer>,
|
|
||||||
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
|
||||||
_notification_handler_task: Task<()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CodexSession {
|
|
||||||
thread: WeakEntity<AcpThread>,
|
|
||||||
cancel_tx: Option<oneshot::Sender<()>>,
|
|
||||||
_mcp_server: ZedMcpServer,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AgentConnection for CodexConnection {
|
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
"Codex"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_thread(
|
|
||||||
self: Rc<Self>,
|
|
||||||
project: Entity<Project>,
|
|
||||||
cwd: &Path,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Task<Result<Entity<AcpThread>>> {
|
|
||||||
let client = self.client.client();
|
|
||||||
let sessions = self.sessions.clone();
|
|
||||||
let cwd = cwd.to_path_buf();
|
|
||||||
cx.spawn(async move |cx| {
|
|
||||||
let client = client.context("MCP server is not initialized yet")?;
|
|
||||||
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
|
||||||
|
|
||||||
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
|
|
||||||
|
|
||||||
let response = client
|
|
||||||
.request::<requests::CallTool>(context_server::types::CallToolParams {
|
|
||||||
name: acp::NEW_SESSION_TOOL_NAME.into(),
|
|
||||||
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
|
|
||||||
mcp_servers: [(
|
|
||||||
mcp_server::SERVER_NAME.to_string(),
|
|
||||||
mcp_server.server_config()?,
|
|
||||||
)]
|
|
||||||
.into(),
|
|
||||||
client_tools: acp::ClientTools {
|
|
||||||
request_permission: Some(acp::McpToolId {
|
|
||||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
|
||||||
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
|
|
||||||
}),
|
|
||||||
read_text_file: Some(acp::McpToolId {
|
|
||||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
|
||||||
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
|
|
||||||
}),
|
|
||||||
write_text_file: Some(acp::McpToolId {
|
|
||||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
|
||||||
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
cwd,
|
|
||||||
})?),
|
|
||||||
meta: None,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if response.is_error.unwrap_or_default() {
|
|
||||||
return Err(anyhow!(response.text_contents()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = serde_json::from_value::<acp::NewSessionOutput>(
|
|
||||||
response.structured_content.context("Empty response")?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let thread =
|
|
||||||
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
|
|
||||||
|
|
||||||
thread_tx.send(thread.downgrade())?;
|
|
||||||
|
|
||||||
let session = CodexSession {
|
|
||||||
thread: thread.downgrade(),
|
|
||||||
cancel_tx: None,
|
|
||||||
_mcp_server: mcp_server,
|
|
||||||
};
|
|
||||||
sessions.borrow_mut().insert(result.session_id, session);
|
|
||||||
|
|
||||||
Ok(thread)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
|
||||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prompt(
|
|
||||||
&self,
|
|
||||||
params: agent_client_protocol::PromptArguments,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Task<Result<()>> {
|
|
||||||
let client = self.client.client();
|
|
||||||
let sessions = self.sessions.clone();
|
|
||||||
|
|
||||||
cx.foreground_executor().spawn(async move {
|
|
||||||
let client = client.context("MCP server is not initialized yet")?;
|
|
||||||
|
|
||||||
let (new_cancel_tx, cancel_rx) = oneshot::channel();
|
|
||||||
{
|
|
||||||
let mut sessions = sessions.borrow_mut();
|
|
||||||
let session = sessions
|
|
||||||
.get_mut(¶ms.session_id)
|
|
||||||
.context("Session not found")?;
|
|
||||||
session.cancel_tx.replace(new_cancel_tx);
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = client
|
|
||||||
.request_with::<requests::CallTool>(
|
|
||||||
context_server::types::CallToolParams {
|
|
||||||
name: acp::PROMPT_TOOL_NAME.into(),
|
|
||||||
arguments: Some(serde_json::to_value(params)?),
|
|
||||||
meta: None,
|
|
||||||
},
|
|
||||||
Some(cancel_rx),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if let Err(err) = &result
|
|
||||||
&& err.is::<context_server::client::RequestCanceled>()
|
|
||||||
{
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = result?;
|
|
||||||
|
|
||||||
if response.is_error.unwrap_or_default() {
|
|
||||||
return Err(anyhow!(response.text_contents()));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
|
|
||||||
let mut sessions = self.sessions.borrow_mut();
|
|
||||||
|
|
||||||
if let Some(cancel_tx) = sessions
|
|
||||||
.get_mut(session_id)
|
|
||||||
.and_then(|session| session.cancel_tx.take())
|
|
||||||
{
|
|
||||||
cancel_tx.send(()).ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CodexConnection {
|
|
||||||
pub fn handle_session_notification(
|
|
||||||
notification: acp::SessionNotification,
|
|
||||||
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) {
|
|
||||||
let threads = threads.borrow();
|
|
||||||
let Some(thread) = threads
|
|
||||||
.get(¬ification.session_id)
|
|
||||||
.and_then(|session| session.thread.upgrade())
|
|
||||||
else {
|
|
||||||
log::error!(
|
|
||||||
"Thread not found for session ID: {}",
|
|
||||||
notification.session_id
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.handle_session_update(notification.update, cx)
|
|
||||||
})
|
|
||||||
.log_err();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for CodexConnection {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
self.client.stop().log_err();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
pub(crate) mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::AgentServerCommand;
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
crate::common_e2e_tests!(Codex, allow_option_id = "approve");
|
|
||||||
|
|
||||||
pub fn local_command() -> AgentServerCommand {
|
|
||||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
|
||||||
.join("../../../codex/codex-rs/target/debug/codex");
|
|
||||||
|
|
||||||
AgentServerCommand {
|
|
||||||
path: cli_path,
|
|
||||||
args: vec![],
|
|
||||||
env: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -375,9 +375,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
||||||
gemini: Some(AgentServerSettings {
|
gemini: Some(AgentServerSettings {
|
||||||
command: crate::gemini::tests::local_command(),
|
command: crate::gemini::tests::local_command(),
|
||||||
}),
|
}),
|
||||||
codex: Some(AgentServerSettings {
|
|
||||||
command: crate::codex::tests::local_command(),
|
|
||||||
}),
|
|
||||||
},
|
},
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
|
|
@ -1,14 +1,10 @@
|
||||||
use anyhow::anyhow;
|
|
||||||
use std::cell::RefCell;
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use util::ResultExt as _;
|
|
||||||
|
|
||||||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
|
use crate::{AgentServer, AgentServerCommand};
|
||||||
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
|
use acp_thread::AgentConnection;
|
||||||
use agentic_coding_protocol as acp_old;
|
use anyhow::Result;
|
||||||
use anyhow::{Context as _, Result};
|
use gpui::{Entity, Task};
|
||||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use ui::App;
|
use ui::App;
|
||||||
|
@ -43,146 +39,25 @@ impl AgentServer for Gemini {
|
||||||
project: &Entity<Project>,
|
project: &Entity<Project>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||||
let root_dir = root_dir.to_path_buf();
|
|
||||||
let project = project.clone();
|
let project = project.clone();
|
||||||
let this = self.clone();
|
let root_dir = root_dir.to_path_buf();
|
||||||
let name = self.name();
|
let server_name = self.name();
|
||||||
|
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
let command = this.command(&project, cx).await?;
|
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||||
|
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||||
|
})?;
|
||||||
|
|
||||||
let mut child = util::command::new_smol_command(&command.path)
|
let Some(command) =
|
||||||
.args(command.args.iter())
|
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
||||||
.current_dir(root_dir)
|
else {
|
||||||
.stdin(std::process::Stdio::piped())
|
anyhow::bail!("Failed to find gemini binary");
|
||||||
.stdout(std::process::Stdio::piped())
|
};
|
||||||
.stderr(std::process::Stdio::inherit())
|
|
||||||
.kill_on_drop(true)
|
|
||||||
.spawn()?;
|
|
||||||
|
|
||||||
let stdin = child.stdin.take().unwrap();
|
crate::acp::connect(server_name, command, &root_dir, cx).await
|
||||||
let stdout = child.stdout.take().unwrap();
|
|
||||||
|
|
||||||
let foreground_executor = cx.foreground_executor().clone();
|
|
||||||
|
|
||||||
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
|
|
||||||
|
|
||||||
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
|
|
||||||
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
|
|
||||||
stdin,
|
|
||||||
stdout,
|
|
||||||
move |fut| foreground_executor.spawn(fut).detach(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let io_task = cx.background_spawn(async move {
|
|
||||||
io_fut.await.log_err();
|
|
||||||
});
|
|
||||||
|
|
||||||
let child_status = cx.background_spawn(async move {
|
|
||||||
let result = match child.status().await {
|
|
||||||
Err(e) => Err(anyhow!(e)),
|
|
||||||
Ok(result) if result.success() => Ok(()),
|
|
||||||
Ok(result) => {
|
|
||||||
if let Some(AgentServerVersion::Unsupported {
|
|
||||||
error_message,
|
|
||||||
upgrade_message,
|
|
||||||
upgrade_command,
|
|
||||||
}) = this.version(&command).await.log_err()
|
|
||||||
{
|
|
||||||
Err(anyhow!(LoadError::Unsupported {
|
|
||||||
error_message,
|
|
||||||
upgrade_message,
|
|
||||||
upgrade_command
|
|
||||||
}))
|
|
||||||
} else {
|
|
||||||
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
drop(io_task);
|
|
||||||
result
|
|
||||||
});
|
|
||||||
|
|
||||||
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
|
|
||||||
name,
|
|
||||||
connection,
|
|
||||||
child_status,
|
|
||||||
current_thread: thread_rc,
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(connection)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Gemini {
|
|
||||||
async fn command(
|
|
||||||
&self,
|
|
||||||
project: &Entity<Project>,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<AgentServerCommand> {
|
|
||||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
|
||||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
|
||||||
})?;
|
|
||||||
|
|
||||||
if let Some(command) =
|
|
||||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
|
||||||
{
|
|
||||||
return Ok(command);
|
|
||||||
};
|
|
||||||
|
|
||||||
let (fs, node_runtime) = project.update(cx, |project, _| {
|
|
||||||
(project.fs().clone(), project.node_runtime().cloned())
|
|
||||||
})?;
|
|
||||||
let node_runtime = node_runtime.context("gemini not found on path")?;
|
|
||||||
|
|
||||||
let directory = ::paths::agent_servers_dir().join("gemini");
|
|
||||||
fs.create_dir(&directory).await?;
|
|
||||||
node_runtime
|
|
||||||
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
|
|
||||||
.await?;
|
|
||||||
let path = directory.join("node_modules/.bin/gemini");
|
|
||||||
|
|
||||||
Ok(AgentServerCommand {
|
|
||||||
path,
|
|
||||||
args: vec![ACP_ARG.into()],
|
|
||||||
env: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
|
|
||||||
let version_fut = util::command::new_smol_command(&command.path)
|
|
||||||
.args(command.args.iter())
|
|
||||||
.arg("--version")
|
|
||||||
.kill_on_drop(true)
|
|
||||||
.output();
|
|
||||||
|
|
||||||
let help_fut = util::command::new_smol_command(&command.path)
|
|
||||||
.args(command.args.iter())
|
|
||||||
.arg("--help")
|
|
||||||
.kill_on_drop(true)
|
|
||||||
.output();
|
|
||||||
|
|
||||||
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
|
|
||||||
|
|
||||||
let current_version = String::from_utf8(version_output?.stdout)?;
|
|
||||||
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
|
|
||||||
|
|
||||||
if supported {
|
|
||||||
Ok(AgentServerVersion::Supported)
|
|
||||||
} else {
|
|
||||||
Ok(AgentServerVersion::Unsupported {
|
|
||||||
error_message: format!(
|
|
||||||
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
|
|
||||||
current_version
|
|
||||||
).into(),
|
|
||||||
upgrade_message: "Upgrade Gemini to Latest".into(),
|
|
||||||
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) mod tests {
|
pub(crate) mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -199,7 +74,7 @@ pub(crate) mod tests {
|
||||||
|
|
||||||
AgentServerCommand {
|
AgentServerCommand {
|
||||||
path: "node".into(),
|
path: "node".into(),
|
||||||
args: vec![cli_path, ACP_ARG.into()],
|
args: vec![cli_path],
|
||||||
env: None,
|
env: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,207 +0,0 @@
|
||||||
use acp_thread::AcpThread;
|
|
||||||
use agent_client_protocol as acp;
|
|
||||||
use anyhow::Result;
|
|
||||||
use context_server::listener::{McpServerTool, ToolResponse};
|
|
||||||
use context_server::types::{
|
|
||||||
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
|
|
||||||
ToolsCapabilities, requests,
|
|
||||||
};
|
|
||||||
use futures::channel::oneshot;
|
|
||||||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
|
||||||
use indoc::indoc;
|
|
||||||
|
|
||||||
pub struct ZedMcpServer {
|
|
||||||
server: context_server::listener::McpServer,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const SERVER_NAME: &str = "zed";
|
|
||||||
|
|
||||||
impl ZedMcpServer {
|
|
||||||
pub async fn new(
|
|
||||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
|
||||||
cx: &AsyncApp,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
|
||||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
|
||||||
|
|
||||||
mcp_server.add_tool(RequestPermissionTool {
|
|
||||||
thread_rx: thread_rx.clone(),
|
|
||||||
});
|
|
||||||
mcp_server.add_tool(ReadTextFileTool {
|
|
||||||
thread_rx: thread_rx.clone(),
|
|
||||||
});
|
|
||||||
mcp_server.add_tool(WriteTextFileTool {
|
|
||||||
thread_rx: thread_rx.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(Self { server: mcp_server })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn server_config(&self) -> Result<acp::McpServerConfig> {
|
|
||||||
#[cfg(not(test))]
|
|
||||||
let zed_path = anyhow::Context::context(
|
|
||||||
std::env::current_exe(),
|
|
||||||
"finding current executable path for use in mcp_server",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
let zed_path = crate::e2e_tests::get_zed_path();
|
|
||||||
|
|
||||||
Ok(acp::McpServerConfig {
|
|
||||||
command: zed_path,
|
|
||||||
args: vec![
|
|
||||||
"--nc".into(),
|
|
||||||
self.server.socket_path().display().to_string(),
|
|
||||||
],
|
|
||||||
env: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
|
|
||||||
cx.foreground_executor().spawn(async move {
|
|
||||||
Ok(InitializeResponse {
|
|
||||||
protocol_version: ProtocolVersion("2025-06-18".into()),
|
|
||||||
capabilities: ServerCapabilities {
|
|
||||||
experimental: None,
|
|
||||||
logging: None,
|
|
||||||
completions: None,
|
|
||||||
prompts: None,
|
|
||||||
resources: None,
|
|
||||||
tools: Some(ToolsCapabilities {
|
|
||||||
list_changed: Some(false),
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
server_info: Implementation {
|
|
||||||
name: SERVER_NAME.into(),
|
|
||||||
version: "0.1.0".into(),
|
|
||||||
},
|
|
||||||
meta: None,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tools
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct RequestPermissionTool {
|
|
||||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpServerTool for RequestPermissionTool {
|
|
||||||
type Input = acp::RequestPermissionArguments;
|
|
||||||
type Output = acp::RequestPermissionOutput;
|
|
||||||
|
|
||||||
const NAME: &'static str = "Confirmation";
|
|
||||||
|
|
||||||
fn description(&self) -> &'static str {
|
|
||||||
indoc! {"
|
|
||||||
Request permission for tool calls.
|
|
||||||
|
|
||||||
This tool is meant to be called programmatically by the agent loop, not the LLM.
|
|
||||||
"}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run(
|
|
||||||
&self,
|
|
||||||
input: Self::Input,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<ToolResponse<Self::Output>> {
|
|
||||||
let mut thread_rx = self.thread_rx.clone();
|
|
||||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
|
||||||
anyhow::bail!("Thread closed");
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.request_tool_call_permission(input.tool_call, input.options, cx)
|
|
||||||
})?
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let outcome = match result {
|
|
||||||
Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id },
|
|
||||||
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(ToolResponse {
|
|
||||||
content: vec![],
|
|
||||||
structured_content: acp::RequestPermissionOutput { outcome },
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct ReadTextFileTool {
|
|
||||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpServerTool for ReadTextFileTool {
|
|
||||||
type Input = acp::ReadTextFileArguments;
|
|
||||||
type Output = acp::ReadTextFileOutput;
|
|
||||||
|
|
||||||
const NAME: &'static str = "Read";
|
|
||||||
|
|
||||||
fn description(&self) -> &'static str {
|
|
||||||
"Reads the content of the given file in the project including unsaved changes."
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run(
|
|
||||||
&self,
|
|
||||||
input: Self::Input,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<ToolResponse<Self::Output>> {
|
|
||||||
let mut thread_rx = self.thread_rx.clone();
|
|
||||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
|
||||||
anyhow::bail!("Thread closed");
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.read_text_file(input.path, input.line, input.limit, false, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(ToolResponse {
|
|
||||||
content: vec![],
|
|
||||||
structured_content: acp::ReadTextFileOutput { content },
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct WriteTextFileTool {
|
|
||||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl McpServerTool for WriteTextFileTool {
|
|
||||||
type Input = acp::WriteTextFileArguments;
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
const NAME: &'static str = "Write";
|
|
||||||
|
|
||||||
fn description(&self) -> &'static str {
|
|
||||||
"Write to a file replacing its contents"
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run(
|
|
||||||
&self,
|
|
||||||
input: Self::Input,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<ToolResponse<Self::Output>> {
|
|
||||||
let mut thread_rx = self.thread_rx.clone();
|
|
||||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
|
||||||
anyhow::bail!("Thread closed");
|
|
||||||
};
|
|
||||||
|
|
||||||
thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.write_text_file(input.path, input.content, cx)
|
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(ToolResponse {
|
|
||||||
content: vec![],
|
|
||||||
structured_content: (),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -13,7 +13,6 @@ pub fn init(cx: &mut App) {
|
||||||
pub struct AllAgentServersSettings {
|
pub struct AllAgentServersSettings {
|
||||||
pub gemini: Option<AgentServerSettings>,
|
pub gemini: Option<AgentServerSettings>,
|
||||||
pub claude: Option<AgentServerSettings>,
|
pub claude: Option<AgentServerSettings>,
|
||||||
pub codex: Option<AgentServerSettings>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||||
|
@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings {
|
||||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||||
let mut settings = AllAgentServersSettings::default();
|
let mut settings = AllAgentServersSettings::default();
|
||||||
|
|
||||||
for AllAgentServersSettings {
|
for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
|
||||||
gemini,
|
|
||||||
claude,
|
|
||||||
codex,
|
|
||||||
} in sources.defaults_and_customizations()
|
|
||||||
{
|
|
||||||
if gemini.is_some() {
|
if gemini.is_some() {
|
||||||
settings.gemini = gemini.clone();
|
settings.gemini = gemini.clone();
|
||||||
}
|
}
|
||||||
if claude.is_some() {
|
if claude.is_some() {
|
||||||
settings.claude = claude.clone();
|
settings.claude = claude.clone();
|
||||||
}
|
}
|
||||||
if codex.is_some() {
|
|
||||||
settings.codex = codex.clone();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(settings)
|
Ok(settings)
|
||||||
|
|
|
@ -246,7 +246,7 @@ impl AcpThreadView {
|
||||||
{
|
{
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let mut cx = cx.clone();
|
let mut cx = cx.clone();
|
||||||
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
|
if e.is::<acp_thread::AuthRequired>() {
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -719,13 +719,18 @@ impl AcpThreadView {
|
||||||
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
fn authenticate(
|
||||||
|
&mut self,
|
||||||
|
method: acp::AuthMethodId,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
self.last_error.take();
|
self.last_error.take();
|
||||||
let authenticate = connection.authenticate(cx);
|
let authenticate = connection.authenticate(method, cx);
|
||||||
self.auth_task = Some(cx.spawn_in(window, {
|
self.auth_task = Some(cx.spawn_in(window, {
|
||||||
let project = self.project.clone();
|
let project = self.project.clone();
|
||||||
let agent = self.agent.clone();
|
let agent = self.agent.clone();
|
||||||
|
@ -2424,22 +2429,26 @@ impl Render for AcpThreadView {
|
||||||
.on_action(cx.listener(Self::next_history_message))
|
.on_action(cx.listener(Self::next_history_message))
|
||||||
.on_action(cx.listener(Self::open_agent_diff))
|
.on_action(cx.listener(Self::open_agent_diff))
|
||||||
.child(match &self.thread_state {
|
.child(match &self.thread_state {
|
||||||
ThreadState::Unauthenticated { .. } => {
|
ThreadState::Unauthenticated { connection } => v_flex()
|
||||||
v_flex()
|
.p_2()
|
||||||
.p_2()
|
.flex_1()
|
||||||
.flex_1()
|
.items_center()
|
||||||
.items_center()
|
.justify_center()
|
||||||
.justify_center()
|
.child(self.render_pending_auth_state())
|
||||||
.child(self.render_pending_auth_state())
|
.child(h_flex().mt_1p5().justify_center().children(
|
||||||
.child(
|
connection.auth_methods().into_iter().map(|method| {
|
||||||
h_flex().mt_1p5().justify_center().child(
|
Button::new(
|
||||||
Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
|
SharedString::from(method.id.0.clone()),
|
||||||
.on_click(cx.listener(|this, _, window, cx| {
|
method.label.clone(),
|
||||||
this.authenticate(window, cx)
|
)
|
||||||
})),
|
.on_click({
|
||||||
),
|
let method_id = method.id.clone();
|
||||||
)
|
cx.listener(move |this, _, window, cx| {
|
||||||
}
|
this.authenticate(method_id.clone(), window, cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
)),
|
||||||
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
||||||
ThreadState::LoadError(e) => v_flex()
|
ThreadState::LoadError(e) => v_flex()
|
||||||
.p_2()
|
.p_2()
|
||||||
|
@ -2878,8 +2887,8 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentConnection for StubAgentConnection {
|
impl AgentConnection for StubAgentConnection {
|
||||||
fn name(&self) -> &'static str {
|
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||||
"StubAgentConnection"
|
&[]
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
|
@ -2897,17 +2906,21 @@ mod tests {
|
||||||
.into(),
|
.into(),
|
||||||
);
|
);
|
||||||
let thread = cx
|
let thread = cx
|
||||||
.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))
|
.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||||
Task::ready(Ok(thread))
|
Task::ready(Ok(thread))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
|
fn authenticate(
|
||||||
|
&self,
|
||||||
|
_method_id: acp::AuthMethodId,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Task<gpui::Result<()>> {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> {
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
|
||||||
let sessions = self.sessions.lock();
|
let sessions = self.sessions.lock();
|
||||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||||
let mut tasks = vec![];
|
let mut tasks = vec![];
|
||||||
|
@ -2954,10 +2967,6 @@ mod tests {
|
||||||
struct SaboteurAgentConnection;
|
struct SaboteurAgentConnection;
|
||||||
|
|
||||||
impl AgentConnection for SaboteurAgentConnection {
|
impl AgentConnection for SaboteurAgentConnection {
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
"SaboteurAgentConnection"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
@ -2965,15 +2974,31 @@ mod tests {
|
||||||
cx: &mut gpui::AsyncApp,
|
cx: &mut gpui::AsyncApp,
|
||||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||||
Task::ready(Ok(cx
|
Task::ready(Ok(cx
|
||||||
.new(|cx| AcpThread::new(self, project, SessionId("test".into()), cx))
|
.new(|cx| {
|
||||||
|
AcpThread::new(
|
||||||
|
"SaboteurAgentConnection",
|
||||||
|
self,
|
||||||
|
project,
|
||||||
|
SessionId("test".into()),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})
|
||||||
.unwrap()))
|
.unwrap()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
|
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||||
|
&[]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(
|
||||||
|
&self,
|
||||||
|
_method_id: acp::AuthMethodId,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Task<gpui::Result<()>> {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt(&self, _params: acp::PromptArguments, _cx: &mut App) -> Task<gpui::Result<()>> {
|
fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task<gpui::Result<()>> {
|
||||||
Task::ready(Err(anyhow::anyhow!("Error prompting")))
|
Task::ready(Err(anyhow::anyhow!("Error prompting")))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1987,20 +1987,6 @@ impl AgentPanel {
|
||||||
);
|
);
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
.item(
|
|
||||||
ContextMenuEntry::new("New Codex Thread")
|
|
||||||
.icon(IconName::AiOpenAi)
|
|
||||||
.icon_color(Color::Muted)
|
|
||||||
.handler(move |window, cx| {
|
|
||||||
window.dispatch_action(
|
|
||||||
NewExternalAgentThread {
|
|
||||||
agent: Some(crate::ExternalAgent::Codex),
|
|
||||||
}
|
|
||||||
.boxed_clone(),
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
});
|
});
|
||||||
menu
|
menu
|
||||||
}))
|
}))
|
||||||
|
@ -2662,25 +2648,6 @@ impl AgentPanel {
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
|
||||||
.child(
|
|
||||||
NewThreadButton::new(
|
|
||||||
"new-codex-thread-btn",
|
|
||||||
"New Codex Thread",
|
|
||||||
IconName::AiOpenAi,
|
|
||||||
)
|
|
||||||
.on_click(
|
|
||||||
|window, cx| {
|
|
||||||
window.dispatch_action(
|
|
||||||
Box::new(NewExternalAgentThread {
|
|
||||||
agent: Some(
|
|
||||||
crate::ExternalAgent::Codex,
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -150,7 +150,6 @@ enum ExternalAgent {
|
||||||
#[default]
|
#[default]
|
||||||
Gemini,
|
Gemini,
|
||||||
ClaudeCode,
|
ClaudeCode,
|
||||||
Codex,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ExternalAgent {
|
impl ExternalAgent {
|
||||||
|
@ -158,7 +157,6 @@ impl ExternalAgent {
|
||||||
match self {
|
match self {
|
||||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||||
ExternalAgent::Codex => Rc::new(agent_servers::Codex),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -441,14 +441,12 @@ impl Client {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
pub fn on_notification(
|
||||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
&self,
|
||||||
where
|
method: &'static str,
|
||||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||||
{
|
) {
|
||||||
self.notification_handlers
|
self.notification_handlers.lock().insert(method, f);
|
||||||
.lock()
|
|
||||||
.insert(method, Box::new(f));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -95,8 +95,28 @@ impl ContextServer {
|
||||||
self.client.read().clone()
|
self.client.read().clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
|
||||||
let client = match &self.configuration {
|
self.initialize(self.new_client(cx)?).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Starts the context server, making sure handlers are registered before initialization happens
|
||||||
|
pub async fn start_with_handlers(
|
||||||
|
&self,
|
||||||
|
notification_handlers: Vec<(
|
||||||
|
&'static str,
|
||||||
|
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
|
||||||
|
)>,
|
||||||
|
cx: &AsyncApp,
|
||||||
|
) -> Result<()> {
|
||||||
|
let client = self.new_client(cx)?;
|
||||||
|
for (method, handler) in notification_handlers {
|
||||||
|
client.on_notification(method, handler);
|
||||||
|
}
|
||||||
|
self.initialize(client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
|
||||||
|
Ok(match &self.configuration {
|
||||||
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
|
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
|
||||||
client::ContextServerId(self.id.0.clone()),
|
client::ContextServerId(self.id.0.clone()),
|
||||||
client::ModelContextServerBinary {
|
client::ModelContextServerBinary {
|
||||||
|
@ -113,8 +133,7 @@ impl ContextServer {
|
||||||
transport.clone(),
|
transport.clone(),
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)?,
|
)?,
|
||||||
};
|
})
|
||||||
self.initialize(client).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn initialize(&self, client: Client) -> Result<()> {
|
async fn initialize(&self, client: Client) -> Result<()> {
|
||||||
|
|
|
@ -83,14 +83,18 @@ impl McpServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
|
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
|
||||||
let output_schema = schemars::schema_for!(T::Output);
|
let mut settings = schemars::generate::SchemaSettings::draft07();
|
||||||
let unit_schema = schemars::schema_for!(());
|
settings.inline_subschemas = true;
|
||||||
|
let mut generator = settings.into_generator();
|
||||||
|
|
||||||
|
let output_schema = generator.root_schema_for::<T::Output>();
|
||||||
|
let unit_schema = generator.root_schema_for::<T::Output>();
|
||||||
|
|
||||||
let registered_tool = RegisteredTool {
|
let registered_tool = RegisteredTool {
|
||||||
tool: Tool {
|
tool: Tool {
|
||||||
name: T::NAME.into(),
|
name: T::NAME.into(),
|
||||||
description: Some(tool.description().into()),
|
description: Some(tool.description().into()),
|
||||||
input_schema: schemars::schema_for!(T::Input).into(),
|
input_schema: generator.root_schema_for::<T::Input>().into(),
|
||||||
output_schema: if output_schema == unit_schema {
|
output_schema: if output_schema == unit_schema {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
|
||||||
self.inner.notify(T::METHOD, params)
|
self.inner.notify(T::METHOD, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
pub fn on_notification(
|
||||||
where
|
&self,
|
||||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
method: &'static str,
|
||||||
{
|
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||||
|
) {
|
||||||
self.inner.on_notification(method, f);
|
self.inner.on_notification(method, f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue