#[cfg(any(test, feature = "test-support"))]
pub mod test;
mod proxy;
pub mod telemetry;
pub mod user;
pub mod zed_urls;
use anyhow::{Context as _, Result, anyhow};
use async_tungstenite::tungstenite::{
client::IntoClientRequest,
error::Error as WebsocketError,
http::{HeaderValue, Request, StatusCode},
};
use clock::SystemClock;
use cloud_api_client::CloudApiClient;
use cloud_api_client::websocket_protocol::MessageToClient;
use credentials_provider::CredentialsProvider;
use feature_flags::FeatureFlagAppExt as _;
use futures::{
AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
channel::oneshot, future::BoxFuture,
};
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{HttpClient, HttpClientWithUrl, http};
use parking_lot::RwLock;
use postage::watch;
use proxy::connect_proxy_stream;
use rand::prelude::*;
use release_channel::{AppVersion, ReleaseChannel};
use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use std::{
any::TypeId,
convert::TryFrom,
fmt::Write as _,
future::Future,
marker::PhantomData,
path::PathBuf,
sync::{
Arc, LazyLock, Weak,
atomic::{AtomicU64, Ordering},
},
time::{Duration, Instant},
};
use std::{cmp, pin::Pin};
use telemetry::Telemetry;
use thiserror::Error;
use tokio::net::TcpStream;
use url::Url;
use util::{ConnectionResult, ResultExt};
pub use rpc::*;
pub use telemetry_events::Event;
pub use user::*;
static ZED_SERVER_URL: LazyLock> =
LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
static ZED_RPC_URL: LazyLock > = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
pub static IMPERSONATE_LOGIN: LazyLock > = LazyLock::new(|| {
std::env::var("ZED_IMPERSONATE")
.ok()
.and_then(|s| if s.is_empty() { None } else { Some(s) })
});
pub static ADMIN_API_TOKEN: LazyLock > = LazyLock::new(|| {
std::env::var("ZED_ADMIN_API_TOKEN")
.ok()
.and_then(|s| if s.is_empty() { None } else { Some(s) })
});
pub static ZED_APP_PATH: LazyLock > =
LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
pub static ZED_ALWAYS_ACTIVE: LazyLock =
LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").is_ok_and(|e| !e.is_empty()));
pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
actions!(
client,
[
/// Signs in to Zed account.
SignIn,
/// Signs out of Zed account.
SignOut,
/// Reconnects to the collaboration server.
Reconnect
]
);
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
pub struct ClientSettingsContent {
server_url: Option,
}
#[derive(Deserialize)]
pub struct ClientSettings {
pub server_url: String,
}
impl Settings for ClientSettings {
const KEY: Option<&'static str> = None;
type FileContent = ClientSettingsContent;
fn load(sources: SettingsSources, _: &mut App) -> Result {
let mut result = sources.json_merge::()?;
if let Some(server_url) = &*ZED_SERVER_URL {
result.server_url.clone_from(server_url)
}
Ok(result)
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}
#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ProxySettingsContent {
proxy: Option,
}
#[derive(Deserialize, Default)]
pub struct ProxySettings {
pub proxy: Option,
}
impl Settings for ProxySettings {
const KEY: Option<&'static str> = None;
type FileContent = ProxySettingsContent;
fn load(sources: SettingsSources, _: &mut App) -> Result {
Ok(Self {
proxy: sources
.user
.or(sources.server)
.and_then(|value| value.proxy.clone())
.or(sources.default.proxy.clone()),
})
}
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
vscode.string_setting("http.proxy", &mut current.proxy);
}
}
pub fn init_settings(cx: &mut App) {
TelemetrySettings::register(cx);
ClientSettings::register(cx);
ProxySettings::register(cx);
}
pub fn init(client: &Arc, cx: &mut App) {
let client = Arc::downgrade(client);
cx.on_action({
let client = client.clone();
move |_: &SignIn, cx| {
if let Some(client) = client.upgrade() {
cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, cx).await)
.detach_and_log_err(cx);
}
}
});
cx.on_action({
let client = client.clone();
move |_: &SignOut, cx| {
if let Some(client) = client.upgrade() {
cx.spawn(async move |cx| {
client.sign_out(cx).await;
})
.detach();
}
}
});
cx.on_action({
let client = client;
move |_: &Reconnect, cx| {
if let Some(client) = client.upgrade() {
cx.spawn(async move |cx| {
client.reconnect(cx);
})
.detach();
}
}
});
}
pub type MessageToClientHandler = Box;
struct GlobalClient(Arc);
impl Global for GlobalClient {}
pub struct Client {
id: AtomicU64,
peer: Arc,
http: Arc,
cloud_client: Arc,
telemetry: Arc,
credentials_provider: ClientCredentialsProvider,
state: RwLock,
handler_set: parking_lot::Mutex,
message_to_client_handlers: parking_lot::Mutex>,
#[allow(clippy::type_complexity)]
#[cfg(any(test, feature = "test-support"))]
authenticate:
RwLock Task>>>>,
#[allow(clippy::type_complexity)]
#[cfg(any(test, feature = "test-support"))]
establish_connection: RwLock<
Option<
Box<
dyn 'static
+ Send
+ Sync
+ Fn(
&Credentials,
&AsyncApp,
) -> Task>,
>,
>,
>,
#[cfg(any(test, feature = "test-support"))]
rpc_url: RwLock>,
}
#[derive(Error, Debug)]
pub enum EstablishConnectionError {
#[error("upgrade required")]
UpgradeRequired,
#[error("unauthorized")]
Unauthorized,
#[error("{0}")]
Other(#[from] anyhow::Error),
#[error("{0}")]
InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
Websocket(#[from] async_tungstenite::tungstenite::http::Error),
}
impl From for EstablishConnectionError {
fn from(error: WebsocketError) -> Self {
if let WebsocketError::Http(response) = &error {
match response.status() {
StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
_ => {}
}
}
EstablishConnectionError::Other(error.into())
}
}
impl EstablishConnectionError {
pub fn other(error: impl Into + Send + Sync) -> Self {
Self::Other(error.into())
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum Status {
SignedOut,
UpgradeRequired,
Authenticating,
Authenticated,
AuthenticationError,
Connecting,
ConnectionError,
Connected {
peer_id: PeerId,
connection_id: ConnectionId,
},
ConnectionLost,
Reauthenticating,
Reconnecting,
ReconnectionError {
next_reconnection: Instant,
},
}
impl Status {
pub fn is_connected(&self) -> bool {
matches!(self, Self::Connected { .. })
}
pub fn is_signing_in(&self) -> bool {
matches!(
self,
Self::Authenticating | Self::Reauthenticating | Self::Connecting | Self::Reconnecting
)
}
pub fn is_signed_out(&self) -> bool {
matches!(self, Self::SignedOut | Self::UpgradeRequired)
}
}
struct ClientState {
credentials: Option,
status: (watch::Sender, watch::Receiver),
_reconnect_task: Option>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Credentials {
pub user_id: u64,
pub access_token: String,
}
impl Credentials {
pub fn authorization_header(&self) -> String {
format!("{} {}", self.user_id, self.access_token)
}
}
pub struct ClientCredentialsProvider {
provider: Arc,
}
impl ClientCredentialsProvider {
pub fn new(cx: &App) -> Self {
Self {
provider: ::global(cx),
}
}
fn server_url(&self, cx: &AsyncApp) -> Result {
cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
}
/// Reads the credentials from the provider.
fn read_credentials<'a>(
&'a self,
cx: &'a AsyncApp,
) -> Pin> + 'a>> {
async move {
if IMPERSONATE_LOGIN.is_some() {
return None;
}
let server_url = self.server_url(cx).ok()?;
let (user_id, access_token) = self
.provider
.read_credentials(&server_url, cx)
.await
.log_err()
.flatten()?;
Some(Credentials {
user_id: user_id.parse().ok()?,
access_token: String::from_utf8(access_token).ok()?,
})
}
.boxed_local()
}
/// Writes the credentials to the provider.
fn write_credentials<'a>(
&'a self,
user_id: u64,
access_token: String,
cx: &'a AsyncApp,
) -> Pin> + 'a>> {
async move {
let server_url = self.server_url(cx)?;
self.provider
.write_credentials(
&server_url,
&user_id.to_string(),
access_token.as_bytes(),
cx,
)
.await
}
.boxed_local()
}
/// Deletes the credentials from the provider.
fn delete_credentials<'a>(
&'a self,
cx: &'a AsyncApp,
) -> Pin> + 'a>> {
async move {
let server_url = self.server_url(cx)?;
self.provider.delete_credentials(&server_url, cx).await
}
.boxed_local()
}
}
impl Default for ClientState {
fn default() -> Self {
Self {
credentials: None,
status: watch::channel_with(Status::SignedOut),
_reconnect_task: None,
}
}
}
pub enum Subscription {
Entity {
client: Weak,
id: (TypeId, u64),
},
Message {
client: Weak,
id: TypeId,
},
}
impl Drop for Subscription {
fn drop(&mut self) {
match self {
Subscription::Entity { client, id } => {
if let Some(client) = client.upgrade() {
let mut state = client.handler_set.lock();
let _ = state.entities_by_type_and_remote_id.remove(id);
}
}
Subscription::Message { client, id } => {
if let Some(client) = client.upgrade() {
let mut state = client.handler_set.lock();
let _ = state.entity_types_by_message_type.remove(id);
let _ = state.message_handlers.remove(id);
}
}
}
}
}
pub struct PendingEntitySubscription {
client: Arc,
remote_id: u64,
_entity_type: PhantomData,
consumed: bool,
}
impl PendingEntitySubscription {
pub fn set_entity(mut self, entity: &Entity, cx: &AsyncApp) -> Subscription {
self.consumed = true;
let mut handlers = self.client.handler_set.lock();
let id = (TypeId::of::(), self.remote_id);
let Some(EntityMessageSubscriber::Pending(messages)) =
handlers.entities_by_type_and_remote_id.remove(&id)
else {
unreachable!()
};
handlers.entities_by_type_and_remote_id.insert(
id,
EntityMessageSubscriber::Entity {
handle: entity.downgrade().into(),
},
);
drop(handlers);
for message in messages {
let client_id = self.client.id();
let type_name = message.payload_type_name();
let sender_id = message.original_sender_id();
log::debug!(
"handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
client_id,
sender_id,
type_name
);
self.client.handle_message(message, cx);
}
Subscription::Entity {
client: Arc::downgrade(&self.client),
id,
}
}
}
impl Drop for PendingEntitySubscription {
fn drop(&mut self) {
if !self.consumed {
let mut state = self.client.handler_set.lock();
if let Some(EntityMessageSubscriber::Pending(messages)) = state
.entities_by_type_and_remote_id
.remove(&(TypeId::of::(), self.remote_id))
{
for message in messages {
log::info!("unhandled message {}", message.payload_type_name());
}
}
}
}
}
#[derive(Copy, Clone, Deserialize, Debug)]
pub struct TelemetrySettings {
pub diagnostics: bool,
pub metrics: bool,
}
/// Control what info is collected by Zed.
#[derive(Default, Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct TelemetrySettingsContent {
/// Send debug info like crash reports.
///
/// Default: true
pub diagnostics: Option,
/// Send anonymized usage data like what languages you're using Zed with.
///
/// Default: true
pub metrics: Option,
}
impl settings::Settings for TelemetrySettings {
const KEY: Option<&'static str> = Some("telemetry");
type FileContent = TelemetrySettingsContent;
fn load(sources: SettingsSources, _: &mut App) -> Result {
sources.json_merge()
}
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
vscode.enum_setting("telemetry.telemetryLevel", &mut current.metrics, |s| {
Some(s == "all")
});
vscode.enum_setting("telemetry.telemetryLevel", &mut current.diagnostics, |s| {
Some(matches!(s, "all" | "error" | "crash"))
});
// we could translate telemetry.telemetryLevel, but just because users didn't want
// to send microsoft telemetry doesn't mean they don't want to send it to zed. their
// all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
}
}
impl Client {
pub fn new(
clock: Arc,
http: Arc,
cx: &mut App,
) -> Arc {
Arc::new(Self {
id: AtomicU64::new(0),
peer: Peer::new(0),
telemetry: Telemetry::new(clock, http.clone(), cx),
cloud_client: Arc::new(CloudApiClient::new(http.clone())),
http,
credentials_provider: ClientCredentialsProvider::new(cx),
state: Default::default(),
handler_set: Default::default(),
message_to_client_handlers: parking_lot::Mutex::new(Vec::new()),
#[cfg(any(test, feature = "test-support"))]
authenticate: Default::default(),
#[cfg(any(test, feature = "test-support"))]
establish_connection: Default::default(),
#[cfg(any(test, feature = "test-support"))]
rpc_url: RwLock::default(),
})
}
pub fn production(cx: &mut App) -> Arc {
let clock = Arc::new(clock::RealSystemClock);
let http = Arc::new(HttpClientWithUrl::new_url(
cx.http_client(),
&ClientSettings::get_global(cx).server_url,
cx.http_client().proxy().cloned(),
));
Self::new(clock, http, cx)
}
pub fn id(&self) -> u64 {
self.id.load(Ordering::SeqCst)
}
pub fn http_client(&self) -> Arc {
self.http.clone()
}
pub fn cloud_client(&self) -> Arc {
self.cloud_client.clone()
}
pub fn set_id(&self, id: u64) -> &Self {
self.id.store(id, Ordering::SeqCst);
self
}
#[cfg(any(test, feature = "test-support"))]
pub fn teardown(&self) {
let mut state = self.state.write();
state._reconnect_task.take();
self.handler_set.lock().clear();
self.peer.teardown();
}
#[cfg(any(test, feature = "test-support"))]
pub fn override_authenticate(&self, authenticate: F) -> &Self
where
F: 'static + Send + Sync + Fn(&AsyncApp) -> Task>,
{
*self.authenticate.write() = Some(Box::new(authenticate));
self
}
#[cfg(any(test, feature = "test-support"))]
pub fn override_establish_connection(&self, connect: F) -> &Self
where
F: 'static
+ Send
+ Sync
+ Fn(&Credentials, &AsyncApp) -> Task>,
{
*self.establish_connection.write() = Some(Box::new(connect));
self
}
#[cfg(any(test, feature = "test-support"))]
pub fn override_rpc_url(&self, url: Url) -> &Self {
*self.rpc_url.write() = Some(url);
self
}
pub fn global(cx: &App) -> Arc {
cx.global::().0.clone()
}
pub fn set_global(client: Arc, cx: &mut App) {
cx.set_global(GlobalClient(client))
}
pub fn user_id(&self) -> Option {
self.state
.read()
.credentials
.as_ref()
.map(|credentials| credentials.user_id)
}
pub fn peer_id(&self) -> Option {
if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
Some(*peer_id)
} else {
None
}
}
pub fn status(&self) -> watch::Receiver {
self.state.read().status.1.clone()
}
fn set_status(self: &Arc, status: Status, cx: &AsyncApp) {
log::info!("set status on client {}: {:?}", self.id(), status);
let mut state = self.state.write();
*state.status.0.borrow_mut() = status;
match status {
Status::Connected { .. } => {
state._reconnect_task = None;
}
Status::ConnectionLost => {
let client = self.clone();
state._reconnect_task = Some(cx.spawn(async move |cx| {
#[cfg(any(test, feature = "test-support"))]
let mut rng = StdRng::seed_from_u64(0);
#[cfg(not(any(test, feature = "test-support")))]
let mut rng = StdRng::from_entropy();
let mut delay = INITIAL_RECONNECTION_DELAY;
loop {
match client.connect(true, cx).await {
ConnectionResult::Timeout => {
log::error!("client connect attempt timed out")
}
ConnectionResult::ConnectionReset => {
log::error!("client connect attempt reset")
}
ConnectionResult::Result(r) => {
if let Err(error) = r {
log::error!("failed to connect: {error}");
} else {
break;
}
}
}
if matches!(
*client.status().borrow(),
Status::AuthenticationError | Status::ConnectionError
) {
client.set_status(
Status::ReconnectionError {
next_reconnection: Instant::now() + delay,
},
cx,
);
let jitter =
Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64));
cx.background_executor().timer(delay + jitter).await;
delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
} else {
break;
}
}
}));
}
Status::SignedOut | Status::UpgradeRequired => {
self.telemetry.set_authenticated_user_info(None, false);
state._reconnect_task.take();
}
_ => {}
}
}
pub fn subscribe_to_entity(
self: &Arc,
remote_id: u64,
) -> Result>
where
T: 'static,
{
let id = (TypeId::of::(), remote_id);
let mut state = self.handler_set.lock();
anyhow::ensure!(
!state.entities_by_type_and_remote_id.contains_key(&id),
"already subscribed to entity"
);
state
.entities_by_type_and_remote_id
.insert(id, EntityMessageSubscriber::Pending(Default::default()));
Ok(PendingEntitySubscription {
client: self.clone(),
remote_id,
consumed: false,
_entity_type: PhantomData,
})
}
#[track_caller]
pub fn add_message_handler(
self: &Arc,
entity: WeakEntity,
handler: H,
) -> Subscription
where
M: EnvelopedMessage,
E: 'static,
H: 'static + Sync + Fn(Entity, TypedEnvelope, AsyncApp) -> F + Send + Sync,
F: 'static + Future>,
{
self.add_message_handler_impl(entity, move |entity, message, _, cx| {
handler(entity, message, cx)
})
}
fn add_message_handler_impl(
self: &Arc,
entity: WeakEntity,
handler: H,
) -> Subscription
where
M: EnvelopedMessage,
E: 'static,
H: 'static
+ Sync
+ Fn(Entity, TypedEnvelope, AnyProtoClient, AsyncApp) -> F
+ Send
+ Sync,
F: 'static + Future>,
{
let message_type_id = TypeId::of::();
let mut state = self.handler_set.lock();
state
.entities_by_message_type
.insert(message_type_id, entity.into());
let prev_handler = state.message_handlers.insert(
message_type_id,
Arc::new(move |subscriber, envelope, client, cx| {
let subscriber = subscriber.downcast::().unwrap();
let envelope = envelope.into_any().downcast::>().unwrap();
handler(subscriber, *envelope, client, cx).boxed_local()
}),
);
if prev_handler.is_some() {
let location = std::panic::Location::caller();
panic!(
"{}:{} registered handler for the same message {} twice",
location.file(),
location.line(),
std::any::type_name::()
);
}
Subscription::Message {
client: Arc::downgrade(self),
id: message_type_id,
}
}
pub fn add_request_handler(
self: &Arc,
entity: WeakEntity,
handler: H,
) -> Subscription
where
M: RequestMessage,
E: 'static,
H: 'static + Sync + Fn(Entity, TypedEnvelope, AsyncApp) -> F + Send + Sync,
F: 'static + Future>,
{
self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
})
}
async fn respond_to_request>>(
receipt: Receipt,
response: F,
client: AnyProtoClient,
) -> Result<()> {
match response.await {
Ok(response) => {
client.send_response(receipt.message_id, response)?;
Ok(())
}
Err(error) => {
client.send_response(receipt.message_id, error.to_proto())?;
Err(error)
}
}
}
pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
self.credentials_provider
.read_credentials(cx)
.await
.is_some()
}
pub async fn sign_in(
self: &Arc,
try_provider: bool,
cx: &AsyncApp,
) -> Result {
if self.status().borrow().is_signed_out() {
self.set_status(Status::Authenticating, cx);
} else {
self.set_status(Status::Reauthenticating, cx);
}
let mut credentials = None;
let old_credentials = self.state.read().credentials.clone();
if let Some(old_credentials) = old_credentials
&& self.validate_credentials(&old_credentials, cx).await?
{
credentials = Some(old_credentials);
}
if credentials.is_none()
&& try_provider
&& let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await
{
if self.validate_credentials(&stored_credentials, cx).await? {
credentials = Some(stored_credentials);
} else {
self.credentials_provider
.delete_credentials(cx)
.await
.log_err();
}
}
if credentials.is_none() {
let mut status_rx = self.status();
let _ = status_rx.next().await;
futures::select_biased! {
authenticate = self.authenticate(cx).fuse() => {
match authenticate {
Ok(creds) => {
if IMPERSONATE_LOGIN.is_none() {
self.credentials_provider
.write_credentials(creds.user_id, creds.access_token.clone(), cx)
.await
.log_err();
}
credentials = Some(creds);
},
Err(err) => {
self.set_status(Status::AuthenticationError, cx);
return Err(err);
}
}
}
_ = status_rx.next().fuse() => {
return Err(anyhow!("authentication canceled"));
}
}
}
let credentials = credentials.unwrap();
self.set_id(credentials.user_id);
self.cloud_client
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
self.state.write().credentials = Some(credentials.clone());
self.set_status(Status::Authenticated, cx);
Ok(credentials)
}
async fn validate_credentials(
self: &Arc,
credentials: &Credentials,
cx: &AsyncApp,
) -> Result {
match self
.cloud_client
.validate_credentials(credentials.user_id as u32, &credentials.access_token)
.await
{
Ok(valid) => Ok(valid),
Err(err) => {
self.set_status(Status::AuthenticationError, cx);
Err(anyhow!("failed to validate credentials: {}", err))
}
}
}
/// Establishes a WebSocket connection with Cloud for receiving updates from the server.
async fn connect_to_cloud(self: &Arc, cx: &AsyncApp) -> Result<()> {
let connect_task = cx.update({
let cloud_client = self.cloud_client.clone();
move |cx| cloud_client.connect(cx)
})??;
let connection = connect_task.await?;
let (mut messages, task) = cx.update(|cx| connection.spawn(cx))?;
task.detach();
cx.spawn({
let this = self.clone();
async move |cx| {
while let Some(message) = messages.next().await {
if let Some(message) = message.log_err() {
this.handle_message_to_client(message, cx);
}
}
}
})
.detach();
Ok(())
}
/// Performs a sign-in and also (optionally) connects to Collab.
///
/// Only Zed staff automatically connect to Collab.
pub async fn sign_in_with_optional_connect(
self: &Arc,
try_provider: bool,
cx: &AsyncApp,
) -> Result<()> {
// Don't try to sign in again if we're already connected to Collab, as it will temporarily disconnect us.
if self.status().borrow().is_connected() {
return Ok(());
}
let (is_staff_tx, is_staff_rx) = oneshot::channel::();
let mut is_staff_tx = Some(is_staff_tx);
cx.update(|cx| {
cx.on_flags_ready(move |state, _cx| {
if let Some(is_staff_tx) = is_staff_tx.take() {
is_staff_tx.send(state.is_staff).log_err();
}
})
.detach();
})
.log_err();
let credentials = self.sign_in(try_provider, cx).await?;
self.connect_to_cloud(cx).await.log_err();
cx.update(move |cx| {
cx.spawn({
let client = self.clone();
async move |cx| {
let is_staff = is_staff_rx.await?;
if is_staff {
match client.connect_with_credentials(credentials, cx).await {
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
ConnectionResult::Result(result) => {
result.context("client auth and connect")
}
}
} else {
Ok(())
}
}
})
.detach_and_log_err(cx);
})
.log_err();
Ok(())
}
pub async fn connect(
self: &Arc,
try_provider: bool,
cx: &AsyncApp,
) -> ConnectionResult<()> {
let was_disconnected = match *self.status().borrow() {
Status::SignedOut | Status::Authenticated => true,
Status::ConnectionError
| Status::ConnectionLost
| Status::Authenticating
| Status::AuthenticationError
| Status::Reauthenticating
| Status::ReconnectionError { .. } => false,
Status::Connected { .. } | Status::Connecting | Status::Reconnecting => {
return ConnectionResult::Result(Ok(()));
}
Status::UpgradeRequired => {
return ConnectionResult::Result(
Err(EstablishConnectionError::UpgradeRequired)
.context("client auth and connect"),
);
}
};
let credentials = match self.sign_in(try_provider, cx).await {
Ok(credentials) => credentials,
Err(err) => return ConnectionResult::Result(Err(err)),
};
if was_disconnected {
self.set_status(Status::Connecting, cx);
} else {
self.set_status(Status::Reconnecting, cx);
}
self.connect_with_credentials(credentials, cx).await
}
async fn connect_with_credentials(
self: &Arc,
credentials: Credentials,
cx: &AsyncApp,
) -> ConnectionResult<()> {
let mut timeout =
futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
futures::select_biased! {
connection = self.establish_connection(&credentials, cx).fuse() => {
match connection {
Ok(conn) => {
futures::select_biased! {
result = self.set_connection(conn, cx).fuse() => {
match result.context("client auth and connect") {
Ok(()) => ConnectionResult::Result(Ok(())),
Err(err) => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(err))
},
}
},
_ = timeout => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Timeout
}
}
}
Err(EstablishConnectionError::Unauthorized) => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
}
Err(EstablishConnectionError::UpgradeRequired) => {
self.set_status(Status::UpgradeRequired, cx);
ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
}
Err(error) => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(error).context("client auth and connect"))
}
}
}
_ = &mut timeout => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Timeout
}
}
}
async fn set_connection(self: &Arc, conn: Connection, cx: &AsyncApp) -> Result<()> {
let executor = cx.background_executor();
log::debug!("add connection to peer");
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
let executor = executor.clone();
move |duration| executor.timer(duration)
});
let handle_io = executor.spawn(handle_io);
let peer_id = async {
log::debug!("waiting for server hello");
let message = incoming.next().await.context("no hello message received")?;
log::debug!("got server hello");
let hello_message_type_name = message.payload_type_name().to_string();
let hello = message
.into_any()
.downcast::>()
.map_err(|_| {
anyhow!(
"invalid hello message received: {:?}",
hello_message_type_name
)
})?;
let peer_id = hello.payload.peer_id.context("invalid peer id")?;
Ok(peer_id)
};
let peer_id = match peer_id.await {
Ok(peer_id) => peer_id,
Err(error) => {
self.peer.disconnect(connection_id);
return Err(error);
}
};
log::debug!(
"set status to connected (connection id: {:?}, peer id: {:?})",
connection_id,
peer_id
);
self.set_status(
Status::Connected {
peer_id,
connection_id,
},
cx,
);
cx.spawn({
let this = self.clone();
async move |cx| {
while let Some(message) = incoming.next().await {
this.handle_message(message, cx);
// Don't starve the main thread when receiving lots of messages at once.
smol::future::yield_now().await;
}
}
})
.detach();
cx.spawn({
let this = self.clone();
async move |cx| match handle_io.await {
Ok(()) => {
if *this.status().borrow()
== (Status::Connected {
connection_id,
peer_id,
})
{
this.set_status(Status::SignedOut, cx);
}
}
Err(err) => {
log::error!("connection error: {:?}", err);
this.set_status(Status::ConnectionLost, cx);
}
}
})
.detach();
Ok(())
}
fn authenticate(self: &Arc, cx: &AsyncApp) -> Task> {
#[cfg(any(test, feature = "test-support"))]
if let Some(callback) = self.authenticate.read().as_ref() {
return callback(cx);
}
self.authenticate_with_browser(cx)
}
fn establish_connection(
self: &Arc,
credentials: &Credentials,
cx: &AsyncApp,
) -> Task> {
#[cfg(any(test, feature = "test-support"))]
if let Some(callback) = self.establish_connection.read().as_ref() {
return callback(credentials, cx);
}
self.establish_websocket_connection(credentials, cx)
}
fn rpc_url(
&self,
http: Arc,
release_channel: Option,
) -> impl Future> + use<> {
#[cfg(any(test, feature = "test-support"))]
let url_override = self.rpc_url.read().clone();
async move {
#[cfg(any(test, feature = "test-support"))]
if let Some(url) = url_override {
return Ok(url);
}
if let Some(url) = &*ZED_RPC_URL {
return Url::parse(url).context("invalid rpc url");
}
let mut url = http.build_url("/rpc");
if let Some(preview_param) =
release_channel.and_then(|channel| channel.release_query_param())
{
url += "?";
url += preview_param;
}
let response = http.get(&url, Default::default(), false).await?;
anyhow::ensure!(
response.status().is_redirection(),
"unexpected /rpc response status {}",
response.status()
);
let collab_url = response
.headers()
.get("Location")
.context("missing location header in /rpc response")?
.to_str()
.map_err(EstablishConnectionError::other)?
.to_string();
Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
}
}
fn establish_websocket_connection(
self: &Arc,
credentials: &Credentials,
cx: &AsyncApp,
) -> Task> {
let release_channel = cx
.update(|cx| ReleaseChannel::try_global(cx))
.ok()
.flatten();
let app_version = cx
.update(|cx| AppVersion::global(cx).to_string())
.ok()
.unwrap_or_default();
let http = self.http.clone();
let proxy = http.proxy().cloned();
let user_agent = http.user_agent().cloned();
let credentials = credentials.clone();
let rpc_url = self.rpc_url(http, release_channel);
let system_id = self.telemetry.system_id();
let metrics_id = self.telemetry.metrics_id();
cx.spawn(async move |cx| {
use HttpOrHttps::*;
#[derive(Debug)]
enum HttpOrHttps {
Http,
Https,
}
let mut rpc_url = rpc_url.await?;
let url_scheme = match rpc_url.scheme() {
"https" => Https,
"http" => Http,
_ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
};
let stream = gpui_tokio::Tokio::spawn_result(cx, {
let rpc_url = rpc_url.clone();
async move {
let rpc_host = rpc_url
.host_str()
.zip(rpc_url.port_or_known_default())
.context("missing host in rpc url")?;
Ok(match proxy {
Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
None => Box::new(TcpStream::connect(rpc_host).await?),
})
}
})?
.await?;
log::info!("connected to rpc endpoint {}", rpc_url);
rpc_url
.set_scheme(match url_scheme {
Https => "wss",
Http => "ws",
})
.unwrap();
// We call `into_client_request` to let `tungstenite` construct the WebSocket request
// for us from the RPC URL.
//
// Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
// We then modify the request to add our desired headers.
let request_headers = request.headers_mut();
request_headers.insert(
http::header::AUTHORIZATION,
HeaderValue::from_str(&credentials.authorization_header())?,
);
request_headers.insert(
"x-zed-protocol-version",
HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
);
request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
request_headers.insert(
"x-zed-release-channel",
HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
);
if let Some(user_agent) = user_agent {
request_headers.insert(http::header::USER_AGENT, user_agent);
}
if let Some(system_id) = system_id {
request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
}
if let Some(metrics_id) = metrics_id {
request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
}
let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
request,
stream,
Some(Arc::new(http_client_tls::tls_config()).into()),
None,
)
.await?;
Ok(Connection::new(
stream
.map_err(|error| anyhow!(error))
.sink_map_err(|error| anyhow!(error)),
))
})
}
pub fn authenticate_with_browser(self: &Arc, cx: &AsyncApp) -> Task> {
let http = self.http.clone();
let this = self.clone();
cx.spawn(async move |cx| {
let background = cx.background_executor().clone();
let (open_url_tx, open_url_rx) = oneshot::channel::();
cx.update(|cx| {
cx.spawn(async move |cx| {
let url = open_url_rx.await?;
cx.update(|cx| cx.open_url(&url))
})
.detach_and_log_err(cx);
})
.log_err();
let credentials = background
.clone()
.spawn(async move {
// Generate a pair of asymmetric encryption keys. The public key will be used by the
// zed server to encrypt the user's access token, so that it can'be intercepted by
// any other app running on the user's device.
let (public_key, private_key) =
rpc::auth::keypair().expect("failed to generate keypair for auth");
let public_key_string = String::try_from(public_key)
.expect("failed to serialize public key for auth");
if let Some((login, token)) =
IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
{
eprintln!("authenticate as admin {login}, {token}");
return this
.authenticate_as_admin(http, login.clone(), token.clone())
.await;
}
// Start an HTTP server to receive the redirect from Zed's sign-in page.
let server =
tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
let port = server.server_addr().port();
// Open the Zed sign-in page in the user's browser, with query parameters that indicate
// that the user is signing in from a Zed app running on the same device.
let mut url = http.build_url(&format!(
"/native_app_signin?native_app_port={}&native_app_public_key={}",
port, public_key_string
));
if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
log::info!("impersonating user @{}", impersonate_login);
write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
}
open_url_tx.send(url).log_err();
#[derive(Deserialize)]
struct CallbackParams {
pub user_id: String,
pub access_token: String,
}
// Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
// access token from the query params.
//
// TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
// custom URL scheme instead of this local HTTP server.
let (user_id, access_token) = background
.spawn(async move {
for _ in 0..100 {
if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
let path = req.url();
let url = Url::parse(&format!("http://example.com{}", path))
.context("failed to parse login notification url")?;
let callback_params: CallbackParams =
serde_urlencoded::from_str(url.query().unwrap_or_default())
.context(
"failed to parse sign-in callback query parameters",
)?;
let post_auth_url =
http.build_url("/native_app_signin_succeeded");
req.respond(
tiny_http::Response::empty(302).with_header(
tiny_http::Header::from_bytes(
&b"Location"[..],
post_auth_url.as_bytes(),
)
.unwrap(),
),
)
.context("failed to respond to login http request")?;
return Ok((
callback_params.user_id,
callback_params.access_token,
));
}
}
anyhow::bail!("didn't receive login redirect");
})
.await?;
let access_token = private_key
.decrypt_string(&access_token)
.context("failed to decrypt access token")?;
Ok(Credentials {
user_id: user_id.parse()?,
access_token,
})
})
.await?;
cx.update(|cx| cx.activate(true))?;
Ok(credentials)
})
}
async fn authenticate_as_admin(
self: &Arc,
http: Arc,
login: String,
api_token: String,
) -> Result {
#[derive(Serialize)]
struct ImpersonateUserBody {
github_login: String,
}
#[derive(Deserialize)]
struct ImpersonateUserResponse {
user_id: u64,
access_token: String,
}
let url = self
.http
.build_zed_cloud_url("/internal/users/impersonate", &[])?;
let request = Request::post(url.as_str())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {api_token}"))
.body(
serde_json::to_string(&ImpersonateUserBody {
github_login: login,
})?
.into(),
)?;
let mut response = http.send(request).await?;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::ensure!(
response.status().is_success(),
"admin user request failed {} - {}",
response.status().as_u16(),
body,
);
let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
Ok(Credentials {
user_id: response.user_id,
access_token: response.access_token,
})
}
pub async fn sign_out(self: &Arc, cx: &AsyncApp) {
self.state.write().credentials = None;
self.cloud_client.clear_credentials();
self.disconnect(cx);
if self.has_credentials(cx).await {
self.credentials_provider
.delete_credentials(cx)
.await
.log_err();
}
}
pub fn disconnect(self: &Arc, cx: &AsyncApp) {
self.peer.teardown();
self.set_status(Status::SignedOut, cx);
}
pub fn reconnect(self: &Arc, cx: &AsyncApp) {
self.peer.teardown();
self.set_status(Status::ConnectionLost, cx);
}
fn connection_id(&self) -> Result {
if let Status::Connected { connection_id, .. } = *self.status().borrow() {
Ok(connection_id)
} else {
anyhow::bail!("not connected");
}
}
pub fn send(&self, message: T) -> Result<()> {
log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
self.peer.send(self.connection_id()?, message)
}
pub fn request(
&self,
request: T,
) -> impl Future> + use {
self.request_envelope(request)
.map_ok(|envelope| envelope.payload)
}
pub fn request_stream(
&self,
request: T,
) -> impl Future>>> {
let client_id = self.id.load(Ordering::SeqCst);
log::debug!(
"rpc request start. client_id:{}. name:{}",
client_id,
T::NAME
);
let response = self
.connection_id()
.map(|conn_id| self.peer.request_stream(conn_id, request));
async move {
let response = response?.await;
log::debug!(
"rpc request finish. client_id:{}. name:{}",
client_id,
T::NAME
);
response
}
}
pub fn request_envelope(
&self,
request: T,
) -> impl Future>> + use {
let client_id = self.id();
log::debug!(
"rpc request start. client_id:{}. name:{}",
client_id,
T::NAME
);
let response = self
.connection_id()
.map(|conn_id| self.peer.request_envelope(conn_id, request));
async move {
let response = response?.await;
log::debug!(
"rpc request finish. client_id:{}. name:{}",
client_id,
T::NAME
);
response
}
}
pub fn request_dynamic(
&self,
envelope: proto::Envelope,
request_type: &'static str,
) -> impl Future> + use<> {
let client_id = self.id();
log::debug!(
"rpc request start. client_id:{}. name:{}",
client_id,
request_type
);
let response = self
.connection_id()
.map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
async move {
let response = response?.await;
log::debug!(
"rpc request finish. client_id:{}. name:{}",
client_id,
request_type
);
Ok(response?.0)
}
}
fn handle_message(self: &Arc, message: Box, cx: &AsyncApp) {
let sender_id = message.sender_id();
let request_id = message.message_id();
let type_name = message.payload_type_name();
let original_sender_id = message.original_sender_id();
if let Some(future) = ProtoMessageHandlerSet::handle_message(
&self.handler_set,
message,
self.clone().into(),
cx.clone(),
) {
let client_id = self.id();
log::debug!(
"rpc message received. client_id:{}, sender_id:{:?}, type:{}",
client_id,
original_sender_id,
type_name
);
cx.spawn(async move |_| match future.await {
Ok(()) => {
log::debug!(
"rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
client_id,
original_sender_id,
type_name
);
}
Err(error) => {
log::error!(
"error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
client_id,
original_sender_id,
type_name,
error
);
}
})
.detach();
} else {
log::info!("unhandled message {}", type_name);
self.peer
.respond_with_unhandled_message(sender_id.into(), request_id, type_name)
.log_err();
}
}
pub fn add_message_to_client_handler(
self: &Arc,
handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
) {
self.message_to_client_handlers
.lock()
.push(Box::new(handler));
}
fn handle_message_to_client(self: &Arc, message: MessageToClient, cx: &AsyncApp) {
cx.update(|cx| {
for handler in self.message_to_client_handlers.lock().iter() {
handler(&message, cx);
}
})
.ok();
}
pub fn telemetry(&self) -> &Arc {
&self.telemetry
}
}
impl ProtoClient for Client {
fn request(
&self,
envelope: proto::Envelope,
request_type: &'static str,
) -> BoxFuture<'static, Result> {
self.request_dynamic(envelope, request_type).boxed()
}
fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
let connection_id = self.connection_id()?;
self.peer.send_dynamic(connection_id, envelope)
}
fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
log::debug!(
"rpc respond. client_id:{}, name:{}",
self.id(),
message_type
);
let connection_id = self.connection_id()?;
self.peer.send_dynamic(connection_id, envelope)
}
fn message_handler_set(&self) -> &parking_lot::Mutex {
&self.handler_set
}
fn is_via_collab(&self) -> bool {
true
}
}
/// prefix for the zed:// url scheme
pub const ZED_URL_SCHEME: &str = "zed";
/// Parses the given link into a Zed link.
///
/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
/// Returns [`None`] otherwise.
pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
let server_url = &ClientSettings::get_global(cx).server_url;
if let Some(stripped) = link
.strip_prefix(server_url)
.and_then(|result| result.strip_prefix('/'))
{
return Some(stripped);
}
if let Some(stripped) = link
.strip_prefix(ZED_URL_SCHEME)
.and_then(|result| result.strip_prefix("://"))
{
return Some(stripped);
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::{FakeServer, parse_authorization_header};
use clock::FakeSystemClock;
use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
use http_client::FakeHttpClient;
use parking_lot::Mutex;
use proto::TypedEnvelope;
use settings::SettingsStore;
use std::future;
#[gpui::test(iterations = 10)]
async fn test_reconnection(cx: &mut TestAppContext) {
init_test(cx);
let user_id = 5;
let client = cx.update(|cx| {
Client::new(
Arc::new(FakeSystemClock::new()),
FakeHttpClient::with_404_response(),
cx,
)
});
let server = FakeServer::for_client(user_id, &client, cx).await;
let mut status = client.status();
assert!(matches!(
status.next().await,
Some(Status::Connected { .. })
));
assert_eq!(server.auth_count(), 1);
server.forbid_connections();
server.disconnect();
while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
server.allow_connections();
cx.executor().advance_clock(Duration::from_secs(10));
while !matches!(status.next().await, Some(Status::Connected { .. })) {}
assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
server.forbid_connections();
server.disconnect();
while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
// Clear cached credentials after authentication fails
server.roll_access_token();
server.allow_connections();
cx.executor().run_until_parked();
cx.executor().advance_clock(Duration::from_secs(10));
while !matches!(status.next().await, Some(Status::Connected { .. })) {}
assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
}
#[gpui::test(iterations = 10)]
async fn test_auth_failure_during_reconnection(cx: &mut TestAppContext) {
init_test(cx);
let http_client = FakeHttpClient::with_200_response();
let client =
cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
let server = FakeServer::for_client(42, &client, cx).await;
let mut status = client.status();
assert!(matches!(
status.next().await,
Some(Status::Connected { .. })
));
assert_eq!(server.auth_count(), 1);
// Simulate an auth failure during reconnection.
http_client
.as_fake()
.replace_handler(|_, _request| async move {
Ok(http_client::Response::builder()
.status(503)
.body("".into())
.unwrap())
});
server.disconnect();
while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
// Restore the ability to authenticate.
http_client
.as_fake()
.replace_handler(|_, _request| async move {
Ok(http_client::Response::builder()
.status(200)
.body("".into())
.unwrap())
});
cx.executor().advance_clock(Duration::from_secs(10));
while !matches!(status.next().await, Some(Status::Connected { .. })) {}
assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
}
#[gpui::test(iterations = 10)]
async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
init_test(cx);
let user_id = 5;
let client = cx.update(|cx| {
Client::new(
Arc::new(FakeSystemClock::new()),
FakeHttpClient::with_404_response(),
cx,
)
});
let mut status = client.status();
// Time out when client tries to connect.
client.override_authenticate(move |cx| {
cx.background_spawn(async move {
Ok(Credentials {
user_id,
access_token: "token".into(),
})
})
});
client.override_establish_connection(|_, cx| {
cx.background_spawn(async move {
future::pending::<()>().await;
unreachable!()
})
});
let auth_and_connect = cx.spawn({
let client = client.clone();
|cx| async move { client.connect(false, &cx).await }
});
executor.run_until_parked();
assert!(matches!(status.next().await, Some(Status::Connecting)));
executor.advance_clock(CONNECTION_TIMEOUT);
assert!(matches!(status.next().await, Some(Status::ConnectionError)));
auth_and_connect.await.into_response().unwrap_err();
// Allow the connection to be established.
let server = FakeServer::for_client(user_id, &client, cx).await;
assert!(matches!(
status.next().await,
Some(Status::Connected { .. })
));
// Disconnect client.
server.forbid_connections();
server.disconnect();
while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
// Time out when re-establishing the connection.
server.allow_connections();
client.override_establish_connection(|_, cx| {
cx.background_spawn(async move {
future::pending::<()>().await;
unreachable!()
})
});
executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
assert!(matches!(status.next().await, Some(Status::Reconnecting)));
executor.advance_clock(CONNECTION_TIMEOUT);
assert!(matches!(
status.next().await,
Some(Status::ReconnectionError { .. })
));
}
#[gpui::test(iterations = 10)]
async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
init_test(cx);
let auth_count = Arc::new(Mutex::new(0));
let http_client = FakeHttpClient::create(|_request| async move {
Ok(http_client::Response::builder()
.status(200)
.body("".into())
.unwrap())
});
let client =
cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
client.override_authenticate({
let auth_count = auth_count.clone();
move |cx| {
let auth_count = auth_count.clone();
cx.background_spawn(async move {
*auth_count.lock() += 1;
Ok(Credentials {
user_id: 1,
access_token: auth_count.lock().to_string(),
})
})
}
});
let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
assert_eq!(*auth_count.lock(), 1);
assert_eq!(credentials.access_token, "1");
// If credentials are still valid, signing in doesn't trigger authentication.
let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
assert_eq!(*auth_count.lock(), 1);
assert_eq!(credentials.access_token, "1");
// If the server is unavailable, signing in doesn't trigger authentication.
http_client
.as_fake()
.replace_handler(|_, _request| async move {
Ok(http_client::Response::builder()
.status(503)
.body("".into())
.unwrap())
});
client.sign_in(false, &cx.to_async()).await.unwrap_err();
assert_eq!(*auth_count.lock(), 1);
// If credentials became invalid, signing in triggers authentication.
http_client
.as_fake()
.replace_handler(|_, request| async move {
let credentials = parse_authorization_header(&request).unwrap();
if credentials.access_token == "2" {
Ok(http_client::Response::builder()
.status(200)
.body("".into())
.unwrap())
} else {
Ok(http_client::Response::builder()
.status(401)
.body("".into())
.unwrap())
}
});
let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
assert_eq!(*auth_count.lock(), 2);
assert_eq!(credentials.access_token, "2");
}
#[gpui::test(iterations = 10)]
async fn test_authenticating_more_than_once(
cx: &mut TestAppContext,
executor: BackgroundExecutor,
) {
init_test(cx);
let auth_count = Arc::new(Mutex::new(0));
let dropped_auth_count = Arc::new(Mutex::new(0));
let client = cx.update(|cx| {
Client::new(
Arc::new(FakeSystemClock::new()),
FakeHttpClient::with_404_response(),
cx,
)
});
client.override_authenticate({
let auth_count = auth_count.clone();
let dropped_auth_count = dropped_auth_count.clone();
move |cx| {
let auth_count = auth_count.clone();
let dropped_auth_count = dropped_auth_count.clone();
cx.background_spawn(async move {
*auth_count.lock() += 1;
let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
future::pending::<()>().await;
unreachable!()
})
}
});
let _authenticate = cx.spawn({
let client = client.clone();
move |cx| async move { client.connect(false, &cx).await }
});
executor.run_until_parked();
assert_eq!(*auth_count.lock(), 1);
assert_eq!(*dropped_auth_count.lock(), 0);
let _authenticate = cx.spawn(|cx| async move { client.connect(false, &cx).await });
executor.run_until_parked();
assert_eq!(*auth_count.lock(), 2);
assert_eq!(*dropped_auth_count.lock(), 1);
}
#[gpui::test]
async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
init_test(cx);
let user_id = 5;
let client = cx.update(|cx| {
Client::new(
Arc::new(FakeSystemClock::new()),
FakeHttpClient::with_404_response(),
cx,
)
});
let server = FakeServer::for_client(user_id, &client, cx).await;
let (done_tx1, done_rx1) = smol::channel::unbounded();
let (done_tx2, done_rx2) = smol::channel::unbounded();
AnyProtoClient::from(client.clone()).add_entity_message_handler(
move |entity: Entity, _: TypedEnvelope, cx| {
match entity.read_with(&cx, |entity, _| entity.id).unwrap() {
1 => done_tx1.try_send(()).unwrap(),
2 => done_tx2.try_send(()).unwrap(),
_ => unreachable!(),
}
async { Ok(()) }
},
);
let entity1 = cx.new(|_| TestEntity {
id: 1,
subscription: None,
});
let entity2 = cx.new(|_| TestEntity {
id: 2,
subscription: None,
});
let entity3 = cx.new(|_| TestEntity {
id: 3,
subscription: None,
});
let _subscription1 = client
.subscribe_to_entity(1)
.unwrap()
.set_entity(&entity1, &cx.to_async());
let _subscription2 = client
.subscribe_to_entity(2)
.unwrap()
.set_entity(&entity2, &cx.to_async());
// Ensure dropping a subscription for the same entity type still allows receiving of
// messages for other entity IDs of the same type.
let subscription3 = client
.subscribe_to_entity(3)
.unwrap()
.set_entity(&entity3, &cx.to_async());
drop(subscription3);
server.send(proto::JoinProject {
project_id: 1,
committer_name: None,
committer_email: None,
});
server.send(proto::JoinProject {
project_id: 2,
committer_name: None,
committer_email: None,
});
done_rx1.recv().await.unwrap();
done_rx2.recv().await.unwrap();
}
#[gpui::test]
async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
init_test(cx);
let user_id = 5;
let client = cx.update(|cx| {
Client::new(
Arc::new(FakeSystemClock::new()),
FakeHttpClient::with_404_response(),
cx,
)
});
let server = FakeServer::for_client(user_id, &client, cx).await;
let entity = cx.new(|_| TestEntity::default());
let (done_tx1, _done_rx1) = smol::channel::unbounded();
let (done_tx2, done_rx2) = smol::channel::unbounded();
let subscription1 = client.add_message_handler(
entity.downgrade(),
move |_, _: TypedEnvelope, _| {
done_tx1.try_send(()).unwrap();
async { Ok(()) }
},
);
drop(subscription1);
let _subscription2 = client.add_message_handler(
entity.downgrade(),
move |_, _: TypedEnvelope, _| {
done_tx2.try_send(()).unwrap();
async { Ok(()) }
},
);
server.send(proto::Ping {});
done_rx2.recv().await.unwrap();
}
#[gpui::test]
async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
init_test(cx);
let user_id = 5;
let client = cx.update(|cx| {
Client::new(
Arc::new(FakeSystemClock::new()),
FakeHttpClient::with_404_response(),
cx,
)
});
let server = FakeServer::for_client(user_id, &client, cx).await;
let entity = cx.new(|_| TestEntity::default());
let (done_tx, done_rx) = smol::channel::unbounded();
let subscription = client.add_message_handler(
entity.clone().downgrade(),
move |entity: Entity, _: TypedEnvelope, mut cx| {
entity
.update(&mut cx, |entity, _| entity.subscription.take())
.unwrap();
done_tx.try_send(()).unwrap();
async { Ok(()) }
},
);
entity.update(cx, |entity, _| {
entity.subscription = Some(subscription);
});
server.send(proto::Ping {});
done_rx.recv().await.unwrap();
}
#[derive(Default)]
struct TestEntity {
id: usize,
subscription: Option,
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
init_settings(cx);
});
}
}