SSH installation refactor (#19991)

This also cleans up logic for deciding how to do things.

Release Notes:

- Remoting: If downloading the binary on the remote fails, fall back to
uploading it.

---------

Co-authored-by: Mikayala <mikayla@zed.dev>
This commit is contained in:
Conrad Irwin 2024-10-30 17:20:11 -06:00 committed by GitHub
parent 6d5784daa6
commit 40802d91d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 315 additions and 297 deletions

1
Cargo.lock generated
View file

@ -9538,6 +9538,7 @@ dependencies = [
"log", "log",
"parking_lot", "parking_lot",
"prost", "prost",
"release_channel",
"rpc", "rpc",
"serde", "serde",
"serde_json", "serde_json",

View file

@ -13,8 +13,7 @@ use gpui::{AppContext, Model};
use language::CursorShape; use language::CursorShape;
use markdown::{Markdown, MarkdownStyle}; use markdown::{Markdown, MarkdownStyle};
use release_channel::{AppVersion, ReleaseChannel}; use release_channel::ReleaseChannel;
use remote::ssh_session::{ServerBinary, ServerVersion};
use remote::{SshConnectionOptions, SshPlatform, SshRemoteClient}; use remote::{SshConnectionOptions, SshPlatform, SshRemoteClient};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -441,23 +440,66 @@ impl remote::SshClientDelegate for SshClientDelegate {
self.update_status(status, cx) self.update_status(status, cx)
} }
fn get_server_binary( fn download_server_binary_locally(
&self, &self,
platform: SshPlatform, platform: SshPlatform,
upload_binary_over_ssh: bool, release_channel: ReleaseChannel,
version: Option<SemanticVersion>,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> oneshot::Receiver<Result<(ServerBinary, ServerVersion)>> { ) -> Task<anyhow::Result<PathBuf>> {
let (tx, rx) = oneshot::channel();
let this = self.clone();
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
tx.send( let binary_path = AutoUpdater::download_remote_server_release(
this.get_server_binary_impl(platform, upload_binary_over_ssh, &mut cx) platform.os,
.await, platform.arch,
release_channel,
version,
&mut cx,
) )
.ok(); .await
.map_err(|e| {
anyhow!(
"Failed to download remote server binary (version: {}, os: {}, arch: {}): {}",
version
.map(|v| format!("{}", v))
.unwrap_or("unknown".to_string()),
platform.os,
platform.arch,
e
)
})?;
Ok(binary_path)
}) })
.detach(); }
rx
fn get_download_params(
&self,
platform: SshPlatform,
release_channel: ReleaseChannel,
version: Option<SemanticVersion>,
cx: &mut AsyncAppContext,
) -> Task<Result<(String, String)>> {
cx.spawn(|mut cx| async move {
let (release, request_body) = AutoUpdater::get_remote_server_release_url(
platform.os,
platform.arch,
release_channel,
version,
&mut cx,
)
.await
.map_err(|e| {
anyhow!(
"Failed to get remote server binary download url (version: {}, os: {}, arch: {}): {}",
version.map(|v| format!("{}", v)).unwrap_or("unknown".to_string()),
platform.os,
platform.arch,
e
)
})?;
Ok((release.url, request_body))
}
)
} }
fn remote_server_binary_path( fn remote_server_binary_path(
@ -485,208 +527,6 @@ impl SshClientDelegate {
}) })
.ok(); .ok();
} }
async fn get_server_binary_impl(
&self,
platform: SshPlatform,
upload_binary_via_ssh: bool,
cx: &mut AsyncAppContext,
) -> Result<(ServerBinary, ServerVersion)> {
let (version, release_channel) = cx.update(|cx| {
let version = AppVersion::global(cx);
let channel = ReleaseChannel::global(cx);
(version, channel)
})?;
// In dev mode, build the remote server binary from source
#[cfg(debug_assertions)]
if release_channel == ReleaseChannel::Dev {
let result = self.build_local(cx, platform, version).await?;
// Fall through to a remote binary if we're not able to compile a local binary
if let Some((path, version)) = result {
return Ok((
ServerBinary::LocalBinary(path),
ServerVersion::Semantic(version),
));
}
}
// For nightly channel, always get latest
let current_version = if release_channel == ReleaseChannel::Nightly {
None
} else {
Some(version)
};
self.update_status(
Some(&format!("Checking remote server release {}", version)),
cx,
);
if upload_binary_via_ssh {
let binary_path = AutoUpdater::download_remote_server_release(
platform.os,
platform.arch,
release_channel,
current_version,
cx,
)
.await
.map_err(|e| {
anyhow!(
"Failed to download remote server binary (version: {}, os: {}, arch: {}): {}",
version,
platform.os,
platform.arch,
e
)
})?;
Ok((
ServerBinary::LocalBinary(binary_path),
ServerVersion::Semantic(version),
))
} else {
let (release, request_body) = AutoUpdater::get_remote_server_release_url(
platform.os,
platform.arch,
release_channel,
current_version,
cx,
)
.await
.map_err(|e| {
anyhow!(
"Failed to get remote server binary download url (version: {}, os: {}, arch: {}): {}",
version,
platform.os,
platform.arch,
e
)
})?;
let version = release
.version
.parse::<SemanticVersion>()
.map(ServerVersion::Semantic)
.unwrap_or_else(|_| ServerVersion::Commit(release.version));
Ok((
ServerBinary::ReleaseUrl {
url: release.url,
body: request_body,
},
version,
))
}
}
#[cfg(debug_assertions)]
async fn build_local(
&self,
cx: &mut AsyncAppContext,
platform: SshPlatform,
version: gpui::SemanticVersion,
) -> Result<Option<(PathBuf, gpui::SemanticVersion)>> {
use smol::process::{Command, Stdio};
async fn run_cmd(command: &mut Command) -> Result<()> {
let output = command
.kill_on_drop(true)
.stderr(Stdio::inherit())
.output()
.await?;
if !output.status.success() {
Err(anyhow!("Failed to run command: {:?}", command))?;
}
Ok(())
}
if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
self.update_status(Some("Building remote server binary from source"), cx);
log::info!("building remote server binary from source");
run_cmd(Command::new("cargo").args([
"build",
"--package",
"remote_server",
"--features",
"debug-embed",
"--target-dir",
"target/remote_server",
]))
.await?;
self.update_status(Some("Compressing binary"), cx);
run_cmd(Command::new("gzip").args([
"-9",
"-f",
"target/remote_server/debug/remote_server",
]))
.await?;
let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
return Ok(Some((path, version)));
} else if let Some(triple) = platform.triple() {
smol::fs::create_dir_all("target/remote_server").await?;
self.update_status(Some("Installing cross.rs for cross-compilation"), cx);
log::info!("installing cross");
run_cmd(Command::new("cargo").args([
"install",
"cross",
"--git",
"https://github.com/cross-rs/cross",
]))
.await?;
self.update_status(
Some(&format!(
"Building remote server binary from source for {} with Docker",
&triple
)),
cx,
);
log::info!("building remote server binary from source for {}", &triple);
run_cmd(
Command::new("cross")
.args([
"build",
"--package",
"remote_server",
"--features",
"debug-embed",
"--target-dir",
"target/remote_server",
"--target",
&triple,
])
.env(
"CROSS_CONTAINER_OPTS",
"--mount type=bind,src=./target,dst=/app/target",
),
)
.await?;
self.update_status(Some("Compressing binary"), cx);
run_cmd(Command::new("gzip").args([
"-9",
"-f",
&format!("target/remote_server/{}/debug/remote_server", triple),
]))
.await?;
let path = std::env::current_dir()?.join(format!(
"target/remote_server/{}/debug/remote_server.gz",
triple
));
return Ok(Some((path, version)));
} else {
return Ok(None);
}
}
} }
pub fn is_connecting_over_ssh(workspace: &Workspace, cx: &AppContext) -> bool { pub fn is_connecting_over_ssh(workspace: &Workspace, cx: &AppContext) -> bool {

View file

@ -35,6 +35,7 @@ smol.workspace = true
tempfile.workspace = true tempfile.workspace = true
thiserror.workspace = true thiserror.workspace = true
util.workspace = true util.workspace = true
release_channel.workspace = true
[dev-dependencies] [dev-dependencies]
gpui = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] }

