diff --git a/Cargo.lock b/Cargo.lock index d3a623f7ef..bfa9b2396e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2821,6 +2821,7 @@ dependencies = [ "cocoa 0.26.0", "collections", "credentials_provider", + "derive_more", "feature_flags", "fs", "futures 0.3.31", @@ -2859,6 +2860,7 @@ dependencies = [ "windows 0.61.1", "workspace-hack", "worktree", + "zed_llm_client", ] [[package]] @@ -8159,6 +8161,7 @@ name = "inline_completion" version = "0.1.0" dependencies = [ "anyhow", + "client", "gpui", "language", "project", diff --git a/crates/agent/src/agent_panel.rs b/crates/agent/src/agent_panel.rs index ee76045d21..10c2db37c4 100644 --- a/crates/agent/src/agent_panel.rs +++ b/crates/agent/src/agent_panel.rs @@ -29,8 +29,7 @@ use gpui::{ }; use language::LanguageRegistry; use language_model::{ - ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, RequestUsage, - ZED_CLOUD_PROVIDER_ID, + ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, }; use project::{Project, ProjectPath, Worktree}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; @@ -45,7 +44,7 @@ use ui::{ Banner, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu, PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*, }; -use util::{ResultExt as _, maybe}; +use util::ResultExt as _; use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::{ CollaboratorId, DraggedSelection, DraggedTab, ToggleZoom, ToolbarItemView, Workspace, @@ -1682,24 +1681,7 @@ impl AgentPanel { let thread_id = thread.id().clone(); let is_empty = active_thread.is_empty(); let editor_empty = self.message_editor.read(cx).is_editor_fully_empty(cx); - let last_usage = active_thread.thread().read(cx).last_usage().or_else(|| { - maybe!({ - let amount = user_store.model_request_usage_amount()?; - let limit = user_store.model_request_usage_limit()?.variant?; - - Some(RequestUsage { - amount: amount as i32, - limit: match limit { - proto::usage_limit::Variant::Limited(limited) => { - zed_llm_client::UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => { - zed_llm_client::UsageLimit::Unlimited - } - }, - }) - }) - }); + let usage = user_store.model_request_usage(); let account_url = zed_urls::account_url(cx); @@ -1820,7 +1802,7 @@ impl AgentPanel { .action("Add Custom Server…", Box::new(AddContextServer)) .separator(); - if let Some(usage) = last_usage { + if let Some(usage) = usage { menu = menu .header_with_link("Prompt Usage", "Manage", account_url.clone()) .custom_entry( diff --git a/crates/agent/src/debug.rs b/crates/agent/src/debug.rs index 7bd52e5a96..ff6538dc85 100644 --- a/crates/agent/src/debug.rs +++ b/crates/agent/src/debug.rs @@ -1,7 +1,7 @@ #![allow(unused, dead_code)] +use client::{ModelRequestUsage, RequestUsage}; use gpui::Global; -use language_model::RequestUsage; use std::ops::{Deref, DerefMut}; use ui::prelude::*; use zed_llm_client::{Plan, UsageLimit}; @@ -17,7 +17,7 @@ pub struct DebugAccountState { pub enabled: bool, pub trial_expired: bool, pub plan: Plan, - pub custom_prompt_usage: RequestUsage, + pub custom_prompt_usage: ModelRequestUsage, pub usage_based_billing_enabled: bool, pub monthly_spending_cap: i32, pub custom_edit_prediction_usage: UsageLimit, @@ -43,7 +43,7 @@ impl DebugAccountState { self } - pub fn set_custom_prompt_usage(&mut self, custom_prompt_usage: RequestUsage) -> &mut Self { + pub fn set_custom_prompt_usage(&mut self, custom_prompt_usage: ModelRequestUsage) -> &mut Self { self.custom_prompt_usage = custom_prompt_usage; self } @@ -76,10 +76,10 @@ impl Default for DebugAccountState { enabled: false, trial_expired: false, plan: Plan::ZedFree, - custom_prompt_usage: RequestUsage { + custom_prompt_usage: ModelRequestUsage(RequestUsage { limit: UsageLimit::Unlimited, amount: 0, - }, + }), usage_based_billing_enabled: false, // $50.00 monthly_spending_cap: 5000, diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index c8d127aa28..ec0a01e8af 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -29,8 +29,7 @@ use gpui::{ }; use language::{Buffer, Language, Point}; use language_model::{ - ConfiguredModel, LanguageModelRequestMessage, MessageContent, RequestUsage, - ZED_CLOUD_PROVIDER_ID, + ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID, }; use multi_buffer; use project::Project; @@ -42,7 +41,7 @@ use theme::ThemeSettings; use ui::{ Callout, Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*, }; -use util::{ResultExt as _, maybe}; +use util::ResultExt as _; use workspace::{CollaboratorId, Workspace}; use zed_llm_client::CompletionIntent; @@ -1257,24 +1256,8 @@ impl MessageEditor { Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, }) .unwrap_or(zed_llm_client::Plan::ZedFree); - let usage = self.thread.read(cx).last_usage().or_else(|| { - maybe!({ - let amount = user_store.model_request_usage_amount()?; - let limit = user_store.model_request_usage_limit()?.variant?; - Some(RequestUsage { - amount: amount as i32, - limit: match limit { - proto::usage_limit::Variant::Limited(limited) => { - zed_llm_client::UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => { - zed_llm_client::UsageLimit::Unlimited - } - }, - }) - }) - })?; + let usage = user_store.model_request_usage()?; Some( div() diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index edc0ea1152..1a6b9604b5 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -7,6 +7,7 @@ use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; +use client::{ModelRequestUsage, RequestUsage}; use collections::HashMap; use editor::display_map::CreaseMetadata; use feature_flags::{self, FeatureFlagAppExt}; @@ -22,8 +23,8 @@ use language_model::{ LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, - ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, - StopReason, TokenUsage, + ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason, + TokenUsage, }; use postage::stream::Stream as _; use project::Project; @@ -38,7 +39,7 @@ use ui::Window; use util::{ResultExt as _, post_inc}; use uuid::Uuid; -use zed_llm_client::{CompletionIntent, CompletionRequestStatus}; +use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use crate::ThreadStore; use crate::agent_profile::AgentProfile; @@ -350,7 +351,6 @@ pub struct Thread { request_token_usage: Vec, cumulative_token_usage: TokenUsage, exceeded_window_error: Option, - last_usage: Option, tool_use_limit_reached: bool, feedback: Option, message_feedback: HashMap, @@ -443,7 +443,6 @@ impl Thread { request_token_usage: Vec::new(), cumulative_token_usage: TokenUsage::default(), exceeded_window_error: None, - last_usage: None, tool_use_limit_reached: false, feedback: None, message_feedback: HashMap::default(), @@ -568,7 +567,6 @@ impl Thread { request_token_usage: serialized.request_token_usage, cumulative_token_usage: serialized.cumulative_token_usage, exceeded_window_error: None, - last_usage: None, tool_use_limit_reached: serialized.tool_use_limit_reached, feedback: None, message_feedback: HashMap::default(), @@ -875,10 +873,6 @@ impl Thread { .unwrap_or(false) } - pub fn last_usage(&self) -> Option { - self.last_usage - } - pub fn tool_use_limit_reached(&self) -> bool { self.tool_use_limit_reached } @@ -1658,9 +1652,7 @@ impl Thread { CompletionRequestStatus::UsageUpdated { amount, limit } => { - let usage = RequestUsage { limit, amount: amount as i32 }; - - thread.last_usage = Some(usage); + thread.update_model_request_usage(amount as u32, limit, cx); } CompletionRequestStatus::ToolUseLimitReached => { thread.tool_use_limit_reached = true; @@ -1871,11 +1863,8 @@ impl Thread { LanguageModelCompletionEvent::StatusUpdate( CompletionRequestStatus::UsageUpdated { amount, limit }, ) => { - this.update(cx, |thread, _cx| { - thread.last_usage = Some(RequestUsage { - limit, - amount: amount as i32, - }); + this.update(cx, |thread, cx| { + thread.update_model_request_usage(amount as u32, limit, cx); })?; continue; } @@ -2757,6 +2746,20 @@ impl Thread { } } + fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { + self.project.update(cx, |project, cx| { + project.user_store().update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }) + }); + } + pub fn deny_tool_use( &mut self, tool_use_id: LanguageModelToolUseId, diff --git a/crates/agent/src/ui/preview/usage_callouts.rs b/crates/agent/src/ui/preview/usage_callouts.rs index 62e2909461..45af41395b 100644 --- a/crates/agent/src/ui/preview/usage_callouts.rs +++ b/crates/agent/src/ui/preview/usage_callouts.rs @@ -1,18 +1,17 @@ -use client::zed_urls; +use client::{ModelRequestUsage, RequestUsage, zed_urls}; use component::{empty_example, example_group_with_title, single_example}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; -use language_model::RequestUsage; use ui::{Callout, prelude::*}; use zed_llm_client::{Plan, UsageLimit}; #[derive(IntoElement, RegisterComponent)] pub struct UsageCallout { plan: Plan, - usage: RequestUsage, + usage: ModelRequestUsage, } impl UsageCallout { - pub fn new(plan: Plan, usage: RequestUsage) -> Self { + pub fn new(plan: Plan, usage: ModelRequestUsage) -> Self { Self { plan, usage } } } @@ -128,10 +127,10 @@ impl Component for UsageCallout { "Approaching limit (90%)", UsageCallout::new( Plan::ZedFree, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(50), amount: 45, // 90% of limit - }, + }), ) .into_any_element(), ), @@ -139,10 +138,10 @@ impl Component for UsageCallout { "Limit reached (100%)", UsageCallout::new( Plan::ZedFree, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(50), amount: 50, // 100% of limit - }, + }), ) .into_any_element(), ), @@ -156,10 +155,10 @@ impl Component for UsageCallout { "Approaching limit (90%)", UsageCallout::new( Plan::ZedProTrial, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(150), amount: 135, // 90% of limit - }, + }), ) .into_any_element(), ), @@ -167,10 +166,10 @@ impl Component for UsageCallout { "Limit reached (100%)", UsageCallout::new( Plan::ZedProTrial, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(150), amount: 150, // 100% of limit - }, + }), ) .into_any_element(), ), @@ -184,10 +183,10 @@ impl Component for UsageCallout { "Limit reached (100%)", UsageCallout::new( Plan::ZedPro, - RequestUsage { + ModelRequestUsage(RequestUsage { limit: UsageLimit::Limited(500), amount: 500, // 100% of limit - }, + }), ) .into_any_element(), ), diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 0d65b7ef21..b741f515fd 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -24,6 +24,7 @@ chrono = { workspace = true, features = ["serde"] } clock.workspace = true collections.workspace = true credentials_provider.workspace = true +derive_more.workspace = true feature_flags.workspace = true futures.workspace = true gpui.workspace = true @@ -57,6 +58,7 @@ worktree.workspace = true telemetry.workspace = true tokio.workspace = true workspace-hack.workspace = true +zed_llm_client.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 1be8d71e85..61e3064eb4 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -2,16 +2,25 @@ use super::{Client, Status, TypedEnvelope, proto}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use collections::{HashMap, HashSet, hash_map::Entry}; +use derive_more::Deref; use feature_flags::FeatureFlagAppExt; use futures::{Future, StreamExt, channel::mpsc}; use gpui::{ App, AsyncApp, Context, Entity, EventEmitter, SharedString, SharedUri, Task, WeakEntity, }; +use http_client::http::{HeaderMap, HeaderValue}; use postage::{sink::Sink, watch}; use rpc::proto::{RequestMessage, UsersResponse}; -use std::sync::{Arc, Weak}; +use std::{ + str::FromStr as _, + sync::{Arc, Weak}, +}; use text::ReplicaId; use util::{TryFutureExt as _, maybe}; +use zed_llm_client::{ + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, +}; pub type UserId = u64; @@ -104,10 +113,8 @@ pub struct UserStore { current_plan: Option, subscription_period: Option<(DateTime, DateTime)>, trial_started_at: Option>, - model_request_usage_amount: Option, - model_request_usage_limit: Option, - edit_predictions_usage_amount: Option, - edit_predictions_usage_limit: Option, + model_request_usage: Option, + edit_prediction_usage: Option, is_usage_based_billing_enabled: Option, account_too_young: Option, has_overdue_invoices: Option, @@ -155,6 +162,18 @@ enum UpdateContacts { Clear(postage::barrier::Sender), } +#[derive(Debug, Clone, Copy, Deref)] +pub struct ModelRequestUsage(pub RequestUsage); + +#[derive(Debug, Clone, Copy, Deref)] +pub struct EditPredictionUsage(pub RequestUsage); + +#[derive(Debug, Clone, Copy)] +pub struct RequestUsage { + pub limit: UsageLimit, + pub amount: i32, +} + impl UserStore { pub fn new(client: Arc, cx: &Context) -> Self { let (mut current_user_tx, current_user_rx) = watch::channel(); @@ -172,10 +191,8 @@ impl UserStore { current_plan: None, subscription_period: None, trial_started_at: None, - model_request_usage_amount: None, - model_request_usage_limit: None, - edit_predictions_usage_amount: None, - edit_predictions_usage_limit: None, + model_request_usage: None, + edit_prediction_usage: None, is_usage_based_billing_enabled: None, account_too_young: None, has_overdue_invoices: None, @@ -356,10 +373,19 @@ impl UserStore { this.has_overdue_invoices = message.payload.has_overdue_invoices; if let Some(usage) = message.payload.usage { - this.model_request_usage_amount = Some(usage.model_requests_usage_amount); - this.model_request_usage_limit = usage.model_requests_usage_limit; - this.edit_predictions_usage_amount = Some(usage.edit_predictions_usage_amount); - this.edit_predictions_usage_limit = usage.edit_predictions_usage_limit; + // limits are always present even though they are wrapped in Option + this.model_request_usage = usage + .model_requests_usage_limit + .and_then(|limit| { + RequestUsage::from_proto(usage.model_requests_usage_amount, limit) + }) + .map(ModelRequestUsage); + this.edit_prediction_usage = usage + .edit_predictions_usage_limit + .and_then(|limit| { + RequestUsage::from_proto(usage.model_requests_usage_amount, limit) + }) + .map(EditPredictionUsage); } cx.notify(); @@ -367,6 +393,20 @@ impl UserStore { Ok(()) } + pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context) { + self.model_request_usage = Some(usage); + cx.notify(); + } + + pub fn update_edit_prediction_usage( + &mut self, + usage: EditPredictionUsage, + cx: &mut Context, + ) { + self.edit_prediction_usage = Some(usage); + cx.notify(); + } + fn update_contacts(&mut self, message: UpdateContacts, cx: &Context) -> Task> { match message { UpdateContacts::Wait(barrier) => { @@ -739,20 +779,12 @@ impl UserStore { self.is_usage_based_billing_enabled } - pub fn model_request_usage_amount(&self) -> Option { - self.model_request_usage_amount + pub fn model_request_usage(&self) -> Option { + self.model_request_usage } - pub fn model_request_usage_limit(&self) -> Option { - self.model_request_usage_limit.clone() - } - - pub fn edit_predictions_usage_amount(&self) -> Option { - self.edit_predictions_usage_amount - } - - pub fn edit_predictions_usage_limit(&self) -> Option { - self.edit_predictions_usage_limit.clone() + pub fn edit_prediction_usage(&self) -> Option { + self.edit_prediction_usage } pub fn watch_current_user(&self) -> watch::Receiver>> { @@ -917,3 +949,63 @@ impl Collaborator { }) } } + +impl RequestUsage { + pub fn over_limit(&self) -> bool { + match self.limit { + UsageLimit::Limited(limit) => self.amount >= limit, + UsageLimit::Unlimited => false, + } + } + + pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option { + let limit = match limit.variant? { + proto::usage_limit::Variant::Limited(limited) => { + UsageLimit::Limited(limited.limit as i32) + } + proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited, + }; + Some(RequestUsage { + limit, + amount: amount as i32, + }) + } + + fn from_headers( + limit_name: &str, + amount_name: &str, + headers: &HeaderMap, + ) -> Result { + let limit = headers + .get(limit_name) + .with_context(|| format!("missing {limit_name:?} header"))?; + let limit = UsageLimit::from_str(limit.to_str()?)?; + + let amount = headers + .get(amount_name) + .with_context(|| format!("missing {amount_name:?} header"))?; + let amount = amount.to_str()?.parse::()?; + + Ok(Self { limit, amount }) + } +} + +impl ModelRequestUsage { + pub fn from_headers(headers: &HeaderMap) -> Result { + Ok(Self(RequestUsage::from_headers( + MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, + MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, + headers, + )?)) + } +} + +impl EditPredictionUsage { + pub fn from_headers(headers: &HeaderMap) -> Result { + Ok(Self(RequestUsage::from_headers( + EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, + EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, + headers, + )?)) + } +} diff --git a/crates/inline_completion/Cargo.toml b/crates/inline_completion/Cargo.toml index 0094385e16..3a90875def 100644 --- a/crates/inline_completion/Cargo.toml +++ b/crates/inline_completion/Cargo.toml @@ -12,9 +12,8 @@ workspace = true path = "src/inline_completion.rs" [dependencies] -anyhow.workspace = true +client.workspace = true gpui.workspace = true language.workspace = true project.workspace = true workspace-hack.workspace = true -zed_llm_client.workspace = true diff --git a/crates/inline_completion/src/inline_completion.rs b/crates/inline_completion/src/inline_completion.rs index 7acfea72b2..c8f35bf16a 100644 --- a/crates/inline_completion/src/inline_completion.rs +++ b/crates/inline_completion/src/inline_completion.rs @@ -1,14 +1,9 @@ use std::ops::Range; -use std::str::FromStr as _; -use anyhow::{Context as _, Result}; -use gpui::http_client::http::{HeaderMap, HeaderValue}; +use client::EditPredictionUsage; use gpui::{App, Context, Entity, SharedString}; use language::Buffer; use project::Project; -use zed_llm_client::{ - EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; // TODO: Find a better home for `Direction`. // @@ -59,39 +54,6 @@ impl DataCollectionState { } } -#[derive(Debug, Clone, Copy)] -pub struct EditPredictionUsage { - pub limit: UsageLimit, - pub amount: i32, -} - -impl EditPredictionUsage { - pub fn from_headers(headers: &HeaderMap) -> Result { - let limit = headers - .get(EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME) - .with_context(|| { - format!("missing {EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME:?} header") - })?; - let limit = UsageLimit::from_str(limit.to_str()?)?; - - let amount = headers - .get(EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME) - .with_context(|| { - format!("missing {EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME:?} header") - })?; - let amount = amount.to_str()?.parse::()?; - - Ok(Self { limit, amount }) - } - - pub fn over_limit(&self) -> bool { - match self.limit { - UsageLimit::Limited(limit) => self.amount >= limit, - UsageLimit::Unlimited => false, - } - } -} - pub trait EditPredictionProvider: 'static + Sized { fn name() -> &'static str; fn display_name() -> &'static str; diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index c411593213..900d7f6f39 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -8,27 +8,22 @@ mod telemetry; #[cfg(any(test, feature = "test-support"))] pub mod fake_provider; -use anyhow::{Context as _, Result}; +use anyhow::Result; use client::Client; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; -use http_client::http::{HeaderMap, HeaderValue}; use icons::IconName; use parking_lot::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::fmt; use std::ops::{Add, Sub}; -use std::str::FromStr as _; use std::sync::Arc; use std::time::Duration; use thiserror::Error; use util::serde::is_default; -use zed_llm_client::{ - CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, - MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit, -}; +use zed_llm_client::CompletionRequestStatus; pub use crate::model::*; pub use crate::rate_limiter::*; @@ -106,32 +101,6 @@ pub enum StopReason { Refusal, } -#[derive(Debug, Clone, Copy)] -pub struct RequestUsage { - pub limit: UsageLimit, - pub amount: i32, -} - -impl RequestUsage { - pub fn from_headers(headers: &HeaderMap) -> Result { - let limit = headers - .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME) - .with_context(|| { - format!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header") - })?; - let limit = UsageLimit::from_str(limit.to_str()?)?; - - let amount = headers - .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME) - .with_context(|| { - format!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header") - })?; - let amount = amount.to_str()?.parse::()?; - - Ok(Self { limit, amount }) - } -} - #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] pub struct TokenUsage { #[serde(default, skip_serializing_if = "is_default")] diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 59a5537ae9..1062d732a4 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,6 +1,6 @@ use anthropic::{AnthropicModelMode, parse_prompt_too_long}; use anyhow::{Context as _, Result, anyhow}; -use client::{Client, UserStore, zed_urls}; +use client::{Client, ModelRequestUsage, UserStore, zed_urls}; use futures::{ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; @@ -14,7 +14,7 @@ use language_model::{ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, + LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ @@ -530,7 +530,7 @@ pub struct CloudLanguageModel { struct PerformLlmCompletionResponse { response: Response, - usage: Option, + usage: Option, tool_use_limit_reached: bool, includes_status_messages: bool, } @@ -581,7 +581,7 @@ impl CloudLanguageModel { let usage = if includes_status_messages { None } else { - RequestUsage::from_headers(response.headers()).ok() + ModelRequestUsage::from_headers(response.headers()).ok() }; return Ok(PerformLlmCompletionResponse { @@ -1002,7 +1002,7 @@ where } fn usage_updated_event( - usage: Option, + usage: Option, ) -> impl Stream>> { futures::stream::iter(usage.map(|usage| { Ok(CloudCompletionEvent::Status( diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 23ce320ee9..4d643c9db0 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -9,14 +9,14 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::KEY_VALUE_STORE; pub use init::*; -use inline_completion::{DataCollectionState, EditPredictionUsage}; +use inline_completion::DataCollectionState; use license_detection::LICENSE_FILES_TO_CHECK; pub use license_detection::is_license_eligible_for_data_collection; pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; use arrayvec::ArrayVec; -use client::{Client, UserStore}; +use client::{Client, EditPredictionUsage, UserStore}; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; use gpui::{ @@ -48,7 +48,7 @@ use std::{ }; use telemetry_events::InlineCompletionRating; use thiserror::Error; -use util::{ResultExt, maybe}; +use util::ResultExt; use uuid::Uuid; use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId}; @@ -188,7 +188,6 @@ pub struct Zeta { data_collection_choice: Entity, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - last_usage: Option, /// Whether the terms of service have been accepted. tos_accepted: bool, /// Whether an update to a newer version of Zed is required to continue using Zeta. @@ -234,25 +233,7 @@ impl Zeta { } pub fn usage(&self, cx: &App) -> Option { - self.last_usage.or_else(|| { - let user_store = self.user_store.read(cx); - maybe!({ - let amount = user_store.edit_predictions_usage_amount()?; - let limit = user_store.edit_predictions_usage_limit()?.variant?; - - Some(EditPredictionUsage { - amount: amount as i32, - limit: match limit { - proto::usage_limit::Variant::Limited(limited) => { - zed_llm_client::UsageLimit::Limited(limited.limit as i32) - } - proto::usage_limit::Variant::Unlimited(_) => { - zed_llm_client::UsageLimit::Unlimited - } - }, - }) - }) - }) + self.user_store.read(cx).edit_prediction_usage() } fn new( @@ -287,7 +268,6 @@ impl Zeta { .detach_and_log_err(cx); }, ), - last_usage: None, tos_accepted: user_store .read(cx) .current_user_has_accepted_terms() @@ -533,8 +513,10 @@ impl Zeta { log::debug!("completion response: {}", &response.output_excerpt); if let Some(usage) = usage { - this.update(cx, |this, _cx| { - this.last_usage = Some(usage); + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); }) .ok(); } @@ -874,8 +856,9 @@ and then another if response.status().is_success() { if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { this.update(cx, |this, cx| { - this.last_usage = Some(usage); - cx.notify(); + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); })?; }