diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 0e5da2d43b..ee16f83dc4 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -12,7 +12,7 @@ use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; -use client::{ModelRequestUsage, RequestUsage}; +use client::{CloudUserStore, ModelRequestUsage, RequestUsage}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use collections::HashMap; use feature_flags::{self, FeatureFlagAppExt}; @@ -374,6 +374,7 @@ pub struct Thread { completion_count: usize, pending_completions: Vec, project: Entity, + cloud_user_store: Entity, prompt_builder: Arc, tools: Entity, tool_use: ToolUseState, @@ -444,6 +445,7 @@ pub struct ExceededWindowError { impl Thread { pub fn new( project: Entity, + cloud_user_store: Entity, tools: Entity, prompt_builder: Arc, system_prompt: SharedProjectContext, @@ -470,6 +472,7 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), project: project.clone(), + cloud_user_store, prompt_builder, tools: tools.clone(), last_restore_checkpoint: None, @@ -503,6 +506,7 @@ impl Thread { id: ThreadId, serialized: SerializedThread, project: Entity, + cloud_user_store: Entity, tools: Entity, prompt_builder: Arc, project_context: SharedProjectContext, @@ -603,6 +607,7 @@ impl Thread { last_restore_checkpoint: None, pending_checkpoint: None, project: project.clone(), + cloud_user_store, prompt_builder, tools: tools.clone(), tool_use, @@ -3255,16 +3260,14 @@ impl Thread { } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { - self.project.update(cx, |project, cx| { - project.user_store().update(cx, |user_store, cx| { - user_store.update_model_request_usage( - ModelRequestUsage(RequestUsage { - amount: amount as i32, - limit, - }), - cx, - ) - }) + self.cloud_user_store.update(cx, |cloud_user_store, cx| { + cloud_user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) }); } @@ -3883,6 +3886,7 @@ fn main() {{ thread.id.clone(), serialized, thread.project.clone(), + thread.cloud_user_store.clone(), thread.tools.clone(), thread.prompt_builder.clone(), thread.project_context.clone(), @@ -5479,10 +5483,16 @@ fn main() {{ let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let (client, user_store) = + project.read_with(cx, |project, _cx| (project.client(), project.user_store())); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx)); + let thread_store = cx .update(|_, cx| { ThreadStore::load( project.clone(), + cloud_user_store, cx.new(|_| ToolWorkingSet::default()), None, Arc::new(PromptBuilder::new(None).unwrap()), diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index cc7cb50c91..6efa56f233 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -8,6 +8,7 @@ use agent_settings::{AgentProfileId, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{Tool, ToolId, ToolWorkingSet}; use chrono::{DateTime, Utc}; +use client::CloudUserStore; use collections::HashMap; use context_server::ContextServerId; use futures::{ @@ -104,6 +105,7 @@ pub type TextThreadStore = assistant_context::ContextStore; pub struct ThreadStore { project: Entity, + cloud_user_store: Entity, tools: Entity, prompt_builder: Arc, prompt_store: Option>, @@ -124,6 +126,7 @@ impl EventEmitter for ThreadStore {} impl ThreadStore { pub fn load( project: Entity, + cloud_user_store: Entity, tools: Entity, prompt_store: Option>, prompt_builder: Arc, @@ -133,8 +136,14 @@ impl ThreadStore { let (thread_store, ready_rx) = cx.update(|cx| { let mut option_ready_rx = None; let thread_store = cx.new(|cx| { - let (thread_store, ready_rx) = - Self::new(project, tools, prompt_builder, prompt_store, cx); + let (thread_store, ready_rx) = Self::new( + project, + cloud_user_store, + tools, + prompt_builder, + prompt_store, + cx, + ); option_ready_rx = Some(ready_rx); thread_store }); @@ -147,6 +156,7 @@ impl ThreadStore { fn new( project: Entity, + cloud_user_store: Entity, tools: Entity, prompt_builder: Arc, prompt_store: Option>, @@ -190,6 +200,7 @@ impl ThreadStore { let this = Self { project, + cloud_user_store, tools, prompt_builder, prompt_store, @@ -407,6 +418,7 @@ impl ThreadStore { cx.new(|cx| { Thread::new( self.project.clone(), + self.cloud_user_store.clone(), self.tools.clone(), self.prompt_builder.clone(), self.project_context.clone(), @@ -425,6 +437,7 @@ impl ThreadStore { ThreadId::new(), serialized, self.project.clone(), + self.cloud_user_store.clone(), self.tools.clone(), self.prompt_builder.clone(), self.project_context.clone(), @@ -456,6 +469,7 @@ impl ThreadStore { id.clone(), thread, this.project.clone(), + this.cloud_user_store.clone(), this.tools.clone(), this.prompt_builder.clone(), this.project_context.clone(), diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 04a093c7d0..1669c24a1b 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -3820,6 +3820,7 @@ mod tests { use super::*; use agent::{MessageSegment, context::ContextLoadResult, thread_store}; use assistant_tool::{ToolRegistry, ToolWorkingSet}; + use client::CloudUserStore; use editor::EditorSettings; use fs::FakeFs; use gpui::{AppContext, TestAppContext, VisualTestContext}; @@ -4116,10 +4117,16 @@ mod tests { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let (client, user_store) = + project.read_with(cx, |project, _cx| (project.client(), project.user_store())); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx)); + let thread_store = cx .update(|_, cx| { ThreadStore::load( project.clone(), + cloud_user_store, cx.new(|_| ToolWorkingSet::default()), None, Arc::new(PromptBuilder::new(None).unwrap()), diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index ec0a11f86b..5c8011cb18 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1893,6 +1893,7 @@ mod tests { use agent::thread_store::{self, ThreadStore}; use agent_settings::AgentSettings; use assistant_tool::ToolWorkingSet; + use client::CloudUserStore; use editor::EditorSettings; use gpui::{TestAppContext, UpdateGlobal, VisualTestContext}; use project::{FakeFs, Project}; @@ -1932,11 +1933,17 @@ mod tests { }) .unwrap(); + let (client, user_store) = + project.read_with(cx, |project, _cx| (project.client(), project.user_store())); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx)); + let prompt_store = None; let thread_store = cx .update(|cx| { ThreadStore::load( project.clone(), + cloud_user_store, cx.new(|_| ToolWorkingSet::default()), prompt_store, Arc::new(PromptBuilder::new(None).unwrap()), @@ -2098,11 +2105,17 @@ mod tests { }) .unwrap(); + let (client, user_store) = + project.read_with(cx, |project, _cx| (project.client(), project.user_store())); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx)); + let prompt_store = None; let thread_store = cx .update(|cx| { ThreadStore::load( project.clone(), + cloud_user_store, cx.new(|_| ToolWorkingSet::default()), prompt_store, Arc::new(PromptBuilder::new(None).unwrap()), diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index e7b1943561..a39e022df4 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -43,7 +43,7 @@ use anyhow::{Result, anyhow}; use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; -use client::{DisableAiSettings, UserStore, zed_urls}; +use client::{CloudUserStore, DisableAiSettings, UserStore, zed_urls}; use cloud_llm_client::{CompletionIntent, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; use feature_flags::{self, FeatureFlagAppExt}; @@ -427,6 +427,7 @@ impl ActiveView { pub struct AgentPanel { workspace: WeakEntity, user_store: Entity, + cloud_user_store: Entity, project: Entity, fs: Arc, language_registry: Arc, @@ -486,6 +487,7 @@ impl AgentPanel { let project = workspace.project().clone(); ThreadStore::load( project, + workspace.app_state().cloud_user_store.clone(), tools.clone(), prompt_store.clone(), prompt_builder.clone(), @@ -553,6 +555,7 @@ impl AgentPanel { let thread = thread_store.update(cx, |this, cx| this.create_thread(cx)); let fs = workspace.app_state().fs.clone(); let user_store = workspace.app_state().user_store.clone(); + let cloud_user_store = workspace.app_state().cloud_user_store.clone(); let project = workspace.project(); let language_registry = project.read(cx).languages().clone(); let client = workspace.client().clone(); @@ -579,7 +582,7 @@ impl AgentPanel { MessageEditor::new( fs.clone(), workspace.clone(), - user_store.clone(), + cloud_user_store.clone(), message_editor_context_store.clone(), prompt_store.clone(), thread_store.downgrade(), @@ -706,6 +709,7 @@ impl AgentPanel { active_view, workspace, user_store, + cloud_user_store, project: project.clone(), fs: fs.clone(), language_registry, @@ -848,7 +852,7 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), + self.cloud_user_store.clone(), context_store.clone(), self.prompt_store.clone(), self.thread_store.downgrade(), @@ -1122,7 +1126,7 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.user_store.clone(), + self.cloud_user_store.clone(), context_store, self.prompt_store.clone(), self.thread_store.downgrade(), @@ -1821,8 +1825,8 @@ impl AgentPanel { } fn render_toolbar(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let user_store = self.user_store.read(cx); - let usage = user_store.model_request_usage(); + let cloud_user_store = self.cloud_user_store.read(cx); + let usage = cloud_user_store.model_request_usage(); let account_url = zed_urls::account_url(cx); diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index 082d1dfb51..e00a0087eb 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -17,7 +17,7 @@ use agent::{ use agent_settings::{AgentSettings, CompletionMode}; use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; -use client::UserStore; +use client::CloudUserStore; use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; @@ -43,7 +43,6 @@ use language_model::{ use multi_buffer; use project::Project; use prompt_store::PromptStore; -use proto::Plan; use settings::Settings; use std::time::Duration; use theme::ThemeSettings; @@ -79,7 +78,7 @@ pub struct MessageEditor { editor: Entity, workspace: WeakEntity, project: Entity, - user_store: Entity, + cloud_user_store: Entity, context_store: Entity, prompt_store: Option>, history_store: Option>, @@ -159,7 +158,7 @@ impl MessageEditor { pub fn new( fs: Arc, workspace: WeakEntity, - user_store: Entity, + cloud_user_store: Entity, context_store: Entity, prompt_store: Option>, thread_store: WeakEntity, @@ -231,7 +230,7 @@ impl MessageEditor { Self { editor: editor.clone(), project: thread.read(cx).project().clone(), - user_store, + cloud_user_store, thread, incompatible_tools_state: incompatible_tools.clone(), workspace, @@ -1287,26 +1286,16 @@ impl MessageEditor { return None; } - let user_store = self.user_store.read(cx); - - let ubb_enable = user_store - .usage_based_billing_enabled() - .map_or(false, |enabled| enabled); - - if ubb_enable { + let cloud_user_store = self.cloud_user_store.read(cx); + if cloud_user_store.is_usage_based_billing_enabled() { return None; } - let plan = user_store - .current_plan() - .map(|plan| match plan { - Plan::Free => cloud_llm_client::Plan::ZedFree, - Plan::ZedPro => cloud_llm_client::Plan::ZedPro, - Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial, - }) + let plan = cloud_user_store + .plan() .unwrap_or(cloud_llm_client::Plan::ZedFree); - let usage = user_store.model_request_usage()?; + let usage = cloud_user_store.model_request_usage()?; Some( div() @@ -1769,7 +1758,7 @@ impl AgentPreview for MessageEditor { ) -> Option { if let Some(workspace) = workspace.upgrade() { let fs = workspace.read(cx).app_state().fs.clone(); - let user_store = workspace.read(cx).app_state().user_store.clone(); + let cloud_user_store = workspace.read(cx).app_state().cloud_user_store.clone(); let project = workspace.read(cx).project().clone(); let weak_project = project.downgrade(); let context_store = cx.new(|_cx| ContextStore::new(weak_project, None)); @@ -1782,7 +1771,7 @@ impl AgentPreview for MessageEditor { MessageEditor::new( fs, workspace.downgrade(), - user_store, + cloud_user_store, context_store, None, thread_store.downgrade(), diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index eda7eee0e3..13619da25c 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -7,7 +7,7 @@ use crate::{ }; use Role::*; use assistant_tool::ToolRegistry; -use client::{Client, UserStore}; +use client::{Client, CloudUserStore, UserStore}; use collections::HashMap; use fs::FakeFs; use futures::{FutureExt, future::LocalBoxFuture}; @@ -1470,12 +1470,14 @@ impl EditAgentTest { client::init_settings(cx); let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); settings::init(cx); Project::init_settings(cx); language::init(cx); language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), cx); + language_models::init(user_store.clone(), cloud_user_store, client.clone(), cx); crate::init(client.http_client(), cx); }); diff --git a/crates/client/src/cloud/user_store.rs b/crates/client/src/cloud/user_store.rs index ea432f71ed..78444b3f95 100644 --- a/crates/client/src/cloud/user_store.rs +++ b/crates/client/src/cloud/user_store.rs @@ -9,12 +9,13 @@ use gpui::{Context, Entity, Subscription, Task}; use util::{ResultExt as _, maybe}; use crate::user::Event as RpcUserStoreEvent; -use crate::{EditPredictionUsage, RequestUsage, UserStore}; +use crate::{EditPredictionUsage, ModelRequestUsage, RequestUsage, UserStore}; pub struct CloudUserStore { cloud_client: Arc, authenticated_user: Option>, plan_info: Option>, + model_request_usage: Option, edit_prediction_usage: Option, _maintain_authenticated_user_task: Task<()>, _rpc_plan_updated_subscription: Subscription, @@ -33,6 +34,7 @@ impl CloudUserStore { cloud_client: cloud_client.clone(), authenticated_user: None, plan_info: None, + model_request_usage: None, edit_prediction_usage: None, _maintain_authenticated_user_task: cx.spawn(async move |this, cx| { maybe!(async move { @@ -104,6 +106,13 @@ impl CloudUserStore { }) } + pub fn trial_started_at(&self) -> Option> { + self.plan_info + .as_ref() + .and_then(|plan| plan.trial_started_at) + .map(|trial_started_at| trial_started_at.0) + } + pub fn has_accepted_tos(&self) -> bool { self.authenticated_user .as_ref() @@ -127,6 +136,22 @@ impl CloudUserStore { .unwrap_or_default() } + pub fn is_usage_based_billing_enabled(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_usage_based_billing_enabled) + .unwrap_or_default() + } + + pub fn model_request_usage(&self) -> Option { + self.model_request_usage + } + + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + pub fn edit_prediction_usage(&self) -> Option { self.edit_prediction_usage } @@ -142,6 +167,10 @@ impl CloudUserStore { fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) { self.authenticated_user = Some(Arc::new(response.user)); + self.model_request_usage = Some(ModelRequestUsage(RequestUsage { + limit: response.plan.usage.model_requests.limit, + amount: response.plan.usage.model_requests.used as i32, + })); self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { limit: response.plan.usage.edit_predictions.limit, amount: response.plan.usage.edit_predictions.used as i32, diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 0ba7d1472b..dc762efa5d 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -113,7 +113,6 @@ pub struct UserStore { current_plan: Option, subscription_period: Option<(DateTime, DateTime)>, trial_started_at: Option>, - model_request_usage: Option, is_usage_based_billing_enabled: Option, account_too_young: Option, has_overdue_invoices: Option, @@ -191,7 +190,6 @@ impl UserStore { current_plan: None, subscription_period: None, trial_started_at: None, - model_request_usage: None, is_usage_based_billing_enabled: None, account_too_young: None, has_overdue_invoices: None, @@ -371,27 +369,12 @@ impl UserStore { this.account_too_young = message.payload.account_too_young; this.has_overdue_invoices = message.payload.has_overdue_invoices; - if let Some(usage) = message.payload.usage { - // limits are always present even though they are wrapped in Option - this.model_request_usage = usage - .model_requests_usage_limit - .and_then(|limit| { - RequestUsage::from_proto(usage.model_requests_usage_amount, limit) - }) - .map(ModelRequestUsage); - } - cx.emit(Event::PlanUpdated); cx.notify(); })?; Ok(()) } - pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { - self.model_request_usage = Some(usage); - cx.notify(); - } - fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { @@ -776,10 +759,6 @@ impl UserStore { self.is_usage_based_billing_enabled } - pub fn model_request_usage(&self) -> Option { - self.model_request_usage - } - pub fn watch_current_user(&self) -> watch::Receiver>> { self.current_user.clone() } diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index a02b4a7f0b..8d257a37a7 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -13,7 +13,7 @@ pub(crate) use tool_metrics::*; use ::fs::RealFs; use clap::Parser; -use client::{Client, ProxySettings, UserStore}; +use client::{Client, CloudUserStore, ProxySettings, UserStore}; use collections::{HashMap, HashSet}; use extension::ExtensionHostProxy; use futures::future; @@ -329,6 +329,7 @@ pub struct AgentAppState { pub languages: Arc, pub client: Arc, pub user_store: Entity, + pub cloud_user_store: Entity, pub fs: Arc, pub node_runtime: NodeRuntime, @@ -383,6 +384,8 @@ pub fn init(cx: &mut App) -> Arc { let languages = Arc::new(languages); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + let cloud_user_store = + cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); extension::init(cx); @@ -422,7 +425,12 @@ pub fn init(cx: &mut App) -> Arc { languages.clone(), ); language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), client.clone(), cx); + language_models::init( + user_store.clone(), + cloud_user_store.clone(), + client.clone(), + cx, + ); languages::init(languages.clone(), node_runtime.clone(), cx); prompt_store::init(cx); terminal_view::init(cx); @@ -447,6 +455,7 @@ pub fn init(cx: &mut App) -> Arc { languages, client, user_store, + cloud_user_store, fs, node_runtime, prompt_builder, diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 0f2b4c18ea..54d864ea21 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -221,6 +221,7 @@ impl ExampleInstance { let prompt_store = None; let thread_store = ThreadStore::load( project.clone(), + app_state.cloud_user_store.clone(), tools, prompt_store, app_state.prompt_builder.clone(), diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 18e6f47ed0..a88f12283a 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use ::settings::{Settings, SettingsStore}; -use client::{Client, UserStore}; +use client::{Client, CloudUserStore, UserStore}; use collections::HashSet; use gpui::{App, Context, Entity}; use language_model::{LanguageModelProviderId, LanguageModelRegistry}; @@ -26,11 +26,22 @@ use crate::provider::vercel::VercelLanguageModelProvider; use crate::provider::x_ai::XAiLanguageModelProvider; pub use crate::settings::*; -pub fn init(user_store: Entity, client: Arc, cx: &mut App) { +pub fn init( + user_store: Entity, + cloud_user_store: Entity, + client: Arc, + cx: &mut App, +) { crate::settings::init_settings(cx); let registry = LanguageModelRegistry::global(cx); registry.update(cx, |registry, cx| { - register_language_model_providers(registry, user_store, client.clone(), cx); + register_language_model_providers( + registry, + user_store, + cloud_user_store, + client.clone(), + cx, + ); }); let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx) @@ -100,11 +111,17 @@ fn register_openai_compatible_providers( fn register_language_model_providers( registry: &mut LanguageModelRegistry, user_store: Entity, + cloud_user_store: Entity, client: Arc, cx: &mut Context, ) { registry.register_provider( - CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx), + CloudLanguageModelProvider::new( + user_store.clone(), + cloud_user_store.clone(), + client.clone(), + cx, + ), cx, ); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 3de135c5a2..a5de7f3442 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -2,11 +2,11 @@ use ai_onboarding::YoungAccountBanner; use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; -use client::{Client, ModelRequestUsage, UserStore, zed_urls}; +use client::{Client, CloudUserStore, ModelRequestUsage, UserStore, zed_urls}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, - EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, + EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; @@ -27,7 +27,6 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, }; -use proto::Plan; use release_channel::AppVersion; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -118,6 +117,7 @@ pub struct State { client: Arc, llm_api_token: LlmApiToken, user_store: Entity, + cloud_user_store: Entity, status: client::Status, accept_terms_of_service_task: Option>>, models: Vec>, @@ -133,6 +133,7 @@ impl State { fn new( client: Arc, user_store: Entity, + cloud_user_store: Entity, status: client::Status, cx: &mut Context, ) -> Self { @@ -142,6 +143,7 @@ impl State { client: client.clone(), llm_api_token: LlmApiToken::default(), user_store, + cloud_user_store, status, accept_terms_of_service_task: None, models: Vec::new(), @@ -150,12 +152,19 @@ impl State { recommended_models: Vec::new(), _fetch_models_task: cx.spawn(async move |this, cx| { maybe!(async move { - let (client, llm_api_token) = this - .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; + let (client, cloud_user_store, llm_api_token) = + this.read_with(cx, |this, _cx| { + ( + client.clone(), + this.cloud_user_store.clone(), + this.llm_api_token.clone(), + ) + })?; loop { - let status = this.read_with(cx, |this, _cx| this.status)?; - if matches!(status, client::Status::Connected { .. }) { + let is_authenticated = + cloud_user_store.read_with(cx, |this, _cx| this.is_authenticated())?; + if is_authenticated { break; } @@ -194,8 +203,8 @@ impl State { } } - fn is_signed_out(&self) -> bool { - self.status.is_signed_out() + fn is_signed_out(&self, cx: &App) -> bool { + !self.cloud_user_store.read(cx).is_authenticated() } fn authenticate(&self, cx: &mut Context) -> Task> { @@ -210,10 +219,7 @@ impl State { } fn has_accepted_terms_of_service(&self, cx: &App) -> bool { - self.user_store - .read(cx) - .current_user_has_accepted_terms() - .unwrap_or(false) + self.cloud_user_store.read(cx).has_accepted_tos() } fn accept_terms_of_service(&mut self, cx: &mut Context) { @@ -297,11 +303,24 @@ impl State { } impl CloudLanguageModelProvider { - pub fn new(user_store: Entity, client: Arc, cx: &mut App) -> Self { + pub fn new( + user_store: Entity, + cloud_user_store: Entity, + client: Arc, + cx: &mut App, + ) -> Self { let mut status_rx = client.status(); let status = *status_rx.borrow(); - let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx)); + let state = cx.new(|cx| { + State::new( + client.clone(), + user_store.clone(), + cloud_user_store.clone(), + status, + cx, + ) + }); let state_ref = state.downgrade(); let maintain_client_status = cx.spawn(async move |cx| { @@ -398,7 +417,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn is_authenticated(&self, cx: &App) -> bool { let state = self.state.read(cx); - !state.is_signed_out() && state.has_accepted_terms_of_service(cx) + !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx) } fn authenticate(&self, _cx: &mut App) -> Task> { @@ -614,9 +633,9 @@ impl CloudLanguageModel { .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok()) { let plan = match plan { - cloud_llm_client::Plan::ZedFree => Plan::Free, - cloud_llm_client::Plan::ZedPro => Plan::ZedPro, - cloud_llm_client::Plan::ZedProTrial => Plan::ZedProTrial, + cloud_llm_client::Plan::ZedFree => proto::Plan::Free, + cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro, + cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, }; return Err(anyhow!(ModelRequestLimitReachedError { plan })); } @@ -1118,7 +1137,7 @@ fn response_lines( #[derive(IntoElement, RegisterComponent)] struct ZedAiConfiguration { is_connected: bool, - plan: Option, + plan: Option, subscription_period: Option<(DateTime, DateTime)>, eligible_for_trial: bool, has_accepted_terms_of_service: bool, @@ -1132,15 +1151,15 @@ impl RenderOnce for ZedAiConfiguration { fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement { let young_account_banner = YoungAccountBanner; - let is_pro = self.plan == Some(proto::Plan::ZedPro); + let is_pro = self.plan == Some(Plan::ZedPro); let subscription_text = match (self.plan, self.subscription_period) { - (Some(proto::Plan::ZedPro), Some(_)) => { + (Some(Plan::ZedPro), Some(_)) => { "You have access to Zed's hosted models through your Pro subscription." } - (Some(proto::Plan::ZedProTrial), Some(_)) => { + (Some(Plan::ZedProTrial), Some(_)) => { "You have access to Zed's hosted models through your Pro trial." } - (Some(proto::Plan::Free), Some(_)) => { + (Some(Plan::ZedFree), Some(_)) => { "You have basic access to Zed's hosted models through the Free plan." } _ => { @@ -1262,15 +1281,15 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { let state = self.state.read(cx); - let user_store = state.user_store.read(cx); + let cloud_user_store = state.cloud_user_store.read(cx); ZedAiConfiguration { - is_connected: !state.is_signed_out(), - plan: user_store.current_plan(), - subscription_period: user_store.subscription_period(), - eligible_for_trial: user_store.trial_started_at().is_none(), + is_connected: !state.is_signed_out(cx), + plan: cloud_user_store.plan(), + subscription_period: cloud_user_store.subscription_period(), + eligible_for_trial: cloud_user_store.trial_started_at().is_none(), has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), - account_too_young: user_store.account_too_young(), + account_too_young: cloud_user_store.account_too_young(), accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(), accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(), sign_in_callback: self.sign_in_callback.clone(), @@ -1286,7 +1305,7 @@ impl Component for ZedAiConfiguration { fn preview(_window: &mut Window, _cx: &mut App) -> Option { fn configuration( is_connected: bool, - plan: Option, + plan: Option, eligible_for_trial: bool, account_too_young: bool, has_accepted_terms_of_service: bool, @@ -1330,15 +1349,15 @@ impl Component for ZedAiConfiguration { ), single_example( "Free Plan", - configuration(true, Some(proto::Plan::Free), true, false, true), + configuration(true, Some(Plan::ZedFree), true, false, true), ), single_example( "Zed Pro Trial Plan", - configuration(true, Some(proto::Plan::ZedProTrial), true, false, true), + configuration(true, Some(Plan::ZedProTrial), true, false, true), ), single_example( "Zed Pro Plan", - configuration(true, Some(proto::Plan::ZedPro), true, false, true), + configuration(true, Some(Plan::ZedPro), true, false, true), ), ]) .into_any_element(), diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index a18c112c7e..9859702bf8 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -556,7 +556,12 @@ pub fn main() { ); supermaven::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), cx); - language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_models::init( + app_state.user_store.clone(), + app_state.cloud_user_store.clone(), + app_state.client.clone(), + cx, + ); agent_settings::init(cx); agent_servers::init(cx); web_search::init(cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 8c6da335ab..0a43ec0bbe 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -4488,7 +4488,12 @@ mod tests { ); image_viewer::init(cx); language_model::init(app_state.client.clone(), cx); - language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_models::init( + app_state.user_store.clone(), + app_state.cloud_user_store.clone(), + app_state.client.clone(), + cx, + ); web_search::init(cx); web_search_providers::init(app_state.client.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx); diff --git a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs index 825744572d..1076ee49ea 100644 --- a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs +++ b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs @@ -17,9 +17,10 @@ pub fn load_preview_thread_store( cx: &mut AsyncApp, ) -> Task>> { workspace - .update(cx, |_, cx| { + .update(cx, |workspace, cx| { ThreadStore::load( project.clone(), + workspace.app_state().cloud_user_store.clone(), cx.new(|_| ToolWorkingSet::default()), None, Arc::new(PromptBuilder::new(None).unwrap()),