View file

@ -21,6 +21,7 @@ use gpui::{
ModelContext, SemanticVersion, Task, WeakModel, ModelContext, SemanticVersion, Task, WeakModel,
}; };
use parking_lot::Mutex; use parking_lot::Mutex;
use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
use rpc::{ use rpc::{
proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage}, proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet, AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
@ -227,10 +228,19 @@ pub enum ServerBinary {
ReleaseUrl { url: String, body: String }, ReleaseUrl { url: String, body: String },
} }
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ServerVersion { pub enum ServerVersion {
Semantic(SemanticVersion), Semantic(SemanticVersion),
Commit(String), Commit(String),
} }
impl ServerVersion {
pub fn semantic_version(&self) -> Option<SemanticVersion> {
match self {
Self::Semantic(version) => Some(*version),
_ => None,
}
}
}
impl std::fmt::Display for ServerVersion { impl std::fmt::Display for ServerVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@ -252,12 +262,21 @@ pub trait SshClientDelegate: Send + Sync {
platform: SshPlatform, platform: SshPlatform,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> Result<PathBuf>; ) -> Result<PathBuf>;
fn get_server_binary( fn get_download_params(
&self, &self,
platform: SshPlatform, platform: SshPlatform,
upload_binary_over_ssh: bool, release_channel: ReleaseChannel,
version: Option<SemanticVersion>,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> oneshot::Receiver<Result<(ServerBinary, ServerVersion)>>; ) -> Task<Result<(String, String)>>;
fn download_server_binary_locally(
&self,
platform: SshPlatform,
release_channel: ReleaseChannel,
version: Option<SemanticVersion>,
cx: &mut AsyncAppContext,
) -> Task<Result<PathBuf>>;
fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext); fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
} }
@ -1727,86 +1746,123 @@ impl SshRemoteConnection {
platform: SshPlatform, platform: SshPlatform,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> Result<()> { ) -> Result<()> {
if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() { let current_version = match run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
if let Ok(installed_version) = {
run_cmd(self.socket.ssh_command(dst_path).arg("version")).await Ok(version_output) => {
{
log::info!("using cached server binary version {}", installed_version);
return Ok(());
}
}
if cfg!(not(debug_assertions)) {
// When we're not in dev mode, we don't want to switch out the binary if it's
// still open.
// In dev mode, that's fine, since we often kill Zed processes with Ctrl-C and want
// to still replace the binary.
if self.is_binary_in_use(dst_path).await? {
log::info!("server binary is opened by another process. not updating");
delegate.set_status(
Some("Skipping update of remote development server, since it's still in use"),
cx,
);
return Ok(());
}
}
let upload_binary_over_ssh = self.socket.connection_options.upload_binary_over_ssh;
let (binary, new_server_version) = delegate
.get_server_binary(platform, upload_binary_over_ssh, cx)
.await??;
if cfg!(not(debug_assertions)) {
let installed_version = if let Ok(version_output) =
run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
{
if let Ok(version) = version_output.trim().parse::<SemanticVersion>() { if let Ok(version) = version_output.trim().parse::<SemanticVersion>() {
Some(ServerVersion::Semantic(version)) Some(ServerVersion::Semantic(version))
} else { } else {
Some(ServerVersion::Commit(version_output.trim().to_string())) Some(ServerVersion::Commit(version_output.trim().to_string()))
} }
} else { }
None Err(_) => None,
};
let (release_channel, wanted_version) = cx.update(|cx| {
let release_channel = ReleaseChannel::global(cx);
let wanted_version = match release_channel {
ReleaseChannel::Nightly => {
AppCommitSha::try_global(cx).map(|sha| ServerVersion::Commit(sha.0))
}
ReleaseChannel::Dev => None,
_ => Some(ServerVersion::Semantic(AppVersion::global(cx))),
}; };
(release_channel, wanted_version)
})?;
if let Some(installed_version) = installed_version { match (&current_version, &wanted_version) {
use ServerVersion::*; (Some(current), Some(wanted)) if current == wanted => {
match (installed_version, new_server_version) { log::info!("remote development server present and matching client version");
(Semantic(installed), Semantic(new)) if installed == new => { return Ok(());
log::info!("remote development server present and matching client version"); }
return Ok(()); (Some(ServerVersion::Semantic(current)), Some(ServerVersion::Semantic(wanted)))
} if current > wanted =>
(Semantic(installed), Semantic(new)) if installed > new => { {
let error = anyhow!("The version of the remote server ({}) is newer than the Zed version ({}). Please update Zed.", installed, new); anyhow::bail!("The version of the remote server ({}) is newer than the Zed version ({}). Please update Zed.", current, wanted);
return Err(error); }
} _ => {
(Commit(installed), Commit(new)) if installed == new => { log::info!("Installing remote development server");
log::info!( }
"remote development server present and matching client version {}", }
installed
); if self.is_binary_in_use(dst_path).await? {
return Ok(()); // When we're not in dev mode, we don't want to switch out the binary if it's
} // still open.
(installed, _) => { // In dev mode, that's fine, since we often kill Zed processes with Ctrl-C and want
log::info!( // to still replace the binary.
"remote development server has version: {}. updating...", if cfg!(not(debug_assertions)) {
installed anyhow::bail!("The remote server version ({:?}) does not match the wanted version ({:?}), but is in use by another Zed client so cannot be upgraded.", &current_version, &wanted_version)
); } else {
} log::info!("Binary is currently in use, ignoring because this is a dev build")
}
}
if wanted_version.is_none() {
if std::env::var("ZED_BUILD_REMOTE_SERVER").is_err() {
if let Some(current_version) = current_version {
log::warn!(
"In development, using cached remote server binary version ({})",
current_version
);
return Ok(());
} else {
anyhow::bail!(
"ZED_BUILD_REMOTE_SERVER is not set, but no remote server exists at ({:?})",
dst_path
)
}
}
#[cfg(debug_assertions)]
{
let src_path = self.build_local(platform, delegate, cx).await?;
return self
.upload_local_server_binary(&src_path, dst_path, delegate, cx)
.await;
}
#[cfg(not(debug_assertions))]
anyhow::bail!("Running development build in release mode, cannot cross compile (unset ZED_BUILD_REMOTE_SERVER)")
}
let upload_binary_over_ssh = self.socket.connection_options.upload_binary_over_ssh;
if !upload_binary_over_ssh {
let (url, body) = delegate
.get_download_params(
platform,
release_channel,
wanted_version.clone().and_then(|v| v.semantic_version()),
cx,
)
.await?;
match self
.download_binary_on_server(&url, &body, dst_path, delegate, cx)
.await
{
Ok(_) => return Ok(()),
Err(e) => {
log::error!(
"Failed to download binary on server, attempting to upload server: {}",
e
)
} }
} }
} }
match binary { let src_path = delegate
ServerBinary::LocalBinary(src_path) => { .download_server_binary_locally(
self.upload_local_server_binary(&src_path, dst_path, delegate, cx) platform,
.await release_channel,
} wanted_version.and_then(|v| v.semantic_version()),
ServerBinary::ReleaseUrl { url, body } => { cx,
self.download_binary_on_server(&url, &body, dst_path, delegate, cx) )
.await .await?;
}
} self.upload_local_server_binary(&src_path, dst_path, delegate, cx)
.await
} }
async fn is_binary_in_use(&self, binary_path: &Path) -> Result<bool> { async fn is_binary_in_use(&self, binary_path: &Path) -> Result<bool> {
@ -1973,6 +2029,113 @@ impl SshRemoteConnection {
)) ))
} }
} }
#[cfg(debug_assertions)]
async fn build_local(
&self,
platform: SshPlatform,
delegate: &Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext,
) -> Result<PathBuf> {
use smol::process::{Command, Stdio};
async fn run_cmd(command: &mut Command) -> Result<()> {
let output = command
.kill_on_drop(true)
.stderr(Stdio::inherit())
.output()
.await?;
if !output.status.success() {
Err(anyhow!("Failed to run command: {:?}", command))?;
}
Ok(())
}
if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
delegate.set_status(Some("Building remote server binary from source"), cx);
log::info!("building remote server binary from source");
run_cmd(Command::new("cargo").args([
"build",
"--package",
"remote_server",
"--features",
"debug-embed",
"--target-dir",
"target/remote_server",
]))
.await?;
delegate.set_status(Some("Compressing binary"), cx);
run_cmd(Command::new("gzip").args([
"-9",
"-f",
"target/remote_server/debug/remote_server",
]))
.await?;
let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
return Ok(path);
}
let Some(triple) = platform.triple() else {
anyhow::bail!("can't cross compile for: {:?}", platform);
};
smol::fs::create_dir_all("target/remote_server").await?;
delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
log::info!("installing cross");
run_cmd(Command::new("cargo").args([
"install",
"cross",
"--git",
"https://github.com/cross-rs/cross",
]))
.await?;
delegate.set_status(
Some(&format!(
"Building remote server binary from source for {} with Docker",
&triple
)),
cx,
);
log::info!("building remote server binary from source for {}", &triple);
run_cmd(
Command::new("cross")
.args([
"build",
"--package",
"remote_server",
"--features",
"debug-embed",
"--target-dir",
"target/remote_server",
"--target",
&triple,
])
.env(
"CROSS_CONTAINER_OPTS",
"--mount type=bind,src=./target,dst=/app/target",
),
)
.await?;
delegate.set_status(Some("Compressing binary"), cx);
run_cmd(Command::new("gzip").args([
"-9",
"-f",
&format!("target/remote_server/{}/debug/remote_server", triple),
]))
.await?;
let path = std::env::current_dir()?.join(format!(
"target/remote_server/{}/debug/remote_server.gz",
triple
));
return Ok(path);
}
} }
type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>; type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
@ -2294,12 +2457,12 @@ mod fake {
}, },
select_biased, FutureExt, SinkExt, StreamExt, select_biased, FutureExt, SinkExt, StreamExt,
}; };
use gpui::{AsyncAppContext, Task, TestAppContext}; use gpui::{AsyncAppContext, SemanticVersion, Task, TestAppContext};
use release_channel::ReleaseChannel;
use rpc::proto::Envelope; use rpc::proto::Envelope;
use super::{ use super::{
ChannelClient, RemoteConnection, ServerBinary, ServerVersion, SshClientDelegate, ChannelClient, RemoteConnection, SshClientDelegate, SshConnectionOptions, SshPlatform,
SshConnectionOptions, SshPlatform,
}; };
pub(super) struct FakeRemoteConnection { pub(super) struct FakeRemoteConnection {
@ -2411,23 +2574,36 @@ mod fake {
) -> oneshot::Receiver<Result<String>> { ) -> oneshot::Receiver<Result<String>> {
unreachable!() unreachable!()
} }
fn remote_server_binary_path(
fn download_server_binary_locally(
&self, &self,
_: SshPlatform, _: SshPlatform,
_: ReleaseChannel,
_: Option<SemanticVersion>,
_: &mut AsyncAppContext, _: &mut AsyncAppContext,
) -> Result<PathBuf> { ) -> Task<Result<PathBuf>> {
unreachable!() unreachable!()
} }
fn get_server_binary(
fn get_download_params(
&self, &self,
_: SshPlatform, _platform: SshPlatform,
_: bool, _release_channel: ReleaseChannel,
_: &mut AsyncAppContext, _version: Option<SemanticVersion>,
) -> oneshot::Receiver<Result<(ServerBinary, ServerVersion)>> { _cx: &mut AsyncAppContext,
) -> Task<Result<(String, String)>> {
unreachable!() unreachable!()
} }
fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {} fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {}
fn remote_server_binary_path(
&self,
_platform: SshPlatform,
_cx: &mut AsyncAppContext,
) -> Result<PathBuf> {
unreachable!()
}
} }
} }