From ce216432be5a967feb0d30ee9878d0cf4fb07cb7 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 19 Aug 2025 17:33:56 -0700 Subject: [PATCH] Refactor ssh remoting - make ChannelClient type private (#36514) This PR is one step in a series of refactors to prepare for having "remote" projects that do not use SSH. The main use cases for this are WSL and dev containers. Release Notes: - N/A --- crates/editor/src/editor.rs | 5 +- crates/project/src/project.rs | 23 +-- crates/remote/src/ssh_session.rs | 146 +++++++++---------- crates/remote_server/src/headless_project.rs | 67 ++++----- crates/remote_server/src/unix.rs | 13 +- crates/rpc/src/proto_client.rs | 19 +++ crates/tasks_ui/src/tasks_ui.rs | 6 +- 7 files changed, 133 insertions(+), 146 deletions(-) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 3805904243..f943e64923 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -14895,10 +14895,7 @@ impl Editor { }; let hide_runnables = project - .update(cx, |project, cx| { - // Do not display any test indicators in non-dev server remote projects. - project.is_via_collab() && project.ssh_connection_string(cx).is_none() - }) + .update(cx, |project, _| project.is_via_collab()) .unwrap_or(true); if hide_runnables { return; diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 6712b3fab0..f07ee13866 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1346,14 +1346,13 @@ impl Project { }; // ssh -> local machine handlers - let ssh = ssh.read(cx); - ssh.subscribe_to_entity(SSH_PROJECT_ID, &cx.entity()); - ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.buffer_store); - ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.worktree_store); - ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.lsp_store); - ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.dap_store); - ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.settings_observer); - ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.git_store); + ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &cx.entity()); + ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.buffer_store); + ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.worktree_store); + ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.lsp_store); + ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.dap_store); + ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.settings_observer); + ssh_proto.subscribe_to_entity(SSH_PROJECT_ID, &this.git_store); ssh_proto.add_entity_message_handler(Self::handle_create_buffer_for_peer); ssh_proto.add_entity_message_handler(Self::handle_update_worktree); @@ -1900,14 +1899,6 @@ impl Project { false } - pub fn ssh_connection_string(&self, cx: &App) -> Option { - if let Some(ssh_state) = &self.ssh_client { - return Some(ssh_state.read(cx).connection_string().into()); - } - - None - } - pub fn ssh_connection_state(&self, cx: &App) -> Option { self.ssh_client .as_ref() diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index abde2d7568..ffd0cac310 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -26,8 +26,7 @@ use parking_lot::Mutex; use release_channel::{AppCommitSha, AppVersion, ReleaseChannel}; use rpc::{ - AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet, - RpcError, + AnyProtoClient, ErrorExt, ProtoClient, ProtoMessageHandlerSet, RpcError, proto::{self, Envelope, EnvelopedMessage, PeerId, RequestMessage, build_typed_envelope}, }; use schemars::JsonSchema; @@ -37,7 +36,6 @@ use smol::{ process::{self, Child, Stdio}, }; use std::{ - any::TypeId, collections::VecDeque, fmt, iter, ops::ControlFlow, @@ -664,6 +662,7 @@ impl ConnectionIdentifier { pub fn setup() -> Self { Self::Setup(NEXT_ID.fetch_add(1, SeqCst)) } + // This string gets used in a socket name, and so must be relatively short. // The total length of: // /home/{username}/.local/share/zed/server_state/{name}/stdout.sock @@ -760,6 +759,15 @@ impl SshRemoteClient { }) } + pub fn proto_client_from_channels( + incoming_rx: mpsc::UnboundedReceiver, + outgoing_tx: mpsc::UnboundedSender, + cx: &App, + name: &'static str, + ) -> AnyProtoClient { + ChannelClient::new(incoming_rx, outgoing_tx, cx, name).into() + } + pub fn shutdown_processes( &self, shutdown_request: Option, @@ -990,64 +998,63 @@ impl SshRemoteClient { }; cx.spawn(async move |cx| { - let mut missed_heartbeats = 0; + let mut missed_heartbeats = 0; - let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse(); - futures::pin_mut!(keepalive_timer); + let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse(); + futures::pin_mut!(keepalive_timer); - loop { - select_biased! { - result = connection_activity_rx.next().fuse() => { - if result.is_none() { - log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping."); - return Ok(()); - } - - if missed_heartbeats != 0 { - missed_heartbeats = 0; - let _ =this.update(cx, |this, cx| { - this.handle_heartbeat_result(missed_heartbeats, cx) - })?; - } + loop { + select_biased! { + result = connection_activity_rx.next().fuse() => { + if result.is_none() { + log::warn!("ssh heartbeat: connection activity channel has been dropped. stopping."); + return Ok(()); } - _ = keepalive_timer => { - log::debug!("Sending heartbeat to server..."); - let result = select_biased! { - _ = connection_activity_rx.next().fuse() => { - Ok(()) - } - ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => { - ping_result - } - }; - - if result.is_err() { - missed_heartbeats += 1; - log::warn!( - "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.", - HEARTBEAT_TIMEOUT, - missed_heartbeats, - MAX_MISSED_HEARTBEATS - ); - } else if missed_heartbeats != 0 { - missed_heartbeats = 0; - } else { - continue; - } - - let result = this.update(cx, |this, cx| { + if missed_heartbeats != 0 { + missed_heartbeats = 0; + let _ =this.update(cx, |this, cx| { this.handle_heartbeat_result(missed_heartbeats, cx) })?; - if result.is_break() { - return Ok(()); - } } } + _ = keepalive_timer => { + log::debug!("Sending heartbeat to server..."); - keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse()); + let result = select_biased! { + _ = connection_activity_rx.next().fuse() => { + Ok(()) + } + ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => { + ping_result + } + }; + + if result.is_err() { + missed_heartbeats += 1; + log::warn!( + "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.", + HEARTBEAT_TIMEOUT, + missed_heartbeats, + MAX_MISSED_HEARTBEATS + ); + } else if missed_heartbeats != 0 { + missed_heartbeats = 0; + } else { + continue; + } + + let result = this.update(cx, |this, cx| { + this.handle_heartbeat_result(missed_heartbeats, cx) + })?; + if result.is_break() { + return Ok(()); + } + } } + keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse()); + } }) } @@ -1145,10 +1152,6 @@ impl SshRemoteClient { cx.notify(); } - pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Entity) { - self.client.subscribe_to_entity(remote_id, entity); - } - pub fn ssh_info(&self) -> Option<(SshArgs, PathStyle)> { self.state .lock() @@ -1222,7 +1225,7 @@ impl SshRemoteClient { pub fn fake_server( client_cx: &mut gpui::TestAppContext, server_cx: &mut gpui::TestAppContext, - ) -> (SshConnectionOptions, Arc) { + ) -> (SshConnectionOptions, AnyProtoClient) { let port = client_cx .update(|cx| cx.default_global::().connections.len() as u16 + 1); let opts = SshConnectionOptions { @@ -1255,7 +1258,7 @@ impl SshRemoteClient { }) }); - (opts, server_client) + (opts, server_client.into()) } #[cfg(any(test, feature = "test-support"))] @@ -2269,7 +2272,7 @@ impl SshRemoteConnection { type ResponseChannels = Mutex)>>>; -pub struct ChannelClient { +struct ChannelClient { next_message_id: AtomicU32, outgoing_tx: Mutex>, buffer: Mutex>, @@ -2281,7 +2284,7 @@ pub struct ChannelClient { } impl ChannelClient { - pub fn new( + fn new( incoming_rx: mpsc::UnboundedReceiver, outgoing_tx: mpsc::UnboundedSender, cx: &App, @@ -2402,7 +2405,7 @@ impl ChannelClient { }) } - pub fn reconnect( + fn reconnect( self: &Arc, incoming_rx: UnboundedReceiver, outgoing_tx: UnboundedSender, @@ -2412,26 +2415,7 @@ impl ChannelClient { *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx); } - pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Entity) { - let id = (TypeId::of::(), remote_id); - - let mut message_handlers = self.message_handlers.lock(); - if message_handlers - .entities_by_type_and_remote_id - .contains_key(&id) - { - panic!("already subscribed to entity"); - } - - message_handlers.entities_by_type_and_remote_id.insert( - id, - EntityMessageSubscriber::Entity { - handle: entity.downgrade().into(), - }, - ); - } - - pub fn request( + fn request( &self, payload: T, ) -> impl 'static + Future> { @@ -2453,7 +2437,7 @@ impl ChannelClient { } } - pub async fn resync(&self, timeout: Duration) -> Result<()> { + async fn resync(&self, timeout: Duration) -> Result<()> { smol::future::or( async { self.request_internal(proto::FlushBufferedMessages {}, false) @@ -2475,7 +2459,7 @@ impl ChannelClient { .await } - pub async fn ping(&self, timeout: Duration) -> Result<()> { + async fn ping(&self, timeout: Duration) -> Result<()> { smol::future::or( async { self.request(proto::Ping {}).await?; diff --git a/crates/remote_server/src/headless_project.rs b/crates/remote_server/src/headless_project.rs index 6fc327ac1c..3bcdcbd73c 100644 --- a/crates/remote_server/src/headless_project.rs +++ b/crates/remote_server/src/headless_project.rs @@ -19,7 +19,6 @@ use project::{ task_store::TaskStore, worktree_store::WorktreeStore, }; -use remote::ssh_session::ChannelClient; use rpc::{ AnyProtoClient, TypedEnvelope, proto::{self, SSH_PEER_ID, SSH_PROJECT_ID}, @@ -50,7 +49,7 @@ pub struct HeadlessProject { } pub struct HeadlessAppState { - pub session: Arc, + pub session: AnyProtoClient, pub fs: Arc, pub http_client: Arc, pub node_runtime: NodeRuntime, @@ -81,7 +80,7 @@ impl HeadlessProject { let worktree_store = cx.new(|cx| { let mut store = WorktreeStore::local(true, fs.clone()); - store.shared(SSH_PROJECT_ID, session.clone().into(), cx); + store.shared(SSH_PROJECT_ID, session.clone(), cx); store }); @@ -99,7 +98,7 @@ impl HeadlessProject { let buffer_store = cx.new(|cx| { let mut buffer_store = BufferStore::local(worktree_store.clone(), cx); - buffer_store.shared(SSH_PROJECT_ID, session.clone().into(), cx); + buffer_store.shared(SSH_PROJECT_ID, session.clone(), cx); buffer_store }); @@ -117,7 +116,7 @@ impl HeadlessProject { breakpoint_store.clone(), cx, ); - dap_store.shared(SSH_PROJECT_ID, session.clone().into(), cx); + dap_store.shared(SSH_PROJECT_ID, session.clone(), cx); dap_store }); @@ -129,7 +128,7 @@ impl HeadlessProject { fs.clone(), cx, ); - store.shared(SSH_PROJECT_ID, session.clone().into(), cx); + store.shared(SSH_PROJECT_ID, session.clone(), cx); store }); @@ -152,7 +151,7 @@ impl HeadlessProject { environment.clone(), cx, ); - task_store.shared(SSH_PROJECT_ID, session.clone().into(), cx); + task_store.shared(SSH_PROJECT_ID, session.clone(), cx); task_store }); let settings_observer = cx.new(|cx| { @@ -162,7 +161,7 @@ impl HeadlessProject { task_store.clone(), cx, ); - observer.shared(SSH_PROJECT_ID, session.clone().into(), cx); + observer.shared(SSH_PROJECT_ID, session.clone(), cx); observer }); @@ -183,7 +182,7 @@ impl HeadlessProject { fs.clone(), cx, ); - lsp_store.shared(SSH_PROJECT_ID, session.clone().into(), cx); + lsp_store.shared(SSH_PROJECT_ID, session.clone(), cx); lsp_store }); @@ -210,8 +209,6 @@ impl HeadlessProject { cx, ); - let client: AnyProtoClient = session.clone().into(); - // local_machine -> ssh handlers session.subscribe_to_entity(SSH_PROJECT_ID, &worktree_store); session.subscribe_to_entity(SSH_PROJECT_ID, &buffer_store); @@ -223,44 +220,45 @@ impl HeadlessProject { session.subscribe_to_entity(SSH_PROJECT_ID, &settings_observer); session.subscribe_to_entity(SSH_PROJECT_ID, &git_store); - client.add_request_handler(cx.weak_entity(), Self::handle_list_remote_directory); - client.add_request_handler(cx.weak_entity(), Self::handle_get_path_metadata); - client.add_request_handler(cx.weak_entity(), Self::handle_shutdown_remote_server); - client.add_request_handler(cx.weak_entity(), Self::handle_ping); + session.add_request_handler(cx.weak_entity(), Self::handle_list_remote_directory); + session.add_request_handler(cx.weak_entity(), Self::handle_get_path_metadata); + session.add_request_handler(cx.weak_entity(), Self::handle_shutdown_remote_server); + session.add_request_handler(cx.weak_entity(), Self::handle_ping); - client.add_entity_request_handler(Self::handle_add_worktree); - client.add_request_handler(cx.weak_entity(), Self::handle_remove_worktree); + session.add_entity_request_handler(Self::handle_add_worktree); + session.add_request_handler(cx.weak_entity(), Self::handle_remove_worktree); - client.add_entity_request_handler(Self::handle_open_buffer_by_path); - client.add_entity_request_handler(Self::handle_open_new_buffer); - client.add_entity_request_handler(Self::handle_find_search_candidates); - client.add_entity_request_handler(Self::handle_open_server_settings); + session.add_entity_request_handler(Self::handle_open_buffer_by_path); + session.add_entity_request_handler(Self::handle_open_new_buffer); + session.add_entity_request_handler(Self::handle_find_search_candidates); + session.add_entity_request_handler(Self::handle_open_server_settings); - client.add_entity_request_handler(BufferStore::handle_update_buffer); - client.add_entity_message_handler(BufferStore::handle_close_buffer); + session.add_entity_request_handler(BufferStore::handle_update_buffer); + session.add_entity_message_handler(BufferStore::handle_close_buffer); - client.add_request_handler( + session.add_request_handler( extensions.clone().downgrade(), HeadlessExtensionStore::handle_sync_extensions, ); - client.add_request_handler( + session.add_request_handler( extensions.clone().downgrade(), HeadlessExtensionStore::handle_install_extension, ); - BufferStore::init(&client); - WorktreeStore::init(&client); - SettingsObserver::init(&client); - LspStore::init(&client); - TaskStore::init(Some(&client)); - ToolchainStore::init(&client); - DapStore::init(&client, cx); + BufferStore::init(&session); + WorktreeStore::init(&session); + SettingsObserver::init(&session); + LspStore::init(&session); + TaskStore::init(Some(&session)); + ToolchainStore::init(&session); + DapStore::init(&session, cx); // todo(debugger): Re init breakpoint store when we set it up for collab // BreakpointStore::init(&client); - GitStore::init(&client); + GitStore::init(&session); HeadlessProject { - session: client, + next_entry_id: Default::default(), + session, settings_observer, fs, worktree_store, @@ -268,7 +266,6 @@ impl HeadlessProject { lsp_store, task_store, dap_store, - next_entry_id: Default::default(), languages, extensions, git_store, diff --git a/crates/remote_server/src/unix.rs b/crates/remote_server/src/unix.rs index 15a465a880..3352b317cb 100644 --- a/crates/remote_server/src/unix.rs +++ b/crates/remote_server/src/unix.rs @@ -19,11 +19,11 @@ use project::project_settings::ProjectSettings; use proto::CrashReport; use release_channel::{AppVersion, RELEASE_CHANNEL, ReleaseChannel}; -use remote::proxy::ProxyLaunchError; -use remote::ssh_session::ChannelClient; +use remote::SshRemoteClient; use remote::{ json_log::LogRecord, protocol::{read_message, write_message}, + proxy::ProxyLaunchError, }; use reqwest_client::ReqwestClient; use rpc::proto::{self, Envelope, SSH_PROJECT_ID}; @@ -199,8 +199,7 @@ fn init_panic_hook(session_id: String) { })); } -fn handle_crash_files_requests(project: &Entity, client: &Arc) { - let client: AnyProtoClient = client.clone().into(); +fn handle_crash_files_requests(project: &Entity, client: &AnyProtoClient) { client.add_request_handler( project.downgrade(), |_, _: TypedEnvelope, _cx| async move { @@ -276,7 +275,7 @@ fn start_server( listeners: ServerListeners, log_rx: Receiver>, cx: &mut App, -) -> Arc { +) -> AnyProtoClient { // This is the server idle timeout. If no connection comes in this timeout, the server will shut down. const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10 * 60); @@ -395,7 +394,7 @@ fn start_server( }) .detach(); - ChannelClient::new(incoming_rx, outgoing_tx, cx, "server") + SshRemoteClient::proto_client_from_channels(incoming_rx, outgoing_tx, cx, "server") } fn init_paths() -> anyhow::Result<()> { @@ -792,7 +791,7 @@ async fn write_size_prefixed_buffer( } fn initialize_settings( - session: Arc, + session: AnyProtoClient, fs: Arc, cx: &mut App, ) -> watch::Receiver> { diff --git a/crates/rpc/src/proto_client.rs b/crates/rpc/src/proto_client.rs index eb570b96a3..05b6bd1439 100644 --- a/crates/rpc/src/proto_client.rs +++ b/crates/rpc/src/proto_client.rs @@ -315,4 +315,23 @@ impl AnyProtoClient { }), ); } + + pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Entity) { + let id = (TypeId::of::(), remote_id); + + let mut message_handlers = self.0.message_handler_set().lock(); + if message_handlers + .entities_by_type_and_remote_id + .contains_key(&id) + { + panic!("already subscribed to entity"); + } + + message_handlers.entities_by_type_and_remote_id.insert( + id, + EntityMessageSubscriber::Entity { + handle: entity.downgrade().into(), + }, + ); + } } diff --git a/crates/tasks_ui/src/tasks_ui.rs b/crates/tasks_ui/src/tasks_ui.rs index dae366a979..a4fdc24e17 100644 --- a/crates/tasks_ui/src/tasks_ui.rs +++ b/crates/tasks_ui/src/tasks_ui.rs @@ -148,9 +148,9 @@ pub fn toggle_modal( ) -> Task<()> { let task_store = workspace.project().read(cx).task_store().clone(); let workspace_handle = workspace.weak_handle(); - let can_open_modal = workspace.project().update(cx, |project, cx| { - project.is_local() || project.ssh_connection_string(cx).is_some() || project.is_via_ssh() - }); + let can_open_modal = workspace + .project() + .read_with(cx, |project, _| !project.is_via_collab()); if can_open_modal { let task_contexts = task_contexts(workspace, window, cx); cx.spawn_in(window, async move |workspace, cx| {