context_servers: Add initial implementation (#16103)

This commit proposes the addition of "context serveres" and the
underlying protocol (model context protocol). Context servers allow
simple definition of slash commands in another language and running
local on the user machines. This aims to quickly prototype new commands,
and provide a way to add personal (or company wide) customizations to
the assistant panel, without having to maintain an extension. We can
use this to reuse our existing codebase, with authenticators, etc and
easily have it provide context into the assistant panel.

As such it occupies a different design space as extensions, which I
think are
more aimed towards long-term, well maintained pieces of code that can be
easily distributed.

It's implemented as a central crate for easy reusability across the
codebase
and to easily hook into the assistant panel at all points.

Design wise there are a few pieces:
1. client.rs: A simple JSON-RPC client talking over stdio to a spawned
server. This is
very close to how LSP work and likely there could be a combined client
down the line.
2. types.rs: Serialization and deserialization client for the underlying
model context protocol.
3. protocol.rs: Handling the session between client and server.
4. manager.rs: Manages settings and adding and deleting servers from a
central pool.

A server can be defined in the settings.json as:

```
"context_servers": [
   {"id": "test", "executable": "python", "args": ["-m", "context_server"]
]
```

## Quick Example
A quick example of how a theoretical backend site can look like. With
roughly 100 lines
of code (nicely generated by Claude) and a bit of decorator magic (200
lines in total), one
can come up with a framework that makes it as easy as:

```python
@context_server.slash_command(name="rot13", description="Perform a rot13 transformation")
@context_server.argument(name="input", type=str, help="String to rot13")
async def rot13(input: str) -> str:
    return ''.join(chr((ord(c) - 97 + 13) % 26 + 97) if c.isalpha() else c for c in echo.lower())
```

to define a new slash_command.

## Todo:
 - Allow context servers to be defined in workspace settings.
 - Allow passing env variables to context_servers


Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
David Soria Parra 2024-08-15 15:49:30 +01:00 committed by GitHub
parent d54818fd9e
commit 02ea6ac845
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1433 additions and 7 deletions

View file

@ -0,0 +1,278 @@
//! This module implements a context server management system for Zed.
//!
//! It provides functionality to:
//! - Define and load context server settings
//! - Manage individual context servers (start, stop, restart)
//! - Maintain a global manager for all context servers
//!
//! Key components:
//! - `ContextServerSettings`: Defines the structure for server configurations
//! - `ContextServer`: Represents an individual context server
//! - `ContextServerManager`: Manages multiple context servers
//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
//!
//! The module also includes initialization logic to set up the context server system
//! and react to changes in settings.
use collections::{HashMap, HashSet};
use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task};
use log;
use parking_lot::RwLock;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources, SettingsStore};
use std::path::Path;
use std::sync::Arc;
use crate::{
client::{self, Client},
types,
};
#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
pub struct ContextServerSettings {
pub servers: Vec<ServerConfig>,
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
pub struct ServerConfig {
pub id: String,
pub executable: String,
pub args: Vec<String>,
}
impl Settings for ContextServerSettings {
const KEY: Option<&'static str> = Some("experimental.context_servers");
type FileContent = Self;
fn load(
sources: SettingsSources<Self::FileContent>,
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
sources.json_merge()
}
}
pub struct ContextServer {
pub id: String,
pub config: ServerConfig,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
}
impl ContextServer {
fn new(config: ServerConfig) -> Self {
Self {
id: config.id.clone(),
config,
client: RwLock::new(None),
}
}
async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
log::info!("starting context server {}", self.config.id);
let client = Client::new(
client::ContextServerId(self.config.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&self.config.executable).to_path_buf(),
args: self.config.args.clone(),
env: None,
},
cx.clone(),
)?;
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::EntityInfo {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!(
"context server {} initialized: {:?}",
self.config.id,
initialized_protocol.initialize,
);
*self.client.write() = Some(Arc::new(initialized_protocol));
Ok(())
}
async fn stop(&self) -> anyhow::Result<()> {
let mut client = self.client.write();
if let Some(protocol) = client.take() {
drop(protocol);
}
Ok(())
}
}
/// A Context server manager manages the starting and stopping
/// of all servers. To obtain a server to interact with, a crate
/// must go through the `GlobalContextServerManager` which holds
/// a model to the ContextServerManager.
pub struct ContextServerManager {
servers: HashMap<String, Arc<ContextServer>>,
pending_servers: HashSet<String>,
}
pub enum Event {
ServerStarted { server_id: String },
ServerStopped { server_id: String },
}
impl Global for ContextServerManager {}
impl EventEmitter<Event> for ContextServerManager {}
impl ContextServerManager {
pub fn new() -> Self {
Self {
servers: HashMap::default(),
pending_servers: HashSet::default(),
}
}
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<GlobalContextServerManager>().0.clone()
}
pub fn add_server(
&mut self,
config: ServerConfig,
cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let server_id = config.id.clone();
let server_id2 = config.id.clone();
if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
return Task::ready(Ok(()));
}
let task = cx.spawn(|this, mut cx| async move {
let server = Arc::new(ContextServer::new(config));
server.start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(server_id.clone(), server);
this.pending_servers.remove(&server_id);
cx.emit(Event::ServerStarted {
server_id: server_id.clone(),
});
})?;
Ok(())
});
self.pending_servers.insert(server_id2);
task
}
pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
self.servers.get(id).cloned()
}
pub fn remove_server(
&mut self,
id: &str,
cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let id = id.to_string();
cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop().await?;
}
this.update(&mut cx, |this, cx| {
this.pending_servers.remove(&id);
cx.emit(Event::ServerStopped {
server_id: id.clone(),
})
})?;
Ok(())
})
}
pub fn restart_server(
&mut self,
id: &str,
cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let id = id.to_string();
cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop().await?;
let config = server.config.clone();
let new_server = Arc::new(ContextServer::new(config));
new_server.start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(id.clone(), new_server);
cx.emit(Event::ServerStopped {
server_id: id.clone(),
});
cx.emit(Event::ServerStarted {
server_id: id.clone(),
});
})?;
}
Ok(())
})
}
pub fn servers(&self) -> Vec<Arc<ContextServer>> {
self.servers.values().cloned().collect()
}
pub fn model(cx: &mut AppContext) -> Model<Self> {
cx.new_model(|_cx| ContextServerManager::new())
}
}
pub struct GlobalContextServerManager(Model<ContextServerManager>);
impl Global for GlobalContextServerManager {}
impl GlobalContextServerManager {
fn register(cx: &mut AppContext) {
let model = ContextServerManager::model(cx);
cx.set_global(Self(model));
}
}
pub fn init(cx: &mut AppContext) {
ContextServerSettings::register(cx);
GlobalContextServerManager::register(cx);
cx.observe_global::<SettingsStore>(|cx| {
let manager = ContextServerManager::global(cx);
cx.update_model(&manager, |manager, cx| {
let settings = ContextServerSettings::get_global(cx);
let current_servers: HashMap<String, ServerConfig> = manager
.servers()
.into_iter()
.map(|server| (server.id.clone(), server.config.clone()))
.collect();
let new_servers = settings
.servers
.iter()
.map(|config| (config.id.clone(), config.clone()))
.collect::<HashMap<_, _>>();
let servers_to_add = new_servers
.values()
.filter(|config| !current_servers.contains_key(&config.id))
.cloned()
.collect::<Vec<_>>();
let servers_to_remove = current_servers
.keys()
.filter(|id| !new_servers.contains_key(*id))
.cloned()
.collect::<Vec<_>>();
log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add {
manager.add_server(config, cx).detach();
}
for id in servers_to_remove {
manager.remove_server(&id, cx).detach();
}
})
})
.detach();
}