diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 2b45efec85..abb54e1ec9 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -16,39 +16,38 @@ doctest = false test-support = ["clock/test-support", "collections/test-support", "gpui/test-support", "rpc/test-support"] [dependencies] -chrono = { workspace = true, features = ["serde"] } -clock.workspace = true -collections.workspace = true -gpui.workspace = true -util.workspace = true -release_channel.workspace = true -rpc.workspace = true -text.workspace = true -settings.workspace = true -feature_flags.workspace = true - anyhow.workspace = true async-recursion = "0.3" async-tungstenite = { version = "0.16", features = ["async-std", "async-native-tls"] } +chrono = { workspace = true, features = ["serde"] } +clock.workspace = true +collections.workspace = true +feature_flags.workspace = true futures.workspace = true +gpui.workspace = true lazy_static.workspace = true log.workspace = true once_cell = "1.19.0" parking_lot.workspace = true postage.workspace = true rand.workspace = true +release_channel.workspace = true +rpc.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +settings.workspace = true sha2.workspace = true smol.workspace = true sysinfo.workspace = true telemetry_events.workspace = true tempfile.workspace = true +text.workspace = true thiserror.workspace = true time.workspace = true tiny_http = "0.8" url.workspace = true +util.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 2144b1ed81..9e4c4a43ca 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -30,6 +30,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources, SettingsStore}; use std::fmt; +use std::pin::Pin; use std::{ any::TypeId, convert::TryFrom, @@ -65,6 +66,13 @@ impl fmt::Display for DevServerToken { lazy_static! { static ref ZED_SERVER_URL: Option = std::env::var("ZED_SERVER_URL").ok(); static ref ZED_RPC_URL: Option = std::env::var("ZED_RPC_URL").ok(); + /// An environment variable whose presence indicates that the development auth + /// provider should be used. + /// + /// Only works in development. Setting this environment variable in other release + /// channels is a no-op. + pub static ref ZED_DEVELOPMENT_AUTH: bool = + std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty()); pub static ref IMPERSONATE_LOGIN: Option = std::env::var("ZED_IMPERSONATE") .ok() .and_then(|s| if s.is_empty() { None } else { Some(s) }); @@ -161,6 +169,7 @@ pub struct Client { peer: Arc, http: Arc, telemetry: Arc, + credentials_provider: Arc, state: RwLock, #[allow(clippy::type_complexity)] @@ -298,6 +307,32 @@ impl Credentials { } } +/// A provider for [`Credentials`]. +/// +/// Used to abstract over reading and writing credentials to some form of +/// persistence (like the system keychain). +trait CredentialsProvider { + /// Reads the credentials from the provider. + fn read_credentials<'a>( + &'a self, + cx: &'a AsyncAppContext, + ) -> Pin> + 'a>>; + + /// Writes the credentials to the provider. + fn write_credentials<'a>( + &'a self, + user_id: u64, + access_token: String, + cx: &'a AsyncAppContext, + ) -> Pin> + 'a>>; + + /// Deletes the credentials from the provider. + fn delete_credentials<'a>( + &'a self, + cx: &'a AsyncAppContext, + ) -> Pin> + 'a>>; +} + impl Default for ClientState { fn default() -> Self { Self { @@ -443,11 +478,27 @@ impl Client { http: Arc, cx: &mut AppContext, ) -> Arc { + let use_zed_development_auth = match ReleaseChannel::try_global(cx) { + Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH, + Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable) + | None => false, + }; + + let credentials_provider: Arc = + if use_zed_development_auth { + Arc::new(DevelopmentCredentialsProvider { + path: util::paths::CONFIG_DIR.join("development_auth"), + }) + } else { + Arc::new(KeychainCredentialsProvider) + }; + Arc::new(Self { id: AtomicU64::new(0), peer: Peer::new(0), telemetry: Telemetry::new(clock, http.clone(), cx), http, + credentials_provider, state: Default::default(), #[cfg(any(test, feature = "test-support"))] @@ -763,8 +814,11 @@ impl Client { } } - pub async fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool { - read_credentials_from_keychain(cx).await.is_some() + pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool { + self.credentials_provider + .read_credentials(cx) + .await + .is_some() } pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self { @@ -775,7 +829,7 @@ impl Client { #[async_recursion(?Send)] pub async fn authenticate_and_connect( self: &Arc, - try_keychain: bool, + try_provider: bool, cx: &AsyncAppContext, ) -> anyhow::Result<()> { let was_disconnected = match *self.status().borrow() { @@ -796,12 +850,13 @@ impl Client { self.set_status(Status::Reauthenticating, cx) } - let mut read_from_keychain = false; + let mut read_from_provider = false; let mut credentials = self.state.read().credentials.clone(); - if credentials.is_none() && try_keychain { - credentials = read_credentials_from_keychain(cx).await; - read_from_keychain = credentials.is_some(); + if credentials.is_none() && try_provider { + credentials = self.credentials_provider.read_credentials(cx).await; + read_from_provider = credentials.is_some(); } + if credentials.is_none() { let mut status_rx = self.status(); let _ = status_rx.next().await; @@ -838,9 +893,9 @@ impl Client { match connection { Ok(conn) => { self.state.write().credentials = Some(credentials.clone()); - if !read_from_keychain && IMPERSONATE_LOGIN.is_none() { + if !read_from_provider && IMPERSONATE_LOGIN.is_none() { if let Credentials::User{user_id, access_token} = credentials { - write_credentials_to_keychain(user_id, access_token, cx).await.log_err(); + self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err(); } } @@ -854,8 +909,8 @@ impl Client { } Err(EstablishConnectionError::Unauthorized) => { self.state.write().credentials.take(); - if read_from_keychain { - delete_credentials_from_keychain(cx).await.log_err(); + if read_from_provider { + self.credentials_provider.delete_credentials(cx).await.log_err(); self.set_status(Status::SignedOut, cx); self.authenticate_and_connect(false, cx).await } else { @@ -1264,8 +1319,11 @@ impl Client { self.state.write().credentials = None; self.disconnect(&cx); - if self.has_keychain_credentials(cx).await { - delete_credentials_from_keychain(cx).await.log_err(); + if self.has_credentials(cx).await { + self.credentials_provider + .delete_credentials(cx) + .await + .log_err(); } } @@ -1465,41 +1523,128 @@ impl Client { } } -async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option { - if IMPERSONATE_LOGIN.is_some() { - return None; - } - - let (user_id, access_token) = cx - .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url)) - .log_err()? - .await - .log_err()??; - - Some(Credentials::User { - user_id: user_id.parse().ok()?, - access_token: String::from_utf8(access_token).ok()?, - }) -} - -async fn write_credentials_to_keychain( +#[derive(Serialize, Deserialize)] +struct DevelopmentCredentials { user_id: u64, access_token: String, - cx: &AsyncAppContext, -) -> Result<()> { - cx.update(move |cx| { - cx.write_credentials( - &ClientSettings::get_global(cx).server_url, - &user_id.to_string(), - access_token.as_bytes(), - ) - })? - .await } -async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> { - cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))? - .await +/// A credentials provider that stores credentials in a local file. +/// +/// This MUST only be used in development, as this is not a secure way of storing +/// credentials on user machines. +/// +/// Its existence is purely to work around the annoyance of having to constantly +/// re-allow access to the system keychain when developing Zed. +struct DevelopmentCredentialsProvider { + path: PathBuf, +} + +impl CredentialsProvider for DevelopmentCredentialsProvider { + fn read_credentials<'a>( + &'a self, + _cx: &'a AsyncAppContext, + ) -> Pin> + 'a>> { + async move { + if IMPERSONATE_LOGIN.is_some() { + return None; + } + + let json = std::fs::read(&self.path).log_err()?; + + let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?; + + Some(Credentials::User { + user_id: credentials.user_id, + access_token: credentials.access_token, + }) + } + .boxed_local() + } + + fn write_credentials<'a>( + &'a self, + user_id: u64, + access_token: String, + _cx: &'a AsyncAppContext, + ) -> Pin> + 'a>> { + async move { + let json = serde_json::to_string(&DevelopmentCredentials { + user_id, + access_token, + })?; + + std::fs::write(&self.path, json)?; + + Ok(()) + } + .boxed_local() + } + + fn delete_credentials<'a>( + &'a self, + _cx: &'a AsyncAppContext, + ) -> Pin> + 'a>> { + async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local() + } +} + +/// A credentials provider that stores credentials in the system keychain. +struct KeychainCredentialsProvider; + +impl CredentialsProvider for KeychainCredentialsProvider { + fn read_credentials<'a>( + &'a self, + cx: &'a AsyncAppContext, + ) -> Pin> + 'a>> { + async move { + if IMPERSONATE_LOGIN.is_some() { + return None; + } + + let (user_id, access_token) = cx + .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url)) + .log_err()? + .await + .log_err()??; + + Some(Credentials::User { + user_id: user_id.parse().ok()?, + access_token: String::from_utf8(access_token).ok()?, + }) + } + .boxed_local() + } + + fn write_credentials<'a>( + &'a self, + user_id: u64, + access_token: String, + cx: &'a AsyncAppContext, + ) -> Pin> + 'a>> { + async move { + cx.update(move |cx| { + cx.write_credentials( + &ClientSettings::get_global(cx).server_url, + &user_id.to_string(), + access_token.as_bytes(), + ) + })? + .await + } + .boxed_local() + } + + fn delete_credentials<'a>( + &'a self, + cx: &'a AsyncAppContext, + ) -> Pin> + 'a>> { + async move { + cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))? + .await + } + .boxed_local() + } } /// prefix for the zed:// url scheme diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index ba7d879a1d..3bcc66dd93 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -542,10 +542,12 @@ fn handle_open_request( async fn authenticate(client: Arc, cx: &AsyncAppContext) -> Result<()> { if stdout_is_a_pty() { - if client::IMPERSONATE_LOGIN.is_some() { + if *client::ZED_DEVELOPMENT_AUTH { + client.authenticate_and_connect(true, &cx).await?; + } else if client::IMPERSONATE_LOGIN.is_some() { client.authenticate_and_connect(false, &cx).await?; } - } else if client.has_keychain_credentials(&cx).await { + } else if client.has_credentials(&cx).await { client.authenticate_and_connect(true, &cx).await?; } Ok::<_, anyhow::Error>(())