context_servers: Upgrade protocol to version 2024-11-05 (#20615)

This updates context servers to the most recent version

Release Notes:

- N/A

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
David Soria Parra 2024-11-14 18:03:30 +00:00 committed by GitHub
parent b5ce8e7aa5
commit 690a725667
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 180 additions and 34 deletions

View file

@ -152,7 +152,7 @@ impl SlashCommand for ContextServerSlashCommand {
if result if result
.messages .messages
.iter() .iter()
.any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User)) .any(|msg| !matches!(msg.role, context_servers::types::Role::User))
{ {
return Err(anyhow!( return Err(anyhow!(
"Prompt contains non-user roles, which is not supported" "Prompt contains non-user roles, which is not supported"
@ -164,7 +164,7 @@ impl SlashCommand for ContextServerSlashCommand {
.messages .messages
.into_iter() .into_iter()
.filter_map(|msg| match msg.content { .filter_map(|msg| match msg.content {
context_servers::types::SamplingContent::Text { text } => Some(text), context_servers::types::MessageContent::Text { text } => Some(text),
_ => None, _ => None,
}) })
.collect::<Vec<String>>() .collect::<Vec<String>>()

View file

@ -74,11 +74,21 @@ impl Tool for ContextServerTool {
); );
let response = protocol.run_tool(tool_name, arguments).await?; let response = protocol.run_tool(tool_name, arguments).await?;
let tool_result = match response.tool_result { let mut result = String::new();
serde_json::Value::String(s) => s, for content in response.content {
_ => serde_json::to_string(&response.tool_result)?, match content {
}; types::ToolResponseContent::Text { text } => {
Ok(tool_result) result.push_str(&text);
}
types::ToolResponseContent::Image { .. } => {
log::warn!("Ignoring image content from tool response");
}
types::ToolResponseContent::Resource { .. } => {
log::warn!("Ignoring resource content from tool response");
}
}
}
Ok(result)
} }
}) })
} else { } else {

View file

@ -25,6 +25,13 @@ use util::TryFutureExt;
const JSON_RPC_VERSION: &str = "2.0"; const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
// Standard JSON-RPC error codes
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>; type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>; type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;

View file

@ -11,8 +11,6 @@ use collections::HashMap;
use crate::client::Client; use crate::client::Client;
use crate::types; use crate::types;
const PROTOCOL_VERSION: &str = "2024-10-07";
pub struct ModelContextProtocol { pub struct ModelContextProtocol {
inner: Client, inner: Client,
} }
@ -23,10 +21,9 @@ impl ModelContextProtocol {
} }
fn supported_protocols() -> Vec<types::ProtocolVersion> { fn supported_protocols() -> Vec<types::ProtocolVersion> {
vec![ vec![types::ProtocolVersion(
types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()), types::LATEST_PROTOCOL_VERSION.to_string(),
types::ProtocolVersion::VersionNumber(1), )]
]
} }
pub async fn initialize( pub async fn initialize(
@ -34,11 +31,13 @@ impl ModelContextProtocol {
client_info: types::Implementation, client_info: types::Implementation,
) -> Result<InitializedContextServerProtocol> { ) -> Result<InitializedContextServerProtocol> {
let params = types::InitializeParams { let params = types::InitializeParams {
protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()), protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
capabilities: types::ClientCapabilities { capabilities: types::ClientCapabilities {
experimental: None, experimental: None,
sampling: None, sampling: None,
roots: None,
}, },
meta: None,
client_info, client_info,
}; };
@ -148,6 +147,7 @@ impl InitializedContextServerProtocol {
let params = types::PromptsGetParams { let params = types::PromptsGetParams {
name: prompt.as_ref().to_string(), name: prompt.as_ref().to_string(),
arguments: Some(arguments), arguments: Some(arguments),
meta: None,
}; };
let response: types::PromptsGetResponse = self let response: types::PromptsGetResponse = self
@ -170,6 +170,7 @@ impl InitializedContextServerProtocol {
name: argument.into(), name: argument.into(),
value: value.into(), value: value.into(),
}, },
meta: None,
}; };
let result: types::CompletionCompleteResponse = self let result: types::CompletionCompleteResponse = self
.inner .inner
@ -210,6 +211,7 @@ impl InitializedContextServerProtocol {
let params = types::CallToolParams { let params = types::CallToolParams {
name: tool.as_ref().to_string(), name: tool.as_ref().to_string(),
arguments, arguments,
meta: None,
}; };
let response: types::CallToolResponse = self let response: types::CallToolResponse = self

View file

@ -2,8 +2,8 @@ use collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
#[derive(Debug, Serialize)] pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
#[serde(rename_all = "camelCase")]
pub enum RequestType { pub enum RequestType {
Initialize, Initialize,
CallTool, CallTool,
@ -18,6 +18,7 @@ pub enum RequestType {
Ping, Ping,
ListTools, ListTools,
ListResourceTemplates, ListResourceTemplates,
ListRoots,
} }
impl RequestType { impl RequestType {
@ -36,16 +37,14 @@ impl RequestType {
RequestType::Ping => "ping", RequestType::Ping => "ping",
RequestType::ListTools => "tools/list", RequestType::ListTools => "tools/list",
RequestType::ListResourceTemplates => "resources/templates/list", RequestType::ListResourceTemplates => "resources/templates/list",
RequestType::ListRoots => "roots/list",
} }
} }
} }
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)] #[serde(transparent)]
pub enum ProtocolVersion { pub struct ProtocolVersion(pub String);
VersionString(String),
VersionNumber(u32),
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -53,6 +52,8 @@ pub struct InitializeParams {
pub protocol_version: ProtocolVersion, pub protocol_version: ProtocolVersion,
pub capabilities: ClientCapabilities, pub capabilities: ClientCapabilities,
pub client_info: Implementation, pub client_info: Implementation,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -61,30 +62,40 @@ pub struct CallToolParams {
pub name: String, pub name: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<HashMap<String, serde_json::Value>>, pub arguments: Option<HashMap<String, serde_json::Value>>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourcesUnsubscribeParams { pub struct ResourcesUnsubscribeParams {
pub uri: Url, pub uri: Url,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourcesSubscribeParams { pub struct ResourcesSubscribeParams {
pub uri: Url, pub uri: Url,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourcesReadParams { pub struct ResourcesReadParams {
pub uri: Url, pub uri: Url,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct LoggingSetLevelParams { pub struct LoggingSetLevelParams {
pub level: LoggingLevel, pub level: LoggingLevel,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -93,6 +104,8 @@ pub struct PromptsGetParams {
pub name: String, pub name: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<HashMap<String, String>>, pub arguments: Option<HashMap<String, String>>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -100,6 +113,8 @@ pub struct PromptsGetParams {
pub struct CompletionCompleteParams { pub struct CompletionCompleteParams {
pub r#ref: CompletionReference, pub r#ref: CompletionReference,
pub argument: CompletionArgument, pub argument: CompletionArgument,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -145,12 +160,16 @@ pub struct InitializeResponse {
pub protocol_version: ProtocolVersion, pub protocol_version: ProtocolVersion,
pub capabilities: ServerCapabilities, pub capabilities: ServerCapabilities,
pub server_info: Implementation, pub server_info: Implementation,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourcesReadResponse { pub struct ResourcesReadResponse {
pub contents: Vec<ResourceContent>, pub contents: Vec<ResourceContents>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -159,29 +178,39 @@ pub struct ResourcesListResponse {
pub resources: Vec<Resource>, pub resources: Vec<Resource>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>, pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamplingMessage {
pub role: Role,
pub content: MessageContent,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct SamplingMessage { pub struct PromptMessage {
pub role: SamplingRole, pub role: Role,
pub content: SamplingContent, pub content: MessageContent,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum SamplingRole { pub enum Role {
User, User,
Assistant, Assistant,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum SamplingContent { pub enum MessageContent {
#[serde(rename = "text")] #[serde(rename = "text")]
Text { text: String }, Text { text: String },
#[serde(rename = "image")] #[serde(rename = "image")]
Image { data: String, mime_type: String }, Image { data: String, mime_type: String },
#[serde(rename = "resource")]
Resource { resource: ResourceContents },
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -189,7 +218,9 @@ pub enum SamplingContent {
pub struct PromptsGetResponse { pub struct PromptsGetResponse {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>, pub description: Option<String>,
pub messages: Vec<SamplingMessage>, pub messages: Vec<PromptMessage>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -198,12 +229,16 @@ pub struct PromptsListResponse {
pub prompts: Vec<Prompt>, pub prompts: Vec<Prompt>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>, pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CompletionCompleteResponse { pub struct CompletionCompleteResponse {
pub completion: CompletionResult, pub completion: CompletionResult,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -214,6 +249,8 @@ pub struct CompletionResult {
pub total: Option<u32>, pub total: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub has_more: Option<bool>, pub has_more: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
@ -243,6 +280,8 @@ pub struct ClientCapabilities {
pub experimental: Option<HashMap<String, serde_json::Value>>, pub experimental: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub sampling: Option<serde_json::Value>, pub sampling: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub roots: Option<RootsCapabilities>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -283,6 +322,13 @@ pub struct ToolsCapabilities {
pub list_changed: Option<bool>, pub list_changed: Option<bool>,
} }
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RootsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Tool { pub struct Tool {
@ -312,14 +358,28 @@ pub struct Resource {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourceContent { pub struct ResourceContents {
pub uri: Url, pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>, pub mime_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextResourceContents {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>, pub mime_type: Option<String>,
pub text: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BlobResourceContents {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub blob: Option<String>, pub mime_type: Option<String>,
pub blob: String,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -338,8 +398,32 @@ pub struct ResourceTemplate {
pub enum LoggingLevel { pub enum LoggingLevel {
Debug, Debug,
Info, Info,
Notice,
Warning, Warning,
Error, Error,
Critical,
Alert,
Emergency,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ModelPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub hints: Option<Vec<ModelHint>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_priority: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub intelligence_priority: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ModelHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -352,6 +436,7 @@ pub enum NotificationType {
ResourcesListChanged, ResourcesListChanged,
ToolsListChanged, ToolsListChanged,
PromptsListChanged, PromptsListChanged,
RootsListChanged,
} }
impl NotificationType { impl NotificationType {
@ -364,6 +449,7 @@ impl NotificationType {
NotificationType::ResourcesListChanged => "notifications/resources/list_changed", NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
NotificationType::ToolsListChanged => "notifications/tools/list_changed", NotificationType::ToolsListChanged => "notifications/tools/list_changed",
NotificationType::PromptsListChanged => "notifications/prompts/list_changed", NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
NotificationType::RootsListChanged => "notifications/roots/list_changed",
} }
} }
} }
@ -373,6 +459,14 @@ impl NotificationType {
pub enum ClientNotification { pub enum ClientNotification {
Initialized, Initialized,
Progress(ProgressParams), Progress(ProgressParams),
RootsListChanged,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ProgressToken {
String(String),
Number(f64),
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -382,10 +476,10 @@ pub struct ProgressParams {
pub progress: f64, pub progress: f64,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<f64>, pub total: Option<f64>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
} }
pub type ProgressToken = String;
pub enum CompletionTotal { pub enum CompletionTotal {
Exact(u32), Exact(u32),
HasMore, HasMore,
@ -410,7 +504,22 @@ pub struct Completion {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CallToolResponse { pub struct CallToolResponse {
pub tool_result: serde_json::Value, pub content: Vec<ToolResponseContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ToolResponseContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { data: String, mime_type: String },
#[serde(rename = "resource")]
Resource { resource: ResourceContents },
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -419,4 +528,22 @@ pub struct ListToolsResponse {
pub tools: Vec<Tool>, pub tools: Vec<Tool>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>, pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListRootsResponse {
pub roots: Vec<Root>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Root {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
} }