diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index 026911128e..17571fca04 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -105,7 +105,7 @@ impl Tool for ContextServerTool { arguments ); let response = protocol - .request::( + .request::( context_server::types::CallToolParams { name: tool_name, arguments, diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 620279249e..db87bdd3a5 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -562,7 +562,7 @@ impl ThreadStore { if protocol.capable(context_server::protocol::ServerCapability::Tools) { if let Some(response) = protocol - .request::(()) + .request::(()) .await .log_err() { diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index 128aa87008..30268420dd 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -869,7 +869,7 @@ impl ContextStore { if protocol.capable(context_server::protocol::ServerCapability::Prompts) { if let Some(response) = protocol - .request::(()) + .request::(()) .await .log_err() { diff --git a/crates/assistant_slash_commands/src/context_server_command.rs b/crates/assistant_slash_commands/src/context_server_command.rs index 509076c167..f223d3b184 100644 --- a/crates/assistant_slash_commands/src/context_server_command.rs +++ b/crates/assistant_slash_commands/src/context_server_command.rs @@ -87,7 +87,7 @@ impl SlashCommand for ContextServerSlashCommand { let protocol = server.client().context("Context server not initialized")?; let response = protocol - .request::( + .request::( context_server::types::CompletionCompleteParams { reference: context_server::types::CompletionReference::Prompt( context_server::types::PromptReference { @@ -145,7 +145,7 @@ impl SlashCommand for ContextServerSlashCommand { cx.foreground_executor().spawn(async move { let protocol = server.client().context("Context server not initialized")?; let response = protocol - .request::( + .request::( context_server::types::PromptsGetParams { name: prompt_name.clone(), arguments: Some(prompt_args), diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 8f50cd8fa5..d8bbac60d6 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -8,7 +8,7 @@ use anyhow::Result; use crate::client::Client; -use crate::types::{self, Request}; +use crate::types::{self, Notification, Request}; pub struct ModelContextProtocol { inner: Client, @@ -43,7 +43,7 @@ impl ModelContextProtocol { let response: types::InitializeResponse = self .inner - .request(types::request::Initialize::METHOD, params) + .request(types::requests::Initialize::METHOD, params) .await?; anyhow::ensure!( @@ -54,16 +54,13 @@ impl ModelContextProtocol { log::trace!("mcp server info {:?}", response.server_info); - self.inner.notify( - types::NotificationType::Initialized.as_str(), - serde_json::json!({}), - )?; - let initialized_protocol = InitializedContextServerProtocol { inner: self.inner, initialize: response, }; + initialized_protocol.notify::(())?; + Ok(initialized_protocol) } } @@ -97,4 +94,8 @@ impl InitializedContextServerProtocol { pub async fn request(&self, params: T::Params) -> Result { self.inner.request(T::METHOD, params).await } + + pub fn notify(&self, params: T::Params) -> Result<()> { + self.inner.notify(T::METHOD, params) + } } diff --git a/crates/context_server/src/test.rs b/crates/context_server/src/test.rs index d882a56984..dedf589664 100644 --- a/crates/context_server/src/test.rs +++ b/crates/context_server/src/test.rs @@ -14,7 +14,7 @@ pub fn create_fake_transport( executor: BackgroundExecutor, ) -> FakeTransport { let name = name.into(); - FakeTransport::new(executor).on_request::(move |_params| { + FakeTransport::new(executor).on_request::(move |_params| { create_initialize_response(name.clone()) }) } diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 1ab3225e1e..8e3daf9e22 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -6,7 +6,7 @@ use url::Url; pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const VERSION_2024_11_05: &str = "2024-11-05"; -pub mod request { +pub mod requests { use super::*; macro_rules! request { @@ -83,6 +83,57 @@ pub trait Request { const METHOD: &'static str; } +pub mod notifications { + use super::*; + + macro_rules! notification { + ($method:expr, $name:ident, $params:ty) => { + pub struct $name; + + impl Notification for $name { + type Params = $params; + const METHOD: &'static str = $method; + } + }; + } + + notification!("notifications/initialized", Initialized, ()); + notification!("notifications/progress", Progress, ProgressParams); + notification!("notifications/message", Message, MessageParams); + notification!( + "notifications/resources/updated", + ResourcesUpdated, + ResourcesUpdatedParams + ); + notification!( + "notifications/resources/list_changed", + ResourcesListChanged, + () + ); + notification!("notifications/tools/list_changed", ToolsListChanged, ()); + notification!("notifications/prompts/list_changed", PromptsListChanged, ()); + notification!("notifications/roots/list_changed", RootsListChanged, ()); +} + +pub trait Notification { + type Params: DeserializeOwned + Serialize + Send + Sync + 'static; + const METHOD: &'static str; +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct MessageParams { + pub level: LoggingLevel, + pub logger: Option, + pub data: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourcesUpdatedParams { + pub uri: String, +} + #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct ProtocolVersion(pub String); @@ -560,34 +611,6 @@ pub struct ModelHint { pub name: Option, } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub enum NotificationType { - Initialized, - Progress, - Message, - ResourcesUpdated, - ResourcesListChanged, - ToolsListChanged, - PromptsListChanged, - RootsListChanged, -} - -impl NotificationType { - pub fn as_str(&self) -> &'static str { - match self { - NotificationType::Initialized => "notifications/initialized", - NotificationType::Progress => "notifications/progress", - NotificationType::Message => "notifications/message", - NotificationType::ResourcesUpdated => "notifications/resources/updated", - NotificationType::ResourcesListChanged => "notifications/resources/list_changed", - NotificationType::ToolsListChanged => "notifications/tools/list_changed", - NotificationType::PromptsListChanged => "notifications/prompts/list_changed", - NotificationType::RootsListChanged => "notifications/roots/list_changed", - } - } -} - #[derive(Debug, Serialize)] #[serde(untagged)] pub enum ClientNotification { @@ -608,7 +631,7 @@ pub enum ProgressToken { Number(f64), } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ProgressParams { pub progress_token: ProgressToken,