#[cfg(any(test, feature = "test-support"))] pub mod test; mod socks; pub mod telemetry; pub mod user; pub mod zed_urls; use anyhow::{Context as _, Result, anyhow, bail}; use async_recursion::async_recursion; use async_tungstenite::tungstenite::{ client::IntoClientRequest, error::Error as WebsocketError, http::{HeaderValue, Request, StatusCode}, }; use chrono::{DateTime, Utc}; use clock::SystemClock; use credentials_provider::CredentialsProvider; use futures::{ AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, channel::oneshot, future::BoxFuture, }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; use http_client::{AsyncBody, HttpClient, HttpClientWithUrl}; use parking_lot::RwLock; use postage::watch; use rand::prelude::*; use release_channel::{AppVersion, ReleaseChannel}; use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; use socks::connect_socks_proxy_stream; use std::pin::Pin; use std::{ any::TypeId, convert::TryFrom, fmt::Write as _, future::Future, marker::PhantomData, path::PathBuf, sync::{ Arc, LazyLock, Weak, atomic::{AtomicU64, Ordering}, }, time::{Duration, Instant}, }; use telemetry::Telemetry; use thiserror::Error; use url::Url; use util::{ResultExt, TryFutureExt}; pub use rpc::*; pub use telemetry_events::Event; pub use user::*; static ZED_SERVER_URL: LazyLock> = LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok()); static ZED_RPC_URL: LazyLock> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok()); pub static IMPERSONATE_LOGIN: LazyLock> = LazyLock::new(|| { std::env::var("ZED_IMPERSONATE") .ok() .and_then(|s| if s.is_empty() { None } else { Some(s) }) }); pub static ADMIN_API_TOKEN: LazyLock> = LazyLock::new(|| { std::env::var("ZED_ADMIN_API_TOKEN") .ok() .and_then(|s| if s.is_empty() { None } else { Some(s) }) }); pub static ZED_APP_PATH: LazyLock> = LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from)); pub static ZED_ALWAYS_ACTIVE: LazyLock = LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty())); pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500); pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10); pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20); actions!(client, [SignIn, SignOut, Reconnect]); #[derive(Clone, Default, Serialize, Deserialize, JsonSchema)] pub struct ClientSettingsContent { server_url: Option, } #[derive(Deserialize)] pub struct ClientSettings { pub server_url: String, } impl Settings for ClientSettings { const KEY: Option<&'static str> = None; type FileContent = ClientSettingsContent; fn load(sources: SettingsSources, _: &mut App) -> Result { let mut result = sources.json_merge::()?; if let Some(server_url) = &*ZED_SERVER_URL { result.server_url.clone_from(server_url) } Ok(result) } fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} } #[derive(Default, Clone, Serialize, Deserialize, JsonSchema)] pub struct ProxySettingsContent { proxy: Option, } #[derive(Deserialize, Default)] pub struct ProxySettings { pub proxy: Option, } impl Settings for ProxySettings { const KEY: Option<&'static str> = None; type FileContent = ProxySettingsContent; fn load(sources: SettingsSources, _: &mut App) -> Result { Ok(Self { proxy: sources .user .or(sources.server) .and_then(|value| value.proxy.clone()) .or(sources.default.proxy.clone()), }) } fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) { vscode.string_setting("http.proxy", &mut current.proxy); } } pub fn init_settings(cx: &mut App) { TelemetrySettings::register(cx); ClientSettings::register(cx); ProxySettings::register(cx); } pub fn init(client: &Arc, cx: &mut App) { let client = Arc::downgrade(client); cx.on_action({ let client = client.clone(); move |_: &SignIn, cx| { if let Some(client) = client.upgrade() { cx.spawn(async move |cx| { client.authenticate_and_connect(true, &cx).log_err().await }) .detach(); } } }); cx.on_action({ let client = client.clone(); move |_: &SignOut, cx| { if let Some(client) = client.upgrade() { cx.spawn(async move |cx| { client.sign_out(&cx).await; }) .detach(); } } }); cx.on_action({ let client = client.clone(); move |_: &Reconnect, cx| { if let Some(client) = client.upgrade() { cx.spawn(async move |cx| { client.reconnect(&cx); }) .detach(); } } }); } struct GlobalClient(Arc); impl Global for GlobalClient {} pub struct Client { id: AtomicU64, peer: Arc, http: Arc, telemetry: Arc, credentials_provider: ClientCredentialsProvider, state: RwLock, handler_set: parking_lot::Mutex, #[allow(clippy::type_complexity)] #[cfg(any(test, feature = "test-support"))] authenticate: RwLock Task>>>>, #[allow(clippy::type_complexity)] #[cfg(any(test, feature = "test-support"))] establish_connection: RwLock< Option< Box< dyn 'static + Send + Sync + Fn( &Credentials, &AsyncApp, ) -> Task>, >, >, >, #[cfg(any(test, feature = "test-support"))] rpc_url: RwLock>, } #[derive(Error, Debug)] pub enum EstablishConnectionError { #[error("upgrade required")] UpgradeRequired, #[error("unauthorized")] Unauthorized, #[error("{0}")] Other(#[from] anyhow::Error), #[error("{0}")] InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue), #[error("{0}")] Io(#[from] std::io::Error), #[error("{0}")] Websocket(#[from] async_tungstenite::tungstenite::http::Error), } impl From for EstablishConnectionError { fn from(error: WebsocketError) -> Self { if let WebsocketError::Http(response) = &error { match response.status() { StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized, StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired, _ => {} } } EstablishConnectionError::Other(error.into()) } } impl EstablishConnectionError { pub fn other(error: impl Into + Send + Sync) -> Self { Self::Other(error.into()) } } #[derive(Copy, Clone, Debug, PartialEq)] pub enum Status { SignedOut, UpgradeRequired, Authenticating, Connecting, ConnectionError, Connected { peer_id: PeerId, connection_id: ConnectionId, }, ConnectionLost, Reauthenticating, Reconnecting, ReconnectionError { next_reconnection: Instant, }, } impl Status { pub fn is_connected(&self) -> bool { matches!(self, Self::Connected { .. }) } pub fn is_signed_out(&self) -> bool { matches!(self, Self::SignedOut | Self::UpgradeRequired) } } struct ClientState { credentials: Option, status: (watch::Sender, watch::Receiver), _reconnect_task: Option>, } #[derive(Clone, Debug, Eq, PartialEq)] pub struct Credentials { pub user_id: u64, pub access_token: String, } impl Credentials { pub fn authorization_header(&self) -> String { format!("{} {}", self.user_id, self.access_token) } } pub struct ClientCredentialsProvider { provider: Arc, } impl ClientCredentialsProvider { pub fn new(cx: &App) -> Self { Self { provider: ::global(cx), } } fn server_url(&self, cx: &AsyncApp) -> Result { cx.update(|cx| ClientSettings::get_global(cx).server_url.clone()) } /// Reads the credentials from the provider. fn read_credentials<'a>( &'a self, cx: &'a AsyncApp, ) -> Pin> + 'a>> { async move { if IMPERSONATE_LOGIN.is_some() { return None; } let server_url = self.server_url(cx).ok()?; let (user_id, access_token) = self .provider .read_credentials(&server_url, cx) .await .log_err() .flatten()?; Some(Credentials { user_id: user_id.parse().ok()?, access_token: String::from_utf8(access_token).ok()?, }) } .boxed_local() } /// Writes the credentials to the provider. fn write_credentials<'a>( &'a self, user_id: u64, access_token: String, cx: &'a AsyncApp, ) -> Pin> + 'a>> { async move { let server_url = self.server_url(cx)?; self.provider .write_credentials( &server_url, &user_id.to_string(), access_token.as_bytes(), cx, ) .await } .boxed_local() } /// Deletes the credentials from the provider. fn delete_credentials<'a>( &'a self, cx: &'a AsyncApp, ) -> Pin> + 'a>> { async move { let server_url = self.server_url(cx)?; self.provider.delete_credentials(&server_url, cx).await } .boxed_local() } } impl Default for ClientState { fn default() -> Self { Self { credentials: None, status: watch::channel_with(Status::SignedOut), _reconnect_task: None, } } } pub enum Subscription { Entity { client: Weak, id: (TypeId, u64), }, Message { client: Weak, id: TypeId, }, } impl Drop for Subscription { fn drop(&mut self) { match self { Subscription::Entity { client, id } => { if let Some(client) = client.upgrade() { let mut state = client.handler_set.lock(); let _ = state.entities_by_type_and_remote_id.remove(id); } } Subscription::Message { client, id } => { if let Some(client) = client.upgrade() { let mut state = client.handler_set.lock(); let _ = state.entity_types_by_message_type.remove(id); let _ = state.message_handlers.remove(id); } } } } } pub struct PendingEntitySubscription { client: Arc, remote_id: u64, _entity_type: PhantomData, consumed: bool, } impl PendingEntitySubscription { pub fn set_entity(mut self, entity: &Entity, cx: &AsyncApp) -> Subscription { self.consumed = true; let mut handlers = self.client.handler_set.lock(); let id = (TypeId::of::(), self.remote_id); let Some(EntityMessageSubscriber::Pending(messages)) = handlers.entities_by_type_and_remote_id.remove(&id) else { unreachable!() }; handlers.entities_by_type_and_remote_id.insert( id, EntityMessageSubscriber::Entity { handle: entity.downgrade().into(), }, ); drop(handlers); for message in messages { let client_id = self.client.id(); let type_name = message.payload_type_name(); let sender_id = message.original_sender_id(); log::debug!( "handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}", client_id, sender_id, type_name ); self.client.handle_message(message, cx); } Subscription::Entity { client: Arc::downgrade(&self.client), id, } } } impl Drop for PendingEntitySubscription { fn drop(&mut self) { if !self.consumed { let mut state = self.client.handler_set.lock(); if let Some(EntityMessageSubscriber::Pending(messages)) = state .entities_by_type_and_remote_id .remove(&(TypeId::of::(), self.remote_id)) { for message in messages { log::info!("unhandled message {}", message.payload_type_name()); } } } } } #[derive(Copy, Clone)] pub struct TelemetrySettings { pub diagnostics: bool, pub metrics: bool, } /// Control what info is collected by Zed. #[derive(Default, Clone, Serialize, Deserialize, JsonSchema)] pub struct TelemetrySettingsContent { /// Send debug info like crash reports. /// /// Default: true pub diagnostics: Option, /// Send anonymized usage data like what languages you're using Zed with. /// /// Default: true pub metrics: Option, } impl settings::Settings for TelemetrySettings { const KEY: Option<&'static str> = Some("telemetry"); type FileContent = TelemetrySettingsContent; fn load(sources: SettingsSources, _: &mut App) -> Result { Ok(Self { diagnostics: sources .user .as_ref() .or(sources.server.as_ref()) .and_then(|v| v.diagnostics) .unwrap_or( sources .default .diagnostics .ok_or_else(Self::missing_default)?, ), metrics: sources .user .as_ref() .or(sources.server.as_ref()) .and_then(|v| v.metrics) .unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?), }) } fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) { vscode.enum_setting("telemetry.telemetryLevel", &mut current.metrics, |s| { Some(s == "all") }); vscode.enum_setting("telemetry.telemetryLevel", &mut current.diagnostics, |s| { Some(matches!(s, "all" | "error" | "crash")) }); // we could translate telemetry.telemetryLevel, but just because users didn't want // to send microsoft telemetry doesn't mean they don't want to send it to zed. their // all/error/crash/off correspond to combinations of our "diagnostics" and "metrics". } } impl Client { pub fn new( clock: Arc, http: Arc, cx: &mut App, ) -> Arc { Arc::new(Self { id: AtomicU64::new(0), peer: Peer::new(0), telemetry: Telemetry::new(clock, http.clone(), cx), http, credentials_provider: ClientCredentialsProvider::new(cx), state: Default::default(), handler_set: Default::default(), #[cfg(any(test, feature = "test-support"))] authenticate: Default::default(), #[cfg(any(test, feature = "test-support"))] establish_connection: Default::default(), #[cfg(any(test, feature = "test-support"))] rpc_url: RwLock::default(), }) } pub fn production(cx: &mut App) -> Arc { let clock = Arc::new(clock::RealSystemClock); let http = Arc::new(HttpClientWithUrl::new_uri( cx.http_client(), &ClientSettings::get_global(cx).server_url, cx.http_client().proxy().cloned(), )); Self::new(clock, http, cx) } pub fn id(&self) -> u64 { self.id.load(Ordering::SeqCst) } pub fn http_client(&self) -> Arc { self.http.clone() } pub fn set_id(&self, id: u64) -> &Self { self.id.store(id, Ordering::SeqCst); self } #[cfg(any(test, feature = "test-support"))] pub fn teardown(&self) { let mut state = self.state.write(); state._reconnect_task.take(); self.handler_set.lock().clear(); self.peer.teardown(); } #[cfg(any(test, feature = "test-support"))] pub fn override_authenticate(&self, authenticate: F) -> &Self where F: 'static + Send + Sync + Fn(&AsyncApp) -> Task>, { *self.authenticate.write() = Some(Box::new(authenticate)); self } #[cfg(any(test, feature = "test-support"))] pub fn override_establish_connection(&self, connect: F) -> &Self where F: 'static + Send + Sync + Fn(&Credentials, &AsyncApp) -> Task>, { *self.establish_connection.write() = Some(Box::new(connect)); self } #[cfg(any(test, feature = "test-support"))] pub fn override_rpc_url(&self, url: Url) -> &Self { *self.rpc_url.write() = Some(url); self } pub fn global(cx: &App) -> Arc { cx.global::().0.clone() } pub fn set_global(client: Arc, cx: &mut App) { cx.set_global(GlobalClient(client)) } pub fn user_id(&self) -> Option { self.state .read() .credentials .as_ref() .map(|credentials| credentials.user_id) } pub fn peer_id(&self) -> Option { if let Status::Connected { peer_id, .. } = &*self.status().borrow() { Some(*peer_id) } else { None } } pub fn status(&self) -> watch::Receiver { self.state.read().status.1.clone() } fn set_status(self: &Arc, status: Status, cx: &AsyncApp) { log::info!("set status on client {}: {:?}", self.id(), status); let mut state = self.state.write(); *state.status.0.borrow_mut() = status; match status { Status::Connected { .. } => { state._reconnect_task = None; } Status::ConnectionLost => { let this = self.clone(); state._reconnect_task = Some(cx.spawn(async move |cx| { #[cfg(any(test, feature = "test-support"))] let mut rng = StdRng::seed_from_u64(0); #[cfg(not(any(test, feature = "test-support")))] let mut rng = StdRng::from_entropy(); let mut delay = INITIAL_RECONNECTION_DELAY; while let Err(error) = this.authenticate_and_connect(true, &cx).await { log::error!("failed to connect {}", error); if matches!(*this.status().borrow(), Status::ConnectionError) { this.set_status( Status::ReconnectionError { next_reconnection: Instant::now() + delay, }, &cx, ); cx.background_executor().timer(delay).await; delay = delay .mul_f32(rng.gen_range(0.5..=2.5)) .max(INITIAL_RECONNECTION_DELAY) .min(MAX_RECONNECTION_DELAY); } else { break; } } })); } Status::SignedOut | Status::UpgradeRequired => { self.telemetry.set_authenticated_user_info(None, false); state._reconnect_task.take(); } _ => {} } } pub fn subscribe_to_entity( self: &Arc, remote_id: u64, ) -> Result> where T: 'static, { let id = (TypeId::of::(), remote_id); let mut state = self.handler_set.lock(); if state.entities_by_type_and_remote_id.contains_key(&id) { return Err(anyhow!("already subscribed to entity")); } state .entities_by_type_and_remote_id .insert(id, EntityMessageSubscriber::Pending(Default::default())); Ok(PendingEntitySubscription { client: self.clone(), remote_id, consumed: false, _entity_type: PhantomData, }) } #[track_caller] pub fn add_message_handler( self: &Arc, entity: WeakEntity, handler: H, ) -> Subscription where M: EnvelopedMessage, E: 'static, H: 'static + Sync + Fn(Entity, TypedEnvelope, AsyncApp) -> F + Send + Sync, F: 'static + Future>, { self.add_message_handler_impl(entity, move |entity, message, _, cx| { handler(entity, message, cx) }) } fn add_message_handler_impl( self: &Arc, entity: WeakEntity, handler: H, ) -> Subscription where M: EnvelopedMessage, E: 'static, H: 'static + Sync + Fn(Entity, TypedEnvelope, AnyProtoClient, AsyncApp) -> F + Send + Sync, F: 'static + Future>, { let message_type_id = TypeId::of::(); let mut state = self.handler_set.lock(); state .entities_by_message_type .insert(message_type_id, entity.into()); let prev_handler = state.message_handlers.insert( message_type_id, Arc::new(move |subscriber, envelope, client, cx| { let subscriber = subscriber.downcast::().unwrap(); let envelope = envelope.into_any().downcast::>().unwrap(); handler(subscriber, *envelope, client.clone(), cx).boxed_local() }), ); if prev_handler.is_some() { let location = std::panic::Location::caller(); panic!( "{}:{} registered handler for the same message {} twice", location.file(), location.line(), std::any::type_name::() ); } Subscription::Message { client: Arc::downgrade(self), id: message_type_id, } } pub fn add_request_handler( self: &Arc, entity: WeakEntity, handler: H, ) -> Subscription where M: RequestMessage, E: 'static, H: 'static + Sync + Fn(Entity, TypedEnvelope, AsyncApp) -> F + Send + Sync, F: 'static + Future>, { self.add_message_handler_impl(entity, move |handle, envelope, this, cx| { Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this) }) } async fn respond_to_request>>( receipt: Receipt, response: F, client: AnyProtoClient, ) -> Result<()> { match response.await { Ok(response) => { client.send_response(receipt.message_id, response)?; Ok(()) } Err(error) => { client.send_response(receipt.message_id, error.to_proto())?; Err(error) } } } pub async fn has_credentials(&self, cx: &AsyncApp) -> bool { self.credentials_provider .read_credentials(cx) .await .is_some() } #[async_recursion(?Send)] pub async fn authenticate_and_connect( self: &Arc, try_provider: bool, cx: &AsyncApp, ) -> anyhow::Result<()> { let was_disconnected = match *self.status().borrow() { Status::SignedOut => true, Status::ConnectionError | Status::ConnectionLost | Status::Authenticating { .. } | Status::Reauthenticating { .. } | Status::ReconnectionError { .. } => false, Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { return Ok(()); } Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?, }; if was_disconnected { self.set_status(Status::Authenticating, cx); } else { self.set_status(Status::Reauthenticating, cx) } let mut read_from_provider = false; let mut credentials = self.state.read().credentials.clone(); if credentials.is_none() && try_provider { credentials = self.credentials_provider.read_credentials(cx).await; read_from_provider = credentials.is_some(); } if credentials.is_none() { let mut status_rx = self.status(); let _ = status_rx.next().await; futures::select_biased! { authenticate = self.authenticate(cx).fuse() => { match authenticate { Ok(creds) => credentials = Some(creds), Err(err) => { self.set_status(Status::ConnectionError, cx); return Err(err); } } } _ = status_rx.next().fuse() => { return Err(anyhow!("authentication canceled")); } } } let credentials = credentials.unwrap(); self.set_id(credentials.user_id); if was_disconnected { self.set_status(Status::Connecting, cx); } else { self.set_status(Status::Reconnecting, cx); } let mut timeout = futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT)); futures::select_biased! { connection = self.establish_connection(&credentials, cx).fuse() => { match connection { Ok(conn) => { self.state.write().credentials = Some(credentials.clone()); if !read_from_provider && IMPERSONATE_LOGIN.is_none() { self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err(); } futures::select_biased! { result = self.set_connection(conn, cx).fuse() => result, _ = timeout => { self.set_status(Status::ConnectionError, cx); Err(anyhow!("timed out waiting on hello message from server")) } } } Err(EstablishConnectionError::Unauthorized) => { self.state.write().credentials.take(); if read_from_provider { self.credentials_provider.delete_credentials(cx).await.log_err(); self.set_status(Status::SignedOut, cx); self.authenticate_and_connect(false, cx).await } else { self.set_status(Status::ConnectionError, cx); Err(EstablishConnectionError::Unauthorized)? } } Err(EstablishConnectionError::UpgradeRequired) => { self.set_status(Status::UpgradeRequired, cx); Err(EstablishConnectionError::UpgradeRequired)? } Err(error) => { self.set_status(Status::ConnectionError, cx); Err(error)? } } } _ = &mut timeout => { self.set_status(Status::ConnectionError, cx); Err(anyhow!("timed out trying to establish connection")) } } } async fn set_connection(self: &Arc, conn: Connection, cx: &AsyncApp) -> Result<()> { let executor = cx.background_executor(); log::debug!("add connection to peer"); let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, { let executor = executor.clone(); move |duration| executor.timer(duration) }); let handle_io = executor.spawn(handle_io); let peer_id = async { log::debug!("waiting for server hello"); let message = incoming .next() .await .ok_or_else(|| anyhow!("no hello message received"))?; log::debug!("got server hello"); let hello_message_type_name = message.payload_type_name().to_string(); let hello = message .into_any() .downcast::>() .map_err(|_| { anyhow!( "invalid hello message received: {:?}", hello_message_type_name ) })?; let peer_id = hello .payload .peer_id .ok_or_else(|| anyhow!("invalid peer id"))?; Ok(peer_id) }; let peer_id = match peer_id.await { Ok(peer_id) => peer_id, Err(error) => { self.peer.disconnect(connection_id); return Err(error); } }; log::debug!( "set status to connected (connection id: {:?}, peer id: {:?})", connection_id, peer_id ); self.set_status( Status::Connected { peer_id, connection_id, }, cx, ); cx.spawn({ let this = self.clone(); async move |cx| { while let Some(message) = incoming.next().await { this.handle_message(message, &cx); // Don't starve the main thread when receiving lots of messages at once. smol::future::yield_now().await; } } }) .detach(); cx.spawn({ let this = self.clone(); async move |cx| match handle_io.await { Ok(()) => { if *this.status().borrow() == (Status::Connected { connection_id, peer_id, }) { this.set_status(Status::SignedOut, &cx); } } Err(err) => { log::error!("connection error: {:?}", err); this.set_status(Status::ConnectionLost, &cx); } } }) .detach(); Ok(()) } fn authenticate(self: &Arc, cx: &AsyncApp) -> Task> { #[cfg(any(test, feature = "test-support"))] if let Some(callback) = self.authenticate.read().as_ref() { return callback(cx); } self.authenticate_with_browser(cx) } fn establish_connection( self: &Arc, credentials: &Credentials, cx: &AsyncApp, ) -> Task> { #[cfg(any(test, feature = "test-support"))] if let Some(callback) = self.establish_connection.read().as_ref() { return callback(credentials, cx); } self.establish_websocket_connection(credentials, cx) } fn rpc_url( &self, http: Arc, release_channel: Option, ) -> impl Future> + use<> { #[cfg(any(test, feature = "test-support"))] let url_override = self.rpc_url.read().clone(); async move { #[cfg(any(test, feature = "test-support"))] if let Some(url) = url_override { return Ok(url); } if let Some(url) = &*ZED_RPC_URL { return Url::parse(url).context("invalid rpc url"); } let mut url = http.build_url("/rpc"); if let Some(preview_param) = release_channel.and_then(|channel| channel.release_query_param()) { url += "?"; url += preview_param; } let response = http.get(&url, Default::default(), false).await?; let collab_url = if response.status().is_redirection() { response .headers() .get("Location") .ok_or_else(|| anyhow!("missing location header in /rpc response"))? .to_str() .map_err(EstablishConnectionError::other)? .to_string() } else { Err(anyhow!( "unexpected /rpc response status {}", response.status() ))? }; Url::parse(&collab_url).context("invalid rpc url") } } fn establish_websocket_connection( self: &Arc, credentials: &Credentials, cx: &AsyncApp, ) -> Task> { let release_channel = cx .update(|cx| ReleaseChannel::try_global(cx)) .ok() .flatten(); let app_version = cx .update(|cx| AppVersion::global(cx).to_string()) .ok() .unwrap_or_default(); let http = self.http.clone(); let proxy = http.proxy().cloned(); let credentials = credentials.clone(); let rpc_url = self.rpc_url(http, release_channel); let system_id = self.telemetry.system_id(); let metrics_id = self.telemetry.metrics_id(); cx.spawn(async move |cx| { use HttpOrHttps::*; #[derive(Debug)] enum HttpOrHttps { Http, Https, } let mut rpc_url = rpc_url.await?; let url_scheme = match rpc_url.scheme() { "https" => Https, "http" => Http, _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?, }; let rpc_host = rpc_url .host_str() .zip(rpc_url.port_or_known_default()) .ok_or_else(|| anyhow!("missing host in rpc url"))?; let stream = { let handle = cx.update(|cx| gpui_tokio::Tokio::handle(cx)).ok().unwrap(); let _guard = handle.enter(); connect_socks_proxy_stream(proxy.as_ref(), rpc_host).await? }; log::info!("connected to rpc endpoint {}", rpc_url); rpc_url .set_scheme(match url_scheme { Https => "wss", Http => "ws", }) .unwrap(); // We call `into_client_request` to let `tungstenite` construct the WebSocket request // for us from the RPC URL. // // Among other things, it will generate and set a `Sec-WebSocket-Key` header for us. let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?; // We then modify the request to add our desired headers. let request_headers = request.headers_mut(); request_headers.insert( "Authorization", HeaderValue::from_str(&credentials.authorization_header())?, ); request_headers.insert( "x-zed-protocol-version", HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?, ); request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?); request_headers.insert( "x-zed-release-channel", HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?, ); if let Some(system_id) = system_id { request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?); } if let Some(metrics_id) = metrics_id { request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?); } let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config( request, stream, Some(Arc::new(http_client_tls::tls_config()).into()), None, ) .await?; Ok(Connection::new( stream .map_err(|error| anyhow!(error)) .sink_map_err(|error| anyhow!(error)), )) }) } pub fn authenticate_with_browser(self: &Arc, cx: &AsyncApp) -> Task> { let http = self.http.clone(); let this = self.clone(); cx.spawn(async move |cx| { let background = cx.background_executor().clone(); let (open_url_tx, open_url_rx) = oneshot::channel::(); cx.update(|cx| { cx.spawn(async move |cx| { let url = open_url_rx.await?; cx.update(|cx| cx.open_url(&url)) }) .detach_and_log_err(cx); }) .log_err(); let credentials = background .clone() .spawn(async move { // Generate a pair of asymmetric encryption keys. The public key will be used by the // zed server to encrypt the user's access token, so that it can'be intercepted by // any other app running on the user's device. let (public_key, private_key) = rpc::auth::keypair().expect("failed to generate keypair for auth"); let public_key_string = String::try_from(public_key) .expect("failed to serialize public key for auth"); if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) { eprintln!("authenticate as admin {login}, {token}"); return this .authenticate_as_admin(http, login.clone(), token.clone()) .await; } // Start an HTTP server to receive the redirect from Zed's sign-in page. let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port"); let port = server.server_addr().port(); // Open the Zed sign-in page in the user's browser, with query parameters that indicate // that the user is signing in from a Zed app running on the same device. let mut url = http.build_url(&format!( "/native_app_signin?native_app_port={}&native_app_public_key={}", port, public_key_string )); if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() { log::info!("impersonating user @{}", impersonate_login); write!(&mut url, "&impersonate={}", impersonate_login).unwrap(); } open_url_tx.send(url).log_err(); // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted // access token from the query params. // // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a // custom URL scheme instead of this local HTTP server. let (user_id, access_token) = background .spawn(async move { for _ in 0..100 { if let Some(req) = server.recv_timeout(Duration::from_secs(1))? { let path = req.url(); let mut user_id = None; let mut access_token = None; let url = Url::parse(&format!("http://example.com{}", path)) .context("failed to parse login notification url")?; for (key, value) in url.query_pairs() { if key == "access_token" { access_token = Some(value.to_string()); } else if key == "user_id" { user_id = Some(value.to_string()); } } let post_auth_url = http.build_url("/native_app_signin_succeeded"); req.respond( tiny_http::Response::empty(302).with_header( tiny_http::Header::from_bytes( &b"Location"[..], post_auth_url.as_bytes(), ) .unwrap(), ), ) .context("failed to respond to login http request")?; return Ok(( user_id .ok_or_else(|| anyhow!("missing user_id parameter"))?, access_token.ok_or_else(|| { anyhow!("missing access_token parameter") })?, )); } } Err(anyhow!("didn't receive login redirect")) }) .await?; let access_token = private_key .decrypt_string(&access_token) .context("failed to decrypt access token")?; Ok(Credentials { user_id: user_id.parse()?, access_token, }) }) .await?; cx.update(|cx| cx.activate(true))?; Ok(credentials) }) } async fn authenticate_as_admin( self: &Arc, http: Arc, login: String, mut api_token: String, ) -> Result { #[derive(Deserialize)] struct AuthenticatedUserResponse { user: User, } #[derive(Deserialize)] struct User { id: u64, } let github_user = { #[derive(Deserialize)] struct GithubUser { id: i32, login: String, created_at: DateTime, } let request = { let mut request_builder = Request::get(&format!("https://api.github.com/users/{login}")); if let Ok(github_token) = std::env::var("GITHUB_TOKEN") { request_builder = request_builder.header("Authorization", format!("Bearer {}", github_token)); } request_builder.body(AsyncBody::empty())? }; let mut response = http .send(request) .await .context("error fetching GitHub user")?; let mut body = Vec::new(); response .body_mut() .read_to_end(&mut body) .await .context("error reading GitHub user")?; if !response.status().is_success() { let text = String::from_utf8_lossy(body.as_slice()); bail!( "status error {}, response: {text:?}", response.status().as_u16() ); } serde_json::from_slice::(body.as_slice()).map_err(|err| { log::error!("Error deserializing: {:?}", err); log::error!( "GitHub API response text: {:?}", String::from_utf8_lossy(body.as_slice()) ); anyhow!("error deserializing GitHub user") })? }; let query_params = [ ("github_login", &github_user.login), ("github_user_id", &github_user.id.to_string()), ( "github_user_created_at", &github_user.created_at.to_rfc3339(), ), ]; // Use the collab server's admin API to retrieve the ID // of the impersonated user. let mut url = self.rpc_url(http.clone(), None).await?; url.set_path("/user"); url.set_query(Some( &query_params .iter() .map(|(key, value)| { format!( "{}={}", key, url::form_urlencoded::byte_serialize(value.as_bytes()).collect::() ) }) .collect::>() .join("&"), )); let request: http_client::Request = Request::get(url.as_str()) .header("Authorization", format!("token {api_token}")) .body("".into())?; let mut response = http.send(request).await?; let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; if !response.status().is_success() { Err(anyhow!( "admin user request failed {} - {}", response.status().as_u16(), body, ))?; } let response: AuthenticatedUserResponse = serde_json::from_str(&body)?; // Use the admin API token to authenticate as the impersonated user. api_token.insert_str(0, "ADMIN_TOKEN:"); Ok(Credentials { user_id: response.user.id, access_token: api_token, }) } pub async fn sign_out(self: &Arc, cx: &AsyncApp) { self.state.write().credentials = None; self.disconnect(cx); if self.has_credentials(cx).await { self.credentials_provider .delete_credentials(cx) .await .log_err(); } } pub fn disconnect(self: &Arc, cx: &AsyncApp) { self.peer.teardown(); self.set_status(Status::SignedOut, cx); } pub fn reconnect(self: &Arc, cx: &AsyncApp) { self.peer.teardown(); self.set_status(Status::ConnectionLost, cx); } fn connection_id(&self) -> Result { if let Status::Connected { connection_id, .. } = *self.status().borrow() { Ok(connection_id) } else { Err(anyhow!("not connected")) } } pub fn send(&self, message: T) -> Result<()> { log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME); self.peer.send(self.connection_id()?, message) } pub fn request( &self, request: T, ) -> impl Future> + use { self.request_envelope(request) .map_ok(|envelope| envelope.payload) } pub fn request_stream( &self, request: T, ) -> impl Future>>> { let client_id = self.id.load(Ordering::SeqCst); log::debug!( "rpc request start. client_id:{}. name:{}", client_id, T::NAME ); let response = self .connection_id() .map(|conn_id| self.peer.request_stream(conn_id, request)); async move { let response = response?.await; log::debug!( "rpc request finish. client_id:{}. name:{}", client_id, T::NAME ); response } } pub fn request_envelope( &self, request: T, ) -> impl Future>> + use { let client_id = self.id(); log::debug!( "rpc request start. client_id:{}. name:{}", client_id, T::NAME ); let response = self .connection_id() .map(|conn_id| self.peer.request_envelope(conn_id, request)); async move { let response = response?.await; log::debug!( "rpc request finish. client_id:{}. name:{}", client_id, T::NAME ); response } } pub fn request_dynamic( &self, envelope: proto::Envelope, request_type: &'static str, ) -> impl Future> + use<> { let client_id = self.id(); log::debug!( "rpc request start. client_id:{}. name:{}", client_id, request_type ); let response = self .connection_id() .map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type)); async move { let response = response?.await; log::debug!( "rpc request finish. client_id:{}. name:{}", client_id, request_type ); Ok(response?.0) } } fn handle_message(self: &Arc, message: Box, cx: &AsyncApp) { let sender_id = message.sender_id(); let request_id = message.message_id(); let type_name = message.payload_type_name(); let original_sender_id = message.original_sender_id(); if let Some(future) = ProtoMessageHandlerSet::handle_message( &self.handler_set, message, self.clone().into(), cx.clone(), ) { let client_id = self.id(); log::debug!( "rpc message received. client_id:{}, sender_id:{:?}, type:{}", client_id, original_sender_id, type_name ); cx.spawn(async move |_| match future.await { Ok(()) => { log::debug!( "rpc message handled. client_id:{}, sender_id:{:?}, type:{}", client_id, original_sender_id, type_name ); } Err(error) => { log::error!( "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}", client_id, original_sender_id, type_name, error ); } }) .detach(); } else { log::info!("unhandled message {}", type_name); self.peer .respond_with_unhandled_message(sender_id.into(), request_id, type_name) .log_err(); } } pub fn telemetry(&self) -> &Arc { &self.telemetry } } impl ProtoClient for Client { fn request( &self, envelope: proto::Envelope, request_type: &'static str, ) -> BoxFuture<'static, Result> { self.request_dynamic(envelope, request_type).boxed() } fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> { log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type); let connection_id = self.connection_id()?; self.peer.send_dynamic(connection_id, envelope) } fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> { log::debug!( "rpc respond. client_id:{}, name:{}", self.id(), message_type ); let connection_id = self.connection_id()?; self.peer.send_dynamic(connection_id, envelope) } fn message_handler_set(&self) -> &parking_lot::Mutex { &self.handler_set } fn is_via_collab(&self) -> bool { true } } /// prefix for the zed:// url scheme pub const ZED_URL_SCHEME: &str = "zed"; /// Parses the given link into a Zed link. /// /// Returns a [`Some`] containing the unprefixed link if the link is a Zed link. /// Returns [`None`] otherwise. pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> { let server_url = &ClientSettings::get_global(cx).server_url; if let Some(stripped) = link .strip_prefix(server_url) .and_then(|result| result.strip_prefix('/')) { return Some(stripped); } if let Some(stripped) = link .strip_prefix(ZED_URL_SCHEME) .and_then(|result| result.strip_prefix("://")) { return Some(stripped); } None } #[cfg(test)] mod tests { use super::*; use crate::test::FakeServer; use clock::FakeSystemClock; use gpui::{AppContext as _, BackgroundExecutor, TestAppContext}; use http_client::FakeHttpClient; use parking_lot::Mutex; use proto::TypedEnvelope; use settings::SettingsStore; use std::future; #[gpui::test(iterations = 10)] async fn test_reconnection(cx: &mut TestAppContext) { init_test(cx); let user_id = 5; let client = cx.update(|cx| { Client::new( Arc::new(FakeSystemClock::new()), FakeHttpClient::with_404_response(), cx, ) }); let server = FakeServer::for_client(user_id, &client, cx).await; let mut status = client.status(); assert!(matches!( status.next().await, Some(Status::Connected { .. }) )); assert_eq!(server.auth_count(), 1); server.forbid_connections(); server.disconnect(); while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} server.allow_connections(); cx.executor().advance_clock(Duration::from_secs(10)); while !matches!(status.next().await, Some(Status::Connected { .. })) {} assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting server.forbid_connections(); server.disconnect(); while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} // Clear cached credentials after authentication fails server.roll_access_token(); server.allow_connections(); cx.executor().run_until_parked(); cx.executor().advance_clock(Duration::from_secs(10)); while !matches!(status.next().await, Some(Status::Connected { .. })) {} assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token } #[gpui::test(iterations = 10)] async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) { init_test(cx); let user_id = 5; let client = cx.update(|cx| { Client::new( Arc::new(FakeSystemClock::new()), FakeHttpClient::with_404_response(), cx, ) }); let mut status = client.status(); // Time out when client tries to connect. client.override_authenticate(move |cx| { cx.background_spawn(async move { Ok(Credentials { user_id, access_token: "token".into(), }) }) }); client.override_establish_connection(|_, cx| { cx.background_spawn(async move { future::pending::<()>().await; unreachable!() }) }); let auth_and_connect = cx.spawn({ let client = client.clone(); |cx| async move { client.authenticate_and_connect(false, &cx).await } }); executor.run_until_parked(); assert!(matches!(status.next().await, Some(Status::Connecting))); executor.advance_clock(CONNECTION_TIMEOUT); assert!(matches!( status.next().await, Some(Status::ConnectionError { .. }) )); auth_and_connect.await.unwrap_err(); // Allow the connection to be established. let server = FakeServer::for_client(user_id, &client, cx).await; assert!(matches!( status.next().await, Some(Status::Connected { .. }) )); // Disconnect client. server.forbid_connections(); server.disconnect(); while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} // Time out when re-establishing the connection. server.allow_connections(); client.override_establish_connection(|_, cx| { cx.background_spawn(async move { future::pending::<()>().await; unreachable!() }) }); executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY); assert!(matches!( status.next().await, Some(Status::Reconnecting { .. }) )); executor.advance_clock(CONNECTION_TIMEOUT); assert!(matches!( status.next().await, Some(Status::ReconnectionError { .. }) )); } #[gpui::test(iterations = 10)] async fn test_authenticating_more_than_once( cx: &mut TestAppContext, executor: BackgroundExecutor, ) { init_test(cx); let auth_count = Arc::new(Mutex::new(0)); let dropped_auth_count = Arc::new(Mutex::new(0)); let client = cx.update(|cx| { Client::new( Arc::new(FakeSystemClock::new()), FakeHttpClient::with_404_response(), cx, ) }); client.override_authenticate({ let auth_count = auth_count.clone(); let dropped_auth_count = dropped_auth_count.clone(); move |cx| { let auth_count = auth_count.clone(); let dropped_auth_count = dropped_auth_count.clone(); cx.background_spawn(async move { *auth_count.lock() += 1; let _drop = util::defer(move || *dropped_auth_count.lock() += 1); future::pending::<()>().await; unreachable!() }) } }); let _authenticate = cx.spawn({ let client = client.clone(); move |cx| async move { client.authenticate_and_connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 1); assert_eq!(*dropped_auth_count.lock(), 0); let _authenticate = cx.spawn({ let client = client.clone(); |cx| async move { client.authenticate_and_connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 2); assert_eq!(*dropped_auth_count.lock(), 1); } #[gpui::test] async fn test_subscribing_to_entity(cx: &mut TestAppContext) { init_test(cx); let user_id = 5; let client = cx.update(|cx| { Client::new( Arc::new(FakeSystemClock::new()), FakeHttpClient::with_404_response(), cx, ) }); let server = FakeServer::for_client(user_id, &client, cx).await; let (done_tx1, done_rx1) = smol::channel::unbounded(); let (done_tx2, done_rx2) = smol::channel::unbounded(); AnyProtoClient::from(client.clone()).add_entity_message_handler( move |entity: Entity, _: TypedEnvelope, mut cx| { match entity.update(&mut cx, |entity, _| entity.id).unwrap() { 1 => done_tx1.try_send(()).unwrap(), 2 => done_tx2.try_send(()).unwrap(), _ => unreachable!(), } async { Ok(()) } }, ); let entity1 = cx.new(|_| TestEntity { id: 1, subscription: None, }); let entity2 = cx.new(|_| TestEntity { id: 2, subscription: None, }); let entity3 = cx.new(|_| TestEntity { id: 3, subscription: None, }); let _subscription1 = client .subscribe_to_entity(1) .unwrap() .set_entity(&entity1, &mut cx.to_async()); let _subscription2 = client .subscribe_to_entity(2) .unwrap() .set_entity(&entity2, &mut cx.to_async()); // Ensure dropping a subscription for the same entity type still allows receiving of // messages for other entity IDs of the same type. let subscription3 = client .subscribe_to_entity(3) .unwrap() .set_entity(&entity3, &mut cx.to_async()); drop(subscription3); server.send(proto::JoinProject { project_id: 1 }); server.send(proto::JoinProject { project_id: 2 }); done_rx1.recv().await.unwrap(); done_rx2.recv().await.unwrap(); } #[gpui::test] async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) { init_test(cx); let user_id = 5; let client = cx.update(|cx| { Client::new( Arc::new(FakeSystemClock::new()), FakeHttpClient::with_404_response(), cx, ) }); let server = FakeServer::for_client(user_id, &client, cx).await; let entity = cx.new(|_| TestEntity::default()); let (done_tx1, _done_rx1) = smol::channel::unbounded(); let (done_tx2, done_rx2) = smol::channel::unbounded(); let subscription1 = client.add_message_handler( entity.downgrade(), move |_, _: TypedEnvelope, _| { done_tx1.try_send(()).unwrap(); async { Ok(()) } }, ); drop(subscription1); let _subscription2 = client.add_message_handler( entity.downgrade(), move |_, _: TypedEnvelope, _| { done_tx2.try_send(()).unwrap(); async { Ok(()) } }, ); server.send(proto::Ping {}); done_rx2.recv().await.unwrap(); } #[gpui::test] async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) { init_test(cx); let user_id = 5; let client = cx.update(|cx| { Client::new( Arc::new(FakeSystemClock::new()), FakeHttpClient::with_404_response(), cx, ) }); let server = FakeServer::for_client(user_id, &client, cx).await; let entity = cx.new(|_| TestEntity::default()); let (done_tx, done_rx) = smol::channel::unbounded(); let subscription = client.add_message_handler( entity.clone().downgrade(), move |entity: Entity, _: TypedEnvelope, mut cx| { entity .update(&mut cx, |entity, _| entity.subscription.take()) .unwrap(); done_tx.try_send(()).unwrap(); async { Ok(()) } }, ); entity.update(cx, |entity, _| { entity.subscription = Some(subscription); }); server.send(proto::Ping {}); done_rx.recv().await.unwrap(); } #[derive(Default)] struct TestEntity { id: usize, subscription: Option, } fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); init_settings(cx); }); } }