diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 0cd90ba796..fea66c3dea 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1,4 +1,4 @@ -use crate::context::{AssistantContext, ContextId}; +use crate::context::{AssistantContext, ContextId, format_context_as_string}; use crate::context_picker::MentionLink; use crate::thread::{ LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError, @@ -13,16 +13,18 @@ use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting}; use assistant_tool::ToolUseStatus; use collections::{HashMap, HashSet}; use editor::scroll::Autoscroll; -use editor::{Editor, EditorElement, EditorStyle, MultiBuffer}; +use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer}; use gpui::{ AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem, - DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Hsla, ListAlignment, ListState, - MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, Task, - TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle, + DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment, + ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, + Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage, pulsating_between, }; use language::{Buffer, LanguageRegistry}; -use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason}; +use language_model::{ + LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, Role, StopReason, +}; use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; use project::ProjectItem as _; @@ -682,6 +684,9 @@ fn open_markdown_link( struct EditMessageState { editor: Entity, + last_estimated_token_count: Option, + _subscription: Subscription, + _update_token_count_task: Option>>, } impl ActiveThread { @@ -781,6 +786,13 @@ impl ActiveThread { self.last_error.take(); } + /// Returns the editing message id and the estimated token count in the content + pub fn editing_message_id(&self) -> Option<(MessageId, usize)> { + self.editing_message + .as_ref() + .map(|(id, state)| (*id, state.last_estimated_token_count.unwrap_or(0))) + } + fn push_message( &mut self, id: &MessageId, @@ -1126,15 +1138,91 @@ impl ActiveThread { editor.move_to_end(&editor::actions::MoveToEnd, window, cx); editor }); + let subscription = cx.subscribe(&editor, |this, _, event, cx| match event { + EditorEvent::BufferEdited => { + this.update_editing_message_token_count(true, cx); + } + _ => {} + }); self.editing_message = Some(( message_id, EditMessageState { editor: editor.clone(), + last_estimated_token_count: None, + _subscription: subscription, + _update_token_count_task: None, }, )); + self.update_editing_message_token_count(false, cx); cx.notify(); } + fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context) { + let Some((message_id, state)) = self.editing_message.as_mut() else { + return; + }; + + cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged); + state._update_token_count_task.take(); + + let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else { + state.last_estimated_token_count.take(); + return; + }; + + let editor = state.editor.clone(); + let thread = self.thread.clone(); + let message_id = *message_id; + + state._update_token_count_task = Some(cx.spawn(async move |this, cx| { + if debounce { + cx.background_executor() + .timer(Duration::from_millis(200)) + .await; + } + + let token_count = if let Some(task) = cx.update(|cx| { + let context = thread.read(cx).context_for_message(message_id); + let new_context = thread.read(cx).filter_new_context(context); + let context_text = + format_context_as_string(new_context, cx).unwrap_or(String::new()); + let message_text = editor.read(cx).text(cx); + + let content = context_text + &message_text; + + if content.is_empty() { + return None; + } + + let request = language_model::LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: language_model::Role::User, + content: vec![content.into()], + cache: false, + }], + tools: vec![], + stop: vec![], + temperature: None, + }; + + Some(default_model.model.count_tokens(request, cx)) + })? { + task.await? + } else { + 0 + }; + + this.update(cx, |this, cx| { + let Some((_message_id, state)) = this.editing_message.as_mut() else { + return; + }; + + state.last_estimated_token_count = Some(token_count); + cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged); + }) + })); + } + fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { self.editing_message.take(); cx.notify(); @@ -1676,6 +1764,9 @@ impl ActiveThread { "confirm-edit-message", "Regenerate", ) + .disabled( + edit_message_editor.read(cx).is_empty(cx), + ) .label_size(LabelSize::Small) .key_binding( KeyBinding::for_action_in( @@ -1738,8 +1829,16 @@ impl ActiveThread { ), }; + let after_editing_message = self + .editing_message + .as_ref() + .map_or(false, |(editing_message_id, _)| { + message_id > *editing_message_id + }); + v_flex() .w_full() + .when(after_editing_message, |parent| parent.opacity(0.2)) .when_some(checkpoint, |parent, checkpoint| { let mut is_pending = false; let mut error = None; @@ -2965,6 +3064,12 @@ impl ActiveThread { } } +pub enum ActiveThreadEvent { + EditingMessageTokenCountChanged, +} + +impl EventEmitter for ActiveThread {} + impl Render for ActiveThread { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 3090f4682d..6b7ac9e849 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -5,7 +5,7 @@ use std::time::Duration; use anyhow::{Result, anyhow}; use assistant_context_editor::{ AssistantPanelDelegate, ConfigurationError, ContextEditor, SlashCommandCompletionProvider, - make_lsp_adapter_delegate, render_remaining_tokens, + humanize_token_count, make_lsp_adapter_delegate, render_remaining_tokens, }; use assistant_settings::{AssistantDockPosition, AssistantSettings}; use assistant_slash_command::SlashCommandWorkingSet; @@ -37,10 +37,10 @@ use workspace::dock::{DockPosition, Panel, PanelEvent}; use zed_actions::agent::OpenConfiguration; use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus}; -use crate::active_thread::ActiveThread; +use crate::active_thread::{ActiveThread, ActiveThreadEvent}; use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent}; use crate::history_store::{HistoryEntry, HistoryStore}; -use crate::message_editor::MessageEditor; +use crate::message_editor::{MessageEditor, MessageEditorEvent}; use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio}; use crate::thread_history::{PastContext, PastThread, ThreadHistory}; use crate::thread_store::ThreadStore; @@ -181,8 +181,8 @@ pub struct AssistantPanel { language_registry: Arc, thread_store: Entity, thread: Entity, - _thread_subscription: Subscription, message_editor: Entity, + _active_thread_subscriptions: Vec, context_store: Entity, context_editor: Option>, configuration: Option>, @@ -264,6 +264,13 @@ impl AssistantPanel { ) }); + let message_editor_subscription = + cx.subscribe(&message_editor, |_, _, event, cx| match event { + MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => { + cx.notify(); + } + }); + let history_store = cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx)); @@ -288,6 +295,12 @@ impl AssistantPanel { ) }); + let active_thread_subscription = cx.subscribe(&thread, |_, _, event, cx| match &event { + ActiveThreadEvent::EditingMessageTokenCountChanged => { + cx.notify(); + } + }); + Self { active_view, workspace, @@ -296,8 +309,12 @@ impl AssistantPanel { language_registry, thread_store: thread_store.clone(), thread, - _thread_subscription: thread_subscription, message_editor, + _active_thread_subscriptions: vec![ + thread_subscription, + active_thread_subscription, + message_editor_subscription, + ], context_store, context_editor: None, configuration: None, @@ -382,6 +399,13 @@ impl AssistantPanel { .detach_and_log_err(cx); } + let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| { + if let ThreadEvent::MessageAdded(_) = &event { + // needed to leave empty state + cx.notify(); + } + }); + self.thread = cx.new(|cx| { ActiveThread::new( thread.clone(), @@ -394,12 +418,12 @@ impl AssistantPanel { ) }); - self._thread_subscription = cx.subscribe(&thread, |_, _, event, cx| { - if let ThreadEvent::MessageAdded(_) = &event { - // needed to leave empty state - cx.notify(); - } - }); + let active_thread_subscription = + cx.subscribe(&self.thread, |_, _, event, cx| match &event { + ActiveThreadEvent::EditingMessageTokenCountChanged => { + cx.notify(); + } + }); self.message_editor = cx.new(|cx| { MessageEditor::new( @@ -413,6 +437,19 @@ impl AssistantPanel { ) }); self.message_editor.focus_handle(cx).focus(window); + + let message_editor_subscription = + cx.subscribe(&self.message_editor, |_, _, event, cx| match event { + MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => { + cx.notify(); + } + }); + + self._active_thread_subscriptions = vec![ + thread_subscription, + active_thread_subscription, + message_editor_subscription, + ]; } fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context) { @@ -538,6 +575,13 @@ impl AssistantPanel { Some(this.thread_store.downgrade()), ) }); + let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| { + if let ThreadEvent::MessageAdded(_) = &event { + // needed to leave empty state + cx.notify(); + } + }); + this.thread = cx.new(|cx| { ActiveThread::new( thread.clone(), @@ -549,6 +593,14 @@ impl AssistantPanel { cx, ) }); + + let active_thread_subscription = + cx.subscribe(&this.thread, |_, _, event, cx| match &event { + ActiveThreadEvent::EditingMessageTokenCountChanged => { + cx.notify(); + } + }); + this.message_editor = cx.new(|cx| { MessageEditor::new( this.fs.clone(), @@ -561,6 +613,19 @@ impl AssistantPanel { ) }); this.message_editor.focus_handle(cx).focus(window); + + let message_editor_subscription = + cx.subscribe(&this.message_editor, |_, _, event, cx| match event { + MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => { + cx.notify(); + } + }); + + this._active_thread_subscriptions = vec![ + thread_subscription, + active_thread_subscription, + message_editor_subscription, + ]; }) }) } @@ -853,7 +918,7 @@ impl Panel for AssistantPanel { } impl AssistantPanel { - fn render_title_view(&self, _window: &mut Window, cx: &mut Context) -> AnyElement { + fn render_title_view(&self, _window: &mut Window, cx: &Context) -> AnyElement { const LOADING_SUMMARY_PLACEHOLDER: &str = "Loading Summary…"; let content = match &self.active_view { @@ -913,13 +978,8 @@ impl AssistantPanel { fn render_toolbar(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let active_thread = self.thread.read(cx); let thread = active_thread.thread().read(cx); - let token_usage = thread.total_token_usage(cx); let thread_id = thread.id().clone(); - - let is_generating = thread.is_generating(); let is_empty = active_thread.is_empty(); - let focus_handle = self.focus_handle(cx); - let is_history = matches!(self.active_view, ActiveView::History); let show_token_count = match &self.active_view { @@ -928,6 +988,8 @@ impl AssistantPanel { _ => false, }; + let focus_handle = self.focus_handle(cx); + let go_back_button = match &self.active_view { ActiveView::History | ActiveView::Configuration => Some( div().pl_1().child( @@ -974,69 +1036,9 @@ impl AssistantPanel { h_flex() .h_full() .gap_2() - .when(show_token_count, |parent| match self.active_view { - ActiveView::Thread { .. } => { - if token_usage.total == 0 { - return parent; - } - - let token_color = match token_usage.ratio { - TokenUsageRatio::Normal => Color::Muted, - TokenUsageRatio::Warning => Color::Warning, - TokenUsageRatio::Exceeded => Color::Error, - }; - - parent.child( - h_flex() - .flex_shrink_0() - .gap_0p5() - .child( - Label::new(assistant_context_editor::humanize_token_count( - token_usage.total, - )) - .size(LabelSize::Small) - .color(token_color) - .map(|label| { - if is_generating { - label - .with_animation( - "used-tokens-label", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between( - 0.6, 1., - )), - |label, delta| label.alpha(delta), - ) - .into_any() - } else { - label.into_any_element() - } - }), - ) - .child( - Label::new("/").size(LabelSize::Small).color(Color::Muted), - ) - .child( - Label::new(assistant_context_editor::humanize_token_count( - token_usage.max, - )) - .size(LabelSize::Small) - .color(Color::Muted), - ), - ) - } - ActiveView::PromptEditor => { - let Some(editor) = self.context_editor.as_ref() else { - return parent; - }; - let Some(element) = render_remaining_tokens(editor, cx) else { - return parent; - }; - parent.child(element) - } - _ => parent, - }) + .when(show_token_count, |parent| + parent.children(self.render_token_count(&thread, cx)) + ) .child( h_flex() .h_full() @@ -1132,6 +1134,111 @@ impl AssistantPanel { ) } + fn render_token_count(&self, thread: &Thread, cx: &App) -> Option { + let is_generating = thread.is_generating(); + let message_editor = self.message_editor.read(cx); + + let conversation_token_usage = thread.total_token_usage(cx); + let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) = + self.thread.read(cx).editing_message_id() + { + let combined = thread + .token_usage_up_to_message(editing_message_id, cx) + .add(unsent_tokens); + + (combined, unsent_tokens > 0) + } else { + let unsent_tokens = message_editor.last_estimated_token_count().unwrap_or(0); + let combined = conversation_token_usage.add(unsent_tokens); + + (combined, unsent_tokens > 0) + }; + + let is_waiting_to_update_token_count = message_editor.is_waiting_to_update_token_count(); + + match self.active_view { + ActiveView::Thread { .. } => { + if total_token_usage.total == 0 { + return None; + } + + let token_color = match total_token_usage.ratio() { + TokenUsageRatio::Normal if is_estimating => Color::Default, + TokenUsageRatio::Normal => Color::Muted, + TokenUsageRatio::Warning => Color::Warning, + TokenUsageRatio::Exceeded => Color::Error, + }; + + let token_count = h_flex() + .id("token-count") + .flex_shrink_0() + .gap_0p5() + .when(!is_generating && is_estimating, |parent| { + parent + .child( + h_flex() + .mr_0p5() + .size_2() + .justify_center() + .rounded_full() + .bg(cx.theme().colors().text.opacity(0.1)) + .child( + div().size_1().rounded_full().bg(cx.theme().colors().text), + ), + ) + .tooltip(move |window, cx| { + Tooltip::with_meta( + "Estimated New Token Count", + None, + format!( + "Current Conversation Tokens: {}", + humanize_token_count(conversation_token_usage.total) + ), + window, + cx, + ) + }) + }) + .child( + Label::new(humanize_token_count(total_token_usage.total)) + .size(LabelSize::Small) + .color(token_color) + .map(|label| { + if is_generating || is_waiting_to_update_token_count { + label + .with_animation( + "used-tokens-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.6, 1.)), + |label, delta| label.alpha(delta), + ) + .into_any() + } else { + label.into_any_element() + } + }), + ) + .child(Label::new("/").size(LabelSize::Small).color(Color::Muted)) + .child( + Label::new(humanize_token_count(total_token_usage.max)) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any(); + + Some(token_count) + } + ActiveView::PromptEditor => { + let editor = self.context_editor.as_ref()?; + let element = render_remaining_tokens(editor, cx)?; + + Some(element.into_any_element()) + } + _ => None, + } + } + fn render_active_thread_or_empty_state( &self, window: &mut Window, diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 59c01d16dd..df998b5102 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -2,22 +2,23 @@ use std::collections::BTreeMap; use std::sync::Arc; use crate::assistant_model_selector::ModelType; +use crate::context::format_context_as_string; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use buffer_diff::BufferDiff; use collections::HashSet; use editor::actions::MoveUp; use editor::{ - ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, EditorStyle, - MultiBuffer, + ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorEvent, EditorMode, + EditorStyle, MultiBuffer, }; use file_icons::FileIcons; use fs::Fs; use gpui::{ - Animation, AnimationExt, App, Entity, Focusable, Subscription, TextStyle, WeakEntity, - linear_color_stop, linear_gradient, point, pulsating_between, + Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle, + WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, }; use language::{Buffer, Language}; -use language_model::{ConfiguredModel, LanguageModelRegistry}; +use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage}; use language_model_selector::ToggleModelSelector; use multi_buffer; use project::Project; @@ -55,6 +56,8 @@ pub struct MessageEditor { edits_expanded: bool, editor_is_expanded: bool, waiting_for_summaries_to_send: bool, + last_estimated_token_count: Option, + update_token_count_task: Option>>, _subscriptions: Vec, } @@ -129,8 +132,18 @@ impl MessageEditor { let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.read(cx).tools().clone(), cx)); - let subscriptions = - vec![cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event)]; + let subscriptions = vec![ + cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), + cx.subscribe(&editor, |this, _, event, cx| match event { + EditorEvent::BufferEdited => { + this.message_or_context_changed(true, cx); + } + _ => {} + }), + cx.observe(&context_store, |this, _, cx| { + this.message_or_context_changed(false, cx); + }), + ]; Self { editor: editor.clone(), @@ -156,6 +169,8 @@ impl MessageEditor { waiting_for_summaries_to_send: false, profile_selector: cx .new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)), + last_estimated_token_count: None, + update_token_count_task: None, _subscriptions: subscriptions, } } @@ -256,6 +271,9 @@ impl MessageEditor { text }); + self.last_estimated_token_count.take(); + cx.emit(MessageEditorEvent::EstimatedTokenCount); + let refresh_task = refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx); @@ -937,6 +955,80 @@ impl MessageEditor { .label_size(LabelSize::Small), ) } + + pub fn last_estimated_token_count(&self) -> Option { + self.last_estimated_token_count + } + + pub fn is_waiting_to_update_token_count(&self) -> bool { + self.update_token_count_task.is_some() + } + + fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context) { + cx.emit(MessageEditorEvent::Changed); + self.update_token_count_task.take(); + + let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else { + self.last_estimated_token_count.take(); + return; + }; + + let context_store = self.context_store.clone(); + let editor = self.editor.clone(); + let thread = self.thread.clone(); + + self.update_token_count_task = Some(cx.spawn(async move |this, cx| { + if debounce { + cx.background_executor() + .timer(Duration::from_millis(200)) + .await; + } + + let token_count = if let Some(task) = cx.update(|cx| { + let context = context_store.read(cx).context().iter(); + let new_context = thread.read(cx).filter_new_context(context); + let context_text = + format_context_as_string(new_context, cx).unwrap_or(String::new()); + let message_text = editor.read(cx).text(cx); + + let content = context_text + &message_text; + + if content.is_empty() { + return None; + } + + let request = language_model::LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: language_model::Role::User, + content: vec![content.into()], + cache: false, + }], + tools: vec![], + stop: vec![], + temperature: None, + }; + + Some(default_model.model.count_tokens(request, cx)) + })? { + task.await? + } else { + 0 + }; + + this.update(cx, |this, cx| { + this.last_estimated_token_count = Some(token_count); + cx.emit(MessageEditorEvent::EstimatedTokenCount); + this.update_token_count_task.take(); + }) + })); + } +} + +impl EventEmitter for MessageEditor {} + +pub enum MessageEditorEvent { + EstimatedTokenCount, + Changed, } impl Focusable for MessageEditor { @@ -949,6 +1041,7 @@ impl Render for MessageEditor { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let thread = self.thread.read(cx); let total_token_usage = thread.total_token_usage(cx); + let token_usage_ratio = total_token_usage.ratio(); let action_log = self.thread.read(cx).action_log(); let changed_buffers = action_log.read(cx).changed_buffers(cx); @@ -997,15 +1090,8 @@ impl Render for MessageEditor { parent.child(self.render_changed_buffers(&changed_buffers, window, cx)) }) .child(self.render_editor(font_size, line_height, window, cx)) - .when( - total_token_usage.ratio != TokenUsageRatio::Normal, - |parent| { - parent.child(self.render_token_limit_callout( - line_height, - total_token_usage.ratio, - cx, - )) - }, - ) + .when(token_usage_ratio != TokenUsageRatio::Normal, |parent| { + parent.child(self.render_token_limit_callout(line_height, token_usage_ratio, cx)) + }) } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 694b212e31..a0d6f99ea0 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -227,7 +227,33 @@ pub enum DetailedSummaryState { pub struct TotalTokenUsage { pub total: usize, pub max: usize, - pub ratio: TokenUsageRatio, +} + +impl TotalTokenUsage { + pub fn ratio(&self) -> TokenUsageRatio { + #[cfg(debug_assertions)] + let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") + .unwrap_or("0.8".to_string()) + .parse() + .unwrap(); + #[cfg(not(debug_assertions))] + let warning_threshold: f32 = 0.8; + + if self.total >= self.max { + TokenUsageRatio::Exceeded + } else if self.total as f32 / self.max as f32 >= warning_threshold { + TokenUsageRatio::Warning + } else { + TokenUsageRatio::Normal + } + } + + pub fn add(&self, tokens: usize) -> TotalTokenUsage { + TotalTokenUsage { + total: self.total + tokens, + max: self.max, + } + } } #[derive(Debug, Default, PartialEq, Eq)] @@ -261,6 +287,7 @@ pub struct Thread { last_restore_checkpoint: Option, pending_checkpoint: Option, initial_project_snapshot: Shared>>>, + request_token_usage: Vec, cumulative_token_usage: TokenUsage, exceeded_window_error: Option, feedback: Option, @@ -311,6 +338,7 @@ impl Thread { .spawn(async move { Some(project_snapshot.await) }) .shared() }, + request_token_usage: Vec::new(), cumulative_token_usage: TokenUsage::default(), exceeded_window_error: None, feedback: None, @@ -378,6 +406,7 @@ impl Thread { tool_use, action_log: cx.new(|_| ActionLog::new(project)), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), + request_token_usage: serialized.request_token_usage, cumulative_token_usage: serialized.cumulative_token_usage, exceeded_window_error: None, feedback: None, @@ -643,6 +672,18 @@ impl Thread { self.tool_use.message_has_tool_results(message_id) } + /// Filter out contexts that have already been included in previous messages + pub fn filter_new_context<'a>( + &self, + context: impl Iterator, + ) -> impl Iterator { + context.filter(|ctx| self.is_context_new(ctx)) + } + + fn is_context_new(&self, context: &AssistantContext) -> bool { + !self.context.contains_key(&context.id()) + } + pub fn insert_user_message( &mut self, text: impl Into, @@ -654,10 +695,9 @@ impl Thread { let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx); - // Filter out contexts that have already been included in previous messages let new_context: Vec<_> = context .into_iter() - .filter(|ctx| !self.context.contains_key(&ctx.id())) + .filter(|ctx| self.is_context_new(ctx)) .collect(); if !new_context.is_empty() { @@ -837,6 +877,7 @@ impl Thread { .collect(), initial_project_snapshot, cumulative_token_usage: this.cumulative_token_usage, + request_token_usage: this.request_token_usage.clone(), detailed_summary_state: this.detailed_summary_state.clone(), exceeded_window_error: this.exceeded_window_error.clone(), }) @@ -1022,7 +1063,6 @@ impl Thread { cx: &mut Context, ) { let pending_completion_id = post_inc(&mut self.completion_count); - let task = cx.spawn(async move |thread, cx| { let stream = model.stream_completion(request, &cx); let initial_token_usage = @@ -1048,6 +1088,7 @@ impl Thread { stop_reason = reason; } LanguageModelCompletionEvent::UsageUpdate(token_usage) => { + thread.update_token_usage_at_last_message(token_usage); thread.cumulative_token_usage = thread.cumulative_token_usage + token_usage - current_token_usage; @@ -1889,6 +1930,35 @@ impl Thread { self.cumulative_token_usage } + pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage { + let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { + return TotalTokenUsage::default(); + }; + + let max = model.model.max_token_count(); + + let index = self + .messages + .iter() + .position(|msg| msg.id == message_id) + .unwrap_or(0); + + if index == 0 { + return TotalTokenUsage { total: 0, max }; + } + + let token_usage = &self + .request_token_usage + .get(index - 1) + .cloned() + .unwrap_or_default(); + + TotalTokenUsage { + total: token_usage.total_tokens() as usize, + max, + } + } + pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage { let model_registry = LanguageModelRegistry::read_global(cx); let Some(model) = model_registry.default_model() else { @@ -1902,30 +1972,33 @@ impl Thread { return TotalTokenUsage { total: exceeded_error.token_count, max, - ratio: TokenUsageRatio::Exceeded, }; } } - #[cfg(debug_assertions)] - let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") - .unwrap_or("0.8".to_string()) - .parse() - .unwrap(); - #[cfg(not(debug_assertions))] - let warning_threshold: f32 = 0.8; + let total = self + .token_usage_at_last_message() + .unwrap_or_default() + .total_tokens() as usize; - let total = self.cumulative_token_usage.total_tokens() as usize; + TotalTokenUsage { total, max } + } - let ratio = if total >= max { - TokenUsageRatio::Exceeded - } else if total as f32 / max as f32 >= warning_threshold { - TokenUsageRatio::Warning - } else { - TokenUsageRatio::Normal - }; + fn token_usage_at_last_message(&self) -> Option { + self.request_token_usage + .get(self.messages.len().saturating_sub(1)) + .or_else(|| self.request_token_usage.last()) + .cloned() + } - TotalTokenUsage { total, max, ratio } + fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) { + let placeholder = self.token_usage_at_last_message().unwrap_or_default(); + self.request_token_usage + .resize(self.messages.len(), placeholder); + + if let Some(last) = self.request_token_usage.last_mut() { + *last = token_usage; + } } pub fn deny_tool_use( diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 6fb0f6c7a2..a72313061c 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -509,6 +509,8 @@ pub struct SerializedThread { #[serde(default)] pub cumulative_token_usage: TokenUsage, #[serde(default)] + pub request_token_usage: Vec, + #[serde(default)] pub detailed_summary_state: DetailedSummaryState, #[serde(default)] pub exceeded_window_error: Option, @@ -597,6 +599,7 @@ impl LegacySerializedThread { messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(), initial_project_snapshot: self.initial_project_snapshot, cumulative_token_usage: TokenUsage::default(), + request_token_usage: Vec::new(), detailed_summary_state: DetailedSummaryState::default(), exceeded_window_error: None, } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index cf695023c8..35bf5d6094 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -97,7 +97,10 @@ pub struct TokenUsage { impl TokenUsage { pub fn total_tokens(&self) -> u32 { - self.input_tokens + self.output_tokens + self.input_tokens + + self.output_tokens + + self.cache_read_input_tokens + + self.cache_creation_input_tokens } } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 7746d214b4..ee0f941afa 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -705,12 +705,12 @@ pub fn map_to_language_model_completion_events( update_usage(&mut state.usage, &message.usage); return Some(( vec![ - Ok(LanguageModelCompletionEvent::StartMessage { - message_id: message.id, - }), Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( &state.usage, ))), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), ], state, )); diff --git a/crates/ui/src/components/tab.rs b/crates/ui/src/components/tab.rs index b9b4cb43ce..8b7d5bbdd4 100644 --- a/crates/ui/src/components/tab.rs +++ b/crates/ui/src/components/tab.rs @@ -73,11 +73,11 @@ impl Tab { self } - pub fn content_height(cx: &mut App) -> Pixels { + pub fn content_height(cx: &App) -> Pixels { DynamicSpacing::Base32.px(cx) - px(1.) } - pub fn container_height(cx: &mut App) -> Pixels { + pub fn container_height(cx: &App) -> Pixels { DynamicSpacing::Base32.px(cx) } }