From f11357db7ca457518a1ad71fa7fc9479cb98fa8e Mon Sep 17 00:00:00 2001 From: Federico Dionisi Date: Wed, 26 Feb 2025 18:19:19 +0100 Subject: [PATCH] context_server: Abstract server transport (#24528) This PR abstracts the communication layer for context servers, laying the groundwork for supporting multiple transport mechanisms and taking one step towards enabling remote servers. Key changes centre around creating a new `Transport` trait with methods for sending and receiving messages. I've implemented this trait for the existing stdio-based communication, which is now encapsulated in a `StdioTransport` struct. The `Client` struct has been refactored to use this new `Transport` trait instead of directly managing stdin and stdout. The next steps will involve implementing an SSE + HTTP transport and defining alternative context server settings for remote servers. Release Notes: - N/A --------- Co-authored-by: Marshall Bowers --- Cargo.lock | 1 + crates/context_server/Cargo.toml | 1 + crates/context_server/src/client.rs | 154 ++++++------------ crates/context_server/src/context_server.rs | 1 + crates/context_server/src/transport.rs | 16 ++ .../src/transport/stdio_transport.rs | 140 ++++++++++++++++ 6 files changed, 210 insertions(+), 103 deletions(-) create mode 100644 crates/context_server/src/transport.rs create mode 100644 crates/context_server/src/transport/stdio_transport.rs diff --git a/Cargo.lock b/Cargo.lock index a2bab321d9..6a0de841af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3101,6 +3101,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_tool", + "async-trait", "collections", "command_palette_hooks", "context_server_settings", diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index dab1a76a96..33630c2b29 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -14,6 +14,7 @@ path = "src/context_server.rs" [dependencies] anyhow.workspace = true assistant_tool.workspace = true +async-trait.workspace = true collections.workspace = true command_palette_hooks.workspace = true context_server_settings.workspace = true diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 4e0b799fb8..83517fc238 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -1,16 +1,12 @@ -use anyhow::{anyhow, Context as _, Result}; +use anyhow::{anyhow, Context, Result}; use collections::HashMap; -use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt}; +use futures::{channel::oneshot, select, FutureExt, StreamExt}; use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task}; use parking_lot::Mutex; use postage::barrier; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::{value::RawValue, Value}; -use smol::{ - channel, - io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, - process::Child, -}; +use smol::channel; use std::{ fmt, path::PathBuf, @@ -22,6 +18,8 @@ use std::{ }; use util::TryFutureExt; +use crate::transport::{StdioTransport, Transport}; + const JSON_RPC_VERSION: &str = "2.0"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); @@ -55,7 +53,8 @@ pub struct Client { #[allow(dead_code)] output_done_rx: Mutex>, executor: BackgroundExecutor, - server: Arc>>, + #[allow(dead_code)] + transport: Arc, } #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -152,25 +151,13 @@ impl Client { &binary.args ); - let mut command = util::command::new_smol_command(&binary.executable); - command - .args(&binary.args) - .envs(binary.env.unwrap_or_default()) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); + let server_name = binary + .executable + .file_name() + .map(|name| name.to_string_lossy().to_string()) + .unwrap_or_else(String::new); - let mut server = command.spawn().with_context(|| { - format!( - "failed to spawn command. (path={:?}, args={:?})", - binary.executable, &binary.args - ) - })?; - - let stdin = server.stdin.take().unwrap(); - let stdout = server.stdout.take().unwrap(); - let stderr = server.stderr.take().unwrap(); + let transport = Arc::new(StdioTransport::new(binary, &cx)?); let (outbound_tx, outbound_rx) = channel::unbounded::(); let (output_done_tx, output_done_rx) = barrier::channel(); @@ -183,18 +170,22 @@ impl Client { let stdout_input_task = cx.spawn({ let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); + let transport = transport.clone(); move |cx| { - Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err() + Self::handle_input(transport, notification_handlers, response_handlers, cx) + .log_err() } }); - let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err()); + let stderr_input_task = cx.spawn(|_| Self::handle_stderr(transport.clone()).log_err()); let input_task = cx.spawn(|_| async move { let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task); stdout.or(stderr) }); + let output_task = cx.background_spawn({ + let transport = transport.clone(); Self::handle_output( - stdin, + transport, outbound_rx, output_done_tx, response_handlers.clone(), @@ -202,24 +193,18 @@ impl Client { .log_err() }); - let mut context_server = Self { + Ok(Self { server_id, notification_handlers, response_handlers, - name: "".into(), + name: server_name.into(), next_id: Default::default(), outbound_tx, executor: cx.background_executor().clone(), io_tasks: Mutex::new(Some((input_task, output_task))), output_done_rx: Mutex::new(Some(output_done_rx)), - server: Arc::new(Mutex::new(Some(server))), - }; - - if let Some(name) = binary.executable.file_name() { - context_server.name = name.to_string_lossy().into(); - } - - Ok(context_server) + transport, + }) } /// Handles input from the server's stdout. @@ -228,79 +213,53 @@ impl Client { /// parses them as JSON-RPC responses or notifications, and dispatches them /// to the appropriate handlers. It processes both responses (which are matched /// to pending requests) and notifications (which trigger registered handlers). - async fn handle_input( - stdout: Stdout, + async fn handle_input( + transport: Arc, notification_handlers: Arc>>, response_handlers: Arc>>>, cx: AsyncApp, - ) -> anyhow::Result<()> - where - Stdout: AsyncRead + Unpin + Send + 'static, - { - let mut stdout = BufReader::new(stdout); - let mut buffer = String::new(); + ) -> anyhow::Result<()> { + let mut receiver = transport.receive(); - loop { - buffer.clear(); - if stdout.read_line(&mut buffer).await? == 0 { - return Ok(()); - } - - let content = buffer.trim(); - - if !content.is_empty() { - if let Ok(response) = serde_json::from_str::(content) { - if let Some(handlers) = response_handlers.lock().as_mut() { - if let Some(handler) = handlers.remove(&response.id) { - handler(Ok(content.to_string())); - } - } - } else if let Ok(notification) = serde_json::from_str::(content) { - let mut notification_handlers = notification_handlers.lock(); - if let Some(handler) = - notification_handlers.get_mut(notification.method.as_str()) - { - handler(notification.params.unwrap_or(Value::Null), cx.clone()); + while let Some(message) = receiver.next().await { + if let Ok(response) = serde_json::from_str::(&message) { + if let Some(handlers) = response_handlers.lock().as_mut() { + if let Some(handler) = handlers.remove(&response.id) { + handler(Ok(message.to_string())); } } + } else if let Ok(notification) = serde_json::from_str::(&message) { + let mut notification_handlers = notification_handlers.lock(); + if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { + handler(notification.params.unwrap_or(Value::Null), cx.clone()); + } } - - smol::future::yield_now().await; } + + smol::future::yield_now().await; + + Ok(()) } /// Handles the stderr output from the context server. /// Continuously reads and logs any error messages from the server. - async fn handle_stderr(stderr: Stderr) -> anyhow::Result<()> - where - Stderr: AsyncRead + Unpin + Send + 'static, - { - let mut stderr = BufReader::new(stderr); - let mut buffer = String::new(); - - loop { - buffer.clear(); - if stderr.read_line(&mut buffer).await? == 0 { - return Ok(()); - } - log::warn!("context server stderr: {}", buffer.trim()); - smol::future::yield_now().await; + async fn handle_stderr(transport: Arc) -> anyhow::Result<()> { + while let Some(err) = transport.receive_err().next().await { + log::warn!("context server stderr: {}", err.trim()); } + + Ok(()) } /// Handles the output to the context server's stdin. /// This function continuously receives messages from the outbound channel, /// writes them to the server's stdin, and manages the lifecycle of response handlers. - async fn handle_output( - stdin: Stdin, + async fn handle_output( + transport: Arc, outbound_rx: channel::Receiver, output_done_tx: barrier::Sender, response_handlers: Arc>>>, - ) -> anyhow::Result<()> - where - Stdin: AsyncWrite + Unpin + Send + 'static, - { - let mut stdin = BufWriter::new(stdin); + ) -> anyhow::Result<()> { let _clear_response_handlers = util::defer({ let response_handlers = response_handlers.clone(); move || { @@ -309,10 +268,7 @@ impl Client { }); while let Ok(message) = outbound_rx.recv().await { log::trace!("outgoing message: {}", message); - - stdin.write_all(message.as_bytes()).await?; - stdin.write_all(b"\n").await?; - stdin.flush().await?; + transport.send(message).await?; } drop(output_done_tx); Ok(()) @@ -416,14 +372,6 @@ impl Client { } } -impl Drop for Client { - fn drop(&mut self) { - if let Some(mut server) = self.server.lock().take() { - let _ = server.kill(); - } - } -} - impl fmt::Display for ContextServerId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 6189ea4e82..bbd0cc6a7c 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -4,6 +4,7 @@ mod extension_context_server; pub mod manager; pub mod protocol; mod registry; +mod transport; pub mod types; use command_palette_hooks::CommandPaletteFilter; diff --git a/crates/context_server/src/transport.rs b/crates/context_server/src/transport.rs new file mode 100644 index 0000000000..b4f56b0ef0 --- /dev/null +++ b/crates/context_server/src/transport.rs @@ -0,0 +1,16 @@ +mod stdio_transport; + +use std::pin::Pin; + +use anyhow::Result; +use async_trait::async_trait; +use futures::Stream; + +pub use stdio_transport::*; + +#[async_trait] +pub trait Transport: Send + Sync { + async fn send(&self, message: String) -> Result<()>; + fn receive(&self) -> Pin + Send>>; + fn receive_err(&self) -> Pin + Send>>; +} diff --git a/crates/context_server/src/transport/stdio_transport.rs b/crates/context_server/src/transport/stdio_transport.rs new file mode 100644 index 0000000000..cdfe58a4bd --- /dev/null +++ b/crates/context_server/src/transport/stdio_transport.rs @@ -0,0 +1,140 @@ +use std::pin::Pin; + +use anyhow::{Context as _, Result}; +use async_trait::async_trait; +use futures::io::{BufReader, BufWriter}; +use futures::{ + AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _, +}; +use gpui::AsyncApp; +use smol::channel; +use smol::process::Child; +use util::TryFutureExt as _; + +use crate::client::ModelContextServerBinary; +use crate::transport::Transport; + +pub struct StdioTransport { + stdout_sender: channel::Sender, + stdin_receiver: channel::Receiver, + stderr_receiver: channel::Receiver, + server: Child, +} + +impl StdioTransport { + pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result { + let mut command = util::command::new_smol_command(&binary.executable); + command + .args(&binary.args) + .envs(binary.env.unwrap_or_default()) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true); + + let mut server = command.spawn().with_context(|| { + format!( + "failed to spawn command. (path={:?}, args={:?})", + binary.executable, &binary.args + ) + })?; + + let stdin = server.stdin.take().unwrap(); + let stdout = server.stdout.take().unwrap(); + let stderr = server.stderr.take().unwrap(); + + let (stdin_sender, stdin_receiver) = channel::unbounded::(); + let (stdout_sender, stdout_receiver) = channel::unbounded::(); + let (stderr_sender, stderr_receiver) = channel::unbounded::(); + + cx.spawn(|_| Self::handle_output(stdin, stdout_receiver).log_err()) + .detach(); + + cx.spawn(|_| async move { Self::handle_input(stdout, stdin_sender).await }) + .detach(); + + cx.spawn(|_| async move { Self::handle_err(stderr, stderr_sender).await }) + .detach(); + + Ok(Self { + stdout_sender, + stdin_receiver, + stderr_receiver, + server, + }) + } + + async fn handle_input(stdin: Stdout, inbound_rx: channel::Sender) + where + Stdout: AsyncRead + Unpin + Send + 'static, + { + let mut stdin = BufReader::new(stdin); + let mut line = String::new(); + while let Ok(n) = stdin.read_line(&mut line).await { + if n == 0 { + break; + } + if inbound_rx.send(line.clone()).await.is_err() { + break; + } + line.clear(); + } + } + + async fn handle_output( + stdin: Stdin, + outbound_rx: channel::Receiver, + ) -> Result<()> + where + Stdin: AsyncWrite + Unpin + Send + 'static, + { + let mut stdin = BufWriter::new(stdin); + let mut pinned_rx = Box::pin(outbound_rx); + while let Some(message) = pinned_rx.next().await { + log::trace!("outgoing message: {}", message); + + stdin.write_all(message.as_bytes()).await?; + stdin.write_all(b"\n").await?; + stdin.flush().await?; + } + Ok(()) + } + + async fn handle_err(stderr: Stderr, stderr_tx: channel::Sender) + where + Stderr: AsyncRead + Unpin + Send + 'static, + { + let mut stderr = BufReader::new(stderr); + let mut line = String::new(); + while let Ok(n) = stderr.read_line(&mut line).await { + if n == 0 { + break; + } + if stderr_tx.send(line.clone()).await.is_err() { + break; + } + line.clear(); + } + } +} + +#[async_trait] +impl Transport for StdioTransport { + async fn send(&self, message: String) -> Result<()> { + Ok(self.stdout_sender.send(message).await?) + } + + fn receive(&self) -> Pin + Send>> { + Box::pin(self.stdin_receiver.clone()) + } + + fn receive_err(&self) -> Pin + Send>> { + Box::pin(self.stderr_receiver.clone()) + } +} + +impl Drop for StdioTransport { + fn drop(&mut self) { + let _ = self.server.kill(); + } +}