agent: Fix MCP server handler subscription race condition (#32133)

Closes #32132

Release Notes:

- Fixed MCP server handler subscription race condition causing tools to
not load.

---------

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
Jonathan LEI 2025-06-06 21:32:06 +08:00 committed by GitHub
parent 380d8c5662
commit 6ea4d2b30d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 113 additions and 106 deletions

View file

@ -514,11 +514,14 @@ impl ThreadStore {
} }
fn register_context_server_handlers(&self, cx: &mut Context<Self>) { fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
cx.subscribe( let context_server_store = self.project.read(cx).context_server_store();
&self.project.read(cx).context_server_store(), cx.subscribe(&context_server_store, Self::handle_context_server_event)
Self::handle_context_server_event,
)
.detach(); .detach();
// Check for any servers that were already running before the handler was registered
for server in context_server_store.read(cx).running_servers() {
self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
}
} }
fn handle_context_server_event( fn handle_context_server_event(
@ -533,14 +536,31 @@ impl ThreadStore {
match status { match status {
ContextServerStatus::Starting => {} ContextServerStatus::Starting => {}
ContextServerStatus::Running => { ContextServerStatus::Running => {
if let Some(server) = self.load_context_server_tools(server_id.clone(), context_server_store, cx);
context_server_store.read(cx).get_running_server(server_id) }
{ ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
let context_server_manager = context_server_store.clone(); if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
cx.spawn({ tool_working_set.update(cx, |tool_working_set, _| {
let server = server.clone(); tool_working_set.remove(&tool_ids);
let server_id = server_id.clone(); });
async move |this, cx| { }
}
}
}
}
}
fn load_context_server_tools(
&self,
server_id: ContextServerId,
context_server_store: Entity<ContextServerStore>,
cx: &mut Context<Self>,
) {
let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
return;
};
let tool_working_set = self.tools.clone();
cx.spawn(async move |this, cx| {
let Some(protocol) = server.client() else { let Some(protocol) = server.client() else {
return; return;
}; };
@ -553,17 +573,12 @@ impl ThreadStore {
.tools .tools
.into_iter() .into_iter()
.map(|tool| { .map(|tool| {
log::info!( log::info!("registering context server tool: {:?}", tool.name);
"registering context server tool: {:?}", tool_working_set.insert(Arc::new(ContextServerTool::new(
tool.name context_server_store.clone(),
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(), server.id(),
tool, tool,
), )))
))
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
}) })
@ -571,30 +586,16 @@ impl ThreadStore {
if let Some(tool_ids) = tool_ids { if let Some(tool_ids) = tool_ids {
this.update(cx, |this, _| { this.update(cx, |this, _| {
this.context_server_tool_ids this.context_server_tool_ids.insert(server_id, tool_ids);
.insert(server_id, tool_ids);
}) })
.log_err(); .log_err();
} }
} }
} }
}
}) })
.detach(); .detach();
} }
} }
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
}
}
}
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedThreadMetadata { pub struct SerializedThreadMetadata {

View file

@ -809,33 +809,56 @@ impl ContextStore {
} }
fn register_context_server_handlers(&self, cx: &mut Context<Self>) { fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
cx.subscribe( let context_server_store = self.project.read(cx).context_server_store();
&self.project.read(cx).context_server_store(), cx.subscribe(&context_server_store, Self::handle_context_server_event)
Self::handle_context_server_event,
)
.detach(); .detach();
// Check for any servers that were already running before the handler was registered
for server in context_server_store.read(cx).running_servers() {
self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
}
} }
fn handle_context_server_event( fn handle_context_server_event(
&mut self, &mut self,
context_server_manager: Entity<ContextServerStore>, context_server_store: Entity<ContextServerStore>,
event: &project::context_server_store::Event, event: &project::context_server_store::Event,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let slash_command_working_set = self.slash_commands.clone();
match event { match event {
project::context_server_store::Event::ServerStatusChanged { server_id, status } => { project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status { match status {
ContextServerStatus::Running => { ContextServerStatus::Running => {
if let Some(server) = context_server_manager self.load_context_server_slash_commands(
.read(cx) server_id.clone(),
.get_running_server(server_id) context_server_store.clone(),
cx,
);
}
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{ {
let context_server_manager = context_server_manager.clone(); self.slash_commands.remove(&slash_command_ids);
cx.spawn({ }
let server = server.clone(); }
let server_id = server_id.clone(); _ => {}
async move |this, cx| { }
}
}
}
fn load_context_server_slash_commands(
&self,
server_id: ContextServerId,
context_server_store: Entity<ContextServerStore>,
cx: &mut Context<Self>,
) {
let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
return;
};
let slash_command_working_set = self.slash_commands.clone();
cx.spawn(async move |this, cx| {
let Some(protocol) = server.client() else { let Some(protocol) = server.client() else {
return; return;
}; };
@ -846,13 +869,10 @@ impl ContextStore {
.into_iter() .into_iter()
.filter(assistant_slash_commands::acceptable_prompt) .filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| { .map(|prompt| {
log::info!( log::info!("registering context server command: {:?}", prompt.name);
"registering context server command: {:?}",
prompt.name
);
slash_command_working_set.insert(Arc::new( slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new( assistant_slash_commands::ContextServerSlashCommand::new(
context_server_manager.clone(), context_server_store.clone(),
server.id(), server.id(),
prompt, prompt,
), ),
@ -867,21 +887,7 @@ impl ContextStore {
.log_err(); .log_err();
} }
} }
}
}) })
.detach(); .detach();
} }
} }
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
slash_command_working_set.remove(&slash_command_ids);
}
}
_ => {}
}
}
}
}
}