Factor tool definitions out of assistant (#21189)

This PR factors the tool definitions out of the `assistant` crate so
that they can be shared between `assistant` and `assistant2`.

`ToolWorkingSet` now lives in `assistant_tool`. The tool definitions
themselves live in `assistant_tools`, with the exception of the
`ContextServerTool`, which has been moved to the `context_server` crate.

As part of this refactoring I needed to extract the
`ContextServerSettings` to a separate `context_server_settings` crate so
that the `extension_host`—which is referenced by the `remote_server`—can
name the `ContextServerSettings` type without pulling in some undesired
dependencies.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-25 18:26:34 -05:00 committed by GitHub
parent 321fd19763
commit 3901d46101
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 219 additions and 113 deletions

View file

@ -0,0 +1,34 @@
[package]
name = "context_server"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/context_server.rs"
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
context_server_settings.workspace = true
extension.workspace = true
futures.workspace = true
gpui.workspace = true
log.workspace = true
parking_lot.workspace = true
postage.workspace = true
project.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
ui.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
workspace.workspace = true

View file

@ -0,0 +1 @@
../../LICENSE-GPL

View file

@ -0,0 +1,440 @@
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
use gpui::{AsyncAppContext, BackgroundExecutor, Task};
use parking_lot::Mutex;
use postage::barrier;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{value::RawValue, Value};
use smol::{
channel,
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
process::Child,
};
use std::{
fmt,
path::PathBuf,
sync::{
atomic::{AtomicI32, Ordering::SeqCst},
Arc,
},
time::{Duration, Instant},
};
use util::TryFutureExt;
const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
// Standard JSON-RPC error codes
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
pub enum RequestId {
Int(i32),
Str(String),
}
pub struct Client {
server_id: ContextServerId,
next_id: AtomicI32,
outbound_tx: channel::Sender<String>,
name: Arc<str>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
#[allow(clippy::type_complexity)]
#[allow(dead_code)]
io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
#[allow(dead_code)]
output_done_rx: Mutex<Option<barrier::Receiver>>,
executor: BackgroundExecutor,
server: Arc<Mutex<Option<Child>>>,
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub struct ContextServerId(pub Arc<str>);
fn is_null_value<T: Serialize>(value: &T) -> bool {
if let Ok(Value::Null) = serde_json::to_value(value) {
true
} else {
false
}
}
#[derive(Serialize, Deserialize)]
struct Request<'a, T> {
jsonrpc: &'static str,
id: RequestId,
method: &'a str,
#[serde(skip_serializing_if = "is_null_value")]
params: T,
}
#[derive(Serialize, Deserialize)]
struct AnyResponse<'a> {
jsonrpc: &'a str,
id: RequestId,
#[serde(default)]
error: Option<Error>,
#[serde(borrow)]
result: Option<&'a RawValue>,
}
#[derive(Deserialize)]
#[allow(dead_code)]
struct Response<T> {
jsonrpc: &'static str,
id: RequestId,
#[serde(flatten)]
value: CspResult<T>,
}
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
enum CspResult<T> {
#[serde(rename = "result")]
Ok(Option<T>),
#[allow(dead_code)]
Error(Option<Error>),
}
#[derive(Serialize, Deserialize)]
struct Notification<'a, T> {
jsonrpc: &'static str,
#[serde(borrow)]
method: &'a str,
params: T,
}
#[derive(Debug, Clone, Deserialize)]
struct AnyNotification<'a> {
jsonrpc: &'a str,
method: String,
#[serde(default)]
params: Option<Value>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Error {
message: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelContextServerBinary {
pub executable: PathBuf,
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}
impl Client {
/// Creates a new Client instance for a context server.
///
/// This function initializes a new Client by spawning a child process for the context server,
/// setting up communication channels, and initializing handlers for input/output operations.
/// It takes a server ID, binary information, and an async app context as input.
pub fn new(
server_id: ContextServerId,
binary: ModelContextServerBinary,
cx: AsyncAppContext,
) -> Result<Self> {
log::info!(
"starting context server (executable={:?}, args={:?})",
binary.executable,
&binary.args
);
let mut command = util::command::new_smol_command(&binary.executable);
command
.args(&binary.args)
.envs(binary.env.unwrap_or_default())
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
let mut server = command.spawn().with_context(|| {
format!(
"failed to spawn command. (path={:?}, args={:?})",
binary.executable, &binary.args
)
})?;
let stdin = server.stdin.take().unwrap();
let stdout = server.stdout.take().unwrap();
let stderr = server.stderr.take().unwrap();
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
let notification_handlers =
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let stdout_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
move |cx| {
Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
}
});
let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
let input_task = cx.spawn(|_| async move {
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
stdout.or(stderr)
});
let output_task = cx.background_executor().spawn({
Self::handle_output(
stdin,
outbound_rx,
output_done_tx,
response_handlers.clone(),
)
.log_err()
});
let mut context_server = Self {
server_id,
notification_handlers,
response_handlers,
name: "".into(),
next_id: Default::default(),
outbound_tx,
executor: cx.background_executor().clone(),
io_tasks: Mutex::new(Some((input_task, output_task))),
output_done_rx: Mutex::new(Some(output_done_rx)),
server: Arc::new(Mutex::new(Some(server))),
};
if let Some(name) = binary.executable.file_name() {
context_server.name = name.to_string_lossy().into();
}
Ok(context_server)
}
/// Handles input from the server's stdout.
///
/// This function continuously reads lines from the provided stdout stream,
/// parses them as JSON-RPC responses or notifications, and dispatches them
/// to the appropriate handlers. It processes both responses (which are matched
/// to pending requests) and notifications (which trigger registered handlers).
async fn handle_input<Stdout>(
stdout: Stdout,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
cx: AsyncAppContext,
) -> anyhow::Result<()>
where
Stdout: AsyncRead + Unpin + Send + 'static,
{
let mut stdout = BufReader::new(stdout);
let mut buffer = String::new();
loop {
buffer.clear();
if stdout.read_line(&mut buffer).await? == 0 {
return Ok(());
}
let content = buffer.trim();
if !content.is_empty() {
if let Ok(response) = serde_json::from_str::<AnyResponse>(content) {
if let Some(handlers) = response_handlers.lock().as_mut() {
if let Some(handler) = handlers.remove(&response.id) {
handler(Ok(content.to_string()));
}
}
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(content) {
let mut notification_handlers = notification_handlers.lock();
if let Some(handler) =
notification_handlers.get_mut(notification.method.as_str())
{
handler(notification.params.unwrap_or(Value::Null), cx.clone());
}
}
}
smol::future::yield_now().await;
}
}
/// Handles the stderr output from the context server.
/// Continuously reads and logs any error messages from the server.
async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
where
Stderr: AsyncRead + Unpin + Send + 'static,
{
let mut stderr = BufReader::new(stderr);
let mut buffer = String::new();
loop {
buffer.clear();
if stderr.read_line(&mut buffer).await? == 0 {
return Ok(());
}
log::warn!("context server stderr: {}", buffer.trim());
smol::future::yield_now().await;
}
}
/// Handles the output to the context server's stdin.
/// This function continuously receives messages from the outbound channel,
/// writes them to the server's stdin, and manages the lifecycle of response handlers.
async fn handle_output<Stdin>(
stdin: Stdin,
outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
) -> anyhow::Result<()>
where
Stdin: AsyncWrite + Unpin + Send + 'static,
{
let mut stdin = BufWriter::new(stdin);
let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
move || {
response_handlers.lock().take();
}
});
while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message: {}", message);
stdin.write_all(message.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
}
drop(output_done_tx);
Ok(())
}
/// Sends a JSON-RPC request to the context server and waits for a response.
/// This function handles serialization, deserialization, timeout, and error handling.
pub async fn request<T: DeserializeOwned>(
&self,
method: &str,
params: impl Serialize,
) -> Result<T> {
let id = self.next_id.fetch_add(1, SeqCst);
let request = serde_json::to_string(&Request {
jsonrpc: JSON_RPC_VERSION,
id: RequestId::Int(id),
method,
params,
})
.unwrap();
let (tx, rx) = oneshot::channel();
let handle_response = self
.response_handlers
.lock()
.as_mut()
.ok_or_else(|| anyhow!("server shut down"))
.map(|handlers| {
handlers.insert(
RequestId::Int(id),
Box::new(move |result| {
let _ = tx.send(result);
}),
);
});
let send = self
.outbound_tx
.try_send(request)
.context("failed to write to context server's stdin");
let executor = self.executor.clone();
let started = Instant::now();
handle_response?;
send?;
let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
select! {
response = rx.fuse() => {
let elapsed = started.elapsed();
log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
match response? {
Ok(response) => {
let parsed: AnyResponse = serde_json::from_str(&response)?;
if let Some(error) = parsed.error {
Err(anyhow!(error.message))
} else if let Some(result) = parsed.result {
Ok(serde_json::from_str(result.get())?)
} else {
Err(anyhow!("Invalid response: no result or error"))
}
}
Err(_) => anyhow::bail!("cancelled")
}
}
_ = timeout => {
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
anyhow::bail!("Context server request timeout");
}
}
}
/// Sends a notification to the context server without expecting a response.
/// This function serializes the notification and sends it through the outbound channel.
pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
let notification = serde_json::to_string(&Notification {
jsonrpc: JSON_RPC_VERSION,
method,
params,
})
.unwrap();
self.outbound_tx.try_send(notification)?;
Ok(())
}
pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncAppContext),
{
self.notification_handlers
.lock()
.insert(method, Box::new(f));
}
pub fn name(&self) -> &str {
&self.name
}
pub fn server_id(&self) -> ContextServerId {
self.server_id.clone()
}
}
impl Drop for Client {
fn drop(&mut self) {
if let Some(mut server) = self.server.lock().take() {
let _ = server.kill();
}
}
}
impl fmt::Display for ContextServerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Context Server Client")
.field("id", &self.server_id.0)
.field("name", &self.name)
.finish_non_exhaustive()
}
}

