diff --git a/Cargo.lock b/Cargo.lock index cc588fcb22..426d52b8d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9089,6 +9089,7 @@ dependencies = [ "serde_json", "smol", "tempfile", + "thiserror", "util", ] diff --git a/crates/remote/Cargo.toml b/crates/remote/Cargo.toml index 5c291b59d3..63c12cf117 100644 --- a/crates/remote/Cargo.toml +++ b/crates/remote/Cargo.toml @@ -31,6 +31,7 @@ serde.workspace = true serde_json.workspace = true smol.workspace = true tempfile.workspace = true +thiserror.workspace = true util.workspace = true [dev-dependencies] diff --git a/crates/remote/src/proxy.rs b/crates/remote/src/proxy.rs new file mode 100644 index 0000000000..d715d5ecf6 --- /dev/null +++ b/crates/remote/src/proxy.rs @@ -0,0 +1,25 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ProxyLaunchError { + #[error("Attempted reconnect, but server not running.")] + ServerNotRunning, +} + +impl ProxyLaunchError { + pub fn to_exit_code(&self) -> i32 { + match self { + // We're using 90 as the exit code, because 0-78 are often taken + // by shells and other conventions and >128 also has certain meanings + // in certain contexts. + Self::ServerNotRunning => 90, + } + } + + pub fn from_exit_code(exit_code: i32) -> Option { + match exit_code { + 90 => Some(Self::ServerNotRunning), + _ => None, + } + } +} diff --git a/crates/remote/src/remote.rs b/crates/remote/src/remote.rs index 3cfaf48a88..382adf0dac 100644 --- a/crates/remote/src/remote.rs +++ b/crates/remote/src/remote.rs @@ -1,5 +1,6 @@ pub mod json_log; pub mod protocol; +pub mod proxy; pub mod ssh_session; pub use ssh_session::{ diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index cf2a702231..5aeda7ade3 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -3,6 +3,7 @@ use crate::{ protocol::{ message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE, }, + proxy::ProxyLaunchError, }; use anyhow::{anyhow, Context as _, Result}; use collections::HashMap; @@ -271,6 +272,7 @@ enum State { attempts: usize, }, ReconnectExhausted, + ServerNotRunning, } impl fmt::Display for State { @@ -282,6 +284,7 @@ impl fmt::Display for State { Self::ReconnectFailed { .. } => write!(f, "reconnect failed"), Self::ReconnectExhausted => write!(f, "reconnect exhausted"), Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"), + Self::ServerNotRunning { .. } => write!(f, "server not running"), } } } @@ -297,10 +300,23 @@ impl State { } fn can_reconnect(&self) -> bool { - matches!( - self, - Self::Connected { .. } | Self::HeartbeatMissed { .. } | Self::ReconnectFailed { .. } - ) + match self { + Self::Connected { .. } + | Self::HeartbeatMissed { .. } + | Self::ReconnectFailed { .. } => true, + State::Connecting + | State::Reconnecting + | State::ReconnectExhausted + | State::ServerNotRunning => false, + } + } + + fn is_reconnect_failed(&self) -> bool { + matches!(self, Self::ReconnectFailed { .. }) + } + + fn is_reconnecting(&self) -> bool { + matches!(self, Self::Reconnecting { .. }) } fn heartbeat_recovered(self) -> Self { @@ -377,6 +393,7 @@ impl From<&State> for ConnectionState { State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting, State::HeartbeatMissed { .. } => Self::HeartbeatMissed, State::ReconnectExhausted => Self::Disconnected, + State::ServerNotRunning => Self::Disconnected, } } } @@ -426,6 +443,7 @@ impl SshRemoteClient { let (ssh_connection, ssh_proxy_process) = Self::establish_connection( unique_identifier, + false, connection_options, delegate.clone(), &mut cx, @@ -496,6 +514,7 @@ impl SshRemoteClient { } else { "no state set".to_string() }; + log::info!("aborting reconnect, because not in state that allows reconnecting"); return Err(anyhow!(error)); } @@ -527,7 +546,10 @@ impl SshRemoteClient { forwarder, .. } => (attempts, ssh_connection, delegate, forwarder), - State::Connecting | State::Reconnecting | State::ReconnectExhausted => unreachable!(), + State::Connecting + | State::Reconnecting + | State::ReconnectExhausted + | State::ServerNotRunning => unreachable!(), }; let attempts = attempts + 1; @@ -536,11 +558,12 @@ impl SshRemoteClient { "Failed to reconnect to after {} attempts, giving up", MAX_RECONNECT_ATTEMPTS ); - *lock = Some(State::ReconnectExhausted); + drop(lock); + self.set_state(State::ReconnectExhausted, cx); return Ok(()); } - *lock = Some(State::Reconnecting); drop(lock); + self.set_state(State::Reconnecting, cx); log::info!("Trying to reconnect to ssh server... Attempt {}", attempts); @@ -580,6 +603,7 @@ impl SshRemoteClient { let (ssh_connection, ssh_process) = match Self::establish_connection( identifier, + true, connection_options, delegate.clone(), &mut cx, @@ -616,33 +640,39 @@ impl SshRemoteClient { cx.spawn(|this, mut cx| async move { let new_state = reconnect_task.await; this.update(&mut cx, |this, cx| { - match &new_state { - State::Connecting - | State::Reconnecting { .. } - | State::HeartbeatMissed { .. } => {} - State::Connected { .. } => { - log::info!("Successfully reconnected"); + this.try_set_state(cx, |old_state| { + if old_state.is_reconnecting() { + match &new_state { + State::Connecting + | State::Reconnecting { .. } + | State::HeartbeatMissed { .. } + | State::ServerNotRunning => {} + State::Connected { .. } => { + log::info!("Successfully reconnected"); + } + State::ReconnectFailed { + error, attempts, .. + } => { + log::error!( + "Reconnect attempt {} failed: {:?}. Starting new attempt...", + attempts, + error + ); + } + State::ReconnectExhausted => { + log::error!("Reconnect attempt failed and all attempts exhausted"); + } + } + Some(new_state) + } else { + None } - State::ReconnectFailed { - error, attempts, .. - } => { - log::error!( - "Reconnect attempt {} failed: {:?}. Starting new attempt...", - attempts, - error - ); - } - State::ReconnectExhausted => { - log::error!("Reconnect attempt failed and all attempts exhausted"); - } - } + }); - let reconnect_failed = matches!(new_state, State::ReconnectFailed { .. }); - *this.state.lock() = Some(new_state); - cx.notify(); - if reconnect_failed { + if this.state_is(State::is_reconnect_failed) { this.reconnect(cx) } else { + log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state."); Ok(()) } }) @@ -676,8 +706,10 @@ impl SshRemoteClient { missed_heartbeats, MAX_MISSED_HEARTBEATS ); - } else { + } else if missed_heartbeats != 0 { missed_heartbeats = 0; + } else { + continue; } let result = this.update(&mut cx, |this, mut cx| { @@ -697,12 +729,12 @@ impl SshRemoteClient { cx: &mut ModelContext, ) -> ControlFlow<()> { let state = self.state.lock().take().unwrap(); - self.state.lock().replace(if missed_heartbeats > 0 { + let next_state = if missed_heartbeats > 0 { state.heartbeat_missed() } else { state.heartbeat_recovered() - }); - cx.notify(); + }; + self.set_state(next_state, cx); if missed_heartbeats >= MAX_MISSED_HEARTBEATS { log::error!( @@ -743,7 +775,7 @@ impl SshRemoteClient { select_biased! { outgoing = outgoing_rx.next().fuse() => { let Some(outgoing) = outgoing else { - return anyhow::Ok(()); + return anyhow::Ok(None); }; write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?; @@ -755,11 +787,7 @@ impl SshRemoteClient { child_stdin.close().await?; outgoing_rx.close(); let status = ssh_proxy_process.status().await?; - if !status.success() { - log::error!("ssh process exited with status: {status:?}"); - return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code())); - } - return Ok(()); + return Ok(status.code()); } Ok(len) => { if len < stdout_buffer.len() { @@ -813,19 +841,56 @@ impl SshRemoteClient { cx.spawn(|mut cx| async move { let result = io_task.await; - if let Err(error) = result { - log::warn!("ssh io task died with error: {:?}. reconnecting...", error); - this.update(&mut cx, |this, cx| { - this.reconnect(cx).ok(); - })?; + match result { + Ok(Some(exit_code)) => { + if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) { + match error { + ProxyLaunchError::ServerNotRunning => { + log::error!("failed to reconnect because server is not running"); + this.update(&mut cx, |this, cx| { + this.set_state(State::ServerNotRunning, cx); + })?; + } + } + } else if exit_code > 0 { + log::error!("proxy process terminated unexpectedly"); + } + } + Ok(None) => {} + Err(error) => { + log::warn!("ssh io task died with error: {:?}. reconnecting...", error); + this.update(&mut cx, |this, cx| { + this.reconnect(cx).ok(); + })?; + } } - Ok(()) }) } + fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool { + self.state.lock().as_ref().map_or(false, check) + } + + fn try_set_state( + &self, + cx: &mut ModelContext, + map: impl FnOnce(&State) -> Option, + ) { + if let Some(new_state) = self.state.lock().as_ref().and_then(map) { + self.set_state(new_state, cx); + } + } + + fn set_state(&self, state: State, cx: &mut ModelContext) { + log::info!("setting state to '{}'", &state); + self.state.lock().replace(state); + cx.notify(); + } + async fn establish_connection( unique_identifier: String, + reconnect: bool, connection_options: SshConnectionOptions, delegate: Arc, cx: &mut AsyncAppContext, @@ -851,14 +916,19 @@ impl SshRemoteClient { delegate.set_status(Some("Starting proxy"), cx); + let mut start_proxy_command = format!( + "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}", + std::env::var("RUST_LOG").unwrap_or_default(), + std::env::var("RUST_BACKTRACE").unwrap_or_default(), + remote_binary_path, + unique_identifier, + ); + if reconnect { + start_proxy_command.push_str(" --reconnect"); + } + let ssh_proxy_process = socket - .ssh_command(format!( - "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}", - std::env::var("RUST_LOG").unwrap_or_default(), - std::env::var("RUST_BACKTRACE").unwrap_or_default(), - remote_binary_path, - unique_identifier, - )) + .ssh_command(start_proxy_command) // IMPORTANT: we kill this process when we drop the task that uses it. .kill_on_drop(true) .spawn() diff --git a/crates/remote_server/src/main.rs b/crates/remote_server/src/main.rs index 872af42596..50f34fbd39 100644 --- a/crates/remote_server/src/main.rs +++ b/crates/remote_server/src/main.rs @@ -24,6 +24,8 @@ enum Commands { stdout_socket: PathBuf, }, Proxy { + #[arg(long)] + reconnect: bool, #[arg(long)] identifier: String, }, @@ -37,6 +39,7 @@ fn main() { #[cfg(not(windows))] fn main() -> Result<()> { + use remote::proxy::ProxyLaunchError; use remote_server::unix::{execute_proxy, execute_run, init}; let cli = Cli::parse(); @@ -51,9 +54,20 @@ fn main() -> Result<()> { init(Some(log_file))?; execute_run(pid_file, stdin_socket, stdout_socket) } - Some(Commands::Proxy { identifier }) => { + Some(Commands::Proxy { + identifier, + reconnect, + }) => { init(None)?; - execute_proxy(identifier) + match execute_proxy(identifier, reconnect) { + Ok(_) => Ok(()), + Err(err) => { + if let Some(err) = err.downcast_ref::() { + std::process::exit(err.to_exit_code()); + } + Err(err) + } + } } Some(Commands::Version) => { eprintln!("{}", env!("ZED_PKG_VERSION")); diff --git a/crates/remote_server/src/unix.rs b/crates/remote_server/src/unix.rs index 2e03887ae9..7a05268cc8 100644 --- a/crates/remote_server/src/unix.rs +++ b/crates/remote_server/src/unix.rs @@ -4,6 +4,7 @@ use fs::RealFs; use futures::channel::mpsc; use futures::{select, select_biased, AsyncRead, AsyncWrite, FutureExt, SinkExt}; use gpui::{AppContext, Context as _}; +use remote::proxy::ProxyLaunchError; use remote::ssh_session::ChannelClient; use remote::{ json_log::LogRecord, @@ -227,30 +228,62 @@ pub fn execute_run(pid_file: PathBuf, stdin_socket: PathBuf, stdout_socket: Path Ok(()) } -pub fn execute_proxy(identifier: String) -> Result<()> { +#[derive(Clone)] +struct ServerPaths { + log_file: PathBuf, + pid_file: PathBuf, + stdin_socket: PathBuf, + stdout_socket: PathBuf, +} + +impl ServerPaths { + fn new(identifier: &str) -> Result { + let project_dir = create_state_directory(identifier)?; + + let pid_file = project_dir.join("server.pid"); + let stdin_socket = project_dir.join("stdin.sock"); + let stdout_socket = project_dir.join("stdout.sock"); + let log_file = project_dir.join("server.log"); + + Ok(Self { + pid_file, + stdin_socket, + stdout_socket, + log_file, + }) + } +} + +pub fn execute_proxy(identifier: String, is_reconnecting: bool) -> Result<()> { log::debug!("proxy: starting up. PID: {}", std::process::id()); - let project_dir = ensure_project_dir(&identifier)?; + let server_paths = ServerPaths::new(&identifier)?; - let pid_file = project_dir.join("server.pid"); - let stdin_socket = project_dir.join("stdin.sock"); - let stdout_socket = project_dir.join("stdout.sock"); - let log_file = project_dir.join("server.log"); + let server_pid = check_pid_file(&server_paths.pid_file)?; + let server_running = server_pid.is_some(); + if is_reconnecting { + if !server_running { + log::error!("proxy: attempted to reconnect, but no server running"); + return Err(anyhow!(ProxyLaunchError::ServerNotRunning)); + } + } else { + if let Some(pid) = server_pid { + log::debug!("proxy: found server already running with PID {}. Killing process and cleaning up files...", pid); + kill_running_server(pid, &server_paths)?; + } - let server_running = check_pid_file(&pid_file)?; - if !server_running { - spawn_server(&log_file, &pid_file, &stdin_socket, &stdout_socket)?; - }; + spawn_server(&server_paths)?; + } let stdin_task = smol::spawn(async move { let stdin = Async::new(std::io::stdin())?; - let stream = smol::net::unix::UnixStream::connect(stdin_socket).await?; + let stream = smol::net::unix::UnixStream::connect(&server_paths.stdin_socket).await?; handle_io(stdin, stream, "stdin").await }); let stdout_task: smol::Task> = smol::spawn(async move { let stdout = Async::new(std::io::stdout())?; - let stream = smol::net::unix::UnixStream::connect(stdout_socket).await?; + let stream = smol::net::unix::UnixStream::connect(&server_paths.stdout_socket).await?; handle_io(stream, stdout, "stdout").await }); @@ -267,50 +300,62 @@ pub fn execute_proxy(identifier: String) -> Result<()> { Ok(()) } -fn ensure_project_dir(identifier: &str) -> Result { - let project_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string()); - let project_dir = PathBuf::from(project_dir) +fn create_state_directory(identifier: &str) -> Result { + let home_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string()); + let server_dir = PathBuf::from(home_dir) .join(".local") .join("state") .join("zed-remote-server") .join(identifier); - std::fs::create_dir_all(&project_dir)?; + std::fs::create_dir_all(&server_dir)?; - Ok(project_dir) + Ok(server_dir) } -fn spawn_server( - log_file: &Path, - pid_file: &Path, - stdin_socket: &Path, - stdout_socket: &Path, -) -> Result<()> { - if stdin_socket.exists() { - std::fs::remove_file(&stdin_socket)?; +fn kill_running_server(pid: u32, paths: &ServerPaths) -> Result<()> { + log::info!("proxy: killing existing server with PID {}", pid); + std::process::Command::new("kill") + .arg(pid.to_string()) + .output() + .context("proxy: failed to kill existing server")?; + + for file in [&paths.pid_file, &paths.stdin_socket, &paths.stdout_socket] { + log::debug!( + "proxy: cleaning up file {:?} before starting new server", + file + ); + std::fs::remove_file(file).ok(); } - if stdout_socket.exists() { - std::fs::remove_file(&stdout_socket)?; + Ok(()) +} + +fn spawn_server(paths: &ServerPaths) -> Result<()> { + if paths.stdin_socket.exists() { + std::fs::remove_file(&paths.stdin_socket)?; + } + if paths.stdout_socket.exists() { + std::fs::remove_file(&paths.stdout_socket)?; } let binary_name = std::env::current_exe()?; let server_process = std::process::Command::new(binary_name) .arg("run") .arg("--log-file") - .arg(log_file) + .arg(&paths.log_file) .arg("--pid-file") - .arg(pid_file) + .arg(&paths.pid_file) .arg("--stdin-socket") - .arg(stdin_socket) + .arg(&paths.stdin_socket) .arg("--stdout-socket") - .arg(stdout_socket) + .arg(&paths.stdout_socket) .spawn()?; log::debug!("proxy: server started. PID: {:?}", server_process.id()); let mut total_time_waited = std::time::Duration::from_secs(0); let wait_duration = std::time::Duration::from_millis(20); - while !stdout_socket.exists() || !stdin_socket.exists() { + while !paths.stdout_socket.exists() || !paths.stdin_socket.exists() { log::debug!("proxy: waiting for server to be ready to accept connections..."); std::thread::sleep(wait_duration); total_time_waited += wait_duration; @@ -323,12 +368,12 @@ fn spawn_server( Ok(()) } -fn check_pid_file(path: &Path) -> Result { +fn check_pid_file(path: &Path) -> Result> { let Some(pid) = std::fs::read_to_string(&path) .ok() .and_then(|contents| contents.parse::().ok()) else { - return Ok(false); + return Ok(None); }; log::debug!("proxy: Checking if process with PID {} exists...", pid); @@ -339,12 +384,12 @@ fn check_pid_file(path: &Path) -> Result { { Ok(output) if output.status.success() => { log::debug!("proxy: Process with PID {} exists. NOT spawning new server, but attaching to existing one.", pid); - Ok(true) + Ok(Some(pid)) } _ => { log::debug!("proxy: Found PID file, but process with that PID does not exist. Removing PID file."); std::fs::remove_file(&path).context("proxy: Failed to remove PID file")?; - Ok(false) + Ok(None) } } }