agent: Fix issues with usage display sometimes showing initially fetched usage (#33125)

Having `Thread::last_usage` as an override of the initially fetched
usage could cause the initial usage to be displayed when the current
thread is empty or in text threads. Fix is to just store last usage info
in `UserStore` and not have these overrides

Release Notes:

- Agent: Fixed request usage display to always include the most recently
known usage - there were some cases where it would show the initially
requested usage.
This commit is contained in:
Michael Sloan 2025-06-20 15:28:48 -06:00 committed by GitHub
parent e0c0b6f95d
commit 7e801dccb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 188 additions and 211 deletions

3
Cargo.lock generated
View file

@ -2821,6 +2821,7 @@ dependencies = [
"cocoa 0.26.0", "cocoa 0.26.0",
"collections", "collections",
"credentials_provider", "credentials_provider",
"derive_more",
"feature_flags", "feature_flags",
"fs", "fs",
"futures 0.3.31", "futures 0.3.31",
@ -2859,6 +2860,7 @@ dependencies = [
"windows 0.61.1", "windows 0.61.1",
"workspace-hack", "workspace-hack",
"worktree", "worktree",
"zed_llm_client",
] ]
[[package]] [[package]]
@ -8159,6 +8161,7 @@ name = "inline_completion"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"client",
"gpui", "gpui",
"language", "language",
"project", "project",

View file

@ -29,8 +29,7 @@ use gpui::{
}; };
use language::LanguageRegistry; use language::LanguageRegistry;
use language_model::{ use language_model::{
ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, RequestUsage, ConfigurationError, LanguageModelProviderTosView, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
ZED_CLOUD_PROVIDER_ID,
}; };
use project::{Project, ProjectPath, Worktree}; use project::{Project, ProjectPath, Worktree};
use prompt_store::{PromptBuilder, PromptStore, UserPromptId}; use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
@ -45,7 +44,7 @@ use ui::{
Banner, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu, Banner, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu,
PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*, PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*,
}; };
use util::{ResultExt as _, maybe}; use util::ResultExt as _;
use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::{ use workspace::{
CollaboratorId, DraggedSelection, DraggedTab, ToggleZoom, ToolbarItemView, Workspace, CollaboratorId, DraggedSelection, DraggedTab, ToggleZoom, ToolbarItemView, Workspace,
@ -1682,24 +1681,7 @@ impl AgentPanel {
let thread_id = thread.id().clone(); let thread_id = thread.id().clone();
let is_empty = active_thread.is_empty(); let is_empty = active_thread.is_empty();
let editor_empty = self.message_editor.read(cx).is_editor_fully_empty(cx); let editor_empty = self.message_editor.read(cx).is_editor_fully_empty(cx);
let last_usage = active_thread.thread().read(cx).last_usage().or_else(|| { let usage = user_store.model_request_usage();
maybe!({
let amount = user_store.model_request_usage_amount()?;
let limit = user_store.model_request_usage_limit()?.variant?;
Some(RequestUsage {
amount: amount as i32,
limit: match limit {
proto::usage_limit::Variant::Limited(limited) => {
zed_llm_client::UsageLimit::Limited(limited.limit as i32)
}
proto::usage_limit::Variant::Unlimited(_) => {
zed_llm_client::UsageLimit::Unlimited
}
},
})
})
});
let account_url = zed_urls::account_url(cx); let account_url = zed_urls::account_url(cx);
@ -1820,7 +1802,7 @@ impl AgentPanel {
.action("Add Custom Server…", Box::new(AddContextServer)) .action("Add Custom Server…", Box::new(AddContextServer))
.separator(); .separator();
if let Some(usage) = last_usage { if let Some(usage) = usage {
menu = menu menu = menu
.header_with_link("Prompt Usage", "Manage", account_url.clone()) .header_with_link("Prompt Usage", "Manage", account_url.clone())
.custom_entry( .custom_entry(

View file

@ -1,7 +1,7 @@
#![allow(unused, dead_code)] #![allow(unused, dead_code)]
use client::{ModelRequestUsage, RequestUsage};
use gpui::Global; use gpui::Global;
use language_model::RequestUsage;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use ui::prelude::*; use ui::prelude::*;
use zed_llm_client::{Plan, UsageLimit}; use zed_llm_client::{Plan, UsageLimit};
@ -17,7 +17,7 @@ pub struct DebugAccountState {
pub enabled: bool, pub enabled: bool,
pub trial_expired: bool, pub trial_expired: bool,
pub plan: Plan, pub plan: Plan,
pub custom_prompt_usage: RequestUsage, pub custom_prompt_usage: ModelRequestUsage,
pub usage_based_billing_enabled: bool, pub usage_based_billing_enabled: bool,
pub monthly_spending_cap: i32, pub monthly_spending_cap: i32,
pub custom_edit_prediction_usage: UsageLimit, pub custom_edit_prediction_usage: UsageLimit,
@ -43,7 +43,7 @@ impl DebugAccountState {
self self
} }
pub fn set_custom_prompt_usage(&mut self, custom_prompt_usage: RequestUsage) -> &mut Self { pub fn set_custom_prompt_usage(&mut self, custom_prompt_usage: ModelRequestUsage) -> &mut Self {
self.custom_prompt_usage = custom_prompt_usage; self.custom_prompt_usage = custom_prompt_usage;
self self
} }
@ -76,10 +76,10 @@ impl Default for DebugAccountState {
enabled: false, enabled: false,
trial_expired: false, trial_expired: false,
plan: Plan::ZedFree, plan: Plan::ZedFree,
custom_prompt_usage: RequestUsage { custom_prompt_usage: ModelRequestUsage(RequestUsage {
limit: UsageLimit::Unlimited, limit: UsageLimit::Unlimited,
amount: 0, amount: 0,
}, }),
usage_based_billing_enabled: false, usage_based_billing_enabled: false,
// $50.00 // $50.00
monthly_spending_cap: 5000, monthly_spending_cap: 5000,

View file

@ -29,8 +29,7 @@ use gpui::{
}; };
use language::{Buffer, Language, Point}; use language::{Buffer, Language, Point};
use language_model::{ use language_model::{
ConfiguredModel, LanguageModelRequestMessage, MessageContent, RequestUsage, ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID,
ZED_CLOUD_PROVIDER_ID,
}; };
use multi_buffer; use multi_buffer;
use project::Project; use project::Project;
@ -42,7 +41,7 @@ use theme::ThemeSettings;
use ui::{ use ui::{
Callout, Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*, Callout, Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*,
}; };
use util::{ResultExt as _, maybe}; use util::ResultExt as _;
use workspace::{CollaboratorId, Workspace}; use workspace::{CollaboratorId, Workspace};
use zed_llm_client::CompletionIntent; use zed_llm_client::CompletionIntent;
@ -1257,24 +1256,8 @@ impl MessageEditor {
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
}) })
.unwrap_or(zed_llm_client::Plan::ZedFree); .unwrap_or(zed_llm_client::Plan::ZedFree);
let usage = self.thread.read(cx).last_usage().or_else(|| {
maybe!({
let amount = user_store.model_request_usage_amount()?;
let limit = user_store.model_request_usage_limit()?.variant?;
Some(RequestUsage { let usage = user_store.model_request_usage()?;
amount: amount as i32,
limit: match limit {
proto::usage_limit::Variant::Limited(limited) => {
zed_llm_client::UsageLimit::Limited(limited.limit as i32)
}
proto::usage_limit::Variant::Unlimited(_) => {
zed_llm_client::UsageLimit::Unlimited
}
},
})
})
})?;
Some( Some(
div() div()

View file

@ -7,6 +7,7 @@ use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use collections::HashMap; use collections::HashMap;
use editor::display_map::CreaseMetadata; use editor::display_map::CreaseMetadata;
use feature_flags::{self, FeatureFlagAppExt}; use feature_flags::{self, FeatureFlagAppExt};
@ -22,8 +23,8 @@ use language_model::{
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
StopReason, TokenUsage, TokenUsage,
}; };
use postage::stream::Stream as _; use postage::stream::Stream as _;
use project::Project; use project::Project;
@ -38,7 +39,7 @@ use ui::Window;
use util::{ResultExt as _, post_inc}; use util::{ResultExt as _, post_inc};
use uuid::Uuid; use uuid::Uuid;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus}; use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use crate::ThreadStore; use crate::ThreadStore;
use crate::agent_profile::AgentProfile; use crate::agent_profile::AgentProfile;
@ -350,7 +351,6 @@ pub struct Thread {
request_token_usage: Vec<TokenUsage>, request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage, cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>, exceeded_window_error: Option<ExceededWindowError>,
last_usage: Option<RequestUsage>,
tool_use_limit_reached: bool, tool_use_limit_reached: bool,
feedback: Option<ThreadFeedback>, feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>, message_feedback: HashMap<MessageId, ThreadFeedback>,
@ -443,7 +443,6 @@ impl Thread {
request_token_usage: Vec::new(), request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(), cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None, exceeded_window_error: None,
last_usage: None,
tool_use_limit_reached: false, tool_use_limit_reached: false,
feedback: None, feedback: None,
message_feedback: HashMap::default(), message_feedback: HashMap::default(),
@ -568,7 +567,6 @@ impl Thread {
request_token_usage: serialized.request_token_usage, request_token_usage: serialized.request_token_usage,
cumulative_token_usage: serialized.cumulative_token_usage, cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None, exceeded_window_error: None,
last_usage: None,
tool_use_limit_reached: serialized.tool_use_limit_reached, tool_use_limit_reached: serialized.tool_use_limit_reached,
feedback: None, feedback: None,
message_feedback: HashMap::default(), message_feedback: HashMap::default(),
@ -875,10 +873,6 @@ impl Thread {
.unwrap_or(false) .unwrap_or(false)
} }
pub fn last_usage(&self) -> Option<RequestUsage> {
self.last_usage
}
pub fn tool_use_limit_reached(&self) -> bool { pub fn tool_use_limit_reached(&self) -> bool {
self.tool_use_limit_reached self.tool_use_limit_reached
} }
@ -1658,9 +1652,7 @@ impl Thread {
CompletionRequestStatus::UsageUpdated { CompletionRequestStatus::UsageUpdated {
amount, limit amount, limit
} => { } => {
let usage = RequestUsage { limit, amount: amount as i32 }; thread.update_model_request_usage(amount as u32, limit, cx);
thread.last_usage = Some(usage);
} }
CompletionRequestStatus::ToolUseLimitReached => { CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true; thread.tool_use_limit_reached = true;
@ -1871,11 +1863,8 @@ impl Thread {
LanguageModelCompletionEvent::StatusUpdate( LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit }, CompletionRequestStatus::UsageUpdated { amount, limit },
) => { ) => {
this.update(cx, |thread, _cx| { this.update(cx, |thread, cx| {
thread.last_usage = Some(RequestUsage { thread.update_model_request_usage(amount as u32, limit, cx);
limit,
amount: amount as i32,
});
})?; })?;
continue; continue;
} }
@ -2757,6 +2746,20 @@ impl Thread {
} }
} }
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
self.project.update(cx, |project, cx| {
project.user_store().update(cx, |user_store, cx| {
user_store.update_model_request_usage(
ModelRequestUsage(RequestUsage {
amount: amount as i32,
limit,
}),
cx,
)
})
});
}
pub fn deny_tool_use( pub fn deny_tool_use(
&mut self, &mut self,
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,

View file

@ -1,18 +1,17 @@
use client::zed_urls; use client::{ModelRequestUsage, RequestUsage, zed_urls};
use component::{empty_example, example_group_with_title, single_example}; use component::{empty_example, example_group_with_title, single_example};
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window}; use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
use language_model::RequestUsage;
use ui::{Callout, prelude::*}; use ui::{Callout, prelude::*};
use zed_llm_client::{Plan, UsageLimit}; use zed_llm_client::{Plan, UsageLimit};
#[derive(IntoElement, RegisterComponent)] #[derive(IntoElement, RegisterComponent)]
pub struct UsageCallout { pub struct UsageCallout {
plan: Plan, plan: Plan,
usage: RequestUsage, usage: ModelRequestUsage,
} }
impl UsageCallout { impl UsageCallout {
pub fn new(plan: Plan, usage: RequestUsage) -> Self { pub fn new(plan: Plan, usage: ModelRequestUsage) -> Self {
Self { plan, usage } Self { plan, usage }
} }
} }
@ -128,10 +127,10 @@ impl Component for UsageCallout {
"Approaching limit (90%)", "Approaching limit (90%)",
UsageCallout::new( UsageCallout::new(
Plan::ZedFree, Plan::ZedFree,
RequestUsage { ModelRequestUsage(RequestUsage {
limit: UsageLimit::Limited(50), limit: UsageLimit::Limited(50),
amount: 45, // 90% of limit amount: 45, // 90% of limit
}, }),
) )
.into_any_element(), .into_any_element(),
), ),
@ -139,10 +138,10 @@ impl Component for UsageCallout {
"Limit reached (100%)", "Limit reached (100%)",
UsageCallout::new( UsageCallout::new(
Plan::ZedFree, Plan::ZedFree,
RequestUsage { ModelRequestUsage(RequestUsage {
limit: UsageLimit::Limited(50), limit: UsageLimit::Limited(50),
amount: 50, // 100% of limit amount: 50, // 100% of limit
}, }),
) )
.into_any_element(), .into_any_element(),
), ),
@ -156,10 +155,10 @@ impl Component for UsageCallout {
"Approaching limit (90%)", "Approaching limit (90%)",
UsageCallout::new( UsageCallout::new(
Plan::ZedProTrial, Plan::ZedProTrial,
RequestUsage { ModelRequestUsage(RequestUsage {
limit: UsageLimit::Limited(150), limit: UsageLimit::Limited(150),
amount: 135, // 90% of limit amount: 135, // 90% of limit
}, }),
) )
.into_any_element(), .into_any_element(),
), ),
@ -167,10 +166,10 @@ impl Component for UsageCallout {
"Limit reached (100%)", "Limit reached (100%)",
UsageCallout::new( UsageCallout::new(
Plan::ZedProTrial, Plan::ZedProTrial,
RequestUsage { ModelRequestUsage(RequestUsage {
limit: UsageLimit::Limited(150), limit: UsageLimit::Limited(150),
amount: 150, // 100% of limit amount: 150, // 100% of limit
}, }),
) )
.into_any_element(), .into_any_element(),
), ),
@ -184,10 +183,10 @@ impl Component for UsageCallout {
"Limit reached (100%)", "Limit reached (100%)",
UsageCallout::new( UsageCallout::new(
Plan::ZedPro, Plan::ZedPro,
RequestUsage { ModelRequestUsage(RequestUsage {
limit: UsageLimit::Limited(500), limit: UsageLimit::Limited(500),
amount: 500, // 100% of limit amount: 500, // 100% of limit
}, }),
) )
.into_any_element(), .into_any_element(),
), ),

