context_server: Make notifications type safe (#32396)

Follow up to #32254 

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-06-09 17:11:01 +02:00 committed by GitHub
parent 3853e83da7
commit 6801b9137f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 67 additions and 43 deletions

View file

@ -105,7 +105,7 @@ impl Tool for ContextServerTool {
arguments arguments
); );
let response = protocol let response = protocol
.request::<context_server::types::request::CallTool>( .request::<context_server::types::requests::CallTool>(
context_server::types::CallToolParams { context_server::types::CallToolParams {
name: tool_name, name: tool_name,
arguments, arguments,

View file

@ -562,7 +562,7 @@ impl ThreadStore {
if protocol.capable(context_server::protocol::ServerCapability::Tools) { if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(response) = protocol if let Some(response) = protocol
.request::<context_server::types::request::ListTools>(()) .request::<context_server::types::requests::ListTools>(())
.await .await
.log_err() .log_err()
{ {

View file

@ -869,7 +869,7 @@ impl ContextStore {
if protocol.capable(context_server::protocol::ServerCapability::Prompts) { if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(response) = protocol if let Some(response) = protocol
.request::<context_server::types::request::PromptsList>(()) .request::<context_server::types::requests::PromptsList>(())
.await .await
.log_err() .log_err()
{ {

View file

@ -87,7 +87,7 @@ impl SlashCommand for ContextServerSlashCommand {
let protocol = server.client().context("Context server not initialized")?; let protocol = server.client().context("Context server not initialized")?;
let response = protocol let response = protocol
.request::<context_server::types::request::CompletionComplete>( .request::<context_server::types::requests::CompletionComplete>(
context_server::types::CompletionCompleteParams { context_server::types::CompletionCompleteParams {
reference: context_server::types::CompletionReference::Prompt( reference: context_server::types::CompletionReference::Prompt(
context_server::types::PromptReference { context_server::types::PromptReference {
@ -145,7 +145,7 @@ impl SlashCommand for ContextServerSlashCommand {
cx.foreground_executor().spawn(async move { cx.foreground_executor().spawn(async move {
let protocol = server.client().context("Context server not initialized")?; let protocol = server.client().context("Context server not initialized")?;
let response = protocol let response = protocol
.request::<context_server::types::request::PromptsGet>( .request::<context_server::types::requests::PromptsGet>(
context_server::types::PromptsGetParams { context_server::types::PromptsGetParams {
name: prompt_name.clone(), name: prompt_name.clone(),
arguments: Some(prompt_args), arguments: Some(prompt_args),

View file

@ -8,7 +8,7 @@
use anyhow::Result; use anyhow::Result;
use crate::client::Client; use crate::client::Client;
use crate::types::{self, Request}; use crate::types::{self, Notification, Request};
pub struct ModelContextProtocol { pub struct ModelContextProtocol {
inner: Client, inner: Client,
@ -43,7 +43,7 @@ impl ModelContextProtocol {
let response: types::InitializeResponse = self let response: types::InitializeResponse = self
.inner .inner
.request(types::request::Initialize::METHOD, params) .request(types::requests::Initialize::METHOD, params)
.await?; .await?;
anyhow::ensure!( anyhow::ensure!(
@ -54,16 +54,13 @@ impl ModelContextProtocol {
log::trace!("mcp server info {:?}", response.server_info); log::trace!("mcp server info {:?}", response.server_info);
self.inner.notify(
types::NotificationType::Initialized.as_str(),
serde_json::json!({}),
)?;
let initialized_protocol = InitializedContextServerProtocol { let initialized_protocol = InitializedContextServerProtocol {
inner: self.inner, inner: self.inner,
initialize: response, initialize: response,
}; };
initialized_protocol.notify::<types::notifications::Initialized>(())?;
Ok(initialized_protocol) Ok(initialized_protocol)
} }
} }
@ -97,4 +94,8 @@ impl InitializedContextServerProtocol {
pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> { pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
self.inner.request(T::METHOD, params).await self.inner.request(T::METHOD, params).await
} }
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
self.inner.notify(T::METHOD, params)
}
} }

View file

@ -14,7 +14,7 @@ pub fn create_fake_transport(
executor: BackgroundExecutor, executor: BackgroundExecutor,
) -> FakeTransport { ) -> FakeTransport {
let name = name.into(); let name = name.into();
FakeTransport::new(executor).on_request::<crate::types::request::Initialize>(move |_params| { FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
create_initialize_response(name.clone()) create_initialize_response(name.clone())
}) })
} }

View file

@ -6,7 +6,7 @@ use url::Url;
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
pub const VERSION_2024_11_05: &str = "2024-11-05"; pub const VERSION_2024_11_05: &str = "2024-11-05";
pub mod request { pub mod requests {
use super::*; use super::*;
macro_rules! request { macro_rules! request {
@ -83,6 +83,57 @@ pub trait Request {
const METHOD: &'static str; const METHOD: &'static str;
} }
pub mod notifications {
use super::*;
macro_rules! notification {
($method:expr, $name:ident, $params:ty) => {
pub struct $name;
impl Notification for $name {
type Params = $params;
const METHOD: &'static str = $method;
}
};
}
notification!("notifications/initialized", Initialized, ());
notification!("notifications/progress", Progress, ProgressParams);
notification!("notifications/message", Message, MessageParams);
notification!(
"notifications/resources/updated",
ResourcesUpdated,
ResourcesUpdatedParams
);
notification!(
"notifications/resources/list_changed",
ResourcesListChanged,
()
);
notification!("notifications/tools/list_changed", ToolsListChanged, ());
notification!("notifications/prompts/list_changed", PromptsListChanged, ());
notification!("notifications/roots/list_changed", RootsListChanged, ());
}
pub trait Notification {
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
const METHOD: &'static str;
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MessageParams {
pub level: LoggingLevel,
pub logger: Option<String>,
pub data: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesUpdatedParams {
pub uri: String,
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)] #[serde(transparent)]
pub struct ProtocolVersion(pub String); pub struct ProtocolVersion(pub String);
@ -560,34 +611,6 @@ pub struct ModelHint {
pub name: Option<String>, pub name: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum NotificationType {
Initialized,
Progress,
Message,
ResourcesUpdated,
ResourcesListChanged,
ToolsListChanged,
PromptsListChanged,
RootsListChanged,
}
impl NotificationType {
pub fn as_str(&self) -> &'static str {
match self {
NotificationType::Initialized => "notifications/initialized",
NotificationType::Progress => "notifications/progress",
NotificationType::Message => "notifications/message",
NotificationType::ResourcesUpdated => "notifications/resources/updated",
NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
NotificationType::ToolsListChanged => "notifications/tools/list_changed",
NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
NotificationType::RootsListChanged => "notifications/roots/list_changed",
}
}
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum ClientNotification { pub enum ClientNotification {
@ -608,7 +631,7 @@ pub enum ProgressToken {
Number(f64), Number(f64),
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ProgressParams { pub struct ProgressParams {
pub progress_token: ProgressToken, pub progress_token: ProgressToken,