diff --git a/crates/collab/src/tests/remote_editing_collaboration_tests.rs b/crates/collab/src/tests/remote_editing_collaboration_tests.rs index a9cc32c1dd..7de50511ea 100644 --- a/crates/collab/src/tests/remote_editing_collaboration_tests.rs +++ b/crates/collab/src/tests/remote_editing_collaboration_tests.rs @@ -4,7 +4,7 @@ use fs::{FakeFs, Fs as _}; use gpui::{Context as _, TestAppContext}; use language::language_settings::all_language_settings; use project::ProjectPath; -use remote::SshSession; +use remote::SshRemoteClient; use remote_server::HeadlessProject; use serde_json::json; use std::{path::Path, sync::Arc}; @@ -24,7 +24,7 @@ async fn test_sharing_an_ssh_remote_project( .await; // Set up project on remote FS - let (client_ssh, server_ssh) = SshSession::fake(cx_a, server_cx); + let (client_ssh, server_ssh) = SshRemoteClient::fake(cx_a, server_cx); let remote_fs = FakeFs::new(server_cx.executor()); remote_fs .insert_tree( diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 5ff4a72074..5e7d935c36 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -25,7 +25,7 @@ use node_runtime::NodeRuntime; use notifications::NotificationStore; use parking_lot::Mutex; use project::{Project, WorktreeId}; -use remote::SshSession; +use remote::SshRemoteClient; use rpc::{ proto::{self, ChannelRole}, RECEIVE_TIMEOUT, @@ -835,7 +835,7 @@ impl TestClient { pub async fn build_ssh_project( &self, root_path: impl AsRef, - ssh: Arc, + ssh: Arc, cx: &mut TestAppContext, ) -> (Model, WorktreeId) { let project = cx.update(|cx| { diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index b91250e6b2..dadbd394bb 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -54,7 +54,7 @@ use parking_lot::{Mutex, RwLock}; use paths::{local_tasks_file_relative_path, local_vscode_tasks_file_relative_path}; pub use prettier_store::PrettierStore; use project_settings::{ProjectSettings, SettingsObserver, SettingsObserverEvent}; -use remote::SshSession; +use remote::SshRemoteClient; use rpc::{proto::SSH_PROJECT_ID, AnyProtoClient, ErrorCode}; use search::{SearchInputKind, SearchQuery, SearchResult}; use search_history::SearchHistory; @@ -138,7 +138,7 @@ pub struct Project { join_project_response_message_id: u32, user_store: Model, fs: Arc, - ssh_session: Option>, + ssh_client: Option>, client_state: ProjectClientState, collaborators: HashMap, client_subscriptions: Vec, @@ -643,7 +643,7 @@ impl Project { user_store, settings_observer, fs, - ssh_session: None, + ssh_client: None, buffers_needing_diff: Default::default(), git_diff_debouncer: DebouncedDelay::new(), terminals: Terminals { @@ -664,7 +664,7 @@ impl Project { } pub fn ssh( - ssh: Arc, + ssh: Arc, client: Arc, node: NodeRuntime, user_store: Model, @@ -682,14 +682,14 @@ impl Project { SnippetProvider::new(fs.clone(), BTreeSet::from_iter([global_snippets_dir]), cx); let worktree_store = - cx.new_model(|_| WorktreeStore::remote(false, ssh.clone().into(), 0, None)); + cx.new_model(|_| WorktreeStore::remote(false, ssh.to_proto_client(), 0, None)); cx.subscribe(&worktree_store, Self::on_worktree_store_event) .detach(); let buffer_store = cx.new_model(|cx| { BufferStore::remote( worktree_store.clone(), - ssh.clone().into(), + ssh.to_proto_client(), SSH_PROJECT_ID, cx, ) @@ -698,7 +698,7 @@ impl Project { .detach(); let settings_observer = cx.new_model(|cx| { - SettingsObserver::new_ssh(ssh.clone().into(), worktree_store.clone(), cx) + SettingsObserver::new_ssh(ssh.to_proto_client(), worktree_store.clone(), cx) }); cx.subscribe(&settings_observer, Self::on_settings_observer_event) .detach(); @@ -709,7 +709,7 @@ impl Project { buffer_store.clone(), worktree_store.clone(), languages.clone(), - ssh.clone().into(), + ssh.to_proto_client(), SSH_PROJECT_ID, cx, ) @@ -733,7 +733,7 @@ impl Project { user_store, settings_observer, fs, - ssh_session: Some(ssh.clone()), + ssh_client: Some(ssh.clone()), buffers_needing_diff: Default::default(), git_diff_debouncer: DebouncedDelay::new(), terminals: Terminals { @@ -751,7 +751,7 @@ impl Project { search_excluded_history: Self::new_search_history(), }; - let client: AnyProtoClient = ssh.clone().into(); + let client: AnyProtoClient = ssh.to_proto_client(); ssh.subscribe_to_entity(SSH_PROJECT_ID, &cx.handle()); ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.buffer_store); @@ -907,7 +907,7 @@ impl Project { user_store: user_store.clone(), snippets, fs, - ssh_session: None, + ssh_client: None, settings_observer: settings_observer.clone(), client_subscriptions: Default::default(), _subscriptions: vec![cx.on_release(Self::release)], @@ -1230,7 +1230,7 @@ impl Project { match self.client_state { ProjectClientState::Remote { replica_id, .. } => replica_id, _ => { - if self.ssh_session.is_some() { + if self.ssh_client.is_some() { 1 } else { 0 @@ -1638,7 +1638,7 @@ impl Project { pub fn is_local(&self) -> bool { match &self.client_state { ProjectClientState::Local | ProjectClientState::Shared { .. } => { - self.ssh_session.is_none() + self.ssh_client.is_none() } ProjectClientState::Remote { .. } => false, } @@ -1647,7 +1647,7 @@ impl Project { pub fn is_via_ssh(&self) -> bool { match &self.client_state { ProjectClientState::Local | ProjectClientState::Shared { .. } => { - self.ssh_session.is_some() + self.ssh_client.is_some() } ProjectClientState::Remote { .. } => false, } @@ -1933,8 +1933,9 @@ impl Project { } BufferStoreEvent::BufferChangedFilePath { .. } => {} BufferStoreEvent::BufferDropped(buffer_id) => { - if let Some(ref ssh_session) = self.ssh_session { - ssh_session + if let Some(ref ssh_client) = self.ssh_client { + ssh_client + .to_proto_client() .send(proto::CloseBuffer { project_id: 0, buffer_id: buffer_id.to_proto(), @@ -2139,13 +2140,14 @@ impl Project { } => { let operation = language::proto::serialize_operation(operation); - if let Some(ssh) = &self.ssh_session { - ssh.send(proto::UpdateBuffer { - project_id: 0, - buffer_id: buffer_id.to_proto(), - operations: vec![operation.clone()], - }) - .ok(); + if let Some(ssh) = &self.ssh_client { + ssh.to_proto_client() + .send(proto::UpdateBuffer { + project_id: 0, + buffer_id: buffer_id.to_proto(), + operations: vec![operation.clone()], + }) + .ok(); } self.enqueue_buffer_ordered_message(BufferOrderedMessage::Operation { @@ -2825,14 +2827,13 @@ impl Project { ) -> Receiver> { let (tx, rx) = smol::channel::unbounded(); - let (client, remote_id): (AnyProtoClient, _) = - if let Some(ssh_session) = self.ssh_session.clone() { - (ssh_session.into(), 0) - } else if let Some(remote_id) = self.remote_id() { - (self.client.clone().into(), remote_id) - } else { - return rx; - }; + let (client, remote_id): (AnyProtoClient, _) = if let Some(ssh_client) = &self.ssh_client { + (ssh_client.to_proto_client(), 0) + } else if let Some(remote_id) = self.remote_id() { + (self.client.clone().into(), remote_id) + } else { + return rx; + }; let request = client.request(proto::FindSearchCandidates { project_id: remote_id, @@ -2961,11 +2962,13 @@ impl Project { exists.then(|| ResolvedPath::AbsPath(expanded)) }) - } else if let Some(ssh_session) = self.ssh_session.as_ref() { - let request = ssh_session.request(proto::CheckFileExists { - project_id: SSH_PROJECT_ID, - path: path.to_string(), - }); + } else if let Some(ssh_client) = self.ssh_client.as_ref() { + let request = ssh_client + .to_proto_client() + .request(proto::CheckFileExists { + project_id: SSH_PROJECT_ID, + path: path.to_string(), + }); cx.background_executor().spawn(async move { let response = request.await.log_err()?; if response.exists { @@ -3035,13 +3038,13 @@ impl Project { ) -> Task>> { if self.is_local() { DirectoryLister::Local(self.fs.clone()).list_directory(query, cx) - } else if let Some(session) = self.ssh_session.as_ref() { + } else if let Some(session) = self.ssh_client.as_ref() { let request = proto::ListRemoteDirectory { dev_server_id: SSH_PROJECT_ID, path: query, }; - let response = session.request(request); + let response = session.to_proto_client().request(request); cx.background_executor().spawn(async move { let response = response.await?; Ok(response.entries.into_iter().map(PathBuf::from).collect()) @@ -3465,11 +3468,11 @@ impl Project { cx: AsyncAppContext, ) -> Result { let buffer_store = this.read_with(&cx, |this, cx| { - if let Some(ssh) = &this.ssh_session { + if let Some(ssh) = &this.ssh_client { let mut payload = envelope.payload.clone(); payload.project_id = 0; cx.background_executor() - .spawn(ssh.request(payload)) + .spawn(ssh.to_proto_client().request(payload)) .detach_and_log_err(cx); } this.buffer_store.clone() diff --git a/crates/project/src/terminals.rs b/crates/project/src/terminals.rs index 54dd48cf43..7175b75e22 100644 --- a/crates/project/src/terminals.rs +++ b/crates/project/src/terminals.rs @@ -67,8 +67,12 @@ impl Project { } fn ssh_command(&self, cx: &AppContext) -> Option { - if let Some(ssh_session) = self.ssh_session.as_ref() { - return Some(SshCommand::Direct(ssh_session.ssh_args())); + if let Some(args) = self + .ssh_client + .as_ref() + .and_then(|session| session.ssh_args()) + { + return Some(SshCommand::Direct(args)); } let dev_server_project_id = self.dev_server_project_id()?; diff --git a/crates/recent_projects/src/ssh_connections.rs b/crates/recent_projects/src/ssh_connections.rs index dd30f15f26..d0fffc031f 100644 --- a/crates/recent_projects/src/ssh_connections.rs +++ b/crates/recent_projects/src/ssh_connections.rs @@ -11,7 +11,7 @@ use gpui::{ Transformation, View, }; use release_channel::{AppVersion, ReleaseChannel}; -use remote::{SshConnectionOptions, SshPlatform, SshSession}; +use remote::{SshConnectionOptions, SshPlatform, SshRemoteClient}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; @@ -376,12 +376,12 @@ pub fn connect_over_ssh( connection_options: SshConnectionOptions, ui: View, cx: &mut WindowContext, -) -> Task>> { +) -> Task>> { let window = cx.window_handle(); let known_password = connection_options.password.clone(); cx.spawn(|mut cx| async move { - remote::SshSession::client( + remote::SshRemoteClient::new( connection_options, Arc::new(SshClientDelegate { window, diff --git a/crates/remote/src/remote.rs b/crates/remote/src/remote.rs index 23f798c191..c3d9e8f9cc 100644 --- a/crates/remote/src/remote.rs +++ b/crates/remote/src/remote.rs @@ -2,4 +2,4 @@ pub mod json_log; pub mod protocol; pub mod ssh_session; -pub use ssh_session::{SshClientDelegate, SshConnectionOptions, SshPlatform, SshSession}; +pub use ssh_session::{SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteClient}; diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index 915595fd9d..fe1e42fe96 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -7,19 +7,23 @@ use crate::{ use anyhow::{anyhow, Context as _, Result}; use collections::HashMap; use futures::{ - channel::{mpsc, oneshot}, + channel::{ + mpsc::{self, UnboundedReceiver, UnboundedSender}, + oneshot, + }, future::BoxFuture, - select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _, + select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt, + StreamExt as _, }; use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, Task}; use parking_lot::Mutex; use rpc::{ proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage}, - EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError, + AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError, }; use smol::{ fs, - process::{self, Stdio}, + process::{self, Child, Stdio}, }; use std::{ any::TypeId, @@ -44,22 +48,6 @@ pub struct SshSocket { socket_path: PathBuf, } -pub struct SshSession { - next_message_id: AtomicU32, - response_channels: ResponseChannels, // Lock - outgoing_tx: mpsc::UnboundedSender, - spawn_process_tx: mpsc::UnboundedSender, - client_socket: Option, - state: Mutex, // Lock - _io_task: Option>>, -} - -struct SshClientState { - socket: SshSocket, - master_process: process::Child, - _temp_dir: TempDir, -} - #[derive(Debug, Clone, PartialEq, Eq)] pub struct SshConnectionOptions { pub host: String, @@ -105,18 +93,13 @@ impl SshConnectionOptions { } } -struct SpawnRequest { - command: String, - process_tx: oneshot::Sender, -} - #[derive(Copy, Clone, Debug)] pub struct SshPlatform { pub os: &'static str, pub arch: &'static str, } -pub trait SshClientDelegate { +pub trait SshClientDelegate: Send + Sync { fn ask_password( &self, prompt: String, @@ -132,48 +115,249 @@ pub trait SshClientDelegate { fn set_error(&self, error_message: String, cx: &mut AsyncAppContext); } -type ResponseChannels = Mutex)>>>; +impl SshSocket { + fn ssh_command>(&self, program: S) -> process::Command { + let mut command = process::Command::new("ssh"); + self.ssh_options(&mut command) + .arg(self.connection_options.ssh_url()) + .arg(program); + command + } -impl SshSession { - pub async fn client( + fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command { + command + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .args(["-o", "ControlMaster=no", "-o"]) + .arg(format!("ControlPath={}", self.socket_path.display())) + } + + fn ssh_args(&self) -> Vec { + vec![ + "-o".to_string(), + "ControlMaster=no".to_string(), + "-o".to_string(), + format!("ControlPath={}", self.socket_path.display()), + self.connection_options.ssh_url(), + ] + } +} + +async fn run_cmd(command: &mut process::Command) -> Result { + let output = command.output().await?; + if output.status.success() { + Ok(String::from_utf8_lossy(&output.stdout).to_string()) + } else { + Err(anyhow!( + "failed to run command: {}", + String::from_utf8_lossy(&output.stderr) + )) + } +} +#[cfg(unix)] +async fn read_with_timeout( + stdout: &mut process::ChildStdout, + timeout: std::time::Duration, + output: &mut Vec, +) -> Result<(), std::io::Error> { + smol::future::or( + async { + stdout.read_to_end(output).await?; + Ok::<_, std::io::Error>(()) + }, + async { + smol::Timer::after(timeout).await; + + Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "Read operation timed out", + )) + }, + ) + .await +} + +struct ChannelForwarder { + quit_tx: UnboundedSender<()>, + forwarding_task: Task<(UnboundedSender, UnboundedReceiver)>, +} + +impl ChannelForwarder { + fn new( + mut incoming_tx: UnboundedSender, + mut outgoing_rx: UnboundedReceiver, + cx: &mut AsyncAppContext, + ) -> (Self, UnboundedSender, UnboundedReceiver) { + let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>(); + + let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::(); + let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::(); + + let forwarding_task = cx.background_executor().spawn(async move { + loop { + select_biased! { + _ = quit_rx.next().fuse() => { + break; + }, + incoming_envelope = proxy_incoming_rx.next().fuse() => { + if let Some(envelope) = incoming_envelope { + if incoming_tx.send(envelope).await.is_err() { + break; + } + } else { + break; + } + } + outgoing_envelope = outgoing_rx.next().fuse() => { + if let Some(envelope) = outgoing_envelope { + if proxy_outgoing_tx.send(envelope).await.is_err() { + break; + } + } else { + break; + } + } + } + } + + (incoming_tx, outgoing_rx) + }); + + ( + Self { + forwarding_task, + quit_tx, + }, + proxy_incoming_tx, + proxy_outgoing_rx, + ) + } + + async fn into_channels(mut self) -> (UnboundedSender, UnboundedReceiver) { + let _ = self.quit_tx.send(()).await; + self.forwarding_task.await + } +} + +struct SshRemoteClientState { + ssh_connection: SshRemoteConnection, + delegate: Arc, + forwarder: ChannelForwarder, + _multiplex_task: Task>, +} + +pub struct SshRemoteClient { + client: Arc, + inner_state: Arc>>, +} + +impl SshRemoteClient { + pub async fn new( connection_options: SshConnectionOptions, delegate: Arc, cx: &mut AsyncAppContext, ) -> Result> { - let client_state = SshClientState::new(connection_options, delegate.clone(), cx).await?; - - let platform = client_state.query_platform().await?; - let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??; - let remote_binary_path = delegate.remote_server_binary_path(cx)?; - client_state - .ensure_server_binary( - &delegate, - &local_binary_path, - &remote_binary_path, - version, - cx, - ) - .await?; - - let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::(); - let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::(); + let (outgoing_tx, outgoing_rx) = mpsc::unbounded::(); let (incoming_tx, incoming_rx) = mpsc::unbounded::(); - let socket = client_state.socket.clone(); - run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?; + let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?; + let this = Arc::new(Self { + client, + inner_state: Arc::new(Mutex::new(None)), + }); - let mut remote_server_child = socket - .ssh_command(format!( - "RUST_LOG={} RUST_BACKTRACE={} {:?} run", - std::env::var("RUST_LOG").unwrap_or_default(), - std::env::var("RUST_BACKTRACE").unwrap_or_default(), - remote_binary_path, - )) - .spawn() - .context("failed to spawn remote server")?; - let mut child_stderr = remote_server_child.stderr.take().unwrap(); - let mut child_stdout = remote_server_child.stdout.take().unwrap(); - let mut child_stdin = remote_server_child.stdin.take().unwrap(); + let inner_state = { + let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = + ChannelForwarder::new(incoming_tx, outgoing_rx, cx); + + let (ssh_connection, ssh_process) = + Self::establish_connection(connection_options.clone(), delegate.clone(), cx) + .await?; + + let multiplex_task = Self::multiplex( + this.clone(), + ssh_process, + proxy_incoming_tx, + proxy_outgoing_rx, + cx, + ); + + SshRemoteClientState { + ssh_connection, + delegate, + forwarder: proxy, + _multiplex_task: multiplex_task, + } + }; + + this.inner_state.lock().replace(inner_state); + + Ok(this) + } + + fn reconnect(this: Arc, cx: &mut AsyncAppContext) -> Result<()> { + let Some(state) = this.inner_state.lock().take() else { + return Err(anyhow!("reconnect is already in progress")); + }; + + let SshRemoteClientState { + mut ssh_connection, + delegate, + forwarder: proxy, + _multiplex_task, + } = state; + drop(_multiplex_task); + + cx.spawn(|mut cx| async move { + let (incoming_tx, outgoing_rx) = proxy.into_channels().await; + + ssh_connection.master_process.kill()?; + ssh_connection + .master_process + .status() + .await + .context("Failed to kill ssh process")?; + + let connection_options = ssh_connection.socket.connection_options.clone(); + + let (ssh_connection, ssh_process) = + Self::establish_connection(connection_options, delegate.clone(), &mut cx).await?; + + let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = + ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); + + let inner_state = SshRemoteClientState { + ssh_connection, + delegate, + forwarder: proxy, + _multiplex_task: Self::multiplex( + this.clone(), + ssh_process, + proxy_incoming_tx, + proxy_outgoing_rx, + &mut cx, + ), + }; + this.inner_state.lock().replace(inner_state); + + anyhow::Ok(()) + }) + .detach(); + + anyhow::Ok(()) + } + + fn multiplex( + this: Arc, + mut ssh_process: Child, + incoming_tx: UnboundedSender, + mut outgoing_rx: UnboundedReceiver, + cx: &mut AsyncAppContext, + ) -> Task> { + let mut child_stderr = ssh_process.stderr.take().unwrap(); + let mut child_stdout = ssh_process.stdout.take().unwrap(); + let mut child_stdin = ssh_process.stdin.take().unwrap(); let io_task = cx.background_executor().spawn(async move { let mut stdin_buffer = Vec::new(); @@ -194,27 +378,15 @@ impl SshSession { write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?; } - request = spawn_process_rx.next().fuse() => { - let Some(request) = request else { - return Ok(()); - }; - - log::info!("spawn process: {:?}", request.command); - let child = client_state.socket - .ssh_command(&request.command) - .spawn() - .context("failed to create channel")?; - request.process_tx.send(child).ok(); - } - result = child_stdout.read(&mut stdout_buffer).fuse() => { match result { Ok(0) => { child_stdin.close().await?; outgoing_rx.close(); - let status = remote_server_child.status().await?; + let status = ssh_process.status().await?; if !status.success() { - log::error!("channel exited with status: {status:?}"); + log::error!("ssh process exited with status: {status:?}"); + return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code())); } return Ok(()); } @@ -267,239 +439,112 @@ impl SshSession { } }); - cx.update(|cx| { - Self::new( - incoming_rx, - outgoing_tx, - spawn_process_tx, - Some(socket), - Some(io_task), - cx, - ) + cx.spawn(|mut cx| async move { + let result = io_task.await; + + if let Err(error) = result { + log::warn!("ssh io task died with error: {:?}. reconnecting...", error); + Self::reconnect(this, &mut cx).ok(); + } + + Ok(()) }) } - pub fn server( - incoming_rx: mpsc::UnboundedReceiver, - outgoing_tx: mpsc::UnboundedSender, - cx: &AppContext, - ) -> Arc { - let (tx, _rx) = mpsc::unbounded(); - Self::new(incoming_rx, outgoing_tx, tx, None, None, cx) + async fn establish_connection( + connection_options: SshConnectionOptions, + delegate: Arc, + cx: &mut AsyncAppContext, + ) -> Result<(SshRemoteConnection, Child)> { + let ssh_connection = + SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?; + + let platform = ssh_connection.query_platform().await?; + let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??; + let remote_binary_path = delegate.remote_server_binary_path(cx)?; + ssh_connection + .ensure_server_binary( + &delegate, + &local_binary_path, + &remote_binary_path, + version, + cx, + ) + .await?; + + let socket = ssh_connection.socket.clone(); + run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?; + + let ssh_process = socket + .ssh_command(format!( + "RUST_LOG={} RUST_BACKTRACE={} {:?} run", + std::env::var("RUST_LOG").unwrap_or_default(), + std::env::var("RUST_BACKTRACE").unwrap_or_default(), + remote_binary_path, + )) + .spawn() + .context("failed to spawn remote server")?; + + Ok((ssh_connection, ssh_process)) + } + + pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { + self.client.subscribe_to_entity(remote_id, entity); + } + + pub fn ssh_args(&self) -> Option> { + let state = self.inner_state.lock(); + state + .as_ref() + .map(|state| state.ssh_connection.socket.ssh_args()) + } + + pub fn to_proto_client(&self) -> AnyProtoClient { + self.client.clone().into() } #[cfg(any(test, feature = "test-support"))] pub fn fake( client_cx: &mut gpui::TestAppContext, server_cx: &mut gpui::TestAppContext, - ) -> (Arc, Arc) { + ) -> (Arc, Arc) { let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded(); let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded(); - let (tx, _rx) = mpsc::unbounded(); + ( client_cx.update(|cx| { - Self::new( - server_to_client_rx, - client_to_server_tx, - tx.clone(), - None, // todo() - None, - cx, - ) - }), - server_cx.update(|cx| { - Self::new( - client_to_server_rx, - server_to_client_tx, - tx.clone(), - None, - None, - cx, - ) + let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx); + Arc::new(Self { + client, + inner_state: Arc::new(Mutex::new(None)), + }) }), + server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)), ) } +} - fn new( - mut incoming_rx: mpsc::UnboundedReceiver, - outgoing_tx: mpsc::UnboundedSender, - spawn_process_tx: mpsc::UnboundedSender, - client_socket: Option, - io_task: Option>>, - cx: &AppContext, - ) -> Arc { - let this = Arc::new(Self { - next_message_id: AtomicU32::new(0), - response_channels: ResponseChannels::default(), - outgoing_tx, - spawn_process_tx, - client_socket, - state: Default::default(), - _io_task: io_task, - }); - - cx.spawn(|cx| { - let this = Arc::downgrade(&this); - async move { - let peer_id = PeerId { owner_id: 0, id: 0 }; - while let Some(incoming) = incoming_rx.next().await { - let Some(this) = this.upgrade() else { - return anyhow::Ok(()); - }; - - if let Some(request_id) = incoming.responding_to { - let request_id = MessageId(request_id); - let sender = this.response_channels.lock().remove(&request_id); - if let Some(sender) = sender { - let (tx, rx) = oneshot::channel(); - if incoming.payload.is_some() { - sender.send((incoming, tx)).ok(); - } - rx.await.ok(); - } - } else if let Some(envelope) = - build_typed_envelope(peer_id, Instant::now(), incoming) - { - let type_name = envelope.payload_type_name(); - if let Some(future) = ProtoMessageHandlerSet::handle_message( - &this.state, - envelope, - this.clone().into(), - cx.clone(), - ) { - log::debug!("ssh message received. name:{type_name}"); - match future.await { - Ok(_) => { - log::debug!("ssh message handled. name:{type_name}"); - } - Err(error) => { - log::error!( - "error handling message. type:{type_name}, error:{error}", - ); - } - } - } else { - log::error!("unhandled ssh message name:{type_name}"); - } - } - } - anyhow::Ok(()) - } - }) - .detach(); - - this - } - - pub fn request( - &self, - payload: T, - ) -> impl 'static + Future> { - log::debug!("ssh request start. name:{}", T::NAME); - let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME); - async move { - let response = response.await?; - log::debug!("ssh request finish. name:{}", T::NAME); - T::Response::from_envelope(response) - .ok_or_else(|| anyhow!("received a response of the wrong type")) - } - } - - pub fn send(&self, payload: T) -> Result<()> { - log::debug!("ssh send name:{}", T::NAME); - self.send_dynamic(payload.into_envelope(0, None, None)) - } - - pub fn request_dynamic( - &self, - mut envelope: proto::Envelope, - type_name: &'static str, - ) -> impl 'static + Future> { - envelope.id = self.next_message_id.fetch_add(1, SeqCst); - let (tx, rx) = oneshot::channel(); - let mut response_channels_lock = self.response_channels.lock(); - response_channels_lock.insert(MessageId(envelope.id), tx); - drop(response_channels_lock); - let result = self.outgoing_tx.unbounded_send(envelope); - async move { - if let Err(error) = &result { - log::error!("failed to send message: {}", error); - return Err(anyhow!("failed to send message: {}", error)); - } - - let response = rx.await.context("connection lost")?.0; - if let Some(proto::envelope::Payload::Error(error)) = &response.payload { - return Err(RpcError::from_proto(error, type_name)); - } - Ok(response) - } - } - - pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> { - envelope.id = self.next_message_id.fetch_add(1, SeqCst); - self.outgoing_tx.unbounded_send(envelope)?; - Ok(()) - } - - pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { - let id = (TypeId::of::(), remote_id); - - let mut state = self.state.lock(); - if state.entities_by_type_and_remote_id.contains_key(&id) { - panic!("already subscribed to entity"); - } - - state.entities_by_type_and_remote_id.insert( - id, - EntityMessageSubscriber::Entity { - handle: entity.downgrade().into(), - }, - ); - } - - pub async fn spawn_process(&self, command: String) -> process::Child { - let (process_tx, process_rx) = oneshot::channel(); - self.spawn_process_tx - .unbounded_send(SpawnRequest { - command, - process_tx, - }) - .ok(); - process_rx.await.unwrap() - } - - pub fn ssh_args(&self) -> Vec { - self.client_socket.as_ref().unwrap().ssh_args() +impl From for AnyProtoClient { + fn from(client: SshRemoteClient) -> Self { + AnyProtoClient::new(client.client.clone()) } } -impl ProtoClient for SshSession { - fn request( - &self, - envelope: proto::Envelope, - request_type: &'static str, - ) -> BoxFuture<'static, Result> { - self.request_dynamic(envelope, request_type).boxed() - } +struct SshRemoteConnection { + socket: SshSocket, + master_process: process::Child, + _temp_dir: TempDir, +} - fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> { - self.send_dynamic(envelope) - } - - fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> { - self.send_dynamic(envelope) - } - - fn message_handler_set(&self) -> &Mutex { - &self.state - } - - fn is_via_collab(&self) -> bool { - false +impl Drop for SshRemoteConnection { + fn drop(&mut self) { + if let Err(error) = self.master_process.kill() { + log::error!("failed to kill SSH master process: {}", error); + } } } -impl SshClientState { +impl SshRemoteConnection { #[cfg(not(unix))] async fn new( _connection_options: SshConnectionOptions, @@ -740,74 +785,181 @@ impl SshClientState { } } -#[cfg(unix)] -async fn read_with_timeout( - stdout: &mut process::ChildStdout, - timeout: std::time::Duration, - output: &mut Vec, -) -> Result<(), std::io::Error> { - smol::future::or( - async { - stdout.read_to_end(output).await?; - Ok::<_, std::io::Error>(()) - }, - async { - smol::Timer::after(timeout).await; +type ResponseChannels = Mutex)>>>; - Err(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "Read operation timed out", - )) - }, - ) - .await +pub struct ChannelClient { + next_message_id: AtomicU32, + outgoing_tx: mpsc::UnboundedSender, + response_channels: ResponseChannels, // Lock + message_handlers: Mutex, // Lock } -impl Drop for SshClientState { - fn drop(&mut self) { - if let Err(error) = self.master_process.kill() { - log::error!("failed to kill SSH master process: {}", error); +impl ChannelClient { + pub fn new( + incoming_rx: mpsc::UnboundedReceiver, + outgoing_tx: mpsc::UnboundedSender, + cx: &AppContext, + ) -> Arc { + let this = Arc::new(Self { + outgoing_tx, + next_message_id: AtomicU32::new(0), + response_channels: ResponseChannels::default(), + message_handlers: Default::default(), + }); + + Self::start_handling_messages(this.clone(), incoming_rx, cx); + + this + } + + fn start_handling_messages( + this: Arc, + mut incoming_rx: mpsc::UnboundedReceiver, + cx: &AppContext, + ) { + cx.spawn(|cx| { + let this = Arc::downgrade(&this); + async move { + let peer_id = PeerId { owner_id: 0, id: 0 }; + while let Some(incoming) = incoming_rx.next().await { + let Some(this) = this.upgrade() else { + return anyhow::Ok(()); + }; + + if let Some(request_id) = incoming.responding_to { + let request_id = MessageId(request_id); + let sender = this.response_channels.lock().remove(&request_id); + if let Some(sender) = sender { + let (tx, rx) = oneshot::channel(); + if incoming.payload.is_some() { + sender.send((incoming, tx)).ok(); + } + rx.await.ok(); + } + } else if let Some(envelope) = + build_typed_envelope(peer_id, Instant::now(), incoming) + { + let type_name = envelope.payload_type_name(); + if let Some(future) = ProtoMessageHandlerSet::handle_message( + &this.message_handlers, + envelope, + this.clone().into(), + cx.clone(), + ) { + log::debug!("ssh message received. name:{type_name}"); + match future.await { + Ok(_) => { + log::debug!("ssh message handled. name:{type_name}"); + } + Err(error) => { + log::error!( + "error handling message. type:{type_name}, error:{error}", + ); + } + } + } else { + log::error!("unhandled ssh message name:{type_name}"); + } + } + } + anyhow::Ok(()) + } + }) + .detach(); + } + + pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { + let id = (TypeId::of::(), remote_id); + + let mut message_handlers = self.message_handlers.lock(); + if message_handlers + .entities_by_type_and_remote_id + .contains_key(&id) + { + panic!("already subscribed to entity"); + } + + message_handlers.entities_by_type_and_remote_id.insert( + id, + EntityMessageSubscriber::Entity { + handle: entity.downgrade().into(), + }, + ); + } + + pub fn request( + &self, + payload: T, + ) -> impl 'static + Future> { + log::debug!("ssh request start. name:{}", T::NAME); + let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME); + async move { + let response = response.await?; + log::debug!("ssh request finish. name:{}", T::NAME); + T::Response::from_envelope(response) + .ok_or_else(|| anyhow!("received a response of the wrong type")) } } -} -impl SshSocket { - fn ssh_command>(&self, program: S) -> process::Command { - let mut command = process::Command::new("ssh"); - self.ssh_options(&mut command) - .arg(self.connection_options.ssh_url()) - .arg(program); - command + pub fn send(&self, payload: T) -> Result<()> { + log::debug!("ssh send name:{}", T::NAME); + self.send_dynamic(payload.into_envelope(0, None, None)) } - fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command { - command - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .args(["-o", "ControlMaster=no", "-o"]) - .arg(format!("ControlPath={}", self.socket_path.display())) + pub fn request_dynamic( + &self, + mut envelope: proto::Envelope, + type_name: &'static str, + ) -> impl 'static + Future> { + envelope.id = self.next_message_id.fetch_add(1, SeqCst); + let (tx, rx) = oneshot::channel(); + let mut response_channels_lock = self.response_channels.lock(); + response_channels_lock.insert(MessageId(envelope.id), tx); + drop(response_channels_lock); + let result = self.outgoing_tx.unbounded_send(envelope); + async move { + if let Err(error) = &result { + log::error!("failed to send message: {}", error); + return Err(anyhow!("failed to send message: {}", error)); + } + + let response = rx.await.context("connection lost")?.0; + if let Some(proto::envelope::Payload::Error(error)) = &response.payload { + return Err(RpcError::from_proto(error, type_name)); + } + Ok(response) + } } - fn ssh_args(&self) -> Vec { - vec![ - "-o".to_string(), - "ControlMaster=no".to_string(), - "-o".to_string(), - format!("ControlPath={}", self.socket_path.display()), - self.connection_options.ssh_url(), - ] + pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> { + envelope.id = self.next_message_id.fetch_add(1, SeqCst); + self.outgoing_tx.unbounded_send(envelope)?; + Ok(()) } } -async fn run_cmd(command: &mut process::Command) -> Result { - let output = command.output().await?; - if output.status.success() { - Ok(String::from_utf8_lossy(&output.stdout).to_string()) - } else { - Err(anyhow!( - "failed to run command: {}", - String::from_utf8_lossy(&output.stderr) - )) +impl ProtoClient for ChannelClient { + fn request( + &self, + envelope: proto::Envelope, + request_type: &'static str, + ) -> BoxFuture<'static, Result> { + self.request_dynamic(envelope, request_type).boxed() + } + + fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> { + self.send_dynamic(envelope) + } + + fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> { + self.send_dynamic(envelope) + } + + fn message_handler_set(&self) -> &Mutex { + &self.message_handlers + } + + fn is_via_collab(&self) -> bool { + false } } diff --git a/crates/remote_server/src/headless_project.rs b/crates/remote_server/src/headless_project.rs index 4b13938d8c..39540b04e0 100644 --- a/crates/remote_server/src/headless_project.rs +++ b/crates/remote_server/src/headless_project.rs @@ -10,7 +10,7 @@ use project::{ worktree_store::WorktreeStore, LspStore, LspStoreEvent, PrettierStore, ProjectPath, WorktreeId, }; -use remote::SshSession; +use remote::ssh_session::ChannelClient; use rpc::{ proto::{self, SSH_PEER_ID, SSH_PROJECT_ID}, AnyProtoClient, TypedEnvelope, @@ -41,7 +41,7 @@ impl HeadlessProject { project::Project::init_settings(cx); } - pub fn new(session: Arc, fs: Arc, cx: &mut ModelContext) -> Self { + pub fn new(session: Arc, fs: Arc, cx: &mut ModelContext) -> Self { let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone())); let node_runtime = NodeRuntime::unavailable(); diff --git a/crates/remote_server/src/main.rs b/crates/remote_server/src/main.rs index 908a0a89b6..73b8a91da1 100644 --- a/crates/remote_server/src/main.rs +++ b/crates/remote_server/src/main.rs @@ -6,7 +6,6 @@ use gpui::Context as _; use remote::{ json_log::LogRecord, protocol::{read_message, write_message}, - SshSession, }; use remote_server::HeadlessProject; use smol::{io::AsyncWriteExt, stream::StreamExt as _, Async}; @@ -24,6 +23,8 @@ fn main() { #[cfg(not(windows))] fn main() { + use remote::ssh_session::ChannelClient; + env_logger::builder() .format(|buf, record| { serde_json::to_writer(&mut *buf, &LogRecord::new(record))?; @@ -55,7 +56,7 @@ fn main() { let mut stdin = Async::new(io::stdin()).unwrap(); let mut stdout = Async::new(io::stdout()).unwrap(); - let session = SshSession::server(incoming_rx, outgoing_tx, cx); + let session = ChannelClient::new(incoming_rx, outgoing_tx, cx); let project = cx.new_model(|cx| { HeadlessProject::new( session.clone(), diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 8920639427..960b7c248c 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -15,7 +15,7 @@ use project::{ search::{SearchQuery, SearchResult}, Project, ProjectPath, }; -use remote::SshSession; +use remote::SshRemoteClient; use serde_json::json; use settings::{Settings, SettingsLocation, SettingsStore}; use smol::stream::StreamExt; @@ -616,7 +616,7 @@ async fn init_test( cx: &mut TestAppContext, server_cx: &mut TestAppContext, ) -> (Model, Model, Arc) { - let (client_ssh, server_ssh) = SshSession::fake(cx, server_cx); + let (ssh_remote_client, ssh_server_client) = SshRemoteClient::fake(cx, server_cx); init_logger(); let fs = FakeFs::new(server_cx.executor()); @@ -642,8 +642,9 @@ async fn init_test( ); server_cx.update(HeadlessProject::init); - let headless = server_cx.new_model(|cx| HeadlessProject::new(server_ssh, fs.clone(), cx)); - let project = build_project(client_ssh, cx); + let headless = + server_cx.new_model(|cx| HeadlessProject::new(ssh_server_client, fs.clone(), cx)); + let project = build_project(ssh_remote_client, cx); project .update(cx, { @@ -654,7 +655,7 @@ async fn init_test( (project, headless, fs) } -fn build_project(ssh: Arc, cx: &mut TestAppContext) -> Model { +fn build_project(ssh: Arc, cx: &mut TestAppContext) -> Model { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index cec913851f..b668a5802c 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -61,7 +61,7 @@ use postage::stream::Stream; use project::{ DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree, WorktreeId, }; -use remote::{SshConnectionOptions, SshSession}; +use remote::{SshConnectionOptions, SshRemoteClient}; use serde::Deserialize; use session::AppSession; use settings::{InvalidSettingsError, Settings}; @@ -5514,7 +5514,7 @@ pub fn join_hosted_project( pub fn open_ssh_project( window: WindowHandle, connection_options: SshConnectionOptions, - session: Arc, + session: Arc, app_state: Arc, paths: Vec, cx: &mut AppContext,