From 25a97a6a2be277c2c0897a651658254de34d0bcd Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Mon, 7 Oct 2024 15:08:16 +0200 Subject: [PATCH] ssh: Detect timeouts when server is unresponsive (#18808) To detect connection timeouts we ping the remote server every X seconds and attempt to reconnect if the server failed to respond. Next up is showing some feedback in the UI to make this visible to the user, and stop reconnecting after X amount of retries. Release Notes: - N/A --------- Co-authored-by: Thorsten --- crates/remote/src/ssh_session.rs | 75 +++++++++++++++++++- crates/remote_server/src/headless_project.rs | 10 +++ 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 05208dabd7..0a1cd00992 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -26,6 +26,7 @@ use rpc::{ use smol::{ fs, process::{self, Child, Stdio}, + Timer, }; use std::{ any::TypeId, @@ -36,7 +37,7 @@ use std::{ atomic::{AtomicU32, Ordering::SeqCst}, Arc, }, - time::Instant, + time::{Duration, Instant}, }; use tempfile::TempDir; use util::maybe; @@ -173,7 +174,7 @@ async fn run_cmd(command: &mut process::Command) -> Result { #[cfg(unix)] async fn read_with_timeout( stdout: &mut process::ChildStdout, - timeout: std::time::Duration, + timeout: Duration, output: &mut Vec, ) -> Result<(), std::io::Error> { smol::future::or( @@ -260,6 +261,7 @@ struct SshRemoteClientState { delegate: Arc, forwarder: ChannelForwarder, multiplex_task: Task>, + heartbeat_task: Task>, } pub struct SshRemoteClient { @@ -327,6 +329,7 @@ impl SshRemoteClient { delegate, forwarder: proxy, multiplex_task, + heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx), } }; @@ -353,6 +356,7 @@ impl SshRemoteClient { } fn reconnect(&self, cx: &ModelContext) -> Result<()> { + log::info!("Trying to reconnect to ssh server..."); let Some(state) = self.inner_state.lock().take() else { return Err(anyhow!("reconnect is already in progress")); }; @@ -364,8 +368,10 @@ impl SshRemoteClient { delegate, forwarder: proxy, multiplex_task, + heartbeat_task, } = state; drop(multiplex_task); + drop(heartbeat_task); cx.spawn(|this, mut cx| async move { let (incoming_tx, outgoing_rx) = proxy.into_channels().await; @@ -401,6 +407,7 @@ impl SshRemoteClient { proxy_outgoing_rx, &mut cx, ), + heartbeat_task: Self::heartbeat(this.clone(), &mut cx), }; this.update(&mut cx, |this, _| { @@ -411,6 +418,68 @@ impl SshRemoteClient { Ok(()) } + fn heartbeat(this: WeakModel, cx: &mut AsyncAppContext) -> Task> { + let Ok(client) = this.update(cx, |this, _| this.client.clone()) else { + return Task::ready(Err(anyhow!("SshRemoteClient lost"))); + }; + cx.spawn(|mut cx| { + let this = this.clone(); + async move { + const MAX_MISSED_HEARTBEATS: usize = 5; + const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); + const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5); + + let mut missed_heartbeats = 0; + + let mut timer = Timer::interval(HEARTBEAT_INTERVAL); + loop { + timer.next().await; + + log::info!("Sending heartbeat to server..."); + + let result = smol::future::or( + async { + client.request(proto::Ping {}).await?; + Ok(()) + }, + async { + smol::Timer::after(HEARTBEAT_TIMEOUT).await; + + Err(anyhow!("Timeout detected")) + }, + ) + .await; + + if result.is_err() { + missed_heartbeats += 1; + log::warn!( + "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.", + HEARTBEAT_TIMEOUT, + missed_heartbeats, + MAX_MISSED_HEARTBEATS + ); + } else { + missed_heartbeats = 0; + } + + if missed_heartbeats >= MAX_MISSED_HEARTBEATS { + log::error!( + "Missed last {} hearbeats. Reconnecting...", + missed_heartbeats + ); + + this.update(&mut cx, |this, cx| { + this.reconnect(cx) + .context("failed to reconnect after missing heartbeats") + }) + .context("failed to update weak reference, SshRemoteClient lost?")??; + return Ok(()); + } + } + } + }) + } + fn multiplex( this: WeakModel, mut ssh_proxy_process: Child, @@ -712,7 +781,7 @@ impl SshRemoteConnection { // has completed. let stdout = master_process.stdout.as_mut().unwrap(); let mut output = Vec::new(); - let connection_timeout = std::time::Duration::from_secs(10); + let connection_timeout = Duration::from_secs(10); let result = read_with_timeout(stdout, connection_timeout, &mut output).await; if let Err(e) = result { let error_message = if e.kind() == std::io::ErrorKind::TimedOut { diff --git a/crates/remote_server/src/headless_project.rs b/crates/remote_server/src/headless_project.rs index 66f9ca7ed5..0ad16caacc 100644 --- a/crates/remote_server/src/headless_project.rs +++ b/crates/remote_server/src/headless_project.rs @@ -113,6 +113,7 @@ impl HeadlessProject { client.add_request_handler(cx.weak_model(), Self::handle_list_remote_directory); client.add_request_handler(cx.weak_model(), Self::handle_check_file_exists); client.add_request_handler(cx.weak_model(), Self::handle_shutdown_remote_server); + client.add_request_handler(cx.weak_model(), Self::handle_ping); client.add_model_request_handler(Self::handle_add_worktree); client.add_model_request_handler(Self::handle_open_buffer_by_path); @@ -354,4 +355,13 @@ impl HeadlessProject { Ok(proto::Ack {}) } + + pub async fn handle_ping( + _this: Model, + _envelope: TypedEnvelope, + _cx: AsyncAppContext, + ) -> Result { + log::debug!("Received ping from client"); + Ok(proto::Ack {}) + } }