
Things this doesn't currently handle: - [x] ~testing~ - ~we really need an snapshot test that takes a vscode settings file with all options that we support, and verifies the zed settings file you get from importing it, both from an empty starting file or one with lots of conflicts. that way we can open said vscode settings file in vscode to ensure that those options all still exist in the future.~ - Discussed this, we don't think this will meaningfully protect us from future failures, and we will just do this as a manual validation step before merging this PR. Any imports that have meaningfully complex translation steps should still be tested. - [x] confirmation (right now it just clobbers your settings file silently) - it'd be really cool if we could show a diff multibuffer of your current settings with the result of the vscode import and let you pick "hunks" to keep, but that's probably too much effort for this feature, especially given that we expect most of the people using it to have an empty/barebones zed config when they run the import. - [x] ~UI in the "welcome" page~ - we're planning on redoing our welcome/walkthrough experience anyways, but in the meantime it'd be nice to conditionally show a button there if we see a user level vscode config - we'll add it to the UI when we land the new walkthrough experience, for now it'll be accessible through the action - [ ] project-specific settings - handling translation of `.vscode/settings.json` or `.code-workspace` settings to `.zed/settings.json` will come in a future PR, along with UI to prompt the user for those actions when opening a project with local vscode settings for the first time - [ ] extension settings - we probably want to do a best-effort pass of popular extensions like vim and git lens - it's also possible to look for installed/enabled extensions with `code --list-extensions`, but we'd have to maintain some sort of mapping of those to our settings and/or extensions - [ ] LSP settings - these are tricky without access to the json schemas for various language server extensions. we could probably manage to do translations for a couple popular languages and avoid solving it in the general case. - [ ] platform specific settings (`[macos].blah`) - this is blocked on #16392 which I'm hoping to address soon - [ ] language specific settings (`[rust].foo`) - totally doable, just haven't gotten to it yet ~We may want to put this behind some kind of flag and/or not land it until some of the above issues are addressed, given that we expect people to only run this importer once there's an incentive to get it right the first time. Maybe we land it alongside a keymap importer so you don't have to go through separate imports for those?~ We are gonna land this as-is, all these unchecked items at the bottom will be addressed in followup PRs, so maybe don't run the importer for now if you have a large and complex VsCode settings file you'd like to import. Release Notes: - Added a VSCode settings importer, available via a `zed::ImportVsCodeSettings` action --------- Co-authored-by: Mikayla Maki <mikayla@zed.dev> Co-authored-by: Kirill Bulatov <kirill@zed.dev> Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com> Co-authored-by: Marshall Bowers <git@maxdeviant.com>
1964 lines
65 KiB
Rust
1964 lines
65 KiB
Rust
#[cfg(any(test, feature = "test-support"))]
|
|
pub mod test;
|
|
|
|
mod socks;
|
|
pub mod telemetry;
|
|
pub mod user;
|
|
pub mod zed_urls;
|
|
|
|
use anyhow::{Context as _, Result, anyhow, bail};
|
|
use async_recursion::async_recursion;
|
|
use async_tungstenite::tungstenite::{
|
|
client::IntoClientRequest,
|
|
error::Error as WebsocketError,
|
|
http::{HeaderValue, Request, StatusCode},
|
|
};
|
|
use chrono::{DateTime, Utc};
|
|
use clock::SystemClock;
|
|
use credentials_provider::CredentialsProvider;
|
|
use futures::{
|
|
AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
|
|
channel::oneshot, future::BoxFuture,
|
|
};
|
|
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
|
|
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
|
|
use parking_lot::RwLock;
|
|
use postage::watch;
|
|
use rand::prelude::*;
|
|
use release_channel::{AppVersion, ReleaseChannel};
|
|
use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
|
|
use schemars::JsonSchema;
|
|
use serde::{Deserialize, Serialize};
|
|
use settings::{Settings, SettingsSources};
|
|
use socks::connect_socks_proxy_stream;
|
|
use std::pin::Pin;
|
|
use std::{
|
|
any::TypeId,
|
|
convert::TryFrom,
|
|
fmt::Write as _,
|
|
future::Future,
|
|
marker::PhantomData,
|
|
path::PathBuf,
|
|
sync::{
|
|
Arc, LazyLock, Weak,
|
|
atomic::{AtomicU64, Ordering},
|
|
},
|
|
time::{Duration, Instant},
|
|
};
|
|
use telemetry::Telemetry;
|
|
use thiserror::Error;
|
|
use url::Url;
|
|
use util::{ResultExt, TryFutureExt};
|
|
|
|
pub use rpc::*;
|
|
pub use telemetry_events::Event;
|
|
pub use user::*;
|
|
|
|
static ZED_SERVER_URL: LazyLock<Option<String>> =
|
|
LazyLock::new(|| std::env::var("ZED_SERVER_URL").ok());
|
|
static ZED_RPC_URL: LazyLock<Option<String>> = LazyLock::new(|| std::env::var("ZED_RPC_URL").ok());
|
|
|
|
pub static IMPERSONATE_LOGIN: LazyLock<Option<String>> = LazyLock::new(|| {
|
|
std::env::var("ZED_IMPERSONATE")
|
|
.ok()
|
|
.and_then(|s| if s.is_empty() { None } else { Some(s) })
|
|
});
|
|
|
|
pub static ADMIN_API_TOKEN: LazyLock<Option<String>> = LazyLock::new(|| {
|
|
std::env::var("ZED_ADMIN_API_TOKEN")
|
|
.ok()
|
|
.and_then(|s| if s.is_empty() { None } else { Some(s) })
|
|
});
|
|
|
|
pub static ZED_APP_PATH: LazyLock<Option<PathBuf>> =
|
|
LazyLock::new(|| std::env::var("ZED_APP_PATH").ok().map(PathBuf::from));
|
|
|
|
pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
|
|
LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
|
|
|
|
pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
|
|
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
|
|
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
|
|
|
|
actions!(client, [SignIn, SignOut, Reconnect]);
|
|
|
|
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
|
|
pub struct ClientSettingsContent {
|
|
server_url: Option<String>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct ClientSettings {
|
|
pub server_url: String,
|
|
}
|
|
|
|
impl Settings for ClientSettings {
|
|
const KEY: Option<&'static str> = None;
|
|
|
|
type FileContent = ClientSettingsContent;
|
|
|
|
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
|
let mut result = sources.json_merge::<Self>()?;
|
|
if let Some(server_url) = &*ZED_SERVER_URL {
|
|
result.server_url.clone_from(server_url)
|
|
}
|
|
Ok(result)
|
|
}
|
|
|
|
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
|
}
|
|
|
|
#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
|
|
pub struct ProxySettingsContent {
|
|
proxy: Option<String>,
|
|
}
|
|
|
|
#[derive(Deserialize, Default)]
|
|
pub struct ProxySettings {
|
|
pub proxy: Option<String>,
|
|
}
|
|
|
|
impl Settings for ProxySettings {
|
|
const KEY: Option<&'static str> = None;
|
|
|
|
type FileContent = ProxySettingsContent;
|
|
|
|
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
|
Ok(Self {
|
|
proxy: sources
|
|
.user
|
|
.or(sources.server)
|
|
.and_then(|value| value.proxy.clone())
|
|
.or(sources.default.proxy.clone()),
|
|
})
|
|
}
|
|
|
|
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
|
vscode.string_setting("http.proxy", &mut current.proxy);
|
|
}
|
|
}
|
|
|
|
pub fn init_settings(cx: &mut App) {
|
|
TelemetrySettings::register(cx);
|
|
ClientSettings::register(cx);
|
|
ProxySettings::register(cx);
|
|
}
|
|
|
|
pub fn init(client: &Arc<Client>, cx: &mut App) {
|
|
let client = Arc::downgrade(client);
|
|
cx.on_action({
|
|
let client = client.clone();
|
|
move |_: &SignIn, cx| {
|
|
if let Some(client) = client.upgrade() {
|
|
cx.spawn(async move |cx| {
|
|
client.authenticate_and_connect(true, &cx).log_err().await
|
|
})
|
|
.detach();
|
|
}
|
|
}
|
|
});
|
|
|
|
cx.on_action({
|
|
let client = client.clone();
|
|
move |_: &SignOut, cx| {
|
|
if let Some(client) = client.upgrade() {
|
|
cx.spawn(async move |cx| {
|
|
client.sign_out(&cx).await;
|
|
})
|
|
.detach();
|
|
}
|
|
}
|
|
});
|
|
|
|
cx.on_action({
|
|
let client = client.clone();
|
|
move |_: &Reconnect, cx| {
|
|
if let Some(client) = client.upgrade() {
|
|
cx.spawn(async move |cx| {
|
|
client.reconnect(&cx);
|
|
})
|
|
.detach();
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
struct GlobalClient(Arc<Client>);
|
|
|
|
impl Global for GlobalClient {}
|
|
|
|
pub struct Client {
|
|
id: AtomicU64,
|
|
peer: Arc<Peer>,
|
|
http: Arc<HttpClientWithUrl>,
|
|
telemetry: Arc<Telemetry>,
|
|
credentials_provider: ClientCredentialsProvider,
|
|
state: RwLock<ClientState>,
|
|
handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
|
|
|
|
#[allow(clippy::type_complexity)]
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
authenticate:
|
|
RwLock<Option<Box<dyn 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>>>>,
|
|
|
|
#[allow(clippy::type_complexity)]
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
establish_connection: RwLock<
|
|
Option<
|
|
Box<
|
|
dyn 'static
|
|
+ Send
|
|
+ Sync
|
|
+ Fn(
|
|
&Credentials,
|
|
&AsyncApp,
|
|
) -> Task<Result<Connection, EstablishConnectionError>>,
|
|
>,
|
|
>,
|
|
>,
|
|
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
rpc_url: RwLock<Option<Url>>,
|
|
}
|
|
|
|
#[derive(Error, Debug)]
|
|
pub enum EstablishConnectionError {
|
|
#[error("upgrade required")]
|
|
UpgradeRequired,
|
|
#[error("unauthorized")]
|
|
Unauthorized,
|
|
#[error("{0}")]
|
|
Other(#[from] anyhow::Error),
|
|
#[error("{0}")]
|
|
InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
|
|
#[error("{0}")]
|
|
Io(#[from] std::io::Error),
|
|
#[error("{0}")]
|
|
Websocket(#[from] async_tungstenite::tungstenite::http::Error),
|
|
}
|
|
|
|
impl From<WebsocketError> for EstablishConnectionError {
|
|
fn from(error: WebsocketError) -> Self {
|
|
if let WebsocketError::Http(response) = &error {
|
|
match response.status() {
|
|
StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
|
|
StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
|
|
_ => {}
|
|
}
|
|
}
|
|
EstablishConnectionError::Other(error.into())
|
|
}
|
|
}
|
|
|
|
impl EstablishConnectionError {
|
|
pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
|
|
Self::Other(error.into())
|
|
}
|
|
}
|
|
|
|
#[derive(Copy, Clone, Debug, PartialEq)]
|
|
pub enum Status {
|
|
SignedOut,
|
|
UpgradeRequired,
|
|
Authenticating,
|
|
Connecting,
|
|
ConnectionError,
|
|
Connected {
|
|
peer_id: PeerId,
|
|
connection_id: ConnectionId,
|
|
},
|
|
ConnectionLost,
|
|
Reauthenticating,
|
|
Reconnecting,
|
|
ReconnectionError {
|
|
next_reconnection: Instant,
|
|
},
|
|
}
|
|
|
|
impl Status {
|
|
pub fn is_connected(&self) -> bool {
|
|
matches!(self, Self::Connected { .. })
|
|
}
|
|
|
|
pub fn is_signed_out(&self) -> bool {
|
|
matches!(self, Self::SignedOut | Self::UpgradeRequired)
|
|
}
|
|
}
|
|
|
|
struct ClientState {
|
|
credentials: Option<Credentials>,
|
|
status: (watch::Sender<Status>, watch::Receiver<Status>),
|
|
_reconnect_task: Option<Task<()>>,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
|
pub struct Credentials {
|
|
pub user_id: u64,
|
|
pub access_token: String,
|
|
}
|
|
|
|
impl Credentials {
|
|
pub fn authorization_header(&self) -> String {
|
|
format!("{} {}", self.user_id, self.access_token)
|
|
}
|
|
}
|
|
|
|
pub struct ClientCredentialsProvider {
|
|
provider: Arc<dyn CredentialsProvider>,
|
|
}
|
|
|
|
impl ClientCredentialsProvider {
|
|
pub fn new(cx: &App) -> Self {
|
|
Self {
|
|
provider: <dyn CredentialsProvider>::global(cx),
|
|
}
|
|
}
|
|
|
|
fn server_url(&self, cx: &AsyncApp) -> Result<String> {
|
|
cx.update(|cx| ClientSettings::get_global(cx).server_url.clone())
|
|
}
|
|
|
|
/// Reads the credentials from the provider.
|
|
fn read_credentials<'a>(
|
|
&'a self,
|
|
cx: &'a AsyncApp,
|
|
) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
|
|
async move {
|
|
if IMPERSONATE_LOGIN.is_some() {
|
|
return None;
|
|
}
|
|
|
|
let server_url = self.server_url(cx).ok()?;
|
|
let (user_id, access_token) = self
|
|
.provider
|
|
.read_credentials(&server_url, cx)
|
|
.await
|
|
.log_err()
|
|
.flatten()?;
|
|
|
|
Some(Credentials {
|
|
user_id: user_id.parse().ok()?,
|
|
access_token: String::from_utf8(access_token).ok()?,
|
|
})
|
|
}
|
|
.boxed_local()
|
|
}
|
|
|
|
/// Writes the credentials to the provider.
|
|
fn write_credentials<'a>(
|
|
&'a self,
|
|
user_id: u64,
|
|
access_token: String,
|
|
cx: &'a AsyncApp,
|
|
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
|
async move {
|
|
let server_url = self.server_url(cx)?;
|
|
self.provider
|
|
.write_credentials(
|
|
&server_url,
|
|
&user_id.to_string(),
|
|
access_token.as_bytes(),
|
|
cx,
|
|
)
|
|
.await
|
|
}
|
|
.boxed_local()
|
|
}
|
|
|
|
/// Deletes the credentials from the provider.
|
|
fn delete_credentials<'a>(
|
|
&'a self,
|
|
cx: &'a AsyncApp,
|
|
) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
|
|
async move {
|
|
let server_url = self.server_url(cx)?;
|
|
self.provider.delete_credentials(&server_url, cx).await
|
|
}
|
|
.boxed_local()
|
|
}
|
|
}
|
|
|
|
impl Default for ClientState {
|
|
fn default() -> Self {
|
|
Self {
|
|
credentials: None,
|
|
status: watch::channel_with(Status::SignedOut),
|
|
_reconnect_task: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub enum Subscription {
|
|
Entity {
|
|
client: Weak<Client>,
|
|
id: (TypeId, u64),
|
|
},
|
|
Message {
|
|
client: Weak<Client>,
|
|
id: TypeId,
|
|
},
|
|
}
|
|
|
|
impl Drop for Subscription {
|
|
fn drop(&mut self) {
|
|
match self {
|
|
Subscription::Entity { client, id } => {
|
|
if let Some(client) = client.upgrade() {
|
|
let mut state = client.handler_set.lock();
|
|
let _ = state.entities_by_type_and_remote_id.remove(id);
|
|
}
|
|
}
|
|
Subscription::Message { client, id } => {
|
|
if let Some(client) = client.upgrade() {
|
|
let mut state = client.handler_set.lock();
|
|
let _ = state.entity_types_by_message_type.remove(id);
|
|
let _ = state.message_handlers.remove(id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct PendingEntitySubscription<T: 'static> {
|
|
client: Arc<Client>,
|
|
remote_id: u64,
|
|
_entity_type: PhantomData<T>,
|
|
consumed: bool,
|
|
}
|
|
|
|
impl<T: 'static> PendingEntitySubscription<T> {
|
|
pub fn set_entity(mut self, entity: &Entity<T>, cx: &AsyncApp) -> Subscription {
|
|
self.consumed = true;
|
|
let mut handlers = self.client.handler_set.lock();
|
|
let id = (TypeId::of::<T>(), self.remote_id);
|
|
let Some(EntityMessageSubscriber::Pending(messages)) =
|
|
handlers.entities_by_type_and_remote_id.remove(&id)
|
|
else {
|
|
unreachable!()
|
|
};
|
|
|
|
handlers.entities_by_type_and_remote_id.insert(
|
|
id,
|
|
EntityMessageSubscriber::Entity {
|
|
handle: entity.downgrade().into(),
|
|
},
|
|
);
|
|
drop(handlers);
|
|
for message in messages {
|
|
let client_id = self.client.id();
|
|
let type_name = message.payload_type_name();
|
|
let sender_id = message.original_sender_id();
|
|
log::debug!(
|
|
"handling queued rpc message. client_id:{}, sender_id:{:?}, type:{}",
|
|
client_id,
|
|
sender_id,
|
|
type_name
|
|
);
|
|
self.client.handle_message(message, cx);
|
|
}
|
|
Subscription::Entity {
|
|
client: Arc::downgrade(&self.client),
|
|
id,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T: 'static> Drop for PendingEntitySubscription<T> {
|
|
fn drop(&mut self) {
|
|
if !self.consumed {
|
|
let mut state = self.client.handler_set.lock();
|
|
if let Some(EntityMessageSubscriber::Pending(messages)) = state
|
|
.entities_by_type_and_remote_id
|
|
.remove(&(TypeId::of::<T>(), self.remote_id))
|
|
{
|
|
for message in messages {
|
|
log::info!("unhandled message {}", message.payload_type_name());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Copy, Clone)]
|
|
pub struct TelemetrySettings {
|
|
pub diagnostics: bool,
|
|
pub metrics: bool,
|
|
}
|
|
|
|
/// Control what info is collected by Zed.
|
|
#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
|
|
pub struct TelemetrySettingsContent {
|
|
/// Send debug info like crash reports.
|
|
///
|
|
/// Default: true
|
|
pub diagnostics: Option<bool>,
|
|
/// Send anonymized usage data like what languages you're using Zed with.
|
|
///
|
|
/// Default: true
|
|
pub metrics: Option<bool>,
|
|
}
|
|
|
|
impl settings::Settings for TelemetrySettings {
|
|
const KEY: Option<&'static str> = Some("telemetry");
|
|
|
|
type FileContent = TelemetrySettingsContent;
|
|
|
|
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
|
Ok(Self {
|
|
diagnostics: sources
|
|
.user
|
|
.as_ref()
|
|
.or(sources.server.as_ref())
|
|
.and_then(|v| v.diagnostics)
|
|
.unwrap_or(
|
|
sources
|
|
.default
|
|
.diagnostics
|
|
.ok_or_else(Self::missing_default)?,
|
|
),
|
|
metrics: sources
|
|
.user
|
|
.as_ref()
|
|
.or(sources.server.as_ref())
|
|
.and_then(|v| v.metrics)
|
|
.unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
|
|
})
|
|
}
|
|
|
|
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
|
vscode.enum_setting("telemetry.telemetryLevel", &mut current.metrics, |s| {
|
|
Some(s == "all")
|
|
});
|
|
vscode.enum_setting("telemetry.telemetryLevel", &mut current.diagnostics, |s| {
|
|
Some(matches!(s, "all" | "error" | "crash"))
|
|
});
|
|
// we could translate telemetry.telemetryLevel, but just because users didn't want
|
|
// to send microsoft telemetry doesn't mean they don't want to send it to zed. their
|
|
// all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
|
|
}
|
|
}
|
|
|
|
impl Client {
|
|
pub fn new(
|
|
clock: Arc<dyn SystemClock>,
|
|
http: Arc<HttpClientWithUrl>,
|
|
cx: &mut App,
|
|
) -> Arc<Self> {
|
|
Arc::new(Self {
|
|
id: AtomicU64::new(0),
|
|
peer: Peer::new(0),
|
|
telemetry: Telemetry::new(clock, http.clone(), cx),
|
|
http,
|
|
credentials_provider: ClientCredentialsProvider::new(cx),
|
|
state: Default::default(),
|
|
handler_set: Default::default(),
|
|
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
authenticate: Default::default(),
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
establish_connection: Default::default(),
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
rpc_url: RwLock::default(),
|
|
})
|
|
}
|
|
|
|
pub fn production(cx: &mut App) -> Arc<Self> {
|
|
let clock = Arc::new(clock::RealSystemClock);
|
|
let http = Arc::new(HttpClientWithUrl::new_uri(
|
|
cx.http_client(),
|
|
&ClientSettings::get_global(cx).server_url,
|
|
cx.http_client().proxy().cloned(),
|
|
));
|
|
Self::new(clock, http, cx)
|
|
}
|
|
|
|
pub fn id(&self) -> u64 {
|
|
self.id.load(Ordering::SeqCst)
|
|
}
|
|
|
|
pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
|
|
self.http.clone()
|
|
}
|
|
|
|
pub fn set_id(&self, id: u64) -> &Self {
|
|
self.id.store(id, Ordering::SeqCst);
|
|
self
|
|
}
|
|
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
pub fn teardown(&self) {
|
|
let mut state = self.state.write();
|
|
state._reconnect_task.take();
|
|
self.handler_set.lock().clear();
|
|
self.peer.teardown();
|
|
}
|
|
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
pub fn override_authenticate<F>(&self, authenticate: F) -> &Self
|
|
where
|
|
F: 'static + Send + Sync + Fn(&AsyncApp) -> Task<Result<Credentials>>,
|
|
{
|
|
*self.authenticate.write() = Some(Box::new(authenticate));
|
|
self
|
|
}
|
|
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
pub fn override_establish_connection<F>(&self, connect: F) -> &Self
|
|
where
|
|
F: 'static
|
|
+ Send
|
|
+ Sync
|
|
+ Fn(&Credentials, &AsyncApp) -> Task<Result<Connection, EstablishConnectionError>>,
|
|
{
|
|
*self.establish_connection.write() = Some(Box::new(connect));
|
|
self
|
|
}
|
|
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
pub fn override_rpc_url(&self, url: Url) -> &Self {
|
|
*self.rpc_url.write() = Some(url);
|
|
self
|
|
}
|
|
|
|
pub fn global(cx: &App) -> Arc<Self> {
|
|
cx.global::<GlobalClient>().0.clone()
|
|
}
|
|
pub fn set_global(client: Arc<Client>, cx: &mut App) {
|
|
cx.set_global(GlobalClient(client))
|
|
}
|
|
|
|
pub fn user_id(&self) -> Option<u64> {
|
|
self.state
|
|
.read()
|
|
.credentials
|
|
.as_ref()
|
|
.map(|credentials| credentials.user_id)
|
|
}
|
|
|
|
pub fn peer_id(&self) -> Option<PeerId> {
|
|
if let Status::Connected { peer_id, .. } = &*self.status().borrow() {
|
|
Some(*peer_id)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
pub fn status(&self) -> watch::Receiver<Status> {
|
|
self.state.read().status.1.clone()
|
|
}
|
|
|
|
fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncApp) {
|
|
log::info!("set status on client {}: {:?}", self.id(), status);
|
|
let mut state = self.state.write();
|
|
*state.status.0.borrow_mut() = status;
|
|
|
|
match status {
|
|
Status::Connected { .. } => {
|
|
state._reconnect_task = None;
|
|
}
|
|
Status::ConnectionLost => {
|
|
let this = self.clone();
|
|
state._reconnect_task = Some(cx.spawn(async move |cx| {
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
let mut rng = StdRng::seed_from_u64(0);
|
|
#[cfg(not(any(test, feature = "test-support")))]
|
|
let mut rng = StdRng::from_entropy();
|
|
|
|
let mut delay = INITIAL_RECONNECTION_DELAY;
|
|
while let Err(error) = this.authenticate_and_connect(true, &cx).await {
|
|
log::error!("failed to connect {}", error);
|
|
if matches!(*this.status().borrow(), Status::ConnectionError) {
|
|
this.set_status(
|
|
Status::ReconnectionError {
|
|
next_reconnection: Instant::now() + delay,
|
|
},
|
|
&cx,
|
|
);
|
|
cx.background_executor().timer(delay).await;
|
|
delay = delay
|
|
.mul_f32(rng.gen_range(0.5..=2.5))
|
|
.max(INITIAL_RECONNECTION_DELAY)
|
|
.min(MAX_RECONNECTION_DELAY);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
}));
|
|
}
|
|
Status::SignedOut | Status::UpgradeRequired => {
|
|
self.telemetry.set_authenticated_user_info(None, false);
|
|
state._reconnect_task.take();
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
pub fn subscribe_to_entity<T>(
|
|
self: &Arc<Self>,
|
|
remote_id: u64,
|
|
) -> Result<PendingEntitySubscription<T>>
|
|
where
|
|
T: 'static,
|
|
{
|
|
let id = (TypeId::of::<T>(), remote_id);
|
|
|
|
let mut state = self.handler_set.lock();
|
|
if state.entities_by_type_and_remote_id.contains_key(&id) {
|
|
return Err(anyhow!("already subscribed to entity"));
|
|
}
|
|
|
|
state
|
|
.entities_by_type_and_remote_id
|
|
.insert(id, EntityMessageSubscriber::Pending(Default::default()));
|
|
|
|
Ok(PendingEntitySubscription {
|
|
client: self.clone(),
|
|
remote_id,
|
|
consumed: false,
|
|
_entity_type: PhantomData,
|
|
})
|
|
}
|
|
|
|
#[track_caller]
|
|
pub fn add_message_handler<M, E, H, F>(
|
|
self: &Arc<Self>,
|
|
entity: WeakEntity<E>,
|
|
handler: H,
|
|
) -> Subscription
|
|
where
|
|
M: EnvelopedMessage,
|
|
E: 'static,
|
|
H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
|
|
F: 'static + Future<Output = Result<()>>,
|
|
{
|
|
self.add_message_handler_impl(entity, move |entity, message, _, cx| {
|
|
handler(entity, message, cx)
|
|
})
|
|
}
|
|
|
|
fn add_message_handler_impl<M, E, H, F>(
|
|
self: &Arc<Self>,
|
|
entity: WeakEntity<E>,
|
|
handler: H,
|
|
) -> Subscription
|
|
where
|
|
M: EnvelopedMessage,
|
|
E: 'static,
|
|
H: 'static
|
|
+ Sync
|
|
+ Fn(Entity<E>, TypedEnvelope<M>, AnyProtoClient, AsyncApp) -> F
|
|
+ Send
|
|
+ Sync,
|
|
F: 'static + Future<Output = Result<()>>,
|
|
{
|
|
let message_type_id = TypeId::of::<M>();
|
|
let mut state = self.handler_set.lock();
|
|
state
|
|
.entities_by_message_type
|
|
.insert(message_type_id, entity.into());
|
|
|
|
let prev_handler = state.message_handlers.insert(
|
|
message_type_id,
|
|
Arc::new(move |subscriber, envelope, client, cx| {
|
|
let subscriber = subscriber.downcast::<E>().unwrap();
|
|
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
|
|
handler(subscriber, *envelope, client.clone(), cx).boxed_local()
|
|
}),
|
|
);
|
|
if prev_handler.is_some() {
|
|
let location = std::panic::Location::caller();
|
|
panic!(
|
|
"{}:{} registered handler for the same message {} twice",
|
|
location.file(),
|
|
location.line(),
|
|
std::any::type_name::<M>()
|
|
);
|
|
}
|
|
|
|
Subscription::Message {
|
|
client: Arc::downgrade(self),
|
|
id: message_type_id,
|
|
}
|
|
}
|
|
|
|
pub fn add_request_handler<M, E, H, F>(
|
|
self: &Arc<Self>,
|
|
entity: WeakEntity<E>,
|
|
handler: H,
|
|
) -> Subscription
|
|
where
|
|
M: RequestMessage,
|
|
E: 'static,
|
|
H: 'static + Sync + Fn(Entity<E>, TypedEnvelope<M>, AsyncApp) -> F + Send + Sync,
|
|
F: 'static + Future<Output = Result<M::Response>>,
|
|
{
|
|
self.add_message_handler_impl(entity, move |handle, envelope, this, cx| {
|
|
Self::respond_to_request(envelope.receipt(), handler(handle, envelope, cx), this)
|
|
})
|
|
}
|
|
|
|
async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
|
|
receipt: Receipt<T>,
|
|
response: F,
|
|
client: AnyProtoClient,
|
|
) -> Result<()> {
|
|
match response.await {
|
|
Ok(response) => {
|
|
client.send_response(receipt.message_id, response)?;
|
|
Ok(())
|
|
}
|
|
Err(error) => {
|
|
client.send_response(receipt.message_id, error.to_proto())?;
|
|
Err(error)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn has_credentials(&self, cx: &AsyncApp) -> bool {
|
|
self.credentials_provider
|
|
.read_credentials(cx)
|
|
.await
|
|
.is_some()
|
|
}
|
|
|
|
#[async_recursion(?Send)]
|
|
pub async fn authenticate_and_connect(
|
|
self: &Arc<Self>,
|
|
try_provider: bool,
|
|
cx: &AsyncApp,
|
|
) -> anyhow::Result<()> {
|
|
let was_disconnected = match *self.status().borrow() {
|
|
Status::SignedOut => true,
|
|
Status::ConnectionError
|
|
| Status::ConnectionLost
|
|
| Status::Authenticating { .. }
|
|
| Status::Reauthenticating { .. }
|
|
| Status::ReconnectionError { .. } => false,
|
|
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
|
|
return Ok(());
|
|
}
|
|
Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
|
|
};
|
|
if was_disconnected {
|
|
self.set_status(Status::Authenticating, cx);
|
|
} else {
|
|
self.set_status(Status::Reauthenticating, cx)
|
|
}
|
|
|
|
let mut read_from_provider = false;
|
|
let mut credentials = self.state.read().credentials.clone();
|
|
if credentials.is_none() && try_provider {
|
|
credentials = self.credentials_provider.read_credentials(cx).await;
|
|
read_from_provider = credentials.is_some();
|
|
}
|
|
|
|
if credentials.is_none() {
|
|
let mut status_rx = self.status();
|
|
let _ = status_rx.next().await;
|
|
futures::select_biased! {
|
|
authenticate = self.authenticate(cx).fuse() => {
|
|
match authenticate {
|
|
Ok(creds) => credentials = Some(creds),
|
|
Err(err) => {
|
|
self.set_status(Status::ConnectionError, cx);
|
|
return Err(err);
|
|
}
|
|
}
|
|
}
|
|
_ = status_rx.next().fuse() => {
|
|
return Err(anyhow!("authentication canceled"));
|
|
}
|
|
}
|
|
}
|
|
let credentials = credentials.unwrap();
|
|
self.set_id(credentials.user_id);
|
|
|
|
if was_disconnected {
|
|
self.set_status(Status::Connecting, cx);
|
|
} else {
|
|
self.set_status(Status::Reconnecting, cx);
|
|
}
|
|
|
|
let mut timeout =
|
|
futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
|
|
futures::select_biased! {
|
|
connection = self.establish_connection(&credentials, cx).fuse() => {
|
|
match connection {
|
|
Ok(conn) => {
|
|
self.state.write().credentials = Some(credentials.clone());
|
|
if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
|
|
self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
|
|
}
|
|
|
|
futures::select_biased! {
|
|
result = self.set_connection(conn, cx).fuse() => result,
|
|
_ = timeout => {
|
|
self.set_status(Status::ConnectionError, cx);
|
|
Err(anyhow!("timed out waiting on hello message from server"))
|
|
}
|
|
}
|
|
}
|
|
Err(EstablishConnectionError::Unauthorized) => {
|
|
self.state.write().credentials.take();
|
|
if read_from_provider {
|
|
self.credentials_provider.delete_credentials(cx).await.log_err();
|
|
self.set_status(Status::SignedOut, cx);
|
|
self.authenticate_and_connect(false, cx).await
|
|
} else {
|
|
self.set_status(Status::ConnectionError, cx);
|
|
Err(EstablishConnectionError::Unauthorized)?
|
|
}
|
|
}
|
|
Err(EstablishConnectionError::UpgradeRequired) => {
|
|
self.set_status(Status::UpgradeRequired, cx);
|
|
Err(EstablishConnectionError::UpgradeRequired)?
|
|
}
|
|
Err(error) => {
|
|
self.set_status(Status::ConnectionError, cx);
|
|
Err(error)?
|
|
}
|
|
}
|
|
}
|
|
_ = &mut timeout => {
|
|
self.set_status(Status::ConnectionError, cx);
|
|
Err(anyhow!("timed out trying to establish connection"))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncApp) -> Result<()> {
|
|
let executor = cx.background_executor();
|
|
log::debug!("add connection to peer");
|
|
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, {
|
|
let executor = executor.clone();
|
|
move |duration| executor.timer(duration)
|
|
});
|
|
let handle_io = executor.spawn(handle_io);
|
|
|
|
let peer_id = async {
|
|
log::debug!("waiting for server hello");
|
|
let message = incoming
|
|
.next()
|
|
.await
|
|
.ok_or_else(|| anyhow!("no hello message received"))?;
|
|
log::debug!("got server hello");
|
|
let hello_message_type_name = message.payload_type_name().to_string();
|
|
let hello = message
|
|
.into_any()
|
|
.downcast::<TypedEnvelope<proto::Hello>>()
|
|
.map_err(|_| {
|
|
anyhow!(
|
|
"invalid hello message received: {:?}",
|
|
hello_message_type_name
|
|
)
|
|
})?;
|
|
let peer_id = hello
|
|
.payload
|
|
.peer_id
|
|
.ok_or_else(|| anyhow!("invalid peer id"))?;
|
|
Ok(peer_id)
|
|
};
|
|
|
|
let peer_id = match peer_id.await {
|
|
Ok(peer_id) => peer_id,
|
|
Err(error) => {
|
|
self.peer.disconnect(connection_id);
|
|
return Err(error);
|
|
}
|
|
};
|
|
|
|
log::debug!(
|
|
"set status to connected (connection id: {:?}, peer id: {:?})",
|
|
connection_id,
|
|
peer_id
|
|
);
|
|
self.set_status(
|
|
Status::Connected {
|
|
peer_id,
|
|
connection_id,
|
|
},
|
|
cx,
|
|
);
|
|
|
|
cx.spawn({
|
|
let this = self.clone();
|
|
async move |cx| {
|
|
while let Some(message) = incoming.next().await {
|
|
this.handle_message(message, &cx);
|
|
// Don't starve the main thread when receiving lots of messages at once.
|
|
smol::future::yield_now().await;
|
|
}
|
|
}
|
|
})
|
|
.detach();
|
|
|
|
cx.spawn({
|
|
let this = self.clone();
|
|
async move |cx| match handle_io.await {
|
|
Ok(()) => {
|
|
if *this.status().borrow()
|
|
== (Status::Connected {
|
|
connection_id,
|
|
peer_id,
|
|
})
|
|
{
|
|
this.set_status(Status::SignedOut, &cx);
|
|
}
|
|
}
|
|
Err(err) => {
|
|
log::error!("connection error: {:?}", err);
|
|
this.set_status(Status::ConnectionLost, &cx);
|
|
}
|
|
}
|
|
})
|
|
.detach();
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn authenticate(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
if let Some(callback) = self.authenticate.read().as_ref() {
|
|
return callback(cx);
|
|
}
|
|
|
|
self.authenticate_with_browser(cx)
|
|
}
|
|
|
|
fn establish_connection(
|
|
self: &Arc<Self>,
|
|
credentials: &Credentials,
|
|
cx: &AsyncApp,
|
|
) -> Task<Result<Connection, EstablishConnectionError>> {
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
if let Some(callback) = self.establish_connection.read().as_ref() {
|
|
return callback(credentials, cx);
|
|
}
|
|
|
|
self.establish_websocket_connection(credentials, cx)
|
|
}
|
|
|
|
fn rpc_url(
|
|
&self,
|
|
http: Arc<HttpClientWithUrl>,
|
|
release_channel: Option<ReleaseChannel>,
|
|
) -> impl Future<Output = Result<url::Url>> + use<> {
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
let url_override = self.rpc_url.read().clone();
|
|
|
|
async move {
|
|
#[cfg(any(test, feature = "test-support"))]
|
|
if let Some(url) = url_override {
|
|
return Ok(url);
|
|
}
|
|
|
|
if let Some(url) = &*ZED_RPC_URL {
|
|
return Url::parse(url).context("invalid rpc url");
|
|
}
|
|
|
|
let mut url = http.build_url("/rpc");
|
|
if let Some(preview_param) =
|
|
release_channel.and_then(|channel| channel.release_query_param())
|
|
{
|
|
url += "?";
|
|
url += preview_param;
|
|
}
|
|
|
|
let response = http.get(&url, Default::default(), false).await?;
|
|
let collab_url = if response.status().is_redirection() {
|
|
response
|
|
.headers()
|
|
.get("Location")
|
|
.ok_or_else(|| anyhow!("missing location header in /rpc response"))?
|
|
.to_str()
|
|
.map_err(EstablishConnectionError::other)?
|
|
.to_string()
|
|
} else {
|
|
Err(anyhow!(
|
|
"unexpected /rpc response status {}",
|
|
response.status()
|
|
))?
|
|
};
|
|
|
|
Url::parse(&collab_url).context("invalid rpc url")
|
|
}
|
|
}
|
|
|
|
fn establish_websocket_connection(
|
|
self: &Arc<Self>,
|
|
credentials: &Credentials,
|
|
cx: &AsyncApp,
|
|
) -> Task<Result<Connection, EstablishConnectionError>> {
|
|
let release_channel = cx
|
|
.update(|cx| ReleaseChannel::try_global(cx))
|
|
.ok()
|
|
.flatten();
|
|
let app_version = cx
|
|
.update(|cx| AppVersion::global(cx).to_string())
|
|
.ok()
|
|
.unwrap_or_default();
|
|
|
|
let http = self.http.clone();
|
|
let proxy = http.proxy().cloned();
|
|
let credentials = credentials.clone();
|
|
let rpc_url = self.rpc_url(http, release_channel);
|
|
let system_id = self.telemetry.system_id();
|
|
let metrics_id = self.telemetry.metrics_id();
|
|
cx.spawn(async move |cx| {
|
|
use HttpOrHttps::*;
|
|
|
|
#[derive(Debug)]
|
|
enum HttpOrHttps {
|
|
Http,
|
|
Https,
|
|
}
|
|
|
|
let mut rpc_url = rpc_url.await?;
|
|
let url_scheme = match rpc_url.scheme() {
|
|
"https" => Https,
|
|
"http" => Http,
|
|
_ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
|
|
};
|
|
let rpc_host = rpc_url
|
|
.host_str()
|
|
.zip(rpc_url.port_or_known_default())
|
|
.ok_or_else(|| anyhow!("missing host in rpc url"))?;
|
|
|
|
let stream = {
|
|
let handle = cx.update(|cx| gpui_tokio::Tokio::handle(cx)).ok().unwrap();
|
|
let _guard = handle.enter();
|
|
connect_socks_proxy_stream(proxy.as_ref(), rpc_host).await?
|
|
};
|
|
|
|
log::info!("connected to rpc endpoint {}", rpc_url);
|
|
|
|
rpc_url
|
|
.set_scheme(match url_scheme {
|
|
Https => "wss",
|
|
Http => "ws",
|
|
})
|
|
.unwrap();
|
|
|
|
// We call `into_client_request` to let `tungstenite` construct the WebSocket request
|
|
// for us from the RPC URL.
|
|
//
|
|
// Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
|
|
let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
|
|
|
|
// We then modify the request to add our desired headers.
|
|
let request_headers = request.headers_mut();
|
|
request_headers.insert(
|
|
"Authorization",
|
|
HeaderValue::from_str(&credentials.authorization_header())?,
|
|
);
|
|
request_headers.insert(
|
|
"x-zed-protocol-version",
|
|
HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
|
|
);
|
|
request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
|
|
request_headers.insert(
|
|
"x-zed-release-channel",
|
|
HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
|
|
);
|
|
if let Some(system_id) = system_id {
|
|
request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
|
|
}
|
|
if let Some(metrics_id) = metrics_id {
|
|
request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
|
|
}
|
|
|
|
let (stream, _) = async_tungstenite::tokio::client_async_tls_with_connector_and_config(
|
|
request,
|
|
stream,
|
|
Some(Arc::new(http_client_tls::tls_config()).into()),
|
|
None,
|
|
)
|
|
.await?;
|
|
|
|
Ok(Connection::new(
|
|
stream
|
|
.map_err(|error| anyhow!(error))
|
|
.sink_map_err(|error| anyhow!(error)),
|
|
))
|
|
})
|
|
}
|
|
|
|
pub fn authenticate_with_browser(self: &Arc<Self>, cx: &AsyncApp) -> Task<Result<Credentials>> {
|
|
let http = self.http.clone();
|
|
let this = self.clone();
|
|
cx.spawn(async move |cx| {
|
|
let background = cx.background_executor().clone();
|
|
|
|
let (open_url_tx, open_url_rx) = oneshot::channel::<String>();
|
|
cx.update(|cx| {
|
|
cx.spawn(async move |cx| {
|
|
let url = open_url_rx.await?;
|
|
cx.update(|cx| cx.open_url(&url))
|
|
})
|
|
.detach_and_log_err(cx);
|
|
})
|
|
.log_err();
|
|
|
|
let credentials = background
|
|
.clone()
|
|
.spawn(async move {
|
|
// Generate a pair of asymmetric encryption keys. The public key will be used by the
|
|
// zed server to encrypt the user's access token, so that it can'be intercepted by
|
|
// any other app running on the user's device.
|
|
let (public_key, private_key) =
|
|
rpc::auth::keypair().expect("failed to generate keypair for auth");
|
|
let public_key_string = String::try_from(public_key)
|
|
.expect("failed to serialize public key for auth");
|
|
|
|
if let Some((login, token)) =
|
|
IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref())
|
|
{
|
|
eprintln!("authenticate as admin {login}, {token}");
|
|
|
|
return this
|
|
.authenticate_as_admin(http, login.clone(), token.clone())
|
|
.await;
|
|
}
|
|
|
|
// Start an HTTP server to receive the redirect from Zed's sign-in page.
|
|
let server =
|
|
tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
|
|
let port = server.server_addr().port();
|
|
|
|
// Open the Zed sign-in page in the user's browser, with query parameters that indicate
|
|
// that the user is signing in from a Zed app running on the same device.
|
|
let mut url = http.build_url(&format!(
|
|
"/native_app_signin?native_app_port={}&native_app_public_key={}",
|
|
port, public_key_string
|
|
));
|
|
|
|
if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
|
|
log::info!("impersonating user @{}", impersonate_login);
|
|
write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
|
|
}
|
|
|
|
open_url_tx.send(url).log_err();
|
|
|
|
// Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
|
|
// access token from the query params.
|
|
//
|
|
// TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
|
|
// custom URL scheme instead of this local HTTP server.
|
|
let (user_id, access_token) = background
|
|
.spawn(async move {
|
|
for _ in 0..100 {
|
|
if let Some(req) = server.recv_timeout(Duration::from_secs(1))? {
|
|
let path = req.url();
|
|
let mut user_id = None;
|
|
let mut access_token = None;
|
|
let url = Url::parse(&format!("http://example.com{}", path))
|
|
.context("failed to parse login notification url")?;
|
|
for (key, value) in url.query_pairs() {
|
|
if key == "access_token" {
|
|
access_token = Some(value.to_string());
|
|
} else if key == "user_id" {
|
|
user_id = Some(value.to_string());
|
|
}
|
|
}
|
|
|
|
let post_auth_url =
|
|
http.build_url("/native_app_signin_succeeded");
|
|
req.respond(
|
|
tiny_http::Response::empty(302).with_header(
|
|
tiny_http::Header::from_bytes(
|
|
&b"Location"[..],
|
|
post_auth_url.as_bytes(),
|
|
)
|
|
.unwrap(),
|
|
),
|
|
)
|
|
.context("failed to respond to login http request")?;
|
|
return Ok((
|
|
user_id
|
|
.ok_or_else(|| anyhow!("missing user_id parameter"))?,
|
|
access_token.ok_or_else(|| {
|
|
anyhow!("missing access_token parameter")
|
|
})?,
|
|
));
|
|
}
|
|
}
|
|
|
|
Err(anyhow!("didn't receive login redirect"))
|
|
})
|
|
.await?;
|
|
|
|
let access_token = private_key
|
|
.decrypt_string(&access_token)
|
|
.context("failed to decrypt access token")?;
|
|
|
|
Ok(Credentials {
|
|
user_id: user_id.parse()?,
|
|
access_token,
|
|
})
|
|
})
|
|
.await?;
|
|
|
|
cx.update(|cx| cx.activate(true))?;
|
|
Ok(credentials)
|
|
})
|
|
}
|
|
|
|
async fn authenticate_as_admin(
|
|
self: &Arc<Self>,
|
|
http: Arc<HttpClientWithUrl>,
|
|
login: String,
|
|
mut api_token: String,
|
|
) -> Result<Credentials> {
|
|
#[derive(Deserialize)]
|
|
struct AuthenticatedUserResponse {
|
|
user: User,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct User {
|
|
id: u64,
|
|
}
|
|
|
|
let github_user = {
|
|
#[derive(Deserialize)]
|
|
struct GithubUser {
|
|
id: i32,
|
|
login: String,
|
|
created_at: DateTime<Utc>,
|
|
}
|
|
|
|
let request = {
|
|
let mut request_builder =
|
|
Request::get(&format!("https://api.github.com/users/{login}"));
|
|
if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
|
|
request_builder =
|
|
request_builder.header("Authorization", format!("Bearer {}", github_token));
|
|
}
|
|
|
|
request_builder.body(AsyncBody::empty())?
|
|
};
|
|
|
|
let mut response = http
|
|
.send(request)
|
|
.await
|
|
.context("error fetching GitHub user")?;
|
|
|
|
let mut body = Vec::new();
|
|
response
|
|
.body_mut()
|
|
.read_to_end(&mut body)
|
|
.await
|
|
.context("error reading GitHub user")?;
|
|
|
|
if !response.status().is_success() {
|
|
let text = String::from_utf8_lossy(body.as_slice());
|
|
bail!(
|
|
"status error {}, response: {text:?}",
|
|
response.status().as_u16()
|
|
);
|
|
}
|
|
|
|
serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
|
|
log::error!("Error deserializing: {:?}", err);
|
|
log::error!(
|
|
"GitHub API response text: {:?}",
|
|
String::from_utf8_lossy(body.as_slice())
|
|
);
|
|
anyhow!("error deserializing GitHub user")
|
|
})?
|
|
};
|
|
|
|
let query_params = [
|
|
("github_login", &github_user.login),
|
|
("github_user_id", &github_user.id.to_string()),
|
|
(
|
|
"github_user_created_at",
|
|
&github_user.created_at.to_rfc3339(),
|
|
),
|
|
];
|
|
|
|
// Use the collab server's admin API to retrieve the ID
|
|
// of the impersonated user.
|
|
let mut url = self.rpc_url(http.clone(), None).await?;
|
|
url.set_path("/user");
|
|
url.set_query(Some(
|
|
&query_params
|
|
.iter()
|
|
.map(|(key, value)| {
|
|
format!(
|
|
"{}={}",
|
|
key,
|
|
url::form_urlencoded::byte_serialize(value.as_bytes()).collect::<String>()
|
|
)
|
|
})
|
|
.collect::<Vec<String>>()
|
|
.join("&"),
|
|
));
|
|
let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
|
|
.header("Authorization", format!("token {api_token}"))
|
|
.body("".into())?;
|
|
|
|
let mut response = http.send(request).await?;
|
|
let mut body = String::new();
|
|
response.body_mut().read_to_string(&mut body).await?;
|
|
if !response.status().is_success() {
|
|
Err(anyhow!(
|
|
"admin user request failed {} - {}",
|
|
response.status().as_u16(),
|
|
body,
|
|
))?;
|
|
}
|
|
let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
|
|
|
|
// Use the admin API token to authenticate as the impersonated user.
|
|
api_token.insert_str(0, "ADMIN_TOKEN:");
|
|
Ok(Credentials {
|
|
user_id: response.user.id,
|
|
access_token: api_token,
|
|
})
|
|
}
|
|
|
|
pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
|
|
self.state.write().credentials = None;
|
|
self.disconnect(cx);
|
|
|
|
if self.has_credentials(cx).await {
|
|
self.credentials_provider
|
|
.delete_credentials(cx)
|
|
.await
|
|
.log_err();
|
|
}
|
|
}
|
|
|
|
pub fn disconnect(self: &Arc<Self>, cx: &AsyncApp) {
|
|
self.peer.teardown();
|
|
self.set_status(Status::SignedOut, cx);
|
|
}
|
|
|
|
pub fn reconnect(self: &Arc<Self>, cx: &AsyncApp) {
|
|
self.peer.teardown();
|
|
self.set_status(Status::ConnectionLost, cx);
|
|
}
|
|
|
|
fn connection_id(&self) -> Result<ConnectionId> {
|
|
if let Status::Connected { connection_id, .. } = *self.status().borrow() {
|
|
Ok(connection_id)
|
|
} else {
|
|
Err(anyhow!("not connected"))
|
|
}
|
|
}
|
|
|
|
pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
|
|
log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
|
|
self.peer.send(self.connection_id()?, message)
|
|
}
|
|
|
|
pub fn request<T: RequestMessage>(
|
|
&self,
|
|
request: T,
|
|
) -> impl Future<Output = Result<T::Response>> + use<T> {
|
|
self.request_envelope(request)
|
|
.map_ok(|envelope| envelope.payload)
|
|
}
|
|
|
|
pub fn request_stream<T: RequestMessage>(
|
|
&self,
|
|
request: T,
|
|
) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
|
|
let client_id = self.id.load(Ordering::SeqCst);
|
|
log::debug!(
|
|
"rpc request start. client_id:{}. name:{}",
|
|
client_id,
|
|
T::NAME
|
|
);
|
|
let response = self
|
|
.connection_id()
|
|
.map(|conn_id| self.peer.request_stream(conn_id, request));
|
|
async move {
|
|
let response = response?.await;
|
|
log::debug!(
|
|
"rpc request finish. client_id:{}. name:{}",
|
|
client_id,
|
|
T::NAME
|
|
);
|
|
response
|
|
}
|
|
}
|
|
|
|
pub fn request_envelope<T: RequestMessage>(
|
|
&self,
|
|
request: T,
|
|
) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
|
|
let client_id = self.id();
|
|
log::debug!(
|
|
"rpc request start. client_id:{}. name:{}",
|
|
client_id,
|
|
T::NAME
|
|
);
|
|
let response = self
|
|
.connection_id()
|
|
.map(|conn_id| self.peer.request_envelope(conn_id, request));
|
|
async move {
|
|
let response = response?.await;
|
|
log::debug!(
|
|
"rpc request finish. client_id:{}. name:{}",
|
|
client_id,
|
|
T::NAME
|
|
);
|
|
response
|
|
}
|
|
}
|
|
|
|
pub fn request_dynamic(
|
|
&self,
|
|
envelope: proto::Envelope,
|
|
request_type: &'static str,
|
|
) -> impl Future<Output = Result<proto::Envelope>> + use<> {
|
|
let client_id = self.id();
|
|
log::debug!(
|
|
"rpc request start. client_id:{}. name:{}",
|
|
client_id,
|
|
request_type
|
|
);
|
|
let response = self
|
|
.connection_id()
|
|
.map(|conn_id| self.peer.request_dynamic(conn_id, envelope, request_type));
|
|
async move {
|
|
let response = response?.await;
|
|
log::debug!(
|
|
"rpc request finish. client_id:{}. name:{}",
|
|
client_id,
|
|
request_type
|
|
);
|
|
Ok(response?.0)
|
|
}
|
|
}
|
|
|
|
fn handle_message(self: &Arc<Client>, message: Box<dyn AnyTypedEnvelope>, cx: &AsyncApp) {
|
|
let sender_id = message.sender_id();
|
|
let request_id = message.message_id();
|
|
let type_name = message.payload_type_name();
|
|
let original_sender_id = message.original_sender_id();
|
|
|
|
if let Some(future) = ProtoMessageHandlerSet::handle_message(
|
|
&self.handler_set,
|
|
message,
|
|
self.clone().into(),
|
|
cx.clone(),
|
|
) {
|
|
let client_id = self.id();
|
|
log::debug!(
|
|
"rpc message received. client_id:{}, sender_id:{:?}, type:{}",
|
|
client_id,
|
|
original_sender_id,
|
|
type_name
|
|
);
|
|
cx.spawn(async move |_| match future.await {
|
|
Ok(()) => {
|
|
log::debug!(
|
|
"rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
|
|
client_id,
|
|
original_sender_id,
|
|
type_name
|
|
);
|
|
}
|
|
Err(error) => {
|
|
log::error!(
|
|
"error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
|
|
client_id,
|
|
original_sender_id,
|
|
type_name,
|
|
error
|
|
);
|
|
}
|
|
})
|
|
.detach();
|
|
} else {
|
|
log::info!("unhandled message {}", type_name);
|
|
self.peer
|
|
.respond_with_unhandled_message(sender_id.into(), request_id, type_name)
|
|
.log_err();
|
|
}
|
|
}
|
|
|
|
pub fn telemetry(&self) -> &Arc<Telemetry> {
|
|
&self.telemetry
|
|
}
|
|
}
|
|
|
|
impl ProtoClient for Client {
|
|
fn request(
|
|
&self,
|
|
envelope: proto::Envelope,
|
|
request_type: &'static str,
|
|
) -> BoxFuture<'static, Result<proto::Envelope>> {
|
|
self.request_dynamic(envelope, request_type).boxed()
|
|
}
|
|
|
|
fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
|
|
log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
|
|
let connection_id = self.connection_id()?;
|
|
self.peer.send_dynamic(connection_id, envelope)
|
|
}
|
|
|
|
fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
|
|
log::debug!(
|
|
"rpc respond. client_id:{}, name:{}",
|
|
self.id(),
|
|
message_type
|
|
);
|
|
let connection_id = self.connection_id()?;
|
|
self.peer.send_dynamic(connection_id, envelope)
|
|
}
|
|
|
|
fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
|
|
&self.handler_set
|
|
}
|
|
|
|
fn is_via_collab(&self) -> bool {
|
|
true
|
|
}
|
|
}
|
|
|
|
/// prefix for the zed:// url scheme
|
|
pub const ZED_URL_SCHEME: &str = "zed";
|
|
|
|
/// Parses the given link into a Zed link.
|
|
///
|
|
/// Returns a [`Some`] containing the unprefixed link if the link is a Zed link.
|
|
/// Returns [`None`] otherwise.
|
|
pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
|
|
let server_url = &ClientSettings::get_global(cx).server_url;
|
|
if let Some(stripped) = link
|
|
.strip_prefix(server_url)
|
|
.and_then(|result| result.strip_prefix('/'))
|
|
{
|
|
return Some(stripped);
|
|
}
|
|
if let Some(stripped) = link
|
|
.strip_prefix(ZED_URL_SCHEME)
|
|
.and_then(|result| result.strip_prefix("://"))
|
|
{
|
|
return Some(stripped);
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::test::FakeServer;
|
|
|
|
use clock::FakeSystemClock;
|
|
use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
|
|
use http_client::FakeHttpClient;
|
|
use parking_lot::Mutex;
|
|
use proto::TypedEnvelope;
|
|
use settings::SettingsStore;
|
|
use std::future;
|
|
|
|
#[gpui::test(iterations = 10)]
|
|
async fn test_reconnection(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let user_id = 5;
|
|
let client = cx.update(|cx| {
|
|
Client::new(
|
|
Arc::new(FakeSystemClock::new()),
|
|
FakeHttpClient::with_404_response(),
|
|
cx,
|
|
)
|
|
});
|
|
let server = FakeServer::for_client(user_id, &client, cx).await;
|
|
let mut status = client.status();
|
|
assert!(matches!(
|
|
status.next().await,
|
|
Some(Status::Connected { .. })
|
|
));
|
|
assert_eq!(server.auth_count(), 1);
|
|
|
|
server.forbid_connections();
|
|
server.disconnect();
|
|
while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
|
|
|
|
server.allow_connections();
|
|
cx.executor().advance_clock(Duration::from_secs(10));
|
|
while !matches!(status.next().await, Some(Status::Connected { .. })) {}
|
|
assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
|
|
|
|
server.forbid_connections();
|
|
server.disconnect();
|
|
while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
|
|
|
|
// Clear cached credentials after authentication fails
|
|
server.roll_access_token();
|
|
server.allow_connections();
|
|
cx.executor().run_until_parked();
|
|
cx.executor().advance_clock(Duration::from_secs(10));
|
|
while !matches!(status.next().await, Some(Status::Connected { .. })) {}
|
|
assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
|
|
}
|
|
|
|
#[gpui::test(iterations = 10)]
|
|
async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let user_id = 5;
|
|
let client = cx.update(|cx| {
|
|
Client::new(
|
|
Arc::new(FakeSystemClock::new()),
|
|
FakeHttpClient::with_404_response(),
|
|
cx,
|
|
)
|
|
});
|
|
let mut status = client.status();
|
|
|
|
// Time out when client tries to connect.
|
|
client.override_authenticate(move |cx| {
|
|
cx.background_spawn(async move {
|
|
Ok(Credentials {
|
|
user_id,
|
|
access_token: "token".into(),
|
|
})
|
|
})
|
|
});
|
|
client.override_establish_connection(|_, cx| {
|
|
cx.background_spawn(async move {
|
|
future::pending::<()>().await;
|
|
unreachable!()
|
|
})
|
|
});
|
|
let auth_and_connect = cx.spawn({
|
|
let client = client.clone();
|
|
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
|
});
|
|
executor.run_until_parked();
|
|
assert!(matches!(status.next().await, Some(Status::Connecting)));
|
|
|
|
executor.advance_clock(CONNECTION_TIMEOUT);
|
|
assert!(matches!(
|
|
status.next().await,
|
|
Some(Status::ConnectionError { .. })
|
|
));
|
|
auth_and_connect.await.unwrap_err();
|
|
|
|
// Allow the connection to be established.
|
|
let server = FakeServer::for_client(user_id, &client, cx).await;
|
|
assert!(matches!(
|
|
status.next().await,
|
|
Some(Status::Connected { .. })
|
|
));
|
|
|
|
// Disconnect client.
|
|
server.forbid_connections();
|
|
server.disconnect();
|
|
while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {}
|
|
|
|
// Time out when re-establishing the connection.
|
|
server.allow_connections();
|
|
client.override_establish_connection(|_, cx| {
|
|
cx.background_spawn(async move {
|
|
future::pending::<()>().await;
|
|
unreachable!()
|
|
})
|
|
});
|
|
executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY);
|
|
assert!(matches!(
|
|
status.next().await,
|
|
Some(Status::Reconnecting { .. })
|
|
));
|
|
|
|
executor.advance_clock(CONNECTION_TIMEOUT);
|
|
assert!(matches!(
|
|
status.next().await,
|
|
Some(Status::ReconnectionError { .. })
|
|
));
|
|
}
|
|
|
|
#[gpui::test(iterations = 10)]
|
|
async fn test_authenticating_more_than_once(
|
|
cx: &mut TestAppContext,
|
|
executor: BackgroundExecutor,
|
|
) {
|
|
init_test(cx);
|
|
let auth_count = Arc::new(Mutex::new(0));
|
|
let dropped_auth_count = Arc::new(Mutex::new(0));
|
|
let client = cx.update(|cx| {
|
|
Client::new(
|
|
Arc::new(FakeSystemClock::new()),
|
|
FakeHttpClient::with_404_response(),
|
|
cx,
|
|
)
|
|
});
|
|
client.override_authenticate({
|
|
let auth_count = auth_count.clone();
|
|
let dropped_auth_count = dropped_auth_count.clone();
|
|
move |cx| {
|
|
let auth_count = auth_count.clone();
|
|
let dropped_auth_count = dropped_auth_count.clone();
|
|
cx.background_spawn(async move {
|
|
*auth_count.lock() += 1;
|
|
let _drop = util::defer(move || *dropped_auth_count.lock() += 1);
|
|
future::pending::<()>().await;
|
|
unreachable!()
|
|
})
|
|
}
|
|
});
|
|
|
|
let _authenticate = cx.spawn({
|
|
let client = client.clone();
|
|
move |cx| async move { client.authenticate_and_connect(false, &cx).await }
|
|
});
|
|
executor.run_until_parked();
|
|
assert_eq!(*auth_count.lock(), 1);
|
|
assert_eq!(*dropped_auth_count.lock(), 0);
|
|
|
|
let _authenticate = cx.spawn({
|
|
let client = client.clone();
|
|
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
|
});
|
|
executor.run_until_parked();
|
|
assert_eq!(*auth_count.lock(), 2);
|
|
assert_eq!(*dropped_auth_count.lock(), 1);
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subscribing_to_entity(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let user_id = 5;
|
|
let client = cx.update(|cx| {
|
|
Client::new(
|
|
Arc::new(FakeSystemClock::new()),
|
|
FakeHttpClient::with_404_response(),
|
|
cx,
|
|
)
|
|
});
|
|
let server = FakeServer::for_client(user_id, &client, cx).await;
|
|
|
|
let (done_tx1, done_rx1) = smol::channel::unbounded();
|
|
let (done_tx2, done_rx2) = smol::channel::unbounded();
|
|
AnyProtoClient::from(client.clone()).add_entity_message_handler(
|
|
move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
|
|
match entity.update(&mut cx, |entity, _| entity.id).unwrap() {
|
|
1 => done_tx1.try_send(()).unwrap(),
|
|
2 => done_tx2.try_send(()).unwrap(),
|
|
_ => unreachable!(),
|
|
}
|
|
async { Ok(()) }
|
|
},
|
|
);
|
|
let entity1 = cx.new(|_| TestEntity {
|
|
id: 1,
|
|
subscription: None,
|
|
});
|
|
let entity2 = cx.new(|_| TestEntity {
|
|
id: 2,
|
|
subscription: None,
|
|
});
|
|
let entity3 = cx.new(|_| TestEntity {
|
|
id: 3,
|
|
subscription: None,
|
|
});
|
|
|
|
let _subscription1 = client
|
|
.subscribe_to_entity(1)
|
|
.unwrap()
|
|
.set_entity(&entity1, &mut cx.to_async());
|
|
let _subscription2 = client
|
|
.subscribe_to_entity(2)
|
|
.unwrap()
|
|
.set_entity(&entity2, &mut cx.to_async());
|
|
// Ensure dropping a subscription for the same entity type still allows receiving of
|
|
// messages for other entity IDs of the same type.
|
|
let subscription3 = client
|
|
.subscribe_to_entity(3)
|
|
.unwrap()
|
|
.set_entity(&entity3, &mut cx.to_async());
|
|
drop(subscription3);
|
|
|
|
server.send(proto::JoinProject { project_id: 1 });
|
|
server.send(proto::JoinProject { project_id: 2 });
|
|
done_rx1.recv().await.unwrap();
|
|
done_rx2.recv().await.unwrap();
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let user_id = 5;
|
|
let client = cx.update(|cx| {
|
|
Client::new(
|
|
Arc::new(FakeSystemClock::new()),
|
|
FakeHttpClient::with_404_response(),
|
|
cx,
|
|
)
|
|
});
|
|
let server = FakeServer::for_client(user_id, &client, cx).await;
|
|
|
|
let entity = cx.new(|_| TestEntity::default());
|
|
let (done_tx1, _done_rx1) = smol::channel::unbounded();
|
|
let (done_tx2, done_rx2) = smol::channel::unbounded();
|
|
let subscription1 = client.add_message_handler(
|
|
entity.downgrade(),
|
|
move |_, _: TypedEnvelope<proto::Ping>, _| {
|
|
done_tx1.try_send(()).unwrap();
|
|
async { Ok(()) }
|
|
},
|
|
);
|
|
drop(subscription1);
|
|
let _subscription2 = client.add_message_handler(
|
|
entity.downgrade(),
|
|
move |_, _: TypedEnvelope<proto::Ping>, _| {
|
|
done_tx2.try_send(()).unwrap();
|
|
async { Ok(()) }
|
|
},
|
|
);
|
|
server.send(proto::Ping {});
|
|
done_rx2.recv().await.unwrap();
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
let user_id = 5;
|
|
let client = cx.update(|cx| {
|
|
Client::new(
|
|
Arc::new(FakeSystemClock::new()),
|
|
FakeHttpClient::with_404_response(),
|
|
cx,
|
|
)
|
|
});
|
|
let server = FakeServer::for_client(user_id, &client, cx).await;
|
|
|
|
let entity = cx.new(|_| TestEntity::default());
|
|
let (done_tx, done_rx) = smol::channel::unbounded();
|
|
let subscription = client.add_message_handler(
|
|
entity.clone().downgrade(),
|
|
move |entity: Entity<TestEntity>, _: TypedEnvelope<proto::Ping>, mut cx| {
|
|
entity
|
|
.update(&mut cx, |entity, _| entity.subscription.take())
|
|
.unwrap();
|
|
done_tx.try_send(()).unwrap();
|
|
async { Ok(()) }
|
|
},
|
|
);
|
|
entity.update(cx, |entity, _| {
|
|
entity.subscription = Some(subscription);
|
|
});
|
|
server.send(proto::Ping {});
|
|
done_rx.recv().await.unwrap();
|
|
}
|
|
|
|
#[derive(Default)]
|
|
struct TestEntity {
|
|
id: usize,
|
|
subscription: Option<Subscription>,
|
|
}
|
|
|
|
fn init_test(cx: &mut TestAppContext) {
|
|
cx.update(|cx| {
|
|
let settings_store = SettingsStore::test(cx);
|
|
cx.set_global(settings_store);
|
|
init_settings(cx);
|
|
});
|
|
}
|
|
}
|