From fa85238c6970ce828858535d1a0799af51500b03 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Tue, 8 Oct 2024 11:37:54 +0200 Subject: [PATCH] ssh: Limit amount of reconnect attempts (#18819) Co-Authored-by: Thorsten Release Notes: - N/A --------- Co-authored-by: Thorsten --- Cargo.lock | 1 + crates/project/src/project.rs | 6 +- crates/remote/src/remote.rs | 4 +- crates/remote/src/ssh_session.rs | 520 +++++++++++++++++++++++------- crates/title_bar/Cargo.toml | 1 + crates/title_bar/src/title_bar.rs | 10 +- 6 files changed, 415 insertions(+), 127 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9deb937370..dada7d97f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11885,6 +11885,7 @@ dependencies = [ "pretty_assertions", "project", "recent_projects", + "remote", "rpc", "serde", "settings", diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index a0164dd981..8c2b4bd2a0 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1263,8 +1263,10 @@ impl Project { .clone() } - pub fn ssh_is_connected(&self, cx: &AppContext) -> Option { - Some(!self.ssh_client.as_ref()?.read(cx).is_reconnect_underway()) + pub fn ssh_connection_state(&self, cx: &AppContext) -> Option { + self.ssh_client + .as_ref() + .map(|ssh| ssh.read(cx).connection_state()) } pub fn replica_id(&self) -> ReplicaId { diff --git a/crates/remote/src/remote.rs b/crates/remote/src/remote.rs index c3d9e8f9cc..3cfaf48a88 100644 --- a/crates/remote/src/remote.rs +++ b/crates/remote/src/remote.rs @@ -2,4 +2,6 @@ pub mod json_log; pub mod protocol; pub mod ssh_session; -pub use ssh_session::{SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient}; +pub use ssh_session::{ + ConnectionState, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient, +}; diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 26ef8626ec..cf2a702231 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -31,7 +31,8 @@ use smol::{ use std::{ any::TypeId, ffi::OsStr, - mem, + fmt, + ops::ControlFlow, path::{Path, PathBuf}, sync::{ atomic::{AtomicU32, Ordering::SeqCst}, @@ -40,7 +41,7 @@ use std::{ time::{Duration, Instant}, }; use tempfile::TempDir; -use util::maybe; +use util::ResultExt; #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize, @@ -234,19 +235,157 @@ impl ChannelForwarder { } } -struct SshRemoteClientState { - ssh_connection: SshRemoteConnection, - delegate: Arc, - forwarder: ChannelForwarder, - multiplex_task: Task>, - heartbeat_task: Task>, +const MAX_MISSED_HEARTBEATS: usize = 5; +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); +const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5); + +const MAX_RECONNECT_ATTEMPTS: usize = 3; + +enum State { + Connecting, + Connected { + ssh_connection: SshRemoteConnection, + delegate: Arc, + forwarder: ChannelForwarder, + + multiplex_task: Task>, + heartbeat_task: Task>, + }, + HeartbeatMissed { + missed_heartbeats: usize, + + ssh_connection: SshRemoteConnection, + delegate: Arc, + forwarder: ChannelForwarder, + + multiplex_task: Task>, + heartbeat_task: Task>, + }, + Reconnecting, + ReconnectFailed { + ssh_connection: SshRemoteConnection, + delegate: Arc, + forwarder: ChannelForwarder, + + error: anyhow::Error, + attempts: usize, + }, + ReconnectExhausted, +} + +impl fmt::Display for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Connecting => write!(f, "connecting"), + Self::Connected { .. } => write!(f, "connected"), + Self::Reconnecting => write!(f, "reconnecting"), + Self::ReconnectFailed { .. } => write!(f, "reconnect failed"), + Self::ReconnectExhausted => write!(f, "reconnect exhausted"), + Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"), + } + } +} + +impl State { + fn ssh_connection(&self) -> Option<&SshRemoteConnection> { + match self { + Self::Connected { ssh_connection, .. } => Some(ssh_connection), + Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection), + Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection), + _ => None, + } + } + + fn can_reconnect(&self) -> bool { + matches!( + self, + Self::Connected { .. } | Self::HeartbeatMissed { .. } | Self::ReconnectFailed { .. } + ) + } + + fn heartbeat_recovered(self) -> Self { + match self { + Self::HeartbeatMissed { + ssh_connection, + delegate, + forwarder, + multiplex_task, + heartbeat_task, + .. + } => Self::Connected { + ssh_connection, + delegate, + forwarder, + multiplex_task, + heartbeat_task, + }, + _ => self, + } + } + + fn heartbeat_missed(self) -> Self { + match self { + Self::Connected { + ssh_connection, + delegate, + forwarder, + multiplex_task, + heartbeat_task, + } => Self::HeartbeatMissed { + missed_heartbeats: 1, + ssh_connection, + delegate, + forwarder, + multiplex_task, + heartbeat_task, + }, + Self::HeartbeatMissed { + missed_heartbeats, + ssh_connection, + delegate, + forwarder, + multiplex_task, + heartbeat_task, + } => Self::HeartbeatMissed { + missed_heartbeats: missed_heartbeats + 1, + ssh_connection, + delegate, + forwarder, + multiplex_task, + heartbeat_task, + }, + _ => self, + } + } +} + +/// The state of the ssh connection. +#[derive(Clone, Copy, Debug)] +pub enum ConnectionState { + Connecting, + Connected, + HeartbeatMissed, + Reconnecting, + Disconnected, +} + +impl From<&State> for ConnectionState { + fn from(value: &State) -> Self { + match value { + State::Connecting => Self::Connecting, + State::Connected { .. } => Self::Connected, + State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting, + State::HeartbeatMissed { .. } => Self::HeartbeatMissed, + State::ReconnectExhausted => Self::Disconnected, + } + } } pub struct SshRemoteClient { client: Arc, unique_identifier: String, connection_options: SshConnectionOptions, - inner_state: Arc>>, + state: Arc>>, } impl Drop for SshRemoteClient { @@ -266,6 +405,7 @@ impl SshRemoteClient { let (outgoing_tx, outgoing_rx) = mpsc::unbounded::(); let (incoming_tx, incoming_rx) = mpsc::unbounded::(); + let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?; let this = cx.new_model(|cx| { cx.on_app_quit(|this: &mut Self, _| { this.shutdown_processes(); @@ -273,47 +413,49 @@ impl SshRemoteClient { }) .detach(); - let client = ChannelClient::new(incoming_rx, outgoing_tx, cx); Self { - client, + client: client.clone(), unique_identifier: unique_identifier.clone(), - connection_options: SshConnectionOptions::default(), - inner_state: Arc::new(Mutex::new(None)), + connection_options: connection_options.clone(), + state: Arc::new(Mutex::new(Some(State::Connecting))), } })?; - let inner_state = { - let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = - ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); + let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = + ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); - let (ssh_connection, ssh_proxy_process) = Self::establish_connection( - unique_identifier, - connection_options, - delegate.clone(), - &mut cx, - ) - .await?; + let (ssh_connection, ssh_proxy_process) = Self::establish_connection( + unique_identifier, + connection_options, + delegate.clone(), + &mut cx, + ) + .await?; - let multiplex_task = Self::multiplex( - this.downgrade(), - ssh_proxy_process, - proxy_incoming_tx, - proxy_outgoing_rx, - &mut cx, - ); + let multiplex_task = Self::multiplex( + this.downgrade(), + ssh_proxy_process, + proxy_incoming_tx, + proxy_outgoing_rx, + &mut cx, + ); - SshRemoteClientState { + if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await { + log::error!("failed to establish connection: {}", error); + delegate.set_error(error.to_string(), &mut cx); + return Err(error); + } + + let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx); + + this.update(&mut cx, |this, _| { + *this.state.lock() = Some(State::Connected { ssh_connection, delegate, forwarder: proxy, multiplex_task, - heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx), - } - }; - - this.update(&mut cx, |this, cx| { - this.inner_state.lock().replace(inner_state); - cx.notify(); + heartbeat_task, + }); })?; Ok(this) @@ -321,78 +463,192 @@ impl SshRemoteClient { } fn shutdown_processes(&self) { - let Some(mut state) = self.inner_state.lock().take() else { + let Some(state) = self.state.lock().take() else { return; }; log::info!("shutting down ssh processes"); - // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a - // child of master_process. - let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(()))); - drop(task); - // Now drop the rest of state, which kills master process. - drop(state); - } - fn reconnect(&self, cx: &ModelContext) -> Result<()> { - log::info!("Trying to reconnect to ssh server..."); - let Some(state) = self.inner_state.lock().take() else { - return Err(anyhow!("reconnect is already in progress")); - }; - - let workspace_identifier = self.unique_identifier.clone(); - - let SshRemoteClientState { - mut ssh_connection, - delegate, - forwarder: proxy, + let State::Connected { multiplex_task, heartbeat_task, - } = state; + .. + } = state + else { + return; + }; + // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a + // child of master_process. drop(multiplex_task); + // Now drop the rest of state, which kills master process. drop(heartbeat_task); + } - cx.spawn(|this, mut cx| async move { - let (incoming_tx, outgoing_rx) = proxy.into_channels().await; + fn reconnect(&mut self, cx: &mut ModelContext) -> Result<()> { + let mut lock = self.state.lock(); - ssh_connection.master_process.kill()?; - ssh_connection + let can_reconnect = lock + .as_ref() + .map(|state| state.can_reconnect()) + .unwrap_or(false); + if !can_reconnect { + let error = if let Some(state) = lock.as_ref() { + format!("invalid state, cannot reconnect while in state {state}") + } else { + "no state set".to_string() + }; + return Err(anyhow!(error)); + } + + let state = lock.take().unwrap(); + let (attempts, mut ssh_connection, delegate, forwarder) = match state { + State::Connected { + ssh_connection, + delegate, + forwarder, + multiplex_task, + heartbeat_task, + } + | State::HeartbeatMissed { + ssh_connection, + delegate, + forwarder, + multiplex_task, + heartbeat_task, + .. + } => { + drop(multiplex_task); + drop(heartbeat_task); + (0, ssh_connection, delegate, forwarder) + } + State::ReconnectFailed { + attempts, + ssh_connection, + delegate, + forwarder, + .. + } => (attempts, ssh_connection, delegate, forwarder), + State::Connecting | State::Reconnecting | State::ReconnectExhausted => unreachable!(), + }; + + let attempts = attempts + 1; + if attempts > MAX_RECONNECT_ATTEMPTS { + log::error!( + "Failed to reconnect to after {} attempts, giving up", + MAX_RECONNECT_ATTEMPTS + ); + *lock = Some(State::ReconnectExhausted); + return Ok(()); + } + *lock = Some(State::Reconnecting); + drop(lock); + + log::info!("Trying to reconnect to ssh server... Attempt {}", attempts); + + let identifier = self.unique_identifier.clone(); + let client = self.client.clone(); + let reconnect_task = cx.spawn(|this, mut cx| async move { + macro_rules! failed { + ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => { + return State::ReconnectFailed { + error: anyhow!($error), + attempts: $attempts, + ssh_connection: $ssh_connection, + delegate: $delegate, + forwarder: $forwarder, + }; + }; + } + + if let Err(error) = ssh_connection.master_process.kill() { + failed!(error, attempts, ssh_connection, delegate, forwarder); + }; + + if let Err(error) = ssh_connection .master_process .status() .await - .context("Failed to kill ssh process")?; + .context("Failed to kill ssh process") + { + failed!(error, attempts, ssh_connection, delegate, forwarder); + } let connection_options = ssh_connection.socket.connection_options.clone(); - let (ssh_connection, ssh_process) = Self::establish_connection( - workspace_identifier, + let (incoming_tx, outgoing_rx) = forwarder.into_channels().await; + let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) = + ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); + + let (ssh_connection, ssh_process) = match Self::establish_connection( + identifier, connection_options, delegate.clone(), &mut cx, ) - .await?; - - let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = - ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); - - let inner_state = SshRemoteClientState { - ssh_connection, - delegate, - forwarder: proxy, - multiplex_task: Self::multiplex( - this.clone(), - ssh_process, - proxy_incoming_tx, - proxy_outgoing_rx, - &mut cx, - ), - heartbeat_task: Self::heartbeat(this.clone(), &mut cx), + .await + { + Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process), + Err(error) => { + failed!(error, attempts, ssh_connection, delegate, forwarder); + } }; - this.update(&mut cx, |this, _| { - this.inner_state.lock().replace(inner_state); + let multiplex_task = Self::multiplex( + this.clone(), + ssh_process, + proxy_incoming_tx, + proxy_outgoing_rx, + &mut cx, + ); + + if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await { + failed!(error, attempts, ssh_connection, delegate, forwarder); + }; + + State::Connected { + ssh_connection, + delegate, + forwarder, + multiplex_task, + heartbeat_task: Self::heartbeat(this.clone(), &mut cx), + } + }); + + cx.spawn(|this, mut cx| async move { + let new_state = reconnect_task.await; + this.update(&mut cx, |this, cx| { + match &new_state { + State::Connecting + | State::Reconnecting { .. } + | State::HeartbeatMissed { .. } => {} + State::Connected { .. } => { + log::info!("Successfully reconnected"); + } + State::ReconnectFailed { + error, attempts, .. + } => { + log::error!( + "Reconnect attempt {} failed: {:?}. Starting new attempt...", + attempts, + error + ); + } + State::ReconnectExhausted => { + log::error!("Reconnect attempt failed and all attempts exhausted"); + } + } + + let reconnect_failed = matches!(new_state, State::ReconnectFailed { .. }); + *this.state.lock() = Some(new_state); + cx.notify(); + if reconnect_failed { + this.reconnect(cx) + } else { + Ok(()) + } }) }) - .detach(); + .detach_and_log_err(cx); + Ok(()) } @@ -403,10 +659,6 @@ impl SshRemoteClient { cx.spawn(|mut cx| { let this = this.clone(); async move { - const MAX_MISSED_HEARTBEATS: usize = 5; - const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); - const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5); - let mut missed_heartbeats = 0; let mut timer = Timer::interval(HEARTBEAT_INTERVAL); @@ -415,19 +667,7 @@ impl SshRemoteClient { log::info!("Sending heartbeat to server..."); - let result = smol::future::or( - async { - client.request(proto::Ping {}).await?; - Ok(()) - }, - async { - smol::Timer::after(HEARTBEAT_TIMEOUT).await; - - Err(anyhow!("Timeout detected")) - }, - ) - .await; - + let result = client.ping(HEARTBEAT_TIMEOUT).await; if result.is_err() { missed_heartbeats += 1; log::warn!( @@ -440,17 +680,10 @@ impl SshRemoteClient { missed_heartbeats = 0; } - if missed_heartbeats >= MAX_MISSED_HEARTBEATS { - log::error!( - "Missed last {} hearbeats. Reconnecting...", - missed_heartbeats - ); - - this.update(&mut cx, |this, cx| { - this.reconnect(cx) - .context("failed to reconnect after missing heartbeats") - }) - .context("failed to update weak reference, SshRemoteClient lost?")??; + let result = this.update(&mut cx, |this, mut cx| { + this.handle_heartbeat_result(missed_heartbeats, &mut cx) + })?; + if result.is_break() { return Ok(()); } } @@ -458,6 +691,34 @@ impl SshRemoteClient { }) } + fn handle_heartbeat_result( + &mut self, + missed_heartbeats: usize, + cx: &mut ModelContext, + ) -> ControlFlow<()> { + let state = self.state.lock().take().unwrap(); + self.state.lock().replace(if missed_heartbeats > 0 { + state.heartbeat_missed() + } else { + state.heartbeat_recovered() + }); + cx.notify(); + + if missed_heartbeats >= MAX_MISSED_HEARTBEATS { + log::error!( + "Missed last {} heartbeats. Reconnecting...", + missed_heartbeats + ); + + self.reconnect(cx) + .context("failed to start reconnect process after missing heartbeats") + .log_err(); + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + } + fn multiplex( this: WeakModel, mut ssh_proxy_process: Child, @@ -611,10 +872,11 @@ impl SshRemoteClient { } pub fn ssh_args(&self) -> Option> { - let state = self.inner_state.lock(); - state + self.state + .lock() .as_ref() - .map(|state| state.ssh_connection.socket.ssh_args()) + .and_then(|state| state.ssh_connection()) + .map(|ssh_connection| ssh_connection.socket.ssh_args()) } pub fn to_proto_client(&self) -> AnyProtoClient { @@ -625,8 +887,12 @@ impl SshRemoteClient { self.connection_options.connection_string() } - pub fn is_reconnect_underway(&self) -> bool { - maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default() + pub fn connection_state(&self) -> ConnectionState { + self.state + .lock() + .as_ref() + .map(ConnectionState::from) + .unwrap_or(ConnectionState::Disconnected) } #[cfg(any(test, feature = "test-support"))] @@ -646,7 +912,7 @@ impl SshRemoteClient { client, unique_identifier: "fake".to_string(), connection_options: SshConnectionOptions::default(), - inner_state: Arc::new(Mutex::new(None)), + state: Arc::new(Mutex::new(None)), }) }), server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)), @@ -1046,6 +1312,20 @@ impl ChannelClient { } } + pub async fn ping(&self, timeout: Duration) -> Result<()> { + smol::future::or( + async { + self.request(proto::Ping {}).await?; + Ok(()) + }, + async { + smol::Timer::after(timeout).await; + Err(anyhow!("Timeout detected")) + }, + ) + .await + } + pub fn send(&self, payload: T) -> Result<()> { log::debug!("ssh send name:{}", T::NAME); self.send_dynamic(payload.into_envelope(0, None, None)) diff --git a/crates/title_bar/Cargo.toml b/crates/title_bar/Cargo.toml index c837b74dca..e4d3d7fc5b 100644 --- a/crates/title_bar/Cargo.toml +++ b/crates/title_bar/Cargo.toml @@ -41,6 +41,7 @@ gpui.workspace = true notifications.workspace = true project.workspace = true recent_projects.workspace = true +remote.workspace = true rpc.workspace = true serde.workspace = true smallvec.workspace = true diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 52dab68a2a..da0179fd64 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -265,10 +265,12 @@ impl TitleBar { fn render_ssh_project_host(&self, cx: &mut ViewContext) -> Option { let host = self.project.read(cx).ssh_connection_string(cx)?; let meta = SharedString::from(format!("Connected to: {host}")); - let indicator_color = if self.project.read(cx).ssh_is_connected(cx)? { - Color::Success - } else { - Color::Warning + let indicator_color = match self.project.read(cx).ssh_connection_state(cx)? { + remote::ConnectionState::Connecting => Color::Info, + remote::ConnectionState::Connected => Color::Success, + remote::ConnectionState::HeartbeatMissed => Color::Warning, + remote::ConnectionState::Reconnecting => Color::Warning, + remote::ConnectionState::Disconnected => Color::Error, }; let indicator = div() .absolute()