context_server: Abstract server transport (#24528)
This PR abstracts the communication layer for context servers, laying the groundwork for supporting multiple transport mechanisms and taking one step towards enabling remote servers. Key changes centre around creating a new `Transport` trait with methods for sending and receiving messages. I've implemented this trait for the existing stdio-based communication, which is now encapsulated in a `StdioTransport` struct. The `Client` struct has been refactored to use this new `Transport` trait instead of directly managing stdin and stdout. The next steps will involve implementing an SSE + HTTP transport and defining alternative context server settings for remote servers. Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
6d17546b1a
commit
f11357db7c
6 changed files with 210 additions and 103 deletions
|
@ -14,6 +14,7 @@ path = "src/context_server.rs"
|
|||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
async-trait.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
context_server_settings.workspace = true
|
||||
|
|
|
@ -1,16 +1,12 @@
|
|||
use anyhow::{anyhow, Context as _, Result};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
|
||||
use futures::{channel::oneshot, select, FutureExt, StreamExt};
|
||||
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
|
||||
use parking_lot::Mutex;
|
||||
use postage::barrier;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::{value::RawValue, Value};
|
||||
use smol::{
|
||||
channel,
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||
process::Child,
|
||||
};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
fmt,
|
||||
path::PathBuf,
|
||||
|
@ -22,6 +18,8 @@ use std::{
|
|||
};
|
||||
use util::TryFutureExt;
|
||||
|
||||
use crate::transport::{StdioTransport, Transport};
|
||||
|
||||
const JSON_RPC_VERSION: &str = "2.0";
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
|
@ -55,7 +53,8 @@ pub struct Client {
|
|||
#[allow(dead_code)]
|
||||
output_done_rx: Mutex<Option<barrier::Receiver>>,
|
||||
executor: BackgroundExecutor,
|
||||
server: Arc<Mutex<Option<Child>>>,
|
||||
#[allow(dead_code)]
|
||||
transport: Arc<dyn Transport>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
|
@ -152,25 +151,13 @@ impl Client {
|
|||
&binary.args
|
||||
);
|
||||
|
||||
let mut command = util::command::new_smol_command(&binary.executable);
|
||||
command
|
||||
.args(&binary.args)
|
||||
.envs(binary.env.unwrap_or_default())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true);
|
||||
let server_name = binary
|
||||
.executable
|
||||
.file_name()
|
||||
.map(|name| name.to_string_lossy().to_string())
|
||||
.unwrap_or_else(String::new);
|
||||
|
||||
let mut server = command.spawn().with_context(|| {
|
||||
format!(
|
||||
"failed to spawn command. (path={:?}, args={:?})",
|
||||
binary.executable, &binary.args
|
||||
)
|
||||
})?;
|
||||
|
||||
let stdin = server.stdin.take().unwrap();
|
||||
let stdout = server.stdout.take().unwrap();
|
||||
let stderr = server.stderr.take().unwrap();
|
||||
let transport = Arc::new(StdioTransport::new(binary, &cx)?);
|
||||
|
||||
let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
|
||||
let (output_done_tx, output_done_rx) = barrier::channel();
|
||||
|
@ -183,18 +170,22 @@ impl Client {
|
|||
let stdout_input_task = cx.spawn({
|
||||
let notification_handlers = notification_handlers.clone();
|
||||
let response_handlers = response_handlers.clone();
|
||||
let transport = transport.clone();
|
||||
move |cx| {
|
||||
Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
|
||||
Self::handle_input(transport, notification_handlers, response_handlers, cx)
|
||||
.log_err()
|
||||
}
|
||||
});
|
||||
let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
|
||||
let stderr_input_task = cx.spawn(|_| Self::handle_stderr(transport.clone()).log_err());
|
||||
let input_task = cx.spawn(|_| async move {
|
||||
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
|
||||
stdout.or(stderr)
|
||||
});
|
||||
|
||||
let output_task = cx.background_spawn({
|
||||
let transport = transport.clone();
|
||||
Self::handle_output(
|
||||
stdin,
|
||||
transport,
|
||||
outbound_rx,
|
||||
output_done_tx,
|
||||
response_handlers.clone(),
|
||||
|
@ -202,24 +193,18 @@ impl Client {
|
|||
.log_err()
|
||||
});
|
||||
|
||||
let mut context_server = Self {
|
||||
Ok(Self {
|
||||
server_id,
|
||||
notification_handlers,
|
||||
response_handlers,
|
||||
name: "".into(),
|
||||
name: server_name.into(),
|
||||
next_id: Default::default(),
|
||||
outbound_tx,
|
||||
executor: cx.background_executor().clone(),
|
||||
io_tasks: Mutex::new(Some((input_task, output_task))),
|
||||
output_done_rx: Mutex::new(Some(output_done_rx)),
|
||||
server: Arc::new(Mutex::new(Some(server))),
|
||||
};
|
||||
|
||||
if let Some(name) = binary.executable.file_name() {
|
||||
context_server.name = name.to_string_lossy().into();
|
||||
}
|
||||
|
||||
Ok(context_server)
|
||||
transport,
|
||||
})
|
||||
}
|
||||
|
||||
/// Handles input from the server's stdout.
|
||||
|
@ -228,79 +213,53 @@ impl Client {
|
|||
/// parses them as JSON-RPC responses or notifications, and dispatches them
|
||||
/// to the appropriate handlers. It processes both responses (which are matched
|
||||
/// to pending requests) and notifications (which trigger registered handlers).
|
||||
async fn handle_input<Stdout>(
|
||||
stdout: Stdout,
|
||||
async fn handle_input(
|
||||
transport: Arc<dyn Transport>,
|
||||
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
|
||||
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
|
||||
cx: AsyncApp,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
Stdout: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stdout = BufReader::new(stdout);
|
||||
let mut buffer = String::new();
|
||||
) -> anyhow::Result<()> {
|
||||
let mut receiver = transport.receive();
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
if stdout.read_line(&mut buffer).await? == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let content = buffer.trim();
|
||||
|
||||
if !content.is_empty() {
|
||||
if let Ok(response) = serde_json::from_str::<AnyResponse>(content) {
|
||||
if let Some(handlers) = response_handlers.lock().as_mut() {
|
||||
if let Some(handler) = handlers.remove(&response.id) {
|
||||
handler(Ok(content.to_string()));
|
||||
}
|
||||
}
|
||||
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(content) {
|
||||
let mut notification_handlers = notification_handlers.lock();
|
||||
if let Some(handler) =
|
||||
notification_handlers.get_mut(notification.method.as_str())
|
||||
{
|
||||
handler(notification.params.unwrap_or(Value::Null), cx.clone());
|
||||
while let Some(message) = receiver.next().await {
|
||||
if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
|
||||
if let Some(handlers) = response_handlers.lock().as_mut() {
|
||||
if let Some(handler) = handlers.remove(&response.id) {
|
||||
handler(Ok(message.to_string()));
|
||||
}
|
||||
}
|
||||
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
|
||||
let mut notification_handlers = notification_handlers.lock();
|
||||
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
|
||||
handler(notification.params.unwrap_or(Value::Null), cx.clone());
|
||||
}
|
||||
}
|
||||
|
||||
smol::future::yield_now().await;
|
||||
}
|
||||
|
||||
smol::future::yield_now().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handles the stderr output from the context server.
|
||||
/// Continuously reads and logs any error messages from the server.
|
||||
async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
|
||||
where
|
||||
Stderr: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stderr = BufReader::new(stderr);
|
||||
let mut buffer = String::new();
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
if stderr.read_line(&mut buffer).await? == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
log::warn!("context server stderr: {}", buffer.trim());
|
||||
smol::future::yield_now().await;
|
||||
async fn handle_stderr(transport: Arc<dyn Transport>) -> anyhow::Result<()> {
|
||||
while let Some(err) = transport.receive_err().next().await {
|
||||
log::warn!("context server stderr: {}", err.trim());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handles the output to the context server's stdin.
|
||||
/// This function continuously receives messages from the outbound channel,
|
||||
/// writes them to the server's stdin, and manages the lifecycle of response handlers.
|
||||
async fn handle_output<Stdin>(
|
||||
stdin: Stdin,
|
||||
async fn handle_output(
|
||||
transport: Arc<dyn Transport>,
|
||||
outbound_rx: channel::Receiver<String>,
|
||||
output_done_tx: barrier::Sender,
|
||||
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
Stdin: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stdin = BufWriter::new(stdin);
|
||||
) -> anyhow::Result<()> {
|
||||
let _clear_response_handlers = util::defer({
|
||||
let response_handlers = response_handlers.clone();
|
||||
move || {
|
||||
|
@ -309,10 +268,7 @@ impl Client {
|
|||
});
|
||||
while let Ok(message) = outbound_rx.recv().await {
|
||||
log::trace!("outgoing message: {}", message);
|
||||
|
||||
stdin.write_all(message.as_bytes()).await?;
|
||||
stdin.write_all(b"\n").await?;
|
||||
stdin.flush().await?;
|
||||
transport.send(message).await?;
|
||||
}
|
||||
drop(output_done_tx);
|
||||
Ok(())
|
||||
|
@ -416,14 +372,6 @@ impl Client {
|
|||
}
|
||||
}
|
||||
|
||||
impl Drop for Client {
|
||||
fn drop(&mut self) {
|
||||
if let Some(mut server) = self.server.lock().take() {
|
||||
let _ = server.kill();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ContextServerId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
|
|
|
@ -4,6 +4,7 @@ mod extension_context_server;
|
|||
pub mod manager;
|
||||
pub mod protocol;
|
||||
mod registry;
|
||||
mod transport;
|
||||
pub mod types;
|
||||
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
|
|
16
crates/context_server/src/transport.rs
Normal file
16
crates/context_server/src/transport.rs
Normal file
|
@ -0,0 +1,16 @@
|
|||
mod stdio_transport;
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
|
||||
pub use stdio_transport::*;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Transport: Send + Sync {
|
||||
async fn send(&self, message: String) -> Result<()>;
|
||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
|
||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
|
||||
}
|
140
crates/context_server/src/transport/stdio_transport.rs
Normal file
140
crates/context_server/src/transport/stdio_transport.rs
Normal file
|
@ -0,0 +1,140 @@
|
|||
use std::pin::Pin;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::io::{BufReader, BufWriter};
|
||||
use futures::{
|
||||
AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
|
||||
};
|
||||
use gpui::AsyncApp;
|
||||
use smol::channel;
|
||||
use smol::process::Child;
|
||||
use util::TryFutureExt as _;
|
||||
|
||||
use crate::client::ModelContextServerBinary;
|
||||
use crate::transport::Transport;
|
||||
|
||||
pub struct StdioTransport {
|
||||
stdout_sender: channel::Sender<String>,
|
||||
stdin_receiver: channel::Receiver<String>,
|
||||
stderr_receiver: channel::Receiver<String>,
|
||||
server: Child,
|
||||
}
|
||||
|
||||
impl StdioTransport {
|
||||
pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result<Self> {
|
||||
let mut command = util::command::new_smol_command(&binary.executable);
|
||||
command
|
||||
.args(&binary.args)
|
||||
.envs(binary.env.unwrap_or_default())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true);
|
||||
|
||||
let mut server = command.spawn().with_context(|| {
|
||||
format!(
|
||||
"failed to spawn command. (path={:?}, args={:?})",
|
||||
binary.executable, &binary.args
|
||||
)
|
||||
})?;
|
||||
|
||||
let stdin = server.stdin.take().unwrap();
|
||||
let stdout = server.stdout.take().unwrap();
|
||||
let stderr = server.stderr.take().unwrap();
|
||||
|
||||
let (stdin_sender, stdin_receiver) = channel::unbounded::<String>();
|
||||
let (stdout_sender, stdout_receiver) = channel::unbounded::<String>();
|
||||
let (stderr_sender, stderr_receiver) = channel::unbounded::<String>();
|
||||
|
||||
cx.spawn(|_| Self::handle_output(stdin, stdout_receiver).log_err())
|
||||
.detach();
|
||||
|
||||
cx.spawn(|_| async move { Self::handle_input(stdout, stdin_sender).await })
|
||||
.detach();
|
||||
|
||||
cx.spawn(|_| async move { Self::handle_err(stderr, stderr_sender).await })
|
||||
.detach();
|
||||
|
||||
Ok(Self {
|
||||
stdout_sender,
|
||||
stdin_receiver,
|
||||
stderr_receiver,
|
||||
server,
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_input<Stdout>(stdin: Stdout, inbound_rx: channel::Sender<String>)
|
||||
where
|
||||
Stdout: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stdin = BufReader::new(stdin);
|
||||
let mut line = String::new();
|
||||
while let Ok(n) = stdin.read_line(&mut line).await {
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
if inbound_rx.send(line.clone()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
line.clear();
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_output<Stdin>(
|
||||
stdin: Stdin,
|
||||
outbound_rx: channel::Receiver<String>,
|
||||
) -> Result<()>
|
||||
where
|
||||
Stdin: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stdin = BufWriter::new(stdin);
|
||||
let mut pinned_rx = Box::pin(outbound_rx);
|
||||
while let Some(message) = pinned_rx.next().await {
|
||||
log::trace!("outgoing message: {}", message);
|
||||
|
||||
stdin.write_all(message.as_bytes()).await?;
|
||||
stdin.write_all(b"\n").await?;
|
||||
stdin.flush().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_err<Stderr>(stderr: Stderr, stderr_tx: channel::Sender<String>)
|
||||
where
|
||||
Stderr: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stderr = BufReader::new(stderr);
|
||||
let mut line = String::new();
|
||||
while let Ok(n) = stderr.read_line(&mut line).await {
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
if stderr_tx.send(line.clone()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
line.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for StdioTransport {
|
||||
async fn send(&self, message: String) -> Result<()> {
|
||||
Ok(self.stdout_sender.send(message).await?)
|
||||
}
|
||||
|
||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
Box::pin(self.stdin_receiver.clone())
|
||||
}
|
||||
|
||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
Box::pin(self.stderr_receiver.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StdioTransport {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.server.kill();
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue