use crate::slash_command::context_server_command; use crate::{ prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context, ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata, }; use crate::{tools, SlashCommandId, ToolId, ToolWorkingSet}; use anyhow::{anyhow, Context as _, Result}; use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; use clock::ReplicaId; use collections::HashMap; use command_palette_hooks::CommandPaletteFilter; use context_servers::manager::{ContextServerManager, ContextServerSettings}; use context_servers::CONTEXT_SERVERS_NAMESPACE; use fs::Fs; use futures::StreamExt; use fuzzy::StringMatchCandidate; use gpui::{ AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task, WeakModel, }; use language::LanguageRegistry; use paths::contexts_dir; use project::Project; use regex::Regex; use rpc::AnyProtoClient; use settings::{Settings as _, SettingsStore}; use std::{ cmp::Reverse, ffi::OsStr, mem, path::{Path, PathBuf}, sync::Arc, time::Duration, }; use util::{ResultExt, TryFutureExt}; pub fn init(client: &AnyProtoClient) { client.add_model_message_handler(ContextStore::handle_advertise_contexts); client.add_model_request_handler(ContextStore::handle_open_context); client.add_model_request_handler(ContextStore::handle_create_context); client.add_model_message_handler(ContextStore::handle_update_context); client.add_model_request_handler(ContextStore::handle_synchronize_contexts); } #[derive(Clone)] pub struct RemoteContextMetadata { pub id: ContextId, pub summary: Option, } pub struct ContextStore { contexts: Vec, contexts_metadata: Vec, context_server_manager: Model, context_server_slash_command_ids: HashMap>, context_server_tool_ids: HashMap>, host_contexts: Vec, fs: Arc, languages: Arc, slash_commands: Arc, tools: Arc, telemetry: Arc, _watch_updates: Task>, client: Arc, project: Model, project_is_shared: bool, client_subscription: Option, _project_subscriptions: Vec, prompt_builder: Arc, } pub enum ContextStoreEvent { ContextCreated(ContextId), } impl EventEmitter for ContextStore {} enum ContextHandle { Weak(WeakModel), Strong(Model), } impl ContextHandle { fn upgrade(&self) -> Option> { match self { ContextHandle::Weak(weak) => weak.upgrade(), ContextHandle::Strong(strong) => Some(strong.clone()), } } fn downgrade(&self) -> WeakModel { match self { ContextHandle::Weak(weak) => weak.clone(), ContextHandle::Strong(strong) => strong.downgrade(), } } } impl ContextStore { pub fn new( project: Model, prompt_builder: Arc, slash_commands: Arc, tools: Arc, cx: &mut AppContext, ) -> Task>> { let fs = project.read(cx).fs().clone(); let languages = project.read(cx).languages().clone(); let telemetry = project.read(cx).client().telemetry().clone(); cx.spawn(|mut cx| async move { const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100); let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await; let this = cx.new_model(|cx: &mut ModelContext| { let context_server_manager = cx.new_model(|_cx| ContextServerManager::new()); let mut this = Self { contexts: Vec::new(), contexts_metadata: Vec::new(), context_server_manager, context_server_slash_command_ids: HashMap::default(), context_server_tool_ids: HashMap::default(), host_contexts: Vec::new(), fs, languages, slash_commands, tools, telemetry, _watch_updates: cx.spawn(|this, mut cx| { async move { while events.next().await.is_some() { this.update(&mut cx, |this, cx| this.reload(cx))? .await .log_err(); } anyhow::Ok(()) } .log_err() }), client_subscription: None, _project_subscriptions: vec![ cx.observe(&project, Self::handle_project_changed), cx.subscribe(&project, Self::handle_project_event), ], project_is_shared: false, client: project.read(cx).client(), project: project.clone(), prompt_builder, }; this.handle_project_changed(project, cx); this.synchronize_contexts(cx); this.register_context_server_handlers(cx); this })?; this.update(&mut cx, |this, cx| this.reload(cx))? .await .log_err(); this.update(&mut cx, |this, cx| { this.watch_context_server_settings(cx); }) .log_err(); Ok(this) }) } fn watch_context_server_settings(&self, cx: &mut ModelContext) { cx.observe_global::(move |this, cx| { this.context_server_manager.update(cx, |manager, cx| { let location = this.project.read(cx).worktrees(cx).next().map(|worktree| { settings::SettingsLocation { worktree_id: worktree.read(cx).id(), path: Path::new(""), } }); let settings = ContextServerSettings::get(location, cx); manager.maintain_servers(settings, cx); let has_any_context_servers = !manager.servers().is_empty(); CommandPaletteFilter::update_global(cx, |filter, _cx| { if has_any_context_servers { filter.show_namespace(CONTEXT_SERVERS_NAMESPACE); } else { filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); } }); }) }) .detach(); } async fn handle_advertise_contexts( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, ) -> Result<()> { this.update(&mut cx, |this, cx| { this.host_contexts = envelope .payload .contexts .into_iter() .map(|context| RemoteContextMetadata { id: ContextId::from_proto(context.context_id), summary: context.summary, }) .collect(); cx.notify(); }) } async fn handle_open_context( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, ) -> Result { let context_id = ContextId::from_proto(envelope.payload.context_id); let operations = this.update(&mut cx, |this, cx| { if this.project.read(cx).is_via_collab() { return Err(anyhow!("only the host contexts can be opened")); } let context = this .loaded_context_for_id(&context_id, cx) .context("context not found")?; if context.read(cx).replica_id() != ReplicaId::default() { return Err(anyhow!("context must be opened via the host")); } anyhow::Ok( context .read(cx) .serialize_ops(&ContextVersion::default(), cx), ) })??; let operations = operations.await; Ok(proto::OpenContextResponse { context: Some(proto::Context { operations }), }) } async fn handle_create_context( this: Model, _: TypedEnvelope, mut cx: AsyncAppContext, ) -> Result { let (context_id, operations) = this.update(&mut cx, |this, cx| { if this.project.read(cx).is_via_collab() { return Err(anyhow!("can only create contexts as the host")); } let context = this.create(cx); let context_id = context.read(cx).id().clone(); cx.emit(ContextStoreEvent::ContextCreated(context_id.clone())); anyhow::Ok(( context_id, context .read(cx) .serialize_ops(&ContextVersion::default(), cx), )) })??; let operations = operations.await; Ok(proto::CreateContextResponse { context_id: context_id.to_proto(), context: Some(proto::Context { operations }), }) } async fn handle_update_context( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, ) -> Result<()> { this.update(&mut cx, |this, cx| { let context_id = ContextId::from_proto(envelope.payload.context_id); if let Some(context) = this.loaded_context_for_id(&context_id, cx) { let operation_proto = envelope.payload.operation.context("invalid operation")?; let operation = ContextOperation::from_proto(operation_proto)?; context.update(cx, |context, cx| context.apply_ops([operation], cx)); } Ok(()) })? } async fn handle_synchronize_contexts( this: Model, envelope: TypedEnvelope, mut cx: AsyncAppContext, ) -> Result { this.update(&mut cx, |this, cx| { if this.project.read(cx).is_via_collab() { return Err(anyhow!("only the host can synchronize contexts")); } let mut local_versions = Vec::new(); for remote_version_proto in envelope.payload.contexts { let remote_version = ContextVersion::from_proto(&remote_version_proto); let context_id = ContextId::from_proto(remote_version_proto.context_id); if let Some(context) = this.loaded_context_for_id(&context_id, cx) { let context = context.read(cx); let operations = context.serialize_ops(&remote_version, cx); local_versions.push(context.version(cx).to_proto(context_id.clone())); let client = this.client.clone(); let project_id = envelope.payload.project_id; cx.background_executor() .spawn(async move { let operations = operations.await; for operation in operations { client.send(proto::UpdateContext { project_id, context_id: context_id.to_proto(), operation: Some(operation), })?; } anyhow::Ok(()) }) .detach_and_log_err(cx); } } this.advertise_contexts(cx); anyhow::Ok(proto::SynchronizeContextsResponse { contexts: local_versions, }) })? } fn handle_project_changed(&mut self, _: Model, cx: &mut ModelContext) { let is_shared = self.project.read(cx).is_shared(); let was_shared = mem::replace(&mut self.project_is_shared, is_shared); if is_shared == was_shared { return; } if is_shared { self.contexts.retain_mut(|context| { if let Some(strong_context) = context.upgrade() { *context = ContextHandle::Strong(strong_context); true } else { false } }); let remote_id = self.project.read(cx).remote_id().unwrap(); self.client_subscription = self .client .subscribe_to_entity(remote_id) .log_err() .map(|subscription| subscription.set_model(&cx.handle(), &mut cx.to_async())); self.advertise_contexts(cx); } else { self.client_subscription = None; } } fn handle_project_event( &mut self, _: Model, event: &project::Event, cx: &mut ModelContext, ) { match event { project::Event::Reshared => { self.advertise_contexts(cx); } project::Event::HostReshared | project::Event::Rejoined => { self.synchronize_contexts(cx); } project::Event::DisconnectedFromHost => { self.contexts.retain_mut(|context| { if let Some(strong_context) = context.upgrade() { *context = ContextHandle::Weak(context.downgrade()); strong_context.update(cx, |context, cx| { if context.replica_id() != ReplicaId::default() { context.set_capability(language::Capability::ReadOnly, cx); } }); true } else { false } }); self.host_contexts.clear(); cx.notify(); } _ => {} } } pub fn create(&mut self, cx: &mut ModelContext) -> Model { let context = cx.new_model(|cx| { Context::local( self.languages.clone(), Some(self.project.clone()), Some(self.telemetry.clone()), self.prompt_builder.clone(), self.slash_commands.clone(), self.tools.clone(), cx, ) }); self.register_context(&context, cx); context } pub fn create_remote_context( &mut self, cx: &mut ModelContext, ) -> Task>> { let project = self.project.read(cx); let Some(project_id) = project.remote_id() else { return Task::ready(Err(anyhow!("project was not remote"))); }; let replica_id = project.replica_id(); let capability = project.capability(); let language_registry = self.languages.clone(); let project = self.project.clone(); let telemetry = self.telemetry.clone(); let prompt_builder = self.prompt_builder.clone(); let slash_commands = self.slash_commands.clone(); let tools = self.tools.clone(); let request = self.client.request(proto::CreateContext { project_id }); cx.spawn(|this, mut cx| async move { let response = request.await?; let context_id = ContextId::from_proto(response.context_id); let context_proto = response.context.context("invalid context")?; let context = cx.new_model(|cx| { Context::new( context_id.clone(), replica_id, capability, language_registry, prompt_builder, slash_commands, tools, Some(project), Some(telemetry), cx, ) })?; let operations = cx .background_executor() .spawn(async move { context_proto .operations .into_iter() .map(ContextOperation::from_proto) .collect::>>() }) .await?; context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?; this.update(&mut cx, |this, cx| { if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) { existing_context } else { this.register_context(&context, cx); this.synchronize_contexts(cx); context } }) }) } pub fn open_local_context( &mut self, path: PathBuf, cx: &ModelContext, ) -> Task>> { if let Some(existing_context) = self.loaded_context_for_path(&path, cx) { return Task::ready(Ok(existing_context)); } let fs = self.fs.clone(); let languages = self.languages.clone(); let project = self.project.clone(); let telemetry = self.telemetry.clone(); let load = cx.background_executor().spawn({ let path = path.clone(); async move { let saved_context = fs.load(&path).await?; SavedContext::from_json(&saved_context) } }); let prompt_builder = self.prompt_builder.clone(); let slash_commands = self.slash_commands.clone(); let tools = self.tools.clone(); cx.spawn(|this, mut cx| async move { let saved_context = load.await?; let context = cx.new_model(|cx| { Context::deserialize( saved_context, path.clone(), languages, prompt_builder, slash_commands, tools, Some(project), Some(telemetry), cx, ) })?; this.update(&mut cx, |this, cx| { if let Some(existing_context) = this.loaded_context_for_path(&path, cx) { existing_context } else { this.register_context(&context, cx); context } }) }) } fn loaded_context_for_path(&self, path: &Path, cx: &AppContext) -> Option> { self.contexts.iter().find_map(|context| { let context = context.upgrade()?; if context.read(cx).path() == Some(path) { Some(context) } else { None } }) } pub(super) fn loaded_context_for_id( &self, id: &ContextId, cx: &AppContext, ) -> Option> { self.contexts.iter().find_map(|context| { let context = context.upgrade()?; if context.read(cx).id() == id { Some(context) } else { None } }) } pub fn open_remote_context( &mut self, context_id: ContextId, cx: &mut ModelContext, ) -> Task>> { let project = self.project.read(cx); let Some(project_id) = project.remote_id() else { return Task::ready(Err(anyhow!("project was not remote"))); }; if let Some(context) = self.loaded_context_for_id(&context_id, cx) { return Task::ready(Ok(context)); } let replica_id = project.replica_id(); let capability = project.capability(); let language_registry = self.languages.clone(); let project = self.project.clone(); let telemetry = self.telemetry.clone(); let request = self.client.request(proto::OpenContext { project_id, context_id: context_id.to_proto(), }); let prompt_builder = self.prompt_builder.clone(); let slash_commands = self.slash_commands.clone(); let tools = self.tools.clone(); cx.spawn(|this, mut cx| async move { let response = request.await?; let context_proto = response.context.context("invalid context")?; let context = cx.new_model(|cx| { Context::new( context_id.clone(), replica_id, capability, language_registry, prompt_builder, slash_commands, tools, Some(project), Some(telemetry), cx, ) })?; let operations = cx .background_executor() .spawn(async move { context_proto .operations .into_iter() .map(ContextOperation::from_proto) .collect::>>() }) .await?; context.update(&mut cx, |context, cx| context.apply_ops(operations, cx))?; this.update(&mut cx, |this, cx| { if let Some(existing_context) = this.loaded_context_for_id(&context_id, cx) { existing_context } else { this.register_context(&context, cx); this.synchronize_contexts(cx); context } }) }) } fn register_context(&mut self, context: &Model, cx: &mut ModelContext) { let handle = if self.project_is_shared { ContextHandle::Strong(context.clone()) } else { ContextHandle::Weak(context.downgrade()) }; self.contexts.push(handle); self.advertise_contexts(cx); cx.subscribe(context, Self::handle_context_event).detach(); } fn handle_context_event( &mut self, context: Model, event: &ContextEvent, cx: &mut ModelContext, ) { let Some(project_id) = self.project.read(cx).remote_id() else { return; }; match event { ContextEvent::SummaryChanged => { self.advertise_contexts(cx); } ContextEvent::Operation(operation) => { let context_id = context.read(cx).id().to_proto(); let operation = operation.to_proto(); self.client .send(proto::UpdateContext { project_id, context_id, operation: Some(operation), }) .log_err(); } _ => {} } } fn advertise_contexts(&self, cx: &AppContext) { let Some(project_id) = self.project.read(cx).remote_id() else { return; }; // For now, only the host can advertise their open contexts. if self.project.read(cx).is_via_collab() { return; } let contexts = self .contexts .iter() .rev() .filter_map(|context| { let context = context.upgrade()?.read(cx); if context.replica_id() == ReplicaId::default() { Some(proto::ContextMetadata { context_id: context.id().to_proto(), summary: context.summary().map(|summary| summary.text.clone()), }) } else { None } }) .collect(); self.client .send(proto::AdvertiseContexts { project_id, contexts, }) .ok(); } fn synchronize_contexts(&mut self, cx: &mut ModelContext) { let Some(project_id) = self.project.read(cx).remote_id() else { return; }; let contexts = self .contexts .iter() .filter_map(|context| { let context = context.upgrade()?.read(cx); if context.replica_id() != ReplicaId::default() { Some(context.version(cx).to_proto(context.id().clone())) } else { None } }) .collect(); let client = self.client.clone(); let request = self.client.request(proto::SynchronizeContexts { project_id, contexts, }); cx.spawn(|this, cx| async move { let response = request.await?; let mut context_ids = Vec::new(); let mut operations = Vec::new(); this.read_with(&cx, |this, cx| { for context_version_proto in response.contexts { let context_version = ContextVersion::from_proto(&context_version_proto); let context_id = ContextId::from_proto(context_version_proto.context_id); if let Some(context) = this.loaded_context_for_id(&context_id, cx) { context_ids.push(context_id); operations.push(context.read(cx).serialize_ops(&context_version, cx)); } } })?; let operations = futures::future::join_all(operations).await; for (context_id, operations) in context_ids.into_iter().zip(operations) { for operation in operations { client.send(proto::UpdateContext { project_id, context_id: context_id.to_proto(), operation: Some(operation), })?; } } anyhow::Ok(()) }) .detach_and_log_err(cx); } pub fn search(&self, query: String, cx: &AppContext) -> Task> { let metadata = self.contexts_metadata.clone(); let executor = cx.background_executor().clone(); cx.background_executor().spawn(async move { if query.is_empty() { metadata } else { let candidates = metadata .iter() .enumerate() .map(|(id, metadata)| StringMatchCandidate::new(id, metadata.title.clone())) .collect::>(); let matches = fuzzy::match_strings( &candidates, &query, false, 100, &Default::default(), executor, ) .await; matches .into_iter() .map(|mat| metadata[mat.candidate_id].clone()) .collect() } }) } pub fn host_contexts(&self) -> &[RemoteContextMetadata] { &self.host_contexts } fn reload(&mut self, cx: &mut ModelContext) -> Task> { let fs = self.fs.clone(); cx.spawn(|this, mut cx| async move { fs.create_dir(contexts_dir()).await?; let mut paths = fs.read_dir(contexts_dir()).await?; let mut contexts = Vec::::new(); while let Some(path) = paths.next().await { let path = path?; if path.extension() != Some(OsStr::new("json")) { continue; } let pattern = r" - \d+.zed.json$"; let re = Regex::new(pattern).unwrap(); let metadata = fs.metadata(&path).await?; if let Some((file_name, metadata)) = path .file_name() .and_then(|name| name.to_str()) .zip(metadata) { // This is used to filter out contexts saved by the new assistant. if !re.is_match(file_name) { continue; } if let Some(title) = re.replace(file_name, "").lines().next() { contexts.push(SavedContextMetadata { title: title.to_string(), path, mtime: metadata.mtime.into(), }); } } } contexts.sort_unstable_by_key(|context| Reverse(context.mtime)); this.update(&mut cx, |this, cx| { this.contexts_metadata = contexts; cx.notify(); }) }) } pub fn restart_context_servers(&mut self, cx: &mut ModelContext) { cx.update_model( &self.context_server_manager, |context_server_manager, cx| { for server in context_server_manager.servers() { context_server_manager .restart_server(&server.id, cx) .detach_and_log_err(cx); } }, ); } fn register_context_server_handlers(&self, cx: &mut ModelContext) { cx.subscribe( &self.context_server_manager.clone(), Self::handle_context_server_event, ) .detach(); } fn handle_context_server_event( &mut self, context_server_manager: Model, event: &context_servers::manager::Event, cx: &mut ModelContext, ) { let slash_command_working_set = self.slash_commands.clone(); let tool_working_set = self.tools.clone(); match event { context_servers::manager::Event::ServerStarted { server_id } => { if let Some(server) = context_server_manager.read(cx).get_server(server_id) { let context_server_manager = context_server_manager.clone(); cx.spawn({ let server = server.clone(); let server_id = server_id.clone(); |this, mut cx| async move { let Some(protocol) = server.client.read().clone() else { return; }; if protocol.capable(context_servers::protocol::ServerCapability::Prompts) { if let Some(prompts) = protocol.list_prompts().await.log_err() { let slash_command_ids = prompts .into_iter() .filter(context_server_command::acceptable_prompt) .map(|prompt| { log::info!( "registering context server command: {:?}", prompt.name ); slash_command_working_set.insert(Arc::new( context_server_command::ContextServerSlashCommand::new( context_server_manager.clone(), &server, prompt, ), )) }) .collect::>(); this.update(&mut cx, |this, _cx| { this.context_server_slash_command_ids .insert(server_id.clone(), slash_command_ids); }) .log_err(); } } if protocol.capable(context_servers::protocol::ServerCapability::Tools) { if let Some(tools) = protocol.list_tools().await.log_err() { let tool_ids = tools.tools.into_iter().map(|tool| { log::info!("registering context server tool: {:?}", tool.name); tool_working_set.insert( Arc::new(tools::context_server_tool::ContextServerTool::new( context_server_manager.clone(), server.id.clone(), tool, )), ) }).collect::>(); this.update(&mut cx, |this, _cx| { this.context_server_tool_ids .insert(server_id, tool_ids); }) .log_err(); } } } }) .detach(); } } context_servers::manager::Event::ServerStopped { server_id } => { if let Some(slash_command_ids) = self.context_server_slash_command_ids.remove(server_id) { slash_command_working_set.remove(&slash_command_ids); } if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { tool_working_set.remove(&tool_ids); } } } } }