context_server: Make notifications type safe (#32396)

Follow up to #32254 

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-06-09 17:11:01 +02:00 committed by GitHub
parent 3853e83da7
commit 6801b9137f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 67 additions and 43 deletions

View file

@ -8,7 +8,7 @@
use anyhow::Result;
use crate::client::Client;
use crate::types::{self, Request};
use crate::types::{self, Notification, Request};
pub struct ModelContextProtocol {
inner: Client,
@ -43,7 +43,7 @@ impl ModelContextProtocol {
let response: types::InitializeResponse = self
.inner
.request(types::request::Initialize::METHOD, params)
.request(types::requests::Initialize::METHOD, params)
.await?;
anyhow::ensure!(
@ -54,16 +54,13 @@ impl ModelContextProtocol {
log::trace!("mcp server info {:?}", response.server_info);
self.inner.notify(
types::NotificationType::Initialized.as_str(),
serde_json::json!({}),
)?;
let initialized_protocol = InitializedContextServerProtocol {
inner: self.inner,
initialize: response,
};
initialized_protocol.notify::<types::notifications::Initialized>(())?;
Ok(initialized_protocol)
}
}
@ -97,4 +94,8 @@ impl InitializedContextServerProtocol {
pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
self.inner.request(T::METHOD, params).await
}
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
self.inner.notify(T::METHOD, params)
}
}

View file

@ -14,7 +14,7 @@ pub fn create_fake_transport(
executor: BackgroundExecutor,
) -> FakeTransport {
let name = name.into();
FakeTransport::new(executor).on_request::<crate::types::request::Initialize>(move |_params| {
FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
create_initialize_response(name.clone())
})
}

View file

@ -6,7 +6,7 @@ use url::Url;
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
pub const VERSION_2024_11_05: &str = "2024-11-05";
pub mod request {
pub mod requests {
use super::*;
macro_rules! request {
@ -83,6 +83,57 @@ pub trait Request {
const METHOD: &'static str;
}
pub mod notifications {
use super::*;
macro_rules! notification {
($method:expr, $name:ident, $params:ty) => {
pub struct $name;
impl Notification for $name {
type Params = $params;
const METHOD: &'static str = $method;
}
};
}
notification!("notifications/initialized", Initialized, ());
notification!("notifications/progress", Progress, ProgressParams);
notification!("notifications/message", Message, MessageParams);
notification!(
"notifications/resources/updated",
ResourcesUpdated,
ResourcesUpdatedParams
);
notification!(
"notifications/resources/list_changed",
ResourcesListChanged,
()
);
notification!("notifications/tools/list_changed", ToolsListChanged, ());
notification!("notifications/prompts/list_changed", PromptsListChanged, ());
notification!("notifications/roots/list_changed", RootsListChanged, ());
}
pub trait Notification {
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
const METHOD: &'static str;
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MessageParams {
pub level: LoggingLevel,
pub logger: Option<String>,
pub data: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesUpdatedParams {
pub uri: String,
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ProtocolVersion(pub String);
@ -560,34 +611,6 @@ pub struct ModelHint {
pub name: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum NotificationType {
Initialized,
Progress,
Message,
ResourcesUpdated,
ResourcesListChanged,
ToolsListChanged,
PromptsListChanged,
RootsListChanged,
}
impl NotificationType {
pub fn as_str(&self) -> &'static str {
match self {
NotificationType::Initialized => "notifications/initialized",
NotificationType::Progress => "notifications/progress",
NotificationType::Message => "notifications/message",
NotificationType::ResourcesUpdated => "notifications/resources/updated",
NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
NotificationType::ToolsListChanged => "notifications/tools/list_changed",
NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
NotificationType::RootsListChanged => "notifications/roots/list_changed",
}
}
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum ClientNotification {
@ -608,7 +631,7 @@ pub enum ProgressToken {
Number(f64),
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ProgressParams {
pub progress_token: ProgressToken,