diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index 6e7def92e9..b7199a5287 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -40,7 +40,7 @@ pub struct LanguageServer { name: String, capabilities: ServerCapabilities, notification_handlers: Arc>>, - response_handlers: Arc>>, + response_handlers: Arc>>>, executor: Arc, #[allow(clippy::type_complexity)] io_tasks: Mutex>, Task>)>>, @@ -170,12 +170,18 @@ impl LanguageServer { let (outbound_tx, outbound_rx) = channel::unbounded::>(); let notification_handlers = Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); - let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::default())); + let response_handlers = + Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); let input_task = cx.spawn(|cx| { let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); async move { - let _clear_response_handlers = ClearResponseHandlers(response_handlers.clone()); + let _clear_response_handlers = util::defer({ + let response_handlers = response_handlers.clone(); + move || { + response_handlers.lock().take(); + } + }); let mut buffer = Vec::new(); loop { buffer.clear(); @@ -200,7 +206,11 @@ impl LanguageServer { } else if let Ok(AnyResponse { id, error, result }) = serde_json::from_slice(&buffer) { - if let Some(handler) = response_handlers.lock().remove(&id) { + if let Some(handler) = response_handlers + .lock() + .as_mut() + .and_then(|handlers| handlers.remove(&id)) + { if let Some(error) = error { handler(Err(error)); } else if let Some(result) = result { @@ -226,7 +236,12 @@ impl LanguageServer { let output_task = cx.background().spawn({ let response_handlers = response_handlers.clone(); async move { - let _clear_response_handlers = ClearResponseHandlers(response_handlers); + let _clear_response_handlers = util::defer({ + let response_handlers = response_handlers.clone(); + move || { + response_handlers.lock().take(); + } + }); let mut content_len_buffer = Vec::new(); while let Ok(message) = outbound_rx.recv().await { log::trace!("outgoing message:{}", String::from_utf8_lossy(&message)); @@ -366,7 +381,7 @@ impl LanguageServer { async move { log::debug!("language server shutdown started"); shutdown_request.await?; - response_handlers.lock().clear(); + response_handlers.lock().take(); exit?; output_done.recv().await; log::debug!("language server shutdown finished"); @@ -521,7 +536,7 @@ impl LanguageServer { fn request_internal( next_id: &AtomicUsize, - response_handlers: &Mutex>, + response_handlers: &Mutex>>, outbound_tx: &channel::Sender>, params: T::Params, ) -> impl 'static + Future> @@ -537,25 +552,31 @@ impl LanguageServer { }) .unwrap(); + let (tx, rx) = oneshot::channel(); + let handle_response = response_handlers + .lock() + .as_mut() + .ok_or_else(|| anyhow!("server shut down")) + .map(|handlers| { + handlers.insert( + id, + Box::new(move |result| { + let response = match result { + Ok(response) => serde_json::from_str(response) + .context("failed to deserialize response"), + Err(error) => Err(anyhow!("{}", error.message)), + }; + let _ = tx.send(response); + }), + ); + }); + let send = outbound_tx .try_send(message) .context("failed to write to language server's stdin"); - let (tx, rx) = oneshot::channel(); - response_handlers.lock().insert( - id, - Box::new(move |result| { - let response = match result { - Ok(response) => { - serde_json::from_str(response).context("failed to deserialize response") - } - Err(error) => Err(anyhow!("{}", error.message)), - }; - let _ = tx.send(response); - }), - ); - async move { + handle_response?; send?; rx.await? } @@ -762,14 +783,6 @@ impl FakeLanguageServer { } } -struct ClearResponseHandlers(Arc>>); - -impl Drop for ClearResponseHandlers { - fn drop(&mut self) { - self.0.lock().clear(); - } -} - #[cfg(test)] mod tests { use super::*;