Allow both integer and string request IDs in LSP (#7662)

Zed's LSP support expects request messages to have integer ID, Metals
LSP uses string. According to specification, both is acceptable:

interface RequestMessage extends Message {

	/**
	 * The request id.
	 */
	id: integer | string;

...
This pull requests modifies the types and serialization/deserialization
so that string IDs are accepted.

Release Notes:

- Make Zed LSP request ids compliant to the LSP specification
This commit is contained in:
Michal Příhoda 2024-02-15 19:26:23 +01:00 committed by GitHub
parent 2dffc5f6e1
commit f01763a1fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -23,7 +23,7 @@ use std::{
path::PathBuf, path::PathBuf,
str::{self, FromStr as _}, str::{self, FromStr as _},
sync::{ sync::{
atomic::{AtomicUsize, Ordering::SeqCst}, atomic::{AtomicI32, Ordering::SeqCst},
Arc, Weak, Arc, Weak,
}, },
time::{Duration, Instant}, time::{Duration, Instant},
@ -36,7 +36,7 @@ const JSON_RPC_VERSION: &str = "2.0";
const CONTENT_LEN_HEADER: &str = "Content-Length: "; const CONTENT_LEN_HEADER: &str = "Content-Length: ";
const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2); const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2);
type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>; type NotificationHandler = Box<dyn Send + FnMut(Option<RequestId>, &str, AsyncAppContext)>;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>; type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type IoHandler = Box<dyn Send + FnMut(IoKind, &str)>; type IoHandler = Box<dyn Send + FnMut(IoKind, &str)>;
@ -59,14 +59,14 @@ pub struct LanguageServerBinary {
/// A running language server process. /// A running language server process.
pub struct LanguageServer { pub struct LanguageServer {
server_id: LanguageServerId, server_id: LanguageServerId,
next_id: AtomicUsize, next_id: AtomicI32,
outbound_tx: channel::Sender<String>, outbound_tx: channel::Sender<String>,
name: String, name: String,
capabilities: ServerCapabilities, capabilities: ServerCapabilities,
code_action_kinds: Option<Vec<CodeActionKind>>, code_action_kinds: Option<Vec<CodeActionKind>>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>, response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>, io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
executor: BackgroundExecutor, executor: BackgroundExecutor,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>, io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@ -87,18 +87,28 @@ pub enum Subscription {
notification_handlers: Option<Arc<Mutex<HashMap<&'static str, NotificationHandler>>>>, notification_handlers: Option<Arc<Mutex<HashMap<&'static str, NotificationHandler>>>>,
}, },
Io { Io {
id: usize, id: i32,
io_handlers: Option<Weak<Mutex<HashMap<usize, IoHandler>>>>, io_handlers: Option<Weak<Mutex<HashMap<i32, IoHandler>>>>,
}, },
} }
/// Language server protocol RPC request message ID.
///
/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
pub enum RequestId {
Int(i32),
Str(String),
}
/// Language server protocol RPC request message. /// Language server protocol RPC request message.
/// ///
/// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage) /// [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage)
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct Request<'a, T> { pub struct Request<'a, T> {
jsonrpc: &'static str, jsonrpc: &'static str,
id: usize, id: RequestId,
method: &'a str, method: &'a str,
params: T, params: T,
} }
@ -107,7 +117,7 @@ pub struct Request<'a, T> {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct AnyResponse<'a> { struct AnyResponse<'a> {
jsonrpc: &'a str, jsonrpc: &'a str,
id: usize, id: RequestId,
#[serde(default)] #[serde(default)]
error: Option<Error>, error: Option<Error>,
#[serde(borrow)] #[serde(borrow)]
@ -120,7 +130,7 @@ struct AnyResponse<'a> {
#[derive(Serialize)] #[derive(Serialize)]
struct Response<T> { struct Response<T> {
jsonrpc: &'static str, jsonrpc: &'static str,
id: usize, id: RequestId,
result: Option<T>, result: Option<T>,
error: Option<Error>, error: Option<Error>,
} }
@ -140,7 +150,7 @@ struct Notification<'a, T> {
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
struct AnyNotification<'a> { struct AnyNotification<'a> {
#[serde(default)] #[serde(default)]
id: Option<usize>, id: Option<RequestId>,
#[serde(borrow)] #[serde(borrow)]
method: &'a str, method: &'a str,
#[serde(borrow, default)] #[serde(borrow, default)]
@ -305,8 +315,8 @@ impl LanguageServer {
stdout: Stdout, stdout: Stdout,
mut on_unhandled_notification: F, mut on_unhandled_notification: F,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>, response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>, io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
cx: AsyncAppContext, cx: AsyncAppContext,
) -> anyhow::Result<()> ) -> anyhow::Result<()>
where where
@ -387,7 +397,7 @@ impl LanguageServer {
async fn handle_stderr<Stderr>( async fn handle_stderr<Stderr>(
stderr: Stderr, stderr: Stderr,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>, io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
stderr_capture: Arc<Mutex<Option<String>>>, stderr_capture: Arc<Mutex<Option<String>>>,
) -> anyhow::Result<()> ) -> anyhow::Result<()>
where where
@ -424,8 +434,8 @@ impl LanguageServer {
stdin: Stdin, stdin: Stdin,
outbound_rx: channel::Receiver<String>, outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender, output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>, response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
io_handlers: Arc<Mutex<HashMap<usize, IoHandler>>>, io_handlers: Arc<Mutex<HashMap<i32, IoHandler>>>,
) -> anyhow::Result<()> ) -> anyhow::Result<()>
where where
Stdin: AsyncWrite + Unpin + Send + 'static, Stdin: AsyncWrite + Unpin + Send + 'static,
@ -621,7 +631,7 @@ impl LanguageServer {
pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> { pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Option<()>>> {
if let Some(tasks) = self.io_tasks.lock().take() { if let Some(tasks) = self.io_tasks.lock().take() {
let response_handlers = self.response_handlers.clone(); let response_handlers = self.response_handlers.clone();
let next_id = AtomicUsize::new(self.next_id.load(SeqCst)); let next_id = AtomicI32::new(self.next_id.load(SeqCst));
let outbound_tx = self.outbound_tx.clone(); let outbound_tx = self.outbound_tx.clone();
let executor = self.executor.clone(); let executor = self.executor.clone();
let mut output_done = self.output_done_rx.lock().take().unwrap(); let mut output_done = self.output_done_rx.lock().take().unwrap();
@ -850,8 +860,8 @@ impl LanguageServer {
} }
fn request_internal<T: request::Request>( fn request_internal<T: request::Request>(
next_id: &AtomicUsize, next_id: &AtomicI32,
response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>, response_handlers: &Mutex<Option<HashMap<RequestId, ResponseHandler>>>,
outbound_tx: &channel::Sender<String>, outbound_tx: &channel::Sender<String>,
executor: &BackgroundExecutor, executor: &BackgroundExecutor,
params: T::Params, params: T::Params,
@ -862,7 +872,7 @@ impl LanguageServer {
let id = next_id.fetch_add(1, SeqCst); let id = next_id.fetch_add(1, SeqCst);
let message = serde_json::to_string(&Request { let message = serde_json::to_string(&Request {
jsonrpc: JSON_RPC_VERSION, jsonrpc: JSON_RPC_VERSION,
id, id: RequestId::Int(id),
method: T::METHOD, method: T::METHOD,
params, params,
}) })
@ -876,7 +886,7 @@ impl LanguageServer {
.map(|handlers| { .map(|handlers| {
let executor = executor.clone(); let executor = executor.clone();
handlers.insert( handlers.insert(
id, RequestId::Int(id),
Box::new(move |result| { Box::new(move |result| {
executor executor
.spawn(async move { .spawn(async move {
@ -1340,4 +1350,31 @@ mod tests {
b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n" b"Content-Length: 1235\r\nContent-Type: application/vscode-jsonrpc\r\n\r\n"
); );
} }
#[gpui::test]
fn test_deserialize_string_digit_id() {
let json = r#"{"jsonrpc":"2.0","id":"2","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
let notification = serde_json::from_str::<AnyNotification>(json)
.expect("message with string id should be parsed");
let expected_id = RequestId::Str("2".to_string());
assert_eq!(notification.id, Some(expected_id));
}
#[gpui::test]
fn test_deserialize_string_id() {
let json = r#"{"jsonrpc":"2.0","id":"anythingAtAll","method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
let notification = serde_json::from_str::<AnyNotification>(json)
.expect("message with string id should be parsed");
let expected_id = RequestId::Str("anythingAtAll".to_string());
assert_eq!(notification.id, Some(expected_id));
}
#[gpui::test]
fn test_deserialize_int_id() {
let json = r#"{"jsonrpc":"2.0","id":2,"method":"workspace/configuration","params":{"items":[{"scopeUri":"file:///Users/mph/Devel/personal/hello-scala/","section":"metals"}]}}"#;
let notification = serde_json::from_str::<AnyNotification>(json)
.expect("message with string id should be parsed");
let expected_id = RequestId::Int(2);
assert_eq!(notification.id, Some(expected_id));
}
} }