Factor tool definitions out of assistant
(#21189)
This PR factors the tool definitions out of the `assistant` crate so that they can be shared between `assistant` and `assistant2`. `ToolWorkingSet` now lives in `assistant_tool`. The tool definitions themselves live in `assistant_tools`, with the exception of the `ContextServerTool`, which has been moved to the `context_server` crate. As part of this refactoring I needed to extract the `ContextServerSettings` to a separate `context_server_settings` crate so that the `extension_host`—which is referenced by the `remote_server`—can name the `ContextServerSettings` type without pulling in some undesired dependencies. Release Notes: - N/A
This commit is contained in:
parent
321fd19763
commit
3901d46101
35 changed files with 219 additions and 113 deletions
234
crates/context_server/src/protocol.rs
Normal file
234
crates/context_server/src/protocol.rs
Normal file
|
@ -0,0 +1,234 @@
|
|||
//! This module implements parts of the Model Context Protocol.
|
||||
//!
|
||||
//! It handles the lifecycle messages, and provides a general interface to
|
||||
//! interacting with an MCP server. It uses the generic JSON-RPC client to
|
||||
//! read/write messages and the types from types.rs for serialization/deserialization
|
||||
//! of messages.
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
|
||||
use crate::client::Client;
|
||||
use crate::types;
|
||||
|
||||
pub struct ModelContextProtocol {
|
||||
inner: Client,
|
||||
}
|
||||
|
||||
impl ModelContextProtocol {
|
||||
pub fn new(inner: Client) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
fn supported_protocols() -> Vec<types::ProtocolVersion> {
|
||||
vec![types::ProtocolVersion(
|
||||
types::LATEST_PROTOCOL_VERSION.to_string(),
|
||||
)]
|
||||
}
|
||||
|
||||
pub async fn initialize(
|
||||
self,
|
||||
client_info: types::Implementation,
|
||||
) -> Result<InitializedContextServerProtocol> {
|
||||
let params = types::InitializeParams {
|
||||
protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
|
||||
capabilities: types::ClientCapabilities {
|
||||
experimental: None,
|
||||
sampling: None,
|
||||
roots: None,
|
||||
},
|
||||
meta: None,
|
||||
client_info,
|
||||
};
|
||||
|
||||
let response: types::InitializeResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::Initialize.as_str(), params)
|
||||
.await?;
|
||||
|
||||
if !Self::supported_protocols().contains(&response.protocol_version) {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unsupported protocol version: {:?}",
|
||||
response.protocol_version
|
||||
));
|
||||
}
|
||||
|
||||
log::trace!("mcp server info {:?}", response.server_info);
|
||||
|
||||
self.inner.notify(
|
||||
types::NotificationType::Initialized.as_str(),
|
||||
serde_json::json!({}),
|
||||
)?;
|
||||
|
||||
let initialized_protocol = InitializedContextServerProtocol {
|
||||
inner: self.inner,
|
||||
initialize: response,
|
||||
};
|
||||
|
||||
Ok(initialized_protocol)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InitializedContextServerProtocol {
|
||||
inner: Client,
|
||||
pub initialize: types::InitializeResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub enum ServerCapability {
|
||||
Experimental,
|
||||
Logging,
|
||||
Prompts,
|
||||
Resources,
|
||||
Tools,
|
||||
}
|
||||
|
||||
impl InitializedContextServerProtocol {
|
||||
/// Check if the server supports a specific capability
|
||||
pub fn capable(&self, capability: ServerCapability) -> bool {
|
||||
match capability {
|
||||
ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
|
||||
ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
|
||||
ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
|
||||
ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
|
||||
ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_capability(&self, capability: ServerCapability) -> Result<()> {
|
||||
if self.capable(capability) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"Server does not support {:?} capability",
|
||||
capability
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// List the MCP prompts.
|
||||
pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
|
||||
self.check_capability(ServerCapability::Prompts)?;
|
||||
|
||||
let response: types::PromptsListResponse = self
|
||||
.inner
|
||||
.request(
|
||||
types::RequestType::PromptsList.as_str(),
|
||||
serde_json::json!({}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(response.prompts)
|
||||
}
|
||||
|
||||
/// List the MCP resources.
|
||||
pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
|
||||
self.check_capability(ServerCapability::Resources)?;
|
||||
|
||||
let response: types::ResourcesListResponse = self
|
||||
.inner
|
||||
.request(
|
||||
types::RequestType::ResourcesList.as_str(),
|
||||
serde_json::json!({}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Executes a prompt with the given arguments and returns the result.
|
||||
pub async fn run_prompt<P: AsRef<str>>(
|
||||
&self,
|
||||
prompt: P,
|
||||
arguments: HashMap<String, String>,
|
||||
) -> Result<types::PromptsGetResponse> {
|
||||
self.check_capability(ServerCapability::Prompts)?;
|
||||
|
||||
let params = types::PromptsGetParams {
|
||||
name: prompt.as_ref().to_string(),
|
||||
arguments: Some(arguments),
|
||||
meta: None,
|
||||
};
|
||||
|
||||
let response: types::PromptsGetResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::PromptsGet.as_str(), params)
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn completion<P: Into<String>>(
|
||||
&self,
|
||||
reference: types::CompletionReference,
|
||||
argument: P,
|
||||
value: P,
|
||||
) -> Result<types::Completion> {
|
||||
let params = types::CompletionCompleteParams {
|
||||
r#ref: reference,
|
||||
argument: types::CompletionArgument {
|
||||
name: argument.into(),
|
||||
value: value.into(),
|
||||
},
|
||||
meta: None,
|
||||
};
|
||||
let result: types::CompletionCompleteResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::CompletionComplete.as_str(), params)
|
||||
.await?;
|
||||
|
||||
let completion = types::Completion {
|
||||
values: result.completion.values,
|
||||
total: types::CompletionTotal::from_options(
|
||||
result.completion.has_more,
|
||||
result.completion.total,
|
||||
),
|
||||
};
|
||||
|
||||
Ok(completion)
|
||||
}
|
||||
|
||||
/// List MCP tools.
|
||||
pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
|
||||
self.check_capability(ServerCapability::Tools)?;
|
||||
|
||||
let response = self
|
||||
.inner
|
||||
.request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Executes a tool with the given arguments
|
||||
pub async fn run_tool<P: AsRef<str>>(
|
||||
&self,
|
||||
tool: P,
|
||||
arguments: Option<HashMap<String, serde_json::Value>>,
|
||||
) -> Result<types::CallToolResponse> {
|
||||
self.check_capability(ServerCapability::Tools)?;
|
||||
|
||||
let params = types::CallToolParams {
|
||||
name: tool.as_ref().to_string(),
|
||||
arguments,
|
||||
meta: None,
|
||||
};
|
||||
|
||||
let response: types::CallToolResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::CallTool.as_str(), params)
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitializedContextServerProtocol {
|
||||
pub async fn request<R: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: impl serde::Serialize,
|
||||
) -> Result<R> {
|
||||
self.inner.request(method, params).await
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue