Better handle interrupted connections for shared SSH (#19925)

Co-Authored-By: Mikayla <mikayla@zed.dev>
This commit is contained in:
Conrad Irwin 2024-10-29 14:43:34 -06:00 committed by GitHub
parent 5b7fa05a87
commit fb97e462de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1288,6 +1288,7 @@ impl SshRemoteConnection {
) -> Result<Self> {
use futures::AsyncWriteExt as _;
use futures::{io::BufReader, AsyncBufReadExt as _};
use smol::net::unix::UnixStream;
use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
use util::ResultExt as _;
@ -1304,6 +1305,9 @@ impl SshRemoteConnection {
let listener =
UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<UnixStream>();
let mut kill_tx = Some(askpass_kill_master_tx);
let askpass_task = cx.spawn({
let delegate = delegate.clone();
|mut cx| async move {
@ -1327,6 +1331,11 @@ impl SshRemoteConnection {
.log_err()
{
stream.write_all(password.as_bytes()).await.log_err();
} else {
if let Some(kill_tx) = kill_tx.take() {
kill_tx.send(stream).log_err();
break;
}
}
}
}
@ -1347,6 +1356,7 @@ impl SshRemoteConnection {
// the connection and keep it open, allowing other ssh commands to reuse it
// via a control socket.
let socket_path = temp_dir.path().join("ssh.sock");
let mut master_process = process::Command::new("ssh")
.stdin(Stdio::null())
.stdout(Stdio::piped())
@ -1369,22 +1379,30 @@ impl SshRemoteConnection {
// Wait for this ssh process to close its stdout, indicating that authentication
// has completed.
let stdout = master_process.stdout.as_mut().unwrap();
let mut stdout = master_process.stdout.take().unwrap();
let mut output = Vec::new();
let connection_timeout = Duration::from_secs(10);
let result = select_biased! {
_ = askpass_opened_rx.fuse() => {
select_biased! {
stream = askpass_kill_master_rx.fuse() => {
master_process.kill().ok();
drop(stream);
Err(anyhow!("SSH connection canceled"))
}
// If the askpass script has opened, that means the user is typing
// their password, in which case we don't want to timeout anymore,
// since we know a connection has been established.
stdout.read_to_end(&mut output).await?;
Ok(())
}
result = stdout.read_to_end(&mut output).fuse() => {
result?;
Ok(())
}
}
}
_ = stdout.read_to_end(&mut output).fuse() => {
Ok(())
}
_ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
}