From 8a96ea25c465697ec74ce3447bcd2ce9cb25b4f0 Mon Sep 17 00:00:00 2001 From: David Soria Parra <167242713+dsp-ant@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:37:58 +0000 Subject: [PATCH] context_servers: Support tools (#19548) This PR depends on #19547 This PR adds support for tools from context servers. Context servers are free to expose tools that Zed can pass to models. When called by the model, Zed forwards the request to context servers. This allows for some interesting techniques. Context servers can easily expose tools such as querying local databases, reading or writing local files, reading resources over authenticated APIs (e.g. kubernetes, asana, etc). This is currently experimental. Things to discuss * I want to still add a confirm dialog asking people if a server is allows to use the tool. Should do this or just use the tool and assume trustworthyness of context servers? * Can we add tool use behind a local setting flag? Release Notes: - N/A --------- Co-authored-by: Marshall Bowers --- crates/assistant/src/assistant.rs | 85 ++++++++++++++----- crates/assistant/src/tools.rs | 1 + .../src/tools/context_server_tool.rs | 82 ++++++++++++++++++ crates/context_servers/src/protocol.rs | 33 +++++++ crates/context_servers/src/registry.rs | 32 +++++-- crates/context_servers/src/types.rs | 18 ++++ .../language_model/src/provider/anthropic.rs | 12 ++- 7 files changed, 235 insertions(+), 28 deletions(-) create mode 100644 crates/assistant/src/tools/context_server_tool.rs diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index e1e574744f..a48f6d6c29 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -298,25 +298,64 @@ fn register_context_server_handlers(cx: &mut AppContext) { return; }; - if let Some(prompts) = protocol.list_prompts().await.log_err() { - for prompt in prompts - .into_iter() - .filter(context_server_command::acceptable_prompt) - { - log::info!( - "registering context server command: {:?}", - prompt.name - ); - context_server_registry.register_command( - server.id.clone(), - prompt.name.as_str(), - ); - slash_command_registry.register_command( - context_server_command::ContextServerSlashCommand::new( - &server, prompt, - ), - true, - ); + if protocol.capable(context_servers::protocol::ServerCapability::Prompts) { + if let Some(prompts) = protocol.list_prompts().await.log_err() { + for prompt in prompts + .into_iter() + .filter(context_server_command::acceptable_prompt) + { + log::info!( + "registering context server command: {:?}", + prompt.name + ); + context_server_registry.register_command( + server.id.clone(), + prompt.name.as_str(), + ); + slash_command_registry.register_command( + context_server_command::ContextServerSlashCommand::new( + &server, prompt, + ), + true, + ); + } + } + } + }) + .detach(); + } + }, + ); + + cx.update_model( + &manager, + |manager: &mut context_servers::manager::ContextServerManager, cx| { + let tool_registry = ToolRegistry::global(cx); + let context_server_registry = ContextServerRegistry::global(cx); + if let Some(server) = manager.get_server(server_id) { + cx.spawn(|_, _| async move { + let Some(protocol) = server.client.read().clone() else { + return; + }; + + if protocol.capable(context_servers::protocol::ServerCapability::Tools) { + if let Some(tools) = protocol.list_tools().await.log_err() { + for tool in tools.tools { + log::info!( + "registering context server tool: {:?}", + tool.name + ); + context_server_registry.register_tool( + server.id.clone(), + tool.name.as_str(), + ); + tool_registry.register_tool( + tools::context_server_tool::ContextServerTool::new( + server.id.clone(), + tool + ), + ); + } } } }) @@ -334,6 +373,14 @@ fn register_context_server_handlers(cx: &mut AppContext) { context_server_registry.unregister_command(&server_id, &command_name); } } + + if let Some(tools) = context_server_registry.get_tools(server_id) { + let tool_registry = ToolRegistry::global(cx); + for tool_name in tools { + tool_registry.unregister_tool_by_name(&tool_name); + context_server_registry.unregister_tool(&server_id, &tool_name); + } + } } }, ) diff --git a/crates/assistant/src/tools.rs b/crates/assistant/src/tools.rs index abde04e760..83a396c020 100644 --- a/crates/assistant/src/tools.rs +++ b/crates/assistant/src/tools.rs @@ -1 +1,2 @@ +pub mod context_server_tool; pub mod now_tool; diff --git a/crates/assistant/src/tools/context_server_tool.rs b/crates/assistant/src/tools/context_server_tool.rs new file mode 100644 index 0000000000..93edb32b75 --- /dev/null +++ b/crates/assistant/src/tools/context_server_tool.rs @@ -0,0 +1,82 @@ +use anyhow::{anyhow, bail}; +use assistant_tool::Tool; +use context_servers::manager::ContextServerManager; +use context_servers::types; +use gpui::Task; + +pub struct ContextServerTool { + server_id: String, + tool: types::Tool, +} + +impl ContextServerTool { + pub fn new(server_id: impl Into, tool: types::Tool) -> Self { + Self { + server_id: server_id.into(), + tool, + } + } +} + +impl Tool for ContextServerTool { + fn name(&self) -> String { + self.tool.name.clone() + } + + fn description(&self) -> String { + self.tool.description.clone().unwrap_or_default() + } + + fn input_schema(&self) -> serde_json::Value { + match &self.tool.input_schema { + serde_json::Value::Null => { + serde_json::json!({ "type": "object", "properties": [] }) + } + serde_json::Value::Object(map) if map.is_empty() => { + serde_json::json!({ "type": "object", "properties": [] }) + } + _ => self.tool.input_schema.clone(), + } + } + + fn run( + self: std::sync::Arc, + input: serde_json::Value, + _workspace: gpui::WeakView, + cx: &mut ui::WindowContext, + ) -> gpui::Task> { + let manager = ContextServerManager::global(cx); + let manager = manager.read(cx); + if let Some(server) = manager.get_server(&self.server_id) { + cx.foreground_executor().spawn({ + let tool_name = self.tool.name.clone(); + async move { + let Some(protocol) = server.client.read().clone() else { + bail!("Context server not initialized"); + }; + + let arguments = if let serde_json::Value::Object(map) = input { + Some(map.into_iter().collect()) + } else { + None + }; + + log::trace!( + "Running tool: {} with arguments: {:?}", + tool_name, + arguments + ); + let response = protocol.run_tool(tool_name, arguments).await?; + + let tool_result = match response.tool_result { + serde_json::Value::String(s) => s, + _ => serde_json::to_string(&response.tool_result)?, + }; + Ok(tool_result) + } + }) + } else { + Task::ready(Err(anyhow!("Context server not found"))) + } + } +} diff --git a/crates/context_servers/src/protocol.rs b/crates/context_servers/src/protocol.rs index 80a7a7f991..996fc34f46 100644 --- a/crates/context_servers/src/protocol.rs +++ b/crates/context_servers/src/protocol.rs @@ -180,6 +180,39 @@ impl InitializedContextServerProtocol { Ok(completion) } + + /// List MCP tools. + pub async fn list_tools(&self) -> Result { + self.check_capability(ServerCapability::Tools)?; + + let response = self + .inner + .request::(types::RequestType::ListTools.as_str(), ()) + .await?; + + Ok(response) + } + + /// Executes a tool with the given arguments + pub async fn run_tool>( + &self, + tool: P, + arguments: Option>, + ) -> Result { + self.check_capability(ServerCapability::Tools)?; + + let params = types::CallToolParams { + name: tool.as_ref().to_string(), + arguments, + }; + + let response: types::CallToolResponse = self + .inner + .request(types::RequestType::CallTool.as_str(), params) + .await?; + + Ok(response) + } } impl InitializedContextServerProtocol { diff --git a/crates/context_servers/src/registry.rs b/crates/context_servers/src/registry.rs index 625f308c15..5490187034 100644 --- a/crates/context_servers/src/registry.rs +++ b/crates/context_servers/src/registry.rs @@ -9,7 +9,8 @@ struct GlobalContextServerRegistry(Arc); impl Global for GlobalContextServerRegistry {} pub struct ContextServerRegistry { - registry: RwLock>>>, + command_registry: RwLock>>>, + tool_registry: RwLock>>>, } impl ContextServerRegistry { @@ -20,13 +21,14 @@ impl ContextServerRegistry { pub fn register(cx: &mut AppContext) { cx.set_global(GlobalContextServerRegistry(Arc::new( ContextServerRegistry { - registry: RwLock::new(HashMap::default()), + command_registry: RwLock::new(HashMap::default()), + tool_registry: RwLock::new(HashMap::default()), }, ))) } pub fn register_command(&self, server_id: String, command_name: &str) { - let mut registry = self.registry.write(); + let mut registry = self.command_registry.write(); registry .entry(server_id) .or_default() @@ -34,14 +36,34 @@ impl ContextServerRegistry { } pub fn unregister_command(&self, server_id: &str, command_name: &str) { - let mut registry = self.registry.write(); + let mut registry = self.command_registry.write(); if let Some(commands) = registry.get_mut(server_id) { commands.retain(|name| name.as_ref() != command_name); } } pub fn get_commands(&self, server_id: &str) -> Option>> { - let registry = self.registry.read(); + let registry = self.command_registry.read(); + registry.get(server_id).cloned() + } + + pub fn register_tool(&self, server_id: String, tool_name: &str) { + let mut registry = self.tool_registry.write(); + registry + .entry(server_id) + .or_default() + .push(tool_name.into()); + } + + pub fn unregister_tool(&self, server_id: &str, tool_name: &str) { + let mut registry = self.tool_registry.write(); + if let Some(tools) = registry.get_mut(server_id) { + tools.retain(|name| name.as_ref() != tool_name); + } + } + + pub fn get_tools(&self, server_id: &str) -> Option>> { + let registry = self.tool_registry.read(); registry.get(server_id).cloned() } } diff --git a/crates/context_servers/src/types.rs b/crates/context_servers/src/types.rs index 2bca0a021a..b6d8a958bb 100644 --- a/crates/context_servers/src/types.rs +++ b/crates/context_servers/src/types.rs @@ -16,6 +16,8 @@ pub enum RequestType { PromptsList, CompletionComplete, Ping, + ListTools, + ListResourceTemplates, } impl RequestType { @@ -32,6 +34,8 @@ impl RequestType { RequestType::PromptsList => "prompts/list", RequestType::CompletionComplete => "completion/complete", RequestType::Ping => "ping", + RequestType::ListTools => "tools/list", + RequestType::ListResourceTemplates => "resources/templates/list", } } } @@ -402,3 +406,17 @@ pub struct Completion { pub values: Vec, pub total: CompletionTotal, } + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CallToolResponse { + pub tool_result: serde_json::Value, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListToolsResponse { + pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index fe88c73b90..b7e65650b5 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -505,10 +505,14 @@ pub fn map_to_language_model_completion_events( LanguageModelToolUse { id: tool_use.id, name: tool_use.name, - input: serde_json::Value::from_str( - &tool_use.input_json, - ) - .map_err(|err| anyhow!(err))?, + input: if tool_use.input_json.is_empty() { + serde_json::Value::Null + } else { + serde_json::Value::from_str( + &tool_use.input_json, + ) + .map_err(|err| anyhow!(err))? + }, }, )) })),