From 81c111510f43241cef933567fc2905051dfc5fa3 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 30 Jul 2025 15:48:40 +0200 Subject: [PATCH] Refactor handling of ContextServer notifications The notification handler registration is now more explicit, with handlers set up before server initialization to avoid potential race conditions. --- crates/agent_servers/src/acp_connection.rs | 85 +++++++++++---------- crates/context_server/src/client.rs | 14 ++-- crates/context_server/src/context_server.rs | 27 ++++++- crates/context_server/src/protocol.rs | 9 ++- 4 files changed, 79 insertions(+), 56 deletions(-) diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs index 95c09e2c52..5883f6ac45 100644 --- a/crates/agent_servers/src/acp_connection.rs +++ b/crates/agent_servers/src/acp_connection.rs @@ -22,7 +22,7 @@ use acp_thread::{AcpThread, AgentConnection}; pub struct AcpConnection { agent_state: Rc>, server_name: &'static str, - client: Arc, + context_server: Arc, sessions: Rc>>, _agent_state_task: Task<()>, _session_update_task: Task<()>, @@ -35,7 +35,7 @@ impl AcpConnection { working_directory: Option>, cx: &mut AsyncApp, ) -> Result { - let client: Arc = ContextServer::stdio( + let context_server: Arc = ContextServer::stdio( ContextServerId(format!("{}-mcp-server", server_name).into()), ContextServerCommand { path: command.path, @@ -45,42 +45,9 @@ impl AcpConnection { working_directory, ) .into(); - ContextServer::start(client.clone(), cx).await?; let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default()); - let mcp_client = client.client().context("Failed to subscribe")?; - - mcp_client.on_notification(acp::AGENT_METHODS.agent_state, { - move |notification, _cx| { - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(state) = - serde_json::from_value::(notification).log_err() - { - state_tx.send(state).log_err(); - } - } - }); - let (notification_tx, mut notification_rx) = mpsc::unbounded(); - mcp_client.on_notification(acp::AGENT_METHODS.session_update, { - move |notification, _cx| { - let notification_tx = notification_tx.clone(); - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(notification) = - serde_json::from_value::(notification).log_err() - { - notification_tx.unbounded_send(notification).ok(); - } - } - }); let sessions = Rc::new(RefCell::new(HashMap::default())); let initial_state = state_rx.recv().await?; @@ -104,9 +71,47 @@ impl AcpConnection { } }); + context_server + .start_with_handlers( + vec![ + (acp::AGENT_METHODS.agent_state, { + Box::new(move |notification, _cx| { + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(state) = + serde_json::from_value::(notification).log_err() + { + state_tx.send(state).log_err(); + } + }) + }), + (acp::AGENT_METHODS.session_update, { + Box::new(move |notification, _cx| { + let notification_tx = notification_tx.clone(); + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(notification) = + serde_json::from_value::(notification) + .log_err() + { + notification_tx.unbounded_send(notification).ok(); + } + }) + }), + ], + cx, + ) + .await?; + Ok(Self { server_name, - client, + context_server, sessions, agent_state, _agent_state_task: agent_state_task, @@ -152,7 +157,7 @@ impl AgentConnection for AcpConnection { cwd: &Path, cx: &mut AsyncApp, ) -> Task>> { - let client = self.client.client(); + let client = self.context_server.client(); let sessions = self.sessions.clone(); let cwd = cwd.to_path_buf(); cx.spawn(async move |cx| { @@ -222,7 +227,7 @@ impl AgentConnection for AcpConnection { } fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { - let client = self.client.client(); + let client = self.context_server.client(); cx.foreground_executor().spawn(async move { let params = acp::AuthenticateArguments { method_id }; @@ -248,7 +253,7 @@ impl AgentConnection for AcpConnection { params: agent_client_protocol::PromptArguments, cx: &mut App, ) -> Task> { - let client = self.client.client(); + let client = self.context_server.client(); let sessions = self.sessions.clone(); cx.foreground_executor().spawn(async move { @@ -305,6 +310,6 @@ impl AgentConnection for AcpConnection { impl Drop for AcpConnection { fn drop(&mut self) { - self.client.stop().log_err(); + self.context_server.stop().log_err(); } } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 1eb29bbbf9..65283afa87 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -441,14 +441,12 @@ impl Client { Ok(()) } - #[allow(unused)] - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { - self.notification_handlers - .lock() - .insert(method, Box::new(f)); + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { + self.notification_handlers.lock().insert(method, f); } } diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index e76e7972f7..34fa29678d 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -95,8 +95,28 @@ impl ContextServer { self.client.read().clone() } - pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { - let client = match &self.configuration { + pub async fn start(&self, cx: &AsyncApp) -> Result<()> { + self.initialize(self.new_client(cx)?).await + } + + /// Starts the context server, making sure handlers are registered before initialization happens + pub async fn start_with_handlers( + &self, + notification_handlers: Vec<( + &'static str, + Box, + )>, + cx: &AsyncApp, + ) -> Result<()> { + let client = self.new_client(cx)?; + for (method, handler) in notification_handlers { + client.on_notification(method, handler); + } + self.initialize(client).await + } + + fn new_client(&self, cx: &AsyncApp) -> Result { + Ok(match &self.configuration { ContextServerTransport::Stdio(command, working_directory) => Client::stdio( client::ContextServerId(self.id.0.clone()), client::ModelContextServerBinary { @@ -113,8 +133,7 @@ impl ContextServer { transport.clone(), cx.clone(), )?, - }; - self.initialize(client).await + }) } async fn initialize(&self, client: Client) -> Result<()> { diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 9ccbc8a553..5355f20f62 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -115,10 +115,11 @@ impl InitializedContextServerProtocol { self.inner.notify(T::METHOD, params) } - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { self.inner.on_notification(method, f); } }