context_server: Make notifications type safe (#32396)
Follow up to #32254 Release Notes: - N/A
This commit is contained in:
parent
3853e83da7
commit
6801b9137f
7 changed files with 67 additions and 43 deletions
|
@ -105,7 +105,7 @@ impl Tool for ContextServerTool {
|
||||||
arguments
|
arguments
|
||||||
);
|
);
|
||||||
let response = protocol
|
let response = protocol
|
||||||
.request::<context_server::types::request::CallTool>(
|
.request::<context_server::types::requests::CallTool>(
|
||||||
context_server::types::CallToolParams {
|
context_server::types::CallToolParams {
|
||||||
name: tool_name,
|
name: tool_name,
|
||||||
arguments,
|
arguments,
|
||||||
|
|
|
@ -562,7 +562,7 @@ impl ThreadStore {
|
||||||
|
|
||||||
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
||||||
if let Some(response) = protocol
|
if let Some(response) = protocol
|
||||||
.request::<context_server::types::request::ListTools>(())
|
.request::<context_server::types::requests::ListTools>(())
|
||||||
.await
|
.await
|
||||||
.log_err()
|
.log_err()
|
||||||
{
|
{
|
||||||
|
|
|
@ -869,7 +869,7 @@ impl ContextStore {
|
||||||
|
|
||||||
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
||||||
if let Some(response) = protocol
|
if let Some(response) = protocol
|
||||||
.request::<context_server::types::request::PromptsList>(())
|
.request::<context_server::types::requests::PromptsList>(())
|
||||||
.await
|
.await
|
||||||
.log_err()
|
.log_err()
|
||||||
{
|
{
|
||||||
|
|
|
@ -87,7 +87,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||||
let protocol = server.client().context("Context server not initialized")?;
|
let protocol = server.client().context("Context server not initialized")?;
|
||||||
|
|
||||||
let response = protocol
|
let response = protocol
|
||||||
.request::<context_server::types::request::CompletionComplete>(
|
.request::<context_server::types::requests::CompletionComplete>(
|
||||||
context_server::types::CompletionCompleteParams {
|
context_server::types::CompletionCompleteParams {
|
||||||
reference: context_server::types::CompletionReference::Prompt(
|
reference: context_server::types::CompletionReference::Prompt(
|
||||||
context_server::types::PromptReference {
|
context_server::types::PromptReference {
|
||||||
|
@ -145,7 +145,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||||
cx.foreground_executor().spawn(async move {
|
cx.foreground_executor().spawn(async move {
|
||||||
let protocol = server.client().context("Context server not initialized")?;
|
let protocol = server.client().context("Context server not initialized")?;
|
||||||
let response = protocol
|
let response = protocol
|
||||||
.request::<context_server::types::request::PromptsGet>(
|
.request::<context_server::types::requests::PromptsGet>(
|
||||||
context_server::types::PromptsGetParams {
|
context_server::types::PromptsGetParams {
|
||||||
name: prompt_name.clone(),
|
name: prompt_name.clone(),
|
||||||
arguments: Some(prompt_args),
|
arguments: Some(prompt_args),
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
||||||
use crate::client::Client;
|
use crate::client::Client;
|
||||||
use crate::types::{self, Request};
|
use crate::types::{self, Notification, Request};
|
||||||
|
|
||||||
pub struct ModelContextProtocol {
|
pub struct ModelContextProtocol {
|
||||||
inner: Client,
|
inner: Client,
|
||||||
|
@ -43,7 +43,7 @@ impl ModelContextProtocol {
|
||||||
|
|
||||||
let response: types::InitializeResponse = self
|
let response: types::InitializeResponse = self
|
||||||
.inner
|
.inner
|
||||||
.request(types::request::Initialize::METHOD, params)
|
.request(types::requests::Initialize::METHOD, params)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
anyhow::ensure!(
|
anyhow::ensure!(
|
||||||
|
@ -54,16 +54,13 @@ impl ModelContextProtocol {
|
||||||
|
|
||||||
log::trace!("mcp server info {:?}", response.server_info);
|
log::trace!("mcp server info {:?}", response.server_info);
|
||||||
|
|
||||||
self.inner.notify(
|
|
||||||
types::NotificationType::Initialized.as_str(),
|
|
||||||
serde_json::json!({}),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let initialized_protocol = InitializedContextServerProtocol {
|
let initialized_protocol = InitializedContextServerProtocol {
|
||||||
inner: self.inner,
|
inner: self.inner,
|
||||||
initialize: response,
|
initialize: response,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
initialized_protocol.notify::<types::notifications::Initialized>(())?;
|
||||||
|
|
||||||
Ok(initialized_protocol)
|
Ok(initialized_protocol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -97,4 +94,8 @@ impl InitializedContextServerProtocol {
|
||||||
pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
|
pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
|
||||||
self.inner.request(T::METHOD, params).await
|
self.inner.request(T::METHOD, params).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
|
||||||
|
self.inner.notify(T::METHOD, params)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ pub fn create_fake_transport(
|
||||||
executor: BackgroundExecutor,
|
executor: BackgroundExecutor,
|
||||||
) -> FakeTransport {
|
) -> FakeTransport {
|
||||||
let name = name.into();
|
let name = name.into();
|
||||||
FakeTransport::new(executor).on_request::<crate::types::request::Initialize>(move |_params| {
|
FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
|
||||||
create_initialize_response(name.clone())
|
create_initialize_response(name.clone())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ use url::Url;
|
||||||
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
|
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
|
||||||
pub const VERSION_2024_11_05: &str = "2024-11-05";
|
pub const VERSION_2024_11_05: &str = "2024-11-05";
|
||||||
|
|
||||||
pub mod request {
|
pub mod requests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
macro_rules! request {
|
macro_rules! request {
|
||||||
|
@ -83,6 +83,57 @@ pub trait Request {
|
||||||
const METHOD: &'static str;
|
const METHOD: &'static str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub mod notifications {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
macro_rules! notification {
|
||||||
|
($method:expr, $name:ident, $params:ty) => {
|
||||||
|
pub struct $name;
|
||||||
|
|
||||||
|
impl Notification for $name {
|
||||||
|
type Params = $params;
|
||||||
|
const METHOD: &'static str = $method;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
notification!("notifications/initialized", Initialized, ());
|
||||||
|
notification!("notifications/progress", Progress, ProgressParams);
|
||||||
|
notification!("notifications/message", Message, MessageParams);
|
||||||
|
notification!(
|
||||||
|
"notifications/resources/updated",
|
||||||
|
ResourcesUpdated,
|
||||||
|
ResourcesUpdatedParams
|
||||||
|
);
|
||||||
|
notification!(
|
||||||
|
"notifications/resources/list_changed",
|
||||||
|
ResourcesListChanged,
|
||||||
|
()
|
||||||
|
);
|
||||||
|
notification!("notifications/tools/list_changed", ToolsListChanged, ());
|
||||||
|
notification!("notifications/prompts/list_changed", PromptsListChanged, ());
|
||||||
|
notification!("notifications/roots/list_changed", RootsListChanged, ());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Notification {
|
||||||
|
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
|
||||||
|
const METHOD: &'static str;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct MessageParams {
|
||||||
|
pub level: LoggingLevel,
|
||||||
|
pub logger: Option<String>,
|
||||||
|
pub data: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct ResourcesUpdatedParams {
|
||||||
|
pub uri: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
#[serde(transparent)]
|
#[serde(transparent)]
|
||||||
pub struct ProtocolVersion(pub String);
|
pub struct ProtocolVersion(pub String);
|
||||||
|
@ -560,34 +611,6 @@ pub struct ModelHint {
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub enum NotificationType {
|
|
||||||
Initialized,
|
|
||||||
Progress,
|
|
||||||
Message,
|
|
||||||
ResourcesUpdated,
|
|
||||||
ResourcesListChanged,
|
|
||||||
ToolsListChanged,
|
|
||||||
PromptsListChanged,
|
|
||||||
RootsListChanged,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl NotificationType {
|
|
||||||
pub fn as_str(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
NotificationType::Initialized => "notifications/initialized",
|
|
||||||
NotificationType::Progress => "notifications/progress",
|
|
||||||
NotificationType::Message => "notifications/message",
|
|
||||||
NotificationType::ResourcesUpdated => "notifications/resources/updated",
|
|
||||||
NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
|
|
||||||
NotificationType::ToolsListChanged => "notifications/tools/list_changed",
|
|
||||||
NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
|
|
||||||
NotificationType::RootsListChanged => "notifications/roots/list_changed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum ClientNotification {
|
pub enum ClientNotification {
|
||||||
|
@ -608,7 +631,7 @@ pub enum ProgressToken {
|
||||||
Number(f64),
|
Number(f64),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ProgressParams {
|
pub struct ProgressParams {
|
||||||
pub progress_token: ProgressToken,
|
pub progress_token: ProgressToken,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue