ssh remoting: Treat other message as heartbeat (#19219)

This improves the heartbeat detection logic. We now treat any other
incoming message from the ssh remote server
as a heartbeat message, meaning that we can detect re-connects earlier.

It also changes the connection handling to await futures detached.

Co-Authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Antonio <antonio@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-10-15 16:37:56 +02:00 committed by GitHub
parent 4fa75a78b9
commit 0eb6bfd323
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9,7 +9,7 @@ use anyhow::{anyhow, Context as _, Result};
use collections::HashMap; use collections::HashMap;
use futures::{ use futures::{
channel::{ channel::{
mpsc::{self, UnboundedReceiver, UnboundedSender}, mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
oneshot, oneshot,
}, },
future::BoxFuture, future::BoxFuture,
@ -28,7 +28,6 @@ use rpc::{
use smol::{ use smol::{
fs, fs,
process::{self, Child, Stdio}, process::{self, Child, Stdio},
Timer,
}; };
use std::{ use std::{
any::TypeId, any::TypeId,
@ -441,6 +440,7 @@ impl SshRemoteClient {
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>(); let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>(); let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?; let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
let this = cx.new_model(|_| Self { let this = cx.new_model(|_| Self {
@ -467,6 +467,7 @@ impl SshRemoteClient {
ssh_proxy_process, ssh_proxy_process,
proxy_incoming_tx, proxy_incoming_tx,
proxy_outgoing_rx, proxy_outgoing_rx,
connection_activity_tx,
&mut cx, &mut cx,
); );
@ -476,7 +477,7 @@ impl SshRemoteClient {
return Err(error); return Err(error);
} }
let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx); let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx);
this.update(&mut cx, |this, _| { this.update(&mut cx, |this, _| {
*this.state.lock() = Some(State::Connected { *this.state.lock() = Some(State::Connected {
@ -518,7 +519,7 @@ impl SshRemoteClient {
// We wait 50ms instead of waiting for a response, because // We wait 50ms instead of waiting for a response, because
// waiting for a response would require us to wait on the main thread // waiting for a response would require us to wait on the main thread
// which we want to avoid in an `on_app_quit` callback. // which we want to avoid in an `on_app_quit` callback.
Timer::after(Duration::from_millis(50)).await; smol::Timer::after(Duration::from_millis(50)).await;
} }
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
@ -632,6 +633,7 @@ impl SshRemoteClient {
let (incoming_tx, outgoing_rx) = forwarder.into_channels().await; let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) = let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
let (ssh_connection, ssh_process) = match Self::establish_connection( let (ssh_connection, ssh_process) = match Self::establish_connection(
identifier, identifier,
@ -653,6 +655,7 @@ impl SshRemoteClient {
ssh_process, ssh_process,
proxy_incoming_tx, proxy_incoming_tx,
proxy_outgoing_rx, proxy_outgoing_rx,
connection_activity_tx,
&mut cx, &mut cx,
); );
@ -665,7 +668,7 @@ impl SshRemoteClient {
delegate, delegate,
forwarder, forwarder,
multiplex_task, multiplex_task,
heartbeat_task: Self::heartbeat(this.clone(), &mut cx), heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
} }
}); });
@ -717,22 +720,39 @@ impl SshRemoteClient {
Ok(()) Ok(())
} }
fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> { fn heartbeat(
this: WeakModel<Self>,
mut connection_activity_rx: mpsc::Receiver<()>,
cx: &mut AsyncAppContext,
) -> Task<Result<()>> {
let Ok(client) = this.update(cx, |this, _| this.client.clone()) else { let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
return Task::ready(Err(anyhow!("SshRemoteClient lost"))); return Task::ready(Err(anyhow!("SshRemoteClient lost")));
}; };
cx.spawn(|mut cx| { cx.spawn(|mut cx| {
let this = this.clone(); let this = this.clone();
async move { async move {
let mut missed_heartbeats = 0; let mut missed_heartbeats = 0;
let mut timer = Timer::interval(HEARTBEAT_INTERVAL); let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse();
loop { futures::pin_mut!(keepalive_timer);
timer.next().await;
loop {
select_biased! {
_ = connection_activity_rx.next().fuse() => {
keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse());
}
_ = keepalive_timer => {
log::debug!("Sending heartbeat to server..."); log::debug!("Sending heartbeat to server...");
let result = client.ping(HEARTBEAT_TIMEOUT).await; let result = select_biased! {
_ = connection_activity_rx.next().fuse() => {
Ok(())
}
ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => {
ping_result
}
};
if result.is_err() { if result.is_err() {
missed_heartbeats += 1; missed_heartbeats += 1;
log::warn!( log::warn!(
@ -755,6 +775,8 @@ impl SshRemoteClient {
} }
} }
} }
}
}
}) })
} }
@ -792,6 +814,7 @@ impl SshRemoteClient {
mut ssh_proxy_process: Child, mut ssh_proxy_process: Child,
incoming_tx: UnboundedSender<Envelope>, incoming_tx: UnboundedSender<Envelope>,
mut outgoing_rx: UnboundedReceiver<Envelope>, mut outgoing_rx: UnboundedReceiver<Envelope>,
mut connection_activity_tx: Sender<()>,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
let mut child_stderr = ssh_proxy_process.stderr.take().unwrap(); let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
@ -833,6 +856,7 @@ impl SshRemoteClient {
let message_len = message_len_from_buffer(&stdout_buffer); let message_len = message_len_from_buffer(&stdout_buffer);
match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await { match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
Ok(envelope) => { Ok(envelope) => {
connection_activity_tx.try_send(()).ok();
incoming_tx.unbounded_send(envelope).ok(); incoming_tx.unbounded_send(envelope).ok();
} }
Err(error) => { Err(error) => {
@ -863,6 +887,8 @@ impl SshRemoteClient {
} }
stderr_buffer.drain(0..start_ix); stderr_buffer.drain(0..start_ix);
stderr_offset -= start_ix; stderr_offset -= start_ix;
connection_activity_tx.try_send(()).ok();
} }
Err(error) => { Err(error) => {
Err(anyhow!("error reading stderr: {error:?}"))?; Err(anyhow!("error reading stderr: {error:?}"))?;
@ -1392,6 +1418,7 @@ impl ChannelClient {
cx.clone(), cx.clone(),
) { ) {
log::debug!("ssh message received. name:{type_name}"); log::debug!("ssh message received. name:{type_name}");
cx.foreground_executor().spawn(async move {
match future.await { match future.await {
Ok(_) => { Ok(_) => {
log::debug!("ssh message handled. name:{type_name}"); log::debug!("ssh message handled. name:{type_name}");
@ -1402,6 +1429,8 @@ impl ChannelClient {
); );
} }
} }
}).detach();
} else { } else {
log::error!("unhandled ssh message name:{type_name}"); log::error!("unhandled ssh message name:{type_name}");
} }