View file

@ -0,0 +1,29 @@
pub mod client;
mod context_server_tool;
mod extension_context_server;
pub mod manager;
pub mod protocol;
mod registry;
pub mod types;
use command_palette_hooks::CommandPaletteFilter;
pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerConfig};
use gpui::{actions, AppContext};
pub use crate::context_server_tool::ContextServerTool;
pub use crate::registry::ContextServerFactoryRegistry;
actions!(context_servers, [Restart]);
/// The namespace for the context servers actions.
pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut AppContext) {
context_server_settings::init(cx);
ContextServerFactoryRegistry::default_global(cx);
extension_context_server::init(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
});
}

View file

@ -0,0 +1,99 @@
use std::sync::Arc;
use anyhow::{anyhow, bail};
use assistant_tool::Tool;
use gpui::{Model, Task};
use crate::manager::ContextServerManager;
use crate::types;
pub struct ContextServerTool {
server_manager: Model<ContextServerManager>,
server_id: Arc<str>,
tool: types::Tool,
}
impl ContextServerTool {
pub fn new(
server_manager: Model<ContextServerManager>,
server_id: impl Into<Arc<str>>,
tool: types::Tool,
) -> Self {
Self {
server_manager,
server_id: server_id.into(),
tool,
}
}
}
impl Tool for ContextServerTool {
fn name(&self) -> String {
self.tool.name.clone()
}
fn description(&self) -> String {
self.tool.description.clone().unwrap_or_default()
}
fn input_schema(&self) -> serde_json::Value {
match &self.tool.input_schema {
serde_json::Value::Null => {
serde_json::json!({ "type": "object", "properties": [] })
}
serde_json::Value::Object(map) if map.is_empty() => {
serde_json::json!({ "type": "object", "properties": [] })
}
_ => self.tool.input_schema.clone(),
}
}
fn run(
self: std::sync::Arc<Self>,
input: serde_json::Value,
_workspace: gpui::WeakView<workspace::Workspace>,
cx: &mut ui::WindowContext,
) -> gpui::Task<gpui::Result<String>> {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
cx.foreground_executor().spawn({
let tool_name = self.tool.name.clone();
async move {
let Some(protocol) = server.client() else {
bail!("Context server not initialized");
};
let arguments = if let serde_json::Value::Object(map) = input {
Some(map.into_iter().collect())
} else {
None
};
log::trace!(
"Running tool: {} with arguments: {:?}",
tool_name,
arguments
);
let response = protocol.run_tool(tool_name, arguments).await?;
let mut result = String::new();
for content in response.content {
match content {
types::ToolResponseContent::Text { text } => {
result.push_str(&text);
}
types::ToolResponseContent::Image { .. } => {
log::warn!("Ignoring image content from tool response");
}
types::ToolResponseContent::Resource { .. } => {
log::warn!("Ignoring resource content from tool response");
}
}
}
Ok(result)
}
})
} else {
Task::ready(Err(anyhow!("Context server not found")))
}
}
}

