context_servers: Upgrade protocol to version 2024-11-05 (#20615)
This updates context servers to the most recent version Release Notes: - N/A Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
b5ce8e7aa5
commit
690a725667
5 changed files with 180 additions and 34 deletions
|
@ -152,7 +152,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
|||
if result
|
||||
.messages
|
||||
.iter()
|
||||
.any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User))
|
||||
.any(|msg| !matches!(msg.role, context_servers::types::Role::User))
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"Prompt contains non-user roles, which is not supported"
|
||||
|
@ -164,7 +164,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
|||
.messages
|
||||
.into_iter()
|
||||
.filter_map(|msg| match msg.content {
|
||||
context_servers::types::SamplingContent::Text { text } => Some(text),
|
||||
context_servers::types::MessageContent::Text { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
|
|
|
@ -74,11 +74,21 @@ impl Tool for ContextServerTool {
|
|||
);
|
||||
let response = protocol.run_tool(tool_name, arguments).await?;
|
||||
|
||||
let tool_result = match response.tool_result {
|
||||
serde_json::Value::String(s) => s,
|
||||
_ => serde_json::to_string(&response.tool_result)?,
|
||||
};
|
||||
Ok(tool_result)
|
||||
let mut result = String::new();
|
||||
for content in response.content {
|
||||
match content {
|
||||
types::ToolResponseContent::Text { text } => {
|
||||
result.push_str(&text);
|
||||
}
|
||||
types::ToolResponseContent::Image { .. } => {
|
||||
log::warn!("Ignoring image content from tool response");
|
||||
}
|
||||
types::ToolResponseContent::Resource { .. } => {
|
||||
log::warn!("Ignoring resource content from tool response");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
|
|
|
@ -25,6 +25,13 @@ use util::TryFutureExt;
|
|||
const JSON_RPC_VERSION: &str = "2.0";
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
// Standard JSON-RPC error codes
|
||||
pub const PARSE_ERROR: i32 = -32700;
|
||||
pub const INVALID_REQUEST: i32 = -32600;
|
||||
pub const METHOD_NOT_FOUND: i32 = -32601;
|
||||
pub const INVALID_PARAMS: i32 = -32602;
|
||||
pub const INTERNAL_ERROR: i32 = -32603;
|
||||
|
||||
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
|
||||
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;
|
||||
|
||||
|
|
|
@ -11,8 +11,6 @@ use collections::HashMap;
|
|||
use crate::client::Client;
|
||||
use crate::types;
|
||||
|
||||
const PROTOCOL_VERSION: &str = "2024-10-07";
|
||||
|
||||
pub struct ModelContextProtocol {
|
||||
inner: Client,
|
||||
}
|
||||
|
@ -23,10 +21,9 @@ impl ModelContextProtocol {
|
|||
}
|
||||
|
||||
fn supported_protocols() -> Vec<types::ProtocolVersion> {
|
||||
vec![
|
||||
types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
|
||||
types::ProtocolVersion::VersionNumber(1),
|
||||
]
|
||||
vec![types::ProtocolVersion(
|
||||
types::LATEST_PROTOCOL_VERSION.to_string(),
|
||||
)]
|
||||
}
|
||||
|
||||
pub async fn initialize(
|
||||
|
@ -34,11 +31,13 @@ impl ModelContextProtocol {
|
|||
client_info: types::Implementation,
|
||||
) -> Result<InitializedContextServerProtocol> {
|
||||
let params = types::InitializeParams {
|
||||
protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
|
||||
protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
|
||||
capabilities: types::ClientCapabilities {
|
||||
experimental: None,
|
||||
sampling: None,
|
||||
roots: None,
|
||||
},
|
||||
meta: None,
|
||||
client_info,
|
||||
};
|
||||
|
||||
|
@ -148,6 +147,7 @@ impl InitializedContextServerProtocol {
|
|||
let params = types::PromptsGetParams {
|
||||
name: prompt.as_ref().to_string(),
|
||||
arguments: Some(arguments),
|
||||
meta: None,
|
||||
};
|
||||
|
||||
let response: types::PromptsGetResponse = self
|
||||
|
@ -170,6 +170,7 @@ impl InitializedContextServerProtocol {
|
|||
name: argument.into(),
|
||||
value: value.into(),
|
||||
},
|
||||
meta: None,
|
||||
};
|
||||
let result: types::CompletionCompleteResponse = self
|
||||
.inner
|
||||
|
@ -210,6 +211,7 @@ impl InitializedContextServerProtocol {
|
|||
let params = types::CallToolParams {
|
||||
name: tool.as_ref().to_string(),
|
||||
arguments,
|
||||
meta: None,
|
||||
};
|
||||
|
||||
let response: types::CallToolResponse = self
|
||||
|
|
|
@ -2,8 +2,8 @@ use collections::HashMap;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
|
||||
|
||||
pub enum RequestType {
|
||||
Initialize,
|
||||
CallTool,
|
||||
|
@ -18,6 +18,7 @@ pub enum RequestType {
|
|||
Ping,
|
||||
ListTools,
|
||||
ListResourceTemplates,
|
||||
ListRoots,
|
||||
}
|
||||
|
||||
impl RequestType {
|
||||
|
@ -36,16 +37,14 @@ impl RequestType {
|
|||
RequestType::Ping => "ping",
|
||||
RequestType::ListTools => "tools/list",
|
||||
RequestType::ListResourceTemplates => "resources/templates/list",
|
||||
RequestType::ListRoots => "roots/list",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProtocolVersion {
|
||||
VersionString(String),
|
||||
VersionNumber(u32),
|
||||
}
|
||||
#[serde(transparent)]
|
||||
pub struct ProtocolVersion(pub String);
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
@ -53,6 +52,8 @@ pub struct InitializeParams {
|
|||
pub protocol_version: ProtocolVersion,
|
||||
pub capabilities: ClientCapabilities,
|
||||
pub client_info: Implementation,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
@ -61,30 +62,40 @@ pub struct CallToolParams {
|
|||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesUnsubscribeParams {
|
||||
pub uri: Url,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesSubscribeParams {
|
||||
pub uri: Url,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesReadParams {
|
||||
pub uri: Url,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LoggingSetLevelParams {
|
||||
pub level: LoggingLevel,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
@ -93,6 +104,8 @@ pub struct PromptsGetParams {
|
|||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<HashMap<String, String>>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
@ -100,6 +113,8 @@ pub struct PromptsGetParams {
|
|||
pub struct CompletionCompleteParams {
|
||||
pub r#ref: CompletionReference,
|
||||
pub argument: CompletionArgument,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
@ -145,12 +160,16 @@ pub struct InitializeResponse {
|
|||
pub protocol_version: ProtocolVersion,
|
||||
pub capabilities: ServerCapabilities,
|
||||
pub server_info: Implementation,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesReadResponse {
|
||||
pub contents: Vec<ResourceContent>,
|
||||
pub contents: Vec<ResourceContents>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -159,29 +178,39 @@ pub struct ResourcesListResponse {
|
|||
pub resources: Vec<Resource>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub next_cursor: Option<String>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SamplingMessage {
|
||||
pub role: Role,
|
||||
pub content: MessageContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SamplingMessage {
|
||||
pub role: SamplingRole,
|
||||
pub content: SamplingContent,
|
||||
pub struct PromptMessage {
|
||||
pub role: Role,
|
||||
pub content: MessageContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SamplingRole {
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum SamplingContent {
|
||||
pub enum MessageContent {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image")]
|
||||
Image { data: String, mime_type: String },
|
||||
#[serde(rename = "resource")]
|
||||
Resource { resource: ResourceContents },
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -189,7 +218,9 @@ pub enum SamplingContent {
|
|||
pub struct PromptsGetResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub messages: Vec<SamplingMessage>,
|
||||
pub messages: Vec<PromptMessage>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -198,12 +229,16 @@ pub struct PromptsListResponse {
|
|||
pub prompts: Vec<Prompt>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub next_cursor: Option<String>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionCompleteResponse {
|
||||
pub completion: CompletionResult,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -214,6 +249,8 @@ pub struct CompletionResult {
|
|||
pub total: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub has_more: Option<bool>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
|
@ -243,6 +280,8 @@ pub struct ClientCapabilities {
|
|||
pub experimental: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sampling: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub roots: Option<RootsCapabilities>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -283,6 +322,13 @@ pub struct ToolsCapabilities {
|
|||
pub list_changed: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RootsCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub list_changed: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Tool {
|
||||
|
@ -312,14 +358,28 @@ pub struct Resource {
|
|||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourceContent {
|
||||
pub struct ResourceContents {
|
||||
pub uri: Url,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TextResourceContents {
|
||||
pub uri: Url,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
pub mime_type: Option<String>,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct BlobResourceContents {
|
||||
pub uri: Url,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub blob: Option<String>,
|
||||
pub mime_type: Option<String>,
|
||||
pub blob: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -338,8 +398,32 @@ pub struct ResourceTemplate {
|
|||
pub enum LoggingLevel {
|
||||
Debug,
|
||||
Info,
|
||||
Notice,
|
||||
Warning,
|
||||
Error,
|
||||
Critical,
|
||||
Alert,
|
||||
Emergency,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ModelPreferences {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub hints: Option<Vec<ModelHint>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cost_priority: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub speed_priority: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub intelligence_priority: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ModelHint {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
@ -352,6 +436,7 @@ pub enum NotificationType {
|
|||
ResourcesListChanged,
|
||||
ToolsListChanged,
|
||||
PromptsListChanged,
|
||||
RootsListChanged,
|
||||
}
|
||||
|
||||
impl NotificationType {
|
||||
|
@ -364,6 +449,7 @@ impl NotificationType {
|
|||
NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
|
||||
NotificationType::ToolsListChanged => "notifications/tools/list_changed",
|
||||
NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
|
||||
NotificationType::RootsListChanged => "notifications/roots/list_changed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -373,6 +459,14 @@ impl NotificationType {
|
|||
pub enum ClientNotification {
|
||||
Initialized,
|
||||
Progress(ProgressParams),
|
||||
RootsListChanged,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProgressToken {
|
||||
String(String),
|
||||
Number(f64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
@ -382,10 +476,10 @@ pub struct ProgressParams {
|
|||
pub progress: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total: Option<f64>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
pub type ProgressToken = String;
|
||||
|
||||
pub enum CompletionTotal {
|
||||
Exact(u32),
|
||||
HasMore,
|
||||
|
@ -410,7 +504,22 @@ pub struct Completion {
|
|||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CallToolResponse {
|
||||
pub tool_result: serde_json::Value,
|
||||
pub content: Vec<ToolResponseContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub is_error: Option<bool>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ToolResponseContent {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image")]
|
||||
Image { data: String, mime_type: String },
|
||||
#[serde(rename = "resource")]
|
||||
Resource { resource: ResourceContents },
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -419,4 +528,22 @@ pub struct ListToolsResponse {
|
|||
pub tools: Vec<Tool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub next_cursor: Option<String>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ListRootsResponse {
|
||||
pub roots: Vec<Root>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Root {
|
||||
pub uri: Url,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue