From 7e801dccb0e7d2296120dbf277f63ecfecd223d3 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Fri, 20 Jun 2025 15:28:48 -0600 Subject: [PATCH] agent: Fix issues with usage display sometimes showing initially fetched usage (#33125) Having `Thread::last_usage` as an override of the initially fetched usage could cause the initial usage to be displayed when the current thread is empty or in text threads. Fix is to just store last usage info in `UserStore` and not have these overrides Release Notes: - Agent: Fixed request usage display to always include the most recently known usage - there were some cases where it would show the initially requested usage. --- Cargo.lock | 3 + crates/agent/src/agent_panel.rs | 26 +--- crates/agent/src/debug.rs | 10 +- crates/agent/src/message_editor.rs | 23 +-- crates/agent/src/thread.rs | 39 ++--- crates/agent/src/ui/preview/usage_callouts.rs | 27 ++-- crates/client/Cargo.toml | 2 + crates/client/src/user.rs | 142 +++++++++++++++--- crates/inline_completion/Cargo.toml | 3 +- .../src/inline_completion.rs | 40 +---- crates/language_model/src/language_model.rs | 35 +---- crates/language_models/src/provider/cloud.rs | 10 +- crates/zeta/src/zeta.rs | 39 ++--- 13 files changed, 188 insertions(+), 211 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d3a623f7ef..bfa9b2396e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2821,6 +2821,7 @@ dependencies = [ "cocoa 0.26.0", "collections", "credentials_provider", + "derive_more", "feature_flags", "fs", "futures 0.3.31", @@ -2859,6 +2860,7 @@ dependencies = [ "windows 0.61.1", "workspace-hack", "worktree", + "zed_llm_client", ] [[package]] @@ -8159,6 +8161,7 @@ name = "inline_completion" version = "0.1.0" dependencies = [ "anyhow", + "client", "gpui", "language", "project", diff --git a/crates/agent/src/agent_panel.rs b/crates/agent/src/agent_panel.rs index ee76045d21..10c2db37c4 100644 --- a/crates/agent/src/agent_panel.rs +++ b/crates/agent/src/agent_panel.rs @@ -29,8 +29,7 @@ use gpui::{ }; use language::LanguageRegistry; use language_model::{ - ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, RequestUsage, - ZED_CLOUD_PROVIDER_ID, + ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, }; use project::{Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; @@ -45,7 +44,7 @@ use ui::{ Banner, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu, PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*, }; -use util::{ResultExt as _, maybe}; +use util::ResultExt as _; use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::{ CollaboratorId, DraggedSelection, DraggedTab, ToggleZoom, ToolbarItemView, Workspace, @@ -1682,24 +1681,7 @@ impl AgentPanel { let thread_id = thread.id().clone(); let is_empty = active_thread.is_empty(); let editor_empty = self.message_editor.read(cx).is_editor_fully_empty(cx); - let last_usage = active_thread.thread().read(cx).last_usage().or_else(|| { - maybe!({ - let amount = user_store.model_request_usage_amount()?; - let limit = user_store.model_request_usage_limit()?.variant?; - - Some(RequestUsage { - amount: amount as i32, - limit: match limit { - proto::usage_limit::Variant::Limited(limited) => { - zed_llm_client::UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => { - zed_llm_client::UsageLimit::Unlimited - } - }, - }) - }) - }); + let usage = user_store.model_request_usage(); let account_url = zed_urls::account_url(cx); @@ -1820,7 +1802,7 @@ impl AgentPanel { .action("Add Custom Server…", Box::new(AddContextServer)) .separator(); - if let Some(usage) = last_usage { + if let Some(usage) = usage { menu = menu .header_with_link("Prompt Usage", "Manage", account_url.clone()) .custom_entry( diff --git a/crates/agent/src/debug.rs b/crates/agent/src/debug.rs index 7bd52e5a96..ff6538dc85 100644 --- a/crates/agent/src/debug.rs +++ b/crates/agent/src/debug.rs @@ -1,7 +1,7 @@ #![allow(unused, dead_code)] +use client::{ModelRequestUsage, RequestUsage}; use gpui::Global; -use language_model::RequestUsage; use std::ops::{Deref, DerefMut}; use ui::prelude::*; use zed_llm_client::{Plan, UsageLimit}; @@ -17,7 +17,7 @@ pub struct DebugAccountState { pub enabled: bool, pub trial_expired: bool, pub plan: Plan, - pub custom_prompt_usage: RequestUsage, + pub custom_prompt_usage: ModelRequestUsage, pub usage_based_billing_enabled: bool, pub monthly_spending_cap: i32, pub custom_edit_prediction_usage: UsageLimit, @@ -43,7 +43,7 @@ impl DebugAccountState { self } - pub fn set_custom_prompt_usage(&mut self, custom_prompt_usage: RequestUsage) -> &mut Self { + pub fn set_custom_prompt_usage(&mut self, custom_prompt_usage: ModelRequestUsage) -> &mut Self { self.custom_prompt_usage = custom_prompt_usage; self } @@ -76,10 +76,10 @@ impl Default for DebugAccountState { enabled: false, trial_expired: false, plan: Plan::ZedFree, - custom_prompt_usage: RequestUsage { + custom_prompt_usage: ModelRequestUsage(RequestUsage { limit: UsageLimit::Unlimited, amount: 0, - }, + }), usage_based_billing_enabled: false, // $50.00 monthly_spending_cap: 5000, diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index c8d127aa28..ec0a01e8af 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -29,8 +29,7 @@ use gpui::{ }; use language::{Buffer, Language, Point}; use language_model::{ - ConfiguredModel, LanguageModelRequestMessage, MessageContent, RequestUsage, - ZED_CLOUD_PROVIDER_ID, + ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID, }; use multi_buffer; use project::Project; @@ -42,7 +41,7 @@ use theme::ThemeSettings; use ui::{ Callout, Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*, }; -use util::{ResultExt as _, maybe}; +use util::ResultExt as _; use workspace::{CollaboratorId, Workspace}; use zed_llm_client::CompletionIntent; @@ -1257,24 +1256,8 @@ impl MessageEditor { Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, }) .unwrap_or(zed_llm_client::Plan::ZedFree); - let usage = self.thread.read(cx).last_usage().or_else(|| { - maybe!({ - let amount = user_store.model_request_usage_amount()?; - let limit = user_store.model_request_usage_limit()?.variant?; - Some(RequestUsage { - amount: amount as i32, - limit: match limit { - proto::usage_limit::Variant::Limited(limited) => { - zed_llm_client::UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => { - zed_llm_client::UsageLimit::Unlimited - } - }, - }) - }) - })?; + let usage = user_store.model_request_usage()?; Some( div() diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index edc0ea1152..1a6b9604b5 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -7,6 +7,7 @@ use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; +use client::{ModelRequestUsage, RequestUsage}; use collections::HashMap; use editor::display_map::CreaseMetadata; use feature_flags::{self, FeatureFlagAppExt}; @@ -22,8 +23,8 @@ use language_model::{ LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, - ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, - StopReason, TokenUsage, + ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason, + TokenUsage, }; use postage::stream::Stream as _; use project::Project; @@ -38,7 +39,7 @@ use ui::Window; use util::{ResultExt as _, post_inc}; use uuid::Uuid; -use zed_llm_client::{CompletionIntent, CompletionRequestStatus}; +use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use crate::ThreadStore; use crate::agent_profile::AgentProfile; @@ -350,7 +351,6 @@ pub struct Thread { request_token_usage: Vec, cumulative_token_usage: TokenUsage, exceeded_window_error: Option, - last_usage: Option, tool_use_limit_reached: bool, feedback: Option, message_feedback: HashMap, @@ -443,7 +443,6 @@ impl Thread { request_token_usage: Vec::new(), cumulative_token_usage: TokenUsage::default(), exceeded_window_error: None, - last_usage: None, tool_use_limit_reached: false, feedback: None, message_feedback: HashMap::default(), @@ -568,7 +567,6 @@ impl Thread { request_token_usage: serialized.request_token_usage, cumulative_token_usage: serialized.cumulative_token_usage, exceeded_window_error: None, - last_usage: None, tool_use_limit_reached: serialized.tool_use_limit_reached, feedback: None, message_feedback: HashMap::default(), @@ -875,10 +873,6 @@ impl Thread { .unwrap_or(false) } - pub fn last_usage(&self) -> Option { - self.last_usage - } - pub fn tool_use_limit_reached(&self) -> bool { self.tool_use_limit_reached } @@ -1658,9 +1652,7 @@ impl Thread { CompletionRequestStatus::UsageUpdated { amount, limit } => { - let usage = RequestUsage { limit, amount: amount as i32 }; - - thread.last_usage = Some(usage); + thread.update_model_request_usage(amount as u32, limit, cx); } CompletionRequestStatus::ToolUseLimitReached => { thread.tool_use_limit_reached = true; @@ -1871,11 +1863,8 @@ impl Thread { LanguageModelCompletionEvent::StatusUpdate( CompletionRequestStatus::UsageUpdated { amount, limit }, ) => { - this.update(cx, |thread, _cx| { - thread.last_usage = Some(RequestUsage { - limit, - amount: amount as i32, - }); + this.update(cx, |thread, cx| { + thread.update_model_request_usage(amount as u32, limit, cx); })?; continue; } @@ -2757,6 +2746,20 @@ impl Thread { } } + fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { + self.project.update(cx, |project, cx| { + project.user_store().update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }) + }); + } + pub fn deny_tool_use( &mut self, tool_use_id: LanguageModelToolUseId, diff --git a/crates/agent/src/ui/preview/usage_callouts.rs b/crates/agent/src/ui/preview/usage_callouts.rs index 62e2909461..45af41395b 100644 --- a/crates/agent/src/ui/preview/usage_callouts.rs +++ b/crates/agent/src/ui/preview/usage_callouts.rs @@ -1,18 +1,17 @@ -use client::zed_urls; +use client::{ModelRequestUsage, RequestUsage, zed_urls}; use component::{empty_example, example_group_with_title, single_example}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; -use language_model::RequestUsage; use ui::{Callout, prelude::*}; use zed_llm_client::{Plan, UsageLimit}; #[derive(IntoElement, RegisterComponent)] pub struct UsageCallout { plan: Plan, - usage: RequestUsage, + usage: ModelRequestUsage, } impl UsageCallout { - pub fn new(plan: Plan, usage: RequestUsage) -> Self { + pub fn new(plan: Plan, usage: ModelRequestUsage) -> Self { Self { plan, usage } } } @@ -128,10 +127,10 @@ impl Component for UsageCallout { "Approaching limit (90%)", UsageCallout::new( Plan::ZedFree, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(50), amount: 45, // 90% of limit - }, + }), ) .into_any_element(), ), @@ -139,10 +138,10 @@ impl Component for UsageCallout { "Limit reached (100%)", UsageCallout::new( Plan::ZedFree, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(50), amount: 50, // 100% of limit - }, + }), ) .into_any_element(), ), @@ -156,10 +155,10 @@ impl Component for UsageCallout { "Approaching limit (90%)", UsageCallout::new( Plan::ZedProTrial, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(150), amount: 135, // 90% of limit - }, + }), ) .into_any_element(), ), @@ -167,10 +166,10 @@ impl Component for UsageCallout { "Limit reached (100%)", UsageCallout::new( Plan::ZedProTrial, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(150), amount: 150, // 100% of limit - }, + }), ) .into_any_element(), ), @@ -184,10 +183,10 @@ impl Component for UsageCallout { "Limit reached (100%)", UsageCallout::new( Plan::ZedPro, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(500), amount: 500, // 100% of limit - }, + }), ) .into_any_element(), ), diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 0d65b7ef21..b741f515fd 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -24,6 +24,7 @@ chrono = { workspace = true, features = ["serde"] } clock.workspace = true collections.workspace = true credentials_provider.workspace = true +derive_more.workspace = true feature_flags.workspace = true futures.workspace = true gpui.workspace = true @@ -57,6 +58,7 @@ worktree.workspace = true telemetry.workspace = true tokio.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 1be8d71e85..61e3064eb4 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -2,16 +2,25 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use collections::{HashMap, HashSet, hash_map::Entry}; +use derive_more::Deref; use feature_flags::FeatureFlagAppExt; use futures::{Future, StreamExt, channel::mpsc}; use gpui::{ App, AsyncApp, Context, Entity, EventEmitter, SharedString, SharedUri, Task, WeakEntity, }; +use http_client::http::{HeaderMap, HeaderValue}; use postage::{sink::Sink, watch}; use rpc::proto::{RequestMessage, UsersResponse}; -use std::sync::{Arc, Weak}; +use std::{ + str::FromStr as _, + sync::{Arc, Weak}, +}; use text::ReplicaId; use util::{TryFutureExt as _, maybe}; +use zed_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; pub type UserId = u64; @@ -104,10 +113,8 @@ pub struct UserStore { current_plan: Option, subscription_period: Option<(DateTime, DateTime)>, trial_started_at: Option>, - model_request_usage_amount: Option, - model_request_usage_limit: Option, - edit_predictions_usage_amount: Option, - edit_predictions_usage_limit: Option, + model_request_usage: Option, + edit_prediction_usage: Option, is_usage_based_billing_enabled: Option, account_too_young: Option, has_overdue_invoices: Option, @@ -155,6 +162,18 @@ enum UpdateContacts { Clear(postage::barrier::Sender), } +#[derive(Debug, Clone, Copy, Deref)] +pub struct ModelRequestUsage(pub RequestUsage); + +#[derive(Debug, Clone, Copy, Deref)] +pub struct EditPredictionUsage(pub RequestUsage); + +#[derive(Debug, Clone, Copy)] +pub struct RequestUsage { + pub limit: UsageLimit, + pub amount: i32, +} + impl UserStore { pub fn new(client: Arc, cx: &Context) -> Self { let (mut current_user_tx, current_user_rx) = watch::channel(); @@ -172,10 +191,8 @@ impl UserStore { current_plan: None, subscription_period: None, trial_started_at: None, - model_request_usage_amount: None, - model_request_usage_limit: None, - edit_predictions_usage_amount: None, - edit_predictions_usage_limit: None, + model_request_usage: None, + edit_prediction_usage: None, is_usage_based_billing_enabled: None, account_too_young: None, has_overdue_invoices: None, @@ -356,10 +373,19 @@ impl UserStore { this.has_overdue_invoices = message.payload.has_overdue_invoices; if let Some(usage) = message.payload.usage { - this.model_request_usage_amount = Some(usage.model_requests_usage_amount); - this.model_request_usage_limit = usage.model_requests_usage_limit; - this.edit_predictions_usage_amount = Some(usage.edit_predictions_usage_amount); - this.edit_predictions_usage_limit = usage.edit_predictions_usage_limit; + // limits are always present even though they are wrapped in Option + this.model_request_usage = usage + .model_requests_usage_limit + .and_then(|limit| { + RequestUsage::from_proto(usage.model_requests_usage_amount, limit) + }) + .map(ModelRequestUsage); + this.edit_prediction_usage = usage + .edit_predictions_usage_limit + .and_then(|limit| { + RequestUsage::from_proto(usage.model_requests_usage_amount, limit) + }) + .map(EditPredictionUsage); } cx.notify(); @@ -367,6 +393,20 @@ impl UserStore { Ok(()) } + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); + } + fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { @@ -739,20 +779,12 @@ impl UserStore { self.is_usage_based_billing_enabled } - pub fn model_request_usage_amount(&self) -> Option { - self.model_request_usage_amount + pub fn model_request_usage(&self) -> Option { + self.model_request_usage } - pub fn model_request_usage_limit(&self) -> Option { - self.model_request_usage_limit.clone() - } - - pub fn edit_predictions_usage_amount(&self) -> Option { - self.edit_predictions_usage_amount - } - - pub fn edit_predictions_usage_limit(&self) -> Option { - self.edit_predictions_usage_limit.clone() + pub fn edit_prediction_usage(&self) -> Option { + self.edit_prediction_usage } pub fn watch_current_user(&self) -> watch::Receiver>> { @@ -917,3 +949,63 @@ impl Collaborator { }) } } + +impl RequestUsage { + pub fn over_limit(&self) -> bool { + match self.limit { + UsageLimit::Limited(limit) => self.amount >= limit, + UsageLimit::Unlimited => false, + } + } + + pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option { + let limit = match limit.variant? { + proto::usage_limit::Variant::Limited(limited) => { + UsageLimit::Limited(limited.limit as i32) + } + proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited, + }; + Some(RequestUsage { + limit, + amount: amount as i32, + }) + } + + fn from_headers( + limit_name: &str, + amount_name: &str, + headers: &HeaderMap, + ) -> Result { + let limit = headers + .get(limit_name) + .with_context(|| format!("missing {limit_name:?} header"))?; + let limit = UsageLimit::from_str(limit.to_str()?)?; + + let amount = headers + .get(amount_name) + .with_context(|| format!("missing {amount_name:?} header"))?; + let amount = amount.to_str()?.parse::()?; + + Ok(Self { limit, amount }) + } +} + +impl ModelRequestUsage { + pub fn from_headers(headers: &HeaderMap) -> Result { + Ok(Self(RequestUsage::from_headers( + MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, + headers, + )?)) + } +} + +impl EditPredictionUsage { + pub fn from_headers(headers: &HeaderMap) -> Result { + Ok(Self(RequestUsage::from_headers( + EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, + headers, + )?)) + } +} diff --git a/crates/inline_completion/Cargo.toml b/crates/inline_completion/Cargo.toml index 0094385e16..3a90875def 100644 --- a/crates/inline_completion/Cargo.toml +++ b/crates/inline_completion/Cargo.toml @@ -12,9 +12,8 @@ workspace = true path = "src/inline_completion.rs" [dependencies] -anyhow.workspace = true +client.workspace = true gpui.workspace = true language.workspace = true project.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/inline_completion/src/inline_completion.rs b/crates/inline_completion/src/inline_completion.rs index 7acfea72b2..c8f35bf16a 100644 --- a/crates/inline_completion/src/inline_completion.rs +++ b/crates/inline_completion/src/inline_completion.rs @@ -1,14 +1,9 @@ use std::ops::Range; -use std::str::FromStr as _; -use anyhow::{Context as _, Result}; -use gpui::http_client::http::{HeaderMap, HeaderValue}; +use client::EditPredictionUsage; use gpui::{App, Context, Entity, SharedString}; use language::Buffer; use project::Project; -use zed_llm_client::{ - EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; // TODO: Find a better home for `Direction`. // @@ -59,39 +54,6 @@ impl DataCollectionState { } } -#[derive(Debug, Clone, Copy)] -pub struct EditPredictionUsage { - pub limit: UsageLimit, - pub amount: i32, -} - -impl EditPredictionUsage { - pub fn from_headers(headers: &HeaderMap) -> Result { - let limit = headers - .get(EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME) - .with_context(|| { - format!("missing {EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME:?} header") - })?; - let limit = UsageLimit::from_str(limit.to_str()?)?; - - let amount = headers - .get(EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME) - .with_context(|| { - format!("missing {EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME:?} header") - })?; - let amount = amount.to_str()?.parse::()?; - - Ok(Self { limit, amount }) - } - - pub fn over_limit(&self) -> bool { - match self.limit { - UsageLimit::Limited(limit) => self.amount >= limit, - UsageLimit::Unlimited => false, - } - } -} - pub trait EditPredictionProvider: 'static + Sized { fn name() -> &'static str; fn display_name() -> &'static str; diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index c411593213..900d7f6f39 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -8,27 +8,22 @@ mod telemetry; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anyhow::{Context as _, Result}; +use anyhow::Result; use client::Client; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; -use http_client::http::{HeaderMap, HeaderValue}; use icons::IconName; use parking_lot::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::fmt; use std::ops::{Add, Sub}; -use std::str::FromStr as _; use std::sync::Arc; use std::time::Duration; use thiserror::Error; use util::serde::is_default; -use zed_llm_client::{ - CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, - MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; +use zed_llm_client::CompletionRequestStatus; pub use crate::model::*; pub use crate::rate_limiter::*; @@ -106,32 +101,6 @@ pub enum StopReason { Refusal, } -#[derive(Debug, Clone, Copy)] -pub struct RequestUsage { - pub limit: UsageLimit, - pub amount: i32, -} - -impl RequestUsage { - pub fn from_headers(headers: &HeaderMap) -> Result { - let limit = headers - .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME) - .with_context(|| { - format!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header") - })?; - let limit = UsageLimit::from_str(limit.to_str()?)?; - - let amount = headers - .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME) - .with_context(|| { - format!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header") - })?; - let amount = amount.to_str()?.parse::()?; - - Ok(Self { limit, amount }) - } -} - #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] pub struct TokenUsage { #[serde(default, skip_serializing_if = "is_default")] diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 59a5537ae9..1062d732a4 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,6 +1,6 @@ use anthropic::{AnthropicModelMode, parse_prompt_too_long}; use anyhow::{Context as _, Result, anyhow}; -use client::{Client, UserStore, zed_urls}; +use client::{Client, ModelRequestUsage, UserStore, zed_urls}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; @@ -14,7 +14,7 @@ use language_model::{ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, + LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ @@ -530,7 +530,7 @@ pub struct CloudLanguageModel { struct PerformLlmCompletionResponse { response: Response, - usage: Option, + usage: Option, tool_use_limit_reached: bool, includes_status_messages: bool, } @@ -581,7 +581,7 @@ impl CloudLanguageModel { let usage = if includes_status_messages { None } else { - RequestUsage::from_headers(response.headers()).ok() + ModelRequestUsage::from_headers(response.headers()).ok() }; return Ok(PerformLlmCompletionResponse { @@ -1002,7 +1002,7 @@ where } fn usage_updated_event( - usage: Option, + usage: Option, ) -> impl Stream>> { futures::stream::iter(usage.map(|usage| { Ok(CloudCompletionEvent::Status( diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 23ce320ee9..4d643c9db0 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -9,14 +9,14 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::KEY_VALUE_STORE; pub use init::*; -use inline_completion::{DataCollectionState, EditPredictionUsage}; +use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; pub use license_detection::is_license_eligible_for_data_collection; pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; -use client::{Client, UserStore}; +use client::{Client, EditPredictionUsage, UserStore}; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; use gpui::{ @@ -48,7 +48,7 @@ use std::{ }; use telemetry_events::InlineCompletionRating; use thiserror::Error; -use util::{ResultExt, maybe}; +use util::ResultExt; use uuid::Uuid; use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; @@ -188,7 +188,6 @@ pub struct Zeta { data_collection_choice: Entity, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - last_usage: Option, /// Whether the terms of service have been accepted. tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. @@ -234,25 +233,7 @@ impl Zeta { } pub fn usage(&self, cx: &App) -> Option { - self.last_usage.or_else(|| { - let user_store = self.user_store.read(cx); - maybe!({ - let amount = user_store.edit_predictions_usage_amount()?; - let limit = user_store.edit_predictions_usage_limit()?.variant?; - - Some(EditPredictionUsage { - amount: amount as i32, - limit: match limit { - proto::usage_limit::Variant::Limited(limited) => { - zed_llm_client::UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => { - zed_llm_client::UsageLimit::Unlimited - } - }, - }) - }) - }) + self.user_store.read(cx).edit_prediction_usage() } fn new( @@ -287,7 +268,6 @@ impl Zeta { .detach_and_log_err(cx); }, ), - last_usage: None, tos_accepted: user_store .read(cx) .current_user_has_accepted_terms() @@ -533,8 +513,10 @@ impl Zeta { log::debug!("completion response: {}", &response.output_excerpt); if let Some(usage) = usage { - this.update(cx, |this, _cx| { - this.last_usage = Some(usage); + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); }) .ok(); } @@ -874,8 +856,9 @@ and then another if response.status().is_success() { if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { this.update(cx, |this, cx| { - this.last_usage = Some(usage); - cx.notify(); + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); })?; }