Update Agent panel to work with CloudUserStore (#35436)

This PR updates the Agent panel to work with the `CloudUserStore`
instead of the `UserStore`, reducing its reliance on being connected to
Collab to function.

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Marshall Bowers 2025-07-31 21:44:43 -04:00 committed by GitHub
parent 09b93caa9b
commit 72d354de6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 212 additions and 108 deletions

View file

@ -12,7 +12,7 @@ use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use client::{CloudUserStore, ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use collections::HashMap;
use feature_flags::{self, FeatureFlagAppExt};
@ -374,6 +374,7 @@ pub struct Thread {
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
prompt_builder: Arc<PromptBuilder>,
tools: Entity<ToolWorkingSet>,
tool_use: ToolUseState,
@ -444,6 +445,7 @@ pub struct ExceededWindowError {
impl Thread {
pub fn new(
project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext,
@ -470,6 +472,7 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
project: project.clone(),
cloud_user_store,
prompt_builder,
tools: tools.clone(),
last_restore_checkpoint: None,
@ -503,6 +506,7 @@ impl Thread {
id: ThreadId,
serialized: SerializedThread,
project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
@ -603,6 +607,7 @@ impl Thread {
last_restore_checkpoint: None,
pending_checkpoint: None,
project: project.clone(),
cloud_user_store,
prompt_builder,
tools: tools.clone(),
tool_use,
@ -3255,16 +3260,14 @@ impl Thread {
}
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
self.project.update(cx, |project, cx| {
project.user_store().update(cx, |user_store, cx| {
user_store.update_model_request_usage(
ModelRequestUsage(RequestUsage {
amount: amount as i32,
limit,
}),
cx,
)
})
self.cloud_user_store.update(cx, |cloud_user_store, cx| {
cloud_user_store.update_model_request_usage(
ModelRequestUsage(RequestUsage {
amount: amount as i32,
limit,
}),
cx,
)
});
}
@ -3883,6 +3886,7 @@ fn main() {{
thread.id.clone(),
serialized,
thread.project.clone(),
thread.cloud_user_store.clone(),
thread.tools.clone(),
thread.prompt_builder.clone(),
thread.project_context.clone(),
@ -5479,10 +5483,16 @@ fn main() {{
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let (client, user_store) =
project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
cloud_user_store,
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),

View file

@ -8,6 +8,7 @@ use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{Tool, ToolId, ToolWorkingSet};
use chrono::{DateTime, Utc};
use client::CloudUserStore;
use collections::HashMap;
use context_server::ContextServerId;
use futures::{
@ -104,6 +105,7 @@ pub type TextThreadStore = assistant_context::ContextStore;
pub struct ThreadStore {
project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
@ -124,6 +126,7 @@ impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore {
pub fn load(
project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_store: Option<Entity<PromptStore>>,
prompt_builder: Arc<PromptBuilder>,
@ -133,8 +136,14 @@ impl ThreadStore {
let (thread_store, ready_rx) = cx.update(|cx| {
let mut option_ready_rx = None;
let thread_store = cx.new(|cx| {
let (thread_store, ready_rx) =
Self::new(project, tools, prompt_builder, prompt_store, cx);
let (thread_store, ready_rx) = Self::new(
project,
cloud_user_store,
tools,
prompt_builder,
prompt_store,
cx,
);
option_ready_rx = Some(ready_rx);
thread_store
});
@ -147,6 +156,7 @@ impl ThreadStore {
fn new(
project: Entity<Project>,
cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
@ -190,6 +200,7 @@ impl ThreadStore {
let this = Self {
project,
cloud_user_store,
tools,
prompt_builder,
prompt_store,
@ -407,6 +418,7 @@ impl ThreadStore {
cx.new(|cx| {
Thread::new(
self.project.clone(),
self.cloud_user_store.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
@ -425,6 +437,7 @@ impl ThreadStore {
ThreadId::new(),
serialized,
self.project.clone(),
self.cloud_user_store.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
@ -456,6 +469,7 @@ impl ThreadStore {
id.clone(),
thread,
this.project.clone(),
this.cloud_user_store.clone(),
this.tools.clone(),
this.prompt_builder.clone(),
this.project_context.clone(),

View file

@ -3820,6 +3820,7 @@ mod tests {
use super::*;
use agent::{MessageSegment, context::ContextLoadResult, thread_store};
use assistant_tool::{ToolRegistry, ToolWorkingSet};
use client::CloudUserStore;
use editor::EditorSettings;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext, VisualTestContext};
@ -4116,10 +4117,16 @@ mod tests {
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let (client, user_store) =
project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
cloud_user_store,
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),

View file

@ -1893,6 +1893,7 @@ mod tests {
use agent::thread_store::{self, ThreadStore};
use agent_settings::AgentSettings;
use assistant_tool::ToolWorkingSet;
use client::CloudUserStore;
use editor::EditorSettings;
use gpui::{TestAppContext, UpdateGlobal, VisualTestContext};
use project::{FakeFs, Project};
@ -1932,11 +1933,17 @@ mod tests {
})
.unwrap();
let (client, user_store) =
project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
let prompt_store = None;
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
cloud_user_store,
cx.new(|_| ToolWorkingSet::default()),
prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()),
@ -2098,11 +2105,17 @@ mod tests {
})
.unwrap();
let (client, user_store) =
project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
let prompt_store = None;
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
cloud_user_store,
cx.new(|_| ToolWorkingSet::default()),
prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()),

View file

@ -43,7 +43,7 @@ use anyhow::{Result, anyhow};
use assistant_context::{AssistantContext, ContextEvent, ContextSummary};
use assistant_slash_command::SlashCommandWorkingSet;
use assistant_tool::ToolWorkingSet;
use client::{DisableAiSettings, UserStore, zed_urls};
use client::{CloudUserStore, DisableAiSettings, UserStore, zed_urls};
use cloud_llm_client::{CompletionIntent, UsageLimit};
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use feature_flags::{self, FeatureFlagAppExt};
@ -427,6 +427,7 @@ impl ActiveView {
pub struct AgentPanel {
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
project: Entity<Project>,
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
@ -486,6 +487,7 @@ impl AgentPanel {
let project = workspace.project().clone();
ThreadStore::load(
project,
workspace.app_state().cloud_user_store.clone(),
tools.clone(),
prompt_store.clone(),
prompt_builder.clone(),
@ -553,6 +555,7 @@ impl AgentPanel {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone();
let cloud_user_store = workspace.app_state().cloud_user_store.clone();
let project = workspace.project();
let language_registry = project.read(cx).languages().clone();
let client = workspace.client().clone();
@ -579,7 +582,7 @@ impl AgentPanel {
MessageEditor::new(
fs.clone(),
workspace.clone(),
user_store.clone(),
cloud_user_store.clone(),
message_editor_context_store.clone(),
prompt_store.clone(),
thread_store.downgrade(),
@ -706,6 +709,7 @@ impl AgentPanel {
active_view,
workspace,
user_store,
cloud_user_store,
project: project.clone(),
fs: fs.clone(),
language_registry,
@ -848,7 +852,7 @@ impl AgentPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
self.user_store.clone(),
self.cloud_user_store.clone(),
context_store.clone(),
self.prompt_store.clone(),
self.thread_store.downgrade(),
@ -1122,7 +1126,7 @@ impl AgentPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
self.user_store.clone(),
self.cloud_user_store.clone(),
context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
@ -1821,8 +1825,8 @@ impl AgentPanel {
}
fn render_toolbar(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let user_store = self.user_store.read(cx);
let usage = user_store.model_request_usage();
let cloud_user_store = self.cloud_user_store.read(cx);
let usage = cloud_user_store.model_request_usage();
let account_url = zed_urls::account_url(cx);

View file

@ -17,7 +17,7 @@ use agent::{
use agent_settings::{AgentSettings, CompletionMode};
use ai_onboarding::ApiKeysWithProviders;
use buffer_diff::BufferDiff;
use client::UserStore;
use client::CloudUserStore;
use cloud_llm_client::CompletionIntent;
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
@ -43,7 +43,6 @@ use language_model::{
use multi_buffer;
use project::Project;
use prompt_store::PromptStore;
use proto::Plan;
use settings::Settings;
use std::time::Duration;
use theme::ThemeSettings;
@ -79,7 +78,7 @@ pub struct MessageEditor {
editor: Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
history_store: Option<WeakEntity<HistoryStore>>,
@ -159,7 +158,7 @@ impl MessageEditor {
pub fn new(
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>,
@ -231,7 +230,7 @@ impl MessageEditor {
Self {
editor: editor.clone(),
project: thread.read(cx).project().clone(),
user_store,
cloud_user_store,
thread,
incompatible_tools_state: incompatible_tools.clone(),
workspace,
@ -1287,26 +1286,16 @@ impl MessageEditor {
return None;
}
let user_store = self.user_store.read(cx);
let ubb_enable = user_store
.usage_based_billing_enabled()
.map_or(false, |enabled| enabled);
if ubb_enable {
let cloud_user_store = self.cloud_user_store.read(cx);
if cloud_user_store.is_usage_based_billing_enabled() {
return None;
}
let plan = user_store
.current_plan()
.map(|plan| match plan {
Plan::Free => cloud_llm_client::Plan::ZedFree,
Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
})
let plan = cloud_user_store
.plan()
.unwrap_or(cloud_llm_client::Plan::ZedFree);
let usage = user_store.model_request_usage()?;
let usage = cloud_user_store.model_request_usage()?;
Some(
div()
@ -1769,7 +1758,7 @@ impl AgentPreview for MessageEditor {
) -> Option<AnyElement> {
if let Some(workspace) = workspace.upgrade() {
let fs = workspace.read(cx).app_state().fs.clone();
let user_store = workspace.read(cx).app_state().user_store.clone();
let cloud_user_store = workspace.read(cx).app_state().cloud_user_store.clone();
let project = workspace.read(cx).project().clone();
let weak_project = project.downgrade();
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
@ -1782,7 +1771,7 @@ impl AgentPreview for MessageEditor {
MessageEditor::new(
fs,
workspace.downgrade(),
user_store,
cloud_user_store,
context_store,
None,
thread_store.downgrade(),

View file

@ -7,7 +7,7 @@ use crate::{
};
use Role::*;
use assistant_tool::ToolRegistry;
use client::{Client, UserStore};
use client::{Client, CloudUserStore, UserStore};
use collections::HashMap;
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
@ -1470,12 +1470,14 @@ impl EditAgentTest {
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
settings::init(cx);
Project::init_settings(cx);
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store.clone(), cloud_user_store, client.clone(), cx);
crate::init(client.http_client(), cx);
});

View file

@ -9,12 +9,13 @@ use gpui::{Context, Entity, Subscription, Task};
use util::{ResultExt as _, maybe};
use crate::user::Event as RpcUserStoreEvent;
use crate::{EditPredictionUsage, RequestUsage, UserStore};
use crate::{EditPredictionUsage, ModelRequestUsage, RequestUsage, UserStore};
pub struct CloudUserStore {
cloud_client: Arc<CloudApiClient>,
authenticated_user: Option<Arc<AuthenticatedUser>>,
plan_info: Option<Arc<PlanInfo>>,
model_request_usage: Option<ModelRequestUsage>,
edit_prediction_usage: Option<EditPredictionUsage>,
_maintain_authenticated_user_task: Task<()>,
_rpc_plan_updated_subscription: Subscription,
@ -33,6 +34,7 @@ impl CloudUserStore {
cloud_client: cloud_client.clone(),
authenticated_user: None,
plan_info: None,
model_request_usage: None,
edit_prediction_usage: None,
_maintain_authenticated_user_task: cx.spawn(async move |this, cx| {
maybe!(async move {
@ -104,6 +106,13 @@ impl CloudUserStore {
})
}
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
self.plan_info
.as_ref()
.and_then(|plan| plan.trial_started_at)
.map(|trial_started_at| trial_started_at.0)
}
pub fn has_accepted_tos(&self) -> bool {
self.authenticated_user
.as_ref()
@ -127,6 +136,22 @@ impl CloudUserStore {
.unwrap_or_default()
}
pub fn is_usage_based_billing_enabled(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_usage_based_billing_enabled)
.unwrap_or_default()
}
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
self.model_request_usage
}
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
self.model_request_usage = Some(usage);
cx.notify();
}
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
self.edit_prediction_usage
}
@ -142,6 +167,10 @@ impl CloudUserStore {
fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) {
self.authenticated_user = Some(Arc::new(response.user));
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
limit: response.plan.usage.model_requests.limit,
amount: response.plan.usage.model_requests.used as i32,
}));
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
limit: response.plan.usage.edit_predictions.limit,
amount: response.plan.usage.edit_predictions.used as i32,

View file

@ -113,7 +113,6 @@ pub struct UserStore {
current_plan: Option<proto::Plan>,
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
trial_started_at: Option<DateTime<Utc>>,
model_request_usage: Option<ModelRequestUsage>,
is_usage_based_billing_enabled: Option<bool>,
account_too_young: Option<bool>,
has_overdue_invoices: Option<bool>,
@ -191,7 +190,6 @@ impl UserStore {
current_plan: None,
subscription_period: None,
trial_started_at: None,
model_request_usage: None,
is_usage_based_billing_enabled: None,
account_too_young: None,
has_overdue_invoices: None,
@ -371,27 +369,12 @@ impl UserStore {
this.account_too_young = message.payload.account_too_young;
this.has_overdue_invoices = message.payload.has_overdue_invoices;
if let Some(usage) = message.payload.usage {
// limits are always present even though they are wrapped in Option
this.model_request_usage = usage
.model_requests_usage_limit
.and_then(|limit| {
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
})
.map(ModelRequestUsage);
}
cx.emit(Event::PlanUpdated);
cx.notify();
})?;
Ok(())
}
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
self.model_request_usage = Some(usage);
cx.notify();
}
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
match message {
UpdateContacts::Wait(barrier) => {
@ -776,10 +759,6 @@ impl UserStore {
self.is_usage_based_billing_enabled
}
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
self.model_request_usage
}
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
self.current_user.clone()
}

View file

@ -13,7 +13,7 @@ pub(crate) use tool_metrics::*;
use ::fs::RealFs;
use clap::Parser;
use client::{Client, ProxySettings, UserStore};
use client::{Client, CloudUserStore, ProxySettings, UserStore};
use collections::{HashMap, HashSet};
use extension::ExtensionHostProxy;
use futures::future;
@ -329,6 +329,7 @@ pub struct AgentAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub cloud_user_store: Entity<CloudUserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
@ -383,6 +384,8 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
let languages = Arc::new(languages);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store =
cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
extension::init(cx);
@ -422,7 +425,12 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
languages.clone(),
);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
language_models::init(
user_store.clone(),
cloud_user_store.clone(),
client.clone(),
cx,
);
languages::init(languages.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
terminal_view::init(cx);
@ -447,6 +455,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
languages,
client,
user_store,
cloud_user_store,
fs,
node_runtime,
prompt_builder,

View file

@ -221,6 +221,7 @@ impl ExampleInstance {
let prompt_store = None;
let thread_store = ThreadStore::load(
project.clone(),
app_state.cloud_user_store.clone(),
tools,
prompt_store,
app_state.prompt_builder.clone(),

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
use client::{Client, CloudUserStore, UserStore};
use collections::HashSet;
use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
@ -26,11 +26,22 @@ use crate::provider::vercel::VercelLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
pub fn init(
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>,
cx: &mut App,
) {
crate::settings::init_settings(cx);
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
register_language_model_providers(registry, user_store, client.clone(), cx);
register_language_model_providers(
registry,
user_store,
cloud_user_store,
client.clone(),
cx,
);
});
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
@ -100,11 +111,17 @@ fn register_openai_compatible_providers(
fn register_language_model_providers(
registry: &mut LanguageModelRegistry,
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>,
cx: &mut Context<LanguageModelRegistry>,
) {
registry.register_provider(
CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
CloudLanguageModelProvider::new(
user_store.clone(),
cloud_user_store.clone(),
client.clone(),
cx,
),
cx,
);

View file

@ -2,11 +2,11 @@ use ai_onboarding::YoungAccountBanner;
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use client::{Client, CloudUserStore, ModelRequestUsage, UserStore, zed_urls};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan,
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
@ -27,7 +27,6 @@ use language_model::{
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
};
use proto::Plan;
use release_channel::AppVersion;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
@ -118,6 +117,7 @@ pub struct State {
client: Arc<Client>,
llm_api_token: LlmApiToken,
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
status: client::Status,
accept_terms_of_service_task: Option<Task<Result<()>>>,
models: Vec<Arc<cloud_llm_client::LanguageModel>>,
@ -133,6 +133,7 @@ impl State {
fn new(
client: Arc<Client>,
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
status: client::Status,
cx: &mut Context<Self>,
) -> Self {
@ -142,6 +143,7 @@ impl State {
client: client.clone(),
llm_api_token: LlmApiToken::default(),
user_store,
cloud_user_store,
status,
accept_terms_of_service_task: None,
models: Vec::new(),
@ -150,12 +152,19 @@ impl State {
recommended_models: Vec::new(),
_fetch_models_task: cx.spawn(async move |this, cx| {
maybe!(async move {
let (client, llm_api_token) = this
.read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
let (client, cloud_user_store, llm_api_token) =
this.read_with(cx, |this, _cx| {
(
client.clone(),
this.cloud_user_store.clone(),
this.llm_api_token.clone(),
)
})?;
loop {
let status = this.read_with(cx, |this, _cx| this.status)?;
if matches!(status, client::Status::Connected { .. }) {
let is_authenticated =
cloud_user_store.read_with(cx, |this, _cx| this.is_authenticated())?;
if is_authenticated {
break;
}
@ -194,8 +203,8 @@ impl State {
}
}
fn is_signed_out(&self) -> bool {
self.status.is_signed_out()
fn is_signed_out(&self, cx: &App) -> bool {
!self.cloud_user_store.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
@ -210,10 +219,7 @@ impl State {
}
fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
self.user_store
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false)
self.cloud_user_store.read(cx).has_accepted_tos()
}
fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
@ -297,11 +303,24 @@ impl State {
}
impl CloudLanguageModelProvider {
pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
pub fn new(
user_store: Entity<UserStore>,
cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>,
cx: &mut App,
) -> Self {
let mut status_rx = client.status();
let status = *status_rx.borrow();
let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
let state = cx.new(|cx| {
State::new(
client.clone(),
user_store.clone(),
cloud_user_store.clone(),
status,
cx,
)
});
let state_ref = state.downgrade();
let maintain_client_status = cx.spawn(async move |cx| {
@ -398,7 +417,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
fn is_authenticated(&self, cx: &App) -> bool {
let state = self.state.read(cx);
!state.is_signed_out() && state.has_accepted_terms_of_service(cx)
!state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx)
}
fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
@ -614,9 +633,9 @@ impl CloudLanguageModel {
.and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok())
{
let plan = match plan {
cloud_llm_client::Plan::ZedFree => Plan::Free,
cloud_llm_client::Plan::ZedPro => Plan::ZedPro,
cloud_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
};
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
}
@ -1118,7 +1137,7 @@ fn response_lines<T: DeserializeOwned>(
#[derive(IntoElement, RegisterComponent)]
struct ZedAiConfiguration {
is_connected: bool,
plan: Option<proto::Plan>,
plan: Option<Plan>,
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
eligible_for_trial: bool,
has_accepted_terms_of_service: bool,
@ -1132,15 +1151,15 @@ impl RenderOnce for ZedAiConfiguration {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let young_account_banner = YoungAccountBanner;
let is_pro = self.plan == Some(proto::Plan::ZedPro);
let is_pro = self.plan == Some(Plan::ZedPro);
let subscription_text = match (self.plan, self.subscription_period) {
(Some(proto::Plan::ZedPro), Some(_)) => {
(Some(Plan::ZedPro), Some(_)) => {
"You have access to Zed's hosted models through your Pro subscription."
}
(Some(proto::Plan::ZedProTrial), Some(_)) => {
(Some(Plan::ZedProTrial), Some(_)) => {
"You have access to Zed's hosted models through your Pro trial."
}
(Some(proto::Plan::Free), Some(_)) => {
(Some(Plan::ZedFree), Some(_)) => {
"You have basic access to Zed's hosted models through the Free plan."
}
_ => {
@ -1262,15 +1281,15 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let state = self.state.read(cx);
let user_store = state.user_store.read(cx);
let cloud_user_store = state.cloud_user_store.read(cx);
ZedAiConfiguration {
is_connected: !state.is_signed_out(),
plan: user_store.current_plan(),
subscription_period: user_store.subscription_period(),
eligible_for_trial: user_store.trial_started_at().is_none(),
is_connected: !state.is_signed_out(cx),
plan: cloud_user_store.plan(),
subscription_period: cloud_user_store.subscription_period(),
eligible_for_trial: cloud_user_store.trial_started_at().is_none(),
has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
account_too_young: user_store.account_too_young(),
account_too_young: cloud_user_store.account_too_young(),
accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
sign_in_callback: self.sign_in_callback.clone(),
@ -1286,7 +1305,7 @@ impl Component for ZedAiConfiguration {
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
fn configuration(
is_connected: bool,
plan: Option<proto::Plan>,
plan: Option<Plan>,
eligible_for_trial: bool,
account_too_young: bool,
has_accepted_terms_of_service: bool,
@ -1330,15 +1349,15 @@ impl Component for ZedAiConfiguration {
),
single_example(
"Free Plan",
configuration(true, Some(proto::Plan::Free), true, false, true),
configuration(true, Some(Plan::ZedFree), true, false, true),
),
single_example(
"Zed Pro Trial Plan",
configuration(true, Some(proto::Plan::ZedProTrial), true, false, true),
configuration(true, Some(Plan::ZedProTrial), true, false, true),
),
single_example(
"Zed Pro Plan",
configuration(true, Some(proto::Plan::ZedPro), true, false, true),
configuration(true, Some(Plan::ZedPro), true, false, true),
),
])
.into_any_element(),

View file

@ -556,7 +556,12 @@ pub fn main() {
);
supermaven::init(app_state.client.clone(), cx);
language_model::init(app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_models::init(
app_state.user_store.clone(),
app_state.cloud_user_store.clone(),
app_state.client.clone(),
cx,
);
agent_settings::init(cx);
agent_servers::init(cx);
web_search::init(cx);

View file

@ -4488,7 +4488,12 @@ mod tests {
);
image_viewer::init(cx);
language_model::init(app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_models::init(
app_state.user_store.clone(),
app_state.cloud_user_store.clone(),
app_state.client.clone(),
cx,
);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);

View file

@ -17,9 +17,10 @@ pub fn load_preview_thread_store(
cx: &mut AsyncApp,
) -> Task<Result<Entity<ThreadStore>>> {
workspace
.update(cx, |_, cx| {
.update(cx, |workspace, cx| {
ThreadStore::load(
project.clone(),
workspace.app_state().cloud_user_store.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),