Revert "SSH reconnect reliability (#19398)" (#19440)

This reverts commit 98ecb43b2d.

Tests fail on main?!

Closes #ISSUE

Release Notes:

- N/A
This commit is contained in:
Conrad Irwin 2024-10-18 16:08:56 -06:00 committed by GitHub
parent 47380001cc
commit a5492b3ea6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 80 additions and 423 deletions

1
Cargo.lock generated
View file

@ -9119,7 +9119,6 @@ name = "remote"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"collections", "collections",
"fs", "fs",
"futures 0.3.30", "futures 0.3.30",

View file

@ -26,7 +26,7 @@ async fn test_sharing_an_ssh_remote_project(
.await; .await;
// Set up project on remote FS // Set up project on remote FS
let (forwarder, server_ssh) = SshRemoteClient::fake_server(server_cx); let (client_ssh, server_ssh) = SshRemoteClient::fake(cx_a, server_cx);
let remote_fs = FakeFs::new(server_cx.executor()); let remote_fs = FakeFs::new(server_cx.executor());
remote_fs remote_fs
.insert_tree( .insert_tree(
@ -67,7 +67,6 @@ async fn test_sharing_an_ssh_remote_project(
) )
}); });
let client_ssh = SshRemoteClient::fake_client(forwarder, cx_a).await;
let (project_a, worktree_id) = client_a let (project_a, worktree_id) = client_a
.build_ssh_project("/code/project1", client_ssh, cx_a) .build_ssh_project("/code/project1", client_ssh, cx_a)
.await; .await;

View file

@ -1243,10 +1243,6 @@ impl Project {
self.client.clone() self.client.clone()
} }
pub fn ssh_client(&self) -> Option<Model<SshRemoteClient>> {
self.ssh_client.clone()
}
pub fn user_store(&self) -> Model<UserStore> { pub fn user_store(&self) -> Model<UserStore> {
self.user_store.clone() self.user_store.clone()
} }

View file

@ -12,7 +12,6 @@ message Envelope {
uint32 id = 1; uint32 id = 1;
optional uint32 responding_to = 2; optional uint32 responding_to = 2;
optional PeerId original_sender_id = 3; optional PeerId original_sender_id = 3;
optional uint32 ack_id = 266;
oneof payload { oneof payload {
Hello hello = 4; Hello hello = 4;
@ -296,9 +295,7 @@ message Envelope {
OpenServerSettings open_server_settings = 263; OpenServerSettings open_server_settings = 263;
GetPermalinkToLine get_permalink_to_line = 264; GetPermalinkToLine get_permalink_to_line = 264;
GetPermalinkToLineResponse get_permalink_to_line_response = 265; GetPermalinkToLineResponse get_permalink_to_line_response = 265; // current max
FlushBufferedMessages flush_buffered_messages = 267;
} }
reserved 87 to 88; reserved 87 to 88;
@ -2524,6 +2521,3 @@ message GetPermalinkToLine {
message GetPermalinkToLineResponse { message GetPermalinkToLineResponse {
string permalink = 1; string permalink = 1;
} }
message FlushBufferedMessages {}
message FlushBufferedMessagesResponse {}

View file

@ -32,7 +32,6 @@ macro_rules! messages {
responding_to, responding_to,
original_sender_id, original_sender_id,
payload: Some(envelope::Payload::$name(self)), payload: Some(envelope::Payload::$name(self)),
ack_id: None,
} }
} }

View file

