Merge branch 'mcp-acp-gemini' of github.com:zed-industries/zed into mcp-acp-gemini

This commit is contained in:
Agus Zubiaga 2025-07-30 13:33:03 -03:00
commit 30739041a4
4 changed files with 63 additions and 41 deletions

View file

@ -22,7 +22,7 @@ use acp_thread::{AcpThread, AgentConnection, AuthRequired};
pub struct AcpConnection { pub struct AcpConnection {
auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>, auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
server_name: &'static str, server_name: &'static str,
client: Arc<context_server::ContextServer>, context_server: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>, sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
_session_update_task: Task<()>, _session_update_task: Task<()>,
} }
@ -34,7 +34,7 @@ impl AcpConnection {
working_directory: Option<Arc<Path>>, working_directory: Option<Arc<Path>>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Result<Self> { ) -> Result<Self> {
let client: Arc<ContextServer> = ContextServer::stdio( let context_server: Arc<ContextServer> = ContextServer::stdio(
ContextServerId(format!("{}-mcp-server", server_name).into()), ContextServerId(format!("{}-mcp-server", server_name).into()),
ContextServerCommand { ContextServerCommand {
path: command.path, path: command.path,
@ -44,26 +44,8 @@ impl AcpConnection {
working_directory, working_directory,
) )
.into(); .into();
ContextServer::start(client.clone(), cx).await?;
let mcp_client = client.client().context("Failed to subscribe")?;
let (notification_tx, mut notification_rx) = mpsc::unbounded(); let (notification_tx, mut notification_rx) = mpsc::unbounded();
mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
move |notification, _cx| {
let notification_tx = notification_tx.clone();
log::trace!(
"ACP Notification: {}",
serde_json::to_string_pretty(&notification).unwrap()
);
if let Some(notification) =
serde_json::from_value::<acp::SessionNotification>(notification).log_err()
{
notification_tx.unbounded_send(notification).ok();
}
}
});
let sessions = Rc::new(RefCell::new(HashMap::default())); let sessions = Rc::new(RefCell::new(HashMap::default()));
@ -76,10 +58,32 @@ impl AcpConnection {
} }
}); });
context_server
.start_with_handlers(
vec![(acp::AGENT_METHODS.session_update, {
Box::new(move |notification, _cx| {
let notification_tx = notification_tx.clone();
log::trace!(
"ACP Notification: {}",
serde_json::to_string_pretty(&notification).unwrap()
);
if let Some(notification) =
serde_json::from_value::<acp::SessionNotification>(notification)
.log_err()
{
notification_tx.unbounded_send(notification).ok();
}
})
})],
cx,
)
.await?;
Ok(Self { Ok(Self {
auth_methods: Default::default(), auth_methods: Default::default(),
server_name, server_name,
client, context_server,
sessions, sessions,
_session_update_task: session_update_handler_task, _session_update_task: session_update_handler_task,
}) })
@ -123,7 +127,7 @@ impl AgentConnection for AcpConnection {
cwd: &Path, cwd: &Path,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> { ) -> Task<Result<Entity<AcpThread>>> {
let client = self.client.client(); let client = self.context_server.client();
let sessions = self.sessions.clone(); let sessions = self.sessions.clone();
let auth_methods = self.auth_methods.clone(); let auth_methods = self.auth_methods.clone();
let cwd = cwd.to_path_buf(); let cwd = cwd.to_path_buf();
@ -200,7 +204,7 @@ impl AgentConnection for AcpConnection {
} }
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> { fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let client = self.client.client(); let client = self.context_server.client();
cx.foreground_executor().spawn(async move { cx.foreground_executor().spawn(async move {
let params = acp::AuthenticateArguments { method_id }; let params = acp::AuthenticateArguments { method_id };
@ -226,7 +230,7 @@ impl AgentConnection for AcpConnection {
params: agent_client_protocol::PromptArguments, params: agent_client_protocol::PromptArguments,
cx: &mut App, cx: &mut App,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
let client = self.client.client(); let client = self.context_server.client();
let sessions = self.sessions.clone(); let sessions = self.sessions.clone();
cx.foreground_executor().spawn(async move { cx.foreground_executor().spawn(async move {
@ -283,6 +287,6 @@ impl AgentConnection for AcpConnection {
impl Drop for AcpConnection { impl Drop for AcpConnection {
fn drop(&mut self) { fn drop(&mut self) {
self.client.stop().log_err(); self.context_server.stop().log_err();
} }
} }

View file

@ -441,14 +441,12 @@ impl Client {
Ok(()) Ok(())
} }
#[allow(unused)] pub fn on_notification(
pub fn on_notification<F>(&self, method: &'static str, f: F) &self,
where method: &'static str,
F: 'static + Send + FnMut(Value, AsyncApp), f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
{ ) {
self.notification_handlers self.notification_handlers.lock().insert(method, f);
.lock()
.insert(method, Box::new(f));
} }
} }

View file

@ -95,8 +95,28 @@ impl ContextServer {
self.client.read().clone() self.client.read().clone()
} }
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> { pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
let client = match &self.configuration { self.initialize(self.new_client(cx)?).await
}
/// Starts the context server, making sure handlers are registered before initialization happens
pub async fn start_with_handlers(
&self,
notification_handlers: Vec<(
&'static str,
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
)>,
cx: &AsyncApp,
) -> Result<()> {
let client = self.new_client(cx)?;
for (method, handler) in notification_handlers {
client.on_notification(method, handler);
}
self.initialize(client).await
}
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
Ok(match &self.configuration {
ContextServerTransport::Stdio(command, working_directory) => Client::stdio( ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
client::ContextServerId(self.id.0.clone()), client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary { client::ModelContextServerBinary {
@ -113,8 +133,7 @@ impl ContextServer {
transport.clone(), transport.clone(),
cx.clone(), cx.clone(),
)?, )?,
}; })
self.initialize(client).await
} }
async fn initialize(&self, client: Client) -> Result<()> { async fn initialize(&self, client: Client) -> Result<()> {

View file

@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
self.inner.notify(T::METHOD, params) self.inner.notify(T::METHOD, params)
} }
pub fn on_notification<F>(&self, method: &'static str, f: F) pub fn on_notification(
where &self,
F: 'static + Send + FnMut(Value, AsyncApp), method: &'static str,
{ f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
) {
self.inner.on_notification(method, f); self.inner.on_notification(method, f);
} }
} }