View file

@ -0,0 +1,77 @@
use std::sync::Arc;
use extension::{Extension, ExtensionContextServerProxy, ExtensionHostProxy, ProjectDelegate};
use gpui::{AppContext, Model};
use crate::{ContextServerFactoryRegistry, ServerCommand};
struct ExtensionProject {
worktree_ids: Vec<u64>,
}
impl ProjectDelegate for ExtensionProject {
fn worktree_ids(&self) -> Vec<u64> {
self.worktree_ids.clone()
}
}
pub fn init(cx: &mut AppContext) {
let proxy = ExtensionHostProxy::default_global(cx);
proxy.register_context_server_proxy(ContextServerFactoryRegistryProxy {
context_server_factory_registry: ContextServerFactoryRegistry::global(cx),
});
}
struct ContextServerFactoryRegistryProxy {
context_server_factory_registry: Model<ContextServerFactoryRegistry>,
}
impl ExtensionContextServerProxy for ContextServerFactoryRegistryProxy {
fn register_context_server(
&self,
extension: Arc<dyn Extension>,
id: Arc<str>,
cx: &mut AppContext,
) {
self.context_server_factory_registry
.update(cx, |registry, _| {
registry.register_server_factory(
id.clone(),
Arc::new({
move |project, cx| {
log::info!(
"loading command for context server {id} from extension {}",
extension.manifest().id
);
let id = id.clone();
let extension = extension.clone();
cx.spawn(|mut cx| async move {
let extension_project =
project.update(&mut cx, |project, cx| {
Arc::new(ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})
})?;
let command = extension
.context_server_command(id.clone(), extension_project)
.await?;
log::info!("loaded command for context server {id}: {command:?}");
Ok(ServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
})
})
}
}),
)
});
}
}

