ssh: Detect timeouts when server is unresponsive (#18808)

To detect connection timeouts we ping the remote server every X seconds
and attempt to reconnect if the server failed to respond.
Next up is showing some feedback in the UI to make this visible to the
user, and stop reconnecting after X amount of retries.

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2024-10-07 15:08:16 +02:00 committed by GitHub
parent 5aa165c530
commit 25a97a6a2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 82 additions and 3 deletions

View file

@ -26,6 +26,7 @@ use rpc::{
use smol::{
fs,
process::{self, Child, Stdio},
Timer,
};
use std::{
any::TypeId,
@ -36,7 +37,7 @@ use std::{
atomic::{AtomicU32, Ordering::SeqCst},
Arc,
},
time::Instant,
time::{Duration, Instant},
};
use tempfile::TempDir;
use util::maybe;
@ -173,7 +174,7 @@ async fn run_cmd(command: &mut process::Command) -> Result<String> {
#[cfg(unix)]
async fn read_with_timeout(
stdout: &mut process::ChildStdout,
timeout: std::time::Duration,
timeout: Duration,
output: &mut Vec<u8>,
) -> Result<(), std::io::Error> {
smol::future::or(
@ -260,6 +261,7 @@ struct SshRemoteClientState {
delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder,
multiplex_task: Task<Result<()>>,
heartbeat_task: Task<Result<()>>,
}
pub struct SshRemoteClient {
@ -327,6 +329,7 @@ impl SshRemoteClient {
delegate,
forwarder: proxy,
multiplex_task,
heartbeat_task: Self::heartbeat(this.downgrade(), &mut cx),
}
};
@ -353,6 +356,7 @@ impl SshRemoteClient {
}
fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> {
log::info!("Trying to reconnect to ssh server...");
let Some(state) = self.inner_state.lock().take() else {
return Err(anyhow!("reconnect is already in progress"));
};
@ -364,8 +368,10 @@ impl SshRemoteClient {
delegate,
forwarder: proxy,
multiplex_task,
heartbeat_task,
} = state;
drop(multiplex_task);
drop(heartbeat_task);
cx.spawn(|this, mut cx| async move {
let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
@ -401,6 +407,7 @@ impl SshRemoteClient {
proxy_outgoing_rx,
&mut cx,
),
heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
};
this.update(&mut cx, |this, _| {
@ -411,6 +418,68 @@ impl SshRemoteClient {
Ok(())
}
fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
return Task::ready(Err(anyhow!("SshRemoteClient lost")));
};
cx.spawn(|mut cx| {
let this = this.clone();
async move {
const MAX_MISSED_HEARTBEATS: usize = 5;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
let mut missed_heartbeats = 0;
let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
loop {
timer.next().await;
log::info!("Sending heartbeat to server...");
let result = smol::future::or(
async {
client.request(proto::Ping {}).await?;
Ok(())
},
async {
smol::Timer::after(HEARTBEAT_TIMEOUT).await;
Err(anyhow!("Timeout detected"))
},
)
.await;
if result.is_err() {
missed_heartbeats += 1;
log::warn!(
"No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
HEARTBEAT_TIMEOUT,
missed_heartbeats,
MAX_MISSED_HEARTBEATS
);
} else {
missed_heartbeats = 0;
}
if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
log::error!(
"Missed last {} hearbeats. Reconnecting...",
missed_heartbeats
);
this.update(&mut cx, |this, cx| {
this.reconnect(cx)
.context("failed to reconnect after missing heartbeats")
})
.context("failed to update weak reference, SshRemoteClient lost?")??;
return Ok(());
}
}
}
})
}
fn multiplex(
this: WeakModel<Self>,
mut ssh_proxy_process: Child,
@ -712,7 +781,7 @@ impl SshRemoteConnection {
// has completed.
let stdout = master_process.stdout.as_mut().unwrap();
let mut output = Vec::new();
let connection_timeout = std::time::Duration::from_secs(10);
let connection_timeout = Duration::from_secs(10);
let result = read_with_timeout(stdout, connection_timeout, &mut output).await;
if let Err(e) = result {
let error_message = if e.kind() == std::io::ErrorKind::TimedOut {

View file

@ -113,6 +113,7 @@ impl HeadlessProject {
client.add_request_handler(cx.weak_model(), Self::handle_list_remote_directory);
client.add_request_handler(cx.weak_model(), Self::handle_check_file_exists);
client.add_request_handler(cx.weak_model(), Self::handle_shutdown_remote_server);
client.add_request_handler(cx.weak_model(), Self::handle_ping);
client.add_model_request_handler(Self::handle_add_worktree);
client.add_model_request_handler(Self::handle_open_buffer_by_path);
@ -354,4 +355,13 @@ impl HeadlessProject {
Ok(proto::Ack {})
}
pub async fn handle_ping(
_this: Model<Self>,
_envelope: TypedEnvelope<proto::Ping>,
_cx: AsyncAppContext,
) -> Result<proto::Ack> {
log::debug!("Received ping from client");
Ok(proto::Ack {})
}
}