Compare commits

...
Sign in to create a new pull request.

14 commits

Author SHA1 Message Date
Agus Zubiaga
02d3043ec5 Rename arg to experimental-mcp 2025-07-30 14:28:01 -03:00
Agus Zubiaga
30739041a4 Merge branch 'mcp-acp-gemini' of github.com:zed-industries/zed into mcp-acp-gemini 2025-07-30 13:33:03 -03:00
Agus Zubiaga
27708143ec Fix auth 2025-07-30 13:30:50 -03:00
Agus Zubiaga
738296345e Inline tool schemas 2025-07-30 11:46:11 -03:00
Ben Brandt
81c111510f
Refactor handling of ContextServer notifications
The notification handler registration is now more explicit, with
handlers set up before server initialization to avoid potential race
conditions.
2025-07-30 15:48:40 +02:00
Ben Brandt
f028ca4d1a
Merge branch 'main' into mcp-acp-gemini 2025-07-30 12:24:01 +02:00
Agus Zubiaga
6656403ce8 Auth WIP 2025-07-29 21:15:00 -03:00
Ben Brandt
254c6be42b
Fix broken test 2025-07-29 10:12:57 +02:00
Ben Brandt
745e4b5f1e
Merge branch 'main' into mcp-acp-gemini 2025-07-29 10:10:28 +02:00
Agus Zubiaga
912ab505b2 Connect to gemini over MCP 2025-07-28 20:04:32 -03:00
Agus Zubiaga
b48faddaf4 Restore gemini change 2025-07-28 18:45:05 -03:00
Agus Zubiaga
477731d77d Merge branch 'main' into mcp-acp-gemini 2025-07-28 18:43:25 -03:00
Agus Zubiaga
ced3d09f10 Extract acp_connection 2025-07-28 18:43:01 -03:00
Agus Zubiaga
0f395df9a8 Update to new schema 2025-07-28 18:02:21 -03:00
17 changed files with 442 additions and 482 deletions

4
Cargo.lock generated
View file