@ -372,7 +372,6 @@ messages!(
(OpenServerSettings, Foreground), (OpenServerSettings, Foreground),
(GetPermalinkToLine, Foreground), (GetPermalinkToLine, Foreground),
(GetPermalinkToLineResponse, Foreground), (GetPermalinkToLineResponse, Foreground),
(FlushBufferedMessages, Foreground),
); );
request_messages!( request_messages!(
@ -499,7 +498,6 @@ request_messages!(
(RemoveWorktree, Ack), (RemoveWorktree, Ack),
(OpenServerSettings, OpenBufferResponse), (OpenServerSettings, OpenBufferResponse),
(GetPermalinkToLine, GetPermalinkToLineResponse), (GetPermalinkToLine, GetPermalinkToLineResponse),
(FlushBufferedMessages, Ack),
); );
entity_messages!( entity_messages!(

View file

@ -19,7 +19,6 @@ test-support = ["fs/test-support"]
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
async-trait.workspace = true
collections.workspace = true collections.workspace = true
fs.workspace = true fs.workspace = true
futures.workspace = true futures.workspace = true

View file

@ -6,7 +6,6 @@ use crate::{
proxy::ProxyLaunchError, proxy::ProxyLaunchError,
}; };
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use async_trait::async_trait;
use collections::HashMap; use collections::HashMap;
use futures::{ use futures::{
channel::{ channel::{
@ -32,7 +31,6 @@ use smol::{
}; };
use std::{ use std::{
any::TypeId, any::TypeId,
collections::VecDeque,
ffi::OsStr, ffi::OsStr,
fmt, fmt,
ops::ControlFlow, ops::ControlFlow,
@ -278,7 +276,7 @@ async fn run_cmd(command: &mut process::Command) -> Result<String> {
} }
} }
pub struct ChannelForwarder { struct ChannelForwarder {
quit_tx: UnboundedSender<()>, quit_tx: UnboundedSender<()>,
forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>, forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
} }
@ -349,7 +347,7 @@ const MAX_RECONNECT_ATTEMPTS: usize = 3;
enum State { enum State {
Connecting, Connecting,
Connected { Connected {
ssh_connection: Box<dyn SshRemoteProcess>, ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>, delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder, forwarder: ChannelForwarder,
@ -359,7 +357,7 @@ enum State {
HeartbeatMissed { HeartbeatMissed {
missed_heartbeats: usize, missed_heartbeats: usize,
ssh_connection: Box<dyn SshRemoteProcess>, ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>, delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder, forwarder: ChannelForwarder,
@ -368,7 +366,7 @@ enum State {
}, },
Reconnecting, Reconnecting,
ReconnectFailed { ReconnectFailed {
ssh_connection: Box<dyn SshRemoteProcess>, ssh_connection: SshRemoteConnection,
delegate: Arc<dyn SshClientDelegate>, delegate: Arc<dyn SshClientDelegate>,
forwarder: ChannelForwarder, forwarder: ChannelForwarder,
@ -394,11 +392,11 @@ impl fmt::Display for State {
} }
impl State { impl State {
fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> { fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
match self { match self {
Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()), Self::Connected { ssh_connection, .. } => Some(ssh_connection),
Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()), Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()), Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
_ => None, _ => None,
} }
} }
@ -543,19 +541,23 @@ impl SshRemoteClient {
let (proxy, proxy_incoming_tx, proxy_outgoing_rx) = let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let (ssh_connection, io_task) = Self::establish_connection( let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
unique_identifier, unique_identifier,
false, false,
connection_options, connection_options,
proxy_incoming_tx,
proxy_outgoing_rx,
connection_activity_tx,
delegate.clone(), delegate.clone(),
&mut cx, &mut cx,
) )
.await?; .await?;
let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx); let multiplex_task = Self::multiplex(
this.downgrade(),
ssh_proxy_process,
proxy_incoming_tx,
proxy_outgoing_rx,
connection_activity_tx,
&mut cx,
);
if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await { if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
log::error!("failed to establish connection: {}", error); log::error!("failed to establish connection: {}", error);
@ -701,24 +703,30 @@ impl SshRemoteClient {
}; };
} }
if let Err(error) = ssh_connection.kill().await.context("Failed to kill ssh process") { if let Err(error) = ssh_connection.master_process.kill() {
failed!(error, attempts, ssh_connection, delegate, forwarder); failed!(error, attempts, ssh_connection, delegate, forwarder);
}; };
let connection_options = ssh_connection.connection_options(); if let Err(error) = ssh_connection
.master_process
.status()
.await
.context("Failed to kill ssh process")
{
failed!(error, attempts, ssh_connection, delegate, forwarder);
}
let connection_options = ssh_connection.socket.connection_options.clone();
let (incoming_tx, outgoing_rx) = forwarder.into_channels().await; let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) = let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1); let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
let (ssh_connection, io_task) = match Self::establish_connection( let (ssh_connection, ssh_process) = match Self::establish_connection(
identifier, identifier,
true, true,
connection_options, connection_options,
proxy_incoming_tx,
proxy_outgoing_rx,
connection_activity_tx,
delegate.clone(), delegate.clone(),
&mut cx, &mut cx,
) )
@ -730,9 +738,16 @@ impl SshRemoteClient {
} }
}; };
let multiplex_task = Self::monitor(this.clone(), io_task, &cx); let multiplex_task = Self::multiplex(
this.clone(),
ssh_process,
proxy_incoming_tx,
proxy_outgoing_rx,
connection_activity_tx,
&mut cx,
);
if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await { if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
failed!(error, attempts, ssh_connection, delegate, forwarder); failed!(error, attempts, ssh_connection, delegate, forwarder);
}; };
@ -896,17 +911,18 @@ impl SshRemoteClient {
} }
fn multiplex( fn multiplex(
this: WeakModel<Self>,
mut ssh_proxy_process: Child, mut ssh_proxy_process: Child,
incoming_tx: UnboundedSender<Envelope>, incoming_tx: UnboundedSender<Envelope>,
mut outgoing_rx: UnboundedReceiver<Envelope>, mut outgoing_rx: UnboundedReceiver<Envelope>,
mut connection_activity_tx: Sender<()>, mut connection_activity_tx: Sender<()>,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> Task<Result<Option<i32>>> { ) -> Task<Result<()>> {
let mut child_stderr = ssh_proxy_process.stderr.take().unwrap(); let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
let mut child_stdout = ssh_proxy_process.stdout.take().unwrap(); let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
let mut child_stdin = ssh_proxy_process.stdin.take().unwrap(); let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
cx.background_executor().spawn(async move { let io_task = cx.background_executor().spawn(async move {
let mut stdin_buffer = Vec::new(); let mut stdin_buffer = Vec::new();
let mut stdout_buffer = Vec::new(); let mut stdout_buffer = Vec::new();
let mut stderr_buffer = Vec::new(); let mut stderr_buffer = Vec::new();
@ -985,14 +1001,8 @@ impl SshRemoteClient {
} }
} }
} }
}) });
}
fn monitor(
this: WeakModel<Self>,
io_task: Task<Result<Option<i32>>>,
cx: &AsyncAppContext,
) -> Task<Result<()>> {
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let result = io_task.await; let result = io_task.await;
@ -1051,40 +1061,21 @@ impl SshRemoteClient {
cx.notify(); cx.notify();
} }
#[allow(clippy::too_many_arguments)]
async fn establish_connection( async fn establish_connection(
unique_identifier: String, unique_identifier: String,
reconnect: bool, reconnect: bool,
connection_options: SshConnectionOptions, connection_options: SshConnectionOptions,
proxy_incoming_tx: UnboundedSender<Envelope>,
proxy_outgoing_rx: UnboundedReceiver<Envelope>,
connection_activity_tx: Sender<()>,
delegate: Arc<dyn SshClientDelegate>, delegate: Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> Result<(Box<dyn SshRemoteProcess>, Task<Result<Option<i32>>>)> { ) -> Result<(SshRemoteConnection, Child)> {
#[cfg(any(test, feature = "test-support"))]
if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) {
let io_task = fake::SshRemoteConnection::multiplex(
fake.connection_options(),
proxy_incoming_tx,
proxy_outgoing_rx,
connection_activity_tx,
cx,
)
.await;
return Ok((fake, io_task));
}
let ssh_connection = let ssh_connection =
SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?; SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
let platform = ssh_connection.query_platform().await?; let platform = ssh_connection.query_platform().await?;
let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?; let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
if !reconnect { ssh_connection
ssh_connection .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
.ensure_server_binary(&delegate, &remote_binary_path, platform, cx) .await?;
.await?;
}
let socket = ssh_connection.socket.clone(); let socket = ssh_connection.socket.clone();
run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?; run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
@ -1109,15 +1100,7 @@ impl SshRemoteClient {
.spawn() .spawn()
.context("failed to spawn remote server")?; .context("failed to spawn remote server")?;
let io_task = Self::multiplex( Ok((ssh_connection, ssh_proxy_process))
ssh_proxy_process,
proxy_incoming_tx,
proxy_outgoing_rx,
connection_activity_tx,
&cx,
);
Ok((Box::new(ssh_connection), io_task))
} }
pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) { pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
@ -1129,7 +1112,7 @@ impl SshRemoteClient {
.lock() .lock()
.as_ref() .as_ref()
.and_then(|state| state.ssh_connection()) .and_then(|state| state.ssh_connection())
.map(|ssh_connection| ssh_connection.ssh_args()) .map(|ssh_connection| ssh_connection.socket.ssh_args())
} }
pub fn proto_client(&self) -> AnyProtoClient { pub fn proto_client(&self) -> AnyProtoClient {
@ -1144,6 +1127,7 @@ impl SshRemoteClient {
self.connection_options.clone() self.connection_options.clone()
} }
#[cfg(not(any(test, feature = "test-support")))]
pub fn connection_state(&self) -> ConnectionState { pub fn connection_state(&self) -> ConnectionState {
self.state self.state
.lock() .lock()
@ -1152,69 +1136,37 @@ impl SshRemoteClient {
.unwrap_or(ConnectionState::Disconnected) .unwrap_or(ConnectionState::Disconnected)
} }
#[cfg(any(test, feature = "test-support"))]
pub fn connection_state(&self) -> ConnectionState {
ConnectionState::Connected
}
pub fn is_disconnected(&self) -> bool { pub fn is_disconnected(&self) -> bool {
self.connection_state() == ConnectionState::Disconnected self.connection_state() == ConnectionState::Disconnected
} }
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub fn simulate_disconnect(&self, cx: &mut AppContext) -> Task<()> { pub fn fake(
use gpui::BorrowAppContext;
let port = self.connection_options().port.unwrap();
let disconnect =
cx.update_global(|c: &mut fake::GlobalConnections, _cx| c.take(port).into_channels());
cx.spawn(|mut cx| async move {
let (input_rx, output_tx) = disconnect.await;
let (forwarder, _, _) = ChannelForwarder::new(input_rx, output_tx, &mut cx);
cx.update_global(|c: &mut fake::GlobalConnections, _cx| c.replace(port, forwarder))
.unwrap()
})
}
#[cfg(any(test, feature = "test-support"))]
pub fn fake_server(
server_cx: &mut gpui::TestAppContext,
) -> (ChannelForwarder, Arc<ChannelClient>) {
server_cx.update(|cx| {
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
// We use the forwarder on the server side (in production we only use one on the client side)
// the idea is that we can simulate a disconnect/reconnect by just messing with the forwarder.
let (forwarder, _, _) =
ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx.to_async());
let client = ChannelClient::new(incoming_rx, outgoing_tx, cx);
(forwarder, client)
})
}
#[cfg(any(test, feature = "test-support"))]
pub async fn fake_client(
forwarder: ChannelForwarder,
client_cx: &mut gpui::TestAppContext, client_cx: &mut gpui::TestAppContext,
) -> Model<Self> { server_cx: &mut gpui::TestAppContext,
use gpui::BorrowAppContext; ) -> (Model<Self>, Arc<ChannelClient>) {
client_cx use gpui::Context;
.update(|cx| {
let port = cx.update_default_global(|c: &mut fake::GlobalConnections, _cx| {
c.push(forwarder)
});
Self::new( let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
"fake".to_string(), let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
SshConnectionOptions {
host: "<fake>".to_string(), (
port: Some(port), client_cx.update(|cx| {
..Default::default() let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
}, cx.new_model(|_| Self {
Arc::new(fake::Delegate), client,
cx, unique_identifier: "fake".to_string(),
) connection_options: SshConnectionOptions::default(),
}) state: Arc::new(Mutex::new(None)),
.await })
.unwrap() }),
server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
)
} }
} }
@ -1224,13 +1176,6 @@ impl From<SshRemoteClient> for AnyProtoClient {
} }
} }
#[async_trait]
trait SshRemoteProcess: Send + Sync {
async fn kill(&mut self) -> Result<()>;
fn ssh_args(&self) -> Vec<String>;
fn connection_options(&self) -> SshConnectionOptions;
}
struct SshRemoteConnection { struct SshRemoteConnection {
socket: SshSocket, socket: SshSocket,
master_process: process::Child, master_process: process::Child,
@ -1245,25 +1190,6 @@ impl Drop for SshRemoteConnection {
} }
} }
#[async_trait]
impl SshRemoteProcess for SshRemoteConnection {
async fn kill(&mut self) -> Result<()> {
self.master_process.kill()?;
self.master_process.status().await?;
Ok(())
}
fn ssh_args(&self) -> Vec<String> {
self.socket.ssh_args()
}
fn connection_options(&self) -> SshConnectionOptions {
self.socket.connection_options.clone()
}
}
impl SshRemoteConnection { impl SshRemoteConnection {
#[cfg(not(unix))] #[cfg(not(unix))]
async fn new( async fn new(
@ -1546,10 +1472,8 @@ type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, ones
pub struct ChannelClient { pub struct ChannelClient {
next_message_id: AtomicU32, next_message_id: AtomicU32,
outgoing_tx: mpsc::UnboundedSender<Envelope>, outgoing_tx: mpsc::UnboundedSender<Envelope>,
buffer: Mutex<VecDeque<Envelope>>, response_channels: ResponseChannels, // Lock
response_channels: ResponseChannels, message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
message_handlers: Mutex<ProtoMessageHandlerSet>,
max_received: AtomicU32,
} }
impl ChannelClient { impl ChannelClient {
@ -1561,10 +1485,8 @@ impl ChannelClient {
let this = Arc::new(Self { let this = Arc::new(Self {
outgoing_tx, outgoing_tx,
next_message_id: AtomicU32::new(0), next_message_id: AtomicU32::new(0),
max_received: AtomicU32::new(0),
response_channels: ResponseChannels::default(), response_channels: ResponseChannels::default(),
message_handlers: Default::default(), message_handlers: Default::default(),
buffer: Mutex::new(VecDeque::new()),
}); });
Self::start_handling_messages(this.clone(), incoming_rx, cx); Self::start_handling_messages(this.clone(), incoming_rx, cx);
@ -1585,27 +1507,6 @@ impl ChannelClient {
let Some(this) = this.upgrade() else { let Some(this) = this.upgrade() else {
return anyhow::Ok(()); return anyhow::Ok(());
}; };
if let Some(ack_id) = incoming.ack_id {
let mut buffer = this.buffer.lock();
while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
buffer.pop_front();
}
}
if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) =
&incoming.payload
{
{
let buffer = this.buffer.lock();
for envelope in buffer.iter() {
this.outgoing_tx.unbounded_send(envelope.clone()).ok();
}
}
let response = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
this.send_dynamic(response).ok();
continue;
}
this.max_received.store(incoming.id, SeqCst);
if let Some(request_id) = incoming.responding_to { if let Some(request_id) = incoming.responding_to {
let request_id = MessageId(request_id); let request_id = MessageId(request_id);
@ -1682,23 +1583,6 @@ impl ChannelClient {
} }
} }
pub async fn resync(&self, timeout: Duration) -> Result<()> {
smol::future::or(
async {
self.request(proto::FlushBufferedMessages {}).await?;
for envelope in self.buffer.lock().iter() {
self.outgoing_tx.unbounded_send(envelope.clone()).ok();
}
Ok(())
},
async {
smol::Timer::after(timeout).await;
Err(anyhow!("Timeout detected"))
},
)
.await
}
pub async fn ping(&self, timeout: Duration) -> Result<()> { pub async fn ping(&self, timeout: Duration) -> Result<()> {
smol::future::or( smol::future::or(
async { async {
@ -1728,8 +1612,7 @@ impl ChannelClient {
let mut response_channels_lock = self.response_channels.lock(); let mut response_channels_lock = self.response_channels.lock();
response_channels_lock.insert(MessageId(envelope.id), tx); response_channels_lock.insert(MessageId(envelope.id), tx);
drop(response_channels_lock); drop(response_channels_lock);
let result = self.outgoing_tx.unbounded_send(envelope);
let result = self.send_buffered(envelope);
async move { async move {
if let Err(error) = &result { if let Err(error) = &result {
log::error!("failed to send message: {}", error); log::error!("failed to send message: {}", error);
@ -1746,12 +1629,6 @@ impl ChannelClient {
pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> { pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
envelope.id = self.next_message_id.fetch_add(1, SeqCst); envelope.id = self.next_message_id.fetch_add(1, SeqCst);
self.send_buffered(envelope)
}
pub fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
envelope.ack_id = Some(self.max_received.load(SeqCst));
self.buffer.lock().push_back(envelope.clone());
self.outgoing_tx.unbounded_send(envelope)?; self.outgoing_tx.unbounded_send(envelope)?;
Ok(()) Ok(())
} }
@ -1782,165 +1659,3 @@ impl ProtoClient for ChannelClient {
false false
} }
} }
#[cfg(any(test, feature = "test-support"))]
mod fake {
use std::path::PathBuf;
use anyhow::Result;
use async_trait::async_trait;
use futures::{
channel::{
mpsc::{self, Sender},
oneshot,
},
select_biased, FutureExt, SinkExt, StreamExt,
};
use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task};
use rpc::proto::Envelope;
use super::{
ChannelForwarder, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteProcess,
};
pub(super) struct SshRemoteConnection {
connection_options: SshConnectionOptions,
}
impl SshRemoteConnection {
pub(super) fn new(
connection_options: &SshConnectionOptions,
) -> Option<Box<dyn SshRemoteProcess>> {
if connection_options.host == "<fake>" {
return Some(Box::new(Self {
connection_options: connection_options.clone(),
}));
}
return None;
}
pub(super) async fn multiplex(
connection_options: SshConnectionOptions,
mut client_tx: mpsc::UnboundedSender<Envelope>,
mut client_rx: mpsc::UnboundedReceiver<Envelope>,
mut connection_activity_tx: Sender<()>,
cx: &mut AsyncAppContext,
) -> Task<Result<Option<i32>>> {
let (server_tx, server_rx) = cx
.update(|cx| {
cx.update_global(|conns: &mut GlobalConnections, _| {
conns.take(connection_options.port.unwrap())
})
})
.unwrap()
.into_channels()
.await;
let (forwarder, mut proxy_tx, mut proxy_rx) =
ChannelForwarder::new(server_tx, server_rx, cx);
cx.update(|cx| {
cx.update_global(|conns: &mut GlobalConnections, _| {
conns.replace(connection_options.port.unwrap(), forwarder)
})
})
.unwrap();
cx.background_executor().spawn(async move {
loop {
select_biased! {
server_to_client = proxy_rx.next().fuse() => {
let Some(server_to_client) = server_to_client else {
return Ok(Some(1))
};
connection_activity_tx.try_send(()).ok();
client_tx.send(server_to_client).await.ok();
}
client_to_server = client_rx.next().fuse() => {
let Some(client_to_server) = client_to_server else {
return Ok(None)
};
proxy_tx.send(client_to_server).await.ok();
}
}
}
})
}
}
#[async_trait]
impl SshRemoteProcess for SshRemoteConnection {
async fn kill(&mut self) -> Result<()> {
Ok(())
}
fn ssh_args(&self) -> Vec<String> {
Vec::new()
}
fn connection_options(&self) -> SshConnectionOptions {
self.connection_options.clone()
}
}
#[derive(Default)]
pub(super) struct GlobalConnections(Vec<Option<ChannelForwarder>>);
impl Global for GlobalConnections {}
impl GlobalConnections {
pub(super) fn push(&mut self, forwarder: ChannelForwarder) -> u16 {
self.0.push(Some(forwarder));
self.0.len() as u16 - 1
}
pub(super) fn take(&mut self, port: u16) -> ChannelForwarder {
self.0
.get_mut(port as usize)
.expect("no fake server for port")
.take()
.expect("fake server is already borrowed")
}
pub(super) fn replace(&mut self, port: u16, forwarder: ChannelForwarder) {
let ret = self
.0
.get_mut(port as usize)
.expect("no fake server for port")
.replace(forwarder);
if ret.is_some() {
panic!("fake server is already replaced");
}
}
}
pub(super) struct Delegate;
impl SshClientDelegate for Delegate {
fn ask_password(
&self,
_: String,
_: &mut AsyncAppContext,
) -> oneshot::Receiver<Result<String>> {
unreachable!()
}
fn remote_server_binary_path(
&self,
_: SshPlatform,
_: &mut AsyncAppContext,
) -> Result<PathBuf> {
unreachable!()
}
fn get_server_binary(
&self,
_: SshPlatform,
_: &mut AsyncAppContext,
) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>> {
unreachable!()
}
fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {
unreachable!()
}
fn set_error(&self, _: String, _: &mut AsyncAppContext) {
unreachable!()
}
}
}

View file

@ -641,47 +641,6 @@ async fn test_open_server_settings(cx: &mut TestAppContext, server_cx: &mut Test
}) })
} }
#[gpui::test(iterations = 10)]
async fn test_reconnect(cx: &mut TestAppContext, server_cx: &mut TestAppContext) {
let (project, _headless, fs) = init_test(cx, server_cx).await;
let (worktree, _) = project
.update(cx, |project, cx| {
project.find_or_create_worktree("/code/project1", true, cx)
})
.await
.unwrap();
let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id());
let buffer = project
.update(cx, |project, cx| {
project.open_buffer((worktree_id, Path::new("src/lib.rs")), cx)
})
.await
.unwrap();
buffer.update(cx, |buffer, cx| {
assert_eq!(buffer.text(), "fn one() -> usize { 1 }");
let ix = buffer.text().find('1').unwrap();
buffer.edit([(ix..ix + 1, "100")], None, cx);
});
let client = cx.read(|cx| project.read(cx).ssh_client().unwrap());
client
.update(cx, |client, cx| client.simulate_disconnect(cx))
.detach();
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.await
.unwrap();
assert_eq!(
fs.load("/code/project1/src/lib.rs".as_ref()).await.unwrap(),
"fn one() -> usize { 100 }"
);
}
fn init_logger() { fn init_logger() {
if std::env::var("RUST_LOG").is_ok() { if std::env::var("RUST_LOG").is_ok() {
env_logger::try_init().ok(); env_logger::try_init().ok();
@ -692,9 +651,9 @@ async fn init_test(
cx: &mut TestAppContext, cx: &mut TestAppContext,
server_cx: &mut TestAppContext, server_cx: &mut TestAppContext,
) -> (Model<Project>, Model<HeadlessProject>, Arc<FakeFs>) { ) -> (Model<Project>, Model<HeadlessProject>, Arc<FakeFs>) {
let (ssh_remote_client, ssh_server_client) = SshRemoteClient::fake(cx, server_cx);
init_logger(); init_logger();
let (forwarder, ssh_server_client) = SshRemoteClient::fake_server(server_cx);
let fs = FakeFs::new(server_cx.executor()); let fs = FakeFs::new(server_cx.executor());
fs.insert_tree( fs.insert_tree(
"/code", "/code",
@ -735,9 +694,8 @@ async fn init_test(
cx, cx,
) )
}); });
let project = build_project(ssh_remote_client, cx);
let ssh = SshRemoteClient::fake_client(forwarder, cx).await;
let project = build_project(ssh, cx);
project project
.update(cx, { .update(cx, {
let headless = headless.clone(); let headless = headless.clone();