View file

@ -0,0 +1,288 @@
//! This module implements a context server management system for Zed.
//!
//! It provides functionality to:
//! - Define and load context server settings
//! - Manage individual context servers (start, stop, restart)
//! - Maintain a global manager for all context servers
//!
//! Key components:
//! - `ContextServerSettings`: Defines the structure for server configurations
//! - `ContextServer`: Represents an individual context server
//! - `ContextServerManager`: Manages multiple context servers
//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
//!
//! The module also includes initialization logic to set up the context server system
//! and react to changes in settings.
use std::path::Path;
use std::sync::Arc;
use anyhow::{bail, Result};
use collections::HashMap;
use command_palette_hooks::CommandPaletteFilter;
use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
use log;
use parking_lot::RwLock;
use project::Project;
use settings::{Settings, SettingsStore};
use util::ResultExt as _;
use crate::{ContextServerSettings, ServerConfig};
use crate::{
client::{self, Client},
types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
};
pub struct ContextServer {
pub id: Arc<str>,
pub config: Arc<ServerConfig>,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
}
impl ContextServer {
pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
Self {
id,
config,
client: RwLock::new(None),
}
}
pub fn id(&self) -> Arc<str> {
self.id.clone()
}
pub fn config(&self) -> Arc<ServerConfig> {
self.config.clone()
}
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
self.client.read().clone()
}
pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
log::info!("starting context server {}", self.id);
let Some(command) = &self.config.command else {
bail!("no command specified for server {}", self.id);
};
let client = Client::new(
client::ContextServerId(self.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?;
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!(
"context server {} initialized: {:?}",
self.id,
initialized_protocol.initialize,
);
*self.client.write() = Some(Arc::new(initialized_protocol));
Ok(())
}
pub fn stop(&self) -> Result<()> {
let mut client = self.client.write();
if let Some(protocol) = client.take() {
drop(protocol);
}
Ok(())
}
}
pub struct ContextServerManager {
servers: HashMap<Arc<str>, Arc<ContextServer>>,
project: Model<Project>,
registry: Model<ContextServerFactoryRegistry>,
update_servers_task: Option<Task<Result<()>>>,
needs_server_update: bool,
_subscriptions: Vec<Subscription>,
}
pub enum Event {
ServerStarted { server_id: Arc<str> },
ServerStopped { server_id: Arc<str> },
}
impl EventEmitter<Event> for ContextServerManager {}
impl ContextServerManager {
pub fn new(
registry: Model<ContextServerFactoryRegistry>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Self {
let mut this = Self {
_subscriptions: vec![
cx.observe(&registry, |this, _registry, cx| {
this.available_context_servers_changed(cx);
}),
cx.observe_global::<SettingsStore>(|this, cx| {
this.available_context_servers_changed(cx);
}),
],
project,
registry,
needs_server_update: false,
servers: HashMap::default(),
update_servers_task: None,
};
this.available_context_servers_changed(cx);
this
}
fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
if self.update_servers_task.is_some() {
self.needs_server_update = true;
} else {
self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
this.update(&mut cx, |this, _| {
this.needs_server_update = false;
})?;
Self::maintain_servers(this.clone(), cx.clone()).await?;
this.update(&mut cx, |this, cx| {
let has_any_context_servers = !this.servers().is_empty();
if has_any_context_servers {
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
});
}
this.update_servers_task.take();
if this.needs_server_update {
this.available_context_servers_changed(cx);
}
})?;
Ok(())
}));
}
}
pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
self.servers
.get(id)
.filter(|server| server.client().is_some())
.cloned()
}
pub fn restart_server(
&mut self,
id: &Arc<str>,
cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let id = id.clone();
cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop()?;
let config = server.config();
let new_server = Arc::new(ContextServer::new(id.clone(), config));
new_server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(id.clone(), new_server);
cx.emit(Event::ServerStopped {
server_id: id.clone(),
});
cx.emit(Event::ServerStarted {
server_id: id.clone(),
});
})?;
}
Ok(())
})
}
pub fn servers(&self) -> Vec<Arc<ContextServer>> {
self.servers
.values()
.filter(|server| server.client().is_some())
.cloned()
.collect()
}
async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
let mut desired_servers = HashMap::default();
let (registry, project) = this.update(&mut cx, |this, cx| {
let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
settings::SettingsLocation {
worktree_id: worktree.read(cx).id(),
path: Path::new(""),
}
});
let settings = ContextServerSettings::get(location, cx);
desired_servers = settings.context_servers.clone();
(this.registry.clone(), this.project.clone())
})?;
for (id, factory) in
registry.read_with(&cx, |registry, _| registry.context_server_factories())?
{
let config = desired_servers.entry(id).or_default();
if config.command.is_none() {
if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
config.command = Some(extension_command);
}
}
}
let mut servers_to_start = HashMap::default();
let mut servers_to_stop = HashMap::default();
this.update(&mut cx, |this, _cx| {
this.servers.retain(|id, server| {
if desired_servers.contains_key(id) {
true
} else {
servers_to_stop.insert(id.clone(), server.clone());
false
}
});
for (id, config) in desired_servers {
let existing_config = this.servers.get(&id).map(|server| server.config());
if existing_config.as_deref() != Some(&config) {
let config = Arc::new(config);
let server = Arc::new(ContextServer::new(id.clone(), config));
servers_to_start.insert(id.clone(), server.clone());
let old_server = this.servers.insert(id.clone(), server);
if let Some(old_server) = old_server {
servers_to_stop.insert(id, old_server);
}
}
}
})?;
for (id, server) in servers_to_stop {
server.stop().log_err();
this.update(&mut cx, |_, cx| {
cx.emit(Event::ServerStopped { server_id: id })
})?;
}
for (id, server) in servers_to_start {
if server.start(&cx).await.log_err().is_some() {
this.update(&mut cx, |_, cx| {
cx.emit(Event::ServerStarted { server_id: id })
})?;
}
}
Ok(())
}
}