@ -138,9 +138,7 @@ dependencies = [
[[package]] [[package]]
name = "agent-client-protocol" name = "agent-client-protocol"
version = "0.0.11" version = "0.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b"
dependencies = [ dependencies = [
"schemars", "schemars",
"serde", "serde",

View file

@ -414,7 +414,7 @@ zlog_settings = { path = "crates/zlog_settings" }
# #
agentic-coding-protocol = "0.0.10" agentic-coding-protocol = "0.0.10"
agent-client-protocol = "0.0.11" agent-client-protocol = {path="../agent-client-protocol"}
aho-corasick = "1.1" aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14" any_vec = "0.14"

View file

@ -391,7 +391,7 @@ impl ToolCallContent {
cx: &mut App, cx: &mut App,
) -> Self { ) -> Self {
match content { match content {
acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock { acp::ToolCallContent::Content { content } => Self::ContentBlock {
content: ContentBlock::new(content, &language_registry, cx), content: ContentBlock::new(content, &language_registry, cx),
}, },
acp::ToolCallContent::Diff { diff } => Self::Diff { acp::ToolCallContent::Diff { diff } => Self::Diff {
@ -616,6 +616,7 @@ impl Error for LoadError {}
impl AcpThread { impl AcpThread {
pub fn new( pub fn new(
title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>, connection: Rc<dyn AgentConnection>,
project: Entity<Project>, project: Entity<Project>,
session_id: acp::SessionId, session_id: acp::SessionId,
@ -628,7 +629,7 @@ impl AcpThread {
shared_buffers: Default::default(), shared_buffers: Default::default(),
entries: Default::default(), entries: Default::default(),
plan: Default::default(), plan: Default::default(),
title: connection.name().into(), title: title.into(),
project, project,
send_task: None, send_task: None,
connection, connection,
@ -682,14 +683,14 @@ impl AcpThread {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Result<()> { ) -> Result<()> {
match update { match update {
acp::SessionUpdate::UserMessage(content_block) => { acp::SessionUpdate::UserMessageChunk { content } => {
self.push_user_content_block(content_block, cx); self.push_user_content_block(content, cx);
} }
acp::SessionUpdate::AgentMessageChunk(content_block) => { acp::SessionUpdate::AgentMessageChunk { content } => {
self.push_assistant_content_block(content_block, false, cx); self.push_assistant_content_block(content, false, cx);
} }
acp::SessionUpdate::AgentThoughtChunk(content_block) => { acp::SessionUpdate::AgentThoughtChunk { content } => {
self.push_assistant_content_block(content_block, true, cx); self.push_assistant_content_block(content, true, cx);
} }
acp::SessionUpdate::ToolCall(tool_call) => { acp::SessionUpdate::ToolCall(tool_call) => {
self.upsert_tool_call(tool_call, cx); self.upsert_tool_call(tool_call, cx);
@ -957,10 +958,6 @@ impl AcpThread {
cx.notify(); cx.notify();
} }
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
self.connection.authenticate(cx)
}
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub fn send_raw( pub fn send_raw(
&mut self, &mut self,
@ -1601,6 +1598,7 @@ mod tests {
}; };
AcpThread::new( AcpThread::new(
"test",
Rc::new(connection), Rc::new(connection),
project, project,
acp::SessionId("test".into()), acp::SessionId("test".into()),

View file

@ -1,6 +1,6 @@
use std::{path::Path, rc::Rc}; use std::{error::Error, fmt, path::Path, rc::Rc};
use agent_client_protocol as acp; use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;
use gpui::{AsyncApp, Entity, Task}; use gpui::{AsyncApp, Entity, Task};
use project::Project; use project::Project;
@ -9,8 +9,6 @@ use ui::App;
use crate::AcpThread; use crate::AcpThread;
pub trait AgentConnection { pub trait AgentConnection {
fn name(&self) -> &'static str;
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
@ -18,9 +16,21 @@ pub trait AgentConnection {
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>; ) -> Task<Result<Entity<AcpThread>>>;
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>; fn auth_methods(&self) -> Vec<acp::AuthMethod>;
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>; fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
} }
#[derive(Debug)]
pub struct AuthRequired;
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired")
}
}

View file

@ -5,11 +5,11 @@ use anyhow::{Context as _, Result};
use futures::channel::oneshot; use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project; use project::Project;
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc}; use std::{cell::RefCell, path::Path, rc::Rc};
use ui::App; use ui::App;
use util::ResultExt as _; use util::ResultExt as _;
use crate::{AcpThread, AgentConnection}; use crate::{AcpThread, AgentConnection, AuthRequired};
#[derive(Clone)] #[derive(Clone)]
pub struct OldAcpClientDelegate { pub struct OldAcpClientDelegate {
@ -351,16 +351,6 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
} }
} }
#[derive(Debug)]
pub struct Unauthenticated;
impl Error for Unauthenticated {}
impl fmt::Display for Unauthenticated {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unauthenticated")
}
}
pub struct OldAcpAgentConnection { pub struct OldAcpAgentConnection {
pub name: &'static str, pub name: &'static str,
pub connection: acp_old::AgentConnection, pub connection: acp_old::AgentConnection,
@ -369,10 +359,6 @@ pub struct OldAcpAgentConnection {
} }
impl AgentConnection for OldAcpAgentConnection { impl AgentConnection for OldAcpAgentConnection {
fn name(&self) -> &'static str {
self.name
}
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
@ -391,13 +377,13 @@ impl AgentConnection for OldAcpAgentConnection {
let result = acp_old::InitializeParams::response_from_any(result)?; let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated { if !result.is_authenticated {
anyhow::bail!(Unauthenticated) anyhow::bail!(AuthRequired)
} }
cx.update(|cx| { cx.update(|cx| {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into()); let session_id = acp::SessionId("acp-old-no-id".into());
AcpThread::new(self.clone(), project, session_id, cx) AcpThread::new("Gemini", self.clone(), project, session_id, cx)
}); });
current_thread.replace(thread.downgrade()); current_thread.replace(thread.downgrade());
thread thread
@ -405,7 +391,15 @@ impl AgentConnection for OldAcpAgentConnection {
}) })
} }
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> { fn auth_methods(&self) -> Vec<acp::AuthMethod> {
vec![acp::AuthMethod {
id: acp::AuthMethodId("acp-old-no-id".into()),
label: "Log in".into(),
description: None,
}]
}
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let task = self let task = self
.connection .connection
.request_any(acp_old::AuthenticateParams.into_any()); .request_any(acp_old::AuthenticateParams.into_any());

0
crates/agent_servers/acp Normal file
View file

View file

@ -0,0 +1,292 @@
use agent_client_protocol as acp;
use anyhow::anyhow;
use collections::HashMap;
use context_server::listener::McpServerTool;
use context_server::types::requests;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use project::Project;
use smol::stream::StreamExt as _;
use std::cell::RefCell;
use std::rc::Rc;
use std::{path::Path, sync::Arc};
use util::ResultExt;
use anyhow::{Context, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::mcp_server::ZedMcpServer;
use crate::{AgentServerCommand, mcp_server};
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
pub struct AcpConnection {
auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
server_name: &'static str,
context_server: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
_session_update_task: Task<()>,
}
impl AcpConnection {
pub async fn stdio(
server_name: &'static str,
command: AgentServerCommand,
working_directory: Option<Arc<Path>>,
cx: &mut AsyncApp,
) -> Result<Self> {
let context_server: Arc<ContextServer> = ContextServer::stdio(
ContextServerId(format!("{}-mcp-server", server_name).into()),
ContextServerCommand {
path: command.path,
args: command.args,
env: command.env,
},
working_directory,
)
.into();
let (notification_tx, mut notification_rx) = mpsc::unbounded();
let sessions = Rc::new(RefCell::new(HashMap::default()));
let session_update_handler_task = cx.spawn({
let sessions = sessions.clone();
async move |cx| {
while let Some(notification) = notification_rx.next().await {
Self::handle_session_notification(notification, sessions.clone(), cx)
}
}
});
context_server
.start_with_handlers(
vec![(acp::AGENT_METHODS.session_update, {
Box::new(move |notification, _cx| {
let notification_tx = notification_tx.clone();
log::trace!(
"ACP Notification: {}",
serde_json::to_string_pretty(&notification).unwrap()
);
if let Some(notification) =
serde_json::from_value::<acp::SessionNotification>(notification)
.log_err()
{
notification_tx.unbounded_send(notification).ok();
}
})
})],
cx,
)
.await?;
Ok(Self {
auth_methods: Default::default(),
server_name,
context_server,
sessions,
_session_update_task: session_update_handler_task,
})
}
pub fn handle_session_notification(
notification: acp::SessionNotification,
threads: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
cx: &mut AsyncApp,
) {
let threads = threads.borrow();
let Some(thread) = threads
.get(&notification.session_id)
.and_then(|session| session.thread.upgrade())
else {
log::error!(
"Thread not found for session ID: {}",
notification.session_id
);
return;
};
thread
.update(cx, |thread, cx| {
thread.handle_session_update(notification.update, cx)
})
.log_err();
}
}
pub struct AcpSession {
thread: WeakEntity<AcpThread>,
cancel_tx: Option<oneshot::Sender<()>>,
_mcp_server: ZedMcpServer,
}
impl AgentConnection for AcpConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let client = self.context_server.client();
let sessions = self.sessions.clone();
let auth_methods = self.auth_methods.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let client = client.context("MCP server is not initialized yet")?;
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
let response = client
.request::<requests::CallTool>(context_server::types::CallToolParams {
name: acp::AGENT_METHODS.new_session.into(),
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
mcp_servers: vec![mcp_server.server_config()?],
client_tools: acp::ClientTools {
request_permission: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
}),
read_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
}),
write_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
}),
},
cwd,
})?),
meta: None,
})
.await?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
let result = serde_json::from_value::<acp::NewSessionOutput>(
response.structured_content.context("Empty response")?,
)?;
auth_methods.replace(result.auth_methods);
let Some(session_id) = result.session_id else {
anyhow::bail!(AuthRequired);
};
let thread = cx.new(|cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
session_id.clone(),
cx,
)
})?;
thread_tx.send(thread.downgrade())?;
let session = AcpSession {
thread: thread.downgrade(),
cancel_tx: None,
_mcp_server: mcp_server,
};
sessions.borrow_mut().insert(session_id, session);
Ok(thread)
})
}
fn auth_methods(&self) -> Vec<agent_client_protocol::AuthMethod> {
self.auth_methods.borrow().clone()
}
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let client = self.context_server.client();
cx.foreground_executor().spawn(async move {
let params = acp::AuthenticateArguments { method_id };
let response = client
.context("MCP server is not initialized yet")?
.request::<requests::CallTool>(context_server::types::CallToolParams {
name: acp::AGENT_METHODS.authenticate.into(),
arguments: Some(serde_json::to_value(params)?),
meta: None,
})
.await?;
if response.is_error.unwrap_or_default() {
Err(anyhow!(response.text_contents()))
} else {
Ok(())
}
})
}
fn prompt(
&self,
params: agent_client_protocol::PromptArguments,
cx: &mut App,
) -> Task<Result<()>> {
let client = self.context_server.client();
let sessions = self.sessions.clone();
cx.foreground_executor().spawn(async move {
let client = client.context("MCP server is not initialized yet")?;
let (new_cancel_tx, cancel_rx) = oneshot::channel();
{
let mut sessions = sessions.borrow_mut();
let session = sessions
.get_mut(&params.session_id)
.context("Session not found")?;
session.cancel_tx.replace(new_cancel_tx);
}
let result = client
.request_with::<requests::CallTool>(
context_server::types::CallToolParams {
name: acp::AGENT_METHODS.prompt.into(),
arguments: Some(serde_json::to_value(params)?),
meta: None,
},
Some(cancel_rx),
None,
)
.await;
if let Err(err) = &result
&& err.is::<context_server::client::RequestCanceled>()
{
return Ok(());
}
let response = result?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
Ok(())
})
}
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
let mut sessions = self.sessions.borrow_mut();
if let Some(cancel_tx) = sessions
.get_mut(session_id)
.and_then(|session| session.cancel_tx.take())
{
cancel_tx.send(()).ok();
}
}
}
impl Drop for AcpConnection {
fn drop(&mut self) {
self.context_server.stop().log_err();
}
}

