context_server: Abstract server transport (#24528)

This PR abstracts the communication layer for context servers, laying
the groundwork for supporting multiple transport mechanisms and taking
one step towards enabling remote servers.

Key changes centre around creating a new `Transport` trait with methods
for sending and receiving messages. I've implemented this trait for the
existing stdio-based communication, which is now encapsulated in a
`StdioTransport` struct. The `Client` struct has been refactored to use
this new `Transport` trait instead of directly managing stdin and
stdout.

The next steps will involve implementing an SSE + HTTP transport and
defining alternative context server settings for remote servers.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Federico Dionisi 2025-02-26 18:19:19 +01:00 committed by GitHub
parent 6d17546b1a
commit f11357db7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 210 additions and 103 deletions

View file

@ -14,6 +14,7 @@ path = "src/context_server.rs"
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
async-trait.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
context_server_settings.workspace = true

View file

@ -1,16 +1,12 @@
use anyhow::{anyhow, Context as _, Result};
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
use futures::{channel::oneshot, select, FutureExt, StreamExt};
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
use parking_lot::Mutex;
use postage::barrier;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{value::RawValue, Value};
use smol::{
channel,
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
process::Child,
};
use smol::channel;
use std::{
fmt,
path::PathBuf,
@ -22,6 +18,8 @@ use std::{
};
use util::TryFutureExt;
use crate::transport::{StdioTransport, Transport};
const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
@ -55,7 +53,8 @@ pub struct Client {
#[allow(dead_code)]
output_done_rx: Mutex<Option<barrier::Receiver>>,
executor: BackgroundExecutor,
server: Arc<Mutex<Option<Child>>>,
#[allow(dead_code)]
transport: Arc<dyn Transport>,
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@ -152,25 +151,13 @@ impl Client {
&binary.args
);
let mut command = util::command::new_smol_command(&binary.executable);
command
.args(&binary.args)
.envs(binary.env.unwrap_or_default())
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
let server_name = binary
.executable
.file_name()
.map(|name| name.to_string_lossy().to_string())
.unwrap_or_else(String::new);
let mut server = command.spawn().with_context(|| {
format!(
"failed to spawn command. (path={:?}, args={:?})",
binary.executable, &binary.args
)
})?;
let stdin = server.stdin.take().unwrap();
let stdout = server.stdout.take().unwrap();
let stderr = server.stderr.take().unwrap();
let transport = Arc::new(StdioTransport::new(binary, &cx)?);
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
@ -183,18 +170,22 @@ impl Client {
let stdout_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
let transport = transport.clone();
move |cx| {
Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
Self::handle_input(transport, notification_handlers, response_handlers, cx)
.log_err()
}
});
let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
let stderr_input_task = cx.spawn(|_| Self::handle_stderr(transport.clone()).log_err());
let input_task = cx.spawn(|_| async move {
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
stdout.or(stderr)
});
let output_task = cx.background_spawn({
let transport = transport.clone();
Self::handle_output(
stdin,
transport,
outbound_rx,
output_done_tx,
response_handlers.clone(),
@ -202,24 +193,18 @@ impl Client {
.log_err()
});
let mut context_server = Self {
Ok(Self {
server_id,
notification_handlers,
response_handlers,
name: "".into(),
name: server_name.into(),
next_id: Default::default(),
outbound_tx,
executor: cx.background_executor().clone(),
io_tasks: Mutex::new(Some((input_task, output_task))),
output_done_rx: Mutex::new(Some(output_done_rx)),
server: Arc::new(Mutex::new(Some(server))),
};
if let Some(name) = binary.executable.file_name() {
context_server.name = name.to_string_lossy().into();
}
Ok(context_server)
transport,
})
}
/// Handles input from the server's stdout.
@ -228,79 +213,53 @@ impl Client {
/// parses them as JSON-RPC responses or notifications, and dispatches them
/// to the appropriate handlers. It processes both responses (which are matched
/// to pending requests) and notifications (which trigger registered handlers).
async fn handle_input<Stdout>(
stdout: Stdout,
async fn handle_input(
transport: Arc<dyn Transport>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
cx: AsyncApp,
) -> anyhow::Result<()>
where
Stdout: AsyncRead + Unpin + Send + 'static,
{
let mut stdout = BufReader::new(stdout);
let mut buffer = String::new();
) -> anyhow::Result<()> {
let mut receiver = transport.receive();
loop {
buffer.clear();
if stdout.read_line(&mut buffer).await? == 0 {
return Ok(());
}
let content = buffer.trim();
if !content.is_empty() {
if let Ok(response) = serde_json::from_str::<AnyResponse>(content) {
if let Some(handlers) = response_handlers.lock().as_mut() {
if let Some(handler) = handlers.remove(&response.id) {
handler(Ok(content.to_string()));
}
}
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(content) {
let mut notification_handlers = notification_handlers.lock();
if let Some(handler) =
notification_handlers.get_mut(notification.method.as_str())
{
handler(notification.params.unwrap_or(Value::Null), cx.clone());
while let Some(message) = receiver.next().await {
if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
if let Some(handlers) = response_handlers.lock().as_mut() {
if let Some(handler) = handlers.remove(&response.id) {
handler(Ok(message.to_string()));
}
}
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
let mut notification_handlers = notification_handlers.lock();
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
handler(notification.params.unwrap_or(Value::Null), cx.clone());
}
}
smol::future::yield_now().await;
}
smol::future::yield_now().await;
Ok(())
}
/// Handles the stderr output from the context server.
/// Continuously reads and logs any error messages from the server.
async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
where
Stderr: AsyncRead + Unpin + Send + 'static,
{
let mut stderr = BufReader::new(stderr);
let mut buffer = String::new();
loop {
buffer.clear();
if stderr.read_line(&mut buffer).await? == 0 {
return Ok(());
}
log::warn!("context server stderr: {}", buffer.trim());
smol::future::yield_now().await;
async fn handle_stderr(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
while let Some(err) = transport.receive_err().next().await {
log::warn!("context server stderr: {}", err.trim());
}
Ok(())
}
/// Handles the output to the context server's stdin.
/// This function continuously receives messages from the outbound channel,
/// writes them to the server's stdin, and manages the lifecycle of response handlers.
async fn handle_output<Stdin>(
stdin: Stdin,
async fn handle_output(
transport: Arc<dyn Transport>,
outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
) -> anyhow::Result<()>
where
Stdin: AsyncWrite + Unpin + Send + 'static,
{
let mut stdin = BufWriter::new(stdin);
) -> anyhow::Result<()> {
let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
move || {
@ -309,10 +268,7 @@ impl Client {
});
while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message: {}", message);
stdin.write_all(message.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
transport.send(message).await?;
}
drop(output_done_tx);
Ok(())
@ -416,14 +372,6 @@ impl Client {
}
}
impl Drop for Client {
fn drop(&mut self) {
if let Some(mut server) = self.server.lock().take() {
let _ = server.kill();
}
}
}
impl fmt::Display for ContextServerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)

View file

@ -4,6 +4,7 @@ mod extension_context_server;
pub mod manager;
pub mod protocol;
mod registry;
mod transport;
pub mod types;
use command_palette_hooks::CommandPaletteFilter;

View file

@ -0,0 +1,16 @@
mod stdio_transport;
use std::pin::Pin;
use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
pub use stdio_transport::*;
#[async_trait]
pub trait Transport: Send + Sync {
async fn send(&self, message: String) -> Result<()>;
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
}

View file

@ -0,0 +1,140 @@
use std::pin::Pin;
use anyhow::{Context as _, Result};
use async_trait::async_trait;
use futures::io::{BufReader, BufWriter};
use futures::{
AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
};
use gpui::AsyncApp;
use smol::channel;
use smol::process::Child;
use util::TryFutureExt as _;
use crate::client::ModelContextServerBinary;
use crate::transport::Transport;
pub struct StdioTransport {
stdout_sender: channel::Sender<String>,
stdin_receiver: channel::Receiver<String>,
stderr_receiver: channel::Receiver<String>,
server: Child,
}
impl StdioTransport {
pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result<Self> {
let mut command = util::command::new_smol_command(&binary.executable);
command
.args(&binary.args)
.envs(binary.env.unwrap_or_default())
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
let mut server = command.spawn().with_context(|| {
format!(
"failed to spawn command. (path={:?}, args={:?})",
binary.executable, &binary.args
)
})?;
let stdin = server.stdin.take().unwrap();
let stdout = server.stdout.take().unwrap();
let stderr = server.stderr.take().unwrap();
let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
cx.spawn(|_| Self::handle_output(stdin, stdout_receiver).log_err())
.detach();
cx.spawn(|_| async move { Self::handle_input(stdout, stdin_sender).await })
.detach();
cx.spawn(|_| async move { Self::handle_err(stderr, stderr_sender).await })
.detach();
Ok(Self {
stdout_sender,
stdin_receiver,
stderr_receiver,
server,
})
}
async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
where
Stdout: AsyncRead + Unpin + Send + 'static,
{
let mut stdin = BufReader::new(stdin);
let mut line = String::new();
while let Ok(n) = stdin.read_line(&mut line).await {
if n == 0 {
break;
}
if inbound_rx.send(line.clone()).await.is_err() {
break;
}
line.clear();
}
}
async fn handle_output<Stdin>(
stdin: Stdin,
outbound_rx: channel::Receiver<String>,
) -> Result<()>
where
Stdin: AsyncWrite + Unpin + Send + 'static,
{
let mut stdin = BufWriter::new(stdin);
let mut pinned_rx = Box::pin(outbound_rx);
while let Some(message) = pinned_rx.next().await {
log::trace!("outgoing message: {}", message);
stdin.write_all(message.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
}
Ok(())
}
async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
where
Stderr: AsyncRead + Unpin + Send + 'static,
{
let mut stderr = BufReader::new(stderr);
let mut line = String::new();
while let Ok(n) = stderr.read_line(&mut line).await {
if n == 0 {
break;
}
if stderr_tx.send(line.clone()).await.is_err() {
break;
}
line.clear();
}
}
}
#[async_trait]
impl Transport for StdioTransport {
async fn send(&self, message: String) -> Result<()> {
Ok(self.stdout_sender.send(message).await?)
}
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(self.stdin_receiver.clone())
}
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(self.stderr_receiver.clone())
}
}
impl Drop for StdioTransport {
fn drop(&mut self) {
let _ = self.server.kill();
}
}