Prevent making further requests after language server shut down

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-01-06 13:40:32 -07:00
parent 8487ae77e7
commit 83c98ce049

View file

@ -40,7 +40,7 @@ pub struct LanguageServer {
name: String,
capabilities: ServerCapabilities,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
executor: Arc<executor::Background>,
#[allow(clippy::type_complexity)]
io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@ -170,12 +170,18 @@ impl LanguageServer {
let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
let notification_handlers =
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::default()));
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let input_task = cx.spawn(|cx| {
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
async move {
let _clear_response_handlers = ClearResponseHandlers(response_handlers.clone());
let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
move || {
response_handlers.lock().take();
}
});
let mut buffer = Vec::new();
loop {
buffer.clear();
@ -200,7 +206,11 @@ impl LanguageServer {
} else if let Ok(AnyResponse { id, error, result }) =
serde_json::from_slice(&buffer)
{
if let Some(handler) = response_handlers.lock().remove(&id) {
if let Some(handler) = response_handlers
.lock()
.as_mut()
.and_then(|handlers| handlers.remove(&id))
{
if let Some(error) = error {
handler(Err(error));
} else if let Some(result) = result {
@ -226,7 +236,12 @@ impl LanguageServer {
let output_task = cx.background().spawn({
let response_handlers = response_handlers.clone();
async move {
let _clear_response_handlers = ClearResponseHandlers(response_handlers);
let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
move || {
response_handlers.lock().take();
}
});
let mut content_len_buffer = Vec::new();
while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
@ -366,7 +381,7 @@ impl LanguageServer {
async move {
log::debug!("language server shutdown started");
shutdown_request.await?;
response_handlers.lock().clear();
response_handlers.lock().take();
exit?;
output_done.recv().await;
log::debug!("language server shutdown finished");
@ -521,7 +536,7 @@ impl LanguageServer {
fn request_internal<T: request::Request>(
next_id: &AtomicUsize,
response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
outbound_tx: &channel::Sender<Vec<u8>>,
params: T::Params,
) -> impl 'static + Future<Output = Result<T::Result>>
@ -537,25 +552,31 @@ impl LanguageServer {
})
.unwrap();
let send = outbound_tx
.try_send(message)
.context("failed to write to language server's stdin");
let (tx, rx) = oneshot::channel();
response_handlers.lock().insert(
let handle_response = response_handlers
.lock()
.as_mut()
.ok_or_else(|| anyhow!("server shut down"))
.map(|handlers| {
handlers.insert(
id,
Box::new(move |result| {
let response = match result {
Ok(response) => {
serde_json::from_str(response).context("failed to deserialize response")
}
Ok(response) => serde_json::from_str(response)
.context("failed to deserialize response"),
Err(error) => Err(anyhow!("{}", error.message)),
};
let _ = tx.send(response);
}),
);
});
let send = outbound_tx
.try_send(message)
.context("failed to write to language server's stdin");
async move {
handle_response?;
send?;
rx.await?
}
@ -762,14 +783,6 @@ impl FakeLanguageServer {
}
}
struct ClearResponseHandlers(Arc<Mutex<HashMap<usize, ResponseHandler>>>);
impl Drop for ClearResponseHandlers {
fn drop(&mut self) {
self.0.lock().clear();
}
}
#[cfg(test)]
mod tests {
use super::*;