View file

@ -1,3 +1,4 @@
mod acp_connection;
mod claude; mod claude;
mod codex; mod codex;
mod gemini; mod gemini;

View file

@ -70,10 +70,6 @@ struct ClaudeAgentConnection {
} }
impl AgentConnection for ClaudeAgentConnection { impl AgentConnection for ClaudeAgentConnection {
fn name(&self) -> &'static str {
ClaudeCode.name()
}
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection {
} }
}); });
let thread = let thread = cx.new(|cx| {
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?; AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
})?;
thread_tx.send(thread.downgrade())?; thread_tx.send(thread.downgrade())?;
@ -186,7 +183,11 @@ impl AgentConnection for ClaudeAgentConnection {
}) })
} }
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> { fn auth_methods(&self) -> Vec<acp::AuthMethod> {
vec![]
}
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Err(anyhow!("Authentication not supported"))) Task::ready(Err(anyhow!("Authentication not supported")))
} }

View file

@ -1,24 +1,14 @@
use agent_client_protocol as acp;
use anyhow::anyhow;
use collections::HashMap;
use context_server::listener::McpServerTool;
use context_server::types::requests;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use smol::stream::StreamExt as _; use std::path::Path;
use std::cell::RefCell;
use std::rc::Rc; use std::rc::Rc;
use std::{path::Path, sync::Arc};
use util::ResultExt;
use anyhow::{Context, Result}; use anyhow::Result;
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use gpui::{App, Entity, Task};
use crate::mcp_server::ZedMcpServer; use crate::acp_connection::AcpConnection;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server}; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
use acp_thread::{AcpThread, AgentConnection}; use acp_thread::AgentConnection;
#[derive(Clone)] #[derive(Clone)]
pub struct Codex; pub struct Codex;
@ -47,6 +37,7 @@ impl AgentServer for Codex {
cx: &mut App, cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> { ) -> Task<Result<Rc<dyn AgentConnection>>> {
let project = project.clone(); let project = project.clone();
let server_name = self.name();
let working_directory = project.read(cx).active_project_directory(cx); let working_directory = project.read(cx).active_project_directory(cx);
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let settings = cx.read_global(|settings: &SettingsStore, _| { let settings = cx.read_global(|settings: &SettingsStore, _| {
@ -58,246 +49,14 @@ impl AgentServer for Codex {
else { else {
anyhow::bail!("Failed to find codex binary"); anyhow::bail!("Failed to find codex binary");
}; };
// todo! check supported version
let client: Arc<ContextServer> = ContextServer::stdio( let conn = AcpConnection::stdio(server_name, command, working_directory, cx).await?;
ContextServerId("codex-mcp-server".into()), Ok(Rc::new(conn) as _)
ContextServerCommand {
path: command.path,
args: command.args,
env: command.env,
},
working_directory,
)
.into();
ContextServer::start(client.clone(), cx).await?;
let (notification_tx, mut notification_rx) = mpsc::unbounded();
client
.client()
.context("Failed to subscribe")?
.on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
move |notification, _cx| {
let notification_tx = notification_tx.clone();
log::trace!(
"ACP Notification: {}",
serde_json::to_string_pretty(&notification).unwrap()
);
if let Some(notification) =
serde_json::from_value::<acp::SessionNotification>(notification)
.log_err()
{
notification_tx.unbounded_send(notification).ok();
}
}
});
let sessions = Rc::new(RefCell::new(HashMap::default()));
let notification_handler_task = cx.spawn({
let sessions = sessions.clone();
async move |cx| {
while let Some(notification) = notification_rx.next().await {
CodexConnection::handle_session_notification(
notification,
sessions.clone(),
cx,
)
}
}
});
let connection = CodexConnection {
client,
sessions,
_notification_handler_task: notification_handler_task,
};
Ok(Rc::new(connection) as _)
}) })
} }
} }
struct CodexConnection {
client: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
_notification_handler_task: Task<()>,
}
struct CodexSession {
thread: WeakEntity<AcpThread>,
cancel_tx: Option<oneshot::Sender<()>>,
_mcp_server: ZedMcpServer,
}
impl AgentConnection for CodexConnection {
fn name(&self) -> &'static str {
"Codex"
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let client = self.client.client();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let client = client.context("MCP server is not initialized yet")?;
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
let response = client
.request::<requests::CallTool>(context_server::types::CallToolParams {
name: acp::NEW_SESSION_TOOL_NAME.into(),
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
mcp_servers: [(
mcp_server::SERVER_NAME.to_string(),
mcp_server.server_config()?,
)]
.into(),
client_tools: acp::ClientTools {
request_permission: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
}),
read_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
}),
write_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
}),
},
cwd,
})?),
meta: None,
})
.await?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
let result = serde_json::from_value::<acp::NewSessionOutput>(
response.structured_content.context("Empty response")?,
)?;
let thread =
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
thread_tx.send(thread.downgrade())?;
let session = CodexSession {
thread: thread.downgrade(),
cancel_tx: None,
_mcp_server: mcp_server,
};
sessions.borrow_mut().insert(result.session_id, session);
Ok(thread)
})
}
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Err(anyhow!("Authentication not supported")))
}
fn prompt(
&self,
params: agent_client_protocol::PromptArguments,
cx: &mut App,
) -> Task<Result<()>> {
let client = self.client.client();
let sessions = self.sessions.clone();
cx.foreground_executor().spawn(async move {
let client = client.context("MCP server is not initialized yet")?;
let (new_cancel_tx, cancel_rx) = oneshot::channel();
{
let mut sessions = sessions.borrow_mut();
let session = sessions
.get_mut(&params.session_id)
.context("Session not found")?;
session.cancel_tx.replace(new_cancel_tx);
}
let result = client
.request_with::<requests::CallTool>(
context_server::types::CallToolParams {
name: acp::PROMPT_TOOL_NAME.into(),
arguments: Some(serde_json::to_value(params)?),
meta: None,
},
Some(cancel_rx),
None,
)
.await;
if let Err(err) = &result
&& err.is::<context_server::client::RequestCanceled>()
{
return Ok(());
}
let response = result?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
Ok(())
})
}
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
let mut sessions = self.sessions.borrow_mut();
if let Some(cancel_tx) = sessions
.get_mut(session_id)
.and_then(|session| session.cancel_tx.take())
{
cancel_tx.send(()).ok();
}
}
}
impl CodexConnection {
pub fn handle_session_notification(
notification: acp::SessionNotification,
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
cx: &mut AsyncApp,
) {
let threads = threads.borrow();
let Some(thread) = threads
.get(&notification.session_id)
.and_then(|session| session.thread.upgrade())
else {
log::error!(
"Thread not found for session ID: {}",
notification.session_id
);
return;
};
thread
.update(cx, |thread, cx| {
thread.handle_session_update(notification.update, cx)
})
.log_err();
}
}
impl Drop for CodexConnection {
fn drop(&mut self) {
self.client.stop().log_err();
}
}
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use super::*; use super::*;

