diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 504280fac4..cb3a0d3c63 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -514,11 +514,14 @@ impl ThreadStore { } fn register_context_server_handlers(&self, cx: &mut Context) { - cx.subscribe( - &self.project.read(cx).context_server_store(), - Self::handle_context_server_event, - ) - .detach(); + let context_server_store = self.project.read(cx).context_server_store(); + cx.subscribe(&context_server_store, Self::handle_context_server_event) + .detach(); + + // Check for any servers that were already running before the handler was registered + for server in context_server_store.read(cx).running_servers() { + self.load_context_server_tools(server.id(), context_server_store.clone(), cx); + } } fn handle_context_server_event( @@ -533,55 +536,7 @@ impl ThreadStore { match status { ContextServerStatus::Starting => {} ContextServerStatus::Running => { - if let Some(server) = - context_server_store.read(cx).get_running_server(server_id) - { - let context_server_manager = context_server_store.clone(); - cx.spawn({ - let server = server.clone(); - let server_id = server_id.clone(); - async move |this, cx| { - let Some(protocol) = server.client() else { - return; - }; - - if protocol.capable(context_server::protocol::ServerCapability::Tools) { - if let Some(tools) = protocol.list_tools().await.log_err() { - let tool_ids = tool_working_set - .update(cx, |tool_working_set, _| { - tools - .tools - .into_iter() - .map(|tool| { - log::info!( - "registering context server tool: {:?}", - tool.name - ); - tool_working_set.insert(Arc::new( - ContextServerTool::new( - context_server_manager.clone(), - server.id(), - tool, - ), - )) - }) - .collect::>() - }) - .log_err(); - - if let Some(tool_ids) = tool_ids { - this.update(cx, |this, _| { - this.context_server_tool_ids - .insert(server_id, tool_ids); - }) - .log_err(); - } - } - } - } - }) - .detach(); - } + self.load_context_server_tools(server_id.clone(), context_server_store, cx); } ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { @@ -594,6 +549,52 @@ impl ThreadStore { } } } + + fn load_context_server_tools( + &self, + server_id: ContextServerId, + context_server_store: Entity, + cx: &mut Context, + ) { + let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else { + return; + }; + let tool_working_set = self.tools.clone(); + cx.spawn(async move |this, cx| { + let Some(protocol) = server.client() else { + return; + }; + + if protocol.capable(context_server::protocol::ServerCapability::Tools) { + if let Some(tools) = protocol.list_tools().await.log_err() { + let tool_ids = tool_working_set + .update(cx, |tool_working_set, _| { + tools + .tools + .into_iter() + .map(|tool| { + log::info!("registering context server tool: {:?}", tool.name); + tool_working_set.insert(Arc::new(ContextServerTool::new( + context_server_store.clone(), + server.id(), + tool, + ))) + }) + .collect::>() + }) + .log_err(); + + if let Some(tool_ids) = tool_ids { + this.update(cx, |this, _| { + this.context_server_tool_ids.insert(server_id, tool_ids); + }) + .log_err(); + } + } + } + }) + .detach(); + } } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index e6a771ae68..7af97b62a9 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -809,74 +809,37 @@ impl ContextStore { } fn register_context_server_handlers(&self, cx: &mut Context) { - cx.subscribe( - &self.project.read(cx).context_server_store(), - Self::handle_context_server_event, - ) - .detach(); + let context_server_store = self.project.read(cx).context_server_store(); + cx.subscribe(&context_server_store, Self::handle_context_server_event) + .detach(); + + // Check for any servers that were already running before the handler was registered + for server in context_server_store.read(cx).running_servers() { + self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx); + } } fn handle_context_server_event( &mut self, - context_server_manager: Entity, + context_server_store: Entity, event: &project::context_server_store::Event, cx: &mut Context, ) { - let slash_command_working_set = self.slash_commands.clone(); match event { project::context_server_store::Event::ServerStatusChanged { server_id, status } => { match status { ContextServerStatus::Running => { - if let Some(server) = context_server_manager - .read(cx) - .get_running_server(server_id) - { - let context_server_manager = context_server_manager.clone(); - cx.spawn({ - let server = server.clone(); - let server_id = server_id.clone(); - async move |this, cx| { - let Some(protocol) = server.client() else { - return; - }; - - if protocol.capable(context_server::protocol::ServerCapability::Prompts) { - if let Some(prompts) = protocol.list_prompts().await.log_err() { - let slash_command_ids = prompts - .into_iter() - .filter(assistant_slash_commands::acceptable_prompt) - .map(|prompt| { - log::info!( - "registering context server command: {:?}", - prompt.name - ); - slash_command_working_set.insert(Arc::new( - assistant_slash_commands::ContextServerSlashCommand::new( - context_server_manager.clone(), - server.id(), - prompt, - ), - )) - }) - .collect::>(); - - this.update( cx, |this, _cx| { - this.context_server_slash_command_ids - .insert(server_id.clone(), slash_command_ids); - }) - .log_err(); - } - } - } - }) - .detach(); - } + self.load_context_server_slash_commands( + server_id.clone(), + context_server_store.clone(), + cx, + ); } ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { if let Some(slash_command_ids) = self.context_server_slash_command_ids.remove(server_id) { - slash_command_working_set.remove(&slash_command_ids); + self.slash_commands.remove(&slash_command_ids); } } _ => {} @@ -884,4 +847,47 @@ impl ContextStore { } } } + + fn load_context_server_slash_commands( + &self, + server_id: ContextServerId, + context_server_store: Entity, + cx: &mut Context, + ) { + let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else { + return; + }; + let slash_command_working_set = self.slash_commands.clone(); + cx.spawn(async move |this, cx| { + let Some(protocol) = server.client() else { + return; + }; + + if protocol.capable(context_server::protocol::ServerCapability::Prompts) { + if let Some(prompts) = protocol.list_prompts().await.log_err() { + let slash_command_ids = prompts + .into_iter() + .filter(assistant_slash_commands::acceptable_prompt) + .map(|prompt| { + log::info!("registering context server command: {:?}", prompt.name); + slash_command_working_set.insert(Arc::new( + assistant_slash_commands::ContextServerSlashCommand::new( + context_server_store.clone(), + server.id(), + prompt, + ), + )) + }) + .collect::>(); + + this.update(cx, |this, _cx| { + this.context_server_slash_command_ids + .insert(server_id.clone(), slash_command_ids); + }) + .log_err(); + } + } + }) + .detach(); + } }