View file

@ -24,6 +24,7 @@ chrono = { workspace = true, features = ["serde"] }
clock.workspace = true clock.workspace = true
collections.workspace = true collections.workspace = true
credentials_provider.workspace = true credentials_provider.workspace = true
derive_more.workspace = true
feature_flags.workspace = true feature_flags.workspace = true
futures.workspace = true futures.workspace = true
gpui.workspace = true gpui.workspace = true
@ -57,6 +58,7 @@ worktree.workspace = true
telemetry.workspace = true telemetry.workspace = true
tokio.workspace = true tokio.workspace = true
workspace-hack.workspace = true workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies] [dev-dependencies]
clock = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] }

View file

@ -2,16 +2,25 @@ use super::{Client, Status, TypedEnvelope, proto};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::{HashMap, HashSet, hash_map::Entry}; use collections::{HashMap, HashSet, hash_map::Entry};
use derive_more::Deref;
use feature_flags::FeatureFlagAppExt; use feature_flags::FeatureFlagAppExt;
use futures::{Future, StreamExt, channel::mpsc}; use futures::{Future, StreamExt, channel::mpsc};
use gpui::{ use gpui::{
App, AsyncApp, Context, Entity, EventEmitter, SharedString, SharedUri, Task, WeakEntity, App, AsyncApp, Context, Entity, EventEmitter, SharedString, SharedUri, Task, WeakEntity,
}; };
use http_client::http::{HeaderMap, HeaderValue};
use postage::{sink::Sink, watch}; use postage::{sink::Sink, watch};
use rpc::proto::{RequestMessage, UsersResponse}; use rpc::proto::{RequestMessage, UsersResponse};
use std::sync::{Arc, Weak}; use std::{
str::FromStr as _,
sync::{Arc, Weak},
};
use text::ReplicaId; use text::ReplicaId;
use util::{TryFutureExt as _, maybe}; use util::{TryFutureExt as _, maybe};
use zed_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
};
pub type UserId = u64; pub type UserId = u64;
@ -104,10 +113,8 @@ pub struct UserStore {
current_plan: Option<proto::Plan>, current_plan: Option<proto::Plan>,
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>, subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
trial_started_at: Option<DateTime<Utc>>, trial_started_at: Option<DateTime<Utc>>,
model_request_usage_amount: Option<u32>, model_request_usage: Option<ModelRequestUsage>,
model_request_usage_limit: Option<proto::UsageLimit>, edit_prediction_usage: Option<EditPredictionUsage>,
edit_predictions_usage_amount: Option<u32>,
edit_predictions_usage_limit: Option<proto::UsageLimit>,
is_usage_based_billing_enabled: Option<bool>, is_usage_based_billing_enabled: Option<bool>,
account_too_young: Option<bool>, account_too_young: Option<bool>,
has_overdue_invoices: Option<bool>, has_overdue_invoices: Option<bool>,
@ -155,6 +162,18 @@ enum UpdateContacts {
Clear(postage::barrier::Sender), Clear(postage::barrier::Sender),
} }
#[derive(Debug, Clone, Copy, Deref)]
pub struct ModelRequestUsage(pub RequestUsage);
#[derive(Debug, Clone, Copy, Deref)]
pub struct EditPredictionUsage(pub RequestUsage);
#[derive(Debug, Clone, Copy)]
pub struct RequestUsage {
pub limit: UsageLimit,
pub amount: i32,
}
impl UserStore { impl UserStore {
pub fn new(client: Arc<Client>, cx: &Context<Self>) -> Self { pub fn new(client: Arc<Client>, cx: &Context<Self>) -> Self {
let (mut current_user_tx, current_user_rx) = watch::channel(); let (mut current_user_tx, current_user_rx) = watch::channel();
@ -172,10 +191,8 @@ impl UserStore {
current_plan: None, current_plan: None,
subscription_period: None, subscription_period: None,
trial_started_at: None, trial_started_at: None,
model_request_usage_amount: None, model_request_usage: None,
model_request_usage_limit: None, edit_prediction_usage: None,
edit_predictions_usage_amount: None,
edit_predictions_usage_limit: None,
is_usage_based_billing_enabled: None, is_usage_based_billing_enabled: None,
account_too_young: None, account_too_young: None,
has_overdue_invoices: None, has_overdue_invoices: None,
@ -356,10 +373,19 @@ impl UserStore {
this.has_overdue_invoices = message.payload.has_overdue_invoices; this.has_overdue_invoices = message.payload.has_overdue_invoices;
if let Some(usage) = message.payload.usage { if let Some(usage) = message.payload.usage {
this.model_request_usage_amount = Some(usage.model_requests_usage_amount); // limits are always present even though they are wrapped in Option
this.model_request_usage_limit = usage.model_requests_usage_limit; this.model_request_usage = usage
this.edit_predictions_usage_amount = Some(usage.edit_predictions_usage_amount); .model_requests_usage_limit
this.edit_predictions_usage_limit = usage.edit_predictions_usage_limit; .and_then(|limit| {
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
})
.map(ModelRequestUsage);
this.edit_prediction_usage = usage
.edit_predictions_usage_limit
.and_then(|limit| {
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
})
.map(EditPredictionUsage);
} }
cx.notify(); cx.notify();
@ -367,6 +393,20 @@ impl UserStore {
Ok(()) Ok(())
} }
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
self.model_request_usage = Some(usage);
cx.notify();
}
pub fn update_edit_prediction_usage(
&mut self,
usage: EditPredictionUsage,
cx: &mut Context<Self>,
) {
self.edit_prediction_usage = Some(usage);
cx.notify();
}
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> { fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
match message { match message {
UpdateContacts::Wait(barrier) => { UpdateContacts::Wait(barrier) => {
@ -739,20 +779,12 @@ impl UserStore {
self.is_usage_based_billing_enabled self.is_usage_based_billing_enabled
} }
pub fn model_request_usage_amount(&self) -> Option<u32> { pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
self.model_request_usage_amount self.model_request_usage
} }
pub fn model_request_usage_limit(&self) -> Option<proto::UsageLimit> { pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
self.model_request_usage_limit.clone() self.edit_prediction_usage
}
pub fn edit_predictions_usage_amount(&self) -> Option<u32> {
self.edit_predictions_usage_amount
}
pub fn edit_predictions_usage_limit(&self) -> Option<proto::UsageLimit> {
self.edit_predictions_usage_limit.clone()
} }
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> { pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
@ -917,3 +949,63 @@ impl Collaborator {
}) })
} }
} }
impl RequestUsage {
pub fn over_limit(&self) -> bool {
match self.limit {
UsageLimit::Limited(limit) => self.amount >= limit,
UsageLimit::Unlimited => false,
}
}
pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option<Self> {
let limit = match limit.variant? {
proto::usage_limit::Variant::Limited(limited) => {
UsageLimit::Limited(limited.limit as i32)
}
proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited,
};
Some(RequestUsage {
limit,
amount: amount as i32,
})
}
fn from_headers(
limit_name: &str,
amount_name: &str,
headers: &HeaderMap<HeaderValue>,
) -> Result<Self> {
let limit = headers
.get(limit_name)
.with_context(|| format!("missing {limit_name:?} header"))?;
let limit = UsageLimit::from_str(limit.to_str()?)?;
let amount = headers
.get(amount_name)
.with_context(|| format!("missing {amount_name:?} header"))?;
let amount = amount.to_str()?.parse::<i32>()?;
Ok(Self { limit, amount })
}
}
impl ModelRequestUsage {
pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
Ok(Self(RequestUsage::from_headers(
MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME,
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME,
headers,
)?))
}
}
impl EditPredictionUsage {
pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
Ok(Self(RequestUsage::from_headers(
EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME,
headers,
)?))
}
}

