context_servers: Upgrade protocol to version 2024-11-05 (#20615)

This updates context servers to the most recent version

Release Notes:

- N/A

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
David Soria Parra 2024-11-14 18:03:30 +00:00 committed by GitHub
parent b5ce8e7aa5
commit 690a725667
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 180 additions and 34 deletions

View file

@ -152,7 +152,7 @@ impl SlashCommand for ContextServerSlashCommand {
if result
.messages
.iter()
.any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User))
.any(|msg| !matches!(msg.role, context_servers::types::Role::User))
{
return Err(anyhow!(
"Prompt contains non-user roles, which is not supported"
@ -164,7 +164,7 @@ impl SlashCommand for ContextServerSlashCommand {
.messages
.into_iter()
.filter_map(|msg| match msg.content {
context_servers::types::SamplingContent::Text { text } => Some(text),
context_servers::types::MessageContent::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<String>>()

View file

@ -74,11 +74,21 @@ impl Tool for ContextServerTool {
);
let response = protocol.run_tool(tool_name, arguments).await?;
let tool_result = match response.tool_result {
serde_json::Value::String(s) => s,
_ => serde_json::to_string(&response.tool_result)?,
};
Ok(tool_result)
let mut result = String::new();
for content in response.content {
match content {
types::ToolResponseContent::Text { text } => {
result.push_str(&text);
}
types::ToolResponseContent::Image { .. } => {
log::warn!("Ignoring image content from tool response");
}
types::ToolResponseContent::Resource { .. } => {
log::warn!("Ignoring resource content from tool response");
}
}
}
Ok(result)
}
})
} else {

View file

@ -25,6 +25,13 @@ use util::TryFutureExt;
const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
// Standard JSON-RPC error codes
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;

View file

@ -11,8 +11,6 @@ use collections::HashMap;
use crate::client::Client;
use crate::types;
const PROTOCOL_VERSION: &str = "2024-10-07";
pub struct ModelContextProtocol {
inner: Client,
}
@ -23,10 +21,9 @@ impl ModelContextProtocol {
}
fn supported_protocols() -> Vec<types::ProtocolVersion> {
vec![
types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
types::ProtocolVersion::VersionNumber(1),
]
vec![types::ProtocolVersion(
types::LATEST_PROTOCOL_VERSION.to_string(),
)]
}
pub async fn initialize(
@ -34,11 +31,13 @@ impl ModelContextProtocol {
client_info: types::Implementation,
) -> Result<InitializedContextServerProtocol> {
let params = types::InitializeParams {
protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
capabilities: types::ClientCapabilities {
experimental: None,
sampling: None,
roots: None,
},
meta: None,
client_info,
};
@ -148,6 +147,7 @@ impl InitializedContextServerProtocol {
let params = types::PromptsGetParams {
name: prompt.as_ref().to_string(),
arguments: Some(arguments),
meta: None,
};
let response: types::PromptsGetResponse = self
@ -170,6 +170,7 @@ impl InitializedContextServerProtocol {
name: argument.into(),
value: value.into(),
},
meta: None,
};
let result: types::CompletionCompleteResponse = self
.inner
@ -210,6 +211,7 @@ impl InitializedContextServerProtocol {
let params = types::CallToolParams {
name: tool.as_ref().to_string(),
arguments,
meta: None,
};
let response: types::CallToolResponse = self

View file

@ -2,8 +2,8 @@ use collections::HashMap;
use serde::{Deserialize, Serialize};
use url::Url;
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
pub enum RequestType {
Initialize,
CallTool,
@ -18,6 +18,7 @@ pub enum RequestType {
Ping,
ListTools,
ListResourceTemplates,
ListRoots,
}
impl RequestType {
@ -36,16 +37,14 @@ impl RequestType {
RequestType::Ping => "ping",
RequestType::ListTools => "tools/list",
RequestType::ListResourceTemplates => "resources/templates/list",
RequestType::ListRoots => "roots/list",
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ProtocolVersion {
VersionString(String),
VersionNumber(u32),
}
#[serde(transparent)]
pub struct ProtocolVersion(pub String);
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
@ -53,6 +52,8 @@ pub struct InitializeParams {
pub protocol_version: ProtocolVersion,
pub capabilities: ClientCapabilities,
pub client_info: Implementation,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
@ -61,30 +62,40 @@ pub struct CallToolParams {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<HashMap<String, serde_json::Value>>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesUnsubscribeParams {
pub uri: Url,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesSubscribeParams {
pub uri: Url,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesReadParams {
pub uri: Url,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct LoggingSetLevelParams {
pub level: LoggingLevel,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
@ -93,6 +104,8 @@ pub struct PromptsGetParams {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<HashMap<String, String>>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
@ -100,6 +113,8 @@ pub struct PromptsGetParams {
pub struct CompletionCompleteParams {
pub r#ref: CompletionReference,
pub argument: CompletionArgument,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
@ -145,12 +160,16 @@ pub struct InitializeResponse {
pub protocol_version: ProtocolVersion,
pub capabilities: ServerCapabilities,
pub server_info: Implementation,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesReadResponse {
pub contents: Vec<ResourceContent>,
pub contents: Vec<ResourceContents>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
@ -159,29 +178,39 @@ pub struct ResourcesListResponse {
pub resources: Vec<Resource>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamplingMessage {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamplingMessage {
pub role: SamplingRole,
pub content: SamplingContent,
pub struct PromptMessage {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SamplingRole {
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum SamplingContent {
pub enum MessageContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { data: String, mime_type: String },
#[serde(rename = "resource")]
Resource { resource: ResourceContents },
}
#[derive(Debug, Deserialize)]
@ -189,7 +218,9 @@ pub enum SamplingContent {
pub struct PromptsGetResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub messages: Vec<SamplingMessage>,
pub messages: Vec<PromptMessage>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
@ -198,12 +229,16 @@ pub struct PromptsListResponse {
pub prompts: Vec<Prompt>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionCompleteResponse {
pub completion: CompletionResult,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
@ -214,6 +249,8 @@ pub struct CompletionResult {
pub total: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub has_more: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize, Serialize)]
@ -243,6 +280,8 @@ pub struct ClientCapabilities {
pub experimental: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub roots: Option<RootsCapabilities>,
}
#[derive(Debug, Serialize, Deserialize)]
@ -283,6 +322,13 @@ pub struct ToolsCapabilities {
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RootsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
@ -312,14 +358,28 @@ pub struct Resource {
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceContent {
pub struct ResourceContents {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextResourceContents {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
pub mime_type: Option<String>,
pub text: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BlobResourceContents {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub blob: Option<String>,
pub mime_type: Option<String>,
pub blob: String,
}
#[derive(Debug, Serialize, Deserialize)]
@ -338,8 +398,32 @@ pub struct ResourceTemplate {
pub enum LoggingLevel {
Debug,
Info,
Notice,
Warning,
Error,
Critical,
Alert,
Emergency,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ModelPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub hints: Option<Vec<ModelHint>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_priority: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub intelligence_priority: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ModelHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Serialize)]
@ -352,6 +436,7 @@ pub enum NotificationType {
ResourcesListChanged,
ToolsListChanged,
PromptsListChanged,
RootsListChanged,
}
impl NotificationType {
@ -364,6 +449,7 @@ impl NotificationType {
NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
NotificationType::ToolsListChanged => "notifications/tools/list_changed",
NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
NotificationType::RootsListChanged => "notifications/roots/list_changed",
}
}
}
@ -373,6 +459,14 @@ impl NotificationType {
pub enum ClientNotification {
Initialized,
Progress(ProgressParams),
RootsListChanged,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ProgressToken {
String(String),
Number(f64),
}
#[derive(Debug, Serialize)]
@ -382,10 +476,10 @@ pub struct ProgressParams {
pub progress: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<f64>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
pub type ProgressToken = String;
pub enum CompletionTotal {
Exact(u32),
HasMore,
@ -410,7 +504,22 @@ pub struct Completion {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolResponse {
pub tool_result: serde_json::Value,
pub content: Vec<ToolResponseContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ToolResponseContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { data: String, mime_type: String },
#[serde(rename = "resource")]
Resource { resource: ResourceContents },
}
#[derive(Debug, Deserialize)]
@ -419,4 +528,22 @@ pub struct ListToolsResponse {
pub tools: Vec<Tool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListRootsResponse {
pub roots: Vec<Root>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Root {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}