View file

@ -0,0 +1,234 @@
//! This module implements parts of the Model Context Protocol.
//!
//! It handles the lifecycle messages, and provides a general interface to
//! interacting with an MCP server. It uses the generic JSON-RPC client to
//! read/write messages and the types from types.rs for serialization/deserialization
//! of messages.
use anyhow::Result;
use collections::HashMap;
use crate::client::Client;
use crate::types;
pub struct ModelContextProtocol {
inner: Client,
}
impl ModelContextProtocol {
pub fn new(inner: Client) -> Self {
Self { inner }
}
fn supported_protocols() -> Vec<types::ProtocolVersion> {
vec![types::ProtocolVersion(
types::LATEST_PROTOCOL_VERSION.to_string(),
)]
}
pub async fn initialize(
self,
client_info: types::Implementation,
) -> Result<InitializedContextServerProtocol> {
let params = types::InitializeParams {
protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
capabilities: types::ClientCapabilities {
experimental: None,
sampling: None,
roots: None,
},
meta: None,
client_info,
};
let response: types::InitializeResponse = self
.inner
.request(types::RequestType::Initialize.as_str(), params)
.await?;
if !Self::supported_protocols().contains(&response.protocol_version) {
return Err(anyhow::anyhow!(
"Unsupported protocol version: {:?}",
response.protocol_version
));
}
log::trace!("mcp server info {:?}", response.server_info);
self.inner.notify(
types::NotificationType::Initialized.as_str(),
serde_json::json!({}),
)?;
let initialized_protocol = InitializedContextServerProtocol {
inner: self.inner,
initialize: response,
};
Ok(initialized_protocol)
}
}
pub struct InitializedContextServerProtocol {
inner: Client,
pub initialize: types::InitializeResponse,
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum ServerCapability {
Experimental,
Logging,
Prompts,
Resources,
Tools,
}
impl InitializedContextServerProtocol {
/// Check if the server supports a specific capability
pub fn capable(&self, capability: ServerCapability) -> bool {
match capability {
ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
}
}
fn check_capability(&self, capability: ServerCapability) -> Result<()> {
if self.capable(capability) {
Ok(())
} else {
Err(anyhow::anyhow!(
"Server does not support {:?} capability",
capability
))
}
}
/// List the MCP prompts.
pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
self.check_capability(ServerCapability::Prompts)?;
let response: types::PromptsListResponse = self
.inner
.request(
types::RequestType::PromptsList.as_str(),
serde_json::json!({}),
)
.await?;
Ok(response.prompts)
}
/// List the MCP resources.
pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
self.check_capability(ServerCapability::Resources)?;
let response: types::ResourcesListResponse = self
.inner
.request(
types::RequestType::ResourcesList.as_str(),
serde_json::json!({}),
)
.await?;
Ok(response)
}
/// Executes a prompt with the given arguments and returns the result.
pub async fn run_prompt<P: AsRef<str>>(
&self,
prompt: P,
arguments: HashMap<String, String>,
) -> Result<types::PromptsGetResponse> {
self.check_capability(ServerCapability::Prompts)?;
let params = types::PromptsGetParams {
name: prompt.as_ref().to_string(),
arguments: Some(arguments),
meta: None,
};
let response: types::PromptsGetResponse = self
.inner
.request(types::RequestType::PromptsGet.as_str(), params)
.await?;
Ok(response)
}
pub async fn completion<P: Into<String>>(
&self,
reference: types::CompletionReference,
argument: P,
value: P,
) -> Result<types::Completion> {
let params = types::CompletionCompleteParams {
r#ref: reference,
argument: types::CompletionArgument {
name: argument.into(),
value: value.into(),
},
meta: None,
};
let result: types::CompletionCompleteResponse = self
.inner
.request(types::RequestType::CompletionComplete.as_str(), params)
.await?;
let completion = types::Completion {
values: result.completion.values,
total: types::CompletionTotal::from_options(
result.completion.has_more,
result.completion.total,
),
};
Ok(completion)
}
/// List MCP tools.
pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
self.check_capability(ServerCapability::Tools)?;
let response = self
.inner
.request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
.await?;
Ok(response)
}
/// Executes a tool with the given arguments
pub async fn run_tool<P: AsRef<str>>(
&self,
tool: P,
arguments: Option<HashMap<String, serde_json::Value>>,
) -> Result<types::CallToolResponse> {
self.check_capability(ServerCapability::Tools)?;
let params = types::CallToolParams {
name: tool.as_ref().to_string(),
arguments,
meta: None,
};
let response: types::CallToolResponse = self
.inner
.request(types::RequestType::CallTool.as_str(), params)
.await?;
Ok(response)
}
}
impl InitializedContextServerProtocol {
pub async fn request<R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: impl serde::Serialize,
) -> Result<R> {
self.inner.request(method, params).await
}
}