View file

@ -12,9 +12,8 @@ workspace = true
path = "src/inline_completion.rs" path = "src/inline_completion.rs"
[dependencies] [dependencies]
anyhow.workspace = true client.workspace = true
gpui.workspace = true gpui.workspace = true
language.workspace = true language.workspace = true
project.workspace = true project.workspace = true
workspace-hack.workspace = true workspace-hack.workspace = true
zed_llm_client.workspace = true

View file

@ -1,14 +1,9 @@
use std::ops::Range; use std::ops::Range;
use std::str::FromStr as _;
use anyhow::{Context as _, Result}; use client::EditPredictionUsage;
use gpui::http_client::http::{HeaderMap, HeaderValue};
use gpui::{App, Context, Entity, SharedString}; use gpui::{App, Context, Entity, SharedString};
use language::Buffer; use language::Buffer;
use project::Project; use project::Project;
use zed_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
};
// TODO: Find a better home for `Direction`. // TODO: Find a better home for `Direction`.
// //
@ -59,39 +54,6 @@ impl DataCollectionState {
} }
} }
#[derive(Debug, Clone, Copy)]
pub struct EditPredictionUsage {
pub limit: UsageLimit,
pub amount: i32,
}
impl EditPredictionUsage {
pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
let limit = headers
.get(EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME)
.with_context(|| {
format!("missing {EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME:?} header")
})?;
let limit = UsageLimit::from_str(limit.to_str()?)?;
let amount = headers
.get(EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME)
.with_context(|| {
format!("missing {EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME:?} header")
})?;
let amount = amount.to_str()?.parse::<i32>()?;
Ok(Self { limit, amount })
}
pub fn over_limit(&self) -> bool {
match self.limit {
UsageLimit::Limited(limit) => self.amount >= limit,
UsageLimit::Unlimited => false,
}
}
}
pub trait EditPredictionProvider: 'static + Sized { pub trait EditPredictionProvider: 'static + Sized {
fn name() -> &'static str; fn name() -> &'static str;
fn display_name() -> &'static str; fn display_name() -> &'static str;

View file

@ -8,27 +8,22 @@ mod telemetry;
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub mod fake_provider; pub mod fake_provider;
use anyhow::{Context as _, Result}; use anyhow::Result;
use client::Client; use client::Client;
use futures::FutureExt; use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::http::{HeaderMap, HeaderValue};
use icons::IconName; use icons::IconName;
use parking_lot::Mutex; use parking_lot::Mutex;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::fmt; use std::fmt;
use std::ops::{Add, Sub}; use std::ops::{Add, Sub};
use std::str::FromStr as _;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use util::serde::is_default; use util::serde::is_default;
use zed_llm_client::{ use zed_llm_client::CompletionRequestStatus;
CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME,
MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
};
pub use crate::model::*; pub use crate::model::*;
pub use crate::rate_limiter::*; pub use crate::rate_limiter::*;
@ -106,32 +101,6 @@ pub enum StopReason {
Refusal, Refusal,
} }
#[derive(Debug, Clone, Copy)]
pub struct RequestUsage {
pub limit: UsageLimit,
pub amount: i32,
}
impl RequestUsage {
pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
let limit = headers
.get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
.with_context(|| {
format!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header")
})?;
let limit = UsageLimit::from_str(limit.to_str()?)?;
let amount = headers
.get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
.with_context(|| {
format!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header")
})?;
let amount = amount.to_str()?.parse::<i32>()?;
Ok(Self { limit, amount })
}
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
pub struct TokenUsage { pub struct TokenUsage {
#[serde(default, skip_serializing_if = "is_default")] #[serde(default, skip_serializing_if = "is_default")]

View file

@ -1,6 +1,6 @@
use anthropic::{AnthropicModelMode, parse_prompt_too_long}; use anthropic::{AnthropicModelMode, parse_prompt_too_long};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use client::{Client, UserStore, zed_urls}; use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use futures::{ use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
}; };
@ -14,7 +14,7 @@ use language_model::{
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage, LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_ID,
}; };
use language_model::{ use language_model::{
@ -530,7 +530,7 @@ pub struct CloudLanguageModel {
struct PerformLlmCompletionResponse { struct PerformLlmCompletionResponse {
response: Response<AsyncBody>, response: Response<AsyncBody>,
usage: Option<RequestUsage>, usage: Option<ModelRequestUsage>,
tool_use_limit_reached: bool, tool_use_limit_reached: bool,
includes_status_messages: bool, includes_status_messages: bool,
} }
@ -581,7 +581,7 @@ impl CloudLanguageModel {
let usage = if includes_status_messages { let usage = if includes_status_messages {
None None
} else { } else {
RequestUsage::from_headers(response.headers()).ok() ModelRequestUsage::from_headers(response.headers()).ok()
}; };
return Ok(PerformLlmCompletionResponse { return Ok(PerformLlmCompletionResponse {
@ -1002,7 +1002,7 @@ where
} }
fn usage_updated_event<T>( fn usage_updated_event<T>(
usage: Option<RequestUsage>, usage: Option<ModelRequestUsage>,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> { ) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::iter(usage.map(|usage| { futures::stream::iter(usage.map(|usage| {
Ok(CloudCompletionEvent::Status( Ok(CloudCompletionEvent::Status(

View file

@ -9,14 +9,14 @@ mod rate_completion_modal;
pub(crate) use completion_diff_element::*; pub(crate) use completion_diff_element::*;
use db::kvp::KEY_VALUE_STORE; use db::kvp::KEY_VALUE_STORE;
pub use init::*; pub use init::*;
use inline_completion::{DataCollectionState, EditPredictionUsage}; use inline_completion::DataCollectionState;
use license_detection::LICENSE_FILES_TO_CHECK; use license_detection::LICENSE_FILES_TO_CHECK;
pub use license_detection::is_license_eligible_for_data_collection; pub use license_detection::is_license_eligible_for_data_collection;
pub use rate_completion_modal::*; pub use rate_completion_modal::*;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
use client::{Client, UserStore}; use client::{Client, EditPredictionUsage, UserStore};
use collections::{HashMap, HashSet, VecDeque}; use collections::{HashMap, HashSet, VecDeque};
use futures::AsyncReadExt; use futures::AsyncReadExt;
use gpui::{ use gpui::{
@ -48,7 +48,7 @@ use std::{
}; };
use telemetry_events::InlineCompletionRating; use telemetry_events::InlineCompletionRating;
use thiserror::Error; use thiserror::Error;
use util::{ResultExt, maybe}; use util::ResultExt;
use uuid::Uuid; use uuid::Uuid;
use workspace::Workspace; use workspace::Workspace;
use workspace::notifications::{ErrorMessagePrompt, NotificationId}; use workspace::notifications::{ErrorMessagePrompt, NotificationId};
@ -188,7 +188,6 @@ pub struct Zeta {
data_collection_choice: Entity<DataCollectionChoice>, data_collection_choice: Entity<DataCollectionChoice>,
llm_token: LlmApiToken, llm_token: LlmApiToken,
_llm_token_subscription: Subscription, _llm_token_subscription: Subscription,
last_usage: Option<EditPredictionUsage>,
/// Whether the terms of service have been accepted. /// Whether the terms of service have been accepted.
tos_accepted: bool, tos_accepted: bool,
/// Whether an update to a newer version of Zed is required to continue using Zeta. /// Whether an update to a newer version of Zed is required to continue using Zeta.
@ -234,25 +233,7 @@ impl Zeta {
} }
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> { pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
self.last_usage.or_else(|| { self.user_store.read(cx).edit_prediction_usage()
let user_store = self.user_store.read(cx);
maybe!({
let amount = user_store.edit_predictions_usage_amount()?;
let limit = user_store.edit_predictions_usage_limit()?.variant?;
Some(EditPredictionUsage {
amount: amount as i32,
limit: match limit {
proto::usage_limit::Variant::Limited(limited) => {
zed_llm_client::UsageLimit::Limited(limited.limit as i32)
}
proto::usage_limit::Variant::Unlimited(_) => {
zed_llm_client::UsageLimit::Unlimited
}
},
})
})
})
} }
fn new( fn new(
@ -287,7 +268,6 @@ impl Zeta {
.detach_and_log_err(cx); .detach_and_log_err(cx);
}, },
), ),
last_usage: None,
tos_accepted: user_store tos_accepted: user_store
.read(cx) .read(cx)
.current_user_has_accepted_terms() .current_user_has_accepted_terms()
@ -533,8 +513,10 @@ impl Zeta {
log::debug!("completion response: {}", &response.output_excerpt); log::debug!("completion response: {}", &response.output_excerpt);
if let Some(usage) = usage { if let Some(usage) = usage {
this.update(cx, |this, _cx| { this.update(cx, |this, cx| {
this.last_usage = Some(usage); this.user_store.update(cx, |user_store, cx| {
user_store.update_edit_prediction_usage(usage, cx);
});
}) })
.ok(); .ok();
} }
@ -874,8 +856,9 @@ and then another
if response.status().is_success() { if response.status().is_success() {
if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.last_usage = Some(usage); this.user_store.update(cx, |user_store, cx| {
cx.notify(); user_store.update_edit_prediction_usage(usage, cx);
});
})?; })?;
} }