ssh remoting: Enable reconnecting after connection losses (#18586)

Release Notes:

- N/A

---------

Co-authored-by: Bennet <bennet@zed.dev>
This commit is contained in:
Thorsten Ball 2024-10-07 11:40:59 +02:00 committed by GitHub
parent 67fbdbbed6
commit c03b8d6c48
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 727 additions and 240 deletions

View file

@ -15,7 +15,9 @@ use futures::{
select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
StreamExt as _,
};
use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, Task};
use gpui::{
AppContext, AsyncAppContext, Context, Model, ModelContext, SemanticVersion, Task, WeakModel,
};
use parking_lot::Mutex;
use rpc::{
proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
@ -28,10 +30,11 @@ use smol::{
use std::{
any::TypeId,
ffi::OsStr,
mem,
path::{Path, PathBuf},
sync::{
atomic::{AtomicU32, Ordering::SeqCst},
Arc, Weak,
Arc,
},
time::Instant,
};
@ -92,6 +95,17 @@ impl SshConnectionOptions {
host
}
}
// Uniquely identifies dev server projects on a remote host. Needs to be
// stable for the same dev server project.
pub fn dev_server_identifier(&self) -> String {
let mut identifier = format!("dev-server-{:?}", self.host);
if let Some(username) = self.username.as_ref() {
identifier.push('-');
identifier.push_str(&username);
}
identifier
}
}
#[derive(Copy, Clone, Debug)]
@ -250,59 +264,101 @@ struct SshRemoteClientState {
pub struct SshRemoteClient {
client: Arc<ChannelClient>,
inner_state: Mutex<Option<SshRemoteClientState>>,
unique_identifier: String,
connection_options: SshConnectionOptions,
inner_state: Arc<Mutex<Option<SshRemoteClientState>>>,
}
impl Drop for SshRemoteClient {
fn drop(&mut self) {
self.shutdown_processes();
}
}
impl SshRemoteClient {
pub async fn new(
pub fn new(
unique_identifier: String,
connection_options: SshConnectionOptions,
delegate: Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext,
) -> Result<Arc<Self>> {
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
cx: &AppContext,
) -> Task<Result<Model<Self>>> {
cx.spawn(|mut cx| async move {
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
let this = Arc::new(Self {
client,
inner_state: Mutex::new(None),
connection_options: connection_options.clone(),
});
let this = cx.new_model(|cx| {
cx.on_app_quit(|this: &mut Self, _| {
this.shutdown_processes();
futures::future::ready(())
})
.detach();
let inner_state = {
let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, cx);
let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
Self {
client,
unique_identifier: unique_identifier.clone(),
connection_options: SshConnectionOptions::default(),
inner_state: Arc::new(Mutex::new(None)),
}
})?;
let (ssh_connection, ssh_process) =
Self::establish_connection(connection_options, delegate.clone(), cx).await?;
let inner_state = {
let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let multiplex_task = Self::multiplex(
Arc::downgrade(&this),
ssh_process,
proxy_incoming_tx,
proxy_outgoing_rx,
cx,
);
let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
unique_identifier,
connection_options,
delegate.clone(),
&mut cx,
)
.await?;
SshRemoteClientState {
ssh_connection,
delegate,
forwarder: proxy,
multiplex_task,
}
};
let multiplex_task = Self::multiplex(
this.downgrade(),
ssh_proxy_process,
proxy_incoming_tx,
proxy_outgoing_rx,
&mut cx,
);
this.inner_state.lock().replace(inner_state);
SshRemoteClientState {
ssh_connection,
delegate,
forwarder: proxy,
multiplex_task,
}
};
Ok(this)
this.update(&mut cx, |this, cx| {
this.inner_state.lock().replace(inner_state);
cx.notify();
})?;
Ok(this)
})
}
fn reconnect(this: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
let Some(state) = this.inner_state.lock().take() else {
fn shutdown_processes(&self) {
let Some(mut state) = self.inner_state.lock().take() else {
return;
};
log::info!("shutting down ssh processes");
// Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
// child of master_process.
let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(())));
drop(task);
// Now drop the rest of state, which kills master process.
drop(state);
}
fn reconnect(&self, cx: &ModelContext<Self>) -> Result<()> {
let Some(state) = self.inner_state.lock().take() else {
return Err(anyhow!("reconnect is already in progress"));
};
let workspace_identifier = self.unique_identifier.clone();
let SshRemoteClientState {
mut ssh_connection,
delegate,
@ -311,7 +367,7 @@ impl SshRemoteClient {
} = state;
drop(multiplex_task);
cx.spawn(|mut cx| async move {
cx.spawn(|this, mut cx| async move {
let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
ssh_connection.master_process.kill()?;
@ -323,8 +379,13 @@ impl SshRemoteClient {
let connection_options = ssh_connection.socket.connection_options.clone();
let (ssh_connection, ssh_process) =
Self::establish_connection(connection_options, delegate.clone(), &mut cx).await?;
let (ssh_connection, ssh_process) = Self::establish_connection(
workspace_identifier,
connection_options,
delegate.clone(),
&mut cx,
)
.await?;
let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
@ -334,32 +395,32 @@ impl SshRemoteClient {
delegate,
forwarder: proxy,
multiplex_task: Self::multiplex(
Arc::downgrade(&this),
this.clone(),
ssh_process,
proxy_incoming_tx,
proxy_outgoing_rx,
&mut cx,
),
};
this.inner_state.lock().replace(inner_state);
anyhow::Ok(())
this.update(&mut cx, |this, _| {
this.inner_state.lock().replace(inner_state);
})
})
.detach();
anyhow::Ok(())
Ok(())
}
fn multiplex(
this: Weak<Self>,
mut ssh_process: Child,
this: WeakModel<Self>,
mut ssh_proxy_process: Child,
incoming_tx: UnboundedSender<Envelope>,
mut outgoing_rx: UnboundedReceiver<Envelope>,
cx: &AsyncAppContext,
) -> Task<Result<()>> {
let mut child_stderr = ssh_process.stderr.take().unwrap();
let mut child_stdout = ssh_process.stdout.take().unwrap();
let mut child_stdin = ssh_process.stdin.take().unwrap();
let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
let io_task = cx.background_executor().spawn(async move {
let mut stdin_buffer = Vec::new();
@ -385,7 +446,7 @@ impl SshRemoteClient {
Ok(0) => {
child_stdin.close().await?;
outgoing_rx.close();
let status = ssh_process.status().await?;
let status = ssh_proxy_process.status().await?;
if !status.success() {
log::error!("ssh process exited with status: {status:?}");
return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code()));
@ -446,9 +507,9 @@ impl SshRemoteClient {
if let Err(error) = result {
log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
if let Some(this) = this.upgrade() {
Self::reconnect(this, &mut cx).ok();
}
this.update(&mut cx, |this, cx| {
this.reconnect(cx).ok();
})?;
}
Ok(())
@ -456,6 +517,7 @@ impl SshRemoteClient {
}
async fn establish_connection(
unique_identifier: String,
connection_options: SshConnectionOptions,
delegate: Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext,
@ -479,17 +541,22 @@ impl SshRemoteClient {
let socket = ssh_connection.socket.clone();
run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
let ssh_process = socket
delegate.set_status(Some("Starting proxy"), cx);
let ssh_proxy_process = socket
.ssh_command(format!(
"RUST_LOG={} RUST_BACKTRACE={} {:?} run",
"RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
std::env::var("RUST_LOG").unwrap_or_default(),
std::env::var("RUST_BACKTRACE").unwrap_or_default(),
remote_binary_path,
unique_identifier,
))
// IMPORTANT: we kill this process when we drop the task that uses it.
.kill_on_drop(true)
.spawn()
.context("failed to spawn remote server")?;
Ok((ssh_connection, ssh_process))
Ok((ssh_connection, ssh_proxy_process))
}
pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
@ -514,21 +581,25 @@ impl SshRemoteClient {
pub fn is_reconnect_underway(&self) -> bool {
maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default()
}
#[cfg(any(test, feature = "test-support"))]
pub fn fake(
client_cx: &mut gpui::TestAppContext,
server_cx: &mut gpui::TestAppContext,
) -> (Arc<Self>, Arc<ChannelClient>) {
) -> (Model<Self>, Arc<ChannelClient>) {
use gpui::Context;
let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
(
client_cx.update(|cx| {
let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
Arc::new(Self {
cx.new_model(|_| Self {
client,
inner_state: Mutex::new(None),
unique_identifier: "fake".to_string(),
connection_options: SshConnectionOptions::default(),
inner_state: Arc::new(Mutex::new(None)),
})
}),
server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),