context servers: Show configuration modal when extension is installed (#29309)

WIP

Release Notes:

- N/A

---------

Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Marshall Bowers <git@maxdeviant.com>
Co-authored-by: Cole Miller <m@cole-miller.net>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
Bennet Bo Fenner 2025-05-01 20:02:14 +02:00 committed by GitHub
parent bffa53d706
commit 24eb039752
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 1866 additions and 437 deletions

View file

@ -34,3 +34,7 @@ smol.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }

View file

@ -140,7 +140,7 @@ impl Client {
/// This function initializes a new Client by spawning a child process for the context server,
/// setting up communication channels, and initializing handlers for input/output operations.
/// It takes a server ID, binary information, and an async app context as input.
pub fn new(
pub fn stdio(
server_id: ContextServerId,
binary: ModelContextServerBinary,
cx: AsyncApp,
@ -158,7 +158,16 @@ impl Client {
.unwrap_or_else(String::new);
let transport = Arc::new(StdioTransport::new(binary, &cx)?);
Self::new(server_id, server_name.into(), transport, cx)
}
/// Creates a new Client instance for a context server.
pub fn new(
server_id: ContextServerId,
server_name: Arc<str>,
transport: Arc<dyn Transport>,
cx: AsyncApp,
) -> Result<Self> {
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
@ -167,7 +176,7 @@ impl Client {
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let stdout_input_task = cx.spawn({
let receive_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
let transport = transport.clone();
@ -177,13 +186,13 @@ impl Client {
.await
}
});
let stderr_input_task = cx.spawn({
let receive_err_task = cx.spawn({
let transport = transport.clone();
async move |_| Self::handle_stderr(transport).log_err().await
async move |_| Self::handle_err(transport).log_err().await
});
let input_task = cx.spawn(async move |_| {
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
stdout.or(stderr)
let (input, err) = futures::join!(receive_input_task, receive_err_task);
input.or(err)
});
let output_task = cx.background_spawn({
@ -201,7 +210,7 @@ impl Client {
server_id,
notification_handlers,
response_handlers,
name: server_name.into(),
name: server_name,
next_id: Default::default(),
outbound_tx,
executor: cx.background_executor().clone(),
@ -247,7 +256,7 @@ impl Client {
/// Handles the stderr output from the context server.
/// Continuously reads and logs any error messages from the server.
async fn handle_stderr(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
async fn handle_err(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
while let Some(err) = transport.receive_err().next().await {
log::warn!("context server stderr: {}", err.trim());
}

View file

@ -12,7 +12,7 @@ pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerCo
use gpui::{App, actions};
pub use crate::context_server_tool::ContextServerTool;
pub use crate::registry::ContextServerFactoryRegistry;
pub use crate::registry::ContextServerDescriptorRegistry;
actions!(context_servers, [Restart]);
@ -21,7 +21,7 @@ pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut App) {
context_server_settings::init(cx);
ContextServerFactoryRegistry::default_global(cx);
ContextServerDescriptorRegistry::default_global(cx);
extension_context_server::init(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {

View file

@ -1,9 +1,21 @@
use std::sync::Arc;
use extension::{Extension, ExtensionContextServerProxy, ExtensionHostProxy, ProjectDelegate};
use gpui::{App, Entity};
use anyhow::Result;
use extension::{
ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy,
ProjectDelegate,
};
use gpui::{App, AsyncApp, Entity, Task};
use project::Project;
use crate::{ContextServerFactoryRegistry, ServerCommand};
use crate::{ContextServerDescriptorRegistry, ServerCommand, registry};
pub fn init(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy {
context_server_factory_registry: ContextServerDescriptorRegistry::global(cx),
});
}
struct ExtensionProject {
worktree_ids: Vec<u64>,
@ -15,60 +27,78 @@ impl ProjectDelegate for ExtensionProject {
}
}
pub fn init(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
proxy.register_context_server_proxy(ContextServerFactoryRegistryProxy {
context_server_factory_registry: ContextServerFactoryRegistry::global(cx),
});
struct ContextServerDescriptor {
id: Arc<str>,
extension: Arc<dyn Extension>,
}
struct ContextServerFactoryRegistryProxy {
context_server_factory_registry: Entity<ContextServerFactoryRegistry>,
fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<ExtensionProject>> {
project.update(cx, |project, cx| {
Arc::new(ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})
})
}
impl ExtensionContextServerProxy for ContextServerFactoryRegistryProxy {
impl registry::ContextServerDescriptor for ContextServerDescriptor {
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
let extension_project = extension_project(project, cx)?;
let mut command = extension
.context_server_command(id.clone(), extension_project.clone())
.await?;
command.command = extension
.path_from_extension(command.command.as_ref())
.to_string_lossy()
.to_string();
log::info!("loaded command for context server {id}: {command:?}");
Ok(ServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
})
})
}
fn configuration(
&self,
project: Entity<Project>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>> {
let id = self.id.clone();
let extension = self.extension.clone();
cx.spawn(async move |cx| {
let extension_project = extension_project(project, cx)?;
let configuration = extension
.context_server_configuration(id.clone(), extension_project)
.await?;
log::debug!("loaded configuration for context server {id}: {configuration:?}");
Ok(configuration)
})
}
}
struct ContextServerDescriptorRegistryProxy {
context_server_factory_registry: Entity<ContextServerDescriptorRegistry>,
}
impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
fn register_context_server(&self, extension: Arc<dyn Extension>, id: Arc<str>, cx: &mut App) {
self.context_server_factory_registry
.update(cx, |registry, _| {
registry.register_server_factory(
registry.register_context_server_descriptor(
id.clone(),
Arc::new({
move |project, cx| {
log::info!(
"loading command for context server {id} from extension {}",
extension.manifest().id
);
let id = id.clone();
let extension = extension.clone();
cx.spawn(async move |cx| {
let extension_project = project.update(cx, |project, cx| {
Arc::new(ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})
})?;
let mut command = extension
.context_server_command(id.clone(), extension_project)
.await?;
command.command = extension
.path_from_extension(command.command.as_ref())
.to_string_lossy()
.to_string();
log::info!("loaded command for context server {id}: {command:?}");
Ok(ServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
})
})
}
}),
Arc::new(ContextServerDescriptor { id, extension })
as Arc<dyn registry::ContextServerDescriptor>,
)
});
}

View file

@ -27,18 +27,27 @@ use project::Project;
use settings::{Settings, SettingsStore};
use util::ResultExt as _;
use crate::transport::Transport;
use crate::{ContextServerSettings, ServerConfig};
use crate::{
CONTEXT_SERVERS_NAMESPACE, ContextServerFactoryRegistry,
CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
client::{self, Client},
types,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ContextServerStatus {
Starting,
Running,
Error(Arc<str>),
}
pub struct ContextServer {
pub id: Arc<str>,
pub config: Arc<ServerConfig>,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
transport: Option<Arc<dyn Transport>>,
}
impl ContextServer {
@ -47,9 +56,20 @@ impl ContextServer {
id,
config,
client: RwLock::new(None),
transport: None,
}
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
Arc::new(Self {
id,
client: RwLock::new(None),
config: Arc::new(ServerConfig::default()),
transport: Some(transport),
})
}
pub fn id(&self) -> Arc<str> {
self.id.clone()
}
@ -63,20 +83,32 @@ impl ContextServer {
}
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
log::info!("starting context server {}", self.id);
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
let client = if let Some(transport) = self.transport.clone() {
Client::new(
client::ContextServerId(self.id.clone()),
self.id(),
transport,
cx.clone(),
)?
} else {
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
};
Client::stdio(
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?
};
let client = Client::new(
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?;
self.initialize(client).await
}
async fn initialize(&self, client: Client) -> Result<()> {
log::info!("starting context server {}", self.id);
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
@ -105,23 +137,26 @@ impl ContextServer {
pub struct ContextServerManager {
servers: HashMap<Arc<str>, Arc<ContextServer>>,
server_status: HashMap<Arc<str>, ContextServerStatus>,
project: Entity<Project>,
registry: Entity<ContextServerFactoryRegistry>,
registry: Entity<ContextServerDescriptorRegistry>,
update_servers_task: Option<Task<Result<()>>>,
needs_server_update: bool,
_subscriptions: Vec<Subscription>,
}
pub enum Event {
ServerStarted { server_id: Arc<str> },
ServerStopped { server_id: Arc<str> },
ServerStatusChanged {
server_id: Arc<str>,
status: Option<ContextServerStatus>,
},
}
impl EventEmitter<Event> for ContextServerManager {}
impl ContextServerManager {
pub fn new(
registry: Entity<ContextServerFactoryRegistry>,
registry: Entity<ContextServerDescriptorRegistry>,
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Self {
@ -138,6 +173,7 @@ impl ContextServerManager {
registry,
needs_server_update: false,
servers: HashMap::default(),
server_status: HashMap::default(),
update_servers_task: None,
};
this.available_context_servers_changed(cx);
@ -153,7 +189,9 @@ impl ContextServerManager {
this.needs_server_update = false;
})?;
Self::maintain_servers(this.clone(), cx).await?;
if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
log::error!("Error maintaining context servers: {}", err);
}
this.update(cx, |this, cx| {
let has_any_context_servers = !this.running_servers().is_empty();
@ -181,52 +219,37 @@ impl ContextServerManager {
.cloned()
}
pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
self.server_status.get(id).cloned()
}
pub fn start_server(
&self,
server: Arc<ContextServer>,
cx: &mut Context<Self>,
) -> Task<anyhow::Result<()>> {
cx.spawn(async move |this, cx| {
let id = server.id.clone();
server.start(&cx).await?;
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
Ok(())
})
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
}
pub fn stop_server(
&self,
&mut self,
server: Arc<ContextServer>,
cx: &mut Context<Self>,
) -> anyhow::Result<()> {
server.stop()?;
cx.emit(Event::ServerStopped {
server_id: server.id(),
});
) -> Result<()> {
server.stop().log_err();
self.update_server_status(server.id().clone(), None, cx);
Ok(())
}
pub fn restart_server(
&mut self,
id: &Arc<str>,
cx: &mut Context<Self>,
) -> Task<anyhow::Result<()>> {
pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
let id = id.clone();
cx.spawn(async move |this, cx| {
if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
server.stop()?;
let config = server.config();
this.update(cx, |this, cx| this.stop_server(server, cx))??;
let new_server = Arc::new(ContextServer::new(id.clone(), config));
new_server.clone().start(&cx).await?;
this.update(cx, |this, cx| {
this.servers.insert(id.clone(), new_server);
cx.emit(Event::ServerStopped {
server_id: id.clone(),
});
cx.emit(Event::ServerStarted {
server_id: id.clone(),
});
})?;
Self::run_server(this, new_server, cx).await?;
}
Ok(())
})
@ -263,12 +286,14 @@ impl ContextServerManager {
(this.registry.clone(), this.project.clone())
})?;
for (id, factory) in
registry.read_with(cx, |registry, _| registry.context_server_factories())?
for (id, descriptor) in
registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
{
let config = desired_servers.entry(id).or_default();
if config.command.is_none() {
if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
if let Some(extension_command) =
descriptor.command(project.clone(), &cx).await.log_err()
{
config.command = Some(extension_command);
}
}
@ -290,28 +315,270 @@ impl ContextServerManager {
for (id, config) in desired_servers {
let existing_config = this.servers.get(&id).map(|server| server.config());
if existing_config.as_deref() != Some(&config) {
let config = Arc::new(config);
let server = Arc::new(ContextServer::new(id.clone(), config));
let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
servers_to_start.insert(id.clone(), server.clone());
let old_server = this.servers.insert(id.clone(), server);
if let Some(old_server) = old_server {
if let Some(old_server) = this.servers.remove(&id) {
servers_to_stop.insert(id, old_server);
}
}
}
})?;
for (id, server) in servers_to_stop {
server.stop().log_err();
this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?;
for (_, server) in servers_to_stop {
this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
}
for (id, server) in servers_to_start {
if server.start(&cx).await.log_err().is_some() {
this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
}
for (_, server) in servers_to_start {
Self::run_server(this.clone(), server, cx).await.ok();
}
Ok(())
}
async fn run_server(
this: WeakEntity<Self>,
server: Arc<ContextServer>,
cx: &mut AsyncApp,
) -> Result<()> {
let id = server.id();
this.update(cx, |this, cx| {
this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
this.servers.insert(id.clone(), server.clone());
})?;
match server.start(&cx).await {
Ok(_) => {
log::debug!("`{}` context server started", id);
this.update(cx, |this, cx| {
this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
})?;
Ok(())
}
Err(err) => {
log::error!("`{}` context server failed to start\n{}", id, err);
this.update(cx, |this, cx| {
this.update_server_status(
id.clone(),
Some(ContextServerStatus::Error(err.to_string().into())),
cx,
)
})?;
Err(err)
}
}
}
fn update_server_status(
&mut self,
id: Arc<str>,
status: Option<ContextServerStatus>,
cx: &mut Context<Self>,
) {
if let Some(status) = status.clone() {
self.server_status.insert(id.clone(), status);
} else {
self.server_status.remove(&id);
}
cx.emit(Event::ServerStatusChanged {
server_id: id,
status,
});
}
}
#[cfg(test)]
mod tests {
use std::pin::Pin;
use crate::types::{
Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
};
use super::*;
use futures::{Stream, StreamExt as _, lock::Mutex};
use gpui::{AppContext as _, TestAppContext};
use project::FakeFs;
use serde_json::json;
use util::path;
#[gpui::test]
async fn test_context_server_status(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({"code.rs": ""})).await;
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
let server_1_id: Arc<str> = "mcp-1".into();
let server_2_id: Arc<str> = "mcp-2".into();
let transport_1 = Arc::new(FakeTransport::new(
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-1".to_string()))
}
_ => None,
},
));
let transport_2 = Arc::new(FakeTransport::new(
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-2".to_string()))
}
_ => None,
},
));
let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
manager
.update(cx, |manager, cx| manager.start_server(server_1, cx))
.await
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
});
manager
.update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
.await
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(
manager.read(cx).status_for_server(&server_2_id),
Some(ContextServerStatus::Running)
);
});
manager
.update(cx, |manager, cx| manager.stop_server(server_2, cx))
.unwrap();
cx.update(|cx| {
assert_eq!(
manager.read(cx).status_for_server(&server_1_id),
Some(ContextServerStatus::Running)
);
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
});
}
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
ContextServerSettings::register(cx);
});
}
fn create_initialize_response(server_name: String) -> serde_json::Value {
serde_json::to_value(&InitializeResponse {
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
server_info: Implementation {
name: server_name,
version: "1.0.0".to_string(),
},
capabilities: ServerCapabilities::default(),
meta: None,
})
.unwrap()
}
struct FakeTransport {
on_request: Arc<
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
+ Send
+ Sync,
>,
tx: futures::channel::mpsc::UnboundedSender<String>,
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
}
impl FakeTransport {
fn new(
on_request: impl Fn(
u64,
Option<RequestType>,
serde_json::Value,
) -> Option<serde_json::Value>
+ 'static
+ Send
+ Sync,
) -> Self {
let (tx, rx) = futures::channel::mpsc::unbounded();
Self {
on_request: Arc::new(on_request),
tx,
rx: Arc::new(Mutex::new(rx)),
}
}
}
#[async_trait::async_trait]
impl Transport for FakeTransport {
async fn send(&self, message: String) -> Result<()> {
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
if let Some(method) = msg.get("method") {
let request_type = method
.as_str()
.and_then(|method| types::RequestType::try_from(method).ok());
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": payload
});
self.tx
.unbounded_send(response.to_string())
.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
}
}
}
Ok(())
}
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
let rx = self.rx.clone();
Box::pin(futures::stream::unfold(rx, |rx| async move {
let mut rx_guard = rx.lock().await;
if let Some(message) = rx_guard.next().await {
drop(rx_guard);
Some((message, rx))
} else {
None
}
}))
}
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(futures::stream::empty())
}
}
}