View file

@ -1,14 +1,10 @@
use anyhow::anyhow;
use std::cell::RefCell;
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
use util::ResultExt as _;
use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; use crate::{AgentServer, AgentServerCommand, acp_connection::AcpConnection};
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; use acp_thread::AgentConnection;
use agentic_coding_protocol as acp_old; use anyhow::Result;
use anyhow::{Context as _, Result}; use gpui::{Entity, Task};
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use ui::App; use ui::App;
@ -18,7 +14,7 @@ use crate::AllAgentServersSettings;
#[derive(Clone)] #[derive(Clone)]
pub struct Gemini; pub struct Gemini;
const ACP_ARG: &str = "--experimental-acp"; const MCP_ARG: &str = "--experimental-mcp";
impl AgentServer for Gemini { impl AgentServer for Gemini {
fn name(&self) -> &'static str { fn name(&self) -> &'static str {
@ -39,150 +35,31 @@ impl AgentServer for Gemini {
fn connect( fn connect(
&self, &self,
root_dir: &Path, _root_dir: &Path,
project: &Entity<Project>, project: &Entity<Project>,
cx: &mut App, cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> { ) -> Task<Result<Rc<dyn AgentConnection>>> {
let root_dir = root_dir.to_path_buf();
let project = project.clone(); let project = project.clone();
let this = self.clone(); let server_name = self.name();
let name = self.name(); let working_directory = project.read(cx).active_project_directory(cx);
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let command = this.command(&project, cx).await?; let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).gemini.clone()
})?;
let mut child = util::command::new_smol_command(&command.path) let Some(command) =
.args(command.args.iter()) AgentServerCommand::resolve("gemini", &[MCP_ARG], settings, &project, cx).await
.current_dir(root_dir) else {
.stdin(std::process::Stdio::piped()) anyhow::bail!("Failed to find gemini binary");
.stdout(std::process::Stdio::piped()) };
.stderr(std::process::Stdio::inherit()) // todo! check supported version
.kill_on_drop(true)
.spawn()?;
let stdin = child.stdin.take().unwrap(); let conn = AcpConnection::stdio(server_name, command, working_directory, cx).await?;
let stdout = child.stdout.take().unwrap(); Ok(Rc::new(conn) as _)
let foreground_executor = cx.foreground_executor().clone();
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
stdin,
stdout,
move |fut| foreground_executor.spawn(fut).detach(),
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
});
let child_status = cx.background_spawn(async move {
let result = match child.status().await {
Err(e) => Err(anyhow!(e)),
Ok(result) if result.success() => Ok(()),
Ok(result) => {
if let Some(AgentServerVersion::Unsupported {
error_message,
upgrade_message,
upgrade_command,
}) = this.version(&command).await.log_err()
{
Err(anyhow!(LoadError::Unsupported {
error_message,
upgrade_message,
upgrade_command
}))
} else {
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
}
}
};
drop(io_task);
result
});
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
name,
connection,
child_status,
current_thread: thread_rc,
});
Ok(connection)
}) })
} }
} }
impl Gemini {
async fn command(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<AgentServerCommand> {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).gemini.clone()
})?;
if let Some(command) =
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
{
return Ok(command);
};
let (fs, node_runtime) = project.update(cx, |project, _| {
(project.fs().clone(), project.node_runtime().cloned())
})?;
let node_runtime = node_runtime.context("gemini not found on path")?;
let directory = ::paths::agent_servers_dir().join("gemini");
fs.create_dir(&directory).await?;
node_runtime
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
.await?;
let path = directory.join("node_modules/.bin/gemini");
Ok(AgentServerCommand {
path,
args: vec![ACP_ARG.into()],
env: None,
})
}
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
let version_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--version")
.kill_on_drop(true)
.output();
let help_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--help")
.kill_on_drop(true)
.output();
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
let current_version = String::from_utf8(version_output?.stdout)?;
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
if supported {
Ok(AgentServerVersion::Supported)
} else {
Ok(AgentServerVersion::Unsupported {
error_message: format!(
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
current_version
).into(),
upgrade_message: "Upgrade Gemini to Latest".into(),
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
})
}
}
}
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use super::*; use super::*;
@ -199,7 +76,7 @@ pub(crate) mod tests {
AgentServerCommand { AgentServerCommand {
path: "node".into(), path: "node".into(),
args: vec![cli_path, ACP_ARG.into()], args: vec![cli_path],
env: None, env: None,
} }
} }

