ssh: Limit amount of reconnect attempts (#18819)

Co-Authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-10-08 11:37:54 +02:00 committed by GitHub
parent 910a773b89
commit fa85238c69
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 415 additions and 127 deletions

1
Cargo.lock generated
View file

@ -11885,6 +11885,7 @@ dependencies = [
"pretty_assertions", "pretty_assertions",
"project", "project",
"recent_projects", "recent_projects",
"remote",
"rpc", "rpc",
"serde", "serde",
"settings", "settings",

View file

@ -1263,8 +1263,10 @@ impl Project {
.clone() .clone()
} }
pub fn ssh_is_connected(&self, cx: &AppContext) -> Option<bool> { pub fn ssh_connection_state(&self, cx: &AppContext) -> Option<remote::ConnectionState> {
Some(!self.ssh_client.as_ref()?.read(cx).is_reconnect_underway()) self.ssh_client
.as_ref()
.map(|ssh| ssh.read(cx).connection_state())
} }
pub fn replica_id(&self) -> ReplicaId { pub fn replica_id(&self) -> ReplicaId {

View file

@ -2,4 +2,6 @@ pub mod json_log;
pub mod protocol; pub mod protocol;
pub mod ssh_session; pub mod ssh_session;
pub use ssh_session::{SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient}; pub use ssh_session::{
ConnectionState, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient,
};

View file

@ -31,7 +31,8 @@ use smol::{
use std::{ use std::{
any::TypeId, any::TypeId,
ffi::OsStr, ffi::OsStr,
mem, fmt,
ops::ControlFlow,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::{ sync::{
atomic::{AtomicU32, Ordering::SeqCst}, atomic::{AtomicU32, Ordering::SeqCst},
@ -40,7 +41,7 @@ use std::{
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use tempfile::TempDir; use tempfile::TempDir;
use util::maybe; use util::ResultExt;
#[derive( #[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
@ -234,19 +235,157 @@ impl ChannelForwarder {
} }
} }
struct SshRemoteClientState { const MAX_MISSED_HEARTBEATS: usize = 5;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_RECONNECT_ATTEMPTS: usize = 3;
enum State {
Connecting,
Connected {
ssh_connection: SshRemoteConnection, ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>, delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder, forwarder: ChannelForwarder,
multiplex_task: Task<Result<()>>, multiplex_task: Task<Result<()>>,
heartbeat_task: Task<Result<()>>, heartbeat_task: Task<Result<()>>,
},
HeartbeatMissed {
missed_heartbeats: usize,
ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder,
multiplex_task: Task<Result<()>>,
heartbeat_task: Task<Result<()>>,
},
Reconnecting,
ReconnectFailed {
ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder,
error: anyhow::Error,
attempts: usize,
},
ReconnectExhausted,
}
impl fmt::Display for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Connecting => write!(f, "connecting"),
Self::Connected { .. } => write!(f, "connected"),
Self::Reconnecting => write!(f, "reconnecting"),
Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
}
}
}
impl State {
fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
match self {
Self::Connected { ssh_connection, .. } => Some(ssh_connection),
Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
_ => None,
}
}
fn can_reconnect(&self) -> bool {
matches!(
self,
Self::Connected { .. } | Self::HeartbeatMissed { .. } | Self::ReconnectFailed { .. }
)
}
fn heartbeat_recovered(self) -> Self {
match self {
Self::HeartbeatMissed {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
..
} => Self::Connected {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
},
_ => self,
}
}
fn heartbeat_missed(self) -> Self {
match self {
Self::Connected {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
} => Self::HeartbeatMissed {
missed_heartbeats: 1,
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
},
Self::HeartbeatMissed {
missed_heartbeats,
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
} => Self::HeartbeatMissed {
missed_heartbeats: missed_heartbeats + 1,
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
},
_ => self,
}
}
}
/// The state of the ssh connection.
#[derive(Clone, Copy, Debug)]
pub enum ConnectionState {
Connecting,
Connected,
HeartbeatMissed,
Reconnecting,
Disconnected,
}
impl From<&State> for ConnectionState {
fn from(value: &State) -> Self {
match value {
State::Connecting => Self::Connecting,
State::Connected { .. } => Self::Connected,
State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
State::ReconnectExhausted => Self::Disconnected,
}
}
} }
pub struct SshRemoteClient { pub struct SshRemoteClient {
client: Arc<ChannelClient>, client: Arc<ChannelClient>,
unique_identifier: String, unique_identifier: String,
connection_options: SshConnectionOptions, connection_options: SshConnectionOptions,
inner_state: Arc<Mutex<Option<SshRemoteClientState>>>, state: Arc<Mutex<Option<State>>>,
} }
impl Drop for SshRemoteClient { impl Drop for SshRemoteClient {
@ -266,6 +405,7 @@ impl SshRemoteClient {
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>(); let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>(); let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
let this = cx.new_model(|cx| { let this = cx.new_model(|cx| {
cx.on_app_quit(|this: &mut Self, _| { cx.on_app_quit(|this: &mut Self, _| {
this.shutdown_processes(); this.shutdown_processes();
@ -273,16 +413,14 @@ impl SshRemoteClient {
}) })
.detach(); .detach();
let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
Self { Self {
client, client: client.clone(),
unique_identifier: unique_identifier.clone(), unique_identifier: unique_identifier.clone(),
connection_options: SshConnectionOptions::default(), connection_options: connection_options.clone(),
inner_state: Arc::new(Mutex::new(None)), state: Arc::new(Mutex::new(Some(State::Connecting))),
} }
})?; })?;
let inner_state = {
let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
@ -302,18 +440,22 @@ impl SshRemoteClient {
&mut cx, &mut cx,
); );
SshRemoteClientState { if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
log::error!("failed to establish connection: {}", error);
delegate.set_error(error.to_string(), &mut cx);
return Err(error);
}
let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
this.update(&mut cx, |this, _| {
*this.state.lock() = Some(State::Connected {
ssh_connection, ssh_connection,
delegate, delegate,
forwarder: proxy, forwarder: proxy,
multiplex_task, multiplex_task,
heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx), heartbeat_task,
} });
};
this.update(&mut cx, |this, cx| {
this.inner_state.lock().replace(inner_state);
cx.notify();
})?; })?;
Ok(this) Ok(this)
@ -321,78 +463,192 @@ impl SshRemoteClient {
} }
fn shutdown_processes(&self) { fn shutdown_processes(&self) {
let Some(mut state) = self.inner_state.lock().take() else { let Some(state) = self.state.lock().take() else {
return; return;
}; };
log::info!("shutting down ssh processes"); log::info!("shutting down ssh processes");
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
// child of master_process.
let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(())));
drop(task);
// Now drop the rest of state, which kills master process.
drop(state);
}
fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> { let State::Connected {
log::info!("Trying to reconnect to ssh server...");
let Some(state) = self.inner_state.lock().take() else {
return Err(anyhow!("reconnect is already in progress"));
};
let workspace_identifier = self.unique_identifier.clone();
let SshRemoteClientState {
mut ssh_connection,
delegate,
forwarder: proxy,
multiplex_task, multiplex_task,
heartbeat_task, heartbeat_task,
} = state; ..
} = state
else {
return;
};
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
// child of master_process.
drop(multiplex_task);
// Now drop the rest of state, which kills master process.
drop(heartbeat_task);
}
fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
let mut lock = self.state.lock();
let can_reconnect = lock
.as_ref()
.map(|state| state.can_reconnect())
.unwrap_or(false);
if !can_reconnect {
let error = if let Some(state) = lock.as_ref() {
format!("invalid state, cannot reconnect while in state {state}")
} else {
"no state set".to_string()
};
return Err(anyhow!(error));
}
let state = lock.take().unwrap();
let (attempts, mut ssh_connection, delegate, forwarder) = match state {
State::Connected {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
}
| State::HeartbeatMissed {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
..
} => {
drop(multiplex_task); drop(multiplex_task);
drop(heartbeat_task); drop(heartbeat_task);
(0, ssh_connection, delegate, forwarder)
}
State::ReconnectFailed {
attempts,
ssh_connection,
delegate,
forwarder,
..
} => (attempts, ssh_connection, delegate, forwarder),
State::Connecting | State::Reconnecting | State::ReconnectExhausted => unreachable!(),
};
cx.spawn(|this, mut cx| async move { let attempts = attempts + 1;
let (incoming_tx, outgoing_rx) = proxy.into_channels().await; if attempts > MAX_RECONNECT_ATTEMPTS {
log::error!(
"Failed to reconnect to after {} attempts, giving up",
MAX_RECONNECT_ATTEMPTS
);
*lock = Some(State::ReconnectExhausted);
return Ok(());
}
*lock = Some(State::Reconnecting);
drop(lock);
ssh_connection.master_process.kill()?; log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
ssh_connection
let identifier = self.unique_identifier.clone();
let client = self.client.clone();
let reconnect_task = cx.spawn(|this, mut cx| async move {
macro_rules! failed {
($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
return State::ReconnectFailed {
error: anyhow!($error),
attempts: $attempts,
ssh_connection: $ssh_connection,
delegate: $delegate,
forwarder: $forwarder,
};
};
}
if let Err(error) = ssh_connection.master_process.kill() {
failed!(error, attempts, ssh_connection, delegate, forwarder);
};
if let Err(error) = ssh_connection
.master_process .master_process
.status() .status()
.await .await
.context("Failed to kill ssh process")?; .context("Failed to kill ssh process")
{
failed!(error, attempts, ssh_connection, delegate, forwarder);
}
let connection_options = ssh_connection.socket.connection_options.clone(); let connection_options = ssh_connection.socket.connection_options.clone();
let (ssh_connection, ssh_process) = Self::establish_connection( let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
workspace_identifier, let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let (ssh_connection, ssh_process) = match Self::establish_connection(
identifier,
connection_options, connection_options,
delegate.clone(), delegate.clone(),
&mut cx, &mut cx,
) )
.await?; .await
{
Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
Err(error) => {
failed!(error, attempts, ssh_connection, delegate, forwarder);
}
};
let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = let multiplex_task = Self::multiplex(
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let inner_state = SshRemoteClientState {
ssh_connection,
delegate,
forwarder: proxy,
multiplex_task: Self::multiplex(
this.clone(), this.clone(),
ssh_process, ssh_process,
proxy_incoming_tx, proxy_incoming_tx,
proxy_outgoing_rx, proxy_outgoing_rx,
&mut cx, &mut cx,
), );
heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
failed!(error, attempts, ssh_connection, delegate, forwarder);
}; };
this.update(&mut cx, |this, _| { State::Connected {
this.inner_state.lock().replace(inner_state); ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
}
});
cx.spawn(|this, mut cx| async move {
let new_state = reconnect_task.await;
this.update(&mut cx, |this, cx| {
match &new_state {
State::Connecting
| State::Reconnecting { .. }
| State::HeartbeatMissed { .. } => {}
State::Connected { .. } => {
log::info!("Successfully reconnected");
}
State::ReconnectFailed {
error, attempts, ..
} => {
log::error!(
"Reconnect attempt {} failed: {:?}. Starting new attempt...",
attempts,
error
);
}
State::ReconnectExhausted => {
log::error!("Reconnect attempt failed and all attempts exhausted");
}
}
let reconnect_failed = matches!(new_state, State::ReconnectFailed { .. });
*this.state.lock() = Some(new_state);
cx.notify();
if reconnect_failed {
this.reconnect(cx)
} else {
Ok(())
}
}) })
}) })
.detach(); .detach_and_log_err(cx);
Ok(()) Ok(())
} }
@ -403,10 +659,6 @@ impl SshRemoteClient {
cx.spawn(|mut cx| { cx.spawn(|mut cx| {
let this = this.clone(); let this = this.clone();
async move { async move {
const MAX_MISSED_HEARTBEATS: usize = 5;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
let mut missed_heartbeats = 0; let mut missed_heartbeats = 0;
let mut timer = Timer::interval(HEARTBEAT_INTERVAL); let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
@ -415,19 +667,7 @@ impl SshRemoteClient {
log::info!("Sending heartbeat to server..."); log::info!("Sending heartbeat to server...");
let result = smol::future::or( let result = client.ping(HEARTBEAT_TIMEOUT).await;
async {
client.request(proto::Ping {}).await?;
Ok(())
},
async {
smol::Timer::after(HEARTBEAT_TIMEOUT).await;
Err(anyhow!("Timeout detected"))
},
)
.await;
if result.is_err() { if result.is_err() {
missed_heartbeats += 1; missed_heartbeats += 1;
log::warn!( log::warn!(
@ -440,17 +680,10 @@ impl SshRemoteClient {
missed_heartbeats = 0; missed_heartbeats = 0;
} }
if missed_heartbeats >= MAX_MISSED_HEARTBEATS { let result = this.update(&mut cx, |this, mut cx| {
log::error!( this.handle_heartbeat_result(missed_heartbeats, &mut cx)
"Missed last {} hearbeats. Reconnecting...", })?;
missed_heartbeats if result.is_break() {
);
this.update(&mut cx, |this, cx| {
this.reconnect(cx)
.context("failed to reconnect after missing heartbeats")
})
.context("failed to update weak reference, SshRemoteClient lost?")??;
return Ok(()); return Ok(());
} }
} }
@ -458,6 +691,34 @@ impl SshRemoteClient {
}) })
} }
fn handle_heartbeat_result(
&mut self,
missed_heartbeats: usize,
cx: &mut ModelContext<Self>,
) -> ControlFlow<()> {
let state = self.state.lock().take().unwrap();
self.state.lock().replace(if missed_heartbeats > 0 {
state.heartbeat_missed()
} else {
state.heartbeat_recovered()
});
cx.notify();
if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
log::error!(
"Missed last {} heartbeats. Reconnecting...",
missed_heartbeats
);
self.reconnect(cx)
.context("failed to start reconnect process after missing heartbeats")
.log_err();
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
}
}
fn multiplex( fn multiplex(
this: WeakModel<Self>, this: WeakModel<Self>,
mut ssh_proxy_process: Child, mut ssh_proxy_process: Child,
@ -611,10 +872,11 @@ impl SshRemoteClient {
} }
pub fn ssh_args(&self) -> Option<Vec<String>> { pub fn ssh_args(&self) -> Option<Vec<String>> {
let state = self.inner_state.lock(); self.state
state .lock()
.as_ref() .as_ref()
.map(|state| state.ssh_connection.socket.ssh_args()) .and_then(|state| state.ssh_connection())
.map(|ssh_connection| ssh_connection.socket.ssh_args())
} }
pub fn to_proto_client(&self) -> AnyProtoClient { pub fn to_proto_client(&self) -> AnyProtoClient {
@ -625,8 +887,12 @@ impl SshRemoteClient {
self.connection_options.connection_string() self.connection_options.connection_string()
} }
pub fn is_reconnect_underway(&self) -> bool { pub fn connection_state(&self) -> ConnectionState {
maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default() self.state
.lock()
.as_ref()
.map(ConnectionState::from)
.unwrap_or(ConnectionState::Disconnected)
} }
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
@ -646,7 +912,7 @@ impl SshRemoteClient {
client, client,
unique_identifier: "fake".to_string(), unique_identifier: "fake".to_string(),
connection_options: SshConnectionOptions::default(), connection_options: SshConnectionOptions::default(),
inner_state: Arc::new(Mutex::new(None)), state: Arc::new(Mutex::new(None)),
}) })
}), }),
server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)), server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
@ -1046,6 +1312,20 @@ impl ChannelClient {
} }
} }
pub async fn ping(&self, timeout: Duration) -> Result<()> {
smol::future::or(
async {
self.request(proto::Ping {}).await?;
Ok(())
},
async {
smol::Timer::after(timeout).await;
Err(anyhow!("Timeout detected"))
},
)
.await
}
pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> { pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
log::debug!("ssh send name:{}", T::NAME); log::debug!("ssh send name:{}", T::NAME);
self.send_dynamic(payload.into_envelope(0, None, None)) self.send_dynamic(payload.into_envelope(0, None, None))

View file

@ -41,6 +41,7 @@ gpui.workspace = true
notifications.workspace = true notifications.workspace = true
project.workspace = true project.workspace = true
recent_projects.workspace = true recent_projects.workspace = true
remote.workspace = true
rpc.workspace = true rpc.workspace = true
serde.workspace = true serde.workspace = true
smallvec.workspace = true smallvec.workspace = true

View file

@ -265,10 +265,12 @@ impl TitleBar {
fn render_ssh_project_host(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> { fn render_ssh_project_host(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
let host = self.project.read(cx).ssh_connection_string(cx)?; let host = self.project.read(cx).ssh_connection_string(cx)?;
let meta = SharedString::from(format!("Connected to: {host}")); let meta = SharedString::from(format!("Connected to: {host}"));
let indicator_color = if self.project.read(cx).ssh_is_connected(cx)? { let indicator_color = match self.project.read(cx).ssh_connection_state(cx)? {
Color::Success remote::ConnectionState::Connecting => Color::Info,
} else { remote::ConnectionState::Connected => Color::Success,
Color::Warning remote::ConnectionState::HeartbeatMissed => Color::Warning,
remote::ConnectionState::Reconnecting => Color::Warning,
remote::ConnectionState::Disconnected => Color::Error,
}; };
let indicator = div() let indicator = div()
.absolute() .absolute()