Add LanguageServer::on_io method, for observing JSON sent back and forth

This commit is contained in:
Max Brunsfeld 2023-04-20 17:47:51 -07:00
parent abdccf7393
commit 2dd4920625

View file

@ -20,10 +20,10 @@ use std::{
future::Future, future::Future,
io::Write, io::Write,
path::PathBuf, path::PathBuf,
str::FromStr, str::{self, FromStr as _},
sync::{ sync::{
atomic::{AtomicUsize, Ordering::SeqCst}, atomic::{AtomicUsize, Ordering::SeqCst},
Arc, Arc, Weak,
}, },
}; };
use std::{path::Path, process::Stdio}; use std::{path::Path, process::Stdio};
@ -34,16 +34,18 @@ const CONTENT_LEN_HEADER: &str = "Content-Length: ";
type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>; type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>; type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
type IoHandler = Box<dyn Send + FnMut(bool, &str)>;
pub struct LanguageServer { pub struct LanguageServer {
server_id: LanguageServerId, server_id: LanguageServerId,
next_id: AtomicUsize, next_id: AtomicUsize,
outbound_tx: channel::Sender<Vec<u8>>, outbound_tx: channel::Sender<String>,
name: String, name: String,
capabilities: ServerCapabilities, capabilities: ServerCapabilities,
code_action_kinds: Option<Vec<CodeActionKind>>, code_action_kinds: Option<Vec<CodeActionKind>>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>, response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
executor: Arc<executor::Background>, executor: Arc<executor::Background>,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>, io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@ -56,9 +58,16 @@ pub struct LanguageServer {
#[repr(transparent)] #[repr(transparent)]
pub struct LanguageServerId(pub usize); pub struct LanguageServerId(pub usize);
pub struct Subscription { pub enum Subscription {
method: &'static str, Detached,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, Notification {
method: &'static str,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
},
Io {
id: usize,
io_handlers: Weak<Mutex<HashMap<usize, IoHandler>>>,
},
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -177,33 +186,40 @@ impl LanguageServer {
Stdout: AsyncRead + Unpin + Send + 'static, Stdout: AsyncRead + Unpin + Send + 'static,
F: FnMut(AnyNotification) + 'static + Send, F: FnMut(AnyNotification) + 'static + Send,
{ {
let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>(); let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
let (output_done_tx, output_done_rx) = barrier::channel();
let notification_handlers = let notification_handlers =
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers = let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let io_handlers = Arc::new(Mutex::new(HashMap::default()));
let input_task = cx.spawn(|cx| { let input_task = cx.spawn(|cx| {
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
Self::handle_input( Self::handle_input(
stdout, stdout,
on_unhandled_notification, on_unhandled_notification,
notification_handlers, notification_handlers.clone(),
response_handlers, response_handlers.clone(),
io_handlers.clone(),
cx, cx,
) )
.log_err() .log_err()
}); });
let (output_done_tx, output_done_rx) = barrier::channel();
let output_task = cx.background().spawn({ let output_task = cx.background().spawn({
let response_handlers = response_handlers.clone(); Self::handle_output(
Self::handle_output(stdin, outbound_rx, output_done_tx, response_handlers).log_err() stdin,
outbound_rx,
output_done_tx,
response_handlers.clone(),
io_handlers.clone(),
)
.log_err()
}); });
Self { Self {
server_id, server_id,
notification_handlers, notification_handlers,
response_handlers, response_handlers,
io_handlers,
name: Default::default(), name: Default::default(),
capabilities: Default::default(), capabilities: Default::default(),
code_action_kinds, code_action_kinds,
@ -226,6 +242,7 @@ impl LanguageServer {
mut on_unhandled_notification: F, mut on_unhandled_notification: F,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>, response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
cx: AsyncAppContext, cx: AsyncAppContext,
) -> anyhow::Result<()> ) -> anyhow::Result<()>
where where
@ -252,7 +269,13 @@ impl LanguageServer {
buffer.resize(message_len, 0); buffer.resize(message_len, 0);
stdout.read_exact(&mut buffer).await?; stdout.read_exact(&mut buffer).await?;
log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
if let Ok(message) = str::from_utf8(&buffer) {
log::trace!("incoming message:{}", message);
for handler in io_handlers.lock().values_mut() {
handler(true, message);
}
}
if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) { if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
if let Some(handler) = notification_handlers.lock().get_mut(msg.method) { if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
@ -291,9 +314,10 @@ impl LanguageServer {
async fn handle_output<Stdin>( async fn handle_output<Stdin>(
stdin: Stdin, stdin: Stdin,
outbound_rx: channel::Receiver<Vec<u8>>, outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender, output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>, response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>,
) -> anyhow::Result<()> ) -> anyhow::Result<()>
where where
Stdin: AsyncWrite + Unpin + Send + 'static, Stdin: AsyncWrite + Unpin + Send + 'static,
@ -307,13 +331,17 @@ impl LanguageServer {
}); });
let mut content_len_buffer = Vec::new(); let mut content_len_buffer = Vec::new();
while let Ok(message) = outbound_rx.recv().await { while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message:{}", String::from_utf8_lossy(&message)); log::trace!("outgoing message:{}", message);
for handler in io_handlers.lock().values_mut() {
handler(false, &message);
}
content_len_buffer.clear(); content_len_buffer.clear();
write!(content_len_buffer, "{}", message.len()).unwrap(); write!(content_len_buffer, "{}", message.len()).unwrap();
stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?; stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
stdin.write_all(&content_len_buffer).await?; stdin.write_all(&content_len_buffer).await?;
stdin.write_all("\r\n\r\n".as_bytes()).await?; stdin.write_all("\r\n\r\n".as_bytes()).await?;
stdin.write_all(&message).await?; stdin.write_all(message.as_bytes()).await?;
stdin.flush().await?; stdin.flush().await?;
} }
drop(output_done_tx); drop(output_done_tx);
@ -464,6 +492,19 @@ impl LanguageServer {
self.on_custom_request(T::METHOD, f) self.on_custom_request(T::METHOD, f)
} }
#[must_use]
pub fn on_io<F>(&self, f: F) -> Subscription
where
F: 'static + Send + FnMut(bool, &str),
{
let id = self.next_id.fetch_add(1, SeqCst);
self.io_handlers.lock().insert(id, Box::new(f));
Subscription::Io {
id,
io_handlers: Arc::downgrade(&self.io_handlers),
}
}
pub fn remove_request_handler<T: request::Request>(&self) { pub fn remove_request_handler<T: request::Request>(&self) {
self.notification_handlers.lock().remove(T::METHOD); self.notification_handlers.lock().remove(T::METHOD);
} }
@ -490,7 +531,7 @@ impl LanguageServer {
prev_handler.is_none(), prev_handler.is_none(),
"registered multiple handlers for the same LSP method" "registered multiple handlers for the same LSP method"
); );
Subscription { Subscription::Notification {
method, method,
notification_handlers: self.notification_handlers.clone(), notification_handlers: self.notification_handlers.clone(),
} }
@ -537,7 +578,7 @@ impl LanguageServer {
}, },
}; };
if let Some(response) = if let Some(response) =
serde_json::to_vec(&response).log_err() serde_json::to_string(&response).log_err()
{ {
outbound_tx.try_send(response).ok(); outbound_tx.try_send(response).ok();
} }
@ -560,7 +601,7 @@ impl LanguageServer {
message: error.to_string(), message: error.to_string(),
}), }),
}; };
if let Some(response) = serde_json::to_vec(&response).log_err() { if let Some(response) = serde_json::to_string(&response).log_err() {
outbound_tx.try_send(response).ok(); outbound_tx.try_send(response).ok();
} }
} }
@ -572,7 +613,7 @@ impl LanguageServer {
prev_handler.is_none(), prev_handler.is_none(),
"registered multiple handlers for the same LSP method" "registered multiple handlers for the same LSP method"
); );
Subscription { Subscription::Notification {
method, method,
notification_handlers: self.notification_handlers.clone(), notification_handlers: self.notification_handlers.clone(),
} }
@ -612,14 +653,14 @@ impl LanguageServer {
fn request_internal<T: request::Request>( fn request_internal<T: request::Request>(
next_id: &AtomicUsize, next_id: &AtomicUsize,
response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>, response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
outbound_tx: &channel::Sender<Vec<u8>>, outbound_tx: &channel::Sender<String>,
params: T::Params, params: T::Params,
) -> impl 'static + Future<Output = Result<T::Result>> ) -> impl 'static + Future<Output = Result<T::Result>>
where where
T::Result: 'static + Send, T::Result: 'static + Send,
{ {
let id = next_id.fetch_add(1, SeqCst); let id = next_id.fetch_add(1, SeqCst);
let message = serde_json::to_vec(&Request { let message = serde_json::to_string(&Request {
jsonrpc: JSON_RPC_VERSION, jsonrpc: JSON_RPC_VERSION,
id, id,
method: T::METHOD, method: T::METHOD,
@ -662,10 +703,10 @@ impl LanguageServer {
} }
fn notify_internal<T: notification::Notification>( fn notify_internal<T: notification::Notification>(
outbound_tx: &channel::Sender<Vec<u8>>, outbound_tx: &channel::Sender<String>,
params: T::Params, params: T::Params,
) -> Result<()> { ) -> Result<()> {
let message = serde_json::to_vec(&Notification { let message = serde_json::to_string(&Notification {
jsonrpc: JSON_RPC_VERSION, jsonrpc: JSON_RPC_VERSION,
method: T::METHOD, method: T::METHOD,
params, params,
@ -686,7 +727,7 @@ impl Drop for LanguageServer {
impl Subscription { impl Subscription {
pub fn detach(mut self) { pub fn detach(mut self) {
self.method = ""; *(&mut self) = Self::Detached;
} }
} }
@ -698,7 +739,20 @@ impl fmt::Display for LanguageServerId {
impl Drop for Subscription { impl Drop for Subscription {
fn drop(&mut self) { fn drop(&mut self) {
self.notification_handlers.lock().remove(self.method); match self {
Subscription::Detached => {}
Subscription::Notification {
method,
notification_handlers,
} => {
notification_handlers.lock().remove(method);
}
Subscription::Io { id, io_handlers } => {
if let Some(io_handlers) = io_handlers.upgrade() {
io_handlers.lock().remove(id);
}
}
}
} }
} }