View file

@ -37,7 +37,7 @@ impl ZedMcpServer {
Ok(Self { server: mcp_server }) Ok(Self { server: mcp_server })
} }
pub fn server_config(&self) -> Result<acp::McpServerConfig> { pub fn server_config(&self) -> Result<acp::McpServer> {
#[cfg(not(test))] #[cfg(not(test))]
let zed_path = anyhow::Context::context( let zed_path = anyhow::Context::context(
std::env::current_exe(), std::env::current_exe(),
@ -47,13 +47,14 @@ impl ZedMcpServer {
#[cfg(test)] #[cfg(test)]
let zed_path = crate::e2e_tests::get_zed_path(); let zed_path = crate::e2e_tests::get_zed_path();
Ok(acp::McpServerConfig { Ok(acp::McpServer {
name: SERVER_NAME.into(),
command: zed_path, command: zed_path,
args: vec![ args: vec![
"--nc".into(), "--nc".into(),
self.server.socket_path().display().to_string(), self.server.socket_path().display().to_string(),
], ],
env: None, env: vec![],
}) })
} }

View file

@ -223,7 +223,8 @@ impl AcpThreadView {
{ {
Err(e) => { Err(e) => {
let mut cx = cx.clone(); let mut cx = cx.clone();
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() { // todo! remove duplication
if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection }; this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify(); cx.notify();
@ -640,13 +641,18 @@ impl AcpThreadView {
Some(entry.diffs().map(|diff| diff.multibuffer.clone())) Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
} }
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn authenticate(
&mut self,
method: acp::AuthMethodId,
window: &mut Window,
cx: &mut Context<Self>,
) {
let ThreadState::Unauthenticated { ref connection } = self.thread_state else { let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
return; return;
}; };
self.last_error.take(); self.last_error.take();
let authenticate = connection.authenticate(cx); let authenticate = connection.authenticate(method, cx);
self.auth_task = Some(cx.spawn_in(window, { self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone(); let project = self.project.clone();
let agent = self.agent.clone(); let agent = self.agent.clone();
@ -2197,22 +2203,23 @@ impl Render for AcpThreadView {
.on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::next_history_message))
.on_action(cx.listener(Self::open_agent_diff)) .on_action(cx.listener(Self::open_agent_diff))
.child(match &self.thread_state { .child(match &self.thread_state {
ThreadState::Unauthenticated { .. } => { ThreadState::Unauthenticated { connection } => v_flex()
v_flex() .p_2()
.p_2() .flex_1()
.flex_1() .items_center()
.items_center() .justify_center()
.justify_center() .child(self.render_pending_auth_state())
.child(self.render_pending_auth_state()) .child(h_flex().mt_1p5().justify_center().children(
.child( connection.auth_methods().into_iter().map(|method| {
h_flex().mt_1p5().justify_center().child( Button::new(SharedString::from(method.id.0.clone()), method.label)
Button::new("sign-in", format!("Sign in to {}", self.agent.name())) .on_click({
.on_click(cx.listener(|this, _, window, cx| { let method_id = method.id.clone();
this.authenticate(window, cx) cx.listener(move |this, _, window, cx| {
})), this.authenticate(method_id.clone(), window, cx)
), })
) })
} }),
)),
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
ThreadState::LoadError(e) => v_flex() ThreadState::LoadError(e) => v_flex()
.p_2() .p_2()

View file

@ -441,14 +441,12 @@ impl Client {
Ok(()) Ok(())
} }
#[allow(unused)] pub fn on_notification(
pub fn on_notification<F>(&self, method: &'static str, f: F) &self,
where method: &'static str,
F: 'static + Send + FnMut(Value, AsyncApp), f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
{ ) {
self.notification_handlers self.notification_handlers.lock().insert(method, f);
.lock()
.insert(method, Box::new(f));
} }
} }

View file

@ -95,8 +95,28 @@ impl ContextServer {
self.client.read().clone() self.client.read().clone()
} }
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> { pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
let client = match &self.configuration { self.initialize(self.new_client(cx)?).await
}
/// Starts the context server, making sure handlers are registered before initialization happens
pub async fn start_with_handlers(
&self,
notification_handlers: Vec<(
&'static str,
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
)>,
cx: &AsyncApp,
) -> Result<()> {
let client = self.new_client(cx)?;
for (method, handler) in notification_handlers {
client.on_notification(method, handler);
}
self.initialize(client).await
}
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
Ok(match &self.configuration {
ContextServerTransport::Stdio(command, working_directory) => Client::stdio( ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
client::ContextServerId(self.id.0.clone()), client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary { client::ModelContextServerBinary {
@ -113,8 +133,7 @@ impl ContextServer {
transport.clone(), transport.clone(),
cx.clone(), cx.clone(),
)?, )?,
}; })
self.initialize(client).await
} }
async fn initialize(&self, client: Client) -> Result<()> { async fn initialize(&self, client: Client) -> Result<()> {

View file

@ -83,14 +83,18 @@ impl McpServer {
} }
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) { pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
let output_schema = schemars::schema_for!(T::Output); let mut settings = schemars::generate::SchemaSettings::draft07();
let unit_schema = schemars::schema_for!(()); settings.inline_subschemas = true;
let mut generator = settings.into_generator();
let output_schema = generator.root_schema_for::<T::Output>();
let unit_schema = generator.root_schema_for::<T::Output>();
let registered_tool = RegisteredTool { let registered_tool = RegisteredTool {
tool: Tool { tool: Tool {
name: T::NAME.into(), name: T::NAME.into(),
description: Some(tool.description().into()), description: Some(tool.description().into()),
input_schema: schemars::schema_for!(T::Input).into(), input_schema: generator.root_schema_for::<T::Input>().into(),
output_schema: if output_schema == unit_schema { output_schema: if output_schema == unit_schema {
None None
} else { } else {

View file

@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
self.inner.notify(T::METHOD, params) self.inner.notify(T::METHOD, params)
} }
pub fn on_notification<F>(&self, method: &'static str, f: F) pub fn on_notification(
where &self,
F: 'static + Send + FnMut(Value, AsyncApp), method: &'static str,
{ f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
) {
self.inner.on_notification(method, f); self.inner.on_notification(method, f);
} }
} }