agent: Fix conversation token usage and estimate unsent message (#28878)

The UI was mistakenly using the cumulative token usage for the token
counter. It will now display the last request token count, plus an
estimation of the tokens in the message editor and context entries that
haven't been sent yet.


https://github.com/user-attachments/assets/0438c501-b850-4397-9135-57214ca3c07a

Additionally, when the user edits a message, we'll display the actual
token count up to it and estimate the tokens in the new message.

Note: We don't currently estimate the delta when switching profiles. In
the future, we want to use the count tokens API to measure every part of
the request and display a breakdown.

Release Notes:

- agent: Made the token count more accurate and added back estimation of
used tokens as you type and add context.

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
This commit is contained in:
Agus Zubiaga 2025-04-16 13:27:36 -06:00 committed by GitHub
parent 8de53bd89f
commit 0286b8ab3e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 507 additions and 130 deletions

View file

@ -227,7 +227,33 @@ pub enum DetailedSummaryState {
pub struct TotalTokenUsage {
pub total: usize,
pub max: usize,
pub ratio: TokenUsageRatio,
}
impl TotalTokenUsage {
pub fn ratio(&self) -> TokenUsageRatio {
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
.parse()
.unwrap();
#[cfg(not(debug_assertions))]
let warning_threshold: f32 = 0.8;
if self.total >= self.max {
TokenUsageRatio::Exceeded
} else if self.total as f32 / self.max as f32 >= warning_threshold {
TokenUsageRatio::Warning
} else {
TokenUsageRatio::Normal
}
}
pub fn add(&self, tokens: usize) -> TotalTokenUsage {
TotalTokenUsage {
total: self.total + tokens,
max: self.max,
}
}
}
#[derive(Debug, Default, PartialEq, Eq)]
@ -261,6 +287,7 @@ pub struct Thread {
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
pending_checkpoint: Option<ThreadCheckpoint>,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
feedback: Option<ThreadFeedback>,
@ -311,6 +338,7 @@ impl Thread {
.spawn(async move { Some(project_snapshot.await) })
.shared()
},
request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None,
feedback: None,
@ -378,6 +406,7 @@ impl Thread {
tool_use,
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
request_token_usage: serialized.request_token_usage,
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
feedback: None,
@ -643,6 +672,18 @@ impl Thread {
self.tool_use.message_has_tool_results(message_id)
}
/// Filter out contexts that have already been included in previous messages
pub fn filter_new_context<'a>(
&self,
context: impl Iterator<Item = &'a AssistantContext>,
) -> impl Iterator<Item = &'a AssistantContext> {
context.filter(|ctx| self.is_context_new(ctx))
}
fn is_context_new(&self, context: &AssistantContext) -> bool {
!self.context.contains_key(&context.id())
}
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
@ -654,10 +695,9 @@ impl Thread {
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
// Filter out contexts that have already been included in previous messages
let new_context: Vec<_> = context
.into_iter()
.filter(|ctx| !self.context.contains_key(&ctx.id()))
.filter(|ctx| self.is_context_new(ctx))
.collect();
if !new_context.is_empty() {
@ -837,6 +877,7 @@ impl Thread {
.collect(),
initial_project_snapshot,
cumulative_token_usage: this.cumulative_token_usage,
request_token_usage: this.request_token_usage.clone(),
detailed_summary_state: this.detailed_summary_state.clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
})
@ -1022,7 +1063,6 @@ impl Thread {
cx: &mut Context<Self>,
) {
let pending_completion_id = post_inc(&mut self.completion_count);
let task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion(request, &cx);
let initial_token_usage =
@ -1048,6 +1088,7 @@ impl Thread {
stop_reason = reason;
}
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
thread.update_token_usage_at_last_message(token_usage);
thread.cumulative_token_usage = thread.cumulative_token_usage
+ token_usage
- current_token_usage;
@ -1889,6 +1930,35 @@ impl Thread {
self.cumulative_token_usage
}
pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return TotalTokenUsage::default();
};
let max = model.model.max_token_count();
let index = self
.messages
.iter()
.position(|msg| msg.id == message_id)
.unwrap_or(0);
if index == 0 {
return TotalTokenUsage { total: 0, max };
}
let token_usage = &self
.request_token_usage
.get(index - 1)
.cloned()
.unwrap_or_default();
TotalTokenUsage {
total: token_usage.total_tokens() as usize,
max,
}
}
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.default_model() else {
@ -1902,30 +1972,33 @@ impl Thread {
return TotalTokenUsage {
total: exceeded_error.token_count,
max,
ratio: TokenUsageRatio::Exceeded,
};
}
}
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
.parse()
.unwrap();
#[cfg(not(debug_assertions))]
let warning_threshold: f32 = 0.8;
let total = self
.token_usage_at_last_message()
.unwrap_or_default()
.total_tokens() as usize;
let total = self.cumulative_token_usage.total_tokens() as usize;
TotalTokenUsage { total, max }
}
let ratio = if total >= max {
TokenUsageRatio::Exceeded
} else if total as f32 / max as f32 >= warning_threshold {
TokenUsageRatio::Warning
} else {
TokenUsageRatio::Normal
};
fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
self.request_token_usage
.get(self.messages.len().saturating_sub(1))
.or_else(|| self.request_token_usage.last())
.cloned()
}
TotalTokenUsage { total, max, ratio }
fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
let placeholder = self.token_usage_at_last_message().unwrap_or_default();
self.request_token_usage
.resize(self.messages.len(), placeholder);
if let Some(last) = self.request_token_usage.last_mut() {
*last = token_usage;
}
}
pub fn deny_tool_use(