From c03b8d6c48b9f961ed30cb4d0ef97169fb30c1c4 Mon Sep 17 00:00:00 2001 From: Thorsten Ball Date: Mon, 7 Oct 2024 11:40:59 +0200 Subject: [PATCH] ssh remoting: Enable reconnecting after connection losses (#18586) Release Notes: - N/A --------- Co-authored-by: Bennet --- Cargo.lock | 2 + crates/collab/src/tests/test_server.rs | 2 +- crates/project/src/project.rs | 72 ++-- crates/project/src/terminals.rs | 2 +- crates/proto/proto/zed.proto | 6 +- crates/proto/src/proto.rs | 6 +- crates/recent_projects/src/dev_servers.rs | 22 +- crates/recent_projects/src/ssh_connections.rs | 58 +-- crates/remote/src/protocol.rs | 14 + crates/remote/src/ssh_session.rs | 191 ++++++---- crates/remote_server/Cargo.toml | 11 +- crates/remote_server/src/headless_project.rs | 19 + crates/remote_server/src/main.rs | 134 +++---- .../remote_server/src/remote_editing_tests.rs | 2 +- crates/remote_server/src/remote_server.rs | 3 + crates/remote_server/src/unix.rs | 336 ++++++++++++++++++ crates/title_bar/src/title_bar.rs | 2 +- crates/workspace/Cargo.toml | 1 + crates/workspace/src/workspace.rs | 84 +++-- 19 files changed, 727 insertions(+), 240 deletions(-) create mode 100644 crates/remote_server/src/unix.rs diff --git a/Cargo.lock b/Cargo.lock index 120f51ec93..1cac85e0c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9165,6 +9165,7 @@ version = "0.1.0" dependencies = [ "anyhow", "cargo_toml", + "clap", "client", "clock", "env_logger", @@ -14324,6 +14325,7 @@ dependencies = [ "parking_lot", "postage", "project", + "release_channel", "remote", "schemars", "serde", diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 5e7d935c36..8d2396eef0 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -835,7 +835,7 @@ impl TestClient { pub async fn build_ssh_project( &self, root_path: impl AsRef, - ssh: Arc, + ssh: Model, cx: &mut TestAppContext, ) -> (Model, WorktreeId) { let project = cx.update(|cx| { diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 4b8c050964..f2a8d59c6f 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -141,7 +141,7 @@ pub struct Project { join_project_response_message_id: u32, user_store: Model, fs: Arc, - ssh_client: Option>, + ssh_client: Option>, client_state: ProjectClientState, collaborators: HashMap, client_subscriptions: Vec, @@ -667,7 +667,7 @@ impl Project { } pub fn ssh( - ssh: Arc, + ssh: Model, client: Arc, node: NodeRuntime, user_store: Model, @@ -684,15 +684,16 @@ impl Project { let snippets = SnippetProvider::new(fs.clone(), BTreeSet::from_iter([global_snippets_dir]), cx); + let ssh_proto = ssh.read(cx).to_proto_client(); let worktree_store = - cx.new_model(|_| WorktreeStore::remote(false, ssh.to_proto_client(), 0, None)); + cx.new_model(|_| WorktreeStore::remote(false, ssh_proto.clone(), 0, None)); cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); let buffer_store = cx.new_model(|cx| { BufferStore::remote( worktree_store.clone(), - ssh.to_proto_client(), + ssh.read(cx).to_proto_client(), SSH_PROJECT_ID, cx, ) @@ -701,7 +702,7 @@ impl Project { .detach(); let settings_observer = cx.new_model(|cx| { - SettingsObserver::new_ssh(ssh.to_proto_client(), worktree_store.clone(), cx) + SettingsObserver::new_ssh(ssh_proto.clone(), worktree_store.clone(), cx) }); cx.subscribe(&settings_observer, Self::on_settings_observer_event) .detach(); @@ -712,13 +713,24 @@ impl Project { buffer_store.clone(), worktree_store.clone(), languages.clone(), - ssh.to_proto_client(), + ssh_proto.clone(), SSH_PROJECT_ID, cx, ) }); cx.subscribe(&lsp_store, Self::on_lsp_store_event).detach(); + cx.on_release(|this, cx| { + if let Some(ssh_client) = this.ssh_client.as_ref() { + ssh_client + .read(cx) + .to_proto_client() + .send(proto::ShutdownRemoteServer {}) + .log_err(); + } + }) + .detach(); + let this = Self { buffer_ordered_messages_tx: tx, collaborators: Default::default(), @@ -754,20 +766,20 @@ impl Project { search_excluded_history: Self::new_search_history(), }; - let client: AnyProtoClient = ssh.to_proto_client(); - + let ssh = ssh.read(cx); ssh.subscribe_to_entity(SSH_PROJECT_ID, &cx.handle()); ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.buffer_store); ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.worktree_store); ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.lsp_store); ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.settings_observer); - client.add_model_message_handler(Self::handle_create_buffer_for_peer); - client.add_model_message_handler(Self::handle_update_worktree); - client.add_model_message_handler(Self::handle_update_project); - client.add_model_request_handler(BufferStore::handle_update_buffer); - BufferStore::init(&client); - LspStore::init(&client); - SettingsObserver::init(&client); + + ssh_proto.add_model_message_handler(Self::handle_create_buffer_for_peer); + ssh_proto.add_model_message_handler(Self::handle_update_worktree); + ssh_proto.add_model_message_handler(Self::handle_update_project); + ssh_proto.add_model_request_handler(BufferStore::handle_update_buffer); + BufferStore::init(&ssh_proto); + LspStore::init(&ssh_proto); + SettingsObserver::init(&ssh_proto); this }) @@ -1222,7 +1234,7 @@ impl Project { pub fn ssh_connection_string(&self, cx: &AppContext) -> Option { if let Some(ssh_state) = &self.ssh_client { - return Some(ssh_state.connection_string().into()); + return Some(ssh_state.read(cx).connection_string().into()); } let dev_server_id = self.dev_server_project_id()?; dev_server_projects::Store::global(cx) @@ -1232,8 +1244,8 @@ impl Project { .clone() } - pub fn ssh_is_connected(&self) -> Option { - Some(!self.ssh_client.as_ref()?.is_reconnect_underway()) + pub fn ssh_is_connected(&self, cx: &AppContext) -> Option { + Some(!self.ssh_client.as_ref()?.read(cx).is_reconnect_underway()) } pub fn replica_id(&self) -> ReplicaId { @@ -1945,6 +1957,7 @@ impl Project { BufferStoreEvent::BufferDropped(buffer_id) => { if let Some(ref ssh_client) = self.ssh_client { ssh_client + .read(cx) .to_proto_client() .send(proto::CloseBuffer { project_id: 0, @@ -2151,7 +2164,8 @@ impl Project { let operation = language::proto::serialize_operation(operation); if let Some(ssh) = &self.ssh_client { - ssh.to_proto_client() + ssh.read(cx) + .to_proto_client() .send(proto::UpdateBuffer { project_id: 0, buffer_id: buffer_id.to_proto(), @@ -2838,7 +2852,7 @@ impl Project { let (tx, rx) = smol::channel::unbounded(); let (client, remote_id): (AnyProtoClient, _) = if let Some(ssh_client) = &self.ssh_client { - (ssh_client.to_proto_client(), 0) + (ssh_client.read(cx).to_proto_client(), 0) } else if let Some(remote_id) = self.remote_id() { (self.client.clone().into(), remote_id) } else { @@ -2973,12 +2987,14 @@ impl Project { exists.then(|| ResolvedPath::AbsPath(expanded)) }) } else if let Some(ssh_client) = self.ssh_client.as_ref() { - let request = ssh_client - .to_proto_client() - .request(proto::CheckFileExists { - project_id: SSH_PROJECT_ID, - path: path.to_string(), - }); + let request = + ssh_client + .read(cx) + .to_proto_client() + .request(proto::CheckFileExists { + project_id: SSH_PROJECT_ID, + path: path.to_string(), + }); cx.background_executor().spawn(async move { let response = request.await.log_err()?; if response.exists { @@ -3054,7 +3070,7 @@ impl Project { path: query, }; - let response = session.to_proto_client().request(request); + let response = session.read(cx).to_proto_client().request(request); cx.background_executor().spawn(async move { let response = response.await?; Ok(response.entries.into_iter().map(PathBuf::from).collect()) @@ -3482,7 +3498,7 @@ impl Project { let mut payload = envelope.payload.clone(); payload.project_id = 0; cx.background_executor() - .spawn(ssh.to_proto_client().request(payload)) + .spawn(ssh.read(cx).to_proto_client().request(payload)) .detach_and_log_err(cx); } this.buffer_store.clone() diff --git a/crates/project/src/terminals.rs b/crates/project/src/terminals.rs index 7175b75e22..ecac58fb85 100644 --- a/crates/project/src/terminals.rs +++ b/crates/project/src/terminals.rs @@ -70,7 +70,7 @@ impl Project { if let Some(args) = self .ssh_client .as_ref() - .and_then(|session| session.ssh_args()) + .and_then(|session| session.read(cx).ssh_args()) { return Some(SshCommand::Direct(args)); } diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index f6e9645e9c..4e101f4305 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -282,7 +282,9 @@ message Envelope { UpdateUserSettings update_user_settings = 246; CheckFileExists check_file_exists = 255; - CheckFileExistsResponse check_file_exists_response = 256; // current max + CheckFileExistsResponse check_file_exists_response = 256; + + ShutdownRemoteServer shutdown_remote_server = 257; // current max } reserved 87 to 88; @@ -2511,3 +2513,5 @@ message CheckFileExistsResponse { bool exists = 1; string path = 2; } + +message ShutdownRemoteServer {} diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index fe1725e0d1..48733c449c 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -364,7 +364,8 @@ messages!( (CloseBuffer, Foreground), (UpdateUserSettings, Foreground), (CheckFileExists, Background), - (CheckFileExistsResponse, Background) + (CheckFileExistsResponse, Background), + (ShutdownRemoteServer, Foreground), ); request_messages!( @@ -487,7 +488,8 @@ request_messages!( (SynchronizeContexts, SynchronizeContextsResponse), (LspExtSwitchSourceHeader, LspExtSwitchSourceHeaderResponse), (AddWorktree, AddWorktreeResponse), - (CheckFileExists, CheckFileExistsResponse) + (CheckFileExists, CheckFileExistsResponse), + (ShutdownRemoteServer, Ack) ); entity_messages!( diff --git a/crates/recent_projects/src/dev_servers.rs b/crates/recent_projects/src/dev_servers.rs index 2038d069b4..722743e0ff 100644 --- a/crates/recent_projects/src/dev_servers.rs +++ b/crates/recent_projects/src/dev_servers.rs @@ -305,13 +305,19 @@ impl DevServerProjects { let connection_options = remote::SshConnectionOptions { host: host.to_string(), - username, + username: username.clone(), port, password: None, }; let ssh_prompt = cx.new_view(|cx| SshPrompt::new(&connection_options, cx)); - let connection = connect_over_ssh(connection_options.clone(), ssh_prompt.clone(), cx) - .prompt_err("Failed to connect", cx, |_, _| None); + + let connection = connect_over_ssh( + connection_options.dev_server_identifier(), + connection_options.clone(), + ssh_prompt.clone(), + cx, + ) + .prompt_err("Failed to connect", cx, |_, _| None); let creating = cx.spawn(move |this, mut cx| async move { match connection.await { @@ -363,11 +369,13 @@ impl DevServerProjects { .prompt .clone(); - let connect = connect_over_ssh(connection_options, prompt, cx).prompt_err( - "Failed to connect", + let connect = connect_over_ssh( + connection_options.dev_server_identifier(), + connection_options, + prompt, cx, - |_, _| None, - ); + ) + .prompt_err("Failed to connect", cx, |_, _| None); cx.spawn(|workspace, mut cx| async move { let Some(session) = connect.await else { workspace diff --git a/crates/recent_projects/src/ssh_connections.rs b/crates/recent_projects/src/ssh_connections.rs index d0fffc031f..9e50523773 100644 --- a/crates/recent_projects/src/ssh_connections.rs +++ b/crates/recent_projects/src/ssh_connections.rs @@ -4,12 +4,12 @@ use anyhow::Result; use auto_update::AutoUpdater; use editor::Editor; use futures::channel::oneshot; -use gpui::AppContext; use gpui::{ percentage, px, Animation, AnimationExt, AnyWindowHandle, AsyncAppContext, DismissEvent, EventEmitter, FocusableView, ParentElement as _, Render, SemanticVersion, SharedString, Task, Transformation, View, }; +use gpui::{AppContext, Model}; use release_channel::{AppVersion, ReleaseChannel}; use remote::{SshConnectionOptions, SshPlatform, SshRemoteClient}; use schemars::JsonSchema; @@ -373,25 +373,24 @@ impl SshClientDelegate { } pub fn connect_over_ssh( + unique_identifier: String, connection_options: SshConnectionOptions, ui: View, cx: &mut WindowContext, -) -> Task>> { +) -> Task>> { let window = cx.window_handle(); let known_password = connection_options.password.clone(); - cx.spawn(|mut cx| async move { - remote::SshRemoteClient::new( - connection_options, - Arc::new(SshClientDelegate { - window, - ui, - known_password, - }), - &mut cx, - ) - .await - }) + remote::SshRemoteClient::new( + unique_identifier, + connection_options, + Arc::new(SshClientDelegate { + window, + ui, + known_password, + }), + cx, + ) } pub async fn open_ssh_project( @@ -420,22 +419,25 @@ pub async fn open_ssh_project( })? }; - let session = window - .update(cx, |workspace, cx| { - cx.activate_window(); - workspace.toggle_modal(cx, |cx| SshConnectionModal::new(&connection_options, cx)); - let ui = workspace - .active_modal::(cx) - .unwrap() - .read(cx) - .prompt - .clone(); - connect_over_ssh(connection_options.clone(), ui, cx) - })? - .await?; + let delegate = window.update(cx, |workspace, cx| { + cx.activate_window(); + workspace.toggle_modal(cx, |cx| SshConnectionModal::new(&connection_options, cx)); + let ui = workspace + .active_modal::(cx) + .unwrap() + .read(cx) + .prompt + .clone(); + + Arc::new(SshClientDelegate { + window: cx.window_handle(), + ui, + known_password: connection_options.password.clone(), + }) + })?; cx.update(|cx| { - workspace::open_ssh_project(window, connection_options, session, app_state, paths, cx) + workspace::open_ssh_project(window, connection_options, delegate, app_state, paths, cx) })? .await } diff --git a/crates/remote/src/protocol.rs b/crates/remote/src/protocol.rs index bc495be4e7..311385f73b 100644 --- a/crates/remote/src/protocol.rs +++ b/crates/remote/src/protocol.rs @@ -49,3 +49,17 @@ pub async fn write_message( stream.write_all(buffer).await?; Ok(()) } + +pub async fn read_message_raw( + stream: &mut S, + buffer: &mut Vec, +) -> Result<()> { + buffer.resize(MESSAGE_LEN_SIZE, 0); + stream.read_exact(buffer).await?; + + let message_len = message_len_from_buffer(buffer); + buffer.resize(message_len as usize, 0); + stream.read_exact(buffer).await?; + + Ok(()) +} diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 32d5536b32..05208dabd7 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -15,7 +15,9 @@ use futures::{ select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt, StreamExt as _, }; -use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, Task}; +use gpui::{ + AppContext, AsyncAppContext, Context, Model, ModelContext, SemanticVersion, Task, WeakModel, +}; use parking_lot::Mutex; use rpc::{ proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage}, @@ -28,10 +30,11 @@ use smol::{ use std::{ any::TypeId, ffi::OsStr, + mem, path::{Path, PathBuf}, sync::{ atomic::{AtomicU32, Ordering::SeqCst}, - Arc, Weak, + Arc, }, time::Instant, }; @@ -92,6 +95,17 @@ impl SshConnectionOptions { host } } + + // Uniquely identifies dev server projects on a remote host. Needs to be + // stable for the same dev server project. + pub fn dev_server_identifier(&self) -> String { + let mut identifier = format!("dev-server-{:?}", self.host); + if let Some(username) = self.username.as_ref() { + identifier.push('-'); + identifier.push_str(&username); + } + identifier + } } #[derive(Copy, Clone, Debug)] @@ -250,59 +264,101 @@ struct SshRemoteClientState { pub struct SshRemoteClient { client: Arc, - inner_state: Mutex>, + unique_identifier: String, connection_options: SshConnectionOptions, + inner_state: Arc>>, +} + +impl Drop for SshRemoteClient { + fn drop(&mut self) { + self.shutdown_processes(); + } } impl SshRemoteClient { - pub async fn new( + pub fn new( + unique_identifier: String, connection_options: SshConnectionOptions, delegate: Arc, - cx: &mut AsyncAppContext, - ) -> Result> { - let (outgoing_tx, outgoing_rx) = mpsc::unbounded::(); - let (incoming_tx, incoming_rx) = mpsc::unbounded::(); + cx: &AppContext, + ) -> Task>> { + cx.spawn(|mut cx| async move { + let (outgoing_tx, outgoing_rx) = mpsc::unbounded::(); + let (incoming_tx, incoming_rx) = mpsc::unbounded::(); - let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?; - let this = Arc::new(Self { - client, - inner_state: Mutex::new(None), - connection_options: connection_options.clone(), - }); + let this = cx.new_model(|cx| { + cx.on_app_quit(|this: &mut Self, _| { + this.shutdown_processes(); + futures::future::ready(()) + }) + .detach(); - let inner_state = { - let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = - ChannelForwarder::new(incoming_tx, outgoing_rx, cx); + let client = ChannelClient::new(incoming_rx, outgoing_tx, cx); + Self { + client, + unique_identifier: unique_identifier.clone(), + connection_options: SshConnectionOptions::default(), + inner_state: Arc::new(Mutex::new(None)), + } + })?; - let (ssh_connection, ssh_process) = - Self::establish_connection(connection_options, delegate.clone(), cx).await?; + let inner_state = { + let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = + ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); - let multiplex_task = Self::multiplex( - Arc::downgrade(&this), - ssh_process, - proxy_incoming_tx, - proxy_outgoing_rx, - cx, - ); + let (ssh_connection, ssh_proxy_process) = Self::establish_connection( + unique_identifier, + connection_options, + delegate.clone(), + &mut cx, + ) + .await?; - SshRemoteClientState { - ssh_connection, - delegate, - forwarder: proxy, - multiplex_task, - } - }; + let multiplex_task = Self::multiplex( + this.downgrade(), + ssh_proxy_process, + proxy_incoming_tx, + proxy_outgoing_rx, + &mut cx, + ); - this.inner_state.lock().replace(inner_state); + SshRemoteClientState { + ssh_connection, + delegate, + forwarder: proxy, + multiplex_task, + } + }; - Ok(this) + this.update(&mut cx, |this, cx| { + this.inner_state.lock().replace(inner_state); + cx.notify(); + })?; + + Ok(this) + }) } - fn reconnect(this: Arc, cx: &AsyncAppContext) -> Result<()> { - let Some(state) = this.inner_state.lock().take() else { + fn shutdown_processes(&self) { + let Some(mut state) = self.inner_state.lock().take() else { + return; + }; + log::info!("shutting down ssh processes"); + // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a + // child of master_process. + let task = mem::replace(&mut state.multiplex_task, Task::ready(Ok(()))); + drop(task); + // Now drop the rest of state, which kills master process. + drop(state); + } + + fn reconnect(&self, cx: &ModelContext) -> Result<()> { + let Some(state) = self.inner_state.lock().take() else { return Err(anyhow!("reconnect is already in progress")); }; + let workspace_identifier = self.unique_identifier.clone(); + let SshRemoteClientState { mut ssh_connection, delegate, @@ -311,7 +367,7 @@ impl SshRemoteClient { } = state; drop(multiplex_task); - cx.spawn(|mut cx| async move { + cx.spawn(|this, mut cx| async move { let (incoming_tx, outgoing_rx) = proxy.into_channels().await; ssh_connection.master_process.kill()?; @@ -323,8 +379,13 @@ impl SshRemoteClient { let connection_options = ssh_connection.socket.connection_options.clone(); - let (ssh_connection, ssh_process) = - Self::establish_connection(connection_options, delegate.clone(), &mut cx).await?; + let (ssh_connection, ssh_process) = Self::establish_connection( + workspace_identifier, + connection_options, + delegate.clone(), + &mut cx, + ) + .await?; let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); @@ -334,32 +395,32 @@ impl SshRemoteClient { delegate, forwarder: proxy, multiplex_task: Self::multiplex( - Arc::downgrade(&this), + this.clone(), ssh_process, proxy_incoming_tx, proxy_outgoing_rx, &mut cx, ), }; - this.inner_state.lock().replace(inner_state); - anyhow::Ok(()) + this.update(&mut cx, |this, _| { + this.inner_state.lock().replace(inner_state); + }) }) .detach(); - - anyhow::Ok(()) + Ok(()) } fn multiplex( - this: Weak, - mut ssh_process: Child, + this: WeakModel, + mut ssh_proxy_process: Child, incoming_tx: UnboundedSender, mut outgoing_rx: UnboundedReceiver, cx: &AsyncAppContext, ) -> Task> { - let mut child_stderr = ssh_process.stderr.take().unwrap(); - let mut child_stdout = ssh_process.stdout.take().unwrap(); - let mut child_stdin = ssh_process.stdin.take().unwrap(); + let mut child_stderr = ssh_proxy_process.stderr.take().unwrap(); + let mut child_stdout = ssh_proxy_process.stdout.take().unwrap(); + let mut child_stdin = ssh_proxy_process.stdin.take().unwrap(); let io_task = cx.background_executor().spawn(async move { let mut stdin_buffer = Vec::new(); @@ -385,7 +446,7 @@ impl SshRemoteClient { Ok(0) => { child_stdin.close().await?; outgoing_rx.close(); - let status = ssh_process.status().await?; + let status = ssh_proxy_process.status().await?; if !status.success() { log::error!("ssh process exited with status: {status:?}"); return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code())); @@ -446,9 +507,9 @@ impl SshRemoteClient { if let Err(error) = result { log::warn!("ssh io task died with error: {:?}. reconnecting...", error); - if let Some(this) = this.upgrade() { - Self::reconnect(this, &mut cx).ok(); - } + this.update(&mut cx, |this, cx| { + this.reconnect(cx).ok(); + })?; } Ok(()) @@ -456,6 +517,7 @@ impl SshRemoteClient { } async fn establish_connection( + unique_identifier: String, connection_options: SshConnectionOptions, delegate: Arc, cx: &mut AsyncAppContext, @@ -479,17 +541,22 @@ impl SshRemoteClient { let socket = ssh_connection.socket.clone(); run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?; - let ssh_process = socket + delegate.set_status(Some("Starting proxy"), cx); + + let ssh_proxy_process = socket .ssh_command(format!( - "RUST_LOG={} RUST_BACKTRACE={} {:?} run", + "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}", std::env::var("RUST_LOG").unwrap_or_default(), std::env::var("RUST_BACKTRACE").unwrap_or_default(), remote_binary_path, + unique_identifier, )) + // IMPORTANT: we kill this process when we drop the task that uses it. + .kill_on_drop(true) .spawn() .context("failed to spawn remote server")?; - Ok((ssh_connection, ssh_process)) + Ok((ssh_connection, ssh_proxy_process)) } pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { @@ -514,21 +581,25 @@ impl SshRemoteClient { pub fn is_reconnect_underway(&self) -> bool { maybe!({ Some(self.inner_state.try_lock()?.is_none()) }).unwrap_or_default() } + #[cfg(any(test, feature = "test-support"))] pub fn fake( client_cx: &mut gpui::TestAppContext, server_cx: &mut gpui::TestAppContext, - ) -> (Arc, Arc) { + ) -> (Model, Arc) { + use gpui::Context; + let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded(); let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded(); ( client_cx.update(|cx| { let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx); - Arc::new(Self { + cx.new_model(|_| Self { client, - inner_state: Mutex::new(None), + unique_identifier: "fake".to_string(), connection_options: SshConnectionOptions::default(), + inner_state: Arc::new(Mutex::new(None)), }) }), server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)), diff --git a/crates/remote_server/Cargo.toml b/crates/remote_server/Cargo.toml index b15970042d..211b76e091 100644 --- a/crates/remote_server/Cargo.toml +++ b/crates/remote_server/Cargo.toml @@ -22,25 +22,26 @@ test-support = ["fs/test-support"] [dependencies] anyhow.workspace = true +clap.workspace = true client.workspace = true env_logger.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true -node_runtime.workspace = true +language.workspace = true +languages.workspace = true log.workspace = true +node_runtime.workspace = true project.workspace = true remote.workspace = true rpc.workspace = true -settings.workspace = true serde.workspace = true serde_json.workspace = true +settings.workspace = true shellexpand.workspace = true smol.workspace = true -worktree.workspace = true -language.workspace = true -languages.workspace = true util.workspace = true +worktree.workspace = true [dev-dependencies] client = { workspace = true, features = ["test-support"] } diff --git a/crates/remote_server/src/headless_project.rs b/crates/remote_server/src/headless_project.rs index 39540b04e0..66f9ca7ed5 100644 --- a/crates/remote_server/src/headless_project.rs +++ b/crates/remote_server/src/headless_project.rs @@ -112,6 +112,7 @@ impl HeadlessProject { client.add_request_handler(cx.weak_model(), Self::handle_list_remote_directory); client.add_request_handler(cx.weak_model(), Self::handle_check_file_exists); + client.add_request_handler(cx.weak_model(), Self::handle_shutdown_remote_server); client.add_model_request_handler(Self::handle_add_worktree); client.add_model_request_handler(Self::handle_open_buffer_by_path); @@ -335,4 +336,22 @@ impl HeadlessProject { path: expanded, }) } + + pub async fn handle_shutdown_remote_server( + _this: Model, + _envelope: TypedEnvelope, + cx: AsyncAppContext, + ) -> Result { + cx.spawn(|cx| async move { + cx.update(|cx| { + // TODO: This is a hack, because in a headless project, shutdown isn't executed + // when calling quit, but it should be. + cx.shutdown(); + cx.quit(); + }) + }) + .detach(); + + Ok(proto::Ack {}) + } } diff --git a/crates/remote_server/src/main.rs b/crates/remote_server/src/main.rs index 73b8a91da1..e5582d9b1f 100644 --- a/crates/remote_server/src/main.rs +++ b/crates/remote_server/src/main.rs @@ -1,20 +1,34 @@ #![cfg_attr(target_os = "windows", allow(unused, dead_code))] -use fs::RealFs; -use futures::channel::mpsc; -use gpui::Context as _; -use remote::{ - json_log::LogRecord, - protocol::{read_message, write_message}, -}; -use remote_server::HeadlessProject; -use smol::{io::AsyncWriteExt, stream::StreamExt as _, Async}; -use std::{ - env, - io::{self, Write}, - mem, process, - sync::Arc, -}; +use anyhow::Result; +use clap::{Parser, Subcommand}; +use std::path::PathBuf; + +#[derive(Parser)] +#[command(disable_version_flag = true)] +struct Cli { + #[command(subcommand)] + command: Option, +} + +#[derive(Subcommand)] +enum Commands { + Run { + #[arg(long)] + log_file: PathBuf, + #[arg(long)] + pid_file: PathBuf, + #[arg(long)] + stdin_socket: PathBuf, + #[arg(long)] + stdout_socket: PathBuf, + }, + Proxy { + #[arg(long)] + identifier: String, + }, + Version, +} #[cfg(windows)] fn main() { @@ -22,76 +36,32 @@ fn main() { } #[cfg(not(windows))] -fn main() { - use remote::ssh_session::ChannelClient; +fn main() -> Result<()> { + use remote_server::unix::{execute_proxy, execute_run, init_logging}; - env_logger::builder() - .format(|buf, record| { - serde_json::to_writer(&mut *buf, &LogRecord::new(record))?; - buf.write_all(b"\n")?; - Ok(()) - }) - .init(); + let cli = Cli::parse(); - let subcommand = std::env::args().nth(1); - match subcommand.as_deref() { - Some("run") => {} - Some("version") => { - println!("{}", env!("ZED_PKG_VERSION")); - return; + match cli.command { + Some(Commands::Run { + log_file, + pid_file, + stdin_socket, + stdout_socket, + }) => { + init_logging(Some(log_file))?; + execute_run(pid_file, stdin_socket, stdout_socket) } - _ => { - eprintln!("usage: remote "); - process::exit(1); + Some(Commands::Proxy { identifier }) => { + init_logging(None)?; + execute_proxy(identifier) + } + Some(Commands::Version) => { + eprintln!("{}", env!("ZED_PKG_VERSION")); + Ok(()) + } + None => { + eprintln!("usage: remote "); + std::process::exit(1); } } - - gpui::App::headless().run(move |cx| { - settings::init(cx); - HeadlessProject::init(cx); - - let (incoming_tx, incoming_rx) = mpsc::unbounded(); - let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded(); - - let mut stdin = Async::new(io::stdin()).unwrap(); - let mut stdout = Async::new(io::stdout()).unwrap(); - - let session = ChannelClient::new(incoming_rx, outgoing_tx, cx); - let project = cx.new_model(|cx| { - HeadlessProject::new( - session.clone(), - Arc::new(RealFs::new(Default::default(), None)), - cx, - ) - }); - - cx.background_executor() - .spawn(async move { - let mut output_buffer = Vec::new(); - while let Some(message) = outgoing_rx.next().await { - write_message(&mut stdout, &mut output_buffer, message).await?; - stdout.flush().await?; - } - anyhow::Ok(()) - }) - .detach(); - - cx.background_executor() - .spawn(async move { - let mut input_buffer = Vec::new(); - loop { - let message = match read_message(&mut stdin, &mut input_buffer).await { - Ok(message) => message, - Err(error) => { - log::warn!("error reading message: {:?}", error); - process::exit(0); - } - }; - incoming_tx.unbounded_send(message).ok(); - } - }) - .detach(); - - mem::forget(project); - }); } diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 960b7c248c..6e962a134a 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -655,7 +655,7 @@ async fn init_test( (project, headless, fs) } -fn build_project(ssh: Arc, cx: &mut TestAppContext) -> Model { +fn build_project(ssh: Model, cx: &mut TestAppContext) -> Model { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); diff --git a/crates/remote_server/src/remote_server.rs b/crates/remote_server/src/remote_server.rs index 0aa36b0cd1..2321ee1c6e 100644 --- a/crates/remote_server/src/remote_server.rs +++ b/crates/remote_server/src/remote_server.rs @@ -1,5 +1,8 @@ mod headless_project; +#[cfg(not(windows))] +pub mod unix; + #[cfg(test)] mod remote_editing_tests; diff --git a/crates/remote_server/src/unix.rs b/crates/remote_server/src/unix.rs new file mode 100644 index 0000000000..74b71a2277 --- /dev/null +++ b/crates/remote_server/src/unix.rs @@ -0,0 +1,336 @@ +use crate::HeadlessProject; +use anyhow::{anyhow, Context, Result}; +use fs::RealFs; +use futures::channel::mpsc; +use futures::{select, select_biased, AsyncRead, AsyncWrite, FutureExt, SinkExt}; +use gpui::{AppContext, Context as _}; +use remote::ssh_session::ChannelClient; +use remote::{ + json_log::LogRecord, + protocol::{read_message, write_message}, +}; +use rpc::proto::Envelope; +use smol::Async; +use smol::{io::AsyncWriteExt, net::unix::UnixListener, stream::StreamExt as _}; +use std::{ + env, + io::Write, + mem, + path::{Path, PathBuf}, + sync::Arc, +}; + +pub fn init_logging(log_file: Option) -> Result<()> { + if let Some(log_file) = log_file { + let target = Box::new(if log_file.exists() { + std::fs::OpenOptions::new() + .append(true) + .open(&log_file) + .context("Failed to open log file in append mode")? + } else { + std::fs::File::create(&log_file).context("Failed to create log file")? + }); + + env_logger::Builder::from_default_env() + .target(env_logger::Target::Pipe(target)) + .init(); + } else { + env_logger::builder() + .format(|buf, record| { + serde_json::to_writer(&mut *buf, &LogRecord::new(record))?; + buf.write_all(b"\n")?; + Ok(()) + }) + .init(); + } + Ok(()) +} + +fn start_server( + stdin_listener: UnixListener, + stdout_listener: UnixListener, + cx: &mut AppContext, +) -> Arc { + // This is the server idle timeout. If no connection comes in in this timeout, the server will shut down. + const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10 * 60); + + let (incoming_tx, incoming_rx) = mpsc::unbounded::(); + let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::(); + let (app_quit_tx, mut app_quit_rx) = mpsc::unbounded::<()>(); + + cx.on_app_quit(move |_| { + let mut app_quit_tx = app_quit_tx.clone(); + async move { + app_quit_tx.send(()).await.ok(); + } + }) + .detach(); + + cx.spawn(|cx| async move { + let mut stdin_incoming = stdin_listener.incoming(); + let mut stdout_incoming = stdout_listener.incoming(); + + loop { + let streams = futures::future::join(stdin_incoming.next(), stdout_incoming.next()); + + log::info!("server: accepting new connections"); + let result = select! { + streams = streams.fuse() => { + let (Some(Ok(stdin_stream)), Some(Ok(stdout_stream))) = streams else { + break; + }; + anyhow::Ok((stdin_stream, stdout_stream)) + } + _ = futures::FutureExt::fuse(smol::Timer::after(IDLE_TIMEOUT)) => { + log::warn!("server: timed out waiting for new connections after {:?}. exiting.", IDLE_TIMEOUT); + cx.update(|cx| { + // TODO: This is a hack, because in a headless project, shutdown isn't executed + // when calling quit, but it should be. + cx.shutdown(); + cx.quit(); + })?; + break; + } + _ = app_quit_rx.next().fuse() => { + break; + } + }; + + let Ok((mut stdin_stream, mut stdout_stream)) = result else { + break; + }; + + let mut input_buffer = Vec::new(); + let mut output_buffer = Vec::new(); + loop { + select_biased! { + _ = app_quit_rx.next().fuse() => { + return anyhow::Ok(()); + } + + stdin_message = read_message(&mut stdin_stream, &mut input_buffer).fuse() => { + let message = match stdin_message { + Ok(message) => message, + Err(error) => { + log::warn!("server: error reading message on stdin: {}. exiting.", error); + break; + } + }; + if let Err(error) = incoming_tx.unbounded_send(message) { + log::error!("server: failed to send message to application: {:?}. exiting.", error); + return Err(anyhow!(error)); + } + } + + outgoing_message = outgoing_rx.next().fuse() => { + let Some(message) = outgoing_message else { + log::error!("server: stdout handler, no message"); + break; + }; + + if let Err(error) = + write_message(&mut stdout_stream, &mut output_buffer, message).await + { + log::error!("server: failed to write stdout message: {:?}", error); + break; + } + if let Err(error) = stdout_stream.flush().await { + log::error!("server: failed to flush stdout message: {:?}", error); + break; + } + } + } + } + } + anyhow::Ok(()) + }) + .detach(); + + ChannelClient::new(incoming_rx, outgoing_tx, cx) +} + +pub fn execute_run(pid_file: PathBuf, stdin_socket: PathBuf, stdout_socket: PathBuf) -> Result<()> { + write_pid_file(&pid_file) + .with_context(|| format!("failed to write pid file: {:?}", &pid_file))?; + + let stdin_listener = UnixListener::bind(stdin_socket).context("failed to bind stdin socket")?; + let stdout_listener = + UnixListener::bind(stdout_socket).context("failed to bind stdout socket")?; + + gpui::App::headless().run(move |cx| { + settings::init(cx); + HeadlessProject::init(cx); + + let session = start_server(stdin_listener, stdout_listener, cx); + let project = cx.new_model(|cx| { + HeadlessProject::new(session, Arc::new(RealFs::new(Default::default(), None)), cx) + }); + + mem::forget(project); + }); + log::info!("server: gpui app is shut down. quitting."); + Ok(()) +} + +pub fn execute_proxy(identifier: String) -> Result<()> { + log::debug!("proxy: starting up. PID: {}", std::process::id()); + + let project_dir = ensure_project_dir(&identifier)?; + + let pid_file = project_dir.join("server.pid"); + let stdin_socket = project_dir.join("stdin.sock"); + let stdout_socket = project_dir.join("stdout.sock"); + let log_file = project_dir.join("server.log"); + + let server_running = check_pid_file(&pid_file)?; + if !server_running { + spawn_server(&log_file, &pid_file, &stdin_socket, &stdout_socket)?; + }; + + let stdin_task = smol::spawn(async move { + let stdin = Async::new(std::io::stdin())?; + let stream = smol::net::unix::UnixStream::connect(stdin_socket).await?; + handle_io(stdin, stream, "stdin").await + }); + + let stdout_task: smol::Task> = smol::spawn(async move { + let stdout = Async::new(std::io::stdout())?; + let stream = smol::net::unix::UnixStream::connect(stdout_socket).await?; + handle_io(stream, stdout, "stdout").await + }); + + if let Err(forwarding_result) = + smol::block_on(async move { smol::future::race(stdin_task, stdout_task).await }) + { + log::error!( + "proxy: failed to forward messages: {:?}, terminating...", + forwarding_result + ); + return Err(forwarding_result); + } + + Ok(()) +} + +fn ensure_project_dir(identifier: &str) -> Result { + let project_dir = env::var("HOME").unwrap_or_else(|_| ".".to_string()); + let project_dir = PathBuf::from(project_dir) + .join(".local") + .join("state") + .join("zed-remote-server") + .join(identifier); + + std::fs::create_dir_all(&project_dir)?; + + Ok(project_dir) +} + +fn spawn_server( + log_file: &Path, + pid_file: &Path, + stdin_socket: &Path, + stdout_socket: &Path, +) -> Result<()> { + if stdin_socket.exists() { + std::fs::remove_file(&stdin_socket)?; + } + if stdout_socket.exists() { + std::fs::remove_file(&stdout_socket)?; + } + + let binary_name = std::env::current_exe()?; + let server_process = std::process::Command::new(binary_name) + .arg("run") + .arg("--log-file") + .arg(log_file) + .arg("--pid-file") + .arg(pid_file) + .arg("--stdin-socket") + .arg(stdin_socket) + .arg("--stdout-socket") + .arg(stdout_socket) + .spawn()?; + + log::debug!("proxy: server started. PID: {:?}", server_process.id()); + + let mut total_time_waited = std::time::Duration::from_secs(0); + let wait_duration = std::time::Duration::from_millis(20); + while !stdout_socket.exists() || !stdin_socket.exists() { + log::debug!("proxy: waiting for server to be ready to accept connections..."); + std::thread::sleep(wait_duration); + total_time_waited += wait_duration; + } + + log::info!( + "proxy: server ready to accept connections. total time waited: {:?}", + total_time_waited + ); + Ok(()) +} + +fn check_pid_file(path: &Path) -> Result { + let Some(pid) = std::fs::read_to_string(&path) + .ok() + .and_then(|contents| contents.parse::().ok()) + else { + return Ok(false); + }; + + log::debug!("proxy: Checking if process with PID {} exists...", pid); + match std::process::Command::new("kill") + .arg("-0") + .arg(pid.to_string()) + .output() + { + Ok(output) if output.status.success() => { + log::debug!("proxy: Process with PID {} exists. NOT spawning new server, but attaching to existing one.", pid); + Ok(true) + } + _ => { + log::debug!("proxy: Found PID file, but process with that PID does not exist. Removing PID file."); + std::fs::remove_file(&path).context("proxy: Failed to remove PID file")?; + Ok(false) + } + } +} + +fn write_pid_file(path: &Path) -> Result<()> { + if path.exists() { + std::fs::remove_file(path)?; + } + + std::fs::write(path, std::process::id().to_string()).context("Failed to write PID file") +} + +async fn handle_io(mut reader: R, mut writer: W, socket_name: &str) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + use remote::protocol::read_message_raw; + + let mut buffer = Vec::new(); + loop { + read_message_raw(&mut reader, &mut buffer) + .await + .with_context(|| format!("proxy: failed to read message from {}", socket_name))?; + + write_size_prefixed_buffer(&mut writer, &mut buffer) + .await + .with_context(|| format!("proxy: failed to write message to {}", socket_name))?; + + writer.flush().await?; + + buffer.clear(); + } +} + +async fn write_size_prefixed_buffer( + stream: &mut S, + buffer: &mut Vec, +) -> Result<()> { + let len = buffer.len() as u32; + stream.write_all(len.to_le_bytes().as_slice()).await?; + stream.write_all(buffer).await?; + Ok(()) +} diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 81f908ce79..52dab68a2a 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -265,7 +265,7 @@ impl TitleBar { fn render_ssh_project_host(&self, cx: &mut ViewContext) -> Option { let host = self.project.read(cx).ssh_connection_string(cx)?; let meta = SharedString::from(format!("Connected to: {host}")); - let indicator_color = if self.project.read(cx).ssh_is_connected()? { + let indicator_color = if self.project.read(cx).ssh_is_connected(cx)? { Color::Success } else { Color::Warning diff --git a/crates/workspace/Cargo.toml b/crates/workspace/Cargo.toml index 1b998eeabe..47f6c138c8 100644 --- a/crates/workspace/Cargo.toml +++ b/crates/workspace/Cargo.toml @@ -51,6 +51,7 @@ postage.workspace = true project.workspace = true dev_server_projects.workspace = true task.workspace = true +release_channel.workspace = true remote.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index b668a5802c..d2ccd9cd4a 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -61,7 +61,8 @@ use postage::stream::Stream; use project::{ DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree, WorktreeId, }; -use remote::{SshConnectionOptions, SshRemoteClient}; +use release_channel::ReleaseChannel; +use remote::{SshClientDelegate, SshConnectionOptions}; use serde::Deserialize; use session::AppSession; use settings::{InvalidSettingsError, Settings}; @@ -5514,22 +5515,31 @@ pub fn join_hosted_project( pub fn open_ssh_project( window: WindowHandle, connection_options: SshConnectionOptions, - session: Arc, + delegate: Arc, app_state: Arc, paths: Vec, cx: &mut AppContext, ) -> Task> { + let release_channel = ReleaseChannel::global(cx); + cx.spawn(|mut cx| async move { - let serialized_ssh_project = persistence::DB - .get_or_create_ssh_project( - connection_options.host.clone(), - connection_options.port, - paths - .iter() - .map(|path| path.to_string_lossy().to_string()) - .collect::>(), - connection_options.username.clone(), - ) + let (serialized_ssh_project, workspace_id, serialized_workspace) = + serialize_ssh_project(connection_options.clone(), paths.clone(), &cx).await?; + + let identifier_prefix = match release_channel { + ReleaseChannel::Stable => None, + _ => Some(format!("{}-", release_channel.dev_name())), + }; + let unique_identifier = format!( + "{}workspace-{}", + identifier_prefix.unwrap_or_default(), + workspace_id.0 + ); + + let session = cx + .update(|cx| { + remote::SshRemoteClient::new(unique_identifier, connection_options, delegate, cx) + })? .await?; let project = cx.update(|cx| { @@ -5561,17 +5571,6 @@ pub fn open_ssh_project( }; } - let serialized_workspace = - persistence::DB.workspace_for_ssh_project(&serialized_ssh_project); - - let workspace_id = if let Some(workspace_id) = - serialized_workspace.as_ref().map(|workspace| workspace.id) - { - workspace_id - } else { - persistence::DB.next_id().await? - }; - cx.update_window(window.into(), |_, cx| { cx.replace_root_view(|cx| { let mut workspace = @@ -5603,6 +5602,45 @@ pub fn open_ssh_project( }) } +fn serialize_ssh_project( + connection_options: SshConnectionOptions, + paths: Vec, + cx: &AsyncAppContext, +) -> Task< + Result<( + SerializedSshProject, + WorkspaceId, + Option, + )>, +> { + cx.background_executor().spawn(async move { + let serialized_ssh_project = persistence::DB + .get_or_create_ssh_project( + connection_options.host.clone(), + connection_options.port, + paths + .iter() + .map(|path| path.to_string_lossy().to_string()) + .collect::>(), + connection_options.username.clone(), + ) + .await?; + + let serialized_workspace = + persistence::DB.workspace_for_ssh_project(&serialized_ssh_project); + + let workspace_id = if let Some(workspace_id) = + serialized_workspace.as_ref().map(|workspace| workspace.id) + { + workspace_id + } else { + persistence::DB.next_id().await? + }; + + Ok((serialized_ssh_project, workspace_id, serialized_workspace)) + }) +} + pub fn join_dev_server_project( dev_server_project_id: DevServerProjectId, project_id: ProjectId,