agent: Fix MCP server handler subscription race condition (#32133)
Closes #32132 Release Notes: - Fixed MCP server handler subscription race condition causing tools to not load. --------- Co-authored-by: Bennet Bo Fenner <bennet@zed.dev> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
parent
380d8c5662
commit
6ea4d2b30d
2 changed files with 113 additions and 106 deletions
|
@ -514,11 +514,14 @@ impl ThreadStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
|
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
|
||||||
cx.subscribe(
|
let context_server_store = self.project.read(cx).context_server_store();
|
||||||
&self.project.read(cx).context_server_store(),
|
cx.subscribe(&context_server_store, Self::handle_context_server_event)
|
||||||
Self::handle_context_server_event,
|
.detach();
|
||||||
)
|
|
||||||
.detach();
|
// Check for any servers that were already running before the handler was registered
|
||||||
|
for server in context_server_store.read(cx).running_servers() {
|
||||||
|
self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_context_server_event(
|
fn handle_context_server_event(
|
||||||
|
@ -533,55 +536,7 @@ impl ThreadStore {
|
||||||
match status {
|
match status {
|
||||||
ContextServerStatus::Starting => {}
|
ContextServerStatus::Starting => {}
|
||||||
ContextServerStatus::Running => {
|
ContextServerStatus::Running => {
|
||||||
if let Some(server) =
|
self.load_context_server_tools(server_id.clone(), context_server_store, cx);
|
||||||
context_server_store.read(cx).get_running_server(server_id)
|
|
||||||
{
|
|
||||||
let context_server_manager = context_server_store.clone();
|
|
||||||
cx.spawn({
|
|
||||||
let server = server.clone();
|
|
||||||
let server_id = server_id.clone();
|
|
||||||
async move |this, cx| {
|
|
||||||
let Some(protocol) = server.client() else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
|
||||||
if let Some(tools) = protocol.list_tools().await.log_err() {
|
|
||||||
let tool_ids = tool_working_set
|
|
||||||
.update(cx, |tool_working_set, _| {
|
|
||||||
tools
|
|
||||||
.tools
|
|
||||||
.into_iter()
|
|
||||||
.map(|tool| {
|
|
||||||
log::info!(
|
|
||||||
"registering context server tool: {:?}",
|
|
||||||
tool.name
|
|
||||||
);
|
|
||||||
tool_working_set.insert(Arc::new(
|
|
||||||
ContextServerTool::new(
|
|
||||||
context_server_manager.clone(),
|
|
||||||
server.id(),
|
|
||||||
tool,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
.log_err();
|
|
||||||
|
|
||||||
if let Some(tool_ids) = tool_ids {
|
|
||||||
this.update(cx, |this, _| {
|
|
||||||
this.context_server_tool_ids
|
|
||||||
.insert(server_id, tool_ids);
|
|
||||||
})
|
|
||||||
.log_err();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||||
|
@ -594,6 +549,52 @@ impl ThreadStore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_context_server_tools(
|
||||||
|
&self,
|
||||||
|
server_id: ContextServerId,
|
||||||
|
context_server_store: Entity<ContextServerStore>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let tool_working_set = self.tools.clone();
|
||||||
|
cx.spawn(async move |this, cx| {
|
||||||
|
let Some(protocol) = server.client() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
||||||
|
if let Some(tools) = protocol.list_tools().await.log_err() {
|
||||||
|
let tool_ids = tool_working_set
|
||||||
|
.update(cx, |tool_working_set, _| {
|
||||||
|
tools
|
||||||
|
.tools
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool| {
|
||||||
|
log::info!("registering context server tool: {:?}", tool.name);
|
||||||
|
tool_working_set.insert(Arc::new(ContextServerTool::new(
|
||||||
|
context_server_store.clone(),
|
||||||
|
server.id(),
|
||||||
|
tool,
|
||||||
|
)))
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
.log_err();
|
||||||
|
|
||||||
|
if let Some(tool_ids) = tool_ids {
|
||||||
|
this.update(cx, |this, _| {
|
||||||
|
this.context_server_tool_ids.insert(server_id, tool_ids);
|
||||||
|
})
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
|
@ -809,74 +809,37 @@ impl ContextStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
|
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
|
||||||
cx.subscribe(
|
let context_server_store = self.project.read(cx).context_server_store();
|
||||||
&self.project.read(cx).context_server_store(),
|
cx.subscribe(&context_server_store, Self::handle_context_server_event)
|
||||||
Self::handle_context_server_event,
|
.detach();
|
||||||
)
|
|
||||||
.detach();
|
// Check for any servers that were already running before the handler was registered
|
||||||
|
for server in context_server_store.read(cx).running_servers() {
|
||||||
|
self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_context_server_event(
|
fn handle_context_server_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
context_server_manager: Entity<ContextServerStore>,
|
context_server_store: Entity<ContextServerStore>,
|
||||||
event: &project::context_server_store::Event,
|
event: &project::context_server_store::Event,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
let slash_command_working_set = self.slash_commands.clone();
|
|
||||||
match event {
|
match event {
|
||||||
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
||||||
match status {
|
match status {
|
||||||
ContextServerStatus::Running => {
|
ContextServerStatus::Running => {
|
||||||
if let Some(server) = context_server_manager
|
self.load_context_server_slash_commands(
|
||||||
.read(cx)
|
server_id.clone(),
|
||||||
.get_running_server(server_id)
|
context_server_store.clone(),
|
||||||
{
|
cx,
|
||||||
let context_server_manager = context_server_manager.clone();
|
);
|
||||||
cx.spawn({
|
|
||||||
let server = server.clone();
|
|
||||||
let server_id = server_id.clone();
|
|
||||||
async move |this, cx| {
|
|
||||||
let Some(protocol) = server.client() else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
|
||||||
if let Some(prompts) = protocol.list_prompts().await.log_err() {
|
|
||||||
let slash_command_ids = prompts
|
|
||||||
.into_iter()
|
|
||||||
.filter(assistant_slash_commands::acceptable_prompt)
|
|
||||||
.map(|prompt| {
|
|
||||||
log::info!(
|
|
||||||
"registering context server command: {:?}",
|
|
||||||
prompt.name
|
|
||||||
);
|
|
||||||
slash_command_working_set.insert(Arc::new(
|
|
||||||
assistant_slash_commands::ContextServerSlashCommand::new(
|
|
||||||
context_server_manager.clone(),
|
|
||||||
server.id(),
|
|
||||||
prompt,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
this.update( cx, |this, _cx| {
|
|
||||||
this.context_server_slash_command_ids
|
|
||||||
.insert(server_id.clone(), slash_command_ids);
|
|
||||||
})
|
|
||||||
.log_err();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||||
if let Some(slash_command_ids) =
|
if let Some(slash_command_ids) =
|
||||||
self.context_server_slash_command_ids.remove(server_id)
|
self.context_server_slash_command_ids.remove(server_id)
|
||||||
{
|
{
|
||||||
slash_command_working_set.remove(&slash_command_ids);
|
self.slash_commands.remove(&slash_command_ids);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
|
@ -884,4 +847,47 @@ impl ContextStore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_context_server_slash_commands(
|
||||||
|
&self,
|
||||||
|
server_id: ContextServerId,
|
||||||
|
context_server_store: Entity<ContextServerStore>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let slash_command_working_set = self.slash_commands.clone();
|
||||||
|
cx.spawn(async move |this, cx| {
|
||||||
|
let Some(protocol) = server.client() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
||||||
|
if let Some(prompts) = protocol.list_prompts().await.log_err() {
|
||||||
|
let slash_command_ids = prompts
|
||||||
|
.into_iter()
|
||||||
|
.filter(assistant_slash_commands::acceptable_prompt)
|
||||||
|
.map(|prompt| {
|
||||||
|
log::info!("registering context server command: {:?}", prompt.name);
|
||||||
|
slash_command_working_set.insert(Arc::new(
|
||||||
|
assistant_slash_commands::ContextServerSlashCommand::new(
|
||||||
|
context_server_store.clone(),
|
||||||
|
server.id(),
|
||||||
|
prompt,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
this.update(cx, |this, _cx| {
|
||||||
|
this.context_server_slash_command_ids
|
||||||
|
.insert(server_id.clone(), slash_command_ids);
|
||||||
|
})
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue