From 8a912726d7dd4136aae10fcc868e36e59db9c434 Mon Sep 17 00:00:00 2001 From: Mikayla Maki Date: Fri, 18 Oct 2024 15:41:43 -0700 Subject: [PATCH] Fix flakey SSH connection (#19439) Fixes a bug due to the `select!` macro tossing futures that had partially read messages, causing us to desync our message reading with the input stream. Release Notes: - N/A --------- Co-authored-by: Conrad Irwin Co-authored-by: conrad --- crates/remote/src/protocol.rs | 5 +- crates/remote/src/ssh_session.rs | 148 +++++++++++++++---------------- crates/remote_server/src/unix.rs | 23 +++-- 3 files changed, 92 insertions(+), 84 deletions(-) diff --git a/crates/remote/src/protocol.rs b/crates/remote/src/protocol.rs index 311385f73b..787094781d 100644 --- a/crates/remote/src/protocol.rs +++ b/crates/remote/src/protocol.rs @@ -2,7 +2,6 @@ use anyhow::Result; use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use prost::Message as _; use rpc::proto::Envelope; -use std::mem::size_of; #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub struct MessageId(pub u32); @@ -30,8 +29,10 @@ pub async fn read_message( ) -> Result { buffer.resize(MESSAGE_LEN_SIZE, 0); stream.read_exact(buffer).await?; + let len = message_len_from_buffer(buffer); - read_message_with_len(stream, buffer, len).await + let result = read_message_with_len(stream, buffer, len).await; + result } pub async fn write_message( diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index f7ef74ce39..1d8006a060 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -13,8 +13,7 @@ use futures::{ oneshot, }, future::BoxFuture, - select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt, - StreamExt as _, + select_biased, AsyncReadExt as _, Future, FutureExt as _, SinkExt, StreamExt as _, }; use gpui::{ AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task, @@ -922,92 +921,90 @@ impl SshRemoteClient { let mut child_stdout = ssh_proxy_process.stdout.take().unwrap(); let mut child_stdin = ssh_proxy_process.stdin.take().unwrap(); - let io_task = cx.background_executor().spawn(async move { - let mut stdin_buffer = Vec::new(); - let mut stdout_buffer = Vec::new(); - let mut stderr_buffer = Vec::new(); - let mut stderr_offset = 0; + let mut stdin_buffer = Vec::new(); + let mut stdout_buffer = Vec::new(); + let mut stderr_buffer = Vec::new(); + let mut stderr_offset = 0; - loop { - stdout_buffer.resize(MESSAGE_LEN_SIZE, 0); - stderr_buffer.resize(stderr_offset + 1024, 0); + let stdin_task = cx.background_executor().spawn(async move { + while let Some(outgoing) = outgoing_rx.next().await { + write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?; + } + anyhow::Ok(()) + }); - select_biased! { - outgoing = outgoing_rx.next().fuse() => { - let Some(outgoing) = outgoing else { - return anyhow::Ok(None); - }; + let stdout_task = cx.background_executor().spawn({ + let mut connection_activity_tx = connection_activity_tx.clone(); + async move { + loop { + stdout_buffer.resize(MESSAGE_LEN_SIZE, 0); + let len = child_stdout.read(&mut stdout_buffer).await?; - write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?; + if len == 0 { + return anyhow::Ok(()); } - result = child_stdout.read(&mut stdout_buffer).fuse() => { - match result { - Ok(0) => { - child_stdin.close().await?; - outgoing_rx.close(); - let status = ssh_proxy_process.status().await?; - // If we don't have a code, we assume process - // has been killed and treat it as non-zero exit - // code - return Ok(status.code().or_else(|| Some(1))); - } - Ok(len) => { - if len < stdout_buffer.len() { - child_stdout.read_exact(&mut stdout_buffer[len..]).await?; - } - - let message_len = message_len_from_buffer(&stdout_buffer); - match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await { - Ok(envelope) => { - connection_activity_tx.try_send(()).ok(); - incoming_tx.unbounded_send(envelope).ok(); - } - Err(error) => { - log::error!("error decoding message {error:?}"); - } - } - } - Err(error) => { - Err(anyhow!("error reading stdout: {error:?}"))?; - } - } + if len < MESSAGE_LEN_SIZE { + child_stdout.read_exact(&mut stdout_buffer[len..]).await?; } - result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => { - match result { - Ok(len) => { - stderr_offset += len; - let mut start_ix = 0; - while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') { - let line_ix = start_ix + ix; - let content = &stderr_buffer[start_ix..line_ix]; - start_ix = line_ix + 1; - if let Ok(record) = serde_json::from_slice::(content) { - record.log(log::logger()) - } else { - eprintln!("(remote) {}", String::from_utf8_lossy(content)); - } - } - stderr_buffer.drain(0..start_ix); - stderr_offset -= start_ix; - - connection_activity_tx.try_send(()).ok(); - } - Err(error) => { - Err(anyhow!("error reading stderr: {error:?}"))?; - } - } - } + let message_len = message_len_from_buffer(&stdout_buffer); + let envelope = + read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len) + .await?; + connection_activity_tx.try_send(()).ok(); + incoming_tx.unbounded_send(envelope).ok(); } } }); + let stderr_task: Task> = cx.background_executor().spawn(async move { + loop { + stderr_buffer.resize(stderr_offset + 1024, 0); + + let len = child_stderr + .read(&mut stderr_buffer[stderr_offset..]) + .await?; + + stderr_offset += len; + let mut start_ix = 0; + while let Some(ix) = stderr_buffer[start_ix..stderr_offset] + .iter() + .position(|b| b == &b'\n') + { + let line_ix = start_ix + ix; + let content = &stderr_buffer[start_ix..line_ix]; + start_ix = line_ix + 1; + if let Ok(record) = serde_json::from_slice::(content) { + record.log(log::logger()) + } else { + eprintln!("(remote) {}", String::from_utf8_lossy(content)); + } + } + stderr_buffer.drain(0..start_ix); + stderr_offset -= start_ix; + + connection_activity_tx.try_send(()).ok(); + } + }); + cx.spawn(|mut cx| async move { - let result = io_task.await; + let result = futures::select! { + result = stdin_task.fuse() => { + result.context("stdin") + } + result = stdout_task.fuse() => { + result.context("stdout") + } + result = stderr_task.fuse() => { + result.context("stderr") + } + }; match result { - Ok(Some(exit_code)) => { + Ok(_) => { + let exit_code = ssh_proxy_process.status().await?.code().unwrap_or(1); + if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) { match error { ProxyLaunchError::ServerNotRunning => { @@ -1025,7 +1022,6 @@ impl SshRemoteClient { })?; } } - Ok(None) => {} Err(error) => { log::warn!("ssh io task died with error: {:?}. reconnecting...", error); this.update(&mut cx, |this, cx| { @@ -1033,6 +1029,7 @@ impl SshRemoteClient { })?; } } + Ok(()) }) } @@ -1206,6 +1203,7 @@ impl SshRemoteConnection { delegate: Arc, cx: &mut AsyncAppContext, ) -> Result { + use futures::AsyncWriteExt as _; use futures::{io::BufReader, AsyncBufReadExt as _}; use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener}; use util::ResultExt as _; diff --git a/crates/remote_server/src/unix.rs b/crates/remote_server/src/unix.rs index 60b7fc458d..30b2bacd0a 100644 --- a/crates/remote_server/src/unix.rs +++ b/crates/remote_server/src/unix.rs @@ -12,6 +12,7 @@ use language::LanguageRegistry; use node_runtime::{NodeBinaryOptions, NodeRuntime}; use paths::logs_dir; use project::project_settings::ProjectSettings; + use remote::proxy::ProxyLaunchError; use remote::ssh_session::ChannelClient; use remote::{ @@ -213,19 +214,27 @@ fn start_server( let mut input_buffer = Vec::new(); let mut output_buffer = Vec::new(); + + let (mut stdin_msg_tx, mut stdin_msg_rx) = mpsc::unbounded::(); + cx.background_executor().spawn(async move { + while let Ok(msg) = read_message(&mut stdin_stream, &mut input_buffer).await { + if let Err(_) = stdin_msg_tx.send(msg).await { + break; + } + } + }).detach(); + loop { + select_biased! { _ = app_quit_rx.next().fuse() => { return anyhow::Ok(()); } - stdin_message = read_message(&mut stdin_stream, &mut input_buffer).fuse() => { - let message = match stdin_message { - Ok(message) => message, - Err(error) => { - log::warn!("error reading message on stdin: {}. exiting.", error); - break; - } + stdin_message = stdin_msg_rx.next().fuse() => { + let Some(message) = stdin_message else { + log::warn!("error reading message on stdin. exiting."); + break; }; if let Err(error) = incoming_tx.unbounded_send(message) { log::error!("failed to send message to application: {:?}. exiting.", error);