context_server: Make notifications type safe (#32396)
Follow up to #32254 Release Notes: - N/A
This commit is contained in:
parent
3853e83da7
commit
6801b9137f
7 changed files with 67 additions and 43 deletions
|
@ -8,7 +8,7 @@
|
|||
use anyhow::Result;
|
||||
|
||||
use crate::client::Client;
|
||||
use crate::types::{self, Request};
|
||||
use crate::types::{self, Notification, Request};
|
||||
|
||||
pub struct ModelContextProtocol {
|
||||
inner: Client,
|
||||
|
@ -43,7 +43,7 @@ impl ModelContextProtocol {
|
|||
|
||||
let response: types::InitializeResponse = self
|
||||
.inner
|
||||
.request(types::request::Initialize::METHOD, params)
|
||||
.request(types::requests::Initialize::METHOD, params)
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
|
@ -54,16 +54,13 @@ impl ModelContextProtocol {
|
|||
|
||||
log::trace!("mcp server info {:?}", response.server_info);
|
||||
|
||||
self.inner.notify(
|
||||
types::NotificationType::Initialized.as_str(),
|
||||
serde_json::json!({}),
|
||||
)?;
|
||||
|
||||
let initialized_protocol = InitializedContextServerProtocol {
|
||||
inner: self.inner,
|
||||
initialize: response,
|
||||
};
|
||||
|
||||
initialized_protocol.notify::<types::notifications::Initialized>(())?;
|
||||
|
||||
Ok(initialized_protocol)
|
||||
}
|
||||
}
|
||||
|
@ -97,4 +94,8 @@ impl InitializedContextServerProtocol {
|
|||
pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
|
||||
self.inner.request(T::METHOD, params).await
|
||||
}
|
||||
|
||||
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
|
||||
self.inner.notify(T::METHOD, params)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ pub fn create_fake_transport(
|
|||
executor: BackgroundExecutor,
|
||||
) -> FakeTransport {
|
||||
let name = name.into();
|
||||
FakeTransport::new(executor).on_request::<crate::types::request::Initialize>(move |_params| {
|
||||
FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
|
||||
create_initialize_response(name.clone())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use url::Url;
|
|||
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
|
||||
pub const VERSION_2024_11_05: &str = "2024-11-05";
|
||||
|
||||
pub mod request {
|
||||
pub mod requests {
|
||||
use super::*;
|
||||
|
||||
macro_rules! request {
|
||||
|
@ -83,6 +83,57 @@ pub trait Request {
|
|||
const METHOD: &'static str;
|
||||
}
|
||||
|
||||
pub mod notifications {
|
||||
use super::*;
|
||||
|
||||
macro_rules! notification {
|
||||
($method:expr, $name:ident, $params:ty) => {
|
||||
pub struct $name;
|
||||
|
||||
impl Notification for $name {
|
||||
type Params = $params;
|
||||
const METHOD: &'static str = $method;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
notification!("notifications/initialized", Initialized, ());
|
||||
notification!("notifications/progress", Progress, ProgressParams);
|
||||
notification!("notifications/message", Message, MessageParams);
|
||||
notification!(
|
||||
"notifications/resources/updated",
|
||||
ResourcesUpdated,
|
||||
ResourcesUpdatedParams
|
||||
);
|
||||
notification!(
|
||||
"notifications/resources/list_changed",
|
||||
ResourcesListChanged,
|
||||
()
|
||||
);
|
||||
notification!("notifications/tools/list_changed", ToolsListChanged, ());
|
||||
notification!("notifications/prompts/list_changed", PromptsListChanged, ());
|
||||
notification!("notifications/roots/list_changed", RootsListChanged, ());
|
||||
}
|
||||
|
||||
pub trait Notification {
|
||||
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
|
||||
const METHOD: &'static str;
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct MessageParams {
|
||||
pub level: LoggingLevel,
|
||||
pub logger: Option<String>,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesUpdatedParams {
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct ProtocolVersion(pub String);
|
||||
|
@ -560,34 +611,6 @@ pub struct ModelHint {
|
|||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum NotificationType {
|
||||
Initialized,
|
||||
Progress,
|
||||
Message,
|
||||
ResourcesUpdated,
|
||||
ResourcesListChanged,
|
||||
ToolsListChanged,
|
||||
PromptsListChanged,
|
||||
RootsListChanged,
|
||||
}
|
||||
|
||||
impl NotificationType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
NotificationType::Initialized => "notifications/initialized",
|
||||
NotificationType::Progress => "notifications/progress",
|
||||
NotificationType::Message => "notifications/message",
|
||||
NotificationType::ResourcesUpdated => "notifications/resources/updated",
|
||||
NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
|
||||
NotificationType::ToolsListChanged => "notifications/tools/list_changed",
|
||||
NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
|
||||
NotificationType::RootsListChanged => "notifications/roots/list_changed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ClientNotification {
|
||||
|
@ -608,7 +631,7 @@ pub enum ProgressToken {
|
|||
Number(f64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProgressParams {
|
||||
pub progress_token: ProgressToken,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue