diff --git a/Cargo.lock b/Cargo.lock index 63a66d7150..94ba0cf0ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,7 +114,6 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", - "proto", "rand 0.8.5", "ref-cast", "rope", @@ -359,7 +358,6 @@ dependencies = [ "component", "gpui", "language_model", - "proto", "serde", "smallvec", "telemetry", @@ -1076,17 +1074,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "async-recursion" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "async-recursion" version = "1.1.1" @@ -2972,7 +2959,6 @@ name = "client" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 0.3.2", "async-tungstenite", "base64 0.22.1", "chrono", @@ -7814,6 +7800,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "log", + "parking_lot", "serde", "serde_json", "url", @@ -9085,7 +9072,6 @@ dependencies = [ "open_router", "partial-json-fixer", "project", - "proto", "release_channel", "schemars", "serde", @@ -9823,7 +9809,7 @@ name = "markdown_preview" version = "0.1.0" dependencies = [ "anyhow", - "async-recursion 1.1.1", + "async-recursion", "collections", "editor", "fs", @@ -16192,7 +16178,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_slash_command", - "async-recursion 1.1.1", + "async-recursion", "breadcrumbs", "client", "collections", @@ -19617,7 +19603,7 @@ version = "0.1.0" dependencies = [ "any_vec", "anyhow", - "async-recursion 1.1.1", + "async-recursion", "bincode", "call", "client", @@ -20142,7 +20128,7 @@ dependencies = [ "async-io", "async-lock", "async-process", - "async-recursion 1.1.1", + "async-recursion", "async-task", "async-trait", "blocking", diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index c89a7f3303..7bc0e82cad 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -47,7 +47,6 @@ paths.workspace = true postage.workspace = true project.workspace = true prompt_store.workspace = true -proto.workspace = true ref-cast.workspace = true rope.workspace = true schemars.workspace = true diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index ee16f83dc4..8558dd528d 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -12,8 +12,8 @@ use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; -use client::{CloudUserStore, ModelRequestUsage, RequestUsage}; -use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; +use client::{ModelRequestUsage, RequestUsage}; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; use collections::HashMap; use feature_flags::{self, FeatureFlagAppExt}; use futures::{FutureExt, StreamExt as _, future::Shared}; @@ -37,7 +37,6 @@ use project::{ git_store::{GitStore, GitStoreCheckpoint, RepositoryState}, }; use prompt_store::{ModelContext, PromptBuilder}; -use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -374,7 +373,6 @@ pub struct Thread { completion_count: usize, pending_completions: Vec, project: Entity, - cloud_user_store: Entity, prompt_builder: Arc, tools: Entity, tool_use: ToolUseState, @@ -445,7 +443,6 @@ pub struct ExceededWindowError { impl Thread { pub fn new( project: Entity, - cloud_user_store: Entity, tools: Entity, prompt_builder: Arc, system_prompt: SharedProjectContext, @@ -472,7 +469,6 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), project: project.clone(), - cloud_user_store, prompt_builder, tools: tools.clone(), last_restore_checkpoint: None, @@ -506,7 +502,6 @@ impl Thread { id: ThreadId, serialized: SerializedThread, project: Entity, - cloud_user_store: Entity, tools: Entity, prompt_builder: Arc, project_context: SharedProjectContext, @@ -607,7 +602,6 @@ impl Thread { last_restore_checkpoint: None, pending_checkpoint: None, project: project.clone(), - cloud_user_store, prompt_builder, tools: tools.clone(), tool_use, @@ -3260,15 +3254,18 @@ impl Thread { } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { - self.cloud_user_store.update(cx, |cloud_user_store, cx| { - cloud_user_store.update_model_request_usage( - ModelRequestUsage(RequestUsage { - amount: amount as i32, - limit, - }), - cx, - ) - }); + self.project + .read(cx) + .user_store() + .update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }); } pub fn deny_tool_use( @@ -3886,7 +3883,6 @@ fn main() {{ thread.id.clone(), serialized, thread.project.clone(), - thread.cloud_user_store.clone(), thread.tools.clone(), thread.prompt_builder.clone(), thread.project_context.clone(), @@ -5483,16 +5479,10 @@ fn main() {{ let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let (client, user_store) = - project.read_with(cx, |project, _cx| (project.client(), project.user_store())); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx)); - let thread_store = cx .update(|_, cx| { ThreadStore::load( project.clone(), - cloud_user_store, cx.new(|_| ToolWorkingSet::default()), None, Arc::new(PromptBuilder::new(None).unwrap()), diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 6efa56f233..cc7cb50c91 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -8,7 +8,6 @@ use agent_settings::{AgentProfileId, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{Tool, ToolId, ToolWorkingSet}; use chrono::{DateTime, Utc}; -use client::CloudUserStore; use collections::HashMap; use context_server::ContextServerId; use futures::{ @@ -105,7 +104,6 @@ pub type TextThreadStore = assistant_context::ContextStore; pub struct ThreadStore { project: Entity, - cloud_user_store: Entity, tools: Entity, prompt_builder: Arc, prompt_store: Option>, @@ -126,7 +124,6 @@ impl EventEmitter for ThreadStore {} impl ThreadStore { pub fn load( project: Entity, - cloud_user_store: Entity, tools: Entity, prompt_store: Option>, prompt_builder: Arc, @@ -136,14 +133,8 @@ impl ThreadStore { let (thread_store, ready_rx) = cx.update(|cx| { let mut option_ready_rx = None; let thread_store = cx.new(|cx| { - let (thread_store, ready_rx) = Self::new( - project, - cloud_user_store, - tools, - prompt_builder, - prompt_store, - cx, - ); + let (thread_store, ready_rx) = + Self::new(project, tools, prompt_builder, prompt_store, cx); option_ready_rx = Some(ready_rx); thread_store }); @@ -156,7 +147,6 @@ impl ThreadStore { fn new( project: Entity, - cloud_user_store: Entity, tools: Entity, prompt_builder: Arc, prompt_store: Option>, @@ -200,7 +190,6 @@ impl ThreadStore { let this = Self { project, - cloud_user_store, tools, prompt_builder, prompt_store, @@ -418,7 +407,6 @@ impl ThreadStore { cx.new(|cx| { Thread::new( self.project.clone(), - self.cloud_user_store.clone(), self.tools.clone(), self.prompt_builder.clone(), self.project_context.clone(), @@ -437,7 +425,6 @@ impl ThreadStore { ThreadId::new(), serialized, self.project.clone(), - self.cloud_user_store.clone(), self.tools.clone(), self.prompt_builder.clone(), self.project_context.clone(), @@ -469,7 +456,6 @@ impl ThreadStore { id.clone(), thread, this.project.clone(), - this.cloud_user_store.clone(), this.tools.clone(), this.prompt_builder.clone(), this.project_context.clone(), diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 1669c24a1b..04a093c7d0 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -3820,7 +3820,6 @@ mod tests { use super::*; use agent::{MessageSegment, context::ContextLoadResult, thread_store}; use assistant_tool::{ToolRegistry, ToolWorkingSet}; - use client::CloudUserStore; use editor::EditorSettings; use fs::FakeFs; use gpui::{AppContext, TestAppContext, VisualTestContext}; @@ -4117,16 +4116,10 @@ mod tests { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - let (client, user_store) = - project.read_with(cx, |project, _cx| (project.client(), project.user_store())); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx)); - let thread_store = cx .update(|_, cx| { ThreadStore::load( project.clone(), - cloud_user_store, cx.new(|_| ToolWorkingSet::default()), None, Arc::new(PromptBuilder::new(None).unwrap()), diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index b88b85d85b..dad930be9e 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -7,6 +7,7 @@ use std::{sync::Arc, time::Duration}; use agent_settings::AgentSettings; use assistant_tool::{ToolSource, ToolWorkingSet}; +use cloud_llm_client::Plan; use collections::HashMap; use context_server::ContextServerId; use extension::ExtensionManifest; @@ -25,7 +26,6 @@ use project::{ context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore}, project_settings::{ContextServerSettings, ProjectSettings}, }; -use proto::Plan; use settings::{Settings, update_settings_file}; use ui::{ Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, @@ -180,7 +180,7 @@ impl AgentConfiguration { let current_plan = if is_zed_provider { self.workspace .upgrade() - .and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan()) + .and_then(|workspace| workspace.read(cx).user_store().read(cx).plan()) } else { None }; @@ -508,7 +508,7 @@ impl AgentConfiguration { .blend(cx.theme().colors().text_accent.opacity(0.2)); let (plan_name, label_color, bg_color) = match plan { - Plan::Free => ("Free", Color::Default, free_chip_bg), + Plan::ZedFree => ("Free", Color::Default, free_chip_bg), Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg), Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg), }; diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 135f07a934..c4dc359093 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1896,7 +1896,6 @@ mod tests { use agent::thread_store::{self, ThreadStore}; use agent_settings::AgentSettings; use assistant_tool::ToolWorkingSet; - use client::CloudUserStore; use editor::EditorSettings; use gpui::{TestAppContext, UpdateGlobal, VisualTestContext}; use project::{FakeFs, Project}; @@ -1936,17 +1935,11 @@ mod tests { }) .unwrap(); - let (client, user_store) = - project.read_with(cx, |project, _cx| (project.client(), project.user_store())); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx)); - let prompt_store = None; let thread_store = cx .update(|cx| { ThreadStore::load( project.clone(), - cloud_user_store, cx.new(|_| ToolWorkingSet::default()), prompt_store, Arc::new(PromptBuilder::new(None).unwrap()), @@ -2108,17 +2101,11 @@ mod tests { }) .unwrap(); - let (client, user_store) = - project.read_with(cx, |project, _cx| (project.client(), project.user_store())); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx)); - let prompt_store = None; let thread_store = cx .update(|cx| { ThreadStore::load( project.clone(), - cloud_user_store, cx.new(|_| ToolWorkingSet::default()), prompt_store, Arc::new(PromptBuilder::new(None).unwrap()), diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 7e0d766f91..fcb8dfbac2 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -43,8 +43,8 @@ use anyhow::{Result, anyhow}; use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; use assistant_tool::ToolWorkingSet; -use client::{CloudUserStore, DisableAiSettings, UserStore, zed_urls}; -use cloud_llm_client::{CompletionIntent, UsageLimit}; +use client::{DisableAiSettings, UserStore, zed_urls}; +use cloud_llm_client::{CompletionIntent, Plan, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; use feature_flags::{self, FeatureFlagAppExt}; use fs::Fs; @@ -60,7 +60,6 @@ use language_model::{ }; use project::{Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; -use proto::Plan; use rules_library::{RulesLibrary, open_rules_library}; use search::{BufferSearchBar, buffer_search}; use settings::{Settings, update_settings_file}; @@ -427,7 +426,6 @@ impl ActiveView { pub struct AgentPanel { workspace: WeakEntity, user_store: Entity, - cloud_user_store: Entity, project: Entity, fs: Arc, language_registry: Arc, @@ -487,7 +485,6 @@ impl AgentPanel { let project = workspace.project().clone(); ThreadStore::load( project, - workspace.app_state().cloud_user_store.clone(), tools.clone(), prompt_store.clone(), prompt_builder.clone(), @@ -555,7 +552,6 @@ impl AgentPanel { let thread = thread_store.update(cx, |this, cx| this.create_thread(cx)); let fs = workspace.app_state().fs.clone(); let user_store = workspace.app_state().user_store.clone(); - let cloud_user_store = workspace.app_state().cloud_user_store.clone(); let project = workspace.project(); let language_registry = project.read(cx).languages().clone(); let client = workspace.client().clone(); @@ -582,7 +578,6 @@ impl AgentPanel { MessageEditor::new( fs.clone(), workspace.clone(), - cloud_user_store.clone(), message_editor_context_store.clone(), prompt_store.clone(), thread_store.downgrade(), @@ -697,7 +692,6 @@ impl AgentPanel { let onboarding = cx.new(|cx| { AgentPanelOnboarding::new( user_store.clone(), - cloud_user_store.clone(), client, |_window, cx| { OnboardingUpsell::set_dismissed(true, cx); @@ -710,7 +704,6 @@ impl AgentPanel { active_view, workspace, user_store, - cloud_user_store, project: project.clone(), fs: fs.clone(), language_registry, @@ -853,7 +846,6 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.cloud_user_store.clone(), context_store.clone(), self.prompt_store.clone(), self.thread_store.downgrade(), @@ -1127,7 +1119,6 @@ impl AgentPanel { MessageEditor::new( self.fs.clone(), self.workspace.clone(), - self.cloud_user_store.clone(), context_store, self.prompt_store.clone(), self.thread_store.downgrade(), @@ -1826,8 +1817,8 @@ impl AgentPanel { } fn render_toolbar(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let cloud_user_store = self.cloud_user_store.read(cx); - let usage = cloud_user_store.model_request_usage(); + let user_store = self.user_store.read(cx); + let usage = user_store.model_request_usage(); let account_url = zed_urls::account_url(cx); @@ -2298,10 +2289,10 @@ impl AgentPanel { | ActiveView::Configuration => return false, } - let plan = self.user_store.read(cx).current_plan(); + let plan = self.user_store.read(cx).plan(); let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); - matches!(plan, Some(Plan::Free)) && has_previous_trial + matches!(plan, Some(Plan::ZedFree)) && has_previous_trial } fn should_render_onboarding(&self, cx: &mut Context) -> bool { @@ -2916,7 +2907,7 @@ impl AgentPanel { ) -> AnyElement { let error_message = match plan { Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", - Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.", + Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.", }; let icon = Icon::new(IconName::XCircle) diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index e00a0087eb..2185885347 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -17,7 +17,6 @@ use agent::{ use agent_settings::{AgentSettings, CompletionMode}; use ai_onboarding::ApiKeysWithProviders; use buffer_diff::BufferDiff; -use client::CloudUserStore; use cloud_llm_client::CompletionIntent; use collections::{HashMap, HashSet}; use editor::actions::{MoveUp, Paste}; @@ -78,7 +77,6 @@ pub struct MessageEditor { editor: Entity, workspace: WeakEntity, project: Entity, - cloud_user_store: Entity, context_store: Entity, prompt_store: Option>, history_store: Option>, @@ -158,7 +156,6 @@ impl MessageEditor { pub fn new( fs: Arc, workspace: WeakEntity, - cloud_user_store: Entity, context_store: Entity, prompt_store: Option>, thread_store: WeakEntity, @@ -230,7 +227,6 @@ impl MessageEditor { Self { editor: editor.clone(), project: thread.read(cx).project().clone(), - cloud_user_store, thread, incompatible_tools_state: incompatible_tools.clone(), workspace, @@ -1286,16 +1282,14 @@ impl MessageEditor { return None; } - let cloud_user_store = self.cloud_user_store.read(cx); - if cloud_user_store.is_usage_based_billing_enabled() { + let user_store = self.project.read(cx).user_store().read(cx); + if user_store.is_usage_based_billing_enabled() { return None; } - let plan = cloud_user_store - .plan() - .unwrap_or(cloud_llm_client::Plan::ZedFree); + let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree); - let usage = cloud_user_store.model_request_usage()?; + let usage = user_store.model_request_usage()?; Some( div() @@ -1758,7 +1752,6 @@ impl AgentPreview for MessageEditor { ) -> Option { if let Some(workspace) = workspace.upgrade() { let fs = workspace.read(cx).app_state().fs.clone(); - let cloud_user_store = workspace.read(cx).app_state().cloud_user_store.clone(); let project = workspace.read(cx).project().clone(); let weak_project = project.downgrade(); let context_store = cx.new(|_cx| ContextStore::new(weak_project, None)); @@ -1771,7 +1764,6 @@ impl AgentPreview for MessageEditor { MessageEditor::new( fs, workspace.downgrade(), - cloud_user_store, context_store, None, thread_store.downgrade(), diff --git a/crates/ai_onboarding/Cargo.toml b/crates/ai_onboarding/Cargo.toml index 20fd54339e..95a45b1a6f 100644 --- a/crates/ai_onboarding/Cargo.toml +++ b/crates/ai_onboarding/Cargo.toml @@ -20,7 +20,6 @@ cloud_llm_client.workspace = true component.workspace = true gpui.workspace = true language_model.workspace = true -proto.workspace = true serde.workspace = true smallvec.workspace = true telemetry.workspace = true diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index 237b0ae046..f1629eeff8 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use client::{Client, CloudUserStore, UserStore}; +use client::{Client, UserStore}; use cloud_llm_client::Plan; use gpui::{Entity, IntoElement, ParentElement}; use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; @@ -10,7 +10,6 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding}; pub struct AgentPanelOnboarding { user_store: Entity, - cloud_user_store: Entity, client: Arc, configured_providers: Vec<(IconName, SharedString)>, continue_with_zed_ai: Arc, @@ -19,7 +18,6 @@ pub struct AgentPanelOnboarding { impl AgentPanelOnboarding { pub fn new( user_store: Entity, - cloud_user_store: Entity, client: Arc, continue_with_zed_ai: impl Fn(&mut Window, &mut App) + 'static, cx: &mut Context, @@ -39,7 +37,6 @@ impl AgentPanelOnboarding { Self { user_store, - cloud_user_store, client, configured_providers: Self::compute_available_providers(cx), continue_with_zed_ai: Arc::new(continue_with_zed_ai), @@ -60,8 +57,8 @@ impl AgentPanelOnboarding { impl Render for AgentPanelOnboarding { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let enrolled_in_trial = self.cloud_user_store.read(cx).plan() == Some(Plan::ZedProTrial); - let is_pro_user = self.cloud_user_store.read(cx).plan() == Some(Plan::ZedPro); + let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial); + let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro); AgentPanelOnboardingCard::new() .child( diff --git a/crates/ai_onboarding/src/ai_onboarding.rs b/crates/ai_onboarding/src/ai_onboarding.rs index 3aec9c62cd..c252b65f20 100644 --- a/crates/ai_onboarding/src/ai_onboarding.rs +++ b/crates/ai_onboarding/src/ai_onboarding.rs @@ -9,6 +9,7 @@ pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProvider pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; pub use agent_panel_onboarding_content::AgentPanelOnboarding; pub use ai_upsell_card::AiUpsellCard; +use cloud_llm_client::Plan; pub use edit_prediction_onboarding_content::EditPredictionOnboarding; pub use young_account_banner::YoungAccountBanner; @@ -79,7 +80,7 @@ impl From for SignInStatus { pub struct ZedAiOnboarding { pub sign_in_status: SignInStatus, pub has_accepted_terms_of_service: bool, - pub plan: Option, + pub plan: Option, pub account_too_young: bool, pub continue_with_zed_ai: Arc, pub sign_in: Arc, @@ -99,8 +100,8 @@ impl ZedAiOnboarding { Self { sign_in_status: status.into(), - has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false), - plan: store.current_plan(), + has_accepted_terms_of_service: store.has_accepted_terms_of_service(), + plan: store.plan(), account_too_young: store.account_too_young(), continue_with_zed_ai, accept_terms_of_service: Arc::new({ @@ -113,11 +114,9 @@ impl ZedAiOnboarding { sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); - async move |cx| { - client.authenticate_and_connect(true, cx).await; - } + async move |cx| client.sign_in_with_optional_connect(true, cx).await }) - .detach(); + .detach_and_log_err(cx); }), dismiss_onboarding: None, } @@ -411,9 +410,9 @@ impl RenderOnce for ZedAiOnboarding { if matches!(self.sign_in_status, SignInStatus::SignedIn) { if self.has_accepted_terms_of_service { match self.plan { - None | Some(proto::Plan::Free) => self.render_free_plan_state(cx), - Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx), - Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx), + None | Some(Plan::ZedFree) => self.render_free_plan_state(cx), + Some(Plan::ZedProTrial) => self.render_trial_state(cx), + Some(Plan::ZedPro) => self.render_pro_plan_state(cx), } } else { self.render_accept_terms_of_service() @@ -433,7 +432,7 @@ impl Component for ZedAiOnboarding { fn onboarding( sign_in_status: SignInStatus, has_accepted_terms_of_service: bool, - plan: Option, + plan: Option, account_too_young: bool, ) -> AnyElement { ZedAiOnboarding { @@ -468,25 +467,15 @@ impl Component for ZedAiOnboarding { ), single_example( "Free Plan", - onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false), ), single_example( "Pro Trial", - onboarding( - SignInStatus::SignedIn, - true, - Some(proto::Plan::ZedProTrial), - false, - ), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false), ), single_example( "Pro Plan", - onboarding( - SignInStatus::SignedIn, - true, - Some(proto::Plan::ZedPro), - false, - ), + onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false), ), ]) .into_any_element(), diff --git a/crates/ai_onboarding/src/ai_upsell_card.rs b/crates/ai_onboarding/src/ai_upsell_card.rs index 56eaca2392..2408b6aa37 100644 --- a/crates/ai_onboarding/src/ai_upsell_card.rs +++ b/crates/ai_onboarding/src/ai_upsell_card.rs @@ -24,11 +24,9 @@ impl AiUpsellCard { sign_in: Arc::new(move |_window, cx| { cx.spawn({ let client = client.clone(); - async move |cx| { - client.authenticate_and_connect(true, cx).await; - } + async move |cx| client.sign_in_with_optional_connect(true, cx).await }) - .detach(); + .detach_and_log_err(cx); }), } } diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/assistant_tools/src/edit_agent/evals.rs index 13619da25c..eda7eee0e3 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/assistant_tools/src/edit_agent/evals.rs @@ -7,7 +7,7 @@ use crate::{ }; use Role::*; use assistant_tool::ToolRegistry; -use client::{Client, CloudUserStore, UserStore}; +use client::{Client, UserStore}; use collections::HashMap; use fs::FakeFs; use futures::{FutureExt, future::LocalBoxFuture}; @@ -1470,14 +1470,12 @@ impl EditAgentTest { client::init_settings(cx); let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); settings::init(cx); Project::init_settings(cx); language::init(cx); language_model::init(client.clone(), cx); - language_models::init(user_store.clone(), cloud_user_store, client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); crate::init(client.http_client(), cx); }); diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index f8f5de3c39..c92226eeeb 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -259,20 +259,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) { assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx); }); - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![5]); - server.respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 5, - github_login: "nathansobo".into(), - avatar_url: "http://avatar.com/nathansobo".into(), - name: None, - }], - }, - ); - // Join a channel and populate its existing messages. let channel = channel_store.update(cx, |store, cx| { let channel_id = store.ordered_channels().next().unwrap().1.id; @@ -334,7 +320,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "a".into()), + ("user-5".into(), "a".into()), ("maxbrunsfeld".into(), "b".into()) ] ); @@ -437,7 +423,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), &[ - ("nathansobo".into(), "y".into()), + ("user-5".into(), "y".into()), ("maxbrunsfeld".into(), "z".into()) ] ); diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 3ff03114ea..365625b445 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -17,7 +17,6 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup [dependencies] anyhow.workspace = true -async-recursion = "0.3" async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] } base64.workspace = true chrono = { workspace = true, features = ["serde"] } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 230e1ce634..b9b20aa4f2 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -1,14 +1,12 @@ #[cfg(any(test, feature = "test-support"))] pub mod test; -mod cloud; mod proxy; pub mod telemetry; pub mod user; pub mod zed_urls; use anyhow::{Context as _, Result, anyhow}; -use async_recursion::async_recursion; use async_tungstenite::tungstenite::{ client::IntoClientRequest, error::Error as WebsocketError, @@ -52,7 +50,6 @@ use tokio::net::TcpStream; use url::Url; use util::{ConnectionResult, ResultExt}; -pub use cloud::*; pub use rpc::*; pub use telemetry_events::Event; pub use user::*; @@ -164,20 +161,8 @@ pub fn init(client: &Arc, cx: &mut App) { let client = client.clone(); move |_: &SignIn, cx| { if let Some(client) = client.upgrade() { - cx.spawn( - async move |cx| match client.authenticate_and_connect(true, &cx).await { - ConnectionResult::Timeout => { - log::error!("Initial authentication timed out"); - } - ConnectionResult::ConnectionReset => { - log::error!("Initial authentication connection reset"); - } - ConnectionResult::Result(r) => { - r.log_err(); - } - }, - ) - .detach(); + cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await) + .detach_and_log_err(cx); } } }); @@ -286,6 +271,8 @@ pub enum Status { SignedOut, UpgradeRequired, Authenticating, + Authenticated, + AuthenticationError, Connecting, ConnectionError, Connected { @@ -712,7 +699,7 @@ impl Client { let mut delay = INITIAL_RECONNECTION_DELAY; loop { - match client.authenticate_and_connect(true, &cx).await { + match client.connect(true, &cx).await { ConnectionResult::Timeout => { log::error!("client connect attempt timed out") } @@ -882,17 +869,122 @@ impl Client { .is_some() } - #[async_recursion(?Send)] - pub async fn authenticate_and_connect( + pub async fn sign_in( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> Result { + if self.status().borrow().is_signed_out() { + self.set_status(Status::Authenticating, cx); + } else { + self.set_status(Status::Reauthenticating, cx); + } + + let mut credentials = None; + + let old_credentials = self.state.read().credentials.clone(); + if let Some(old_credentials) = old_credentials { + self.cloud_client.set_credentials( + old_credentials.user_id as u32, + old_credentials.access_token.clone(), + ); + + // Fetch the authenticated user with the old credentials, to ensure they are still valid. + if self.cloud_client.get_authenticated_user().await.is_ok() { + credentials = Some(old_credentials); + } + } + + if credentials.is_none() && try_provider { + if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await { + self.cloud_client.set_credentials( + stored_credentials.user_id as u32, + stored_credentials.access_token.clone(), + ); + + // Fetch the authenticated user with the stored credentials, and + // clear them from the credentials provider if that fails. + if self.cloud_client.get_authenticated_user().await.is_ok() { + credentials = Some(stored_credentials); + } else { + self.credentials_provider + .delete_credentials(cx) + .await + .log_err(); + } + } + } + + if credentials.is_none() { + let mut status_rx = self.status(); + let _ = status_rx.next().await; + futures::select_biased! { + authenticate = self.authenticate(cx).fuse() => { + match authenticate { + Ok(creds) => { + if IMPERSONATE_LOGIN.is_none() { + self.credentials_provider + .write_credentials(creds.user_id, creds.access_token.clone(), cx) + .await + .log_err(); + } + + credentials = Some(creds); + }, + Err(err) => { + self.set_status(Status::AuthenticationError, cx); + return Err(err); + } + } + } + _ = status_rx.next().fuse() => { + return Err(anyhow!("authentication canceled")); + } + } + } + + let credentials = credentials.unwrap(); + self.set_id(credentials.user_id); + self.cloud_client + .set_credentials(credentials.user_id as u32, credentials.access_token.clone()); + self.state.write().credentials = Some(credentials.clone()); + self.set_status(Status::Authenticated, cx); + + Ok(credentials) + } + + /// Performs a sign-in and also connects to Collab. + /// + /// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls + /// to `sign_in` when we're ready to remove auto-connection to Collab. + pub async fn sign_in_with_optional_connect( + self: &Arc, + try_provider: bool, + cx: &AsyncApp, + ) -> Result<()> { + let credentials = self.sign_in(try_provider, cx).await?; + + let connect_result = match self.connect_with_credentials(credentials, cx).await { + ConnectionResult::Timeout => Err(anyhow!("connection timed out")), + ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")), + ConnectionResult::Result(result) => result.context("client auth and connect"), + }; + connect_result.log_err(); + + Ok(()) + } + + pub async fn connect( self: &Arc, try_provider: bool, cx: &AsyncApp, ) -> ConnectionResult<()> { let was_disconnected = match *self.status().borrow() { - Status::SignedOut => true, + Status::SignedOut | Status::Authenticated => true, Status::ConnectionError | Status::ConnectionLost | Status::Authenticating { .. } + | Status::AuthenticationError | Status::Reauthenticating { .. } | Status::ReconnectionError { .. } => false, Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { @@ -905,41 +997,10 @@ impl Client { ); } }; - if was_disconnected { - self.set_status(Status::Authenticating, cx); - } else { - self.set_status(Status::Reauthenticating, cx) - } - - let mut read_from_provider = false; - let mut credentials = self.state.read().credentials.clone(); - if credentials.is_none() && try_provider { - credentials = self.credentials_provider.read_credentials(cx).await; - read_from_provider = credentials.is_some(); - } - - if credentials.is_none() { - let mut status_rx = self.status(); - let _ = status_rx.next().await; - futures::select_biased! { - authenticate = self.authenticate(cx).fuse() => { - match authenticate { - Ok(creds) => credentials = Some(creds), - Err(err) => { - self.set_status(Status::ConnectionError, cx); - return ConnectionResult::Result(Err(err)); - } - } - } - _ = status_rx.next().fuse() => { - return ConnectionResult::Result(Err(anyhow!("authentication canceled"))); - } - } - } - let credentials = credentials.unwrap(); - self.set_id(credentials.user_id); - self.cloud_client - .set_credentials(credentials.user_id as u32, credentials.access_token.clone()); + let credentials = match self.sign_in(try_provider, cx).await { + Ok(credentials) => credentials, + Err(err) => return ConnectionResult::Result(Err(err)), + }; if was_disconnected { self.set_status(Status::Connecting, cx); @@ -947,17 +1008,20 @@ impl Client { self.set_status(Status::Reconnecting, cx); } + self.connect_with_credentials(credentials, cx).await + } + + async fn connect_with_credentials( + self: &Arc, + credentials: Credentials, + cx: &AsyncApp, + ) -> ConnectionResult<()> { let mut timeout = futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT)); futures::select_biased! { connection = self.establish_connection(&credentials, cx).fuse() => { match connection { Ok(conn) => { - self.state.write().credentials = Some(credentials.clone()); - if !read_from_provider && IMPERSONATE_LOGIN.is_none() { - self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err(); - } - futures::select_biased! { result = self.set_connection(conn, cx).fuse() => { match result.context("client auth and connect") { @@ -975,15 +1039,8 @@ impl Client { } } Err(EstablishConnectionError::Unauthorized) => { - self.state.write().credentials.take(); - if read_from_provider { - self.credentials_provider.delete_credentials(cx).await.log_err(); - self.set_status(Status::SignedOut, cx); - self.authenticate_and_connect(false, cx).await - } else { - self.set_status(Status::ConnectionError, cx); - ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) - } + self.set_status(Status::ConnectionError, cx); + ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect")) } Err(EstablishConnectionError::UpgradeRequired) => { self.set_status(Status::UpgradeRequired, cx); @@ -1733,7 +1790,7 @@ mod tests { }); let auth_and_connect = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert!(matches!(status.next().await, Some(Status::Connecting))); @@ -1810,7 +1867,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - move |cx| async move { client.authenticate_and_connect(false, &cx).await } + move |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 1); @@ -1818,7 +1875,7 @@ mod tests { let _authenticate = cx.spawn({ let client = client.clone(); - |cx| async move { client.authenticate_and_connect(false, &cx).await } + |cx| async move { client.connect(false, &cx).await } }); executor.run_until_parked(); assert_eq!(*auth_count.lock(), 2); diff --git a/crates/client/src/cloud.rs b/crates/client/src/cloud.rs deleted file mode 100644 index 39c9d04887..0000000000 --- a/crates/client/src/cloud.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod user_store; - -pub use user_store::*; diff --git a/crates/client/src/cloud/user_store.rs b/crates/client/src/cloud/user_store.rs deleted file mode 100644 index 78444b3f95..0000000000 --- a/crates/client/src/cloud/user_store.rs +++ /dev/null @@ -1,211 +0,0 @@ -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Context as _; -use chrono::{DateTime, Utc}; -use cloud_api_client::{AuthenticatedUser, CloudApiClient, GetAuthenticatedUserResponse, PlanInfo}; -use cloud_llm_client::Plan; -use gpui::{Context, Entity, Subscription, Task}; -use util::{ResultExt as _, maybe}; - -use crate::user::Event as RpcUserStoreEvent; -use crate::{EditPredictionUsage, ModelRequestUsage, RequestUsage, UserStore}; - -pub struct CloudUserStore { - cloud_client: Arc, - authenticated_user: Option>, - plan_info: Option>, - model_request_usage: Option, - edit_prediction_usage: Option, - _maintain_authenticated_user_task: Task<()>, - _rpc_plan_updated_subscription: Subscription, -} - -impl CloudUserStore { - pub fn new( - cloud_client: Arc, - rpc_user_store: Entity, - cx: &mut Context, - ) -> Self { - let rpc_plan_updated_subscription = - cx.subscribe(&rpc_user_store, Self::handle_rpc_user_store_event); - - Self { - cloud_client: cloud_client.clone(), - authenticated_user: None, - plan_info: None, - model_request_usage: None, - edit_prediction_usage: None, - _maintain_authenticated_user_task: cx.spawn(async move |this, cx| { - maybe!(async move { - loop { - let Some(this) = this.upgrade() else { - return anyhow::Ok(()); - }; - - if cloud_client.has_credentials() { - let already_fetched_authenticated_user = this - .read_with(cx, |this, _cx| this.authenticated_user().is_some()) - .unwrap_or(false); - - if already_fetched_authenticated_user { - // We already fetched the authenticated user; nothing to do. - } else { - let authenticated_user_result = cloud_client - .get_authenticated_user() - .await - .context("failed to fetch authenticated user"); - if let Some(response) = authenticated_user_result.log_err() { - this.update(cx, |this, _cx| { - this.update_authenticated_user(response); - }) - .ok(); - } - } - } else { - this.update(cx, |this, _cx| { - this.authenticated_user.take(); - this.plan_info.take(); - }) - .ok(); - } - - cx.background_executor() - .timer(Duration::from_millis(100)) - .await; - } - }) - .await - .log_err(); - }), - _rpc_plan_updated_subscription: rpc_plan_updated_subscription, - } - } - - pub fn is_authenticated(&self) -> bool { - self.authenticated_user.is_some() - } - - pub fn authenticated_user(&self) -> Option> { - self.authenticated_user.clone() - } - - pub fn plan(&self) -> Option { - self.plan_info.as_ref().map(|plan| plan.plan) - } - - pub fn subscription_period(&self) -> Option<(DateTime, DateTime)> { - self.plan_info - .as_ref() - .and_then(|plan| plan.subscription_period) - .map(|subscription_period| { - ( - subscription_period.started_at.0, - subscription_period.ended_at.0, - ) - }) - } - - pub fn trial_started_at(&self) -> Option> { - self.plan_info - .as_ref() - .and_then(|plan| plan.trial_started_at) - .map(|trial_started_at| trial_started_at.0) - } - - pub fn has_accepted_tos(&self) -> bool { - self.authenticated_user - .as_ref() - .map(|user| user.accepted_tos_at.is_some()) - .unwrap_or_default() - } - - /// Returns whether the user's account is too new to use the service. - pub fn account_too_young(&self) -> bool { - self.plan_info - .as_ref() - .map(|plan| plan.is_account_too_young) - .unwrap_or_default() - } - - /// Returns whether the current user has overdue invoices and usage should be blocked. - pub fn has_overdue_invoices(&self) -> bool { - self.plan_info - .as_ref() - .map(|plan| plan.has_overdue_invoices) - .unwrap_or_default() - } - - pub fn is_usage_based_billing_enabled(&self) -> bool { - self.plan_info - .as_ref() - .map(|plan| plan.is_usage_based_billing_enabled) - .unwrap_or_default() - } - - pub fn model_request_usage(&self) -> Option { - self.model_request_usage - } - - pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { - self.model_request_usage = Some(usage); - cx.notify(); - } - - pub fn edit_prediction_usage(&self) -> Option { - self.edit_prediction_usage - } - - pub fn update_edit_prediction_usage( - &mut self, - usage: EditPredictionUsage, - cx: &mut Context, - ) { - self.edit_prediction_usage = Some(usage); - cx.notify(); - } - - fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) { - self.authenticated_user = Some(Arc::new(response.user)); - self.model_request_usage = Some(ModelRequestUsage(RequestUsage { - limit: response.plan.usage.model_requests.limit, - amount: response.plan.usage.model_requests.used as i32, - })); - self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { - limit: response.plan.usage.edit_predictions.limit, - amount: response.plan.usage.edit_predictions.used as i32, - })); - self.plan_info = Some(Arc::new(response.plan)); - } - - fn handle_rpc_user_store_event( - &mut self, - _: Entity, - event: &RpcUserStoreEvent, - cx: &mut Context, - ) { - match event { - RpcUserStoreEvent::PlanUpdated => { - cx.spawn(async move |this, cx| { - let cloud_client = - cx.update(|cx| this.read_with(cx, |this, _cx| this.cloud_client.clone()))??; - - let response = cloud_client - .get_authenticated_user() - .await - .context("failed to fetch authenticated user")?; - - cx.update(|cx| { - this.update(cx, |this, _cx| { - this.update_authenticated_user(response); - }) - })??; - - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - _ => {} - } - } -} diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 6ce79fa9c5..439fb100d2 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,8 +1,11 @@ use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use anyhow::{Context as _, Result, anyhow}; use chrono::Duration; +use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo}; +use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; use futures::{StreamExt, stream::BoxStream}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; +use http_client::{AsyncBody, Method, Request, http}; use parking_lot::Mutex; use rpc::{ ConnectionId, Peer, Receipt, TypedEnvelope, @@ -39,6 +42,44 @@ impl FakeServer { executor: cx.executor(), }; + client.http_client().as_fake().replace_handler({ + let state = server.state.clone(); + move |old_handler, req| { + let state = state.clone(); + let old_handler = old_handler.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: client_user_id, + access_token: state.lock().access_token.to_string(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + client_user_id as i32, + format!("user-{client_user_id}"), + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => old_handler(req).await, + } + } + } + }); client .override_authenticate({ let state = Arc::downgrade(&server.state); @@ -105,7 +146,7 @@ impl FakeServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); @@ -223,3 +264,54 @@ impl Drop for FakeServer { self.disconnect(); } } + +pub fn parse_authorization_header(req: &Request) -> Option { + let mut auth_header = req + .headers() + .get(http::header::AUTHORIZATION)? + .to_str() + .ok()? + .split_whitespace(); + let user_id = auth_header.next()?.parse().ok()?; + let access_token = auth_header.next()?; + Some(Credentials { + user_id, + access_token: access_token.to_string(), + }) +} + +pub fn make_get_authenticated_user_response( + user_id: i32, + github_login: String, +) -> GetAuthenticatedUserResponse { + GetAuthenticatedUserResponse { + user: AuthenticatedUser { + id: user_id, + metrics_id: format!("metrics-id-{user_id}"), + avatar_url: "".to_string(), + github_login, + name: None, + is_staff: false, + accepted_tos_at: None, + }, + feature_flags: vec![], + plan: PlanInfo { + plan: Plan::ZedPro, + subscription_period: None, + usage: CurrentUsage { + model_requests: UsageData { + used: 0, + limit: UsageLimit::Limited(500), + }, + edit_predictions: UsageData { + used: 250, + limit: UsageLimit::Unlimited, + }, + }, + trial_started_at: None, + is_usage_based_billing_enabled: false, + is_account_too_young: false, + has_overdue_invoices: false, + }, + } +} diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index df5ce67be3..3c125a0882 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,6 +1,7 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; +use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo}; use cloud_llm_client::{ EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, @@ -20,7 +21,7 @@ use std::{ sync::{Arc, Weak}, }; use text::ReplicaId; -use util::TryFutureExt as _; +use util::{ResultExt, TryFutureExt as _}; pub type UserId = u64; @@ -110,12 +111,11 @@ pub struct UserStore { by_github_login: HashMap, participant_indices: HashMap, update_contacts_tx: mpsc::UnboundedSender, - current_plan: Option, - trial_started_at: Option>, - is_usage_based_billing_enabled: Option, - account_too_young: Option, + model_request_usage: Option, + edit_prediction_usage: Option, + plan_info: Option, current_user: watch::Receiver>>, - accepted_tos_at: Option>>, + accepted_tos_at: Option>, contacts: Vec>, incoming_contact_requests: Vec>, outgoing_contact_requests: Vec>, @@ -185,10 +185,9 @@ impl UserStore { users: Default::default(), by_github_login: Default::default(), current_user: current_user_rx, - current_plan: None, - trial_started_at: None, - is_usage_based_billing_enabled: None, - account_too_young: None, + plan_info: None, + model_request_usage: None, + edit_prediction_usage: None, accepted_tos_at: None, contacts: Default::default(), incoming_contact_requests: Default::default(), @@ -218,53 +217,30 @@ impl UserStore { return Ok(()); }; match status { - Status::Connected { .. } => { + Status::Authenticated | Status::Connected { .. } => { if let Some(user_id) = client.user_id() { - let fetch_user = if let Ok(fetch_user) = - this.update(cx, |this, cx| this.get_user(user_id, cx).log_err()) - { - fetch_user - } else { - break; - }; - let fetch_private_user_info = - client.request(proto::GetPrivateUserInfo {}).log_err(); - let (user, info) = - futures::join!(fetch_user, fetch_private_user_info); - + let response = client.cloud_client().get_authenticated_user().await; + let mut current_user = None; cx.update(|cx| { - if let Some(info) = info { - let staff = - info.staff && !*feature_flags::ZED_DISABLE_STAFF; - cx.update_flags(staff, info.flags); - client.telemetry.set_authenticated_user_info( - Some(info.metrics_id.clone()), - staff, - ); - + if let Some(response) = response.log_err() { + let user = Arc::new(User { + id: user_id, + github_login: response.user.github_login.clone().into(), + avatar_uri: response.user.avatar_url.clone().into(), + name: response.user.name.clone(), + }); + current_user = Some(user.clone()); this.update(cx, |this, cx| { - let accepted_tos_at = { - #[cfg(debug_assertions)] - if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() - { - None - } else { - info.accepted_tos_at - } - - #[cfg(not(debug_assertions))] - info.accepted_tos_at - }; - - this.set_current_user_accepted_tos_at(accepted_tos_at); - cx.emit(Event::PrivateUserInfoUpdated); + this.by_github_login + .insert(user.github_login.clone(), user_id); + this.users.insert(user_id, user); + this.update_authenticated_user(response, cx) }) } else { anyhow::Ok(()) } })??; - - current_user_tx.send(user).await.ok(); + current_user_tx.send(current_user).await.ok(); this.update(cx, |_, cx| cx.notify())?; } @@ -345,22 +321,22 @@ impl UserStore { async fn handle_update_plan( this: Entity, - message: TypedEnvelope, + _message: TypedEnvelope, mut cx: AsyncApp, ) -> Result<()> { - this.update(&mut cx, |this, cx| { - this.current_plan = Some(message.payload.plan()); - this.trial_started_at = message - .payload - .trial_started_at - .and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0)); - this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled; - this.account_too_young = message.payload.account_too_young; + let client = this + .read_with(&cx, |this, _| this.client.upgrade())? + .context("client was dropped")?; - cx.emit(Event::PlanUpdated); - cx.notify(); - })?; - Ok(()) + let response = client + .cloud_client() + .get_authenticated_user() + .await + .context("failed to fetch authenticated user")?; + + this.update(&mut cx, |this, cx| { + this.update_authenticated_user(response, cx); + }) } fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { @@ -719,42 +695,131 @@ impl UserStore { self.current_user.borrow().clone() } - pub fn current_plan(&self) -> Option { + pub fn plan(&self) -> Option { #[cfg(debug_assertions)] if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() { return match plan.as_str() { - "free" => Some(proto::Plan::Free), - "trial" => Some(proto::Plan::ZedProTrial), - "pro" => Some(proto::Plan::ZedPro), + "free" => Some(cloud_llm_client::Plan::ZedFree), + "trial" => Some(cloud_llm_client::Plan::ZedProTrial), + "pro" => Some(cloud_llm_client::Plan::ZedPro), _ => { panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'"); } }; } - self.current_plan + self.plan_info.as_ref().map(|info| info.plan) + } + + pub fn subscription_period(&self) -> Option<(DateTime, DateTime)> { + self.plan_info + .as_ref() + .and_then(|plan| plan.subscription_period) + .map(|subscription_period| { + ( + subscription_period.started_at.0, + subscription_period.ended_at.0, + ) + }) } pub fn trial_started_at(&self) -> Option> { - self.trial_started_at + self.plan_info + .as_ref() + .and_then(|plan| plan.trial_started_at) + .map(|trial_started_at| trial_started_at.0) } - pub fn usage_based_billing_enabled(&self) -> Option { - self.is_usage_based_billing_enabled + /// Returns whether the user's account is too new to use the service. + pub fn account_too_young(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_account_too_young) + .unwrap_or_default() + } + + /// Returns whether the current user has overdue invoices and usage should be blocked. + pub fn has_overdue_invoices(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.has_overdue_invoices) + .unwrap_or_default() + } + + pub fn is_usage_based_billing_enabled(&self) -> bool { + self.plan_info + .as_ref() + .map(|plan| plan.is_usage_based_billing_enabled) + .unwrap_or_default() + } + + pub fn model_request_usage(&self) -> Option { + self.model_request_usage + } + + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + + pub fn edit_prediction_usage(&self) -> Option { + self.edit_prediction_usage + } + + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); + } + + fn update_authenticated_user( + &mut self, + response: GetAuthenticatedUserResponse, + cx: &mut Context, + ) { + let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF; + cx.update_flags(staff, response.feature_flags); + if let Some(client) = self.client.upgrade() { + client + .telemetry + .set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff); + } + + let accepted_tos_at = { + #[cfg(debug_assertions)] + if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() { + None + } else { + response.user.accepted_tos_at + } + + #[cfg(not(debug_assertions))] + response.user.accepted_tos_at + }; + + self.accepted_tos_at = Some(accepted_tos_at); + self.model_request_usage = Some(ModelRequestUsage(RequestUsage { + limit: response.plan.usage.model_requests.limit, + amount: response.plan.usage.model_requests.used as i32, + })); + self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage { + limit: response.plan.usage.edit_predictions.limit, + amount: response.plan.usage.edit_predictions.used as i32, + })); + self.plan_info = Some(response.plan); + cx.emit(Event::PrivateUserInfoUpdated); } pub fn watch_current_user(&self) -> watch::Receiver>> { self.current_user.clone() } - /// Returns whether the user's account is too new to use the service. - pub fn account_too_young(&self) -> bool { - self.account_too_young.unwrap_or(false) - } - - pub fn current_user_has_accepted_terms(&self) -> Option { + pub fn has_accepted_terms_of_service(&self) -> bool { self.accepted_tos_at - .map(|accepted_tos_at| accepted_tos_at.is_some()) + .map_or(false, |accepted_tos_at| accepted_tos_at.is_some()) } pub fn accept_terms_of_service(&self, cx: &Context) -> Task> { @@ -766,23 +831,18 @@ impl UserStore { cx.spawn(async move |this, cx| -> anyhow::Result<()> { let client = client.upgrade().context("client not found")?; let response = client - .request(proto::AcceptTermsOfService {}) + .cloud_client() + .accept_terms_of_service() .await .context("error accepting tos")?; this.update(cx, |this, cx| { - this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at)); + this.accepted_tos_at = Some(response.user.accepted_tos_at); cx.emit(Event::PrivateUserInfoUpdated); })?; Ok(()) }) } - fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option) { - self.accepted_tos_at = Some( - accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)), - ); - } - fn load_users( &self, request: impl RequestMessage, diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index 7aa41e0e7d..aea359d75b 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -1286,7 +1286,7 @@ async fn test_calls_on_multiple_connections( client_b1.disconnect(&cx_b1.to_async()); executor.advance_clock(RECEIVE_TIMEOUT); client_b1 - .authenticate_and_connect(false, &cx_b1.to_async()) + .connect(false, &cx_b1.to_async()) .await .into_response() .unwrap(); @@ -1667,7 +1667,7 @@ async fn test_project_reconnect( // Client A reconnects. Their project is re-shared, and client B re-joins it. server.allow_connections(); client_a - .authenticate_and_connect(false, &cx_a.to_async()) + .connect(false, &cx_a.to_async()) .await .into_response() .unwrap(); @@ -1796,7 +1796,7 @@ async fn test_project_reconnect( // Client B reconnects. They re-join the room and the remaining shared project. server.allow_connections(); client_b - .authenticate_and_connect(false, &cx_b.to_async()) + .connect(false, &cx_b.to_async()) .await .into_response() .unwrap(); @@ -5738,7 +5738,7 @@ async fn test_contacts( server.allow_connections(); client_c - .authenticate_and_connect(false, &cx_c.to_async()) + .connect(false, &cx_c.to_async()) .await .into_response() .unwrap(); @@ -6269,7 +6269,7 @@ async fn test_contact_requests( client.disconnect(&cx.to_async()); client.clear_contacts(cx).await; client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs index 4e64b5526b..9bf906694e 100644 --- a/crates/collab/src/tests/notification_tests.rs +++ b/crates/collab/src/tests/notification_tests.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use gpui::{BackgroundExecutor, TestAppContext}; use notifications::NotificationEvent; use parking_lot::Mutex; +use pretty_assertions::assert_eq; use rpc::{Notification, proto}; use crate::tests::TestServer; @@ -17,6 +18,9 @@ async fn test_notifications( let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; + // Wait for authentication/connection to Collab to be established. + executor.run_until_parked(); + let notification_events_a = Arc::new(Mutex::new(Vec::new())); let notification_events_b = Arc::new(Mutex::new(Vec::new())); client_a.notification_store().update(cx_a, |_, cx| { diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 3751d6918e..5fcc622fc1 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -8,7 +8,7 @@ use crate::{ use anyhow::anyhow; use call::ActiveCall; use channel::{ChannelBuffer, ChannelStore}; -use client::CloudUserStore; +use client::test::{make_get_authenticated_user_response, parse_authorization_header}; use client::{ self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore, proto::PeerId, @@ -21,7 +21,7 @@ use fs::FakeFs; use futures::{StreamExt as _, channel::oneshot}; use git::GitHostingProviderRegistry; use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext}; -use http_client::FakeHttpClient; +use http_client::{FakeHttpClient, Method}; use language::LanguageRegistry; use node_runtime::NodeRuntime; use notifications::NotificationStore; @@ -162,6 +162,8 @@ impl TestServer { } pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { + const ACCESS_TOKEN: &str = "the-token"; + let fs = FakeFs::new(cx.executor()); cx.update(|cx| { @@ -176,7 +178,7 @@ impl TestServer { }); let clock = Arc::new(FakeSystemClock::new()); - let http = FakeHttpClient::with_404_response(); + let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await { user.id @@ -198,6 +200,47 @@ impl TestServer { .expect("creating user failed") .user_id }; + + let http = FakeHttpClient::create({ + let name = name.to_string(); + move |req| { + let name = name.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/client/users/me") => { + let credentials = parse_authorization_header(&req); + if credentials + != Some(Credentials { + user_id: user_id.to_proto(), + access_token: ACCESS_TOKEN.into(), + }) + { + return Ok(http_client::Response::builder() + .status(401) + .body("Unauthorized".into()) + .unwrap()); + } + + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&make_get_authenticated_user_response( + user_id.0, name, + )) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } + } + } + }); + let client_name = name.to_string(); let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx)); let server = self.server.clone(); @@ -209,11 +252,10 @@ impl TestServer { .unwrap() .set_id(user_id.to_proto()) .override_authenticate(move |cx| { - let access_token = "the-token".to_string(); cx.spawn(async move |_| { Ok(Credentials { user_id: user_id.to_proto(), - access_token, + access_token: ACCESS_TOKEN.into(), }) }) }) @@ -222,7 +264,7 @@ impl TestServer { credentials, &Credentials { user_id: user_id.0 as u64, - access_token: "the-token".into() + access_token: ACCESS_TOKEN.into(), } ); @@ -282,15 +324,12 @@ impl TestServer { .register_hosting_provider(Arc::new(git_hosting_providers::Github::public_instance())); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let language_registry = Arc::new(LanguageRegistry::test(cx.executor())); let session = cx.new(|cx| AppSession::new(Session::test(), cx)); let app_state = Arc::new(workspace::AppState { client: client.clone(), user_store: user_store.clone(), - cloud_user_store, workspace_store, languages: language_registry, fs: fs.clone(), @@ -323,7 +362,7 @@ impl TestServer { }); client - .authenticate_and_connect(false, &cx.to_async()) + .connect(false, &cx.to_async()) .await .into_response() .unwrap(); diff --git a/crates/collab_ui/src/collab_panel.rs b/crates/collab_ui/src/collab_panel.rs index f53b94c209..54077303a1 100644 --- a/crates/collab_ui/src/collab_panel.rs +++ b/crates/collab_ui/src/collab_panel.rs @@ -2331,7 +2331,7 @@ impl CollabPanel { let client = this.client.clone(); cx.spawn_in(window, async move |_, cx| { client - .authenticate_and_connect(true, &cx) + .connect(true, &cx) .await .into_response() .notify_async_err(cx); diff --git a/crates/collab_ui/src/notification_panel.rs b/crates/collab_ui/src/notification_panel.rs index fba8f66c2d..c3e834b645 100644 --- a/crates/collab_ui/src/notification_panel.rs +++ b/crates/collab_ui/src/notification_panel.rs @@ -634,13 +634,13 @@ impl Render for NotificationPanel { .child(Icon::new(IconName::Envelope)), ) .map(|this| { - if self.client.user_id().is_none() { + if !self.client.status().borrow().is_connected() { this.child( v_flex() .gap_2() .p_4() .child( - Button::new("sign_in_prompt_button", "Sign in") + Button::new("connect_prompt_button", "Connect") .icon_color(Color::Muted) .icon(IconName::Github) .icon_position(IconPosition::Start) @@ -652,10 +652,7 @@ impl Render for NotificationPanel { let client = client.clone(); window .spawn(cx, async move |cx| { - match client - .authenticate_and_connect(true, &cx) - .await - { + match client.connect(true, &cx).await { util::ConnectionResult::Timeout => { log::error!("Connection timeout"); } @@ -673,7 +670,7 @@ impl Render for NotificationPanel { ) .child( div().flex().w_full().items_center().child( - Label::new("Sign in to view notifications.") + Label::new("Connect to view notifications.") .color(Color::Muted) .size(LabelSize::Small), ), diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 8d257a37a7..a02b4a7f0b 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -13,7 +13,7 @@ pub(crate) use tool_metrics::*; use ::fs::RealFs; use clap::Parser; -use client::{Client, CloudUserStore, ProxySettings, UserStore}; +use client::{Client, ProxySettings, UserStore}; use collections::{HashMap, HashSet}; use extension::ExtensionHostProxy; use futures::future; @@ -329,7 +329,6 @@ pub struct AgentAppState { pub languages: Arc, pub client: Arc, pub user_store: Entity, - pub cloud_user_store: Entity, pub fs: Arc, pub node_runtime: NodeRuntime, @@ -384,8 +383,6 @@ pub fn init(cx: &mut App) -> Arc { let languages = Arc::new(languages); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); extension::init(cx); @@ -425,12 +422,7 @@ pub fn init(cx: &mut App) -> Arc { languages.clone(), ); language_model::init(client.clone(), cx); - language_models::init( - user_store.clone(), - cloud_user_store.clone(), - client.clone(), - cx, - ); + language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), node_runtime.clone(), cx); prompt_store::init(cx); terminal_view::init(cx); @@ -455,7 +447,6 @@ pub fn init(cx: &mut App) -> Arc { languages, client, user_store, - cloud_user_store, fs, node_runtime, prompt_builder, diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 54d864ea21..0f2b4c18ea 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -221,7 +221,6 @@ impl ExampleInstance { let prompt_store = None; let thread_store = ThreadStore::load( project.clone(), - app_state.cloud_user_store.clone(), tools, prompt_store, app_state.prompt_builder.clone(), diff --git a/crates/http_client/Cargo.toml b/crates/http_client/Cargo.toml index 2045708ff2..3f51cc5a23 100644 --- a/crates/http_client/Cargo.toml +++ b/crates/http_client/Cargo.toml @@ -23,6 +23,7 @@ futures.workspace = true http.workspace = true http-body.workspace = true log.workspace = true +parking_lot.workspace = true serde.workspace = true serde_json.workspace = true url.workspace = true diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 06875718d9..d33bbefc06 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -9,12 +9,10 @@ pub use http::{self, Method, Request, Response, StatusCode, Uri}; use futures::future::BoxFuture; use http::request::Builder; +use parking_lot::Mutex; #[cfg(feature = "test-support")] use std::fmt; -use std::{ - any::type_name, - sync::{Arc, Mutex}, -}; +use std::{any::type_name, sync::Arc}; pub use url::Url; #[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] @@ -86,6 +84,11 @@ pub trait HttpClient: 'static + Send + Sync { } fn proxy(&self) -> Option<&Url>; + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::()) + } } /// An [`HttpClient`] that may have a proxy. @@ -132,6 +135,11 @@ impl HttpClient for HttpClientWithProxy { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } } impl HttpClient for Arc { @@ -153,6 +161,11 @@ impl HttpClient for Arc { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } } /// An [`HttpClient`] that has a base URL. @@ -199,20 +212,13 @@ impl HttpClientWithUrl { /// Returns the base URL. pub fn base_url(&self) -> String { - self.base_url - .lock() - .map_or_else(|_| Default::default(), |url| url.clone()) + self.base_url.lock().clone() } /// Sets the base URL. pub fn set_base_url(&self, base_url: impl Into) { let base_url = base_url.into(); - self.base_url - .lock() - .map(|mut url| { - *url = base_url; - }) - .ok(); + *self.base_url.lock() = base_url; } /// Builds a URL using the given path. @@ -288,6 +294,11 @@ impl HttpClient for Arc { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } } impl HttpClient for HttpClientWithUrl { @@ -309,6 +320,11 @@ impl HttpClient for HttpClientWithUrl { fn type_name(&self) -> &'static str { self.client.type_name() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + self.client.as_fake() + } } pub fn read_proxy_from_env() -> Option { @@ -360,10 +376,15 @@ impl HttpClient for BlockedHttpClient { fn type_name(&self) -> &'static str { type_name::() } + + #[cfg(feature = "test-support")] + fn as_fake(&self) -> &FakeHttpClient { + panic!("called as_fake on {}", type_name::()) + } } #[cfg(feature = "test-support")] -type FakeHttpHandler = Box< +type FakeHttpHandler = Arc< dyn Fn(Request) -> BoxFuture<'static, anyhow::Result>> + Send + Sync @@ -372,7 +393,7 @@ type FakeHttpHandler = Box< #[cfg(feature = "test-support")] pub struct FakeHttpClient { - handler: FakeHttpHandler, + handler: Mutex>, user_agent: HeaderValue, } @@ -387,7 +408,7 @@ impl FakeHttpClient { base_url: Mutex::new("http://test.example".into()), client: HttpClientWithProxy { client: Arc::new(Self { - handler: Box::new(move |req| Box::pin(handler(req))), + handler: Mutex::new(Some(Arc::new(move |req| Box::pin(handler(req))))), user_agent: HeaderValue::from_static(type_name::()), }), proxy: None, @@ -412,6 +433,18 @@ impl FakeHttpClient { .unwrap()) }) } + + pub fn replace_handler(&self, new_handler: F) + where + Fut: futures::Future>> + Send + 'static, + F: Fn(FakeHttpHandler, Request) -> Fut + Send + Sync + 'static, + { + let mut handler = self.handler.lock(); + let old_handler = handler.take().unwrap(); + *handler = Some(Arc::new(move |req| { + Box::pin(new_handler(old_handler.clone(), req)) + })); + } } #[cfg(feature = "test-support")] @@ -427,7 +460,7 @@ impl HttpClient for FakeHttpClient { &self, req: Request, ) -> BoxFuture<'static, anyhow::Result>> { - let future = (self.handler)(req); + let future = (self.handler.lock().as_ref().unwrap())(req); future } @@ -442,4 +475,8 @@ impl HttpClient for FakeHttpClient { fn type_name(&self) -> &'static str { type_name::() } + + fn as_fake(&self) -> &FakeHttpClient { + self + } } diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs index d402b87382..2d7f211942 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use client::{CloudUserStore, DisableAiSettings, zed_urls}; +use client::{DisableAiSettings, UserStore, zed_urls}; use cloud_llm_client::UsageLimit; use copilot::{Copilot, Status}; use editor::{ @@ -59,7 +59,7 @@ pub struct InlineCompletionButton { file: Option>, edit_prediction_provider: Option>, fs: Arc, - cloud_user_store: Entity, + user_store: Entity, popover_menu_handle: PopoverMenuHandle, } @@ -245,9 +245,9 @@ impl Render for InlineCompletionButton { IconName::ZedPredictDisabled }; - if zeta::should_show_upsell_modal(&self.cloud_user_store, cx) { - let tooltip_meta = if self.cloud_user_store.read(cx).is_authenticated() { - if self.cloud_user_store.read(cx).has_accepted_tos() { + if zeta::should_show_upsell_modal(&self.user_store, cx) { + let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { + if self.user_store.read(cx).has_accepted_terms_of_service() { "Choose a Plan" } else { "Accept the Terms of Service" @@ -371,7 +371,7 @@ impl Render for InlineCompletionButton { impl InlineCompletionButton { pub fn new( fs: Arc, - cloud_user_store: Entity, + user_store: Entity, popover_menu_handle: PopoverMenuHandle, cx: &mut Context, ) -> Self { @@ -390,9 +390,9 @@ impl InlineCompletionButton { language: None, file: None, edit_prediction_provider: None, + user_store, popover_menu_handle, fs, - cloud_user_store, } } @@ -763,7 +763,7 @@ impl InlineCompletionButton { }) }) .separator(); - } else if self.cloud_user_store.read(cx).account_too_young() { + } else if self.user_store.read(cx).account_too_young() { menu = menu .custom_entry( |_window, _cx| { @@ -778,7 +778,7 @@ impl InlineCompletionButton { cx.open_url(&zed_urls::account_url(cx)) }) .separator(); - } else if self.cloud_user_store.read(cx).has_overdue_invoices() { + } else if self.user_store.read(cx).has_overdue_invoices() { menu = menu .custom_entry( |_window, _cx| { diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index a5d2ac34f5..8ae5893410 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,10 +3,11 @@ use std::sync::Arc; use anyhow::Result; use client::Client; +use cloud_llm_client::Plan; use gpui::{ App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, }; -use proto::{Plan, TypedEnvelope}; +use proto::TypedEnvelope; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -30,7 +31,7 @@ pub struct ModelRequestLimitReachedError { impl fmt::Display for ModelRequestLimitReachedError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let message = match self.plan { - Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.", + Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.", Plan::ZedPro => { "Model request limit reached. Upgrade to usage-based billing for more requests." } diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 208b0d99c9..b5bfb870f6 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -44,7 +44,6 @@ ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true -proto.workspace = true release_channel.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index a88f12283a..18e6f47ed0 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use ::settings::{Settings, SettingsStore}; -use client::{Client, CloudUserStore, UserStore}; +use client::{Client, UserStore}; use collections::HashSet; use gpui::{App, Context, Entity}; use language_model::{LanguageModelProviderId, LanguageModelRegistry}; @@ -26,22 +26,11 @@ use crate::provider::vercel::VercelLanguageModelProvider; use crate::provider::x_ai::XAiLanguageModelProvider; pub use crate::settings::*; -pub fn init( - user_store: Entity, - cloud_user_store: Entity, - client: Arc, - cx: &mut App, -) { +pub fn init(user_store: Entity, client: Arc, cx: &mut App) { crate::settings::init_settings(cx); let registry = LanguageModelRegistry::global(cx); registry.update(cx, |registry, cx| { - register_language_model_providers( - registry, - user_store, - cloud_user_store, - client.clone(), - cx, - ); + register_language_model_providers(registry, user_store, client.clone(), cx); }); let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx) @@ -111,17 +100,11 @@ fn register_openai_compatible_providers( fn register_language_model_providers( registry: &mut LanguageModelRegistry, user_store: Entity, - cloud_user_store: Entity, client: Arc, cx: &mut Context, ) { registry.register_provider( - CloudLanguageModelProvider::new( - user_store.clone(), - cloud_user_store.clone(), - client.clone(), - cx, - ), + CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx), cx, ); diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index a5de7f3442..2108547c4f 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -2,7 +2,7 @@ use ai_onboarding::YoungAccountBanner; use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; -use client::{Client, CloudUserStore, ModelRequestUsage, UserStore, zed_urls}; +use client::{Client, ModelRequestUsage, UserStore, zed_urls}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, @@ -117,7 +117,6 @@ pub struct State { client: Arc, llm_api_token: LlmApiToken, user_store: Entity, - cloud_user_store: Entity, status: client::Status, accept_terms_of_service_task: Option>>, models: Vec>, @@ -133,17 +132,14 @@ impl State { fn new( client: Arc, user_store: Entity, - cloud_user_store: Entity, status: client::Status, cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - Self { client: client.clone(), llm_api_token: LlmApiToken::default(), - user_store, - cloud_user_store, + user_store: user_store.clone(), status, accept_terms_of_service_task: None, models: Vec::new(), @@ -152,18 +148,12 @@ impl State { recommended_models: Vec::new(), _fetch_models_task: cx.spawn(async move |this, cx| { maybe!(async move { - let (client, cloud_user_store, llm_api_token) = - this.read_with(cx, |this, _cx| { - ( - client.clone(), - this.cloud_user_store.clone(), - this.llm_api_token.clone(), - ) - })?; + let (client, llm_api_token) = this + .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?; loop { - let is_authenticated = - cloud_user_store.read_with(cx, |this, _cx| this.is_authenticated())?; + let is_authenticated = user_store + .read_with(cx, |user_store, _cx| user_store.current_user().is_some())?; if is_authenticated { break; } @@ -204,22 +194,19 @@ impl State { } fn is_signed_out(&self, cx: &App) -> bool { - !self.cloud_user_store.read(cx).is_authenticated() + self.user_store.read(cx).current_user().is_none() } fn authenticate(&self, cx: &mut Context) -> Task> { let client = self.client.clone(); cx.spawn(async move |state, cx| { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.sign_in_with_optional_connect(true, &cx).await?; state.update(cx, |_, cx| cx.notify()) }) } fn has_accepted_terms_of_service(&self, cx: &App) -> bool { - self.cloud_user_store.read(cx).has_accepted_tos() + self.user_store.read(cx).has_accepted_terms_of_service() } fn accept_terms_of_service(&mut self, cx: &mut Context) { @@ -303,24 +290,11 @@ impl State { } impl CloudLanguageModelProvider { - pub fn new( - user_store: Entity, - cloud_user_store: Entity, - client: Arc, - cx: &mut App, - ) -> Self { + pub fn new(user_store: Entity, client: Arc, cx: &mut App) -> Self { let mut status_rx = client.status(); let status = *status_rx.borrow(); - let state = cx.new(|cx| { - State::new( - client.clone(), - user_store.clone(), - cloud_user_store.clone(), - status, - cx, - ) - }); + let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx)); let state_ref = state.downgrade(); let maintain_client_status = cx.spawn(async move |cx| { @@ -632,11 +606,6 @@ impl CloudLanguageModel { .and_then(|plan| plan.to_str().ok()) .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok()) { - let plan = match plan { - cloud_llm_client::Plan::ZedFree => proto::Plan::Free, - cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro, - cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, - }; return Err(anyhow!(ModelRequestLimitReachedError { plan })); } } @@ -1281,15 +1250,15 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { let state = self.state.read(cx); - let cloud_user_store = state.cloud_user_store.read(cx); + let user_store = state.user_store.read(cx); ZedAiConfiguration { is_connected: !state.is_signed_out(cx), - plan: cloud_user_store.plan(), - subscription_period: cloud_user_store.subscription_period(), - eligible_for_trial: cloud_user_store.trial_started_at().is_none(), + plan: user_store.plan(), + subscription_period: user_store.subscription_period(), + eligible_for_trial: user_store.trial_started_at().is_none(), has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), - account_too_young: cloud_user_store.account_too_young(), + account_too_young: user_store.account_too_young(), accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(), accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(), sign_in_callback: self.sign_in_callback.clone(), diff --git a/crates/onboarding/src/ai_setup_page.rs b/crates/onboarding/src/ai_setup_page.rs index a5b4b1d7be..c33dcb9ad1 100644 --- a/crates/onboarding/src/ai_setup_page.rs +++ b/crates/onboarding/src/ai_setup_page.rs @@ -278,7 +278,7 @@ pub(crate) fn render_ai_setup_page( .child(AiUpsellCard { sign_in_status: SignInStatus::SignedIn, sign_in: Arc::new(|_, _| {}), - user_plan: onboarding.cloud_user_store.read(cx).plan(), + user_plan: onboarding.user_store.read(cx).plan(), }) .child(render_llm_provider_section( onboarding, diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index bf60da4aab..2ae07b7cd5 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -1,5 +1,5 @@ use crate::welcome::{ShowWelcome, WelcomePage}; -use client::{Client, CloudUserStore, UserStore}; +use client::{Client, UserStore}; use command_palette_hooks::CommandPaletteFilter; use db::kvp::KEY_VALUE_STORE; use feature_flags::{FeatureFlag, FeatureFlagViewExt as _}; @@ -220,7 +220,6 @@ struct Onboarding { workspace: WeakEntity, focus_handle: FocusHandle, selected_page: SelectedPage, - cloud_user_store: Entity, user_store: Entity, _settings_subscription: Subscription, } @@ -231,7 +230,6 @@ impl Onboarding { workspace: workspace.weak_handle(), focus_handle: cx.focus_handle(), selected_page: SelectedPage::Basics, - cloud_user_store: workspace.app_state().cloud_user_store.clone(), user_store: workspace.user_store().clone(), _settings_subscription: cx.observe_global::(move |_, cx| cx.notify()), }) @@ -365,9 +363,8 @@ impl Onboarding { window .spawn(cx, async move |cx| { client - .authenticate_and_connect(true, &cx) + .sign_in_with_optional_connect(true, &cx) .await - .into_response() .notify_async_err(cx); }) .detach(); diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 13587b43e7..623f48d3c9 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -1362,10 +1362,7 @@ impl Project { fs: Arc, cx: AsyncApp, ) -> Result> { - client - .authenticate_and_connect(true, &cx) - .await - .into_response()?; + client.connect(true, &cx).await.into_response()?; let subscriptions = [ EntitySubscription::Project(client.subscribe_to_entity::(remote_id)?), diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 552ef915cb..426d87ad13 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -20,7 +20,7 @@ use crate::application_menu::{ use auto_update::AutoUpdateStatus; use call::ActiveCall; -use client::{Client, CloudUserStore, UserStore, zed_urls}; +use client::{Client, UserStore, zed_urls}; use cloud_llm_client::Plan; use gpui::{ Action, AnyElement, App, Context, Corner, Element, Entity, Focusable, InteractiveElement, @@ -126,7 +126,6 @@ pub struct TitleBar { platform_titlebar: Entity, project: Entity, user_store: Entity, - cloud_user_store: Entity, client: Arc, workspace: WeakEntity, application_menu: Option>, @@ -180,11 +179,9 @@ impl Render for TitleBar { children.push(self.banner.clone().into_any_element()) } - let is_authenticated = self.cloud_user_store.read(cx).is_authenticated(); let status = self.client.status(); let status = &*status.borrow(); - - let show_sign_in = !is_authenticated || !matches!(status, client::Status::Connected { .. }); + let user = self.user_store.read(cx).current_user(); children.push( h_flex() @@ -194,10 +191,10 @@ impl Render for TitleBar { .children(self.render_call_controls(window, cx)) .children(self.render_connection_status(status, cx)) .when( - show_sign_in && TitleBarSettings::get_global(cx).show_sign_in, + user.is_none() && TitleBarSettings::get_global(cx).show_sign_in, |el| el.child(self.render_sign_in_button(cx)), ) - .when(is_authenticated, |parent| { + .when(user.is_some(), |parent| { parent.child(self.render_user_menu_button(cx)) }) .into_any_element(), @@ -248,7 +245,6 @@ impl TitleBar { ) -> Self { let project = workspace.project().clone(); let user_store = workspace.app_state().user_store.clone(); - let cloud_user_store = workspace.app_state().cloud_user_store.clone(); let client = workspace.app_state().client.clone(); let active_call = ActiveCall::global(cx); @@ -296,7 +292,6 @@ impl TitleBar { workspace: workspace.weak_handle(), project, user_store, - cloud_user_store, client, _subscriptions: subscriptions, banner, @@ -622,9 +617,8 @@ impl TitleBar { window .spawn(cx, async move |cx| { client - .authenticate_and_connect(true, &cx) + .sign_in_with_optional_connect(true, &cx) .await - .into_response() .notify_async_err(cx); }) .detach(); @@ -632,15 +626,15 @@ impl TitleBar { } pub fn render_user_menu_button(&mut self, cx: &mut Context) -> impl Element { - let cloud_user_store = self.cloud_user_store.read(cx); - if let Some(user) = cloud_user_store.authenticated_user() { - let has_subscription_period = cloud_user_store.subscription_period().is_some(); - let plan = cloud_user_store.plan().filter(|_| { + let user_store = self.user_store.read(cx); + if let Some(user) = user_store.current_user() { + let has_subscription_period = user_store.subscription_period().is_some(); + let plan = user_store.plan().filter(|_| { // Since the user might be on the legacy free plan we filter based on whether we have a subscription period. has_subscription_period }); - let user_avatar = user.avatar_url.clone(); + let user_avatar = user.avatar_uri.clone(); let free_chip_bg = cx .theme() .colors() diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index aad585e419..6f7db668dd 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -15,7 +15,6 @@ mod toast_layer; mod toolbar; mod workspace_settings; -use client::CloudUserStore; pub use toast_layer::{ToastAction, ToastLayer, ToastView}; use anyhow::{Context as _, Result, anyhow}; @@ -840,7 +839,6 @@ pub struct AppState { pub languages: Arc, pub client: Arc, pub user_store: Entity, - pub cloud_user_store: Entity, pub workspace_store: Entity, pub fs: Arc, pub build_window_options: fn(Option, &mut App) -> WindowOptions, @@ -913,8 +911,6 @@ impl AppState { let client = Client::new(clock, http_client.clone(), cx); let session = cx.new(|cx| AppSession::new(Session::test(), cx)); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); theme::init(theme::LoadThemes::JustBase, cx); @@ -926,7 +922,6 @@ impl AppState { fs, languages, user_store, - cloud_user_store, workspace_store, node_runtime: NodeRuntime::unavailable(), build_window_options: |_, _| Default::default(), @@ -5739,16 +5734,12 @@ impl Workspace { let client = project.read(cx).client(); let user_store = project.read(cx).user_store(); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); - let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let session = cx.new(|cx| AppSession::new(Session::test(), cx)); window.activate_window(); let app_state = Arc::new(AppState { languages: project.read(cx).languages().clone(), workspace_store, - cloud_user_store, client, user_store, fs: project.read(cx).fs().clone(), @@ -6947,10 +6938,13 @@ async fn join_channel_internal( match status { Status::Connecting | Status::Authenticating + | Status::Authenticated | Status::Reconnecting | Status::Reauthenticating => continue, Status::Connected { .. } => break 'outer, - Status::SignedOut => return Err(ErrorCode::SignedOut.into()), + Status::SignedOut | Status::AuthenticationError => { + return Err(ErrorCode::SignedOut.into()); + } Status::UpgradeRequired => return Err(ErrorCode::UpgradeRequired.into()), Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { return Err(ErrorCode::Disconnected.into()); diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 9859702bf8..c264135e5c 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -5,7 +5,7 @@ use agent_ui::AgentPanel; use anyhow::{Context as _, Result}; use clap::{Parser, command}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; -use client::{Client, CloudUserStore, ProxySettings, UserStore, parse_zed_link}; +use client::{Client, ProxySettings, UserStore, parse_zed_link}; use collab_ui::channel_view::ChannelView; use collections::HashMap; use db::kvp::{GLOBAL_KEY_VALUE_STORE, KEY_VALUE_STORE}; @@ -42,7 +42,7 @@ use theme::{ ActiveTheme, IconThemeNotFoundError, SystemAppearance, ThemeNotFoundError, ThemeRegistry, ThemeSettings, }; -use util::{ConnectionResult, ResultExt, TryFutureExt, maybe}; +use util::{ResultExt, TryFutureExt, maybe}; use uuid::Uuid; use welcome::{FIRST_OPEN, show_welcome_view}; use workspace::{ @@ -457,8 +457,6 @@ pub fn main() { language::init(cx); languages::init(languages.clone(), node_runtime.clone(), cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); language_extension::init( @@ -518,7 +516,6 @@ pub fn main() { languages: languages.clone(), client: client.clone(), user_store: user_store.clone(), - cloud_user_store, fs: fs.clone(), build_window_options, workspace_store, @@ -556,12 +553,7 @@ pub fn main() { ); supermaven::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), cx); - language_models::init( - app_state.user_store.clone(), - app_state.cloud_user_store.clone(), - app_state.client.clone(), - cx, - ); + language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); agent_settings::init(cx); agent_servers::init(cx); web_search::init(cx); @@ -569,7 +561,7 @@ pub fn main() { snippet_provider::init(cx); inline_completion_registry::init( app_state.client.clone(), - app_state.cloud_user_store.clone(), + app_state.user_store.clone(), cx, ); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); @@ -690,17 +682,9 @@ pub fn main() { cx.spawn({ let client = app_state.client.clone(); - async move |cx| match authenticate(client, &cx).await { - ConnectionResult::Timeout => log::error!("Timeout during initial auth"), - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during initial auth") - } - ConnectionResult::Result(r) => { - r.log_err(); - } - } + async move |cx| authenticate(client, &cx).await }) - .detach(); + .detach_and_log_err(cx); let urls: Vec<_> = args .paths_or_urls @@ -850,15 +834,7 @@ fn handle_open_request(request: OpenRequest, app_state: Arc, cx: &mut let client = app_state.client.clone(); // we continue even if authentication fails as join_channel/ open channel notes will // show a visible error message. - match authenticate(client, &cx).await { - ConnectionResult::Timeout => { - log::error!("Timeout during open request handling") - } - ConnectionResult::ConnectionReset => { - log::error!("Connection reset during open request handling") - } - ConnectionResult::Result(r) => r?, - }; + authenticate(client, &cx).await.log_err(); if let Some(channel_id) = request.join_channel { cx.update(|cx| { @@ -908,18 +884,18 @@ fn handle_open_request(request: OpenRequest, app_state: Arc, cx: &mut } } -async fn authenticate(client: Arc, cx: &AsyncApp) -> ConnectionResult<()> { +async fn authenticate(client: Arc, cx: &AsyncApp) -> Result<()> { if stdout_is_a_pty() { if client::IMPERSONATE_LOGIN.is_some() { - return client.authenticate_and_connect(false, cx).await; + client.sign_in_with_optional_connect(false, cx).await?; } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } } else if client.has_credentials(cx).await { - return client.authenticate_and_connect(true, cx).await; + client.sign_in_with_optional_connect(true, cx).await?; } - ConnectionResult::Result(Ok(())) + Ok(()) } async fn system_id() -> Result { diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 060efdf26a..8c6da335ab 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -336,7 +336,7 @@ pub fn initialize_workspace( let edit_prediction_button = cx.new(|cx| { inline_completion_button::InlineCompletionButton::new( app_state.fs.clone(), - app_state.cloud_user_store.clone(), + app_state.user_store.clone(), inline_completion_menu_handle.clone(), cx, ) @@ -4488,12 +4488,7 @@ mod tests { ); image_viewer::init(cx); language_model::init(app_state.client.clone(), cx); - language_models::init( - app_state.user_store.clone(), - app_state.cloud_user_store.clone(), - app_state.client.clone(), - cx, - ); + language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); web_search::init(cx); web_search_providers::init(app_state.client.clone(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx); diff --git a/crates/zed/src/zed/component_preview.rs b/crates/zed/src/zed/component_preview.rs index 2e57152c62..480505338b 100644 --- a/crates/zed/src/zed/component_preview.rs +++ b/crates/zed/src/zed/component_preview.rs @@ -139,8 +139,7 @@ impl ComponentPreview { let project_clone = project.clone(); cx.spawn_in(window, async move |entity, cx| { - let thread_store_future = - load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx); + let thread_store_future = load_preview_thread_store(project_clone.clone(), cx); let text_thread_store_future = load_preview_text_thread_store(workspace_clone.clone(), project_clone.clone(), cx); diff --git a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs index 1076ee49ea..de98106fae 100644 --- a/crates/zed/src/zed/component_preview/preview_support/active_thread.rs +++ b/crates/zed/src/zed/component_preview/preview_support/active_thread.rs @@ -12,22 +12,19 @@ use ui::{App, Window}; use workspace::Workspace; pub fn load_preview_thread_store( - workspace: WeakEntity, project: Entity, cx: &mut AsyncApp, ) -> Task>> { - workspace - .update(cx, |workspace, cx| { - ThreadStore::load( - project.clone(), - workspace.app_state().cloud_user_store.clone(), - cx.new(|_| ToolWorkingSet::default()), - None, - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - }) - .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) + cx.update(|cx| { + ThreadStore::load( + project.clone(), + cx.new(|_| ToolWorkingSet::default()), + None, + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .unwrap_or(Task::ready(Err(anyhow!("workspace dropped")))) } pub fn load_preview_text_thread_store( diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index ba19457d39..55dbea4fe1 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -1,4 +1,4 @@ -use client::{Client, CloudUserStore, DisableAiSettings}; +use client::{Client, DisableAiSettings, UserStore}; use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::Editor; @@ -13,12 +13,12 @@ use util::ResultExt; use workspace::Workspace; use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; -pub fn init(client: Arc, cloud_user_store: Entity, cx: &mut App) { +pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let editors: Rc, AnyWindowHandle>>> = Rc::default(); cx.observe_new({ let editors = editors.clone(); let client = client.clone(); - let cloud_user_store = cloud_user_store.clone(); + let user_store = user_store.clone(); move |editor: &mut Editor, window, cx: &mut Context| { if !editor.mode().is_full() { return; @@ -48,7 +48,7 @@ pub fn init(client: Arc, cloud_user_store: Entity, cx: & editor, provider, &client, - cloud_user_store.clone(), + user_store.clone(), window, cx, ); @@ -60,7 +60,7 @@ pub fn init(client: Arc, cloud_user_store: Entity, cx: & let mut provider = all_language_settings(None, cx).edit_predictions.provider; cx.spawn({ - let cloud_user_store = cloud_user_store.clone(); + let user_store = user_store.clone(); let editors = editors.clone(); let client = client.clone(); @@ -72,7 +72,7 @@ pub fn init(client: Arc, cloud_user_store: Entity, cx: & &editors, provider, &client, - cloud_user_store.clone(), + user_store.clone(), cx, ); }) @@ -85,12 +85,12 @@ pub fn init(client: Arc, cloud_user_store: Entity, cx: & cx.observe_global::({ let editors = editors.clone(); let client = client.clone(); - let cloud_user_store = cloud_user_store.clone(); + let user_store = user_store.clone(); move |cx| { let new_provider = all_language_settings(None, cx).edit_predictions.provider; if new_provider != provider { - let tos_accepted = cloud_user_store.read(cx).has_accepted_tos(); + let tos_accepted = user_store.read(cx).has_accepted_terms_of_service(); telemetry::event!( "Edit Prediction Provider Changed", @@ -104,7 +104,7 @@ pub fn init(client: Arc, cloud_user_store: Entity, cx: & &editors, provider, &client, - cloud_user_store.clone(), + user_store.clone(), cx, ); @@ -145,7 +145,7 @@ fn assign_edit_prediction_providers( editors: &Rc, AnyWindowHandle>>>, provider: EditPredictionProvider, client: &Arc, - cloud_user_store: Entity, + user_store: Entity, cx: &mut App, ) { for (editor, window) in editors.borrow().iter() { @@ -155,7 +155,7 @@ fn assign_edit_prediction_providers( editor, provider, &client, - cloud_user_store.clone(), + user_store.clone(), window, cx, ); @@ -210,7 +210,7 @@ fn assign_edit_prediction_provider( editor: &mut Editor, provider: EditPredictionProvider, client: &Arc, - cloud_user_store: Entity, + user_store: Entity, window: &mut Window, cx: &mut Context, ) { @@ -241,7 +241,7 @@ fn assign_edit_prediction_provider( } } EditPredictionProvider::Zed => { - if cloud_user_store.read(cx).is_authenticated() { + if user_store.read(cx).current_user().is_some() { let mut worktree = None; if let Some(buffer) = &singleton_buffer { @@ -263,7 +263,7 @@ fn assign_edit_prediction_provider( .map(|workspace| workspace.downgrade()); let zeta = - zeta::Zeta::register(workspace, worktree, client.clone(), cloud_user_store, cx); + zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx); if let Some(buffer) = &singleton_buffer { if buffer.read(cx).file().is_some() { diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 0ef6bef59d..18b9217b95 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -16,7 +16,7 @@ pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; -use client::{Client, CloudUserStore, EditPredictionUsage}; +use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, @@ -120,8 +120,8 @@ impl Dismissable for ZedPredictUpsell { } } -pub fn should_show_upsell_modal(cloud_user_store: &Entity, cx: &App) -> bool { - if cloud_user_store.read(cx).has_accepted_tos() { +pub fn should_show_upsell_modal(user_store: &Entity, cx: &App) -> bool { + if user_store.read(cx).has_accepted_terms_of_service() { !ZedPredictUpsell::dismissed() } else { true @@ -229,7 +229,7 @@ pub struct Zeta { _llm_token_subscription: Subscription, /// Whether an update to a newer version of Zed is required to continue using Zeta. update_required: bool, - cloud_user_store: Entity, + user_store: Entity, license_detection_watchers: HashMap>, } @@ -242,11 +242,11 @@ impl Zeta { workspace: Option>, worktree: Option>, client: Arc, - cloud_user_store: Entity, + user_store: Entity, cx: &mut App, ) -> Entity { let this = Self::global(cx).unwrap_or_else(|| { - let entity = cx.new(|cx| Self::new(workspace, client, cloud_user_store, cx)); + let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx)); cx.set_global(ZetaGlobal(entity.clone())); entity }); @@ -269,13 +269,13 @@ impl Zeta { } pub fn usage(&self, cx: &App) -> Option { - self.cloud_user_store.read(cx).edit_prediction_usage() + self.user_store.read(cx).edit_prediction_usage() } fn new( workspace: Option>, client: Arc, - cloud_user_store: Entity, + user_store: Entity, cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); @@ -306,7 +306,7 @@ impl Zeta { ), update_required: false, license_detection_watchers: HashMap::default(), - cloud_user_store, + user_store, } } @@ -535,8 +535,8 @@ impl Zeta { if let Some(usage) = usage { this.update(cx, |this, cx| { - this.cloud_user_store.update(cx, |cloud_user_store, cx| { - cloud_user_store.update_edit_prediction_usage(usage, cx); + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); }); }) .ok(); @@ -877,8 +877,8 @@ and then another if response.status().is_success() { if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { this.update(cx, |this, cx| { - this.cloud_user_store.update(cx, |cloud_user_store, cx| { - cloud_user_store.update_edit_prediction_usage(usage, cx); + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); }); })?; } @@ -1559,9 +1559,9 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider !self .zeta .read(cx) - .cloud_user_store + .user_store .read(cx) - .has_accepted_tos() + .has_accepted_terms_of_service() } fn is_refreshing(&self) -> bool { @@ -1587,7 +1587,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider if self .zeta .read(cx) - .cloud_user_store + .user_store .read_with(cx, |cloud_user_store, _cx| { cloud_user_store.account_too_young() || cloud_user_store.has_overdue_invoices() }) @@ -1808,10 +1808,7 @@ mod tests { use client::UserStore; use client::test::FakeServer; use clock::FakeSystemClock; - use cloud_api_types::{ - AuthenticatedUser, CreateLlmTokenResponse, GetAuthenticatedUserResponse, LlmToken, PlanInfo, - }; - use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit}; + use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; @@ -1820,39 +1817,6 @@ mod tests { use super::*; - fn make_get_authenticated_user_response() -> GetAuthenticatedUserResponse { - GetAuthenticatedUserResponse { - user: AuthenticatedUser { - id: 1, - metrics_id: "metrics-id-1".to_string(), - avatar_url: "".to_string(), - github_login: "".to_string(), - name: None, - is_staff: false, - accepted_tos_at: None, - }, - feature_flags: vec![], - plan: PlanInfo { - plan: Plan::ZedPro, - subscription_period: None, - usage: CurrentUsage { - model_requests: UsageData { - used: 0, - limit: UsageLimit::Limited(500), - }, - edit_predictions: UsageData { - used: 250, - limit: UsageLimit::Unlimited, - }, - }, - trial_started_at: None, - is_usage_based_billing_enabled: false, - is_account_too_young: false, - has_overdue_invoices: false, - }, - } - } - #[gpui::test] async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); @@ -2054,14 +2018,6 @@ mod tests { let http_client = FakeHttpClient::create(move |req| async move { match (req.method(), req.uri().path()) { - (&Method::GET, "/client/users/me") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&make_get_authenticated_user_response()) - .unwrap() - .into(), - ) - .unwrap()), (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() .status(200) .body( @@ -2098,9 +2054,7 @@ mod tests { // Construct the fake server to authenticate. let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -2128,14 +2082,6 @@ mod tests { let completion = completion_response.clone(); async move { match (req.method(), req.uri().path()) { - (&Method::GET, "/client/users/me") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&make_get_authenticated_user_response()) - .unwrap() - .into(), - ) - .unwrap()), (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() .status(200) .body( @@ -2172,9 +2118,7 @@ mod tests { // Construct the fake server to authenticate. let _server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let cloud_user_store = - cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());