diff --git a/Cargo.lock b/Cargo.lock index 7ac29aeac0..5ef5a8afd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20599,6 +20599,7 @@ dependencies = [ "call", "client", "clock", + "cloud_api_types", "cloud_llm_client", "collections", "command_palette_hooks", @@ -20619,7 +20620,6 @@ dependencies = [ "menu", "postage", "project", - "proto", "regex", "release_channel", "reqwest_client", diff --git a/crates/client/src/cloud/user_store.rs b/crates/client/src/cloud/user_store.rs index a9b13ca23c..ea432f71ed 100644 --- a/crates/client/src/cloud/user_store.rs +++ b/crates/client/src/cloud/user_store.rs @@ -8,13 +8,14 @@ use cloud_llm_client::Plan; use gpui::{Context, Entity, Subscription, Task}; use util::{ResultExt as _, maybe}; -use crate::UserStore; use crate::user::Event as RpcUserStoreEvent; +use crate::{EditPredictionUsage, RequestUsage, UserStore}; pub struct CloudUserStore { cloud_client: Arc, authenticated_user: Option>, plan_info: Option>, + edit_prediction_usage: Option, _maintain_authenticated_user_task: Task<()>, _rpc_plan_updated_subscription: Subscription, } @@ -32,6 +33,7 @@ impl CloudUserStore { cloud_client: cloud_client.clone(), authenticated_user: None, plan_info: None, + edit_prediction_usage: None, _maintain_authenticated_user_task: cx.spawn(async move |this, cx| { maybe!(async move { loop { @@ -102,8 +104,48 @@ impl CloudUserStore { }) } + pub fn has_accepted_tos(&self) -> bool { + self.authenticated_user + .as_ref() + .map(|user| user.accepted_tos_at.is_some()) + .unwrap_or_default() + } + + /// Returns whether the user's account is too new to use the service. + pub fn account_too_young(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_account_too_young) + .unwrap_or_default() + } + + /// Returns whether the current user has overdue invoices and usage should be blocked. + pub fn has_overdue_invoices(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.has_overdue_invoices) + .unwrap_or_default() + } + + pub fn edit_prediction_usage(&self) -> Option { + self.edit_prediction_usage + } + + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); + } + fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) { self.authenticated_user = Some(Arc::new(response.user)); + self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { + limit: response.plan.usage.edit_predictions.limit, + amount: response.plan.usage.edit_predictions.used as i32, + })); self.plan_info = Some(Arc::new(response.plan)); } diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index e025ec0523..84f30f3530 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -114,7 +114,6 @@ pub struct UserStore { subscription_period: Option<(DateTime, DateTime)>, trial_started_at: Option>, model_request_usage: Option, - edit_prediction_usage: Option, is_usage_based_billing_enabled: Option, account_too_young: Option, has_overdue_invoices: Option, @@ -193,7 +192,6 @@ impl UserStore { subscription_period: None, trial_started_at: None, model_request_usage: None, - edit_prediction_usage: None, is_usage_based_billing_enabled: None, account_too_young: None, has_overdue_invoices: None, @@ -381,12 +379,6 @@ impl UserStore { RequestUsage::from_proto(usage.model_requests_usage_amount, limit) }) .map(ModelRequestUsage); - this.edit_prediction_usage = usage - .edit_predictions_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(EditPredictionUsage); } cx.emit(Event::PlanUpdated); @@ -400,15 +392,6 @@ impl UserStore { cx.notify(); } - pub fn update_edit_prediction_usage( - &mut self, - usage: EditPredictionUsage, - cx: &mut Context, - ) { - self.edit_prediction_usage = Some(usage); - cx.notify(); - } - fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { @@ -797,10 +780,6 @@ impl UserStore { self.model_request_usage } - pub fn edit_prediction_usage(&self) -> Option { - self.edit_prediction_usage - } - pub fn watch_current_user(&self) -> watch::Receiver>> { self.current_user.clone() } diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 72b7132c60..a5d2ac34f5 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -64,9 +64,14 @@ impl LlmApiToken { mut lock: RwLockWriteGuard<'_, Option>, client: &Arc, ) -> Result { - let response = client.request(proto::GetLlmToken {}).await?; - *lock = Some(response.token.clone()); - Ok(response.token.clone()) + let system_id = client + .telemetry() + .system_id() + .map(|system_id| system_id.to_string()); + + let response = client.cloud_client().create_llm_token(system_id).await?; + *lock = Some(response.token.0.clone()); + Ok(response.token.0.clone()) } } diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index e62c15ae10..219dc1e7ae 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -564,7 +564,7 @@ pub fn main() { snippet_provider::init(cx); inline_completion_registry::init( app_state.client.clone(), - app_state.user_store.clone(), + app_state.cloud_user_store.clone(), cx, ); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index a8037f0f90..89d6ff054b 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -1,4 +1,4 @@ -use client::{Client, UserStore}; +use client::{Client, CloudUserStore}; use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::Editor; @@ -14,12 +14,12 @@ use util::ResultExt; use workspace::Workspace; use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; -pub fn init(client: Arc, user_store: Entity, cx: &mut App) { +pub fn init(client: Arc, cloud_user_store: Entity, cx: &mut App) { let editors: Rc, AnyWindowHandle>>> = Rc::default(); cx.observe_new({ let editors = editors.clone(); let client = client.clone(); - let user_store = user_store.clone(); + let cloud_user_store = cloud_user_store.clone(); move |editor: &mut Editor, window, cx: &mut Context| { if !editor.mode().is_full() { return; @@ -49,7 +49,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { editor, provider, &client, - user_store.clone(), + cloud_user_store.clone(), window, cx, ); @@ -61,7 +61,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let mut provider = all_language_settings(None, cx).edit_predictions.provider; cx.spawn({ - let user_store = user_store.clone(); + let cloud_user_store = cloud_user_store.clone(); let editors = editors.clone(); let client = client.clone(); @@ -73,7 +73,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { &editors, provider, &client, - user_store.clone(), + cloud_user_store.clone(), cx, ); }) @@ -86,15 +86,12 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { cx.observe_global::({ let editors = editors.clone(); let client = client.clone(); - let user_store = user_store.clone(); + let cloud_user_store = cloud_user_store.clone(); move |cx| { let new_provider = all_language_settings(None, cx).edit_predictions.provider; if new_provider != provider { - let tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); + let tos_accepted = cloud_user_store.read(cx).has_accepted_tos(); telemetry::event!( "Edit Prediction Provider Changed", @@ -108,7 +105,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { &editors, provider, &client, - user_store.clone(), + cloud_user_store.clone(), cx, ); @@ -149,7 +146,7 @@ fn assign_edit_prediction_providers( editors: &Rc, AnyWindowHandle>>>, provider: EditPredictionProvider, client: &Arc, - user_store: Entity, + cloud_user_store: Entity, cx: &mut App, ) { for (editor, window) in editors.borrow().iter() { @@ -159,7 +156,7 @@ fn assign_edit_prediction_providers( editor, provider, &client, - user_store.clone(), + cloud_user_store.clone(), window, cx, ); @@ -214,7 +211,7 @@ fn assign_edit_prediction_provider( editor: &mut Editor, provider: EditPredictionProvider, client: &Arc, - user_store: Entity, + cloud_user_store: Entity, window: &mut Window, cx: &mut Context, ) { @@ -245,7 +242,7 @@ fn assign_edit_prediction_provider( } } EditPredictionProvider::Zed => { - if client.status().borrow().is_connected() { + if cloud_user_store.read(cx).is_authenticated() { let mut worktree = None; if let Some(buffer) = &singleton_buffer { @@ -267,7 +264,7 @@ fn assign_edit_prediction_provider( .map(|workspace| workspace.downgrade()); let zeta = - zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx); + zeta::Zeta::register(workspace, worktree, client.clone(), cloud_user_store, cx); if let Some(buffer) = &singleton_buffer { if buffer.read(cx).file().is_some() { diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 294d95aefd..26eeda3f22 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -40,7 +40,6 @@ log.workspace = true menu.workspace = true postage.workspace = true project.workspace = true -proto.workspace = true regex.workspace = true release_channel.workspace = true serde.workspace = true @@ -59,9 +58,11 @@ worktree.workspace = true zed_actions.workspace = true [dev-dependencies] -collections = { workspace = true, features = ["test-support"] } +call = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } +cloud_api_types.workspace = true +collections = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } @@ -77,5 +78,4 @@ tree-sitter-rust.workspace = true unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } -call = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index d5c6be278b..d295b7d17c 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -16,7 +16,7 @@ pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; -use client::{Client, EditPredictionUsage, UserStore}; +use client::{Client, CloudUserStore, EditPredictionUsage, UserStore}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, @@ -226,12 +226,9 @@ pub struct Zeta { data_collection_choice: Entity, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - /// Whether the terms of service have been accepted. - tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. update_required: bool, - user_store: Entity, - _user_store_subscription: Subscription, + cloud_user_store: Entity, license_detection_watchers: HashMap>, } @@ -244,11 +241,11 @@ impl Zeta { workspace: Option>, worktree: Option>, client: Arc, - user_store: Entity, + cloud_user_store: Entity, cx: &mut App, ) -> Entity { let this = Self::global(cx).unwrap_or_else(|| { - let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx)); + let entity = cx.new(|cx| Self::new(workspace, client, cloud_user_store, cx)); cx.set_global(ZetaGlobal(entity.clone())); entity }); @@ -271,13 +268,13 @@ impl Zeta { } pub fn usage(&self, cx: &App) -> Option { - self.user_store.read(cx).edit_prediction_usage() + self.cloud_user_store.read(cx).edit_prediction_usage() } fn new( workspace: Option>, client: Arc, - user_store: Entity, + cloud_user_store: Entity, cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); @@ -306,24 +303,9 @@ impl Zeta { .detach_and_log_err(cx); }, ), - tos_accepted: user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false), update_required: false, - _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { - match event { - client::user::Event::PrivateUserInfoUpdated => { - this.tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); - } - _ => {} - } - }), license_detection_watchers: HashMap::default(), - user_store, + cloud_user_store, } } @@ -552,8 +534,8 @@ impl Zeta { if let Some(usage) = usage { this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); + this.cloud_user_store.update(cx, |cloud_user_store, cx| { + cloud_user_store.update_edit_prediction_usage(usage, cx); }); }) .ok(); @@ -894,8 +876,8 @@ and then another if response.status().is_success() { if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); + this.cloud_user_store.update(cx, |cloud_user_store, cx| { + cloud_user_store.update_edit_prediction_usage(usage, cx); }); })?; } @@ -1573,7 +1555,12 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } fn needs_terms_acceptance(&self, cx: &App) -> bool { - !self.zeta.read(cx).tos_accepted + !self + .zeta + .read(cx) + .cloud_user_store + .read(cx) + .has_accepted_tos() } fn is_refreshing(&self) -> bool { @@ -1588,7 +1575,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider _debounce: bool, cx: &mut Context, ) { - if !self.zeta.read(cx).tos_accepted { + if self.needs_terms_acceptance(cx) { return; } @@ -1599,9 +1586,9 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider if self .zeta .read(cx) - .user_store - .read_with(cx, |user_store, _| { - user_store.account_too_young() || user_store.has_overdue_invoices() + .cloud_user_store + .read_with(cx, |cloud_user_store, _cx| { + cloud_user_store.account_too_young() || cloud_user_store.has_overdue_invoices() }) { return; @@ -1819,15 +1806,51 @@ fn tokens_for_bytes(bytes: usize) -> usize { mod tests { use client::test::FakeServer; use clock::FakeSystemClock; + use cloud_api_types::{ + AuthenticatedUser, CreateLlmTokenResponse, GetAuthenticatedUserResponse, LlmToken, PlanInfo, + }; + use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; use language::Point; - use rpc::proto; use settings::SettingsStore; use super::*; + fn make_get_authenticated_user_response() -> GetAuthenticatedUserResponse { + GetAuthenticatedUserResponse { + user: AuthenticatedUser { + id: 1, + metrics_id: "metrics-id-1".to_string(), + avatar_url: "".to_string(), + github_login: "".to_string(), + name: None, + is_staff: false, + accepted_tos_at: None, + }, + feature_flags: vec![], + plan: PlanInfo { + plan: Plan::ZedPro, + subscription_period: None, + usage: CurrentUsage { + model_requests: UsageData { + used: 0, + limit: UsageLimit::Limited(500), + }, + edit_predictions: UsageData { + used: 250, + limit: UsageLimit::Unlimited, + }, + }, + trial_started_at: None, + is_usage_based_billing_enabled: false, + is_account_too_young: false, + has_overdue_invoices: false, + }, + } + } + #[gpui::test] async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); @@ -2027,28 +2050,55 @@ mod tests { <|editable_region_end|> ```"}; - let http_client = FakeHttpClient::create(move |_| async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") - .unwrap(), - output_excerpt: completion_response.to_string(), - }) - .unwrap() - .into(), - ) - .unwrap()) + let http_client = FakeHttpClient::create(move |req| async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response()) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") + .unwrap(), + output_excerpt: completion_response.to_string(), + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } }); let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -2056,13 +2106,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::().await.unwrap(); - let token_request = server.receive::().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); buffer.update(cx, |buffer, cx| { buffer.edit(completion.edits.iter().cloned(), None, cx) @@ -2079,20 +2122,44 @@ mod tests { cx: &mut TestAppContext, ) -> Vec<(Range, String)> { let completion_response = completion_response.to_string(); - let http_client = FakeHttpClient::create(move |_| { + let http_client = FakeHttpClient::create(move |req| { let completion = completion_response.clone(); async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), - output_excerpt: completion, - }) - .unwrap() - .into(), - ) - .unwrap()) + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response()) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::new_v4(), + output_excerpt: completion, + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } } }); @@ -2100,9 +2167,12 @@ mod tests { cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); @@ -2111,13 +2181,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::().await.unwrap(); - let token_request = server.receive::().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); completion .edits