Start separating authentication from connection to collab (#35471)
This pull request should be idempotent, but lays the groundwork for avoiding to connect to collab in order to interact with AI features provided by Zed. Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com> Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
b01d1872cc
commit
f888f3fc0b
46 changed files with 653 additions and 855 deletions
|
@ -17,7 +17,6 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-recursion = "0.3"
|
||||
async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] }
|
||||
base64.workspace = true
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
||||
mod cloud;
|
||||
mod proxy;
|
||||
pub mod telemetry;
|
||||
pub mod user;
|
||||
pub mod zed_urls;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use async_recursion::async_recursion;
|
||||
use async_tungstenite::tungstenite::{
|
||||
client::IntoClientRequest,
|
||||
error::Error as WebsocketError,
|
||||
|
@ -52,7 +50,6 @@ use tokio::net::TcpStream;
|
|||
use url::Url;
|
||||
use util::{ConnectionResult, ResultExt};
|
||||
|
||||
pub use cloud::*;
|
||||
pub use rpc::*;
|
||||
pub use telemetry_events::Event;
|
||||
pub use user::*;
|
||||
|
@ -164,20 +161,8 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
|
|||
let client = client.clone();
|
||||
move |_: &SignIn, cx| {
|
||||
if let Some(client) = client.upgrade() {
|
||||
cx.spawn(
|
||||
async move |cx| match client.authenticate_and_connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("Initial authentication timed out");
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::error!("Initial authentication connection reset");
|
||||
}
|
||||
ConnectionResult::Result(r) => {
|
||||
r.log_err();
|
||||
}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -286,6 +271,8 @@ pub enum Status {
|
|||
SignedOut,
|
||||
UpgradeRequired,
|
||||
Authenticating,
|
||||
Authenticated,
|
||||
AuthenticationError,
|
||||
Connecting,
|
||||
ConnectionError,
|
||||
Connected {
|
||||
|
@ -712,7 +699,7 @@ impl Client {
|
|||
|
||||
let mut delay = INITIAL_RECONNECTION_DELAY;
|
||||
loop {
|
||||
match client.authenticate_and_connect(true, &cx).await {
|
||||
match client.connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("client connect attempt timed out")
|
||||
}
|
||||
|
@ -882,17 +869,122 @@ impl Client {
|
|||
.is_some()
|
||||
}
|
||||
|
||||
#[async_recursion(?Send)]
|
||||
pub async fn authenticate_and_connect(
|
||||
pub async fn sign_in(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Credentials> {
|
||||
if self.status().borrow().is_signed_out() {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx);
|
||||
}
|
||||
|
||||
let mut credentials = None;
|
||||
|
||||
let old_credentials = self.state.read().credentials.clone();
|
||||
if let Some(old_credentials) = old_credentials {
|
||||
self.cloud_client.set_credentials(
|
||||
old_credentials.user_id as u32,
|
||||
old_credentials.access_token.clone(),
|
||||
);
|
||||
|
||||
// Fetch the authenticated user with the old credentials, to ensure they are still valid.
|
||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
||||
credentials = Some(old_credentials);
|
||||
}
|
||||
}
|
||||
|
||||
if credentials.is_none() && try_provider {
|
||||
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
|
||||
self.cloud_client.set_credentials(
|
||||
stored_credentials.user_id as u32,
|
||||
stored_credentials.access_token.clone(),
|
||||
);
|
||||
|
||||
// Fetch the authenticated user with the stored credentials, and
|
||||
// clear them from the credentials provider if that fails.
|
||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
||||
credentials = Some(stored_credentials);
|
||||
} else {
|
||||
self.credentials_provider
|
||||
.delete_credentials(cx)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if credentials.is_none() {
|
||||
let mut status_rx = self.status();
|
||||
let _ = status_rx.next().await;
|
||||
futures::select_biased! {
|
||||
authenticate = self.authenticate(cx).fuse() => {
|
||||
match authenticate {
|
||||
Ok(creds) => {
|
||||
if IMPERSONATE_LOGIN.is_none() {
|
||||
self.credentials_provider
|
||||
.write_credentials(creds.user_id, creds.access_token.clone(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
|
||||
credentials = Some(creds);
|
||||
},
|
||||
Err(err) => {
|
||||
self.set_status(Status::AuthenticationError, cx);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = status_rx.next().fuse() => {
|
||||
return Err(anyhow!("authentication canceled"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let credentials = credentials.unwrap();
|
||||
self.set_id(credentials.user_id);
|
||||
self.cloud_client
|
||||
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
|
||||
self.state.write().credentials = Some(credentials.clone());
|
||||
self.set_status(Status::Authenticated, cx);
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
/// Performs a sign-in and also connects to Collab.
|
||||
///
|
||||
/// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls
|
||||
/// to `sign_in` when we're ready to remove auto-connection to Collab.
|
||||
pub async fn sign_in_with_optional_connect(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<()> {
|
||||
let credentials = self.sign_in(try_provider, cx).await?;
|
||||
|
||||
let connect_result = match self.connect_with_credentials(credentials, cx).await {
|
||||
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
|
||||
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
|
||||
ConnectionResult::Result(result) => result.context("client auth and connect"),
|
||||
};
|
||||
connect_result.log_err();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> ConnectionResult<()> {
|
||||
let was_disconnected = match *self.status().borrow() {
|
||||
Status::SignedOut => true,
|
||||
Status::SignedOut | Status::Authenticated => true,
|
||||
Status::ConnectionError
|
||||
| Status::ConnectionLost
|
||||
| Status::Authenticating { .. }
|
||||
| Status::AuthenticationError
|
||||
| Status::Reauthenticating { .. }
|
||||
| Status::ReconnectionError { .. } => false,
|
||||
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
|
||||
|
@ -905,41 +997,10 @@ impl Client {
|
|||
);
|
||||
}
|
||||
};
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx)
|
||||
}
|
||||
|
||||
let mut read_from_provider = false;
|
||||
let mut credentials = self.state.read().credentials.clone();
|
||||
if credentials.is_none() && try_provider {
|
||||
credentials = self.credentials_provider.read_credentials(cx).await;
|
||||
read_from_provider = credentials.is_some();
|
||||
}
|
||||
|
||||
if credentials.is_none() {
|
||||
let mut status_rx = self.status();
|
||||
let _ = status_rx.next().await;
|
||||
futures::select_biased! {
|
||||
authenticate = self.authenticate(cx).fuse() => {
|
||||
match authenticate {
|
||||
Ok(creds) => credentials = Some(creds),
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
return ConnectionResult::Result(Err(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = status_rx.next().fuse() => {
|
||||
return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
|
||||
}
|
||||
}
|
||||
}
|
||||
let credentials = credentials.unwrap();
|
||||
self.set_id(credentials.user_id);
|
||||
self.cloud_client
|
||||
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
|
||||
let credentials = match self.sign_in(try_provider, cx).await {
|
||||
Ok(credentials) => credentials,
|
||||
Err(err) => return ConnectionResult::Result(Err(err)),
|
||||
};
|
||||
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Connecting, cx);
|
||||
|
@ -947,17 +1008,20 @@ impl Client {
|
|||
self.set_status(Status::Reconnecting, cx);
|
||||
}
|
||||
|
||||
self.connect_with_credentials(credentials, cx).await
|
||||
}
|
||||
|
||||
async fn connect_with_credentials(
|
||||
self: &Arc<Self>,
|
||||
credentials: Credentials,
|
||||
cx: &AsyncApp,
|
||||
) -> ConnectionResult<()> {
|
||||
let mut timeout =
|
||||
futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
|
||||
futures::select_biased! {
|
||||
connection = self.establish_connection(&credentials, cx).fuse() => {
|
||||
match connection {
|
||||
Ok(conn) => {
|
||||
self.state.write().credentials = Some(credentials.clone());
|
||||
if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
|
||||
self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
|
||||
}
|
||||
|
||||
futures::select_biased! {
|
||||
result = self.set_connection(conn, cx).fuse() => {
|
||||
match result.context("client auth and connect") {
|
||||
|
@ -975,15 +1039,8 @@ impl Client {
|
|||
}
|
||||
}
|
||||
Err(EstablishConnectionError::Unauthorized) => {
|
||||
self.state.write().credentials.take();
|
||||
if read_from_provider {
|
||||
self.credentials_provider.delete_credentials(cx).await.log_err();
|
||||
self.set_status(Status::SignedOut, cx);
|
||||
self.authenticate_and_connect(false, cx).await
|
||||
} else {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
|
||||
}
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
|
||||
}
|
||||
Err(EstablishConnectionError::UpgradeRequired) => {
|
||||
self.set_status(Status::UpgradeRequired, cx);
|
||||
|
@ -1733,7 +1790,7 @@ mod tests {
|
|||
});
|
||||
let auth_and_connect = cx.spawn({
|
||||
let client = client.clone();
|
||||
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
|cx| async move { client.connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert!(matches!(status.next().await, Some(Status::Connecting)));
|
||||
|
@ -1810,7 +1867,7 @@ mod tests {
|
|||
|
||||
let _authenticate = cx.spawn({
|
||||
let client = client.clone();
|
||||
move |cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
move |cx| async move { client.connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert_eq!(*auth_count.lock(), 1);
|
||||
|
@ -1818,7 +1875,7 @@ mod tests {
|
|||
|
||||
let _authenticate = cx.spawn({
|
||||
let client = client.clone();
|
||||
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
|cx| async move { client.connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert_eq!(*auth_count.lock(), 2);
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
mod user_store;
|
||||
|
||||
pub use user_store::*;
|
|
@ -1,211 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_api_client::{AuthenticatedUser, CloudApiClient, GetAuthenticatedUserResponse, PlanInfo};
|
||||
use cloud_llm_client::Plan;
|
||||
use gpui::{Context, Entity, Subscription, Task};
|
||||
use util::{ResultExt as _, maybe};
|
||||
|
||||
use crate::user::Event as RpcUserStoreEvent;
|
||||
use crate::{EditPredictionUsage, ModelRequestUsage, RequestUsage, UserStore};
|
||||
|
||||
pub struct CloudUserStore {
|
||||
cloud_client: Arc<CloudApiClient>,
|
||||
authenticated_user: Option<Arc<AuthenticatedUser>>,
|
||||
plan_info: Option<Arc<PlanInfo>>,
|
||||
model_request_usage: Option<ModelRequestUsage>,
|
||||
edit_prediction_usage: Option<EditPredictionUsage>,
|
||||
_maintain_authenticated_user_task: Task<()>,
|
||||
_rpc_plan_updated_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl CloudUserStore {
|
||||
pub fn new(
|
||||
cloud_client: Arc<CloudApiClient>,
|
||||
rpc_user_store: Entity<UserStore>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let rpc_plan_updated_subscription =
|
||||
cx.subscribe(&rpc_user_store, Self::handle_rpc_user_store_event);
|
||||
|
||||
Self {
|
||||
cloud_client: cloud_client.clone(),
|
||||
authenticated_user: None,
|
||||
plan_info: None,
|
||||
model_request_usage: None,
|
||||
edit_prediction_usage: None,
|
||||
_maintain_authenticated_user_task: cx.spawn(async move |this, cx| {
|
||||
maybe!(async move {
|
||||
loop {
|
||||
let Some(this) = this.upgrade() else {
|
||||
return anyhow::Ok(());
|
||||
};
|
||||
|
||||
if cloud_client.has_credentials() {
|
||||
let already_fetched_authenticated_user = this
|
||||
.read_with(cx, |this, _cx| this.authenticated_user().is_some())
|
||||
.unwrap_or(false);
|
||||
|
||||
if already_fetched_authenticated_user {
|
||||
// We already fetched the authenticated user; nothing to do.
|
||||
} else {
|
||||
let authenticated_user_result = cloud_client
|
||||
.get_authenticated_user()
|
||||
.await
|
||||
.context("failed to fetch authenticated user");
|
||||
if let Some(response) = authenticated_user_result.log_err() {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.update_authenticated_user(response);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.authenticated_user.take();
|
||||
this.plan_info.take();
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(100))
|
||||
.await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
}),
|
||||
_rpc_plan_updated_subscription: rpc_plan_updated_subscription,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.authenticated_user.is_some()
|
||||
}
|
||||
|
||||
pub fn authenticated_user(&self) -> Option<Arc<AuthenticatedUser>> {
|
||||
self.authenticated_user.clone()
|
||||
}
|
||||
|
||||
pub fn plan(&self) -> Option<Plan> {
|
||||
self.plan_info.as_ref().map(|plan| plan.plan)
|
||||
}
|
||||
|
||||
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.and_then(|plan| plan.subscription_period)
|
||||
.map(|subscription_period| {
|
||||
(
|
||||
subscription_period.started_at.0,
|
||||
subscription_period.ended_at.0,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.and_then(|plan| plan.trial_started_at)
|
||||
.map(|trial_started_at| trial_started_at.0)
|
||||
}
|
||||
|
||||
pub fn has_accepted_tos(&self) -> bool {
|
||||
self.authenticated_user
|
||||
.as_ref()
|
||||
.map(|user| user.accepted_tos_at.is_some())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Returns whether the user's account is too new to use the service.
|
||||
pub fn account_too_young(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.is_account_too_young)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Returns whether the current user has overdue invoices and usage should be blocked.
|
||||
pub fn has_overdue_invoices(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.has_overdue_invoices)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn is_usage_based_billing_enabled(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.is_usage_based_billing_enabled)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
|
||||
self.model_request_usage
|
||||
}
|
||||
|
||||
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
|
||||
self.model_request_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
|
||||
self.edit_prediction_usage
|
||||
}
|
||||
|
||||
pub fn update_edit_prediction_usage(
|
||||
&mut self,
|
||||
usage: EditPredictionUsage,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.edit_prediction_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) {
|
||||
self.authenticated_user = Some(Arc::new(response.user));
|
||||
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
|
||||
limit: response.plan.usage.model_requests.limit,
|
||||
amount: response.plan.usage.model_requests.used as i32,
|
||||
}));
|
||||
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
|
||||
limit: response.plan.usage.edit_predictions.limit,
|
||||
amount: response.plan.usage.edit_predictions.used as i32,
|
||||
}));
|
||||
self.plan_info = Some(Arc::new(response.plan));
|
||||
}
|
||||
|
||||
fn handle_rpc_user_store_event(
|
||||
&mut self,
|
||||
_: Entity<UserStore>,
|
||||
event: &RpcUserStoreEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
RpcUserStoreEvent::PlanUpdated => {
|
||||
cx.spawn(async move |this, cx| {
|
||||
let cloud_client =
|
||||
cx.update(|cx| this.read_with(cx, |this, _cx| this.cloud_client.clone()))??;
|
||||
|
||||
let response = cloud_client
|
||||
.get_authenticated_user()
|
||||
.await
|
||||
.context("failed to fetch authenticated user")?;
|
||||
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.update_authenticated_user(response);
|
||||
})
|
||||
})??;
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,8 +1,11 @@
|
|||
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::Duration;
|
||||
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
|
||||
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
|
||||
use futures::{StreamExt, stream::BoxStream};
|
||||
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
|
||||
use http_client::{AsyncBody, Method, Request, http};
|
||||
use parking_lot::Mutex;
|
||||
use rpc::{
|
||||
ConnectionId, Peer, Receipt, TypedEnvelope,
|
||||
|
@ -39,6 +42,44 @@ impl FakeServer {
|
|||
executor: cx.executor(),
|
||||
};
|
||||
|
||||
client.http_client().as_fake().replace_handler({
|
||||
let state = server.state.clone();
|
||||
move |old_handler, req| {
|
||||
let state = state.clone();
|
||||
let old_handler = old_handler.clone();
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::GET, "/client/users/me") => {
|
||||
let credentials = parse_authorization_header(&req);
|
||||
if credentials
|
||||
!= Some(Credentials {
|
||||
user_id: client_user_id,
|
||||
access_token: state.lock().access_token.to_string(),
|
||||
})
|
||||
{
|
||||
return Ok(http_client::Response::builder()
|
||||
.status(401)
|
||||
.body("Unauthorized".into())
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
Ok(http_client::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
serde_json::to_string(&make_get_authenticated_user_response(
|
||||
client_user_id as i32,
|
||||
format!("user-{client_user_id}"),
|
||||
))
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
_ => old_handler(req).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
client
|
||||
.override_authenticate({
|
||||
let state = Arc::downgrade(&server.state);
|
||||
|
@ -105,7 +146,7 @@ impl FakeServer {
|
|||
});
|
||||
|
||||
client
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
|
@ -223,3 +264,54 @@ impl Drop for FakeServer {
|
|||
self.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
|
||||
let mut auth_header = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)?
|
||||
.to_str()
|
||||
.ok()?
|
||||
.split_whitespace();
|
||||
let user_id = auth_header.next()?.parse().ok()?;
|
||||
let access_token = auth_header.next()?;
|
||||
Some(Credentials {
|
||||
user_id,
|
||||
access_token: access_token.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn make_get_authenticated_user_response(
|
||||
user_id: i32,
|
||||
github_login: String,
|
||||
) -> GetAuthenticatedUserResponse {
|
||||
GetAuthenticatedUserResponse {
|
||||
user: AuthenticatedUser {
|
||||
id: user_id,
|
||||
metrics_id: format!("metrics-id-{user_id}"),
|
||||
avatar_url: "".to_string(),
|
||||
github_login,
|
||||
name: None,
|
||||
is_staff: false,
|
||||
accepted_tos_at: None,
|
||||
},
|
||||
feature_flags: vec![],
|
||||
plan: PlanInfo {
|
||||
plan: Plan::ZedPro,
|
||||
subscription_period: None,
|
||||
usage: CurrentUsage {
|
||||
model_requests: UsageData {
|
||||
used: 0,
|
||||
limit: UsageLimit::Limited(500),
|
||||
},
|
||||
edit_predictions: UsageData {
|
||||
used: 250,
|
||||
limit: UsageLimit::Unlimited,
|
||||
},
|
||||
},
|
||||
trial_started_at: None,
|
||||
is_usage_based_billing_enabled: false,
|
||||
is_account_too_young: false,
|
||||
has_overdue_invoices: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use super::{Client, Status, TypedEnvelope, proto};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo};
|
||||
use cloud_llm_client::{
|
||||
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
|
||||
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
|
||||
|
@ -20,7 +21,7 @@ use std::{
|
|||
sync::{Arc, Weak},
|
||||
};
|
||||
use text::ReplicaId;
|
||||
use util::TryFutureExt as _;
|
||||
use util::{ResultExt, TryFutureExt as _};
|
||||
|
||||
pub type UserId = u64;
|
||||
|
||||
|
@ -110,12 +111,11 @@ pub struct UserStore {
|
|||
by_github_login: HashMap<SharedString, u64>,
|
||||
participant_indices: HashMap<u64, ParticipantIndex>,
|
||||
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
|
||||
current_plan: Option<proto::Plan>,
|
||||
trial_started_at: Option<DateTime<Utc>>,
|
||||
is_usage_based_billing_enabled: Option<bool>,
|
||||
account_too_young: Option<bool>,
|
||||
model_request_usage: Option<ModelRequestUsage>,
|
||||
edit_prediction_usage: Option<EditPredictionUsage>,
|
||||
plan_info: Option<PlanInfo>,
|
||||
current_user: watch::Receiver<Option<Arc<User>>>,
|
||||
accepted_tos_at: Option<Option<DateTime<Utc>>>,
|
||||
accepted_tos_at: Option<Option<cloud_api_client::Timestamp>>,
|
||||
contacts: Vec<Arc<Contact>>,
|
||||
incoming_contact_requests: Vec<Arc<User>>,
|
||||
outgoing_contact_requests: Vec<Arc<User>>,
|
||||
|
@ -185,10 +185,9 @@ impl UserStore {
|
|||
users: Default::default(),
|
||||
by_github_login: Default::default(),
|
||||
current_user: current_user_rx,
|
||||
current_plan: None,
|
||||
trial_started_at: None,
|
||||
is_usage_based_billing_enabled: None,
|
||||
account_too_young: None,
|
||||
plan_info: None,
|
||||
model_request_usage: None,
|
||||
edit_prediction_usage: None,
|
||||
accepted_tos_at: None,
|
||||
contacts: Default::default(),
|
||||
incoming_contact_requests: Default::default(),
|
||||
|
@ -218,53 +217,30 @@ impl UserStore {
|
|||
return Ok(());
|
||||
};
|
||||
match status {
|
||||
Status::Connected { .. } => {
|
||||
Status::Authenticated | Status::Connected { .. } => {
|
||||
if let Some(user_id) = client.user_id() {
|
||||
let fetch_user = if let Ok(fetch_user) =
|
||||
this.update(cx, |this, cx| this.get_user(user_id, cx).log_err())
|
||||
{
|
||||
fetch_user
|
||||
} else {
|
||||
break;
|
||||
};
|
||||
let fetch_private_user_info =
|
||||
client.request(proto::GetPrivateUserInfo {}).log_err();
|
||||
let (user, info) =
|
||||
futures::join!(fetch_user, fetch_private_user_info);
|
||||
|
||||
let response = client.cloud_client().get_authenticated_user().await;
|
||||
let mut current_user = None;
|
||||
cx.update(|cx| {
|
||||
if let Some(info) = info {
|
||||
let staff =
|
||||
info.staff && !*feature_flags::ZED_DISABLE_STAFF;
|
||||
cx.update_flags(staff, info.flags);
|
||||
client.telemetry.set_authenticated_user_info(
|
||||
Some(info.metrics_id.clone()),
|
||||
staff,
|
||||
);
|
||||
|
||||
if let Some(response) = response.log_err() {
|
||||
let user = Arc::new(User {
|
||||
id: user_id,
|
||||
github_login: response.user.github_login.clone().into(),
|
||||
avatar_uri: response.user.avatar_url.clone().into(),
|
||||
name: response.user.name.clone(),
|
||||
});
|
||||
current_user = Some(user.clone());
|
||||
this.update(cx, |this, cx| {
|
||||
let accepted_tos_at = {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok()
|
||||
{
|
||||
None
|
||||
} else {
|
||||
info.accepted_tos_at
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
info.accepted_tos_at
|
||||
};
|
||||
|
||||
this.set_current_user_accepted_tos_at(accepted_tos_at);
|
||||
cx.emit(Event::PrivateUserInfoUpdated);
|
||||
this.by_github_login
|
||||
.insert(user.github_login.clone(), user_id);
|
||||
this.users.insert(user_id, user);
|
||||
this.update_authenticated_user(response, cx)
|
||||
})
|
||||
} else {
|
||||
anyhow::Ok(())
|
||||
}
|
||||
})??;
|
||||
|
||||
current_user_tx.send(user).await.ok();
|
||||
current_user_tx.send(current_user).await.ok();
|
||||
|
||||
this.update(cx, |_, cx| cx.notify())?;
|
||||
}
|
||||
|
@ -345,22 +321,22 @@ impl UserStore {
|
|||
|
||||
async fn handle_update_plan(
|
||||
this: Entity<Self>,
|
||||
message: TypedEnvelope<proto::UpdateUserPlan>,
|
||||
_message: TypedEnvelope<proto::UpdateUserPlan>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.current_plan = Some(message.payload.plan());
|
||||
this.trial_started_at = message
|
||||
.payload
|
||||
.trial_started_at
|
||||
.and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0));
|
||||
this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled;
|
||||
this.account_too_young = message.payload.account_too_young;
|
||||
let client = this
|
||||
.read_with(&cx, |this, _| this.client.upgrade())?
|
||||
.context("client was dropped")?;
|
||||
|
||||
cx.emit(Event::PlanUpdated);
|
||||
cx.notify();
|
||||
})?;
|
||||
Ok(())
|
||||
let response = client
|
||||
.cloud_client()
|
||||
.get_authenticated_user()
|
||||
.await
|
||||
.context("failed to fetch authenticated user")?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.update_authenticated_user(response, cx);
|
||||
})
|
||||
}
|
||||
|
||||
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
|
||||
|
@ -719,42 +695,131 @@ impl UserStore {
|
|||
self.current_user.borrow().clone()
|
||||
}
|
||||
|
||||
pub fn current_plan(&self) -> Option<proto::Plan> {
|
||||
pub fn plan(&self) -> Option<cloud_llm_client::Plan> {
|
||||
#[cfg(debug_assertions)]
|
||||
if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() {
|
||||
return match plan.as_str() {
|
||||
"free" => Some(proto::Plan::Free),
|
||||
"trial" => Some(proto::Plan::ZedProTrial),
|
||||
"pro" => Some(proto::Plan::ZedPro),
|
||||
"free" => Some(cloud_llm_client::Plan::ZedFree),
|
||||
"trial" => Some(cloud_llm_client::Plan::ZedProTrial),
|
||||
"pro" => Some(cloud_llm_client::Plan::ZedPro),
|
||||
_ => {
|
||||
panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
self.current_plan
|
||||
self.plan_info.as_ref().map(|info| info.plan)
|
||||
}
|
||||
|
||||
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.and_then(|plan| plan.subscription_period)
|
||||
.map(|subscription_period| {
|
||||
(
|
||||
subscription_period.started_at.0,
|
||||
subscription_period.ended_at.0,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
|
||||
self.trial_started_at
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.and_then(|plan| plan.trial_started_at)
|
||||
.map(|trial_started_at| trial_started_at.0)
|
||||
}
|
||||
|
||||
pub fn usage_based_billing_enabled(&self) -> Option<bool> {
|
||||
self.is_usage_based_billing_enabled
|
||||
/// Returns whether the user's account is too new to use the service.
|
||||
pub fn account_too_young(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.is_account_too_young)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Returns whether the current user has overdue invoices and usage should be blocked.
|
||||
pub fn has_overdue_invoices(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.has_overdue_invoices)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn is_usage_based_billing_enabled(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.is_usage_based_billing_enabled)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
|
||||
self.model_request_usage
|
||||
}
|
||||
|
||||
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
|
||||
self.model_request_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
|
||||
self.edit_prediction_usage
|
||||
}
|
||||
|
||||
pub fn update_edit_prediction_usage(
|
||||
&mut self,
|
||||
usage: EditPredictionUsage,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.edit_prediction_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn update_authenticated_user(
|
||||
&mut self,
|
||||
response: GetAuthenticatedUserResponse,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF;
|
||||
cx.update_flags(staff, response.feature_flags);
|
||||
if let Some(client) = self.client.upgrade() {
|
||||
client
|
||||
.telemetry
|
||||
.set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff);
|
||||
}
|
||||
|
||||
let accepted_tos_at = {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() {
|
||||
None
|
||||
} else {
|
||||
response.user.accepted_tos_at
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
response.user.accepted_tos_at
|
||||
};
|
||||
|
||||
self.accepted_tos_at = Some(accepted_tos_at);
|
||||
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
|
||||
limit: response.plan.usage.model_requests.limit,
|
||||
amount: response.plan.usage.model_requests.used as i32,
|
||||
}));
|
||||
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
|
||||
limit: response.plan.usage.edit_predictions.limit,
|
||||
amount: response.plan.usage.edit_predictions.used as i32,
|
||||
}));
|
||||
self.plan_info = Some(response.plan);
|
||||
cx.emit(Event::PrivateUserInfoUpdated);
|
||||
}
|
||||
|
||||
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
|
||||
self.current_user.clone()
|
||||
}
|
||||
|
||||
/// Returns whether the user's account is too new to use the service.
|
||||
pub fn account_too_young(&self) -> bool {
|
||||
self.account_too_young.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
|
||||
pub fn has_accepted_terms_of_service(&self) -> bool {
|
||||
self.accepted_tos_at
|
||||
.map(|accepted_tos_at| accepted_tos_at.is_some())
|
||||
.map_or(false, |accepted_tos_at| accepted_tos_at.is_some())
|
||||
}
|
||||
|
||||
pub fn accept_terms_of_service(&self, cx: &Context<Self>) -> Task<Result<()>> {
|
||||
|
@ -766,23 +831,18 @@ impl UserStore {
|
|||
cx.spawn(async move |this, cx| -> anyhow::Result<()> {
|
||||
let client = client.upgrade().context("client not found")?;
|
||||
let response = client
|
||||
.request(proto::AcceptTermsOfService {})
|
||||
.cloud_client()
|
||||
.accept_terms_of_service()
|
||||
.await
|
||||
.context("error accepting tos")?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at));
|
||||
this.accepted_tos_at = Some(response.user.accepted_tos_at);
|
||||
cx.emit(Event::PrivateUserInfoUpdated);
|
||||
})?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
|
||||
self.accepted_tos_at = Some(
|
||||
accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
|
||||
);
|
||||
}
|
||||
|
||||
fn load_users(
|
||||
&self,
|
||||
request: impl RequestMessage<Response = UsersResponse>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue