diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs index 6b1ae39186..3db057d074 100644 --- a/crates/assistant/src/slash_command/context_server_command.rs +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -1,3 +1,4 @@ +use super::create_label_for_command; use anyhow::{anyhow, Result}; use assistant_slash_command::{ AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput, @@ -6,9 +7,9 @@ use assistant_slash_command::{ use collections::HashMap; use context_servers::{ manager::{ContextServer, ContextServerManager}, - protocol::PromptInfo, + types::Prompt, }; -use gpui::{Task, WeakView, WindowContext}; +use gpui::{AppContext, Task, WeakView, WindowContext}; use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate}; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -18,11 +19,11 @@ use workspace::Workspace; pub struct ContextServerSlashCommand { server_id: String, - prompt: PromptInfo, + prompt: Prompt, } impl ContextServerSlashCommand { - pub fn new(server: &Arc, prompt: PromptInfo) -> Self { + pub fn new(server: &Arc, prompt: Prompt) -> Self { Self { server_id: server.id.clone(), prompt, @@ -35,12 +36,28 @@ impl SlashCommand for ContextServerSlashCommand { self.prompt.name.clone() } + fn label(&self, cx: &AppContext) -> language::CodeLabel { + let mut parts = vec![self.prompt.name.as_str()]; + if let Some(args) = &self.prompt.arguments { + if let Some(arg) = args.first() { + parts.push(arg.name.as_str()); + } + } + create_label_for_command(&parts[0], &parts[1..], cx) + } + fn description(&self) -> String { - format!("Run context server command: {}", self.prompt.name) + match &self.prompt.description { + Some(desc) => desc.clone(), + None => format!("Run '{}' from {}", self.prompt.name, self.server_id), + } } fn menu_text(&self) -> String { - format!("Run '{}' from {}", self.prompt.name, self.server_id) + match &self.prompt.description { + Some(desc) => desc.clone(), + None => format!("Run '{}' from {}", self.prompt.name, self.server_id), + } } fn requires_argument(&self) -> bool { @@ -154,7 +171,7 @@ impl SlashCommand for ContextServerSlashCommand { } } -fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> { +fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> { if arguments.is_empty() { return Err(anyhow!("No arguments given")); } @@ -170,7 +187,7 @@ fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(Str } } -fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result> { +fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result> { match &prompt.arguments { Some(args) if args.len() > 1 => Err(anyhow!( "Prompt has more than one argument, which is not supported" @@ -199,7 +216,7 @@ fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result bool { +pub fn acceptable_prompt(prompt: &Prompt) -> bool { match &prompt.arguments { None => true, Some(args) if args.len() <= 1 => true, diff --git a/crates/context_servers/src/client.rs b/crates/context_servers/src/client.rs index aff186b115..6681023c00 100644 --- a/crates/context_servers/src/client.rs +++ b/crates/context_servers/src/client.rs @@ -26,7 +26,7 @@ const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); type ResponseHandler = Box)>; -type NotificationHandler = Box; +type NotificationHandler = Box; #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[serde(untagged)] @@ -94,7 +94,6 @@ enum CspResult { #[derive(Serialize, Deserialize)] struct Notification<'a, T> { jsonrpc: &'static str, - id: RequestId, #[serde(borrow)] method: &'a str, params: T, @@ -103,7 +102,6 @@ struct Notification<'a, T> { #[derive(Debug, Clone, Deserialize)] struct AnyNotification<'a> { jsonrpc: &'a str, - id: RequestId, method: String, #[serde(default)] params: Option, @@ -246,11 +244,7 @@ impl Client { if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { - handler( - notification.id, - notification.params.unwrap_or(Value::Null), - cx.clone(), - ); + handler(notification.params.unwrap_or(Value::Null), cx.clone()); } } } @@ -378,10 +372,8 @@ impl Client { /// Sends a notification to the context server without expecting a response. /// This function serializes the notification and sends it through the outbound channel. pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> { - let id = self.next_id.fetch_add(1, SeqCst); let notification = serde_json::to_string(&Notification { jsonrpc: JSON_RPC_VERSION, - id: RequestId::Int(id), method, params, }) @@ -390,13 +382,13 @@ impl Client { Ok(()) } - pub fn on_notification(&self, method: &'static str, mut f: F) + pub fn on_notification(&self, method: &'static str, f: F) where F: 'static + Send + FnMut(Value, AsyncAppContext), { self.notification_handlers .lock() - .insert(method, Box::new(move |_, params, cx| f(params, cx))); + .insert(method, Box::new(f)); } pub fn name(&self) -> &str { diff --git a/crates/context_servers/src/manager.rs b/crates/context_servers/src/manager.rs index 08e403a434..3c21fd53fb 100644 --- a/crates/context_servers/src/manager.rs +++ b/crates/context_servers/src/manager.rs @@ -85,7 +85,7 @@ impl ContextServer { )?; let protocol = crate::protocol::ModelContextProtocol::new(client); - let client_info = types::EntityInfo { + let client_info = types::Implementation { name: "Zed".to_string(), version: env!("CARGO_PKG_VERSION").to_string(), }; diff --git a/crates/context_servers/src/protocol.rs b/crates/context_servers/src/protocol.rs index 87da217f7d..451db56ef3 100644 --- a/crates/context_servers/src/protocol.rs +++ b/crates/context_servers/src/protocol.rs @@ -11,8 +11,6 @@ use collections::HashMap; use crate::client::Client; use crate::types; -pub use types::PromptInfo; - const PROTOCOL_VERSION: u32 = 1; pub struct ModelContextProtocol { @@ -26,7 +24,7 @@ impl ModelContextProtocol { pub async fn initialize( self, - client_info: types::EntityInfo, + client_info: types::Implementation, ) -> Result { let params = types::InitializeParams { protocol_version: PROTOCOL_VERSION, @@ -96,7 +94,7 @@ impl InitializedContextServerProtocol { } /// List the MCP prompts. - pub async fn list_prompts(&self) -> Result> { + pub async fn list_prompts(&self) -> Result> { self.check_capability(ServerCapability::Prompts)?; let response: types::PromptsListResponse = self @@ -107,6 +105,18 @@ impl InitializedContextServerProtocol { Ok(response.prompts) } + /// List the MCP resources. + pub async fn list_resources(&self) -> Result { + self.check_capability(ServerCapability::Resources)?; + + let response: types::ResourcesListResponse = self + .inner + .request(types::RequestType::ResourcesList.as_str(), ()) + .await?; + + Ok(response) + } + /// Executes a prompt with the given arguments and returns the result. pub async fn run_prompt>( &self, diff --git a/crates/context_servers/src/types.rs b/crates/context_servers/src/types.rs index cd95ecd7ad..04ac87c704 100644 --- a/crates/context_servers/src/types.rs +++ b/crates/context_servers/src/types.rs @@ -15,6 +15,7 @@ pub enum RequestType { PromptsGet, PromptsList, CompletionComplete, + Ping, } impl RequestType { @@ -30,6 +31,7 @@ impl RequestType { RequestType::PromptsGet => "prompts/get", RequestType::PromptsList => "prompts/list", RequestType::CompletionComplete => "completion/complete", + RequestType::Ping => "ping", } } } @@ -39,14 +41,15 @@ impl RequestType { pub struct InitializeParams { pub protocol_version: u32, pub capabilities: ClientCapabilities, - pub client_info: EntityInfo, + pub client_info: Implementation, } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct CallToolParams { pub name: String, - pub arguments: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, } #[derive(Debug, Serialize)] @@ -77,6 +80,7 @@ pub struct LoggingSetLevelParams { #[serde(rename_all = "camelCase")] pub struct PromptsGetParams { pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option>, } @@ -101,6 +105,13 @@ pub struct PromptReference { pub name: String, } +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourceReference { + pub r#type: PromptReferenceType, + pub uri: Url, +} + #[derive(Debug, Serialize)] #[serde(rename_all = "snake_case")] pub enum PromptReferenceType { @@ -110,13 +121,6 @@ pub enum PromptReferenceType { Resource, } -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct ResourceReference { - pub r#type: String, - pub uri: String, -} - #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct CompletionArgument { @@ -129,7 +133,7 @@ pub struct CompletionArgument { pub struct InitializeResponse { pub protocol_version: u32, pub capabilities: ServerCapabilities, - pub server_info: EntityInfo, + pub server_info: Implementation, } #[derive(Debug, Deserialize)] @@ -141,13 +145,39 @@ pub struct ResourcesReadResponse { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesListResponse { + #[serde(skip_serializing_if = "Option::is_none")] pub resource_templates: Option>, - pub resources: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub resources: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingMessage { + pub role: SamplingRole, + pub content: SamplingContent, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SamplingRole { + User, + Assistant, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum SamplingContent { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { data: String, mime_type: String }, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsGetResponse { + #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, pub prompt: String, } @@ -155,7 +185,7 @@ pub struct PromptsGetResponse { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsListResponse { - pub prompts: Vec, + pub prompts: Vec, } #[derive(Debug, Deserialize)] @@ -168,61 +198,91 @@ pub struct CompletionCompleteResponse { #[serde(rename_all = "camelCase")] pub struct CompletionResult { pub values: Vec, + #[serde(skip_serializing_if = "Option::is_none")] pub total: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub has_more: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -pub struct PromptInfo { +pub struct Prompt { pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option>, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct PromptArgument { pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub required: Option, } -// Shared Types - #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ClientCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] pub experimental: Option>, - pub sampling: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ServerCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] pub experimental: Option>, - pub logging: Option>, - pub prompts: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub logging: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompts: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub resources: Option, - pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptsCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] + pub list_changed: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] pub subscribe: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub list_changed: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolsCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] + pub list_changed: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Tool { pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, pub input_schema: serde_json::Value, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct EntityInfo { +pub struct Implementation { pub name: String, pub version: String, } @@ -231,6 +291,10 @@ pub struct EntityInfo { #[serde(rename_all = "camelCase")] pub struct Resource { pub uri: Url, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub mime_type: Option, } @@ -238,17 +302,23 @@ pub struct Resource { #[serde(rename_all = "camelCase")] pub struct ResourceContent { pub uri: Url, + #[serde(skip_serializing_if = "Option::is_none")] pub mime_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub text: Option, - pub data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub blob: Option, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourceTemplate { pub uri_template: String, - pub name: Option, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -260,13 +330,16 @@ pub enum LoggingLevel { Error, } -// Client Notifications - #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub enum NotificationType { Initialized, Progress, + Message, + ResourcesUpdated, + ResourcesListChanged, + ToolsListChanged, + PromptsListChanged, } impl NotificationType { @@ -274,6 +347,11 @@ impl NotificationType { match self { NotificationType::Initialized => "notifications/initialized", NotificationType::Progress => "notifications/progress", + NotificationType::Message => "notifications/message", + NotificationType::ResourcesUpdated => "notifications/resources/updated", + NotificationType::ResourcesListChanged => "notifications/resources/list_changed", + NotificationType::ToolsListChanged => "notifications/tools/list_changed", + NotificationType::PromptsListChanged => "notifications/prompts/list_changed", } } } @@ -288,12 +366,13 @@ pub enum ClientNotification { #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct ProgressParams { - pub progress_token: String, + pub progress_token: ProgressToken, pub progress: f64, + #[serde(skip_serializing_if = "Option::is_none")] pub total: Option, } -// Helper Types that don't map directly to the protocol +pub type ProgressToken = String; pub enum CompletionTotal { Exact(u32),