ssh remoting: Treat other message as heartbeat (#19219)

This improves the heartbeat detection logic. We now treat any other
incoming message from the ssh remote server
as a heartbeat message, meaning that we can detect re-connects earlier.

It also changes the connection handling to await futures detached.

Co-Authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-10-15 16:37:56 +02:00 committed by GitHub
parent 4fa75a78b9
commit 0eb6bfd323
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9,7 +9,7 @@ use anyhow::{anyhow, Context as _, Result};
use collections::HashMap;
use futures::{
channel::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
oneshot,
},
future::BoxFuture,
@ -28,7 +28,6 @@ use rpc::{
use smol::{
fs,
process::{self, Child, Stdio},
Timer,
};
use std::{
any::TypeId,
@ -441,6 +440,7 @@ impl SshRemoteClient {
cx.spawn(|mut cx| async move {
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
let this = cx.new_model(|_| Self {
@ -467,6 +467,7 @@ impl SshRemoteClient {
ssh_proxy_process,
proxy_incoming_tx,
proxy_outgoing_rx,
connection_activity_tx,
&mut cx,
);
@ -476,7 +477,7 @@ impl SshRemoteClient {
return Err(error);
}
let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
this.update(&mut cx, |this, _| {
*this.state.lock() = Some(State::Connected {
@ -518,7 +519,7 @@ impl SshRemoteClient {
// We wait 50ms instead of waiting for a response, because
// waiting for a response would require us to wait on the main thread
// which we want to avoid in an `on_app_quit` callback.
Timer::after(Duration::from_millis(50)).await;
smol::Timer::after(Duration::from_millis(50)).await;
}
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
@ -632,6 +633,7 @@ impl SshRemoteClient {
let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
let (ssh_connection, ssh_process) = match Self::establish_connection(
identifier,
@ -653,6 +655,7 @@ impl SshRemoteClient {
ssh_process,
proxy_incoming_tx,
proxy_outgoing_rx,
connection_activity_tx,
&mut cx,
);
@ -665,7 +668,7 @@ impl SshRemoteClient {
delegate,
forwarder,
multiplex_task,
heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
}
});
@ -717,41 +720,60 @@ impl SshRemoteClient {
Ok(())
}
fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
fn heartbeat(
this: WeakModel<Self>,
mut connection_activity_rx: mpsc::Receiver<()>,
cx: &mut AsyncAppContext,
) -> Task<Result<()>> {
let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
return Task::ready(Err(anyhow!("SshRemoteClient lost")));
};
cx.spawn(|mut cx| {
let this = this.clone();
async move {
let mut missed_heartbeats = 0;
let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
futures::pin_mut!(keepalive_timer);
loop {
timer.next().await;
select_biased! {
_ = connection_activity_rx.next().fuse() => {
keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
}
_ = keepalive_timer => {
log::debug!("Sending heartbeat to server...");
log::debug!("Sending heartbeat to server...");
let result = select_biased! {
_ = connection_activity_rx.next().fuse() => {
Ok(())
}
ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
ping_result
}
};
if result.is_err() {
missed_heartbeats += 1;
log::warn!(
"No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
HEARTBEAT_TIMEOUT,
missed_heartbeats,
MAX_MISSED_HEARTBEATS
);
} else if missed_heartbeats != 0 {
missed_heartbeats = 0;
} else {
continue;
}
let result = client.ping(HEARTBEAT_TIMEOUT).await;
if result.is_err() {
missed_heartbeats += 1;
log::warn!(
"No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
HEARTBEAT_TIMEOUT,
missed_heartbeats,
MAX_MISSED_HEARTBEATS
);
} else if missed_heartbeats != 0 {
missed_heartbeats = 0;
} else {
continue;
}
let result = this.update(&mut cx, |this, mut cx| {
this.handle_heartbeat_result(missed_heartbeats, &mut cx)
})?;
if result.is_break() {
return Ok(());
let result = this.update(&mut cx, |this, mut cx| {
this.handle_heartbeat_result(missed_heartbeats, &mut cx)
})?;
if result.is_break() {
return Ok(());
}
}
}
}
}
@ -792,6 +814,7 @@ impl SshRemoteClient {
mut ssh_proxy_process: Child,
incoming_tx: UnboundedSender<Envelope>,
mut outgoing_rx: UnboundedReceiver<Envelope>,
mut connection_activity_tx: Sender<()>,
cx: &AsyncAppContext,
) -> Task<Result<()>> {
let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
@ -833,6 +856,7 @@ impl SshRemoteClient {
let message_len = message_len_from_buffer(&stdout_buffer);
match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
Ok(envelope) => {
connection_activity_tx.try_send(()).ok();
incoming_tx.unbounded_send(envelope).ok();
}
Err(error) => {
@ -863,6 +887,8 @@ impl SshRemoteClient {
}
stderr_buffer.drain(0..start_ix);
stderr_offset -= start_ix;
connection_activity_tx.try_send(()).ok();
}
Err(error) => {
Err(anyhow!("error reading stderr: {error:?}"))?;
@ -1392,16 +1418,19 @@ impl ChannelClient {
cx.clone(),
) {
log::debug!("ssh message received. name:{type_name}");
match future.await {
Ok(_) => {
log::debug!("ssh message handled. name:{type_name}");
cx.foreground_executor().spawn(async move {
match future.await {
Ok(_) => {
log::debug!("ssh message handled. name:{type_name}");
}
Err(error) => {
log::error!(
"error handling message. type:{type_name}, error:{error}",
);
}
}
Err(error) => {
log::error!(
"error handling message. type:{type_name}, error:{error}",
);
}
}
}).detach();
} else {
log::error!("unhandled ssh message name:{type_name}");
}