View file

@ -0,0 +1,62 @@
use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
use gpui::{AppContext, AsyncAppContext, Context, Global, Model, ReadGlobal, Task};
use project::Project;
use crate::ServerCommand;
pub type ContextServerFactory = Arc<
dyn Fn(Model<Project>, &AsyncAppContext) -> Task<Result<ServerCommand>> + Send + Sync + 'static,
>;
struct GlobalContextServerFactoryRegistry(Model<ContextServerFactoryRegistry>);
impl Global for GlobalContextServerFactoryRegistry {}
#[derive(Default)]
pub struct ContextServerFactoryRegistry {
context_servers: HashMap<Arc<str>, ContextServerFactory>,
}
impl ContextServerFactoryRegistry {
/// Returns the global [`ContextServerFactoryRegistry`].
pub fn global(cx: &AppContext) -> Model<Self> {
GlobalContextServerFactoryRegistry::global(cx).0.clone()
}
/// Returns the global [`ContextServerFactoryRegistry`].
///
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut AppContext) -> Model<Self> {
if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
let registry = cx.new_model(|_| Self::new());
cx.set_global(GlobalContextServerFactoryRegistry(registry));
}
cx.global::<GlobalContextServerFactoryRegistry>().0.clone()
}
pub fn new() -> Self {
Self {
context_servers: HashMap::default(),
}
}
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
self.context_servers
.iter()
.map(|(id, factory)| (id.clone(), factory.clone()))
.collect()
}
/// Registers the provided [`ContextServerFactory`].
pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
self.context_servers.insert(id, factory);
}
/// Unregisters the [`ContextServerFactory`] for the server with the given ID.
pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
self.context_servers.remove(server_id);
}
}

