agent: Fix MCP server handler subscription race condition (#32133)

Closes #32132

Release Notes:

- Fixed MCP server handler subscription race condition causing tools to
not load.

---------

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
Jonathan LEI 2025-06-06 21:32:06 +08:00 committed by GitHub
parent 380d8c5662
commit 6ea4d2b30d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 113 additions and 106 deletions

View file

@ -514,11 +514,14 @@ impl ThreadStore {
} }
fn register_context_server_handlers(&self, cx: &mut Context<Self>) { fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
cx.subscribe( let context_server_store = self.project.read(cx).context_server_store();
&self.project.read(cx).context_server_store(), cx.subscribe(&context_server_store, Self::handle_context_server_event)
Self::handle_context_server_event, .detach();
)
.detach(); // Check for any servers that were already running before the handler was registered
for server in context_server_store.read(cx).running_servers() {
self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
}
} }
fn handle_context_server_event( fn handle_context_server_event(
@ -533,55 +536,7 @@ impl ThreadStore {
match status { match status {
ContextServerStatus::Starting => {} ContextServerStatus::Starting => {}
ContextServerStatus::Running => { ContextServerStatus::Running => {
if let Some(server) = self.load_context_server_tools(server_id.clone(), context_server_store, cx);
context_server_store.read(cx).get_running_server(server_id)
{
let context_server_manager = context_server_store.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>()
})
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, _| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
})
.log_err();
}
}
}
}
})
.detach();
}
} }
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
@ -594,6 +549,52 @@ impl ThreadStore {
} }
} }
} }
fn load_context_server_tools(
&self,
server_id: ContextServerId,
context_server_store: Entity<ContextServerStore>,
cx: &mut Context<Self>,
) {
let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
return;
};
let tool_working_set = self.tools.clone();
cx.spawn(async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!("registering context server tool: {:?}", tool.name);
tool_working_set.insert(Arc::new(ContextServerTool::new(
context_server_store.clone(),
server.id(),
tool,
)))
})
.collect::<Vec<_>>()
})
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, _| {
this.context_server_tool_ids.insert(server_id, tool_ids);
})
.log_err();
}
}
}
})
.detach();
}
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -809,74 +809,37 @@ impl ContextStore {
} }
fn register_context_server_handlers(&self, cx: &mut Context<Self>) { fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
cx.subscribe( let context_server_store = self.project.read(cx).context_server_store();
&self.project.read(cx).context_server_store(), cx.subscribe(&context_server_store, Self::handle_context_server_event)
Self::handle_context_server_event, .detach();
)
.detach(); // Check for any servers that were already running before the handler was registered
for server in context_server_store.read(cx).running_servers() {
self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
}
} }
fn handle_context_server_event( fn handle_context_server_event(
&mut self, &mut self,
context_server_manager: Entity<ContextServerStore>, context_server_store: Entity<ContextServerStore>,
event: &project::context_server_store::Event, event: &project::context_server_store::Event,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let slash_command_working_set = self.slash_commands.clone();
match event { match event {
project::context_server_store::Event::ServerStatusChanged { server_id, status } => { project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status { match status {
ContextServerStatus::Running => { ContextServerStatus::Running => {
if let Some(server) = context_server_manager self.load_context_server_slash_commands(
.read(cx) server_id.clone(),
.get_running_server(server_id) context_server_store.clone(),
{ cx,
let context_server_manager = context_server_manager.clone(); );
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() {
let slash_command_ids = prompts
.into_iter()
.filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| {
log::info!(
"registering context server command: {:?}",
prompt.name
);
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_manager.clone(),
server.id(),
prompt,
),
))
})
.collect::<Vec<_>>();
this.update( cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
})
.log_err();
}
}
}
})
.detach();
}
} }
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(slash_command_ids) = if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id) self.context_server_slash_command_ids.remove(server_id)
{ {
slash_command_working_set.remove(&slash_command_ids); self.slash_commands.remove(&slash_command_ids);
} }
} }
_ => {} _ => {}
@ -884,4 +847,47 @@ impl ContextStore {
} }
} }
} }
fn load_context_server_slash_commands(
&self,
server_id: ContextServerId,
context_server_store: Entity<ContextServerStore>,
cx: &mut Context<Self>,
) {
let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
return;
};
let slash_command_working_set = self.slash_commands.clone();
cx.spawn(async move |this, cx| {
let Some(protocol) = server.client() else {
return;
};
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() {
let slash_command_ids = prompts
.into_iter()
.filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| {
log::info!("registering context server command: {:?}", prompt.name);
slash_command_working_set.insert(Arc::new(
assistant_slash_commands::ContextServerSlashCommand::new(
context_server_store.clone(),
server.id(),
prompt,
),
))
})
.collect::<Vec<_>>();
this.update(cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
})
.log_err();
}
}
})
.detach();
}
} }