ssh: Limit amount of reconnect attempts (#18819)

Co-Authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-10-08 11:37:54 +02:00 committed by GitHub
parent 910a773b89
commit fa85238c69
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 415 additions and 127 deletions

1
Cargo.lock generated
View file

@ -11885,6 +11885,7 @@ dependencies = [
"pretty_assertions",
"project",
"recent_projects",
"remote",
"rpc",
"serde",
"settings",

View file

@ -1263,8 +1263,10 @@ impl Project {
.clone()
}
pub fn ssh_is_connected(&self, cx: &AppContext) -> Option<bool> {
Some(!self.ssh_client.as_ref()?.read(cx).is_reconnect_underway())
pub fn ssh_connection_state(&self, cx: &AppContext) -> Option<remote::ConnectionState> {
self.ssh_client
.as_ref()
.map(|ssh| ssh.read(cx).connection_state())
}
pub fn replica_id(&self) -> ReplicaId {

View file

@ -2,4 +2,6 @@ pub mod json_log;
pub mod protocol;
pub mod ssh_session;
pub use ssh_session::{SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient};
pub use ssh_session::{
ConnectionState, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient,
};

View file

@ -31,7 +31,8 @@ use smol::{
use std::{
any::TypeId,
ffi::OsStr,
mem,
fmt,
ops::ControlFlow,
path::{Path, PathBuf},
sync::{
atomic::{AtomicU32, Ordering::SeqCst},
@ -40,7 +41,7 @@ use std::{
time::{Duration, Instant},
};
use tempfile::TempDir;
use util::maybe;
use util::ResultExt;
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
@ -234,19 +235,157 @@ impl ChannelForwarder {
}
}
struct SshRemoteClientState {
ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder,
multiplex_task: Task<Result<()>>,
heartbeat_task: Task<Result<()>>,
const MAX_MISSED_HEARTBEATS: usize = 5;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_RECONNECT_ATTEMPTS: usize = 3;
enum State {
Connecting,
Connected {
ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder,
multiplex_task: Task<Result<()>>,
heartbeat_task: Task<Result<()>>,
},
HeartbeatMissed {
missed_heartbeats: usize,
ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder,
multiplex_task: Task<Result<()>>,
heartbeat_task: Task<Result<()>>,
},
Reconnecting,
ReconnectFailed {
ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder,
error: anyhow::Error,
attempts: usize,
},
ReconnectExhausted,
}
impl fmt::Display for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Connecting => write!(f, "connecting"),
Self::Connected { .. } => write!(f, "connected"),
Self::Reconnecting => write!(f, "reconnecting"),
Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
}
}
}
impl State {
fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
match self {
Self::Connected { ssh_connection, .. } => Some(ssh_connection),
Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
_ => None,
}
}
fn can_reconnect(&self) -> bool {
matches!(
self,
Self::Connected { .. } | Self::HeartbeatMissed { .. } | Self::ReconnectFailed { .. }
)
}
fn heartbeat_recovered(self) -> Self {
match self {
Self::HeartbeatMissed {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
..
} => Self::Connected {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
},
_ => self,
}
}
fn heartbeat_missed(self) -> Self {
match self {
Self::Connected {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
} => Self::HeartbeatMissed {
missed_heartbeats: 1,
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
},
Self::HeartbeatMissed {
missed_heartbeats,
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
} => Self::HeartbeatMissed {
missed_heartbeats: missed_heartbeats + 1,
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
},
_ => self,
}
}
}
/// The state of the ssh connection.
#[derive(Clone, Copy, Debug)]
pub enum ConnectionState {
Connecting,
Connected,
HeartbeatMissed,
Reconnecting,
Disconnected,
}
impl From<&State> for ConnectionState {
fn from(value: &State) -> Self {
match value {
State::Connecting => Self::Connecting,
State::Connected { .. } => Self::Connected,
State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
State::ReconnectExhausted => Self::Disconnected,
}
}
}
pub struct SshRemoteClient {
client: Arc<ChannelClient>,
unique_identifier: String,
connection_options: SshConnectionOptions,
inner_state: Arc<Mutex<Option<SshRemoteClientState>>>,
state: Arc<Mutex<Option<State>>>,
}
impl Drop for SshRemoteClient {
@ -266,6 +405,7 @@ impl SshRemoteClient {
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
let this = cx.new_model(|cx| {
cx.on_app_quit(|this: &mut Self, _| {
this.shutdown_processes();
@ -273,47 +413,49 @@ impl SshRemoteClient {
})
.detach();
let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
Self {
client,
client: client.clone(),
unique_identifier: unique_identifier.clone(),
connection_options: SshConnectionOptions::default(),
inner_state: Arc::new(Mutex::new(None)),
connection_options: connection_options.clone(),
state: Arc::new(Mutex::new(Some(State::Connecting))),
}
})?;
let inner_state = {
let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
unique_identifier,
connection_options,
delegate.clone(),
&mut cx,
)
.await?;
let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
unique_identifier,
connection_options,
delegate.clone(),
&mut cx,
)
.await?;
let multiplex_task = Self::multiplex(
this.downgrade(),
ssh_proxy_process,
proxy_incoming_tx,
proxy_outgoing_rx,
&mut cx,
);
let multiplex_task = Self::multiplex(
this.downgrade(),
ssh_proxy_process,
proxy_incoming_tx,
proxy_outgoing_rx,
&mut cx,
);
SshRemoteClientState {
if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
log::error!("failed to establish connection: {}", error);
delegate.set_error(error.to_string(), &mut cx);
return Err(error);
}
let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
this.update(&mut cx, |this, _| {
*this.state.lock() = Some(State::Connected {
ssh_connection,
delegate,
forwarder: proxy,
multiplex_task,
heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx),
}
};
this.update(&mut cx, |this, cx| {
this.inner_state.lock().replace(inner_state);
cx.notify();
heartbeat_task,
});
})?;
Ok(this)
@ -321,78 +463,192 @@ impl SshRemoteClient {
}
fn shutdown_processes(&self) {
let Some(mut state) = self.inner_state.lock().take() else {
let Some(state) = self.state.lock().take() else {
return;
};
log::info!("shutting down ssh processes");
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
// child of master_process.
let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(())));
drop(task);
// Now drop the rest of state, which kills master process.
drop(state);
}
fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> {
log::info!("Trying to reconnect to ssh server...");
let Some(state) = self.inner_state.lock().take() else {
return Err(anyhow!("reconnect is already in progress"));
};
let workspace_identifier = self.unique_identifier.clone();
let SshRemoteClientState {
mut ssh_connection,
delegate,
forwarder: proxy,
let State::Connected {
multiplex_task,
heartbeat_task,
} = state;
..
} = state
else {
return;
};
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
// child of master_process.
drop(multiplex_task);
// Now drop the rest of state, which kills master process.
drop(heartbeat_task);
}
cx.spawn(|this, mut cx| async move {
let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
let mut lock = self.state.lock();
ssh_connection.master_process.kill()?;
ssh_connection
let can_reconnect = lock
.as_ref()
.map(|state| state.can_reconnect())
.unwrap_or(false);
if !can_reconnect {
let error = if let Some(state) = lock.as_ref() {
format!("invalid state, cannot reconnect while in state {state}")
} else {
"no state set".to_string()
};
return Err(anyhow!(error));
}
let state = lock.take().unwrap();
let (attempts, mut ssh_connection, delegate, forwarder) = match state {
State::Connected {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
}
| State::HeartbeatMissed {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task,
..
} => {
drop(multiplex_task);
drop(heartbeat_task);
(0, ssh_connection, delegate, forwarder)
}
State::ReconnectFailed {
attempts,
ssh_connection,
delegate,
forwarder,
..
} => (attempts, ssh_connection, delegate, forwarder),
State::Connecting | State::Reconnecting | State::ReconnectExhausted => unreachable!(),
};
let attempts = attempts + 1;
if attempts > MAX_RECONNECT_ATTEMPTS {
log::error!(
"Failed to reconnect to after {} attempts, giving up",
MAX_RECONNECT_ATTEMPTS
);
*lock = Some(State::ReconnectExhausted);
return Ok(());
}
*lock = Some(State::Reconnecting);
drop(lock);
log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
let identifier = self.unique_identifier.clone();
let client = self.client.clone();
let reconnect_task = cx.spawn(|this, mut cx| async move {
macro_rules! failed {
($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
return State::ReconnectFailed {
error: anyhow!($error),
attempts: $attempts,
ssh_connection: $ssh_connection,
delegate: $delegate,
forwarder: $forwarder,
};
};
}
if let Err(error) = ssh_connection.master_process.kill() {
failed!(error, attempts, ssh_connection, delegate, forwarder);
};
if let Err(error) = ssh_connection
.master_process
.status()
.await
.context("Failed to kill ssh process")?;
.context("Failed to kill ssh process")
{
failed!(error, attempts, ssh_connection, delegate, forwarder);
}
let connection_options = ssh_connection.socket.connection_options.clone();
let (ssh_connection, ssh_process) = Self::establish_connection(
workspace_identifier,
let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let (ssh_connection, ssh_process) = match Self::establish_connection(
identifier,
connection_options,
delegate.clone(),
&mut cx,
)
.await?;
let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let inner_state = SshRemoteClientState {
ssh_connection,
delegate,
forwarder: proxy,
multiplex_task: Self::multiplex(
this.clone(),
ssh_process,
proxy_incoming_tx,
proxy_outgoing_rx,
&mut cx,
),
heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
.await
{
Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
Err(error) => {
failed!(error, attempts, ssh_connection, delegate, forwarder);
}
};
this.update(&mut cx, |this, _| {
this.inner_state.lock().replace(inner_state);
let multiplex_task = Self::multiplex(
this.clone(),
ssh_process,
proxy_incoming_tx,
proxy_outgoing_rx,
&mut cx,
);
if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
failed!(error, attempts, ssh_connection, delegate, forwarder);
};
State::Connected {
ssh_connection,
delegate,
forwarder,
multiplex_task,
heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
}
});
cx.spawn(|this, mut cx| async move {
let new_state = reconnect_task.await;
this.update(&mut cx, |this, cx| {
match &new_state {
State::Connecting
| State::Reconnecting { .. }
| State::HeartbeatMissed { .. } => {}
State::Connected { .. } => {
log::info!("Successfully reconnected");
}
State::ReconnectFailed {
error, attempts, ..
} => {
log::error!(
"Reconnect attempt {} failed: {:?}. Starting new attempt...",
attempts,
error
);
}
State::ReconnectExhausted => {
log::error!("Reconnect attempt failed and all attempts exhausted");
}
}
let reconnect_failed = matches!(new_state, State::ReconnectFailed { .. });
*this.state.lock() = Some(new_state);
cx.notify();
if reconnect_failed {
this.reconnect(cx)
} else {
Ok(())
}
})
})
.detach();
.detach_and_log_err(cx);
Ok(())
}
@ -403,10 +659,6 @@ impl SshRemoteClient {
cx.spawn(|mut cx| {
let this = this.clone();
async move {
const MAX_MISSED_HEARTBEATS: usize = 5;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
let mut missed_heartbeats = 0;
let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
@ -415,19 +667,7 @@ impl SshRemoteClient {
log::info!("Sending heartbeat to server...");
let result = smol::future::or(
async {
client.request(proto::Ping {}).await?;
Ok(())
},
async {
smol::Timer::after(HEARTBEAT_TIMEOUT).await;
Err(anyhow!("Timeout detected"))
},
)
.await;
let result = client.ping(HEARTBEAT_TIMEOUT).await;
if result.is_err() {
missed_heartbeats += 1;
log::warn!(
@ -440,17 +680,10 @@ impl SshRemoteClient {
missed_heartbeats = 0;
}
if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
log::error!(
"Missed last {} hearbeats. Reconnecting...",
missed_heartbeats
);
this.update(&mut cx, |this, cx| {
this.reconnect(cx)
.context("failed to reconnect after missing heartbeats")
})
.context("failed to update weak reference, SshRemoteClient lost?")??;
let result = this.update(&mut cx, |this, mut cx| {
this.handle_heartbeat_result(missed_heartbeats, &mut cx)
})?;
if result.is_break() {
return Ok(());
}
}
@ -458,6 +691,34 @@ impl SshRemoteClient {
})
}
fn handle_heartbeat_result(
&mut self,
missed_heartbeats: usize,
cx: &mut ModelContext<Self>,
) -> ControlFlow<()> {
let state = self.state.lock().take().unwrap();
self.state.lock().replace(if missed_heartbeats > 0 {
state.heartbeat_missed()
} else {
state.heartbeat_recovered()
});
cx.notify();
if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
log::error!(
"Missed last {} heartbeats. Reconnecting...",
missed_heartbeats
);
self.reconnect(cx)
.context("failed to start reconnect process after missing heartbeats")
.log_err();
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
}
}
fn multiplex(
this: WeakModel<Self>,
mut ssh_proxy_process: Child,
@ -611,10 +872,11 @@ impl SshRemoteClient {
}
pub fn ssh_args(&self) -> Option<Vec<String>> {
let state = self.inner_state.lock();
state
self.state
.lock()
.as_ref()
.map(|state| state.ssh_connection.socket.ssh_args())
.and_then(|state| state.ssh_connection())
.map(|ssh_connection| ssh_connection.socket.ssh_args())
}
pub fn to_proto_client(&self) -> AnyProtoClient {
@ -625,8 +887,12 @@ impl SshRemoteClient {
self.connection_options.connection_string()
}
pub fn is_reconnect_underway(&self) -> bool {
maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default()
pub fn connection_state(&self) -> ConnectionState {
self.state
.lock()
.as_ref()
.map(ConnectionState::from)
.unwrap_or(ConnectionState::Disconnected)
}
#[cfg(any(test, feature = "test-support"))]
@ -646,7 +912,7 @@ impl SshRemoteClient {
client,
unique_identifier: "fake".to_string(),
connection_options: SshConnectionOptions::default(),
inner_state: Arc::new(Mutex::new(None)),
state: Arc::new(Mutex::new(None)),
})
}),
server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
@ -1046,6 +1312,20 @@ impl ChannelClient {
}
}
pub async fn ping(&self, timeout: Duration) -> Result<()> {
smol::future::or(
async {
self.request(proto::Ping {}).await?;
Ok(())
},
async {
smol::Timer::after(timeout).await;
Err(anyhow!("Timeout detected"))
},
)
.await
}
pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
log::debug!("ssh send name:{}", T::NAME);
self.send_dynamic(payload.into_envelope(0, None, None))

View file

@ -41,6 +41,7 @@ gpui.workspace = true
notifications.workspace = true
project.workspace = true
recent_projects.workspace = true
remote.workspace = true
rpc.workspace = true
serde.workspace = true
smallvec.workspace = true

View file

@ -265,10 +265,12 @@ impl TitleBar {
fn render_ssh_project_host(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
let host = self.project.read(cx).ssh_connection_string(cx)?;
let meta = SharedString::from(format!("Connected to: {host}"));
let indicator_color = if self.project.read(cx).ssh_is_connected(cx)? {
Color::Success
} else {
Color::Warning
let indicator_color = match self.project.read(cx).ssh_connection_state(cx)? {
remote::ConnectionState::Connecting => Color::Info,
remote::ConnectionState::Connected => Color::Success,
remote::ConnectionState::HeartbeatMissed => Color::Warning,
remote::ConnectionState::Reconnecting => Color::Warning,
remote::ConnectionState::Disconnected => Color::Error,
};
let indicator = div()
.absolute()