From 7fdbfc9e8da3b9d1194ebf8d80339ac328a40c4a Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 31 Jul 2025 18:12:04 -0400 Subject: [PATCH] Acquire LLM token from Cloud instead of Collab for Edit Predictions (#35431) This PR updates the Zed Edit Prediction provider to acquire the LLM token from Cloud instead of Collab to allow using Edit Predictions even when disconnected from or unable to connect to the Collab server. Release Notes: - N/A --------- Co-authored-by: Richard Feldman --- Cargo.lock | 2 +- crates/client/src/cloud/user_store.rs | 44 +++- crates/client/src/user.rs | 21 -- .../language_model/src/model/cloud_model.rs | 11 +- crates/zed/src/main.rs | 2 +- .../zed/src/zed/inline_completion_registry.rs | 31 ++- crates/zeta/Cargo.toml | 6 +- crates/zeta/src/zeta.rs | 219 +++++++++++------- 8 files changed, 211 insertions(+), 125 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7ac29aeac0..5ef5a8afd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20599,6 +20599,7 @@ dependencies = [ "call", "client", "clock", + "cloud_api_types", "cloud_llm_client", "collections", "command_palette_hooks", @@ -20619,7 +20620,6 @@ dependencies = [ "menu", "postage", "project", - "proto", "regex", "release_channel", "reqwest_client", diff --git a/crates/client/src/cloud/user_store.rs b/crates/client/src/cloud/user_store.rs index a9b13ca23c..ea432f71ed 100644 --- a/crates/client/src/cloud/user_store.rs +++ b/crates/client/src/cloud/user_store.rs @@ -8,13 +8,14 @@ use cloud_llm_client::Plan; use gpui::{Context, Entity, Subscription, Task}; use util::{ResultExt as _, maybe}; -use crate::UserStore; use crate::user::Event as RpcUserStoreEvent; +use crate::{EditPredictionUsage, RequestUsage, UserStore}; pub struct CloudUserStore { cloud_client: Arc, authenticated_user: Option>, plan_info: Option>, + edit_prediction_usage: Option, _maintain_authenticated_user_task: Task<()>, _rpc_plan_updated_subscription: Subscription, } @@ -32,6 +33,7 @@ impl CloudUserStore { cloud_client: cloud_client.clone(), authenticated_user: None, plan_info: None, + edit_prediction_usage: None, _maintain_authenticated_user_task: cx.spawn(async move |this, cx| { maybe!(async move { loop { @@ -102,8 +104,48 @@ impl CloudUserStore { }) } + pub fn has_accepted_tos(&self) -> bool { + self.authenticated_user + .as_ref() + .map(|user| user.accepted_tos_at.is_some()) + .unwrap_or_default() + } + + /// Returns whether the user's account is too new to use the service. + pub fn account_too_young(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_account_too_young) + .unwrap_or_default() + } + + /// Returns whether the current user has overdue invoices and usage should be blocked. + pub fn has_overdue_invoices(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.has_overdue_invoices) + .unwrap_or_default() + } + + pub fn edit_prediction_usage(&self) -> Option { + self.edit_prediction_usage + } + + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); + } + fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) { self.authenticated_user = Some(Arc::new(response.user)); + self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { + limit: response.plan.usage.edit_predictions.limit, + amount: response.plan.usage.edit_predictions.used as i32, + })); self.plan_info = Some(Arc::new(response.plan)); } diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index e025ec0523..84f30f3530 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -114,7 +114,6 @@ pub struct UserStore { subscription_period: Option<(DateTime, DateTime)>, trial_started_at: Option>, model_request_usage: Option, - edit_prediction_usage: Option, is_usage_based_billing_enabled: Option, account_too_young: Option, has_overdue_invoices: Option, @@ -193,7 +192,6 @@ impl UserStore { subscription_period: None, trial_started_at: None, model_request_usage: None, - edit_prediction_usage: None, is_usage_based_billing_enabled: None, account_too_young: None, has_overdue_invoices: None, @@ -381,12 +379,6 @@ impl UserStore { RequestUsage::from_proto(usage.model_requests_usage_amount, limit) }) .map(ModelRequestUsage); - this.edit_prediction_usage = usage - .edit_predictions_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(EditPredictionUsage); } cx.emit(Event::PlanUpdated); @@ -400,15 +392,6 @@ impl UserStore { cx.notify(); } - pub fn update_edit_prediction_usage( - &mut self, - usage: EditPredictionUsage, - cx: &mut Context, - ) { - self.edit_prediction_usage = Some(usage); - cx.notify(); - } - fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { @@ -797,10 +780,6 @@ impl UserStore { self.model_request_usage } - pub fn edit_prediction_usage(&self) -> Option { - self.edit_prediction_usage - } - pub fn watch_current_user(&self) -> watch::Receiver>> { self.current_user.clone() } diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 72b7132c60..a5d2ac34f5 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -64,9 +64,14 @@ impl LlmApiToken { mut lock: RwLockWriteGuard<'_, Option>, client: &Arc, ) -> Result { - let response = client.request(proto::GetLlmToken {}).await?; - *lock = Some(response.token.clone()); - Ok(response.token.clone()) + let system_id = client + .telemetry() + .system_id() + .map(|system_id| system_id.to_string()); + + let response = client.cloud_client().create_llm_token(system_id).await?; + *lock = Some(response.token.0.clone()); + Ok(response.token.0.clone()) } } diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index e62c15ae10..219dc1e7ae 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -564,7 +564,7 @@ pub fn main() { snippet_provider::init(cx); inline_completion_registry::init( app_state.client.clone(), - app_state.user_store.clone(), + app_state.cloud_user_store.clone(), cx, ); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index a8037f0f90..89d6ff054b 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -1,4 +1,4 @@ -use client::{Client, UserStore}; +use client::{Client, CloudUserStore}; use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::Editor; @@ -14,12 +14,12 @@ use util::ResultExt; use workspace::Workspace; use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; -pub fn init(client: Arc, user_store: Entity, cx: &mut App) { +pub fn init(client: Arc, cloud_user_store: Entity, cx: &mut App) { let editors: Rc, AnyWindowHandle>>> = Rc::default(); cx.observe_new({ let editors = editors.clone(); let client = client.clone(); - let user_store = user_store.clone(); + let cloud_user_store = cloud_user_store.clone(); move |editor: &mut Editor, window, cx: &mut Context| { if !editor.mode().is_full() { return; @@ -49,7 +49,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { editor, provider, &client, - user_store.clone(), + cloud_user_store.clone(), window, cx, ); @@ -61,7 +61,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let mut provider = all_language_settings(None, cx).edit_predictions.provider; cx.spawn({ - let user_store = user_store.clone(); + let cloud_user_store = cloud_user_store.clone(); let editors = editors.clone(); let client = client.clone(); @@ -73,7 +73,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { &editors, provider, &client, - user_store.clone(), + cloud_user_store.clone(), cx, ); }) @@ -86,15 +86,12 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { cx.observe_global::({ let editors = editors.clone(); let client = client.clone(); - let user_store = user_store.clone(); + let cloud_user_store = cloud_user_store.clone(); move |cx| { let new_provider = all_language_settings(None, cx).edit_predictions.provider; if new_provider != provider { - let tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); + let tos_accepted = cloud_user_store.read(cx).has_accepted_tos(); telemetry::event!( "Edit Prediction Provider Changed", @@ -108,7 +105,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { &editors, provider, &client, - user_store.clone(), + cloud_user_store.clone(), cx, ); @@ -149,7 +146,7 @@ fn assign_edit_prediction_providers( editors: &Rc, AnyWindowHandle>>>, provider: EditPredictionProvider, client: &Arc, - user_store: Entity, + cloud_user_store: Entity, cx: &mut App, ) { for (editor, window) in editors.borrow().iter() { @@ -159,7 +156,7 @@ fn assign_edit_prediction_providers( editor, provider, &client, - user_store.clone(), + cloud_user_store.clone(), window, cx, ); @@ -214,7 +211,7 @@ fn assign_edit_prediction_provider( editor: &mut Editor, provider: EditPredictionProvider, client: &Arc, - user_store: Entity, + cloud_user_store: Entity, window: &mut Window, cx: &mut Context, ) { @@ -245,7 +242,7 @@ fn assign_edit_prediction_provider( } } EditPredictionProvider::Zed => { - if client.status().borrow().is_connected() { + if cloud_user_store.read(cx).is_authenticated() { let mut worktree = None; if let Some(buffer) = &singleton_buffer { @@ -267,7 +264,7 @@ fn assign_edit_prediction_provider( .map(|workspace| workspace.downgrade()); let zeta = - zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx); + zeta::Zeta::register(workspace, worktree, client.clone(), cloud_user_store, cx); if let Some(buffer) = &singleton_buffer { if buffer.read(cx).file().is_some() { diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 294d95aefd..26eeda3f22 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -40,7 +40,6 @@ log.workspace = true menu.workspace = true postage.workspace = true project.workspace = true -proto.workspace = true regex.workspace = true release_channel.workspace = true serde.workspace = true @@ -59,9 +58,11 @@ worktree.workspace = true zed_actions.workspace = true [dev-dependencies] -collections = { workspace = true, features = ["test-support"] } +call = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } +cloud_api_types.workspace = true +collections = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } @@ -77,5 +78,4 @@ tree-sitter-rust.workspace = true unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } -call = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index d5c6be278b..d295b7d17c 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -16,7 +16,7 @@ pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; -use client::{Client, EditPredictionUsage, UserStore}; +use client::{Client, CloudUserStore, EditPredictionUsage, UserStore}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, @@ -226,12 +226,9 @@ pub struct Zeta { data_collection_choice: Entity, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - /// Whether the terms of service have been accepted. - tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. update_required: bool, - user_store: Entity, - _user_store_subscription: Subscription, + cloud_user_store: Entity, license_detection_watchers: HashMap>, } @@ -244,11 +241,11 @@ impl Zeta { workspace: Option>, worktree: Option>, client: Arc, - user_store: Entity, + cloud_user_store: Entity, cx: &mut App, ) -> Entity { let this = Self::global(cx).unwrap_or_else(|| { - let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx)); + let entity = cx.new(|cx| Self::new(workspace, client, cloud_user_store, cx)); cx.set_global(ZetaGlobal(entity.clone())); entity }); @@ -271,13 +268,13 @@ impl Zeta { } pub fn usage(&self, cx: &App) -> Option { - self.user_store.read(cx).edit_prediction_usage() + self.cloud_user_store.read(cx).edit_prediction_usage() } fn new( workspace: Option>, client: Arc, - user_store: Entity, + cloud_user_store: Entity, cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); @@ -306,24 +303,9 @@ impl Zeta { .detach_and_log_err(cx); }, ), - tos_accepted: user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false), update_required: false, - _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { - match event { - client::user::Event::PrivateUserInfoUpdated => { - this.tos_accepted = user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false); - } - _ => {} - } - }), license_detection_watchers: HashMap::default(), - user_store, + cloud_user_store, } } @@ -552,8 +534,8 @@ impl Zeta { if let Some(usage) = usage { this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); + this.cloud_user_store.update(cx, |cloud_user_store, cx| { + cloud_user_store.update_edit_prediction_usage(usage, cx); }); }) .ok(); @@ -894,8 +876,8 @@ and then another if response.status().is_success() { if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); + this.cloud_user_store.update(cx, |cloud_user_store, cx| { + cloud_user_store.update_edit_prediction_usage(usage, cx); }); })?; } @@ -1573,7 +1555,12 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider } fn needs_terms_acceptance(&self, cx: &App) -> bool { - !self.zeta.read(cx).tos_accepted + !self + .zeta + .read(cx) + .cloud_user_store + .read(cx) + .has_accepted_tos() } fn is_refreshing(&self) -> bool { @@ -1588,7 +1575,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider _debounce: bool, cx: &mut Context, ) { - if !self.zeta.read(cx).tos_accepted { + if self.needs_terms_acceptance(cx) { return; } @@ -1599,9 +1586,9 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider if self .zeta .read(cx) - .user_store - .read_with(cx, |user_store, _| { - user_store.account_too_young() || user_store.has_overdue_invoices() + .cloud_user_store + .read_with(cx, |cloud_user_store, _cx| { + cloud_user_store.account_too_young() || cloud_user_store.has_overdue_invoices() }) { return; @@ -1819,15 +1806,51 @@ fn tokens_for_bytes(bytes: usize) -> usize { mod tests { use client::test::FakeServer; use clock::FakeSystemClock; + use cloud_api_types::{ + AuthenticatedUser, CreateLlmTokenResponse, GetAuthenticatedUserResponse, LlmToken, PlanInfo, + }; + use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; use language::Point; - use rpc::proto; use settings::SettingsStore; use super::*; + fn make_get_authenticated_user_response() -> GetAuthenticatedUserResponse { + GetAuthenticatedUserResponse { + user: AuthenticatedUser { + id: 1, + metrics_id: "metrics-id-1".to_string(), + avatar_url: "".to_string(), + github_login: "".to_string(), + name: None, + is_staff: false, + accepted_tos_at: None, + }, + feature_flags: vec![], + plan: PlanInfo { + plan: Plan::ZedPro, + subscription_period: None, + usage: CurrentUsage { + model_requests: UsageData { + used: 0, + limit: UsageLimit::Limited(500), + }, + edit_predictions: UsageData { + used: 250, + limit: UsageLimit::Unlimited, + }, + }, + trial_started_at: None, + is_usage_based_billing_enabled: false, + is_account_too_young: false, + has_overdue_invoices: false, + }, + } + } + #[gpui::test] async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); @@ -2027,28 +2050,55 @@ mod tests { <|editable_region_end|> ```"}; - let http_client = FakeHttpClient::create(move |_| async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") - .unwrap(), - output_excerpt: completion_response.to_string(), - }) - .unwrap() - .into(), - ) - .unwrap()) + let http_client = FakeHttpClient::create(move |req| async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response()) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") + .unwrap(), + output_excerpt: completion_response.to_string(), + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } }); let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -2056,13 +2106,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::().await.unwrap(); - let token_request = server.receive::().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); buffer.update(cx, |buffer, cx| { buffer.edit(completion.edits.iter().cloned(), None, cx) @@ -2079,20 +2122,44 @@ mod tests { cx: &mut TestAppContext, ) -> Vec<(Range, String)> { let completion_response = completion_response.to_string(); - let http_client = FakeHttpClient::create(move |_| { + let http_client = FakeHttpClient::create(move |req| { let completion = completion_response.clone(); async move { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), - output_excerpt: completion, - }) - .unwrap() - .into(), - ) - .unwrap()) + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response()) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::new_v4(), + output_excerpt: completion, + }) + .unwrap() + .into(), + ) + .unwrap()), + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } } }); @@ -2100,9 +2167,12 @@ mod tests { cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - let server = FakeServer::for_client(42, &client, cx).await; + // Construct the fake server to authenticate. + let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); @@ -2111,13 +2181,6 @@ mod tests { zeta.request_completion(None, &buffer, cursor, false, cx) }); - server.receive::().await.unwrap(); - let token_request = server.receive::().await.unwrap(); - server.respond( - token_request.receipt(), - proto::GetLlmTokenResponse { token: "".into() }, - ); - let completion = completion_task.await.unwrap().unwrap(); completion .edits