diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs index 672af37115..692b4f6ea7 100644 --- a/crates/assistant/src/slash_command/context_server_command.rs +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -152,7 +152,7 @@ impl SlashCommand for ContextServerSlashCommand { if result .messages .iter() - .any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User)) + .any(|msg| !matches!(msg.role, context_servers::types::Role::User)) { return Err(anyhow!( "Prompt contains non-user roles, which is not supported" @@ -164,7 +164,7 @@ impl SlashCommand for ContextServerSlashCommand { .messages .into_iter() .filter_map(|msg| match msg.content { - context_servers::types::SamplingContent::Text { text } => Some(text), + context_servers::types::MessageContent::Text { text } => Some(text), _ => None, }) .collect::>() diff --git a/crates/assistant/src/tools/context_server_tool.rs b/crates/assistant/src/tools/context_server_tool.rs index aa742bd9eb..8015d94df9 100644 --- a/crates/assistant/src/tools/context_server_tool.rs +++ b/crates/assistant/src/tools/context_server_tool.rs @@ -74,11 +74,21 @@ impl Tool for ContextServerTool { ); let response = protocol.run_tool(tool_name, arguments).await?; - let tool_result = match response.tool_result { - serde_json::Value::String(s) => s, - _ => serde_json::to_string(&response.tool_result)?, - }; - Ok(tool_result) + let mut result = String::new(); + for content in response.content { + match content { + types::ToolResponseContent::Text { text } => { + result.push_str(&text); + } + types::ToolResponseContent::Image { .. } => { + log::warn!("Ignoring image content from tool response"); + } + types::ToolResponseContent::Resource { .. } => { + log::warn!("Ignoring resource content from tool response"); + } + } + } + Ok(result) } }) } else { diff --git a/crates/context_servers/src/client.rs b/crates/context_servers/src/client.rs index 34878fc421..8202e950d6 100644 --- a/crates/context_servers/src/client.rs +++ b/crates/context_servers/src/client.rs @@ -25,6 +25,13 @@ use util::TryFutureExt; const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); +// Standard JSON-RPC error codes +pub const PARSE_ERROR: i32 = -32700; +pub const INVALID_REQUEST: i32 = -32600; +pub const METHOD_NOT_FOUND: i32 = -32601; +pub const INVALID_PARAMS: i32 = -32602; +pub const INTERNAL_ERROR: i32 = -32603; + type ResponseHandler = Box)>; type NotificationHandler = Box; diff --git a/crates/context_servers/src/protocol.rs b/crates/context_servers/src/protocol.rs index 12fe18ecf5..91fa9289cc 100644 --- a/crates/context_servers/src/protocol.rs +++ b/crates/context_servers/src/protocol.rs @@ -11,8 +11,6 @@ use collections::HashMap; use crate::client::Client; use crate::types; -const PROTOCOL_VERSION: &str = "2024-10-07"; - pub struct ModelContextProtocol { inner: Client, } @@ -23,10 +21,9 @@ impl ModelContextProtocol { } fn supported_protocols() -> Vec { - vec![ - types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()), - types::ProtocolVersion::VersionNumber(1), - ] + vec![types::ProtocolVersion( + types::LATEST_PROTOCOL_VERSION.to_string(), + )] } pub async fn initialize( @@ -34,11 +31,13 @@ impl ModelContextProtocol { client_info: types::Implementation, ) -> Result { let params = types::InitializeParams { - protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()), + protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()), capabilities: types::ClientCapabilities { experimental: None, sampling: None, + roots: None, }, + meta: None, client_info, }; @@ -148,6 +147,7 @@ impl InitializedContextServerProtocol { let params = types::PromptsGetParams { name: prompt.as_ref().to_string(), arguments: Some(arguments), + meta: None, }; let response: types::PromptsGetResponse = self @@ -170,6 +170,7 @@ impl InitializedContextServerProtocol { name: argument.into(), value: value.into(), }, + meta: None, }; let result: types::CompletionCompleteResponse = self .inner @@ -210,6 +211,7 @@ impl InitializedContextServerProtocol { let params = types::CallToolParams { name: tool.as_ref().to_string(), arguments, + meta: None, }; let response: types::CallToolResponse = self diff --git a/crates/context_servers/src/types.rs b/crates/context_servers/src/types.rs index b6d8a958bb..851ebbf08b 100644 --- a/crates/context_servers/src/types.rs +++ b/crates/context_servers/src/types.rs @@ -2,8 +2,8 @@ use collections::HashMap; use serde::{Deserialize, Serialize}; use url::Url; -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] +pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05"; + pub enum RequestType { Initialize, CallTool, @@ -18,6 +18,7 @@ pub enum RequestType { Ping, ListTools, ListResourceTemplates, + ListRoots, } impl RequestType { @@ -36,16 +37,14 @@ impl RequestType { RequestType::Ping => "ping", RequestType::ListTools => "tools/list", RequestType::ListResourceTemplates => "resources/templates/list", + RequestType::ListRoots => "roots/list", } } } #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ProtocolVersion { - VersionString(String), - VersionNumber(u32), -} +#[serde(transparent)] +pub struct ProtocolVersion(pub String); #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] @@ -53,6 +52,8 @@ pub struct InitializeParams { pub protocol_version: ProtocolVersion, pub capabilities: ClientCapabilities, pub client_info: Implementation, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Serialize)] @@ -61,30 +62,40 @@ pub struct CallToolParams { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option>, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesUnsubscribeParams { pub uri: Url, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesSubscribeParams { pub uri: Url, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesReadParams { pub uri: Url, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct LoggingSetLevelParams { pub level: LoggingLevel, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Serialize)] @@ -93,6 +104,8 @@ pub struct PromptsGetParams { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option>, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Serialize)] @@ -100,6 +113,8 @@ pub struct PromptsGetParams { pub struct CompletionCompleteParams { pub r#ref: CompletionReference, pub argument: CompletionArgument, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Serialize)] @@ -145,12 +160,16 @@ pub struct InitializeResponse { pub protocol_version: ProtocolVersion, pub capabilities: ServerCapabilities, pub server_info: Implementation, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesReadResponse { - pub contents: Vec, + pub contents: Vec, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Deserialize)] @@ -159,29 +178,39 @@ pub struct ResourcesListResponse { pub resources: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub next_cursor: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, +} +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingMessage { + pub role: Role, + pub content: MessageContent, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct SamplingMessage { - pub role: SamplingRole, - pub content: SamplingContent, +pub struct PromptMessage { + pub role: Role, + pub content: MessageContent, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] -pub enum SamplingRole { +pub enum Role { User, Assistant, } #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] -pub enum SamplingContent { +pub enum MessageContent { #[serde(rename = "text")] Text { text: String }, #[serde(rename = "image")] Image { data: String, mime_type: String }, + #[serde(rename = "resource")] + Resource { resource: ResourceContents }, } #[derive(Debug, Deserialize)] @@ -189,7 +218,9 @@ pub enum SamplingContent { pub struct PromptsGetResponse { #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, - pub messages: Vec, + pub messages: Vec, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Deserialize)] @@ -198,12 +229,16 @@ pub struct PromptsListResponse { pub prompts: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub next_cursor: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionCompleteResponse { pub completion: CompletionResult, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Deserialize)] @@ -214,6 +249,8 @@ pub struct CompletionResult { pub total: Option, #[serde(skip_serializing_if = "Option::is_none")] pub has_more: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } #[derive(Debug, Deserialize, Serialize)] @@ -243,6 +280,8 @@ pub struct ClientCapabilities { pub experimental: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub sampling: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub roots: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -283,6 +322,13 @@ pub struct ToolsCapabilities { pub list_changed: Option, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RootsCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] + pub list_changed: Option, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Tool { @@ -312,14 +358,28 @@ pub struct Resource { #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct ResourceContent { +pub struct ResourceContents { pub uri: Url, #[serde(skip_serializing_if = "Option::is_none")] pub mime_type: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TextResourceContents { + pub uri: Url, #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, + pub mime_type: Option, + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct BlobResourceContents { + pub uri: Url, #[serde(skip_serializing_if = "Option::is_none")] - pub blob: Option, + pub mime_type: Option, + pub blob: String, } #[derive(Debug, Serialize, Deserialize)] @@ -338,8 +398,32 @@ pub struct ResourceTemplate { pub enum LoggingLevel { Debug, Info, + Notice, Warning, Error, + Critical, + Alert, + Emergency, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelPreferences { + #[serde(skip_serializing_if = "Option::is_none")] + pub hints: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub cost_priority: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub speed_priority: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub intelligence_priority: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelHint { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, } #[derive(Debug, Serialize)] @@ -352,6 +436,7 @@ pub enum NotificationType { ResourcesListChanged, ToolsListChanged, PromptsListChanged, + RootsListChanged, } impl NotificationType { @@ -364,6 +449,7 @@ impl NotificationType { NotificationType::ResourcesListChanged => "notifications/resources/list_changed", NotificationType::ToolsListChanged => "notifications/tools/list_changed", NotificationType::PromptsListChanged => "notifications/prompts/list_changed", + NotificationType::RootsListChanged => "notifications/roots/list_changed", } } } @@ -373,6 +459,14 @@ impl NotificationType { pub enum ClientNotification { Initialized, Progress(ProgressParams), + RootsListChanged, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ProgressToken { + String(String), + Number(f64), } #[derive(Debug, Serialize)] @@ -382,10 +476,10 @@ pub struct ProgressParams { pub progress: f64, #[serde(skip_serializing_if = "Option::is_none")] pub total: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, } -pub type ProgressToken = String; - pub enum CompletionTotal { Exact(u32), HasMore, @@ -410,7 +504,22 @@ pub struct Completion { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CallToolResponse { - pub tool_result: serde_json::Value, + pub content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ToolResponseContent { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { data: String, mime_type: String }, + #[serde(rename = "resource")] + Resource { resource: ResourceContents }, } #[derive(Debug, Deserialize)] @@ -419,4 +528,22 @@ pub struct ListToolsResponse { pub tools: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub next_cursor: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListRootsResponse { + pub roots: Vec, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Root { + pub uri: Url, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, }