Start separating authentication from connection to collab (#35471)

This pull request should be idempotent, but lays the groundwork for
avoiding to connect to collab in order to interact with AI features
provided by Zed.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Antonio Scandurra 2025-08-01 19:37:38 +02:00 committed by GitHub
parent b01d1872cc
commit f888f3fc0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
46 changed files with 653 additions and 855 deletions

24
Cargo.lock generated
View file

@ -114,7 +114,6 @@ dependencies = [
"pretty_assertions", "pretty_assertions",
"project", "project",
"prompt_store", "prompt_store",
"proto",
"rand 0.8.5", "rand 0.8.5",
"ref-cast", "ref-cast",
"rope", "rope",
@ -359,7 +358,6 @@ dependencies = [
"component", "component",
"gpui", "gpui",
"language_model", "language_model",
"proto",
"serde", "serde",
"smallvec", "smallvec",
"telemetry", "telemetry",
@ -1076,17 +1074,6 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "async-recursion"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "async-recursion" name = "async-recursion"
version = "1.1.1" version = "1.1.1"
@ -2972,7 +2959,6 @@ name = "client"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-recursion 0.3.2",
"async-tungstenite", "async-tungstenite",
"base64 0.22.1", "base64 0.22.1",
"chrono", "chrono",
@ -7814,6 +7800,7 @@ dependencies = [
"http 1.3.1", "http 1.3.1",
"http-body 1.0.1", "http-body 1.0.1",
"log", "log",
"parking_lot",
"serde", "serde",
"serde_json", "serde_json",
"url", "url",
@ -9085,7 +9072,6 @@ dependencies = [
"open_router", "open_router",
"partial-json-fixer", "partial-json-fixer",
"project", "project",
"proto",
"release_channel", "release_channel",
"schemars", "schemars",
"serde", "serde",
@ -9823,7 +9809,7 @@ name = "markdown_preview"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-recursion 1.1.1", "async-recursion",
"collections", "collections",
"editor", "editor",
"fs", "fs",
@ -16192,7 +16178,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assistant_slash_command", "assistant_slash_command",
"async-recursion 1.1.1", "async-recursion",
"breadcrumbs", "breadcrumbs",
"client", "client",
"collections", "collections",
@ -19617,7 +19603,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"any_vec", "any_vec",
"anyhow", "anyhow",
"async-recursion 1.1.1", "async-recursion",
"bincode", "bincode",
"call", "call",
"client", "client",
@ -20142,7 +20128,7 @@ dependencies = [
"async-io", "async-io",
"async-lock", "async-lock",
"async-process", "async-process",
"async-recursion 1.1.1", "async-recursion",
"async-task", "async-task",
"async-trait", "async-trait",
"blocking", "blocking",

View file

@ -47,7 +47,6 @@ paths.workspace = true
postage.workspace = true postage.workspace = true
project.workspace = true project.workspace = true
prompt_store.workspace = true prompt_store.workspace = true
proto.workspace = true
ref-cast.workspace = true ref-cast.workspace = true
rope.workspace = true rope.workspace = true
schemars.workspace = true schemars.workspace = true

View file

@ -12,8 +12,8 @@ use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use client::{CloudUserStore, ModelRequestUsage, RequestUsage}; use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
use collections::HashMap; use collections::HashMap;
use feature_flags::{self, FeatureFlagAppExt}; use feature_flags::{self, FeatureFlagAppExt};
use futures::{FutureExt, StreamExt as _, future::Shared}; use futures::{FutureExt, StreamExt as _, future::Shared};
@ -37,7 +37,6 @@ use project::{
git_store::{GitStore, GitStoreCheckpoint, RepositoryState}, git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
}; };
use prompt_store::{ModelContext, PromptBuilder}; use prompt_store::{ModelContext, PromptBuilder};
use proto::Plan;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::Settings; use settings::Settings;
@ -374,7 +373,6 @@ pub struct Thread {
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
project: Entity<Project>, project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
tool_use: ToolUseState, tool_use: ToolUseState,
@ -445,7 +443,6 @@ pub struct ExceededWindowError {
impl Thread { impl Thread {
pub fn new( pub fn new(
project: Entity<Project>, project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext, system_prompt: SharedProjectContext,
@ -472,7 +469,6 @@ impl Thread {
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
project: project.clone(), project: project.clone(),
cloud_user_store,
prompt_builder, prompt_builder,
tools: tools.clone(), tools: tools.clone(),
last_restore_checkpoint: None, last_restore_checkpoint: None,
@ -506,7 +502,6 @@ impl Thread {
id: ThreadId, id: ThreadId,
serialized: SerializedThread, serialized: SerializedThread,
project: Entity<Project>, project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext, project_context: SharedProjectContext,
@ -607,7 +602,6 @@ impl Thread {
last_restore_checkpoint: None, last_restore_checkpoint: None,
pending_checkpoint: None, pending_checkpoint: None,
project: project.clone(), project: project.clone(),
cloud_user_store,
prompt_builder, prompt_builder,
tools: tools.clone(), tools: tools.clone(),
tool_use, tool_use,
@ -3260,15 +3254,18 @@ impl Thread {
} }
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) { fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
self.cloud_user_store.update(cx, |cloud_user_store, cx| { self.project
cloud_user_store.update_model_request_usage( .read(cx)
ModelRequestUsage(RequestUsage { .user_store()
amount: amount as i32, .update(cx, |user_store, cx| {
limit, user_store.update_model_request_usage(
}), ModelRequestUsage(RequestUsage {
cx, amount: amount as i32,
) limit,
}); }),
cx,
)
});
} }
pub fn deny_tool_use( pub fn deny_tool_use(
@ -3886,7 +3883,6 @@ fn main() {{
thread.id.clone(), thread.id.clone(),
serialized, serialized,
thread.project.clone(), thread.project.clone(),
thread.cloud_user_store.clone(),
thread.tools.clone(), thread.tools.clone(),
thread.prompt_builder.clone(), thread.prompt_builder.clone(),
thread.project_context.clone(), thread.project_context.clone(),
@ -5483,16 +5479,10 @@ fn main() {{
let (workspace, cx) = let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let (client, user_store) =
project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
let thread_store = cx let thread_store = cx
.update(|_, cx| { .update(|_, cx| {
ThreadStore::load( ThreadStore::load(
project.clone(), project.clone(),
cloud_user_store,
cx.new(|_| ToolWorkingSet::default()), cx.new(|_| ToolWorkingSet::default()),
None, None,
Arc::new(PromptBuilder::new(None).unwrap()), Arc::new(PromptBuilder::new(None).unwrap()),

View file

@ -8,7 +8,6 @@ use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{Tool, ToolId, ToolWorkingSet}; use assistant_tool::{Tool, ToolId, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use client::CloudUserStore;
use collections::HashMap; use collections::HashMap;
use context_server::ContextServerId; use context_server::ContextServerId;
use futures::{ use futures::{
@ -105,7 +104,6 @@ pub type TextThreadStore = assistant_context::ContextStore;
pub struct ThreadStore { pub struct ThreadStore {
project: Entity<Project>, project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>, prompt_store: Option<Entity<PromptStore>>,
@ -126,7 +124,6 @@ impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore { impl ThreadStore {
pub fn load( pub fn load(
project: Entity<Project>, project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_store: Option<Entity<PromptStore>>, prompt_store: Option<Entity<PromptStore>>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
@ -136,14 +133,8 @@ impl ThreadStore {
let (thread_store, ready_rx) = cx.update(|cx| { let (thread_store, ready_rx) = cx.update(|cx| {
let mut option_ready_rx = None; let mut option_ready_rx = None;
let thread_store = cx.new(|cx| { let thread_store = cx.new(|cx| {
let (thread_store, ready_rx) = Self::new( let (thread_store, ready_rx) =
project, Self::new(project, tools, prompt_builder, prompt_store, cx);
cloud_user_store,
tools,
prompt_builder,
prompt_store,
cx,
);
option_ready_rx = Some(ready_rx); option_ready_rx = Some(ready_rx);
thread_store thread_store
}); });
@ -156,7 +147,6 @@ impl ThreadStore {
fn new( fn new(
project: Entity<Project>, project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>, tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>, prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>, prompt_store: Option<Entity<PromptStore>>,
@ -200,7 +190,6 @@ impl ThreadStore {
let this = Self { let this = Self {
project, project,
cloud_user_store,
tools, tools,
prompt_builder, prompt_builder,
prompt_store, prompt_store,
@ -418,7 +407,6 @@ impl ThreadStore {
cx.new(|cx| { cx.new(|cx| {
Thread::new( Thread::new(
self.project.clone(), self.project.clone(),
self.cloud_user_store.clone(),
self.tools.clone(), self.tools.clone(),
self.prompt_builder.clone(), self.prompt_builder.clone(),
self.project_context.clone(), self.project_context.clone(),
@ -437,7 +425,6 @@ impl ThreadStore {
ThreadId::new(), ThreadId::new(),
serialized, serialized,
self.project.clone(), self.project.clone(),
self.cloud_user_store.clone(),
self.tools.clone(), self.tools.clone(),
self.prompt_builder.clone(), self.prompt_builder.clone(),
self.project_context.clone(), self.project_context.clone(),
@ -469,7 +456,6 @@ impl ThreadStore {
id.clone(), id.clone(),
thread, thread,
this.project.clone(), this.project.clone(),
this.cloud_user_store.clone(),
this.tools.clone(), this.tools.clone(),
this.prompt_builder.clone(), this.prompt_builder.clone(),
this.project_context.clone(), this.project_context.clone(),

View file

@ -3820,7 +3820,6 @@ mod tests {
use super::*; use super::*;
use agent::{MessageSegment, context::ContextLoadResult, thread_store}; use agent::{MessageSegment, context::ContextLoadResult, thread_store};
use assistant_tool::{ToolRegistry, ToolWorkingSet}; use assistant_tool::{ToolRegistry, ToolWorkingSet};
use client::CloudUserStore;
use editor::EditorSettings; use editor::EditorSettings;
use fs::FakeFs; use fs::FakeFs;
use gpui::{AppContext, TestAppContext, VisualTestContext}; use gpui::{AppContext, TestAppContext, VisualTestContext};
@ -4117,16 +4116,10 @@ mod tests {
let (workspace, cx) = let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let (client, user_store) =
project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
let thread_store = cx let thread_store = cx
.update(|_, cx| { .update(|_, cx| {
ThreadStore::load( ThreadStore::load(
project.clone(), project.clone(),
cloud_user_store,
cx.new(|_| ToolWorkingSet::default()), cx.new(|_| ToolWorkingSet::default()),
None, None,
Arc::new(PromptBuilder::new(None).unwrap()), Arc::new(PromptBuilder::new(None).unwrap()),

View file

@ -7,6 +7,7 @@ use std::{sync::Arc, time::Duration};
use agent_settings::AgentSettings; use agent_settings::AgentSettings;
use assistant_tool::{ToolSource, ToolWorkingSet}; use assistant_tool::{ToolSource, ToolWorkingSet};
use cloud_llm_client::Plan;
use collections::HashMap; use collections::HashMap;
use context_server::ContextServerId; use context_server::ContextServerId;
use extension::ExtensionManifest; use extension::ExtensionManifest;
@ -25,7 +26,6 @@ use project::{
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore}, context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
project_settings::{ContextServerSettings, ProjectSettings}, project_settings::{ContextServerSettings, ProjectSettings},
}; };
use proto::Plan;
use settings::{Settings, update_settings_file}; use settings::{Settings, update_settings_file};
use ui::{ use ui::{
Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu, Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
@ -180,7 +180,7 @@ impl AgentConfiguration {
let current_plan = if is_zed_provider { let current_plan = if is_zed_provider {
self.workspace self.workspace
.upgrade() .upgrade()
.and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan()) .and_then(|workspace| workspace.read(cx).user_store().read(cx).plan())
} else { } else {
None None
}; };
@ -508,7 +508,7 @@ impl AgentConfiguration {
.blend(cx.theme().colors().text_accent.opacity(0.2)); .blend(cx.theme().colors().text_accent.opacity(0.2));
let (plan_name, label_color, bg_color) = match plan { let (plan_name, label_color, bg_color) = match plan {
Plan::Free => ("Free", Color::Default, free_chip_bg), Plan::ZedFree => ("Free", Color::Default, free_chip_bg),
Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg), Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg),
Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg), Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg),
}; };

View file

@ -1896,7 +1896,6 @@ mod tests {
use agent::thread_store::{self, ThreadStore}; use agent::thread_store::{self, ThreadStore};
use agent_settings::AgentSettings; use agent_settings::AgentSettings;
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use client::CloudUserStore;
use editor::EditorSettings; use editor::EditorSettings;
use gpui::{TestAppContext, UpdateGlobal, VisualTestContext}; use gpui::{TestAppContext, UpdateGlobal, VisualTestContext};
use project::{FakeFs, Project}; use project::{FakeFs, Project};
@ -1936,17 +1935,11 @@ mod tests {
}) })
.unwrap(); .unwrap();
let (client, user_store) =
project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
let prompt_store = None; let prompt_store = None;
let thread_store = cx let thread_store = cx
.update(|cx| { .update(|cx| {
ThreadStore::load( ThreadStore::load(
project.clone(), project.clone(),
cloud_user_store,
cx.new(|_| ToolWorkingSet::default()), cx.new(|_| ToolWorkingSet::default()),
prompt_store, prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()), Arc::new(PromptBuilder::new(None).unwrap()),
@ -2108,17 +2101,11 @@ mod tests {
}) })
.unwrap(); .unwrap();
let (client, user_store) =
project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
let prompt_store = None; let prompt_store = None;
let thread_store = cx let thread_store = cx
.update(|cx| { .update(|cx| {
ThreadStore::load( ThreadStore::load(
project.clone(), project.clone(),
cloud_user_store,
cx.new(|_| ToolWorkingSet::default()), cx.new(|_| ToolWorkingSet::default()),
prompt_store, prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()), Arc::new(PromptBuilder::new(None).unwrap()),

View file

@ -43,8 +43,8 @@ use anyhow::{Result, anyhow};
use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_context::{AssistantContext, ContextEvent, ContextSummary};
use assistant_slash_command::SlashCommandWorkingSet; use assistant_slash_command::SlashCommandWorkingSet;
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use client::{CloudUserStore, DisableAiSettings, UserStore, zed_urls}; use client::{DisableAiSettings, UserStore, zed_urls};
use cloud_llm_client::{CompletionIntent, UsageLimit}; use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use feature_flags::{self, FeatureFlagAppExt}; use feature_flags::{self, FeatureFlagAppExt};
use fs::Fs; use fs::Fs;
@ -60,7 +60,6 @@ use language_model::{
}; };
use project::{Project, ProjectPath, Worktree}; use project::{Project, ProjectPath, Worktree};
use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
use proto::Plan;
use rules_library::{RulesLibrary, open_rules_library}; use rules_library::{RulesLibrary, open_rules_library};
use search::{BufferSearchBar, buffer_search}; use search::{BufferSearchBar, buffer_search};
use settings::{Settings, update_settings_file}; use settings::{Settings, update_settings_file};
@ -427,7 +426,6 @@ impl ActiveView {
pub struct AgentPanel { pub struct AgentPanel {
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
project: Entity<Project>, project: Entity<Project>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
@ -487,7 +485,6 @@ impl AgentPanel {
let project = workspace.project().clone(); let project = workspace.project().clone();
ThreadStore::load( ThreadStore::load(
project, project,
workspace.app_state().cloud_user_store.clone(),
tools.clone(), tools.clone(),
prompt_store.clone(), prompt_store.clone(),
prompt_builder.clone(), prompt_builder.clone(),
@ -555,7 +552,6 @@ impl AgentPanel {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx)); let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone(); let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone(); let user_store = workspace.app_state().user_store.clone();
let cloud_user_store = workspace.app_state().cloud_user_store.clone();
let project = workspace.project(); let project = workspace.project();
let language_registry = project.read(cx).languages().clone(); let language_registry = project.read(cx).languages().clone();
let client = workspace.client().clone(); let client = workspace.client().clone();
@ -582,7 +578,6 @@ impl AgentPanel {
MessageEditor::new( MessageEditor::new(
fs.clone(), fs.clone(),
workspace.clone(), workspace.clone(),
cloud_user_store.clone(),
message_editor_context_store.clone(), message_editor_context_store.clone(),
prompt_store.clone(), prompt_store.clone(),
thread_store.downgrade(), thread_store.downgrade(),
@ -697,7 +692,6 @@ impl AgentPanel {
let onboarding = cx.new(|cx| { let onboarding = cx.new(|cx| {
AgentPanelOnboarding::new( AgentPanelOnboarding::new(
user_store.clone(), user_store.clone(),
cloud_user_store.clone(),
client, client,
|_window, cx| { |_window, cx| {
OnboardingUpsell::set_dismissed(true, cx); OnboardingUpsell::set_dismissed(true, cx);
@ -710,7 +704,6 @@ impl AgentPanel {
active_view, active_view,
workspace, workspace,
user_store, user_store,
cloud_user_store,
project: project.clone(), project: project.clone(),
fs: fs.clone(), fs: fs.clone(),
language_registry, language_registry,
@ -853,7 +846,6 @@ impl AgentPanel {
MessageEditor::new( MessageEditor::new(
self.fs.clone(), self.fs.clone(),
self.workspace.clone(), self.workspace.clone(),
self.cloud_user_store.clone(),
context_store.clone(), context_store.clone(),
self.prompt_store.clone(), self.prompt_store.clone(),
self.thread_store.downgrade(), self.thread_store.downgrade(),
@ -1127,7 +1119,6 @@ impl AgentPanel {
MessageEditor::new( MessageEditor::new(
self.fs.clone(), self.fs.clone(),
self.workspace.clone(), self.workspace.clone(),
self.cloud_user_store.clone(),
context_store, context_store,
self.prompt_store.clone(), self.prompt_store.clone(),
self.thread_store.downgrade(), self.thread_store.downgrade(),
@ -1826,8 +1817,8 @@ impl AgentPanel {
} }
fn render_toolbar(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { fn render_toolbar(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let cloud_user_store = self.cloud_user_store.read(cx); let user_store = self.user_store.read(cx);
let usage = cloud_user_store.model_request_usage(); let usage = user_store.model_request_usage();
let account_url = zed_urls::account_url(cx); let account_url = zed_urls::account_url(cx);
@ -2298,10 +2289,10 @@ impl AgentPanel {
| ActiveView::Configuration => return false, | ActiveView::Configuration => return false,
} }
let plan = self.user_store.read(cx).current_plan(); let plan = self.user_store.read(cx).plan();
let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some(); let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some();
matches!(plan, Some(Plan::Free)) && has_previous_trial matches!(plan, Some(Plan::ZedFree)) && has_previous_trial
} }
fn should_render_onboarding(&self, cx: &mut Context<Self>) -> bool { fn should_render_onboarding(&self, cx: &mut Context<Self>) -> bool {
@ -2916,7 +2907,7 @@ impl AgentPanel {
) -> AnyElement { ) -> AnyElement {
let error_message = match plan { let error_message = match plan {
Plan::ZedPro => "Upgrade to usage-based billing for more prompts.", Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.", Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.",
}; };
let icon = Icon::new(IconName::XCircle) let icon = Icon::new(IconName::XCircle)

View file

@ -17,7 +17,6 @@ use agent::{
use agent_settings::{AgentSettings, CompletionMode}; use agent_settings::{AgentSettings, CompletionMode};
use ai_onboarding::ApiKeysWithProviders; use ai_onboarding::ApiKeysWithProviders;
use buffer_diff::BufferDiff; use buffer_diff::BufferDiff;
use client::CloudUserStore;
use cloud_llm_client::CompletionIntent; use cloud_llm_client::CompletionIntent;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste}; use editor::actions::{MoveUp, Paste};
@ -78,7 +77,6 @@ pub struct MessageEditor {
editor: Entity<Editor>, editor: Entity<Editor>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
project: Entity<Project>, project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
context_store: Entity<ContextStore>, context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>, prompt_store: Option<Entity<PromptStore>>,
history_store: Option<WeakEntity<HistoryStore>>, history_store: Option<WeakEntity<HistoryStore>>,
@ -158,7 +156,6 @@ impl MessageEditor {
pub fn new( pub fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
cloud_user_store: Entity<CloudUserStore>,
context_store: Entity<ContextStore>, context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>, prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>, thread_store: WeakEntity<ThreadStore>,
@ -230,7 +227,6 @@ impl MessageEditor {
Self { Self {
editor: editor.clone(), editor: editor.clone(),
project: thread.read(cx).project().clone(), project: thread.read(cx).project().clone(),
cloud_user_store,
thread, thread,
incompatible_tools_state: incompatible_tools.clone(), incompatible_tools_state: incompatible_tools.clone(),
workspace, workspace,
@ -1286,16 +1282,14 @@ impl MessageEditor {
return None; return None;
} }
let cloud_user_store = self.cloud_user_store.read(cx); let user_store = self.project.read(cx).user_store().read(cx);
if cloud_user_store.is_usage_based_billing_enabled() { if user_store.is_usage_based_billing_enabled() {
return None; return None;
} }
let plan = cloud_user_store let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree);
.plan()
.unwrap_or(cloud_llm_client::Plan::ZedFree);
let usage = cloud_user_store.model_request_usage()?; let usage = user_store.model_request_usage()?;
Some( Some(
div() div()
@ -1758,7 +1752,6 @@ impl AgentPreview for MessageEditor {
) -> Option<AnyElement> { ) -> Option<AnyElement> {
if let Some(workspace) = workspace.upgrade() { if let Some(workspace) = workspace.upgrade() {
let fs = workspace.read(cx).app_state().fs.clone(); let fs = workspace.read(cx).app_state().fs.clone();
let cloud_user_store = workspace.read(cx).app_state().cloud_user_store.clone();
let project = workspace.read(cx).project().clone(); let project = workspace.read(cx).project().clone();
let weak_project = project.downgrade(); let weak_project = project.downgrade();
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None)); let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
@ -1771,7 +1764,6 @@ impl AgentPreview for MessageEditor {
MessageEditor::new( MessageEditor::new(
fs, fs,
workspace.downgrade(), workspace.downgrade(),
cloud_user_store,
context_store, context_store,
None, None,
thread_store.downgrade(), thread_store.downgrade(),

View file

@ -20,7 +20,6 @@ cloud_llm_client.workspace = true
component.workspace = true component.workspace = true
gpui.workspace = true gpui.workspace = true
language_model.workspace = true language_model.workspace = true
proto.workspace = true
serde.workspace = true serde.workspace = true
smallvec.workspace = true smallvec.workspace = true
telemetry.workspace = true telemetry.workspace = true

View file

@ -1,6 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use client::{Client, CloudUserStore, UserStore}; use client::{Client, UserStore};
use cloud_llm_client::Plan; use cloud_llm_client::Plan;
use gpui::{Entity, IntoElement, ParentElement}; use gpui::{Entity, IntoElement, ParentElement};
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
@ -10,7 +10,6 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding};
pub struct AgentPanelOnboarding { pub struct AgentPanelOnboarding {
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>, client: Arc<Client>,
configured_providers: Vec<(IconName, SharedString)>, configured_providers: Vec<(IconName, SharedString)>,
continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>, continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
@ -19,7 +18,6 @@ pub struct AgentPanelOnboarding {
impl AgentPanelOnboarding { impl AgentPanelOnboarding {
pub fn new( pub fn new(
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>, client: Arc<Client>,
continue_with_zed_ai: impl Fn(&mut Window, &mut App) + 'static, continue_with_zed_ai: impl Fn(&mut Window, &mut App) + 'static,
cx: &mut Context<Self>, cx: &mut Context<Self>,
@ -39,7 +37,6 @@ impl AgentPanelOnboarding {
Self { Self {
user_store, user_store,
cloud_user_store,
client, client,
configured_providers: Self::compute_available_providers(cx), configured_providers: Self::compute_available_providers(cx),
continue_with_zed_ai: Arc::new(continue_with_zed_ai), continue_with_zed_ai: Arc::new(continue_with_zed_ai),
@ -60,8 +57,8 @@ impl AgentPanelOnboarding {
impl Render for AgentPanelOnboarding { impl Render for AgentPanelOnboarding {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let enrolled_in_trial = self.cloud_user_store.read(cx).plan() == Some(Plan::ZedProTrial); let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial);
let is_pro_user = self.cloud_user_store.read(cx).plan() == Some(Plan::ZedPro); let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro);
AgentPanelOnboardingCard::new() AgentPanelOnboardingCard::new()
.child( .child(

View file

@ -9,6 +9,7 @@ pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProvider
pub use agent_panel_onboarding_card::AgentPanelOnboardingCard; pub use agent_panel_onboarding_card::AgentPanelOnboardingCard;
pub use agent_panel_onboarding_content::AgentPanelOnboarding; pub use agent_panel_onboarding_content::AgentPanelOnboarding;
pub use ai_upsell_card::AiUpsellCard; pub use ai_upsell_card::AiUpsellCard;
use cloud_llm_client::Plan;
pub use edit_prediction_onboarding_content::EditPredictionOnboarding; pub use edit_prediction_onboarding_content::EditPredictionOnboarding;
pub use young_account_banner::YoungAccountBanner; pub use young_account_banner::YoungAccountBanner;
@ -79,7 +80,7 @@ impl From<client::Status> for SignInStatus {
pub struct ZedAiOnboarding { pub struct ZedAiOnboarding {
pub sign_in_status: SignInStatus, pub sign_in_status: SignInStatus,
pub has_accepted_terms_of_service: bool, pub has_accepted_terms_of_service: bool,
pub plan: Option<proto::Plan>, pub plan: Option<Plan>,
pub account_too_young: bool, pub account_too_young: bool,
pub continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>, pub continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>, pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
@ -99,8 +100,8 @@ impl ZedAiOnboarding {
Self { Self {
sign_in_status: status.into(), sign_in_status: status.into(),
has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false), has_accepted_terms_of_service: store.has_accepted_terms_of_service(),
plan: store.current_plan(), plan: store.plan(),
account_too_young: store.account_too_young(), account_too_young: store.account_too_young(),
continue_with_zed_ai, continue_with_zed_ai,
accept_terms_of_service: Arc::new({ accept_terms_of_service: Arc::new({
@ -113,11 +114,9 @@ impl ZedAiOnboarding {
sign_in: Arc::new(move |_window, cx| { sign_in: Arc::new(move |_window, cx| {
cx.spawn({ cx.spawn({
let client = client.clone(); let client = client.clone();
async move |cx| { async move |cx| client.sign_in_with_optional_connect(true, cx).await
client.authenticate_and_connect(true, cx).await;
}
}) })
.detach(); .detach_and_log_err(cx);
}), }),
dismiss_onboarding: None, dismiss_onboarding: None,
} }
@ -411,9 +410,9 @@ impl RenderOnce for ZedAiOnboarding {
if matches!(self.sign_in_status, SignInStatus::SignedIn) { if matches!(self.sign_in_status, SignInStatus::SignedIn) {
if self.has_accepted_terms_of_service { if self.has_accepted_terms_of_service {
match self.plan { match self.plan {
None | Some(proto::Plan::Free) => self.render_free_plan_state(cx), None | Some(Plan::ZedFree) => self.render_free_plan_state(cx),
Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx), Some(Plan::ZedProTrial) => self.render_trial_state(cx),
Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx), Some(Plan::ZedPro) => self.render_pro_plan_state(cx),
} }
} else { } else {
self.render_accept_terms_of_service() self.render_accept_terms_of_service()
@ -433,7 +432,7 @@ impl Component for ZedAiOnboarding {
fn onboarding( fn onboarding(
sign_in_status: SignInStatus, sign_in_status: SignInStatus,
has_accepted_terms_of_service: bool, has_accepted_terms_of_service: bool,
plan: Option<proto::Plan>, plan: Option<Plan>,
account_too_young: bool, account_too_young: bool,
) -> AnyElement { ) -> AnyElement {
ZedAiOnboarding { ZedAiOnboarding {
@ -468,25 +467,15 @@ impl Component for ZedAiOnboarding {
), ),
single_example( single_example(
"Free Plan", "Free Plan",
onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false), onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false),
), ),
single_example( single_example(
"Pro Trial", "Pro Trial",
onboarding( onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false),
SignInStatus::SignedIn,
true,
Some(proto::Plan::ZedProTrial),
false,
),
), ),
single_example( single_example(
"Pro Plan", "Pro Plan",
onboarding( onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false),
SignInStatus::SignedIn,
true,
Some(proto::Plan::ZedPro),
false,
),
), ),
]) ])
.into_any_element(), .into_any_element(),

View file

@ -24,11 +24,9 @@ impl AiUpsellCard {
sign_in: Arc::new(move |_window, cx| { sign_in: Arc::new(move |_window, cx| {
cx.spawn({ cx.spawn({
let client = client.clone(); let client = client.clone();
async move |cx| { async move |cx| client.sign_in_with_optional_connect(true, cx).await
client.authenticate_and_connect(true, cx).await;
}
}) })
.detach(); .detach_and_log_err(cx);
}), }),
} }
} }

View file

@ -7,7 +7,7 @@ use crate::{
}; };
use Role::*; use Role::*;
use assistant_tool::ToolRegistry; use assistant_tool::ToolRegistry;
use client::{Client, CloudUserStore, UserStore}; use client::{Client, UserStore};
use collections::HashMap; use collections::HashMap;
use fs::FakeFs; use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture}; use futures::{FutureExt, future::LocalBoxFuture};
@ -1470,14 +1470,12 @@ impl EditAgentTest {
client::init_settings(cx); client::init_settings(cx);
let client = Client::production(cx); let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
settings::init(cx); settings::init(cx);
Project::init_settings(cx); Project::init_settings(cx);
language::init(cx); language::init(cx);
language_model::init(client.clone(), cx); language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), cloud_user_store, client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx);
crate::init(client.http_client(), cx); crate::init(client.http_client(), cx);
}); });

View file

@ -259,20 +259,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx); assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx);
}); });
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
assert_eq!(get_users.payload.user_ids, vec![5]);
server.respond(
get_users.receipt(),
proto::UsersResponse {
users: vec![proto::User {
id: 5,
github_login: "nathansobo".into(),
avatar_url: "http://avatar.com/nathansobo".into(),
name: None,
}],
},
);
// Join a channel and populate its existing messages. // Join a channel and populate its existing messages.
let channel = channel_store.update(cx, |store, cx| { let channel = channel_store.update(cx, |store, cx| {
let channel_id = store.ordered_channels().next().unwrap().1.id; let channel_id = store.ordered_channels().next().unwrap().1.id;
@ -334,7 +320,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
.map(|message| (message.sender.github_login.clone(), message.body.clone())) .map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
&[ &[
("nathansobo".into(), "a".into()), ("user-5".into(), "a".into()),
("maxbrunsfeld".into(), "b".into()) ("maxbrunsfeld".into(), "b".into())
] ]
); );
@ -437,7 +423,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
.map(|message| (message.sender.github_login.clone(), message.body.clone())) .map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
&[ &[
("nathansobo".into(), "y".into()), ("user-5".into(), "y".into()),
("maxbrunsfeld".into(), "z".into()) ("maxbrunsfeld".into(), "z".into())
] ]
); );

View file

@ -17,7 +17,6 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
async-recursion = "0.3"
async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] } async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] }
base64.workspace = true base64.workspace = true
chrono = { workspace = true, features = ["serde"] } chrono = { workspace = true, features = ["serde"] }

View file

@ -1,14 +1,12 @@
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub mod test; pub mod test;
mod cloud;
mod proxy; mod proxy;
pub mod telemetry; pub mod telemetry;
pub mod user; pub mod user;
pub mod zed_urls; pub mod zed_urls;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use async_recursion::async_recursion;
use async_tungstenite::tungstenite::{ use async_tungstenite::tungstenite::{
client::IntoClientRequest, client::IntoClientRequest,
error::Error as WebsocketError, error::Error as WebsocketError,
@ -52,7 +50,6 @@ use tokio::net::TcpStream;
use url::Url; use url::Url;
use util::{ConnectionResult, ResultExt}; use util::{ConnectionResult, ResultExt};
pub use cloud::*;
pub use rpc::*; pub use rpc::*;
pub use telemetry_events::Event; pub use telemetry_events::Event;
pub use user::*; pub use user::*;
@ -164,20 +161,8 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
let client = client.clone(); let client = client.clone();
move |_: &SignIn, cx| { move |_: &SignIn, cx| {
if let Some(client) = client.upgrade() { if let Some(client) = client.upgrade() {
cx.spawn( cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await)
async move |cx| match client.authenticate_and_connect(true, &cx).await { .detach_and_log_err(cx);
ConnectionResult::Timeout => {
log::error!("Initial authentication timed out");
}
ConnectionResult::ConnectionReset => {
log::error!("Initial authentication connection reset");
}
ConnectionResult::Result(r) => {
r.log_err();
}
},
)
.detach();
} }
} }
}); });
@ -286,6 +271,8 @@ pub enum Status {
SignedOut, SignedOut,
UpgradeRequired, UpgradeRequired,
Authenticating, Authenticating,
Authenticated,
AuthenticationError,
Connecting, Connecting,
ConnectionError, ConnectionError,
Connected { Connected {
@ -712,7 +699,7 @@ impl Client {
let mut delay = INITIAL_RECONNECTION_DELAY; let mut delay = INITIAL_RECONNECTION_DELAY;
loop { loop {
match client.authenticate_and_connect(true, &cx).await { match client.connect(true, &cx).await {
ConnectionResult::Timeout => { ConnectionResult::Timeout => {
log::error!("client connect attempt timed out") log::error!("client connect attempt timed out")
} }
@ -882,17 +869,122 @@ impl Client {
.is_some() .is_some()
} }
#[async_recursion(?Send)] pub async fn sign_in(
pub async fn authenticate_and_connect( self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> Result<Credentials> {
if self.status().borrow().is_signed_out() {
self.set_status(Status::Authenticating, cx);
} else {
self.set_status(Status::Reauthenticating, cx);
}
let mut credentials = None;
let old_credentials = self.state.read().credentials.clone();
if let Some(old_credentials) = old_credentials {
self.cloud_client.set_credentials(
old_credentials.user_id as u32,
old_credentials.access_token.clone(),
);
// Fetch the authenticated user with the old credentials, to ensure they are still valid.
if self.cloud_client.get_authenticated_user().await.is_ok() {
credentials = Some(old_credentials);
}
}
if credentials.is_none() && try_provider {
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
self.cloud_client.set_credentials(
stored_credentials.user_id as u32,
stored_credentials.access_token.clone(),
);
// Fetch the authenticated user with the stored credentials, and
// clear them from the credentials provider if that fails.
if self.cloud_client.get_authenticated_user().await.is_ok() {
credentials = Some(stored_credentials);
} else {
self.credentials_provider
.delete_credentials(cx)
.await
.log_err();
}
}
}
if credentials.is_none() {
let mut status_rx = self.status();
let _ = status_rx.next().await;
futures::select_biased! {
authenticate = self.authenticate(cx).fuse() => {
match authenticate {
Ok(creds) => {
if IMPERSONATE_LOGIN.is_none() {
self.credentials_provider
.write_credentials(creds.user_id, creds.access_token.clone(), cx)
.await
.log_err();
}
credentials = Some(creds);
},
Err(err) => {
self.set_status(Status::AuthenticationError, cx);
return Err(err);
}
}
}
_ = status_rx.next().fuse() => {
return Err(anyhow!("authentication canceled"));
}
}
}
let credentials = credentials.unwrap();
self.set_id(credentials.user_id);
self.cloud_client
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
self.state.write().credentials = Some(credentials.clone());
self.set_status(Status::Authenticated, cx);
Ok(credentials)
}
/// Performs a sign-in and also connects to Collab.
///
/// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls
/// to `sign_in` when we're ready to remove auto-connection to Collab.
pub async fn sign_in_with_optional_connect(
self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> Result<()> {
let credentials = self.sign_in(try_provider, cx).await?;
let connect_result = match self.connect_with_credentials(credentials, cx).await {
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
ConnectionResult::Result(result) => result.context("client auth and connect"),
};
connect_result.log_err();
Ok(())
}
pub async fn connect(
self: &Arc<Self>, self: &Arc<Self>,
try_provider: bool, try_provider: bool,
cx: &AsyncApp, cx: &AsyncApp,
) -> ConnectionResult<()> { ) -> ConnectionResult<()> {
let was_disconnected = match *self.status().borrow() { let was_disconnected = match *self.status().borrow() {
Status::SignedOut => true, Status::SignedOut | Status::Authenticated => true,
Status::ConnectionError Status::ConnectionError
| Status::ConnectionLost | Status::ConnectionLost
| Status::Authenticating { .. } | Status::Authenticating { .. }
| Status::AuthenticationError
| Status::Reauthenticating { .. } | Status::Reauthenticating { .. }
| Status::ReconnectionError { .. } => false, | Status::ReconnectionError { .. } => false,
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
@ -905,41 +997,10 @@ impl Client {
); );
} }
}; };
if was_disconnected { let credentials = match self.sign_in(try_provider, cx).await {
self.set_status(Status::Authenticating, cx); Ok(credentials) => credentials,
} else { Err(err) => return ConnectionResult::Result(Err(err)),
self.set_status(Status::Reauthenticating, cx) };
}
let mut read_from_provider = false;
let mut credentials = self.state.read().credentials.clone();
if credentials.is_none() && try_provider {
credentials = self.credentials_provider.read_credentials(cx).await;
read_from_provider = credentials.is_some();
}
if credentials.is_none() {
let mut status_rx = self.status();
let _ = status_rx.next().await;
futures::select_biased! {
authenticate = self.authenticate(cx).fuse() => {
match authenticate {
Ok(creds) => credentials = Some(creds),
Err(err) => {
self.set_status(Status::ConnectionError, cx);
return ConnectionResult::Result(Err(err));
}
}
}
_ = status_rx.next().fuse() => {
return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
}
}
}
let credentials = credentials.unwrap();
self.set_id(credentials.user_id);
self.cloud_client
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
if was_disconnected { if was_disconnected {
self.set_status(Status::Connecting, cx); self.set_status(Status::Connecting, cx);
@ -947,17 +1008,20 @@ impl Client {
self.set_status(Status::Reconnecting, cx); self.set_status(Status::Reconnecting, cx);
} }
self.connect_with_credentials(credentials, cx).await
}
async fn connect_with_credentials(
self: &Arc<Self>,
credentials: Credentials,
cx: &AsyncApp,
) -> ConnectionResult<()> {
let mut timeout = let mut timeout =
futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT)); futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
futures::select_biased! { futures::select_biased! {
connection = self.establish_connection(&credentials, cx).fuse() => { connection = self.establish_connection(&credentials, cx).fuse() => {
match connection { match connection {
Ok(conn) => { Ok(conn) => {
self.state.write().credentials = Some(credentials.clone());
if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
}
futures::select_biased! { futures::select_biased! {
result = self.set_connection(conn, cx).fuse() => { result = self.set_connection(conn, cx).fuse() => {
match result.context("client auth and connect") { match result.context("client auth and connect") {
@ -975,15 +1039,8 @@ impl Client {
} }
} }
Err(EstablishConnectionError::Unauthorized) => { Err(EstablishConnectionError::Unauthorized) => {
self.state.write().credentials.take(); self.set_status(Status::ConnectionError, cx);
if read_from_provider { ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
self.credentials_provider.delete_credentials(cx).await.log_err();
self.set_status(Status::SignedOut, cx);
self.authenticate_and_connect(false, cx).await
} else {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
}
} }
Err(EstablishConnectionError::UpgradeRequired) => { Err(EstablishConnectionError::UpgradeRequired) => {
self.set_status(Status::UpgradeRequired, cx); self.set_status(Status::UpgradeRequired, cx);
@ -1733,7 +1790,7 @@ mod tests {
}); });
let auth_and_connect = cx.spawn({ let auth_and_connect = cx.spawn({
let client = client.clone(); let client = client.clone();
|cx| async move { client.authenticate_and_connect(false, &cx).await } |cx| async move { client.connect(false, &cx).await }
}); });
executor.run_until_parked(); executor.run_until_parked();
assert!(matches!(status.next().await, Some(Status::Connecting))); assert!(matches!(status.next().await, Some(Status::Connecting)));
@ -1810,7 +1867,7 @@ mod tests {
let _authenticate = cx.spawn({ let _authenticate = cx.spawn({
let client = client.clone(); let client = client.clone();
move |cx| async move { client.authenticate_and_connect(false, &cx).await } move |cx| async move { client.connect(false, &cx).await }
}); });
executor.run_until_parked(); executor.run_until_parked();
assert_eq!(*auth_count.lock(), 1); assert_eq!(*auth_count.lock(), 1);
@ -1818,7 +1875,7 @@ mod tests {
let _authenticate = cx.spawn({ let _authenticate = cx.spawn({
let client = client.clone(); let client = client.clone();
|cx| async move { client.authenticate_and_connect(false, &cx).await } |cx| async move { client.connect(false, &cx).await }
}); });
executor.run_until_parked(); executor.run_until_parked();
assert_eq!(*auth_count.lock(), 2); assert_eq!(*auth_count.lock(), 2);

View file

@ -1,3 +0,0 @@
mod user_store;
pub use user_store::*;

View file

@ -1,211 +0,0 @@
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context as _;
use chrono::{DateTime, Utc};
use cloud_api_client::{AuthenticatedUser, CloudApiClient, GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::Plan;
use gpui::{Context, Entity, Subscription, Task};
use util::{ResultExt as _, maybe};
use crate::user::Event as RpcUserStoreEvent;
use crate::{EditPredictionUsage, ModelRequestUsage, RequestUsage, UserStore};
pub struct CloudUserStore {
cloud_client: Arc<CloudApiClient>,
authenticated_user: Option<Arc<AuthenticatedUser>>,
plan_info: Option<Arc<PlanInfo>>,
model_request_usage: Option<ModelRequestUsage>,
edit_prediction_usage: Option<EditPredictionUsage>,
_maintain_authenticated_user_task: Task<()>,
_rpc_plan_updated_subscription: Subscription,
}
impl CloudUserStore {
pub fn new(
cloud_client: Arc<CloudApiClient>,
rpc_user_store: Entity<UserStore>,
cx: &mut Context<Self>,
) -> Self {
let rpc_plan_updated_subscription =
cx.subscribe(&rpc_user_store, Self::handle_rpc_user_store_event);
Self {
cloud_client: cloud_client.clone(),
authenticated_user: None,
plan_info: None,
model_request_usage: None,
edit_prediction_usage: None,
_maintain_authenticated_user_task: cx.spawn(async move |this, cx| {
maybe!(async move {
loop {
let Some(this) = this.upgrade() else {
return anyhow::Ok(());
};
if cloud_client.has_credentials() {
let already_fetched_authenticated_user = this
.read_with(cx, |this, _cx| this.authenticated_user().is_some())
.unwrap_or(false);
if already_fetched_authenticated_user {
// We already fetched the authenticated user; nothing to do.
} else {
let authenticated_user_result = cloud_client
.get_authenticated_user()
.await
.context("failed to fetch authenticated user");
if let Some(response) = authenticated_user_result.log_err() {
this.update(cx, |this, _cx| {
this.update_authenticated_user(response);
})
.ok();
}
}
} else {
this.update(cx, |this, _cx| {
this.authenticated_user.take();
this.plan_info.take();
})
.ok();
}
cx.background_executor()
.timer(Duration::from_millis(100))
.await;
}
})
.await
.log_err();
}),
_rpc_plan_updated_subscription: rpc_plan_updated_subscription,
}
}
pub fn is_authenticated(&self) -> bool {
self.authenticated_user.is_some()
}
pub fn authenticated_user(&self) -> Option<Arc<AuthenticatedUser>> {
self.authenticated_user.clone()
}
pub fn plan(&self) -> Option<Plan> {
self.plan_info.as_ref().map(|plan| plan.plan)
}
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
self.plan_info
.as_ref()
.and_then(|plan| plan.subscription_period)
.map(|subscription_period| {
(
subscription_period.started_at.0,
subscription_period.ended_at.0,
)
})
}
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
self.plan_info
.as_ref()
.and_then(|plan| plan.trial_started_at)
.map(|trial_started_at| trial_started_at.0)
}
pub fn has_accepted_tos(&self) -> bool {
self.authenticated_user
.as_ref()
.map(|user| user.accepted_tos_at.is_some())
.unwrap_or_default()
}
/// Returns whether the user's account is too new to use the service.
pub fn account_too_young(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_account_too_young)
.unwrap_or_default()
}
/// Returns whether the current user has overdue invoices and usage should be blocked.
pub fn has_overdue_invoices(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.has_overdue_invoices)
.unwrap_or_default()
}
pub fn is_usage_based_billing_enabled(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_usage_based_billing_enabled)
.unwrap_or_default()
}
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
self.model_request_usage
}
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
self.model_request_usage = Some(usage);
cx.notify();
}
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
self.edit_prediction_usage
}
pub fn update_edit_prediction_usage(
&mut self,
usage: EditPredictionUsage,
cx: &mut Context<Self>,
) {
self.edit_prediction_usage = Some(usage);
cx.notify();
}
fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) {
self.authenticated_user = Some(Arc::new(response.user));
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
limit: response.plan.usage.model_requests.limit,
amount: response.plan.usage.model_requests.used as i32,
}));
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
limit: response.plan.usage.edit_predictions.limit,
amount: response.plan.usage.edit_predictions.used as i32,
}));
self.plan_info = Some(Arc::new(response.plan));
}
fn handle_rpc_user_store_event(
&mut self,
_: Entity<UserStore>,
event: &RpcUserStoreEvent,
cx: &mut Context<Self>,
) {
match event {
RpcUserStoreEvent::PlanUpdated => {
cx.spawn(async move |this, cx| {
let cloud_client =
cx.update(|cx| this.read_with(cx, |this, _cx| this.cloud_client.clone()))??;
let response = cloud_client
.get_authenticated_user()
.await
.context("failed to fetch authenticated user")?;
cx.update(|cx| {
this.update(cx, |this, _cx| {
this.update_authenticated_user(response);
})
})??;
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
_ => {}
}
}
}

View file

@ -1,8 +1,11 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use chrono::Duration; use chrono::Duration;
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
use futures::{StreamExt, stream::BoxStream}; use futures::{StreamExt, stream::BoxStream};
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext}; use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
use http_client::{AsyncBody, Method, Request, http};
use parking_lot::Mutex; use parking_lot::Mutex;
use rpc::{ use rpc::{
ConnectionId, Peer, Receipt, TypedEnvelope, ConnectionId, Peer, Receipt, TypedEnvelope,
@ -39,6 +42,44 @@ impl FakeServer {
executor: cx.executor(), executor: cx.executor(),
}; };
client.http_client().as_fake().replace_handler({
let state = server.state.clone();
move |old_handler, req| {
let state = state.clone();
let old_handler = old_handler.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => {
let credentials = parse_authorization_header(&req);
if credentials
!= Some(Credentials {
user_id: client_user_id,
access_token: state.lock().access_token.to_string(),
})
{
return Ok(http_client::Response::builder()
.status(401)
.body("Unauthorized".into())
.unwrap());
}
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response(
client_user_id as i32,
format!("user-{client_user_id}"),
))
.unwrap()
.into(),
)
.unwrap())
}
_ => old_handler(req).await,
}
}
}
});
client client
.override_authenticate({ .override_authenticate({
let state = Arc::downgrade(&server.state); let state = Arc::downgrade(&server.state);
@ -105,7 +146,7 @@ impl FakeServer {
}); });
client client
.authenticate_and_connect(false, &cx.to_async()) .connect(false, &cx.to_async())
.await .await
.into_response() .into_response()
.unwrap(); .unwrap();
@ -223,3 +264,54 @@ impl Drop for FakeServer {
self.disconnect(); self.disconnect();
} }
} }
pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
let mut auth_header = req
.headers()
.get(http::header::AUTHORIZATION)?
.to_str()
.ok()?
.split_whitespace();
let user_id = auth_header.next()?.parse().ok()?;
let access_token = auth_header.next()?;
Some(Credentials {
user_id,
access_token: access_token.to_string(),
})
}
pub fn make_get_authenticated_user_response(
user_id: i32,
github_login: String,
) -> GetAuthenticatedUserResponse {
GetAuthenticatedUserResponse {
user: AuthenticatedUser {
id: user_id,
metrics_id: format!("metrics-id-{user_id}"),
avatar_url: "".to_string(),
github_login,
name: None,
is_staff: false,
accepted_tos_at: None,
},
feature_flags: vec![],
plan: PlanInfo {
plan: Plan::ZedPro,
subscription_period: None,
usage: CurrentUsage {
model_requests: UsageData {
used: 0,
limit: UsageLimit::Limited(500),
},
edit_predictions: UsageData {
used: 250,
limit: UsageLimit::Unlimited,
},
},
trial_started_at: None,
is_usage_based_billing_enabled: false,
is_account_too_young: false,
has_overdue_invoices: false,
},
}
}

View file

@ -1,6 +1,7 @@
use super::{Client, Status, TypedEnvelope, proto}; use super::{Client, Status, TypedEnvelope, proto};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{ use cloud_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
@ -20,7 +21,7 @@ use std::{
sync::{Arc, Weak}, sync::{Arc, Weak},
}; };
use text::ReplicaId; use text::ReplicaId;
use util::TryFutureExt as _; use util::{ResultExt, TryFutureExt as _};
pub type UserId = u64; pub type UserId = u64;
@ -110,12 +111,11 @@ pub struct UserStore {
by_github_login: HashMap<SharedString, u64>, by_github_login: HashMap<SharedString, u64>,
participant_indices: HashMap<u64, ParticipantIndex>, participant_indices: HashMap<u64, ParticipantIndex>,
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>, update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_plan: Option<proto::Plan>, model_request_usage: Option<ModelRequestUsage>,
trial_started_at: Option<DateTime<Utc>>, edit_prediction_usage: Option<EditPredictionUsage>,
is_usage_based_billing_enabled: Option<bool>, plan_info: Option<PlanInfo>,
account_too_young: Option<bool>,
current_user: watch::Receiver<Option<Arc<User>>>, current_user: watch::Receiver<Option<Arc<User>>>,
accepted_tos_at: Option<Option<DateTime<Utc>>>, accepted_tos_at: Option<Option<cloud_api_client::Timestamp>>,
contacts: Vec<Arc<Contact>>, contacts: Vec<Arc<Contact>>,
incoming_contact_requests: Vec<Arc<User>>, incoming_contact_requests: Vec<Arc<User>>,
outgoing_contact_requests: Vec<Arc<User>>, outgoing_contact_requests: Vec<Arc<User>>,
@ -185,10 +185,9 @@ impl UserStore {
users: Default::default(), users: Default::default(),
by_github_login: Default::default(), by_github_login: Default::default(),
current_user: current_user_rx, current_user: current_user_rx,
current_plan: None, plan_info: None,
trial_started_at: None, model_request_usage: None,
is_usage_based_billing_enabled: None, edit_prediction_usage: None,
account_too_young: None,
accepted_tos_at: None, accepted_tos_at: None,
contacts: Default::default(), contacts: Default::default(),
incoming_contact_requests: Default::default(), incoming_contact_requests: Default::default(),
@ -218,53 +217,30 @@ impl UserStore {
return Ok(()); return Ok(());
}; };
match status { match status {
Status::Connected { .. } => { Status::Authenticated | Status::Connected { .. } => {
if let Some(user_id) = client.user_id() { if let Some(user_id) = client.user_id() {
let fetch_user = if let Ok(fetch_user) = let response = client.cloud_client().get_authenticated_user().await;
this.update(cx, |this, cx| this.get_user(user_id, cx).log_err()) let mut current_user = None;
{
fetch_user
} else {
break;
};
let fetch_private_user_info =
client.request(proto::GetPrivateUserInfo {}).log_err();
let (user, info) =
futures::join!(fetch_user, fetch_private_user_info);
cx.update(|cx| { cx.update(|cx| {
if let Some(info) = info { if let Some(response) = response.log_err() {
let staff = let user = Arc::new(User {
info.staff && !*feature_flags::ZED_DISABLE_STAFF; id: user_id,
cx.update_flags(staff, info.flags); github_login: response.user.github_login.clone().into(),
client.telemetry.set_authenticated_user_info( avatar_uri: response.user.avatar_url.clone().into(),
Some(info.metrics_id.clone()), name: response.user.name.clone(),
staff, });
); current_user = Some(user.clone());
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
let accepted_tos_at = { this.by_github_login
#[cfg(debug_assertions)] .insert(user.github_login.clone(), user_id);
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() this.users.insert(user_id, user);
{ this.update_authenticated_user(response, cx)
None
} else {
info.accepted_tos_at
}
#[cfg(not(debug_assertions))]
info.accepted_tos_at
};
this.set_current_user_accepted_tos_at(accepted_tos_at);
cx.emit(Event::PrivateUserInfoUpdated);
}) })
} else { } else {
anyhow::Ok(()) anyhow::Ok(())
} }
})??; })??;
current_user_tx.send(current_user).await.ok();
current_user_tx.send(user).await.ok();
this.update(cx, |_, cx| cx.notify())?; this.update(cx, |_, cx| cx.notify())?;
} }
@ -345,22 +321,22 @@ impl UserStore {
async fn handle_update_plan( async fn handle_update_plan(
this: Entity<Self>, this: Entity<Self>,
message: TypedEnvelope<proto::UpdateUserPlan>, _message: TypedEnvelope<proto::UpdateUserPlan>,
mut cx: AsyncApp, mut cx: AsyncApp,
) -> Result<()> { ) -> Result<()> {
this.update(&mut cx, |this, cx| { let client = this
this.current_plan = Some(message.payload.plan()); .read_with(&cx, |this, _| this.client.upgrade())?
this.trial_started_at = message .context("client was dropped")?;
.payload
.trial_started_at
.and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0));
this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled;
this.account_too_young = message.payload.account_too_young;
cx.emit(Event::PlanUpdated); let response = client
cx.notify(); .cloud_client()
})?; .get_authenticated_user()
Ok(()) .await
.context("failed to fetch authenticated user")?;
this.update(&mut cx, |this, cx| {
this.update_authenticated_user(response, cx);
})
} }
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> { fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
@ -719,42 +695,131 @@ impl UserStore {
self.current_user.borrow().clone() self.current_user.borrow().clone()
} }
pub fn current_plan(&self) -> Option<proto::Plan> { pub fn plan(&self) -> Option<cloud_llm_client::Plan> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() { if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() {
return match plan.as_str() { return match plan.as_str() {
"free" => Some(proto::Plan::Free), "free" => Some(cloud_llm_client::Plan::ZedFree),
"trial" => Some(proto::Plan::ZedProTrial), "trial" => Some(cloud_llm_client::Plan::ZedProTrial),
"pro" => Some(proto::Plan::ZedPro), "pro" => Some(cloud_llm_client::Plan::ZedPro),
_ => { _ => {
panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'"); panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'");
} }
}; };
} }
self.current_plan self.plan_info.as_ref().map(|info| info.plan)
}
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
self.plan_info
.as_ref()
.and_then(|plan| plan.subscription_period)
.map(|subscription_period| {
(
subscription_period.started_at.0,
subscription_period.ended_at.0,
)
})
} }
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> { pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
self.trial_started_at self.plan_info
.as_ref()
.and_then(|plan| plan.trial_started_at)
.map(|trial_started_at| trial_started_at.0)
} }
pub fn usage_based_billing_enabled(&self) -> Option<bool> { /// Returns whether the user's account is too new to use the service.
self.is_usage_based_billing_enabled pub fn account_too_young(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_account_too_young)
.unwrap_or_default()
}
/// Returns whether the current user has overdue invoices and usage should be blocked.
pub fn has_overdue_invoices(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.has_overdue_invoices)
.unwrap_or_default()
}
pub fn is_usage_based_billing_enabled(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_usage_based_billing_enabled)
.unwrap_or_default()
}
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
self.model_request_usage
}
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
self.model_request_usage = Some(usage);
cx.notify();
}
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
self.edit_prediction_usage
}
pub fn update_edit_prediction_usage(
&mut self,
usage: EditPredictionUsage,
cx: &mut Context<Self>,
) {
self.edit_prediction_usage = Some(usage);
cx.notify();
}
fn update_authenticated_user(
&mut self,
response: GetAuthenticatedUserResponse,
cx: &mut Context<Self>,
) {
let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF;
cx.update_flags(staff, response.feature_flags);
if let Some(client) = self.client.upgrade() {
client
.telemetry
.set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff);
}
let accepted_tos_at = {
#[cfg(debug_assertions)]
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() {
None
} else {
response.user.accepted_tos_at
}
#[cfg(not(debug_assertions))]
response.user.accepted_tos_at
};
self.accepted_tos_at = Some(accepted_tos_at);
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
limit: response.plan.usage.model_requests.limit,
amount: response.plan.usage.model_requests.used as i32,
}));
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
limit: response.plan.usage.edit_predictions.limit,
amount: response.plan.usage.edit_predictions.used as i32,
}));
self.plan_info = Some(response.plan);
cx.emit(Event::PrivateUserInfoUpdated);
} }
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> { pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
self.current_user.clone() self.current_user.clone()
} }
/// Returns whether the user's account is too new to use the service. pub fn has_accepted_terms_of_service(&self) -> bool {
pub fn account_too_young(&self) -> bool {
self.account_too_young.unwrap_or(false)
}
pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
self.accepted_tos_at self.accepted_tos_at
.map(|accepted_tos_at| accepted_tos_at.is_some()) .map_or(false, |accepted_tos_at| accepted_tos_at.is_some())
} }
pub fn accept_terms_of_service(&self, cx: &Context<Self>) -> Task<Result<()>> { pub fn accept_terms_of_service(&self, cx: &Context<Self>) -> Task<Result<()>> {
@ -766,23 +831,18 @@ impl UserStore {
cx.spawn(async move |this, cx| -> anyhow::Result<()> { cx.spawn(async move |this, cx| -> anyhow::Result<()> {
let client = client.upgrade().context("client not found")?; let client = client.upgrade().context("client not found")?;
let response = client let response = client
.request(proto::AcceptTermsOfService {}) .cloud_client()
.accept_terms_of_service()
.await .await
.context("error accepting tos")?; .context("error accepting tos")?;
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at)); this.accepted_tos_at = Some(response.user.accepted_tos_at);
cx.emit(Event::PrivateUserInfoUpdated); cx.emit(Event::PrivateUserInfoUpdated);
})?; })?;
Ok(()) Ok(())
}) })
} }
fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
self.accepted_tos_at = Some(
accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
);
}
fn load_users( fn load_users(
&self, &self,
request: impl RequestMessage<Response = UsersResponse>, request: impl RequestMessage<Response = UsersResponse>,

View file

@ -1286,7 +1286,7 @@ async fn test_calls_on_multiple_connections(
client_b1.disconnect(&cx_b1.to_async()); client_b1.disconnect(&cx_b1.to_async());
executor.advance_clock(RECEIVE_TIMEOUT); executor.advance_clock(RECEIVE_TIMEOUT);
client_b1 client_b1
.authenticate_and_connect(false, &cx_b1.to_async()) .connect(false, &cx_b1.to_async())
.await .await
.into_response() .into_response()
.unwrap(); .unwrap();
@ -1667,7 +1667,7 @@ async fn test_project_reconnect(
// Client A reconnects. Their project is re-shared, and client B re-joins it. // Client A reconnects. Their project is re-shared, and client B re-joins it.
server.allow_connections(); server.allow_connections();
client_a client_a
.authenticate_and_connect(false, &cx_a.to_async()) .connect(false, &cx_a.to_async())
.await .await
.into_response() .into_response()
.unwrap(); .unwrap();
@ -1796,7 +1796,7 @@ async fn test_project_reconnect(
// Client B reconnects. They re-join the room and the remaining shared project. // Client B reconnects. They re-join the room and the remaining shared project.
server.allow_connections(); server.allow_connections();
client_b client_b
.authenticate_and_connect(false, &cx_b.to_async()) .connect(false, &cx_b.to_async())
.await .await
.into_response() .into_response()
.unwrap(); .unwrap();
@ -5738,7 +5738,7 @@ async fn test_contacts(
server.allow_connections(); server.allow_connections();
client_c client_c
.authenticate_and_connect(false, &cx_c.to_async()) .connect(false, &cx_c.to_async())
.await .await
.into_response() .into_response()
.unwrap(); .unwrap();
@ -6269,7 +6269,7 @@ async fn test_contact_requests(
client.disconnect(&cx.to_async()); client.disconnect(&cx.to_async());
client.clear_contacts(cx).await; client.clear_contacts(cx).await;
client client
.authenticate_and_connect(false, &cx.to_async()) .connect(false, &cx.to_async())
.await .await
.into_response() .into_response()
.unwrap(); .unwrap();

View file

@ -3,6 +3,7 @@ use std::sync::Arc;
use gpui::{BackgroundExecutor, TestAppContext}; use gpui::{BackgroundExecutor, TestAppContext};
use notifications::NotificationEvent; use notifications::NotificationEvent;
use parking_lot::Mutex; use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use rpc::{Notification, proto}; use rpc::{Notification, proto};
use crate::tests::TestServer; use crate::tests::TestServer;
@ -17,6 +18,9 @@ async fn test_notifications(
let client_a = server.create_client(cx_a, "user_a").await; let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await; let client_b = server.create_client(cx_b, "user_b").await;
// Wait for authentication/connection to Collab to be established.
executor.run_until_parked();
let notification_events_a = Arc::new(Mutex::new(Vec::new())); let notification_events_a = Arc::new(Mutex::new(Vec::new()));
let notification_events_b = Arc::new(Mutex::new(Vec::new())); let notification_events_b = Arc::new(Mutex::new(Vec::new()));
client_a.notification_store().update(cx_a, |_, cx| { client_a.notification_store().update(cx_a, |_, cx| {

View file

@ -8,7 +8,7 @@ use crate::{
use anyhow::anyhow; use anyhow::anyhow;
use call::ActiveCall; use call::ActiveCall;
use channel::{ChannelBuffer, ChannelStore}; use channel::{ChannelBuffer, ChannelStore};
use client::CloudUserStore; use client::test::{make_get_authenticated_user_response, parse_authorization_header};
use client::{ use client::{
self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore, self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
proto::PeerId, proto::PeerId,
@ -21,7 +21,7 @@ use fs::FakeFs;
use futures::{StreamExt as _, channel::oneshot}; use futures::{StreamExt as _, channel::oneshot};
use git::GitHostingProviderRegistry; use git::GitHostingProviderRegistry;
use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext}; use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext};
use http_client::FakeHttpClient; use http_client::{FakeHttpClient, Method};
use language::LanguageRegistry; use language::LanguageRegistry;
use node_runtime::NodeRuntime; use node_runtime::NodeRuntime;
use notifications::NotificationStore; use notifications::NotificationStore;
@ -162,6 +162,8 @@ impl TestServer {
} }
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient { pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
const ACCESS_TOKEN: &str = "the-token";
let fs = FakeFs::new(cx.executor()); let fs = FakeFs::new(cx.executor());
cx.update(|cx| { cx.update(|cx| {
@ -176,7 +178,7 @@ impl TestServer {
}); });
let clock = Arc::new(FakeSystemClock::new()); let clock = Arc::new(FakeSystemClock::new());
let http = FakeHttpClient::with_404_response();
let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
{ {
user.id user.id
@ -198,6 +200,47 @@ impl TestServer {
.expect("creating user failed") .expect("creating user failed")
.user_id .user_id
}; };
let http = FakeHttpClient::create({
let name = name.to_string();
move |req| {
let name = name.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => {
let credentials = parse_authorization_header(&req);
if credentials
!= Some(Credentials {
user_id: user_id.to_proto(),
access_token: ACCESS_TOKEN.into(),
})
{
return Ok(http_client::Response::builder()
.status(401)
.body("Unauthorized".into())
.unwrap());
}
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response(
user_id.0, name,
))
.unwrap()
.into(),
)
.unwrap())
}
_ => Ok(http_client::Response::builder()
.status(404)
.body("Not Found".into())
.unwrap()),
}
}
}
});
let client_name = name.to_string(); let client_name = name.to_string();
let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx)); let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx));
let server = self.server.clone(); let server = self.server.clone();
@ -209,11 +252,10 @@ impl TestServer {
.unwrap() .unwrap()
.set_id(user_id.to_proto()) .set_id(user_id.to_proto())
.override_authenticate(move |cx| { .override_authenticate(move |cx| {
let access_token = "the-token".to_string();
cx.spawn(async move |_| { cx.spawn(async move |_| {
Ok(Credentials { Ok(Credentials {
user_id: user_id.to_proto(), user_id: user_id.to_proto(),
access_token, access_token: ACCESS_TOKEN.into(),
}) })
}) })
}) })
@ -222,7 +264,7 @@ impl TestServer {
credentials, credentials,
&Credentials { &Credentials {
user_id: user_id.0 as u64, user_id: user_id.0 as u64,
access_token: "the-token".into() access_token: ACCESS_TOKEN.into(),
} }
); );
@ -282,15 +324,12 @@ impl TestServer {
.register_hosting_provider(Arc::new(git_hosting_providers::Github::public_instance())); .register_hosting_provider(Arc::new(git_hosting_providers::Github::public_instance()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));
let language_registry = Arc::new(LanguageRegistry::test(cx.executor())); let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
let session = cx.new(|cx| AppSession::new(Session::test(), cx)); let session = cx.new(|cx| AppSession::new(Session::test(), cx));
let app_state = Arc::new(workspace::AppState { let app_state = Arc::new(workspace::AppState {
client: client.clone(), client: client.clone(),
user_store: user_store.clone(), user_store: user_store.clone(),
cloud_user_store,
workspace_store, workspace_store,
languages: language_registry, languages: language_registry,
fs: fs.clone(), fs: fs.clone(),
@ -323,7 +362,7 @@ impl TestServer {
}); });
client client
.authenticate_and_connect(false, &cx.to_async()) .connect(false, &cx.to_async())
.await .await
.into_response() .into_response()
.unwrap(); .unwrap();

View file

@ -2331,7 +2331,7 @@ impl CollabPanel {
let client = this.client.clone(); let client = this.client.clone();
cx.spawn_in(window, async move |_, cx| { cx.spawn_in(window, async move |_, cx| {
client client
.authenticate_and_connect(true, &cx) .connect(true, &cx)
.await .await
.into_response() .into_response()
.notify_async_err(cx); .notify_async_err(cx);

View file

@ -634,13 +634,13 @@ impl Render for NotificationPanel {
.child(Icon::new(IconName::Envelope)), .child(Icon::new(IconName::Envelope)),
) )
.map(|this| { .map(|this| {
if self.client.user_id().is_none() { if !self.client.status().borrow().is_connected() {
this.child( this.child(
v_flex() v_flex()
.gap_2() .gap_2()
.p_4() .p_4()
.child( .child(
Button::new("sign_in_prompt_button", "Sign in") Button::new("connect_prompt_button", "Connect")
.icon_color(Color::Muted) .icon_color(Color::Muted)
.icon(IconName::Github) .icon(IconName::Github)
.icon_position(IconPosition::Start) .icon_position(IconPosition::Start)
@ -652,10 +652,7 @@ impl Render for NotificationPanel {
let client = client.clone(); let client = client.clone();
window window
.spawn(cx, async move |cx| { .spawn(cx, async move |cx| {
match client match client.connect(true, &cx).await {
.authenticate_and_connect(true, &cx)
.await
{
util::ConnectionResult::Timeout => { util::ConnectionResult::Timeout => {
log::error!("Connection timeout"); log::error!("Connection timeout");
} }
@ -673,7 +670,7 @@ impl Render for NotificationPanel {
) )
.child( .child(
div().flex().w_full().items_center().child( div().flex().w_full().items_center().child(
Label::new("Sign in to view notifications.") Label::new("Connect to view notifications.")
.color(Color::Muted) .color(Color::Muted)
.size(LabelSize::Small), .size(LabelSize::Small),
), ),

View file

@ -13,7 +13,7 @@ pub(crate) use tool_metrics::*;
use ::fs::RealFs; use ::fs::RealFs;
use clap::Parser; use clap::Parser;
use client::{Client, CloudUserStore, ProxySettings, UserStore}; use client::{Client, ProxySettings, UserStore};
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use extension::ExtensionHostProxy; use extension::ExtensionHostProxy;
use futures::future; use futures::future;
@ -329,7 +329,6 @@ pub struct AgentAppState {
pub languages: Arc<LanguageRegistry>, pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>, pub client: Arc<Client>,
pub user_store: Entity<UserStore>, pub user_store: Entity<UserStore>,
pub cloud_user_store: Entity<CloudUserStore>,
pub fs: Arc<dyn fs::Fs>, pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime, pub node_runtime: NodeRuntime,
@ -384,8 +383,6 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
let languages = Arc::new(languages); let languages = Arc::new(languages);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
extension::init(cx); extension::init(cx);
@ -425,12 +422,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
languages.clone(), languages.clone(),
); );
language_model::init(client.clone(), cx); language_model::init(client.clone(), cx);
language_models::init( language_models::init(user_store.clone(), client.clone(), cx);
user_store.clone(),
cloud_user_store.clone(),
client.clone(),
cx,
);
languages::init(languages.clone(), node_runtime.clone(), cx); languages::init(languages.clone(), node_runtime.clone(), cx);
prompt_store::init(cx); prompt_store::init(cx);
terminal_view::init(cx); terminal_view::init(cx);
@ -455,7 +447,6 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
languages, languages,
client, client,
user_store, user_store,
cloud_user_store,
fs, fs,
node_runtime, node_runtime,
prompt_builder, prompt_builder,

View file

@ -221,7 +221,6 @@ impl ExampleInstance {
let prompt_store = None; let prompt_store = None;
let thread_store = ThreadStore::load( let thread_store = ThreadStore::load(
project.clone(), project.clone(),
app_state.cloud_user_store.clone(),
tools, tools,
prompt_store, prompt_store,
app_state.prompt_builder.clone(), app_state.prompt_builder.clone(),

View file

@ -23,6 +23,7 @@ futures.workspace = true
http.workspace = true http.workspace = true
http-body.workspace = true http-body.workspace = true
log.workspace = true log.workspace = true
parking_lot.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
url.workspace = true url.workspace = true

View file

@ -9,12 +9,10 @@ pub use http::{self, Method, Request, Response, StatusCode, Uri};
use futures::future::BoxFuture; use futures::future::BoxFuture;
use http::request::Builder; use http::request::Builder;
use parking_lot::Mutex;
#[cfg(feature = "test-support")] #[cfg(feature = "test-support")]
use std::fmt; use std::fmt;
use std::{ use std::{any::type_name, sync::Arc};
any::type_name,
sync::{Arc, Mutex},
};
pub use url::Url; pub use url::Url;
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] #[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
@ -86,6 +84,11 @@ pub trait HttpClient: 'static + Send + Sync {
} }
fn proxy(&self) -> Option<&Url>; fn proxy(&self) -> Option<&Url>;
#[cfg(feature = "test-support")]
fn as_fake(&self) -> &FakeHttpClient {
panic!("called as_fake on {}", type_name::<Self>())
}
} }
/// An [`HttpClient`] that may have a proxy. /// An [`HttpClient`] that may have a proxy.
@ -132,6 +135,11 @@ impl HttpClient for HttpClientWithProxy {
fn type_name(&self) -> &'static str { fn type_name(&self) -> &'static str {
self.client.type_name() self.client.type_name()
} }
#[cfg(feature = "test-support")]
fn as_fake(&self) -> &FakeHttpClient {
self.client.as_fake()
}
} }
impl HttpClient for Arc<HttpClientWithProxy> { impl HttpClient for Arc<HttpClientWithProxy> {
@ -153,6 +161,11 @@ impl HttpClient for Arc<HttpClientWithProxy> {
fn type_name(&self) -> &'static str { fn type_name(&self) -> &'static str {
self.client.type_name() self.client.type_name()
} }
#[cfg(feature = "test-support")]
fn as_fake(&self) -> &FakeHttpClient {
self.client.as_fake()
}
} }
/// An [`HttpClient`] that has a base URL. /// An [`HttpClient`] that has a base URL.
@ -199,20 +212,13 @@ impl HttpClientWithUrl {
/// Returns the base URL. /// Returns the base URL.
pub fn base_url(&self) -> String { pub fn base_url(&self) -> String {
self.base_url self.base_url.lock().clone()
.lock()
.map_or_else(|_| Default::default(), |url| url.clone())
} }
/// Sets the base URL. /// Sets the base URL.
pub fn set_base_url(&self, base_url: impl Into<String>) { pub fn set_base_url(&self, base_url: impl Into<String>) {
let base_url = base_url.into(); let base_url = base_url.into();
self.base_url *self.base_url.lock() = base_url;
.lock()
.map(|mut url| {
*url = base_url;
})
.ok();
} }
/// Builds a URL using the given path. /// Builds a URL using the given path.
@ -288,6 +294,11 @@ impl HttpClient for Arc<HttpClientWithUrl> {
fn type_name(&self) -> &'static str { fn type_name(&self) -> &'static str {
self.client.type_name() self.client.type_name()
} }
#[cfg(feature = "test-support")]
fn as_fake(&self) -> &FakeHttpClient {
self.client.as_fake()
}
} }
impl HttpClient for HttpClientWithUrl { impl HttpClient for HttpClientWithUrl {
@ -309,6 +320,11 @@ impl HttpClient for HttpClientWithUrl {
fn type_name(&self) -> &'static str { fn type_name(&self) -> &'static str {
self.client.type_name() self.client.type_name()
} }
#[cfg(feature = "test-support")]
fn as_fake(&self) -> &FakeHttpClient {
self.client.as_fake()
}
} }
pub fn read_proxy_from_env() -> Option<Url> { pub fn read_proxy_from_env() -> Option<Url> {
@ -360,10 +376,15 @@ impl HttpClient for BlockedHttpClient {
fn type_name(&self) -> &'static str { fn type_name(&self) -> &'static str {
type_name::<Self>() type_name::<Self>()
} }
#[cfg(feature = "test-support")]
fn as_fake(&self) -> &FakeHttpClient {
panic!("called as_fake on {}", type_name::<Self>())
}
} }
#[cfg(feature = "test-support")] #[cfg(feature = "test-support")]
type FakeHttpHandler = Box< type FakeHttpHandler = Arc<
dyn Fn(Request<AsyncBody>) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> dyn Fn(Request<AsyncBody>) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>>
+ Send + Send
+ Sync + Sync
@ -372,7 +393,7 @@ type FakeHttpHandler = Box<
#[cfg(feature = "test-support")] #[cfg(feature = "test-support")]
pub struct FakeHttpClient { pub struct FakeHttpClient {
handler: FakeHttpHandler, handler: Mutex<Option<FakeHttpHandler>>,
user_agent: HeaderValue, user_agent: HeaderValue,
} }
@ -387,7 +408,7 @@ impl FakeHttpClient {
base_url: Mutex::new("http://test.example".into()), base_url: Mutex::new("http://test.example".into()),
client: HttpClientWithProxy { client: HttpClientWithProxy {
client: Arc::new(Self { client: Arc::new(Self {
handler: Box::new(move |req| Box::pin(handler(req))), handler: Mutex::new(Some(Arc::new(move |req| Box::pin(handler(req))))),
user_agent: HeaderValue::from_static(type_name::<Self>()), user_agent: HeaderValue::from_static(type_name::<Self>()),
}), }),
proxy: None, proxy: None,
@ -412,6 +433,18 @@ impl FakeHttpClient {
.unwrap()) .unwrap())
}) })
} }
pub fn replace_handler<Fut, F>(&self, new_handler: F)
where
Fut: futures::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send + 'static,
F: Fn(FakeHttpHandler, Request<AsyncBody>) -> Fut + Send + Sync + 'static,
{
let mut handler = self.handler.lock();
let old_handler = handler.take().unwrap();
*handler = Some(Arc::new(move |req| {
Box::pin(new_handler(old_handler.clone(), req))
}));
}
} }
#[cfg(feature = "test-support")] #[cfg(feature = "test-support")]
@ -427,7 +460,7 @@ impl HttpClient for FakeHttpClient {
&self, &self,
req: Request<AsyncBody>, req: Request<AsyncBody>,
) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> { ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
let future = (self.handler)(req); let future = (self.handler.lock().as_ref().unwrap())(req);
future future
} }
@ -442,4 +475,8 @@ impl HttpClient for FakeHttpClient {
fn type_name(&self) -> &'static str { fn type_name(&self) -> &'static str {
type_name::<Self>() type_name::<Self>()
} }
fn as_fake(&self) -> &FakeHttpClient {
self
}
} }

View file

@ -1,5 +1,5 @@
use anyhow::Result; use anyhow::Result;
use client::{CloudUserStore, DisableAiSettings, zed_urls}; use client::{DisableAiSettings, UserStore, zed_urls};
use cloud_llm_client::UsageLimit; use cloud_llm_client::UsageLimit;
use copilot::{Copilot, Status}; use copilot::{Copilot, Status};
use editor::{ use editor::{
@ -59,7 +59,7 @@ pub struct InlineCompletionButton {
file: Option<Arc<dyn File>>, file: Option<Arc<dyn File>>,
edit_prediction_provider: Option<Arc<dyn inline_completion::InlineCompletionProviderHandle>>, edit_prediction_provider: Option<Arc<dyn inline_completion::InlineCompletionProviderHandle>>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
cloud_user_store: Entity<CloudUserStore>, user_store: Entity<UserStore>,
popover_menu_handle: PopoverMenuHandle<ContextMenu>, popover_menu_handle: PopoverMenuHandle<ContextMenu>,
} }
@ -245,9 +245,9 @@ impl Render for InlineCompletionButton {
IconName::ZedPredictDisabled IconName::ZedPredictDisabled
}; };
if zeta::should_show_upsell_modal(&self.cloud_user_store, cx) { if zeta::should_show_upsell_modal(&self.user_store, cx) {
let tooltip_meta = if self.cloud_user_store.read(cx).is_authenticated() { let tooltip_meta = if self.user_store.read(cx).current_user().is_some() {
if self.cloud_user_store.read(cx).has_accepted_tos() { if self.user_store.read(cx).has_accepted_terms_of_service() {
"Choose a Plan" "Choose a Plan"
} else { } else {
"Accept the Terms of Service" "Accept the Terms of Service"
@ -371,7 +371,7 @@ impl Render for InlineCompletionButton {
impl InlineCompletionButton { impl InlineCompletionButton {
pub fn new( pub fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
cloud_user_store: Entity<CloudUserStore>, user_store: Entity<UserStore>,
popover_menu_handle: PopoverMenuHandle<ContextMenu>, popover_menu_handle: PopoverMenuHandle<ContextMenu>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -390,9 +390,9 @@ impl InlineCompletionButton {
language: None, language: None,
file: None, file: None,
edit_prediction_provider: None, edit_prediction_provider: None,
user_store,
popover_menu_handle, popover_menu_handle,
fs, fs,
cloud_user_store,
} }
} }
@ -763,7 +763,7 @@ impl InlineCompletionButton {
}) })
}) })
.separator(); .separator();
} else if self.cloud_user_store.read(cx).account_too_young() { } else if self.user_store.read(cx).account_too_young() {
menu = menu menu = menu
.custom_entry( .custom_entry(
|_window, _cx| { |_window, _cx| {
@ -778,7 +778,7 @@ impl InlineCompletionButton {
cx.open_url(&zed_urls::account_url(cx)) cx.open_url(&zed_urls::account_url(cx))
}) })
.separator(); .separator();
} else if self.cloud_user_store.read(cx).has_overdue_invoices() { } else if self.user_store.read(cx).has_overdue_invoices() {
menu = menu menu = menu
.custom_entry( .custom_entry(
|_window, _cx| { |_window, _cx| {

View file

@ -3,10 +3,11 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use client::Client; use client::Client;
use cloud_llm_client::Plan;
use gpui::{ use gpui::{
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
}; };
use proto::{Plan, TypedEnvelope}; use proto::TypedEnvelope;
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error; use thiserror::Error;
@ -30,7 +31,7 @@ pub struct ModelRequestLimitReachedError {
impl fmt::Display for ModelRequestLimitReachedError { impl fmt::Display for ModelRequestLimitReachedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let message = match self.plan { let message = match self.plan {
Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.", Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.",
Plan::ZedPro => { Plan::ZedPro => {
"Model request limit reached. Upgrade to usage-based billing for more requests." "Model request limit reached. Upgrade to usage-based billing for more requests."
} }

View file

@ -44,7 +44,6 @@ ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true partial-json-fixer.workspace = true
proto.workspace = true
release_channel.workspace = true release_channel.workspace = true
schemars.workspace = true schemars.workspace = true
serde.workspace = true serde.workspace = true

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use ::settings::{Settings, SettingsStore}; use ::settings::{Settings, SettingsStore};
use client::{Client, CloudUserStore, UserStore}; use client::{Client, UserStore};
use collections::HashSet; use collections::HashSet;
use gpui::{App, Context, Entity}; use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry}; use language_model::{LanguageModelProviderId, LanguageModelRegistry};
@ -26,22 +26,11 @@ use crate::provider::vercel::VercelLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider; use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*; pub use crate::settings::*;
pub fn init( pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>,
cx: &mut App,
) {
crate::settings::init_settings(cx); crate::settings::init_settings(cx);
let registry = LanguageModelRegistry::global(cx); let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| { registry.update(cx, |registry, cx| {
register_language_model_providers( register_language_model_providers(registry, user_store, client.clone(), cx);
registry,
user_store,
cloud_user_store,
client.clone(),
cx,
);
}); });
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx) let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
@ -111,17 +100,11 @@ fn register_openai_compatible_providers(
fn register_language_model_providers( fn register_language_model_providers(
registry: &mut LanguageModelRegistry, registry: &mut LanguageModelRegistry,
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>, client: Arc<Client>,
cx: &mut Context<LanguageModelRegistry>, cx: &mut Context<LanguageModelRegistry>,
) { ) {
registry.register_provider( registry.register_provider(
CloudLanguageModelProvider::new( CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
user_store.clone(),
cloud_user_store.clone(),
client.clone(),
cx,
),
cx, cx,
); );

View file

@ -2,7 +2,7 @@ use ai_onboarding::YoungAccountBanner;
use anthropic::AnthropicModelMode; use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use client::{Client, CloudUserStore, ModelRequestUsage, UserStore, zed_urls}; use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use cloud_llm_client::{ use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse, CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
@ -117,7 +117,6 @@ pub struct State {
client: Arc<Client>, client: Arc<Client>,
llm_api_token: LlmApiToken, llm_api_token: LlmApiToken,
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
status: client::Status, status: client::Status,
accept_terms_of_service_task: Option<Task<Result<()>>>, accept_terms_of_service_task: Option<Task<Result<()>>>,
models: Vec<Arc<cloud_llm_client::LanguageModel>>, models: Vec<Arc<cloud_llm_client::LanguageModel>>,
@ -133,17 +132,14 @@ impl State {
fn new( fn new(
client: Arc<Client>, client: Arc<Client>,
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
status: client::Status, status: client::Status,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self { Self {
client: client.clone(), client: client.clone(),
llm_api_token: LlmApiToken::default(), llm_api_token: LlmApiToken::default(),
user_store, user_store: user_store.clone(),
cloud_user_store,
status, status,
accept_terms_of_service_task: None, accept_terms_of_service_task: None,
models: Vec::new(), models: Vec::new(),
@ -152,18 +148,12 @@ impl State {
recommended_models: Vec::new(), recommended_models: Vec::new(),
_fetch_models_task: cx.spawn(async move |this, cx| { _fetch_models_task: cx.spawn(async move |this, cx| {
maybe!(async move { maybe!(async move {
let (client, cloud_user_store, llm_api_token) = let (client, llm_api_token) = this
this.read_with(cx, |this, _cx| { .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
(
client.clone(),
this.cloud_user_store.clone(),
this.llm_api_token.clone(),
)
})?;
loop { loop {
let is_authenticated = let is_authenticated = user_store
cloud_user_store.read_with(cx, |this, _cx| this.is_authenticated())?; .read_with(cx, |user_store, _cx| user_store.current_user().is_some())?;
if is_authenticated { if is_authenticated {
break; break;
} }
@ -204,22 +194,19 @@ impl State {
} }
fn is_signed_out(&self, cx: &App) -> bool { fn is_signed_out(&self, cx: &App) -> bool {
!self.cloud_user_store.read(cx).is_authenticated() self.user_store.read(cx).current_user().is_none()
} }
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> { fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let client = self.client.clone(); let client = self.client.clone();
cx.spawn(async move |state, cx| { cx.spawn(async move |state, cx| {
client client.sign_in_with_optional_connect(true, &cx).await?;
.authenticate_and_connect(true, &cx)
.await
.into_response()?;
state.update(cx, |_, cx| cx.notify()) state.update(cx, |_, cx| cx.notify())
}) })
} }
fn has_accepted_terms_of_service(&self, cx: &App) -> bool { fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
self.cloud_user_store.read(cx).has_accepted_tos() self.user_store.read(cx).has_accepted_terms_of_service()
} }
fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) { fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
@ -303,24 +290,11 @@ impl State {
} }
impl CloudLanguageModelProvider { impl CloudLanguageModelProvider {
pub fn new( pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>,
cx: &mut App,
) -> Self {
let mut status_rx = client.status(); let mut status_rx = client.status();
let status = *status_rx.borrow(); let status = *status_rx.borrow();
let state = cx.new(|cx| { let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
State::new(
client.clone(),
user_store.clone(),
cloud_user_store.clone(),
status,
cx,
)
});
let state_ref = state.downgrade(); let state_ref = state.downgrade();
let maintain_client_status = cx.spawn(async move |cx| { let maintain_client_status = cx.spawn(async move |cx| {
@ -632,11 +606,6 @@ impl CloudLanguageModel {
.and_then(|plan| plan.to_str().ok()) .and_then(|plan| plan.to_str().ok())
.and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok()) .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok())
{ {
let plan = match plan {
cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
};
return Err(anyhow!(ModelRequestLimitReachedError { plan })); return Err(anyhow!(ModelRequestLimitReachedError { plan }));
} }
} }
@ -1281,15 +1250,15 @@ impl ConfigurationView {
impl Render for ConfigurationView { impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let state = self.state.read(cx); let state = self.state.read(cx);
let cloud_user_store = state.cloud_user_store.read(cx); let user_store = state.user_store.read(cx);
ZedAiConfiguration { ZedAiConfiguration {
is_connected: !state.is_signed_out(cx), is_connected: !state.is_signed_out(cx),
plan: cloud_user_store.plan(), plan: user_store.plan(),
subscription_period: cloud_user_store.subscription_period(), subscription_period: user_store.subscription_period(),
eligible_for_trial: cloud_user_store.trial_started_at().is_none(), eligible_for_trial: user_store.trial_started_at().is_none(),
has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx), has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
account_too_young: cloud_user_store.account_too_young(), account_too_young: user_store.account_too_young(),
accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(), accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(), accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
sign_in_callback: self.sign_in_callback.clone(), sign_in_callback: self.sign_in_callback.clone(),

View file

@ -278,7 +278,7 @@ pub(crate) fn render_ai_setup_page(
.child(AiUpsellCard { .child(AiUpsellCard {
sign_in_status: SignInStatus::SignedIn, sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}), sign_in: Arc::new(|_, _| {}),
user_plan: onboarding.cloud_user_store.read(cx).plan(), user_plan: onboarding.user_store.read(cx).plan(),
}) })
.child(render_llm_provider_section( .child(render_llm_provider_section(
onboarding, onboarding,

View file

@ -1,5 +1,5 @@
use crate::welcome::{ShowWelcome, WelcomePage}; use crate::welcome::{ShowWelcome, WelcomePage};
use client::{Client, CloudUserStore, UserStore}; use client::{Client, UserStore};
use command_palette_hooks::CommandPaletteFilter; use command_palette_hooks::CommandPaletteFilter;
use db::kvp::KEY_VALUE_STORE; use db::kvp::KEY_VALUE_STORE;
use feature_flags::{FeatureFlag, FeatureFlagViewExt as _}; use feature_flags::{FeatureFlag, FeatureFlagViewExt as _};
@ -220,7 +220,6 @@ struct Onboarding {
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
focus_handle: FocusHandle, focus_handle: FocusHandle,
selected_page: SelectedPage, selected_page: SelectedPage,
cloud_user_store: Entity<CloudUserStore>,
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
_settings_subscription: Subscription, _settings_subscription: Subscription,
} }
@ -231,7 +230,6 @@ impl Onboarding {
workspace: workspace.weak_handle(), workspace: workspace.weak_handle(),
focus_handle: cx.focus_handle(), focus_handle: cx.focus_handle(),
selected_page: SelectedPage::Basics, selected_page: SelectedPage::Basics,
cloud_user_store: workspace.app_state().cloud_user_store.clone(),
user_store: workspace.user_store().clone(), user_store: workspace.user_store().clone(),
_settings_subscription: cx.observe_global::<SettingsStore>(move |_, cx| cx.notify()), _settings_subscription: cx.observe_global::<SettingsStore>(move |_, cx| cx.notify()),
}) })
@ -365,9 +363,8 @@ impl Onboarding {
window window
.spawn(cx, async move |cx| { .spawn(cx, async move |cx| {
client client
.authenticate_and_connect(true, &cx) .sign_in_with_optional_connect(true, &cx)
.await .await
.into_response()
.notify_async_err(cx); .notify_async_err(cx);
}) })
.detach(); .detach();

View file

@ -1362,10 +1362,7 @@ impl Project {
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
cx: AsyncApp, cx: AsyncApp,
) -> Result<Entity<Self>> { ) -> Result<Entity<Self>> {
client client.connect(true, &cx).await.into_response()?;
.authenticate_and_connect(true, &cx)
.await
.into_response()?;
let subscriptions = [ let subscriptions = [
EntitySubscription::Project(client.subscribe_to_entity::<Self>(remote_id)?), EntitySubscription::Project(client.subscribe_to_entity::<Self>(remote_id)?),

View file

@ -20,7 +20,7 @@ use crate::application_menu::{
use auto_update::AutoUpdateStatus; use auto_update::AutoUpdateStatus;
use call::ActiveCall; use call::ActiveCall;
use client::{Client, CloudUserStore, UserStore, zed_urls}; use client::{Client, UserStore, zed_urls};
use cloud_llm_client::Plan; use cloud_llm_client::Plan;
use gpui::{ use gpui::{
Action, AnyElement, App, Context, Corner, Element, Entity, Focusable, InteractiveElement, Action, AnyElement, App, Context, Corner, Element, Entity, Focusable, InteractiveElement,
@ -126,7 +126,6 @@ pub struct TitleBar {
platform_titlebar: Entity<PlatformTitleBar>, platform_titlebar: Entity<PlatformTitleBar>,
project: Entity<Project>, project: Entity<Project>,
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>, client: Arc<Client>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
application_menu: Option<Entity<ApplicationMenu>>, application_menu: Option<Entity<ApplicationMenu>>,
@ -180,11 +179,9 @@ impl Render for TitleBar {
children.push(self.banner.clone().into_any_element()) children.push(self.banner.clone().into_any_element())
} }
let is_authenticated = self.cloud_user_store.read(cx).is_authenticated();
let status = self.client.status(); let status = self.client.status();
let status = &*status.borrow(); let status = &*status.borrow();
let user = self.user_store.read(cx).current_user();
let show_sign_in = !is_authenticated || !matches!(status, client::Status::Connected { .. });
children.push( children.push(
h_flex() h_flex()
@ -194,10 +191,10 @@ impl Render for TitleBar {
.children(self.render_call_controls(window, cx)) .children(self.render_call_controls(window, cx))
.children(self.render_connection_status(status, cx)) .children(self.render_connection_status(status, cx))
.when( .when(
show_sign_in && TitleBarSettings::get_global(cx).show_sign_in, user.is_none() && TitleBarSettings::get_global(cx).show_sign_in,
|el| el.child(self.render_sign_in_button(cx)), |el| el.child(self.render_sign_in_button(cx)),
) )
.when(is_authenticated, |parent| { .when(user.is_some(), |parent| {
parent.child(self.render_user_menu_button(cx)) parent.child(self.render_user_menu_button(cx))
}) })
.into_any_element(), .into_any_element(),
@ -248,7 +245,6 @@ impl TitleBar {
) -> Self { ) -> Self {
let project = workspace.project().clone(); let project = workspace.project().clone();
let user_store = workspace.app_state().user_store.clone(); let user_store = workspace.app_state().user_store.clone();
let cloud_user_store = workspace.app_state().cloud_user_store.clone();
let client = workspace.app_state().client.clone(); let client = workspace.app_state().client.clone();
let active_call = ActiveCall::global(cx); let active_call = ActiveCall::global(cx);
@ -296,7 +292,6 @@ impl TitleBar {
workspace: workspace.weak_handle(), workspace: workspace.weak_handle(),
project, project,
user_store, user_store,
cloud_user_store,
client, client,
_subscriptions: subscriptions, _subscriptions: subscriptions,
banner, banner,
@ -622,9 +617,8 @@ impl TitleBar {
window window
.spawn(cx, async move |cx| { .spawn(cx, async move |cx| {
client client
.authenticate_and_connect(true, &cx) .sign_in_with_optional_connect(true, &cx)
.await .await
.into_response()
.notify_async_err(cx); .notify_async_err(cx);
}) })
.detach(); .detach();
@ -632,15 +626,15 @@ impl TitleBar {
} }
pub fn render_user_menu_button(&mut self, cx: &mut Context<Self>) -> impl Element { pub fn render_user_menu_button(&mut self, cx: &mut Context<Self>) -> impl Element {
let cloud_user_store = self.cloud_user_store.read(cx); let user_store = self.user_store.read(cx);
if let Some(user) = cloud_user_store.authenticated_user() { if let Some(user) = user_store.current_user() {
let has_subscription_period = cloud_user_store.subscription_period().is_some(); let has_subscription_period = user_store.subscription_period().is_some();
let plan = cloud_user_store.plan().filter(|_| { let plan = user_store.plan().filter(|_| {
// Since the user might be on the legacy free plan we filter based on whether we have a subscription period. // Since the user might be on the legacy free plan we filter based on whether we have a subscription period.
has_subscription_period has_subscription_period
}); });
let user_avatar = user.avatar_url.clone(); let user_avatar = user.avatar_uri.clone();
let free_chip_bg = cx let free_chip_bg = cx
.theme() .theme()
.colors() .colors()

View file

@ -15,7 +15,6 @@ mod toast_layer;
mod toolbar; mod toolbar;
mod workspace_settings; mod workspace_settings;
use client::CloudUserStore;
pub use toast_layer::{ToastAction, ToastLayer, ToastView}; pub use toast_layer::{ToastAction, ToastLayer, ToastView};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
@ -840,7 +839,6 @@ pub struct AppState {
pub languages: Arc<LanguageRegistry>, pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>, pub client: Arc<Client>,
pub user_store: Entity<UserStore>, pub user_store: Entity<UserStore>,
pub cloud_user_store: Entity<CloudUserStore>,
pub workspace_store: Entity<WorkspaceStore>, pub workspace_store: Entity<WorkspaceStore>,
pub fs: Arc<dyn fs::Fs>, pub fs: Arc<dyn fs::Fs>,
pub build_window_options: fn(Option<Uuid>, &mut App) -> WindowOptions, pub build_window_options: fn(Option<Uuid>, &mut App) -> WindowOptions,
@ -913,8 +911,6 @@ impl AppState {
let client = Client::new(clock, http_client.clone(), cx); let client = Client::new(clock, http_client.clone(), cx);
let session = cx.new(|cx| AppSession::new(Session::test(), cx)); let session = cx.new(|cx| AppSession::new(Session::test(), cx));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));
theme::init(theme::LoadThemes::JustBase, cx); theme::init(theme::LoadThemes::JustBase, cx);
@ -926,7 +922,6 @@ impl AppState {
fs, fs,
languages, languages,
user_store, user_store,
cloud_user_store,
workspace_store, workspace_store,
node_runtime: NodeRuntime::unavailable(), node_runtime: NodeRuntime::unavailable(),
build_window_options: |_, _| Default::default(), build_window_options: |_, _| Default::default(),
@ -5739,16 +5734,12 @@ impl Workspace {
let client = project.read(cx).client(); let client = project.read(cx).client();
let user_store = project.read(cx).user_store(); let user_store = project.read(cx).user_store();
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));
let session = cx.new(|cx| AppSession::new(Session::test(), cx)); let session = cx.new(|cx| AppSession::new(Session::test(), cx));
window.activate_window(); window.activate_window();
let app_state = Arc::new(AppState { let app_state = Arc::new(AppState {
languages: project.read(cx).languages().clone(), languages: project.read(cx).languages().clone(),
workspace_store, workspace_store,
cloud_user_store,
client, client,
user_store, user_store,
fs: project.read(cx).fs().clone(), fs: project.read(cx).fs().clone(),
@ -6947,10 +6938,13 @@ async fn join_channel_internal(
match status { match status {
Status::Connecting Status::Connecting
| Status::Authenticating | Status::Authenticating
| Status::Authenticated
| Status::Reconnecting | Status::Reconnecting
| Status::Reauthenticating => continue, | Status::Reauthenticating => continue,
Status::Connected { .. } => break 'outer, Status::Connected { .. } => break 'outer,
Status::SignedOut => return Err(ErrorCode::SignedOut.into()), Status::SignedOut | Status::AuthenticationError => {
return Err(ErrorCode::SignedOut.into());
}
Status::UpgradeRequired => return Err(ErrorCode::UpgradeRequired.into()), Status::UpgradeRequired => return Err(ErrorCode::UpgradeRequired.into()),
Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
return Err(ErrorCode::Disconnected.into()); return Err(ErrorCode::Disconnected.into());

View file

@ -5,7 +5,7 @@ use agent_ui::AgentPanel;
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use clap::{Parser, command}; use clap::{Parser, command};
use cli::FORCE_CLI_MODE_ENV_VAR_NAME; use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
use client::{Client, CloudUserStore, ProxySettings, UserStore, parse_zed_link}; use client::{Client, ProxySettings, UserStore, parse_zed_link};
use collab_ui::channel_view::ChannelView; use collab_ui::channel_view::ChannelView;
use collections::HashMap; use collections::HashMap;
use db::kvp::{GLOBAL_KEY_VALUE_STORE, KEY_VALUE_STORE}; use db::kvp::{GLOBAL_KEY_VALUE_STORE, KEY_VALUE_STORE};
@ -42,7 +42,7 @@ use theme::{
ActiveTheme, IconThemeNotFoundError, SystemAppearance, ThemeNotFoundError, ThemeRegistry, ActiveTheme, IconThemeNotFoundError, SystemAppearance, ThemeNotFoundError, ThemeRegistry,
ThemeSettings, ThemeSettings,
}; };
use util::{ConnectionResult, ResultExt, TryFutureExt, maybe}; use util::{ResultExt, TryFutureExt, maybe};
use uuid::Uuid; use uuid::Uuid;
use welcome::{FIRST_OPEN, show_welcome_view}; use welcome::{FIRST_OPEN, show_welcome_view};
use workspace::{ use workspace::{
@ -457,8 +457,6 @@ pub fn main() {
language::init(cx); language::init(cx);
languages::init(languages.clone(), node_runtime.clone(), cx); languages::init(languages.clone(), node_runtime.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx)); let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));
language_extension::init( language_extension::init(
@ -518,7 +516,6 @@ pub fn main() {
languages: languages.clone(), languages: languages.clone(),
client: client.clone(), client: client.clone(),
user_store: user_store.clone(), user_store: user_store.clone(),
cloud_user_store,
fs: fs.clone(), fs: fs.clone(),
build_window_options, build_window_options,
workspace_store, workspace_store,
@ -556,12 +553,7 @@ pub fn main() {
); );
supermaven::init(app_state.client.clone(), cx); supermaven::init(app_state.client.clone(), cx);
language_model::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), cx);
language_models::init( language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
app_state.user_store.clone(),
app_state.cloud_user_store.clone(),
app_state.client.clone(),
cx,
);
agent_settings::init(cx); agent_settings::init(cx);
agent_servers::init(cx); agent_servers::init(cx);
web_search::init(cx); web_search::init(cx);
@ -569,7 +561,7 @@ pub fn main() {
snippet_provider::init(cx); snippet_provider::init(cx);
inline_completion_registry::init( inline_completion_registry::init(
app_state.client.clone(), app_state.client.clone(),
app_state.cloud_user_store.clone(), app_state.user_store.clone(),
cx, cx,
); );
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx);
@ -690,17 +682,9 @@ pub fn main() {
cx.spawn({ cx.spawn({
let client = app_state.client.clone(); let client = app_state.client.clone();
async move |cx| match authenticate(client, &cx).await { async move |cx| authenticate(client, &cx).await
ConnectionResult::Timeout => log::error!("Timeout during initial auth"),
ConnectionResult::ConnectionReset => {
log::error!("Connection reset during initial auth")
}
ConnectionResult::Result(r) => {
r.log_err();
}
}
}) })
.detach(); .detach_and_log_err(cx);
let urls: Vec<_> = args let urls: Vec<_> = args
.paths_or_urls .paths_or_urls
@ -850,15 +834,7 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut
let client = app_state.client.clone(); let client = app_state.client.clone();
// we continue even if authentication fails as join_channel/ open channel notes will // we continue even if authentication fails as join_channel/ open channel notes will
// show a visible error message. // show a visible error message.
match authenticate(client, &cx).await { authenticate(client, &cx).await.log_err();
ConnectionResult::Timeout => {
log::error!("Timeout during open request handling")
}
ConnectionResult::ConnectionReset => {
log::error!("Connection reset during open request handling")
}
ConnectionResult::Result(r) => r?,
};
if let Some(channel_id) = request.join_channel { if let Some(channel_id) = request.join_channel {
cx.update(|cx| { cx.update(|cx| {
@ -908,18 +884,18 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut
} }
} }
async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> ConnectionResult<()> { async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> Result<()> {
if stdout_is_a_pty() { if stdout_is_a_pty() {
if client::IMPERSONATE_LOGIN.is_some() { if client::IMPERSONATE_LOGIN.is_some() {
return client.authenticate_and_connect(false, cx).await; client.sign_in_with_optional_connect(false, cx).await?;
} else if client.has_credentials(cx).await { } else if client.has_credentials(cx).await {
return client.authenticate_and_connect(true, cx).await; client.sign_in_with_optional_connect(true, cx).await?;
} }
} else if client.has_credentials(cx).await { } else if client.has_credentials(cx).await {
return client.authenticate_and_connect(true, cx).await; client.sign_in_with_optional_connect(true, cx).await?;
} }
ConnectionResult::Result(Ok(())) Ok(())
} }
async fn system_id() -> Result<IdType> { async fn system_id() -> Result<IdType> {

View file

@ -336,7 +336,7 @@ pub fn initialize_workspace(
let edit_prediction_button = cx.new(|cx| { let edit_prediction_button = cx.new(|cx| {
inline_completion_button::InlineCompletionButton::new( inline_completion_button::InlineCompletionButton::new(
app_state.fs.clone(), app_state.fs.clone(),
app_state.cloud_user_store.clone(), app_state.user_store.clone(),
inline_completion_menu_handle.clone(), inline_completion_menu_handle.clone(),
cx, cx,
) )
@ -4488,12 +4488,7 @@ mod tests {
); );
image_viewer::init(cx); image_viewer::init(cx);
language_model::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), cx);
language_models::init( language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
app_state.user_store.clone(),
app_state.cloud_user_store.clone(),
app_state.client.clone(),
cx,
);
web_search::init(cx); web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx); web_search_providers::init(app_state.client.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx); let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);

View file

@ -139,8 +139,7 @@ impl ComponentPreview {
let project_clone = project.clone(); let project_clone = project.clone();
cx.spawn_in(window, async move |entity, cx| { cx.spawn_in(window, async move |entity, cx| {
let thread_store_future = let thread_store_future = load_preview_thread_store(project_clone.clone(), cx);
load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx);
let text_thread_store_future = let text_thread_store_future =
load_preview_text_thread_store(workspace_clone.clone(), project_clone.clone(), cx); load_preview_text_thread_store(workspace_clone.clone(), project_clone.clone(), cx);

View file

@ -12,22 +12,19 @@ use ui::{App, Window};
use workspace::Workspace; use workspace::Workspace;
pub fn load_preview_thread_store( pub fn load_preview_thread_store(
workspace: WeakEntity<Workspace>,
project: Entity<Project>, project: Entity<Project>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> Task<Result<Entity<ThreadStore>>> { ) -> Task<Result<Entity<ThreadStore>>> {
workspace cx.update(|cx| {
.update(cx, |workspace, cx| { ThreadStore::load(
ThreadStore::load( project.clone(),
project.clone(), cx.new(|_| ToolWorkingSet::default()),
workspace.app_state().cloud_user_store.clone(), None,
cx.new(|_| ToolWorkingSet::default()), Arc::new(PromptBuilder::new(None).unwrap()),
None, cx,
Arc::new(PromptBuilder::new(None).unwrap()), )
cx, })
) .unwrap_or(Task::ready(Err(anyhow!("workspace dropped"))))
})
.unwrap_or(Task::ready(Err(anyhow!("workspace dropped"))))
} }
pub fn load_preview_text_thread_store( pub fn load_preview_text_thread_store(

View file

@ -1,4 +1,4 @@
use client::{Client, CloudUserStore, DisableAiSettings}; use client::{Client, DisableAiSettings, UserStore};
use collections::HashMap; use collections::HashMap;
use copilot::{Copilot, CopilotCompletionProvider}; use copilot::{Copilot, CopilotCompletionProvider};
use editor::Editor; use editor::Editor;
@ -13,12 +13,12 @@ use util::ResultExt;
use workspace::Workspace; use workspace::Workspace;
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
pub fn init(client: Arc<Client>, cloud_user_store: Entity<CloudUserStore>, cx: &mut App) { pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default(); let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
cx.observe_new({ cx.observe_new({
let editors = editors.clone(); let editors = editors.clone();
let client = client.clone(); let client = client.clone();
let cloud_user_store = cloud_user_store.clone(); let user_store = user_store.clone();
move |editor: &mut Editor, window, cx: &mut Context<Editor>| { move |editor: &mut Editor, window, cx: &mut Context<Editor>| {
if !editor.mode().is_full() { if !editor.mode().is_full() {
return; return;
@ -48,7 +48,7 @@ pub fn init(client: Arc<Client>, cloud_user_store: Entity<CloudUserStore>, cx: &
editor, editor,
provider, provider,
&client, &client,
cloud_user_store.clone(), user_store.clone(),
window, window,
cx, cx,
); );
@ -60,7 +60,7 @@ pub fn init(client: Arc<Client>, cloud_user_store: Entity<CloudUserStore>, cx: &
let mut provider = all_language_settings(None, cx).edit_predictions.provider; let mut provider = all_language_settings(None, cx).edit_predictions.provider;
cx.spawn({ cx.spawn({
let cloud_user_store = cloud_user_store.clone(); let user_store = user_store.clone();
let editors = editors.clone(); let editors = editors.clone();
let client = client.clone(); let client = client.clone();
@ -72,7 +72,7 @@ pub fn init(client: Arc<Client>, cloud_user_store: Entity<CloudUserStore>, cx: &
&editors, &editors,
provider, provider,
&client, &client,
cloud_user_store.clone(), user_store.clone(),
cx, cx,
); );
}) })
@ -85,12 +85,12 @@ pub fn init(client: Arc<Client>, cloud_user_store: Entity<CloudUserStore>, cx: &
cx.observe_global::<SettingsStore>({ cx.observe_global::<SettingsStore>({
let editors = editors.clone(); let editors = editors.clone();
let client = client.clone(); let client = client.clone();
let cloud_user_store = cloud_user_store.clone(); let user_store = user_store.clone();
move |cx| { move |cx| {
let new_provider = all_language_settings(None, cx).edit_predictions.provider; let new_provider = all_language_settings(None, cx).edit_predictions.provider;
if new_provider != provider { if new_provider != provider {
let tos_accepted = cloud_user_store.read(cx).has_accepted_tos(); let tos_accepted = user_store.read(cx).has_accepted_terms_of_service();
telemetry::event!( telemetry::event!(
"Edit Prediction Provider Changed", "Edit Prediction Provider Changed",
@ -104,7 +104,7 @@ pub fn init(client: Arc<Client>, cloud_user_store: Entity<CloudUserStore>, cx: &
&editors, &editors,
provider, provider,
&client, &client,
cloud_user_store.clone(), user_store.clone(),
cx, cx,
); );
@ -145,7 +145,7 @@ fn assign_edit_prediction_providers(
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>, editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
provider: EditPredictionProvider, provider: EditPredictionProvider,
client: &Arc<Client>, client: &Arc<Client>,
cloud_user_store: Entity<CloudUserStore>, user_store: Entity<UserStore>,
cx: &mut App, cx: &mut App,
) { ) {
for (editor, window) in editors.borrow().iter() { for (editor, window) in editors.borrow().iter() {
@ -155,7 +155,7 @@ fn assign_edit_prediction_providers(
editor, editor,
provider, provider,
&client, &client,
cloud_user_store.clone(), user_store.clone(),
window, window,
cx, cx,
); );
@ -210,7 +210,7 @@ fn assign_edit_prediction_provider(
editor: &mut Editor, editor: &mut Editor,
provider: EditPredictionProvider, provider: EditPredictionProvider,
client: &Arc<Client>, client: &Arc<Client>,
cloud_user_store: Entity<CloudUserStore>, user_store: Entity<UserStore>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Editor>, cx: &mut Context<Editor>,
) { ) {
@ -241,7 +241,7 @@ fn assign_edit_prediction_provider(
} }
} }
EditPredictionProvider::Zed => { EditPredictionProvider::Zed => {
if cloud_user_store.read(cx).is_authenticated() { if user_store.read(cx).current_user().is_some() {
let mut worktree = None; let mut worktree = None;
if let Some(buffer) = &singleton_buffer { if let Some(buffer) = &singleton_buffer {
@ -263,7 +263,7 @@ fn assign_edit_prediction_provider(
.map(|workspace| workspace.downgrade()); .map(|workspace| workspace.downgrade());
let zeta = let zeta =
zeta::Zeta::register(workspace, worktree, client.clone(), cloud_user_store, cx); zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx);
if let Some(buffer) = &singleton_buffer { if let Some(buffer) = &singleton_buffer {
if buffer.read(cx).file().is_some() { if buffer.read(cx).file().is_some() {

View file

@ -16,7 +16,7 @@ pub use rate_completion_modal::*;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
use client::{Client, CloudUserStore, EditPredictionUsage}; use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::{ use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
@ -120,8 +120,8 @@ impl Dismissable for ZedPredictUpsell {
} }
} }
pub fn should_show_upsell_modal(cloud_user_store: &Entity<CloudUserStore>, cx: &App) -> bool { pub fn should_show_upsell_modal(user_store: &Entity<UserStore>, cx: &App) -> bool {
if cloud_user_store.read(cx).has_accepted_tos() { if user_store.read(cx).has_accepted_terms_of_service() {
!ZedPredictUpsell::dismissed() !ZedPredictUpsell::dismissed()
} else { } else {
true true
@ -229,7 +229,7 @@ pub struct Zeta {
_llm_token_subscription: Subscription, _llm_token_subscription: Subscription,
/// Whether an update to a newer version of Zed is required to continue using Zeta. /// Whether an update to a newer version of Zed is required to continue using Zeta.
update_required: bool, update_required: bool,
cloud_user_store: Entity<CloudUserStore>, user_store: Entity<UserStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>, license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
} }
@ -242,11 +242,11 @@ impl Zeta {
workspace: Option<WeakEntity<Workspace>>, workspace: Option<WeakEntity<Workspace>>,
worktree: Option<Entity<Worktree>>, worktree: Option<Entity<Worktree>>,
client: Arc<Client>, client: Arc<Client>,
cloud_user_store: Entity<CloudUserStore>, user_store: Entity<UserStore>,
cx: &mut App, cx: &mut App,
) -> Entity<Self> { ) -> Entity<Self> {
let this = Self::global(cx).unwrap_or_else(|| { let this = Self::global(cx).unwrap_or_else(|| {
let entity = cx.new(|cx| Self::new(workspace, client, cloud_user_store, cx)); let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
cx.set_global(ZetaGlobal(entity.clone())); cx.set_global(ZetaGlobal(entity.clone()));
entity entity
}); });
@ -269,13 +269,13 @@ impl Zeta {
} }
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> { pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
self.cloud_user_store.read(cx).edit_prediction_usage() self.user_store.read(cx).edit_prediction_usage()
} }
fn new( fn new(
workspace: Option<WeakEntity<Workspace>>, workspace: Option<WeakEntity<Workspace>>,
client: Arc<Client>, client: Arc<Client>,
cloud_user_store: Entity<CloudUserStore>, user_store: Entity<UserStore>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
@ -306,7 +306,7 @@ impl Zeta {
), ),
update_required: false, update_required: false,
license_detection_watchers: HashMap::default(), license_detection_watchers: HashMap::default(),
cloud_user_store, user_store,
} }
} }
@ -535,8 +535,8 @@ impl Zeta {
if let Some(usage) = usage { if let Some(usage) = usage {
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.cloud_user_store.update(cx, |cloud_user_store, cx| { this.user_store.update(cx, |user_store, cx| {
cloud_user_store.update_edit_prediction_usage(usage, cx); user_store.update_edit_prediction_usage(usage, cx);
}); });
}) })
.ok(); .ok();
@ -877,8 +877,8 @@ and then another
if response.status().is_success() { if response.status().is_success() {
if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.cloud_user_store.update(cx, |cloud_user_store, cx| { this.user_store.update(cx, |user_store, cx| {
cloud_user_store.update_edit_prediction_usage(usage, cx); user_store.update_edit_prediction_usage(usage, cx);
}); });
})?; })?;
} }
@ -1559,9 +1559,9 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
!self !self
.zeta .zeta
.read(cx) .read(cx)
.cloud_user_store .user_store
.read(cx) .read(cx)
.has_accepted_tos() .has_accepted_terms_of_service()
} }
fn is_refreshing(&self) -> bool { fn is_refreshing(&self) -> bool {
@ -1587,7 +1587,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
if self if self
.zeta .zeta
.read(cx) .read(cx)
.cloud_user_store .user_store
.read_with(cx, |cloud_user_store, _cx| { .read_with(cx, |cloud_user_store, _cx| {
cloud_user_store.account_too_young() || cloud_user_store.has_overdue_invoices() cloud_user_store.account_too_young() || cloud_user_store.has_overdue_invoices()
}) })
@ -1808,10 +1808,7 @@ mod tests {
use client::UserStore; use client::UserStore;
use client::test::FakeServer; use client::test::FakeServer;
use clock::FakeSystemClock; use clock::FakeSystemClock;
use cloud_api_types::{ use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
AuthenticatedUser, CreateLlmTokenResponse, GetAuthenticatedUserResponse, LlmToken, PlanInfo,
};
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
use gpui::TestAppContext; use gpui::TestAppContext;
use http_client::FakeHttpClient; use http_client::FakeHttpClient;
use indoc::indoc; use indoc::indoc;
@ -1820,39 +1817,6 @@ mod tests {
use super::*; use super::*;
fn make_get_authenticated_user_response() -> GetAuthenticatedUserResponse {
GetAuthenticatedUserResponse {
user: AuthenticatedUser {
id: 1,
metrics_id: "metrics-id-1".to_string(),
avatar_url: "".to_string(),
github_login: "".to_string(),
name: None,
is_staff: false,
accepted_tos_at: None,
},
feature_flags: vec![],
plan: PlanInfo {
plan: Plan::ZedPro,
subscription_period: None,
usage: CurrentUsage {
model_requests: UsageData {
used: 0,
limit: UsageLimit::Limited(500),
},
edit_predictions: UsageData {
used: 250,
limit: UsageLimit::Unlimited,
},
},
trial_started_at: None,
is_usage_based_billing_enabled: false,
is_account_too_young: false,
has_overdue_invoices: false,
},
}
}
#[gpui::test] #[gpui::test]
async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) { async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
@ -2054,14 +2018,6 @@ mod tests {
let http_client = FakeHttpClient::create(move |req| async move { let http_client = FakeHttpClient::create(move |req| async move {
match (req.method(), req.uri().path()) { match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response())
.unwrap()
.into(),
)
.unwrap()),
(&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
.status(200) .status(200)
.body( .body(
@ -2098,9 +2054,7 @@ mod tests {
// Construct the fake server to authenticate. // Construct the fake server to authenticate.
let _server = FakeServer::for_client(42, &client, cx).await; let _server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store = let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
@ -2128,14 +2082,6 @@ mod tests {
let completion = completion_response.clone(); let completion = completion_response.clone();
async move { async move {
match (req.method(), req.uri().path()) { match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response())
.unwrap()
.into(),
)
.unwrap()),
(&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
.status(200) .status(200)
.body( .body(
@ -2172,9 +2118,7 @@ mod tests {
// Construct the fake server to authenticate. // Construct the fake server to authenticate.
let _server = FakeServer::for_client(42, &client, cx).await; let _server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store = let zeta = cx.new(|cx| Zeta::new(None, client, user_store.clone(), cx));
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());