diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index e64e5c7360..bc70f81f3e 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -1,5 +1,6 @@ use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; use anyhow::{Result, anyhow}; use assistant_context_editor::{ @@ -14,9 +15,9 @@ use client::zed_urls; use editor::{Editor, MultiBuffer}; use fs::Fs; use gpui::{ - Action, AnyElement, App, AsyncWindowContext, Corner, Entity, EventEmitter, FocusHandle, - Focusable, FontWeight, KeyContext, Pixels, Subscription, Task, UpdateGlobal, WeakEntity, - action_with_deprecated_aliases, prelude::*, + Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, Corner, Entity, + EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels, Subscription, Task, + UpdateGlobal, WeakEntity, action_with_deprecated_aliases, prelude::*, pulsating_between, }; use language::LanguageRegistry; use language_model::{LanguageModelProviderTosView, LanguageModelRegistry}; @@ -38,7 +39,7 @@ use crate::active_thread::ActiveThread; use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent}; use crate::history_store::{HistoryEntry, HistoryStore}; use crate::message_editor::MessageEditor; -use crate::thread::{Thread, ThreadError, ThreadId}; +use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio}; use crate::thread_history::{PastContext, PastThread, ThreadHistory}; use crate::thread_store::ThreadStore; use crate::{ @@ -715,18 +716,21 @@ impl Panel for AssistantPanel { impl AssistantPanel { fn render_toolbar(&self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let thread = self.thread.read(cx); - let is_empty = thread.is_empty(); + let active_thread = self.thread.read(cx); + let thread = active_thread.thread().read(cx); + let token_usage = thread.total_token_usage(cx); + let thread_id = thread.id().clone(); - let thread_id = thread.thread().read(cx).id().clone(); + let is_generating = thread.is_generating(); + let is_empty = active_thread.is_empty(); let focus_handle = self.focus_handle(cx); let title = match self.active_view { ActiveView::Thread => { if is_empty { - thread.summary_or_default(cx) + active_thread.summary_or_default(cx) } else { - thread + active_thread .summary(cx) .unwrap_or_else(|| SharedString::from("Loading Summary…")) } @@ -742,6 +746,12 @@ impl AssistantPanel { ActiveView::Configuration => "Settings".into(), }; + let show_token_count = match self.active_view { + ActiveView::Thread => !is_empty, + ActiveView::PromptEditor => self.context_editor.is_some(), + _ => false, + }; + h_flex() .id("assistant-toolbar") .h(Tab::container_height(cx)) @@ -764,12 +774,67 @@ impl AssistantPanel { .pl_2() .gap_2() .bg(cx.theme().colors().tab_bar_background) - .children(if matches!(self.active_view, ActiveView::PromptEditor) { - self.context_editor - .as_ref() - .and_then(|editor| render_remaining_tokens(editor, cx)) - } else { - None + .when(show_token_count, |parent| match self.active_view { + ActiveView::Thread => { + if token_usage.total == 0 { + return parent; + } + + let token_color = match token_usage.ratio { + TokenUsageRatio::Normal => Color::Muted, + TokenUsageRatio::Warning => Color::Warning, + TokenUsageRatio::Exceeded => Color::Error, + }; + + parent.child( + h_flex() + .gap_0p5() + .child( + Label::new(assistant_context_editor::humanize_token_count( + token_usage.total, + )) + .size(LabelSize::Small) + .color(token_color) + .map(|label| { + if is_generating { + label + .with_animation( + "used-tokens-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between( + 0.6, 1., + )), + |label, delta| label.alpha(delta), + ) + .into_any() + } else { + label.into_any_element() + } + }), + ) + .child( + Label::new("/").size(LabelSize::Small).color(Color::Muted), + ) + .child( + Label::new(assistant_context_editor::humanize_token_count( + token_usage.max, + )) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + } + ActiveView::PromptEditor => { + let Some(editor) = self.context_editor.as_ref() else { + return parent; + }; + let Some(element) = render_remaining_tokens(editor, cx) else { + return parent; + }; + parent.child(element) + } + _ => parent, }) .child( h_flex() diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 538f45d9b1..b1e98bf49d 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -28,7 +28,7 @@ use crate::context_picker::{ConfirmBehavior, ContextPicker, ContextPickerComplet use crate::context_store::{ContextStore, refresh_context_store_text}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::profile_selector::ProfileSelector; -use crate::thread::{RequestKind, Thread}; +use crate::thread::{RequestKind, Thread, TokenUsageRatio}; use crate::thread_store::ThreadStore; use crate::{ AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ThreadEvent, @@ -338,7 +338,7 @@ impl Render for MessageEditor { let thread = self.thread.read(cx); let is_generating = thread.is_generating(); - let is_too_long = thread.is_getting_too_long(cx); + let total_token_usage = thread.total_token_usage(cx); let is_model_selected = self.is_model_selected(cx); let is_editor_empty = self.is_editor_empty(cx); let needs_confirmation = @@ -788,7 +788,7 @@ impl Render for MessageEditor { ), ) ) - .when(is_too_long, |parent| { + .when(total_token_usage.ratio != TokenUsageRatio::Normal, |parent| { parent.child( h_flex() .p_2() diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 9619cdb492..8b5941dff6 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -214,6 +214,21 @@ pub enum DetailedSummaryState { }, } +#[derive(Default)] +pub struct TotalTokenUsage { + pub total: usize, + pub max: usize, + pub ratio: TokenUsageRatio, +} + +#[derive(Default, PartialEq, Eq)] +pub enum TokenUsageRatio { + #[default] + Normal, + Warning, + Exceeded, +} + /// A thread of conversation with the LLM. pub struct Thread { id: ThreadId, @@ -1723,26 +1738,33 @@ impl Thread { self.cumulative_token_usage.clone() } - pub fn is_getting_too_long(&self, cx: &App) -> bool { + pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage { let model_registry = LanguageModelRegistry::read_global(cx); let Some(model) = model_registry.active_model() else { - return false; + return TotalTokenUsage::default(); }; - let max_tokens = model.max_token_count(); - - let current_usage = - self.cumulative_token_usage.input_tokens + self.cumulative_token_usage.output_tokens; + let max = model.max_token_count(); #[cfg(debug_assertions)] let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") - .unwrap_or("0.9".to_string()) + .unwrap_or("0.8".to_string()) .parse() .unwrap(); #[cfg(not(debug_assertions))] - let warning_threshold: f32 = 0.9; + let warning_threshold: f32 = 0.8; - current_usage as f32 >= (max_tokens as f32 * warning_threshold) + let total = self.cumulative_token_usage.total_tokens() as usize; + + let ratio = if total >= max { + TokenUsageRatio::Exceeded + } else if total as f32 / max as f32 >= warning_threshold { + TokenUsageRatio::Warning + } else { + TokenUsageRatio::Normal + }; + + TotalTokenUsage { total, max, ratio } } pub fn deny_tool_use( diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index c155c0b2bf..121ee9345d 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -3703,6 +3703,18 @@ pub fn humanize_token_count(count: usize) -> String { format!("{}.{}k", thousands, hundreds) } } + 1_000_000..=9_999_999 => { + let millions = count / 1_000_000; + let hundred_thousands = (count % 1_000_000 + 50_000) / 100_000; + if hundred_thousands == 0 { + format!("{}M", millions) + } else if hundred_thousands == 10 { + format!("{}M", millions + 1) + } else { + format!("{}.{}M", millions, hundred_thousands) + } + } + 10_000_000.. => format!("{}M", (count + 500_000) / 1_000_000), _ => format!("{}k", (count + 500) / 1000), } } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 23564f2d61..aa060f7b30 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -95,6 +95,12 @@ pub struct TokenUsage { pub cache_read_input_tokens: u32, } +impl TokenUsage { + pub fn total_tokens(&self) -> u32 { + self.input_tokens + self.output_tokens + } +} + impl Add for TokenUsage { type Output = Self;