Prevent making further requests after language server shut down

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-01-06 13:40:32 -07:00
parent 8487ae77e7
commit 83c98ce049

View file

@ -40,7 +40,7 @@ pub struct LanguageServer {
name: String, name: String,
capabilities: ServerCapabilities, capabilities: ServerCapabilities,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>, response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
executor: Arc<executor::Background>, executor: Arc<executor::Background>,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>, io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@ -170,12 +170,18 @@ impl LanguageServer {
let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>(); let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
let notification_handlers = let notification_handlers =
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::default())); let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let input_task = cx.spawn(|cx| { let input_task = cx.spawn(|cx| {
let notification_handlers = notification_handlers.clone(); let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone(); let response_handlers = response_handlers.clone();
async move { async move {
let _clear_response_handlers = ClearResponseHandlers(response_handlers.clone()); let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
move || {
response_handlers.lock().take();
}
});
let mut buffer = Vec::new(); let mut buffer = Vec::new();
loop { loop {
buffer.clear(); buffer.clear();
@ -200,7 +206,11 @@ impl LanguageServer {
} else if let Ok(AnyResponse { id, error, result }) = } else if let Ok(AnyResponse { id, error, result }) =
serde_json::from_slice(&buffer) serde_json::from_slice(&buffer)
{ {
if let Some(handler) = response_handlers.lock().remove(&id) { if let Some(handler) = response_handlers
.lock()
.as_mut()
.and_then(|handlers| handlers.remove(&id))
{
if let Some(error) = error { if let Some(error) = error {
handler(Err(error)); handler(Err(error));
} else if let Some(result) = result { } else if let Some(result) = result {
@ -226,7 +236,12 @@ impl LanguageServer {
let output_task = cx.background().spawn({ let output_task = cx.background().spawn({
let response_handlers = response_handlers.clone(); let response_handlers = response_handlers.clone();
async move { async move {
let _clear_response_handlers = ClearResponseHandlers(response_handlers); let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
move || {
response_handlers.lock().take();
}
});
let mut content_len_buffer = Vec::new(); let mut content_len_buffer = Vec::new();
while let Ok(message) = outbound_rx.recv().await { while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message:{}", String::from_utf8_lossy(&message)); log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
@ -366,7 +381,7 @@ impl LanguageServer {
async move { async move {
log::debug!("language server shutdown started"); log::debug!("language server shutdown started");
shutdown_request.await?; shutdown_request.await?;
response_handlers.lock().clear(); response_handlers.lock().take();
exit?; exit?;
output_done.recv().await; output_done.recv().await;
log::debug!("language server shutdown finished"); log::debug!("language server shutdown finished");
@ -521,7 +536,7 @@ impl LanguageServer {
fn request_internal<T: request::Request>( fn request_internal<T: request::Request>(
next_id: &AtomicUsize, next_id: &AtomicUsize,
response_handlers: &Mutex<HashMap<usize, ResponseHandler>>, response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
outbound_tx: &channel::Sender<Vec<u8>>, outbound_tx: &channel::Sender<Vec<u8>>,
params: T::Params, params: T::Params,
) -> impl 'static + Future<Output = Result<T::Result>> ) -> impl 'static + Future<Output = Result<T::Result>>
@ -537,25 +552,31 @@ impl LanguageServer {
}) })
.unwrap(); .unwrap();
let (tx, rx) = oneshot::channel();
let handle_response = response_handlers
.lock()
.as_mut()
.ok_or_else(|| anyhow!("server shut down"))
.map(|handlers| {
handlers.insert(
id,
Box::new(move |result| {
let response = match result {
Ok(response) => serde_json::from_str(response)
.context("failed to deserialize response"),
Err(error) => Err(anyhow!("{}", error.message)),
};
let _ = tx.send(response);
}),
);
});
let send = outbound_tx let send = outbound_tx
.try_send(message) .try_send(message)
.context("failed to write to language server's stdin"); .context("failed to write to language server's stdin");
let (tx, rx) = oneshot::channel();
response_handlers.lock().insert(
id,
Box::new(move |result| {
let response = match result {
Ok(response) => {
serde_json::from_str(response).context("failed to deserialize response")
}
Err(error) => Err(anyhow!("{}", error.message)),
};
let _ = tx.send(response);
}),
);
async move { async move {
handle_response?;
send?; send?;
rx.await? rx.await?
} }
@ -762,14 +783,6 @@ impl FakeLanguageServer {
} }
} }
struct ClearResponseHandlers(Arc<Mutex<HashMap<usize, ResponseHandler>>>);
impl Drop for ClearResponseHandlers {
fn drop(&mut self) {
self.0.lock().clear();
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;