View file

@ -2,38 +2,47 @@ use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
use extension::ContextServerConfiguration;
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task};
use project::Project;
use crate::ServerCommand;
pub type ContextServerFactory =
Arc<dyn Fn(Entity<Project>, &AsyncApp) -> Task<Result<ServerCommand>> + Send + Sync + 'static>;
struct GlobalContextServerFactoryRegistry(Entity<ContextServerFactoryRegistry>);
impl Global for GlobalContextServerFactoryRegistry {}
#[derive(Default)]
pub struct ContextServerFactoryRegistry {
context_servers: HashMap<Arc<str>, ContextServerFactory>,
pub trait ContextServerDescriptor {
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>>;
fn configuration(
&self,
project: Entity<Project>,
cx: &AsyncApp,
) -> Task<Result<Option<ContextServerConfiguration>>>;
}
impl ContextServerFactoryRegistry {
/// Returns the global [`ContextServerFactoryRegistry`].
struct GlobalContextServerDescriptorRegistry(Entity<ContextServerDescriptorRegistry>);
impl Global for GlobalContextServerDescriptorRegistry {}
#[derive(Default)]
pub struct ContextServerDescriptorRegistry {
context_servers: HashMap<Arc<str>, Arc<dyn ContextServerDescriptor>>,
}
impl ContextServerDescriptorRegistry {
/// Returns the global [`ContextServerDescriptorRegistry`].
pub fn global(cx: &App) -> Entity<Self> {
GlobalContextServerFactoryRegistry::global(cx).0.clone()
GlobalContextServerDescriptorRegistry::global(cx).0.clone()
}
/// Returns the global [`ContextServerFactoryRegistry`].
/// Returns the global [`ContextServerDescriptorRegistry`].
///
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
/// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut App) -> Entity<Self> {
if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
if !cx.has_global::<GlobalContextServerDescriptorRegistry>() {
let registry = cx.new(|_| Self::new());
cx.set_global(GlobalContextServerFactoryRegistry(registry));
cx.set_global(GlobalContextServerDescriptorRegistry(registry));
}
cx.global::<GlobalContextServerFactoryRegistry>().0.clone()
cx.global::<GlobalContextServerDescriptorRegistry>()
.0
.clone()
}
pub fn new() -> Self {
@ -42,20 +51,28 @@ impl ContextServerFactoryRegistry {
}
}
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
pub fn context_server_descriptors(&self) -> Vec<(Arc<str>, Arc<dyn ContextServerDescriptor>)> {
self.context_servers
.iter()
.map(|(id, factory)| (id.clone(), factory.clone()))
.collect()
}
/// Registers the provided [`ContextServerFactory`].
pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
self.context_servers.insert(id, factory);
pub fn context_server_descriptor(&self, id: &str) -> Option<Arc<dyn ContextServerDescriptor>> {
self.context_servers.get(id).cloned()
}
/// Unregisters the [`ContextServerFactory`] for the server with the given ID.
pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
/// Registers the provided [`ContextServerDescriptor`].
pub fn register_context_server_descriptor(
&mut self,
id: Arc<str>,
descriptor: Arc<dyn ContextServerDescriptor>,
) {
self.context_servers.insert(id, descriptor);
}
/// Unregisters the [`ContextServerDescriptor`] for the server with the given ID.
pub fn unregister_context_server_descriptor_by_id(&mut self, server_id: &str) {
self.context_servers.remove(server_id);
}
}

View file

@ -42,6 +42,30 @@ impl RequestType {
}
}
impl TryFrom<&str> for RequestType {
type Error = ();
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"initialize" => Ok(RequestType::Initialize),
"tools/call" => Ok(RequestType::CallTool),
"resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe),
"resources/subscribe" => Ok(RequestType::ResourcesSubscribe),
"resources/read" => Ok(RequestType::ResourcesRead),
"resources/list" => Ok(RequestType::ResourcesList),
"logging/setLevel" => Ok(RequestType::LoggingSetLevel),
"prompts/get" => Ok(RequestType::PromptsGet),
"prompts/list" => Ok(RequestType::PromptsList),
"completion/complete" => Ok(RequestType::CompletionComplete),
"ping" => Ok(RequestType::Ping),
"tools/list" => Ok(RequestType::ListTools),
"resources/templates/list" => Ok(RequestType::ListResourceTemplates),
"roots/list" => Ok(RequestType::ListRoots),
_ => Err(()),
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ProtocolVersion(pub String);
@ -154,7 +178,7 @@ pub struct CompletionArgument {
pub value: String,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
pub protocol_version: ProtocolVersion,
@ -343,7 +367,7 @@ pub struct ClientCapabilities {
pub roots: Option<RootsCapabilities>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Default, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]