Start separating authentication from connection to collab (#35471)

This pull request should be idempotent, but lays the groundwork for
avoiding to connect to collab in order to interact with AI features
provided by Zed.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Antonio Scandurra 2025-08-01 19:37:38 +02:00 committed by GitHub
parent b01d1872cc
commit f888f3fc0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
46 changed files with 653 additions and 855 deletions

View file

@ -17,7 +17,6 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
[dependencies]
anyhow.workspace = true
async-recursion = "0.3"
async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] }
base64.workspace = true
chrono = { workspace = true, features = ["serde"] }

View file

@ -1,14 +1,12 @@
#[cfg(any(test, feature = "test-support"))]
pub mod test;
mod cloud;
mod proxy;
pub mod telemetry;
pub mod user;
pub mod zed_urls;
use anyhow::{Context as _, Result, anyhow};
use async_recursion::async_recursion;
use async_tungstenite::tungstenite::{
client::IntoClientRequest,
error::Error as WebsocketError,
@ -52,7 +50,6 @@ use tokio::net::TcpStream;
use url::Url;
use util::{ConnectionResult, ResultExt};
pub use cloud::*;
pub use rpc::*;
pub use telemetry_events::Event;
pub use user::*;
@ -164,20 +161,8 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
let client = client.clone();
move |_: &SignIn, cx| {
if let Some(client) = client.upgrade() {
cx.spawn(
async move |cx| match client.authenticate_and_connect(true, &cx).await {
ConnectionResult::Timeout => {
log::error!("Initial authentication timed out");
}
ConnectionResult::ConnectionReset => {
log::error!("Initial authentication connection reset");
}
ConnectionResult::Result(r) => {
r.log_err();
}
},
)
.detach();
cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await)
.detach_and_log_err(cx);
}
}
});
@ -286,6 +271,8 @@ pub enum Status {
SignedOut,
UpgradeRequired,
Authenticating,
Authenticated,
AuthenticationError,
Connecting,
ConnectionError,
Connected {
@ -712,7 +699,7 @@ impl Client {
let mut delay = INITIAL_RECONNECTION_DELAY;
loop {
match client.authenticate_and_connect(true, &cx).await {
match client.connect(true, &cx).await {
ConnectionResult::Timeout => {
log::error!("client connect attempt timed out")
}
@ -882,17 +869,122 @@ impl Client {
.is_some()
}
#[async_recursion(?Send)]
pub async fn authenticate_and_connect(
pub async fn sign_in(
self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> Result<Credentials> {
if self.status().borrow().is_signed_out() {
self.set_status(Status::Authenticating, cx);
} else {
self.set_status(Status::Reauthenticating, cx);
}
let mut credentials = None;
let old_credentials = self.state.read().credentials.clone();
if let Some(old_credentials) = old_credentials {
self.cloud_client.set_credentials(
old_credentials.user_id as u32,
old_credentials.access_token.clone(),
);
// Fetch the authenticated user with the old credentials, to ensure they are still valid.
if self.cloud_client.get_authenticated_user().await.is_ok() {
credentials = Some(old_credentials);
}
}
if credentials.is_none() && try_provider {
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
self.cloud_client.set_credentials(
stored_credentials.user_id as u32,
stored_credentials.access_token.clone(),
);
// Fetch the authenticated user with the stored credentials, and
// clear them from the credentials provider if that fails.
if self.cloud_client.get_authenticated_user().await.is_ok() {
credentials = Some(stored_credentials);
} else {
self.credentials_provider
.delete_credentials(cx)
.await
.log_err();
}
}
}
if credentials.is_none() {
let mut status_rx = self.status();
let _ = status_rx.next().await;
futures::select_biased! {
authenticate = self.authenticate(cx).fuse() => {
match authenticate {
Ok(creds) => {
if IMPERSONATE_LOGIN.is_none() {
self.credentials_provider
.write_credentials(creds.user_id, creds.access_token.clone(), cx)
.await
.log_err();
}
credentials = Some(creds);
},
Err(err) => {
self.set_status(Status::AuthenticationError, cx);
return Err(err);
}
}
}
_ = status_rx.next().fuse() => {
return Err(anyhow!("authentication canceled"));
}
}
}
let credentials = credentials.unwrap();
self.set_id(credentials.user_id);
self.cloud_client
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
self.state.write().credentials = Some(credentials.clone());
self.set_status(Status::Authenticated, cx);
Ok(credentials)
}
/// Performs a sign-in and also connects to Collab.
///
/// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls
/// to `sign_in` when we're ready to remove auto-connection to Collab.
pub async fn sign_in_with_optional_connect(
self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> Result<()> {
let credentials = self.sign_in(try_provider, cx).await?;
let connect_result = match self.connect_with_credentials(credentials, cx).await {
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
ConnectionResult::Result(result) => result.context("client auth and connect"),
};
connect_result.log_err();
Ok(())
}
pub async fn connect(
self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> ConnectionResult<()> {
let was_disconnected = match *self.status().borrow() {
Status::SignedOut => true,
Status::SignedOut | Status::Authenticated => true,
Status::ConnectionError
| Status::ConnectionLost
| Status::Authenticating { .. }
| Status::AuthenticationError
| Status::Reauthenticating { .. }
| Status::ReconnectionError { .. } => false,
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
@ -905,41 +997,10 @@ impl Client {
);
}
};
if was_disconnected {
self.set_status(Status::Authenticating, cx);
} else {
self.set_status(Status::Reauthenticating, cx)
}
let mut read_from_provider = false;
let mut credentials = self.state.read().credentials.clone();
if credentials.is_none() && try_provider {
credentials = self.credentials_provider.read_credentials(cx).await;
read_from_provider = credentials.is_some();
}
if credentials.is_none() {
let mut status_rx = self.status();
let _ = status_rx.next().await;
futures::select_biased! {
authenticate = self.authenticate(cx).fuse() => {
match authenticate {
Ok(creds) => credentials = Some(creds),
Err(err) => {
self.set_status(Status::ConnectionError, cx);
return ConnectionResult::Result(Err(err));
}
}
}
_ = status_rx.next().fuse() => {
return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
}
}
}
let credentials = credentials.unwrap();
self.set_id(credentials.user_id);
self.cloud_client
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
let credentials = match self.sign_in(try_provider, cx).await {
Ok(credentials) => credentials,
Err(err) => return ConnectionResult::Result(Err(err)),
};
if was_disconnected {
self.set_status(Status::Connecting, cx);
@ -947,17 +1008,20 @@ impl Client {
self.set_status(Status::Reconnecting, cx);
}
self.connect_with_credentials(credentials, cx).await
}
async fn connect_with_credentials(
self: &Arc<Self>,
credentials: Credentials,
cx: &AsyncApp,
) -> ConnectionResult<()> {
let mut timeout =
futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
futures::select_biased! {
connection = self.establish_connection(&credentials, cx).fuse() => {
match connection {
Ok(conn) => {
self.state.write().credentials = Some(credentials.clone());
if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
}
futures::select_biased! {
result = self.set_connection(conn, cx).fuse() => {
match result.context("client auth and connect") {
@ -975,15 +1039,8 @@ impl Client {
}
}
Err(EstablishConnectionError::Unauthorized) => {
self.state.write().credentials.take();
if read_from_provider {
self.credentials_provider.delete_credentials(cx).await.log_err();
self.set_status(Status::SignedOut, cx);
self.authenticate_and_connect(false, cx).await
} else {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
}
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
}
Err(EstablishConnectionError::UpgradeRequired) => {
self.set_status(Status::UpgradeRequired, cx);
@ -1733,7 +1790,7 @@ mod tests {
});
let auth_and_connect = cx.spawn({
let client = client.clone();
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|cx| async move { client.connect(false, &cx).await }
});
executor.run_until_parked();
assert!(matches!(status.next().await, Some(Status::Connecting)));
@ -1810,7 +1867,7 @@ mod tests {
let _authenticate = cx.spawn({
let client = client.clone();
move |cx| async move { client.authenticate_and_connect(false, &cx).await }
move |cx| async move { client.connect(false, &cx).await }
});
executor.run_until_parked();
assert_eq!(*auth_count.lock(), 1);
@ -1818,7 +1875,7 @@ mod tests {
let _authenticate = cx.spawn({
let client = client.clone();
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|cx| async move { client.connect(false, &cx).await }
});
executor.run_until_parked();
assert_eq!(*auth_count.lock(), 2);

View file

@ -1,3 +0,0 @@
mod user_store;
pub use user_store::*;

View file

@ -1,211 +0,0 @@
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context as _;
use chrono::{DateTime, Utc};
use cloud_api_client::{AuthenticatedUser, CloudApiClient, GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::Plan;
use gpui::{Context, Entity, Subscription, Task};
use util::{ResultExt as _, maybe};
use crate::user::Event as RpcUserStoreEvent;
use crate::{EditPredictionUsage, ModelRequestUsage, RequestUsage, UserStore};
pub struct CloudUserStore {
cloud_client: Arc<CloudApiClient>,
authenticated_user: Option<Arc<AuthenticatedUser>>,
plan_info: Option<Arc<PlanInfo>>,
model_request_usage: Option<ModelRequestUsage>,
edit_prediction_usage: Option<EditPredictionUsage>,
_maintain_authenticated_user_task: Task<()>,
_rpc_plan_updated_subscription: Subscription,
}
impl CloudUserStore {
pub fn new(
cloud_client: Arc<CloudApiClient>,
rpc_user_store: Entity<UserStore>,
cx: &mut Context<Self>,
) -> Self {
let rpc_plan_updated_subscription =
cx.subscribe(&rpc_user_store, Self::handle_rpc_user_store_event);
Self {
cloud_client: cloud_client.clone(),
authenticated_user: None,
plan_info: None,
model_request_usage: None,
edit_prediction_usage: None,
_maintain_authenticated_user_task: cx.spawn(async move |this, cx| {
maybe!(async move {
loop {
let Some(this) = this.upgrade() else {
return anyhow::Ok(());
};
if cloud_client.has_credentials() {
let already_fetched_authenticated_user = this
.read_with(cx, |this, _cx| this.authenticated_user().is_some())
.unwrap_or(false);
if already_fetched_authenticated_user {
// We already fetched the authenticated user; nothing to do.
} else {
let authenticated_user_result = cloud_client
.get_authenticated_user()
.await
.context("failed to fetch authenticated user");
if let Some(response) = authenticated_user_result.log_err() {
this.update(cx, |this, _cx| {
this.update_authenticated_user(response);
})
.ok();
}
}
} else {
this.update(cx, |this, _cx| {
this.authenticated_user.take();
this.plan_info.take();
})
.ok();
}
cx.background_executor()
.timer(Duration::from_millis(100))
.await;
}
})
.await
.log_err();
}),
_rpc_plan_updated_subscription: rpc_plan_updated_subscription,
}
}
pub fn is_authenticated(&self) -> bool {
self.authenticated_user.is_some()
}
pub fn authenticated_user(&self) -> Option<Arc<AuthenticatedUser>> {
self.authenticated_user.clone()
}
pub fn plan(&self) -> Option<Plan> {
self.plan_info.as_ref().map(|plan| plan.plan)
}
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
self.plan_info
.as_ref()
.and_then(|plan| plan.subscription_period)
.map(|subscription_period| {
(
subscription_period.started_at.0,
subscription_period.ended_at.0,
)
})
}
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
self.plan_info
.as_ref()
.and_then(|plan| plan.trial_started_at)
.map(|trial_started_at| trial_started_at.0)
}
pub fn has_accepted_tos(&self) -> bool {
self.authenticated_user
.as_ref()
.map(|user| user.accepted_tos_at.is_some())
.unwrap_or_default()
}
/// Returns whether the user's account is too new to use the service.
pub fn account_too_young(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_account_too_young)
.unwrap_or_default()
}
/// Returns whether the current user has overdue invoices and usage should be blocked.
pub fn has_overdue_invoices(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.has_overdue_invoices)
.unwrap_or_default()
}
pub fn is_usage_based_billing_enabled(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_usage_based_billing_enabled)
.unwrap_or_default()
}
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
self.model_request_usage
}
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
self.model_request_usage = Some(usage);
cx.notify();
}
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
self.edit_prediction_usage
}
pub fn update_edit_prediction_usage(
&mut self,
usage: EditPredictionUsage,
cx: &mut Context<Self>,
) {
self.edit_prediction_usage = Some(usage);
cx.notify();
}
fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) {
self.authenticated_user = Some(Arc::new(response.user));
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
limit: response.plan.usage.model_requests.limit,
amount: response.plan.usage.model_requests.used as i32,
}));
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
limit: response.plan.usage.edit_predictions.limit,
amount: response.plan.usage.edit_predictions.used as i32,
}));
self.plan_info = Some(Arc::new(response.plan));
}
fn handle_rpc_user_store_event(
&mut self,
_: Entity<UserStore>,
event: &RpcUserStoreEvent,
cx: &mut Context<Self>,
) {
match event {
RpcUserStoreEvent::PlanUpdated => {
cx.spawn(async move |this, cx| {
let cloud_client =
cx.update(|cx| this.read_with(cx, |this, _cx| this.cloud_client.clone()))??;
let response = cloud_client
.get_authenticated_user()
.await
.context("failed to fetch authenticated user")?;
cx.update(|cx| {
this.update(cx, |this, _cx| {
this.update_authenticated_user(response);
})
})??;
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
_ => {}
}
}
}

View file

@ -1,8 +1,11 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{Context as _, Result, anyhow};
use chrono::Duration;
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
use futures::{StreamExt, stream::BoxStream};
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
use http_client::{AsyncBody, Method, Request, http};
use parking_lot::Mutex;
use rpc::{
ConnectionId, Peer, Receipt, TypedEnvelope,
@ -39,6 +42,44 @@ impl FakeServer {
executor: cx.executor(),
};
client.http_client().as_fake().replace_handler({
let state = server.state.clone();
move |old_handler, req| {
let state = state.clone();
let old_handler = old_handler.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => {
let credentials = parse_authorization_header(&req);
if credentials
!= Some(Credentials {
user_id: client_user_id,
access_token: state.lock().access_token.to_string(),
})
{
return Ok(http_client::Response::builder()
.status(401)
.body("Unauthorized".into())
.unwrap());
}
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response(
client_user_id as i32,
format!("user-{client_user_id}"),
))
.unwrap()
.into(),
)
.unwrap())
}
_ => old_handler(req).await,
}
}
}
});
client
.override_authenticate({
let state = Arc::downgrade(&server.state);
@ -105,7 +146,7 @@ impl FakeServer {
});
client
.authenticate_and_connect(false, &cx.to_async())
.connect(false, &cx.to_async())
.await
.into_response()
.unwrap();
@ -223,3 +264,54 @@ impl Drop for FakeServer {
self.disconnect();
}
}
pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
let mut auth_header = req
.headers()
.get(http::header::AUTHORIZATION)?
.to_str()
.ok()?
.split_whitespace();
let user_id = auth_header.next()?.parse().ok()?;
let access_token = auth_header.next()?;
Some(Credentials {
user_id,
access_token: access_token.to_string(),
})
}
pub fn make_get_authenticated_user_response(
user_id: i32,
github_login: String,
) -> GetAuthenticatedUserResponse {
GetAuthenticatedUserResponse {
user: AuthenticatedUser {
id: user_id,
metrics_id: format!("metrics-id-{user_id}"),
avatar_url: "".to_string(),
github_login,
name: None,
is_staff: false,
accepted_tos_at: None,
},
feature_flags: vec![],
plan: PlanInfo {
plan: Plan::ZedPro,
subscription_period: None,
usage: CurrentUsage {
model_requests: UsageData {
used: 0,
limit: UsageLimit::Limited(500),
},
edit_predictions: UsageData {
used: 250,
limit: UsageLimit::Unlimited,
},
},
trial_started_at: None,
is_usage_based_billing_enabled: false,
is_account_too_young: false,
has_overdue_invoices: false,
},
}
}

View file

@ -1,6 +1,7 @@
use super::{Client, Status, TypedEnvelope, proto};
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
@ -20,7 +21,7 @@ use std::{
sync::{Arc, Weak},
};
use text::ReplicaId;
use util::TryFutureExt as _;
use util::{ResultExt, TryFutureExt as _};
pub type UserId = u64;
@ -110,12 +111,11 @@ pub struct UserStore {
by_github_login: HashMap<SharedString, u64>,
participant_indices: HashMap<u64, ParticipantIndex>,
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_plan: Option<proto::Plan>,
trial_started_at: Option<DateTime<Utc>>,
is_usage_based_billing_enabled: Option<bool>,
account_too_young: Option<bool>,
model_request_usage: Option<ModelRequestUsage>,
edit_prediction_usage: Option<EditPredictionUsage>,
plan_info: Option<PlanInfo>,
current_user: watch::Receiver<Option<Arc<User>>>,
accepted_tos_at: Option<Option<DateTime<Utc>>>,
accepted_tos_at: Option<Option<cloud_api_client::Timestamp>>,
contacts: Vec<Arc<Contact>>,
incoming_contact_requests: Vec<Arc<User>>,
outgoing_contact_requests: Vec<Arc<User>>,
@ -185,10 +185,9 @@ impl UserStore {
users: Default::default(),
by_github_login: Default::default(),
current_user: current_user_rx,
current_plan: None,
trial_started_at: None,
is_usage_based_billing_enabled: None,
account_too_young: None,
plan_info: None,
model_request_usage: None,
edit_prediction_usage: None,
accepted_tos_at: None,
contacts: Default::default(),
incoming_contact_requests: Default::default(),
@ -218,53 +217,30 @@ impl UserStore {
return Ok(());
};
match status {
Status::Connected { .. } => {
Status::Authenticated | Status::Connected { .. } => {
if let Some(user_id) = client.user_id() {
let fetch_user = if let Ok(fetch_user) =
this.update(cx, |this, cx| this.get_user(user_id, cx).log_err())
{
fetch_user
} else {
break;
};
let fetch_private_user_info =
client.request(proto::GetPrivateUserInfo {}).log_err();
let (user, info) =
futures::join!(fetch_user, fetch_private_user_info);
let response = client.cloud_client().get_authenticated_user().await;
let mut current_user = None;
cx.update(|cx| {
if let Some(info) = info {
let staff =
info.staff && !*feature_flags::ZED_DISABLE_STAFF;
cx.update_flags(staff, info.flags);
client.telemetry.set_authenticated_user_info(
Some(info.metrics_id.clone()),
staff,
);
if let Some(response) = response.log_err() {
let user = Arc::new(User {
id: user_id,
github_login: response.user.github_login.clone().into(),
avatar_uri: response.user.avatar_url.clone().into(),
name: response.user.name.clone(),
});
current_user = Some(user.clone());
this.update(cx, |this, cx| {
let accepted_tos_at = {
#[cfg(debug_assertions)]
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok()
{
None
} else {
info.accepted_tos_at
}
#[cfg(not(debug_assertions))]
info.accepted_tos_at
};
this.set_current_user_accepted_tos_at(accepted_tos_at);
cx.emit(Event::PrivateUserInfoUpdated);
this.by_github_login
.insert(user.github_login.clone(), user_id);
this.users.insert(user_id, user);
this.update_authenticated_user(response, cx)
})
} else {
anyhow::Ok(())
}
})??;
current_user_tx.send(user).await.ok();
current_user_tx.send(current_user).await.ok();
this.update(cx, |_, cx| cx.notify())?;
}
@ -345,22 +321,22 @@ impl UserStore {
async fn handle_update_plan(
this: Entity<Self>,
message: TypedEnvelope<proto::UpdateUserPlan>,
_message: TypedEnvelope<proto::UpdateUserPlan>,
mut cx: AsyncApp,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.current_plan = Some(message.payload.plan());
this.trial_started_at = message
.payload
.trial_started_at
.and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0));
this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled;
this.account_too_young = message.payload.account_too_young;
let client = this
.read_with(&cx, |this, _| this.client.upgrade())?
.context("client was dropped")?;
cx.emit(Event::PlanUpdated);
cx.notify();
})?;
Ok(())
let response = client
.cloud_client()
.get_authenticated_user()
.await
.context("failed to fetch authenticated user")?;
this.update(&mut cx, |this, cx| {
this.update_authenticated_user(response, cx);
})
}
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
@ -719,42 +695,131 @@ impl UserStore {
self.current_user.borrow().clone()
}
pub fn current_plan(&self) -> Option<proto::Plan> {
pub fn plan(&self) -> Option<cloud_llm_client::Plan> {
#[cfg(debug_assertions)]
if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() {
return match plan.as_str() {
"free" => Some(proto::Plan::Free),
"trial" => Some(proto::Plan::ZedProTrial),
"pro" => Some(proto::Plan::ZedPro),
"free" => Some(cloud_llm_client::Plan::ZedFree),
"trial" => Some(cloud_llm_client::Plan::ZedProTrial),
"pro" => Some(cloud_llm_client::Plan::ZedPro),
_ => {
panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'");
}
};
}
self.current_plan
self.plan_info.as_ref().map(|info| info.plan)
}
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
self.plan_info
.as_ref()
.and_then(|plan| plan.subscription_period)
.map(|subscription_period| {
(
subscription_period.started_at.0,
subscription_period.ended_at.0,
)
})
}
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
self.trial_started_at
self.plan_info
.as_ref()
.and_then(|plan| plan.trial_started_at)
.map(|trial_started_at| trial_started_at.0)
}
pub fn usage_based_billing_enabled(&self) -> Option<bool> {
self.is_usage_based_billing_enabled
/// Returns whether the user's account is too new to use the service.
pub fn account_too_young(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_account_too_young)
.unwrap_or_default()
}
/// Returns whether the current user has overdue invoices and usage should be blocked.
pub fn has_overdue_invoices(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.has_overdue_invoices)
.unwrap_or_default()
}
pub fn is_usage_based_billing_enabled(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_usage_based_billing_enabled)
.unwrap_or_default()
}
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
self.model_request_usage
}
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
self.model_request_usage = Some(usage);
cx.notify();
}
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
self.edit_prediction_usage
}
pub fn update_edit_prediction_usage(
&mut self,
usage: EditPredictionUsage,
cx: &mut Context<Self>,
) {
self.edit_prediction_usage = Some(usage);
cx.notify();
}
fn update_authenticated_user(
&mut self,
response: GetAuthenticatedUserResponse,
cx: &mut Context<Self>,
) {
let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF;
cx.update_flags(staff, response.feature_flags);
if let Some(client) = self.client.upgrade() {
client
.telemetry
.set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff);
}
let accepted_tos_at = {
#[cfg(debug_assertions)]
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() {
None
} else {
response.user.accepted_tos_at
}
#[cfg(not(debug_assertions))]
response.user.accepted_tos_at
};
self.accepted_tos_at = Some(accepted_tos_at);
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
limit: response.plan.usage.model_requests.limit,
amount: response.plan.usage.model_requests.used as i32,
}));
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
limit: response.plan.usage.edit_predictions.limit,
amount: response.plan.usage.edit_predictions.used as i32,
}));
self.plan_info = Some(response.plan);
cx.emit(Event::PrivateUserInfoUpdated);
}
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
self.current_user.clone()
}
/// Returns whether the user's account is too new to use the service.
pub fn account_too_young(&self) -> bool {
self.account_too_young.unwrap_or(false)
}
pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
pub fn has_accepted_terms_of_service(&self) -> bool {
self.accepted_tos_at
.map(|accepted_tos_at| accepted_tos_at.is_some())
.map_or(false, |accepted_tos_at| accepted_tos_at.is_some())
}
pub fn accept_terms_of_service(&self, cx: &Context<Self>) -> Task<Result<()>> {
@ -766,23 +831,18 @@ impl UserStore {
cx.spawn(async move |this, cx| -> anyhow::Result<()> {
let client = client.upgrade().context("client not found")?;
let response = client
.request(proto::AcceptTermsOfService {})
.cloud_client()
.accept_terms_of_service()
.await
.context("error accepting tos")?;
this.update(cx, |this, cx| {
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at));
this.accepted_tos_at = Some(response.user.accepted_tos_at);
cx.emit(Event::PrivateUserInfoUpdated);
})?;
Ok(())
})
}
fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
self.accepted_tos_at = Some(
accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
);
}
fn load_users(
&self,
request: impl RequestMessage<Response = UsersResponse>,