View file

@ -0,0 +1,549 @@
use collections::HashMap;
use serde::{Deserialize, Serialize};
use url::Url;
pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
pub enum RequestType {
Initialize,
CallTool,
ResourcesUnsubscribe,
ResourcesSubscribe,
ResourcesRead,
ResourcesList,
LoggingSetLevel,
PromptsGet,
PromptsList,
CompletionComplete,
Ping,
ListTools,
ListResourceTemplates,
ListRoots,
}
impl RequestType {
pub fn as_str(&self) -> &'static str {
match self {
RequestType::Initialize => "initialize",
RequestType::CallTool => "tools/call",
RequestType::ResourcesUnsubscribe => "resources/unsubscribe",
RequestType::ResourcesSubscribe => "resources/subscribe",
RequestType::ResourcesRead => "resources/read",
RequestType::ResourcesList => "resources/list",
RequestType::LoggingSetLevel => "logging/setLevel",
RequestType::PromptsGet => "prompts/get",
RequestType::PromptsList => "prompts/list",
RequestType::CompletionComplete => "completion/complete",
RequestType::Ping => "ping",
RequestType::ListTools => "tools/list",
RequestType::ListResourceTemplates => "resources/templates/list",
RequestType::ListRoots => "roots/list",
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ProtocolVersion(pub String);
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeParams {
pub protocol_version: ProtocolVersion,
pub capabilities: ClientCapabilities,
pub client_info: Implementation,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolParams {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<HashMap<String, serde_json::Value>>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesUnsubscribeParams {
pub uri: Url,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesSubscribeParams {
pub uri: Url,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesReadParams {
pub uri: Url,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct LoggingSetLevelParams {
pub level: LoggingLevel,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsGetParams {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<HashMap<String, String>>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionCompleteParams {
pub r#ref: CompletionReference,
pub argument: CompletionArgument,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum CompletionReference {
Prompt(PromptReference),
Resource(ResourceReference),
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptReference {
pub r#type: PromptReferenceType,
pub name: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceReference {
pub r#type: PromptReferenceType,
pub uri: Url,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum PromptReferenceType {
#[serde(rename = "ref/prompt")]
Prompt,
#[serde(rename = "ref/resource")]
Resource,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionArgument {
pub name: String,
pub value: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
pub protocol_version: ProtocolVersion,
pub capabilities: ServerCapabilities,
pub server_info: Implementation,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesReadResponse {
pub contents: Vec<ResourceContents>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesListResponse {
pub resources: Vec<Resource>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamplingMessage {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptMessage {
pub role: Role,
pub content: MessageContent,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum MessageContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { data: String, mime_type: String },
#[serde(rename = "resource")]
Resource { resource: ResourceContents },
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsGetResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub messages: Vec<PromptMessage>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsListResponse {
pub prompts: Vec<Prompt>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionCompleteResponse {
pub completion: CompletionResult,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionResult {
pub values: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub has_more: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Prompt {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<Vec<PromptArgument>>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptArgument {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub experimental: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub roots: Option<RootsCapabilities>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub experimental: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logging: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompts: Option<PromptsCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resources: Option<ResourcesCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<ToolsCapabilities>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub subscribe: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RootsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Implementation {
pub name: String,
pub version: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Resource {
pub uri: Url,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceContents {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextResourceContents {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
pub text: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BlobResourceContents {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
pub blob: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceTemplate {
pub uri_template: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LoggingLevel {
Debug,
Info,
Notice,
Warning,
Error,
Critical,
Alert,
Emergency,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ModelPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub hints: Option<Vec<ModelHint>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_priority: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub intelligence_priority: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ModelHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum NotificationType {
Initialized,
Progress,
Message,
ResourcesUpdated,
ResourcesListChanged,
ToolsListChanged,
PromptsListChanged,
RootsListChanged,
}
impl NotificationType {
pub fn as_str(&self) -> &'static str {
match self {
NotificationType::Initialized => "notifications/initialized",
NotificationType::Progress => "notifications/progress",
NotificationType::Message => "notifications/message",
NotificationType::ResourcesUpdated => "notifications/resources/updated",
NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
NotificationType::ToolsListChanged => "notifications/tools/list_changed",
NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
NotificationType::RootsListChanged => "notifications/roots/list_changed",
}
}
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum ClientNotification {
Initialized,
Progress(ProgressParams),
RootsListChanged,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ProgressToken {
String(String),
Number(f64),
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ProgressParams {
pub progress_token: ProgressToken,
pub progress: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<f64>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
pub enum CompletionTotal {
Exact(u32),
HasMore,
Unknown,
}
impl CompletionTotal {
pub fn from_options(has_more: Option<bool>, total: Option<u32>) -> Self {
match (has_more, total) {
(_, Some(count)) => CompletionTotal::Exact(count),
(Some(true), _) => CompletionTotal::HasMore,
_ => CompletionTotal::Unknown,
}
}
}
pub struct Completion {
pub values: Vec<String>,
pub total: CompletionTotal,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolResponse {
pub content: Vec<ToolResponseContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ToolResponseContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { data: String, mime_type: String },
#[serde(rename = "resource")]
Resource { resource: ResourceContents },
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListToolsResponse {
pub tools: Vec<Tool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListRootsResponse {
pub roots: Vec<Root>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Root {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}