agent2: Token count (#36496)

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>
This commit is contained in:
Bennet Bo Fenner 2025-08-19 22:40:31 +02:00 committed by GitHub
parent 6825715503
commit 5fb68cb8be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 321 additions and 26 deletions

View file

@ -13,7 +13,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use chrono::{DateTime, Utc};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
use collections::IndexMap;
use collections::{HashMap, IndexMap};
use fs::Fs;
use futures::{
FutureExt,
@ -24,8 +24,8 @@ use futures::{
use git::repository::DiffType;
use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
@ -481,6 +481,7 @@ pub enum ThreadEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
TokenUsageUpdate(acp_thread::TokenUsage),
TitleUpdate(SharedString),
Retry(acp_thread::RetryStatus),
Stop(acp::StopReason),
@ -509,8 +510,7 @@ pub struct Thread {
pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool,
#[allow(unused)]
request_token_usage: Vec<TokenUsage>,
request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
#[allow(unused)]
cumulative_token_usage: TokenUsage,
#[allow(unused)]
@ -548,7 +548,7 @@ impl Thread {
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
request_token_usage: Vec::new(),
request_token_usage: HashMap::default(),
cumulative_token_usage: TokenUsage::default(),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project.clone(), cx);
@ -951,6 +951,15 @@ impl Thread {
self.flush_pending_message(cx);
}
pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
let Some(last_user_message) = self.last_user_message() else {
return;
};
self.request_token_usage
.insert(last_user_message.id.clone(), update);
}
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
self.cancel(cx);
let Some(position) = self.messages.iter().position(
@ -958,11 +967,31 @@ impl Thread {
) else {
return Err(anyhow!("Message not found"));
};
self.messages.truncate(position);
for message in self.messages.drain(position..) {
match message {
Message::User(message) => {
self.request_token_usage.remove(&message.id);
}
Message::Agent(_) | Message::Resume => {}
}
}
cx.notify();
Ok(())
}
pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
let last_user_message = self.last_user_message()?;
let tokens = self.request_token_usage.get(&last_user_message.id)?;
let model = self.model.clone()?;
Some(acp_thread::TokenUsage {
max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
used_tokens: tokens.total_tokens(),
})
}
pub fn resume(
&mut self,
cx: &mut Context<Self>,
@ -1148,6 +1177,21 @@ impl Thread {
)) => {
*tool_use_limit_reached = true;
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
let usage = acp_thread::TokenUsage {
max_tokens: model.max_token_count_for_mode(
request
.mode
.unwrap_or(cloud_llm_client::CompletionMode::Normal),
),
used_tokens: token_usage.total_tokens(),
};
this.update(cx, |this, _cx| this.update_token_usage(token_usage))
.ok();
event_stream.send_token_usage_update(usage);
}
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
*refusal = true;
return Ok(FuturesUnordered::default());
@ -1532,6 +1576,16 @@ impl Thread {
})
}))
}
fn last_user_message(&self) -> Option<&UserMessage> {
self.messages
.iter()
.rev()
.find_map(|message| match message {
Message::User(user_message) => Some(user_message),
Message::Agent(_) => None,
Message::Resume => None,
})
}
fn pending_message(&mut self) -> &mut AgentMessage {
self.pending_message.get_or_insert_default()
@ -2051,6 +2105,12 @@ impl ThreadEventStream {
.ok();
}
fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
self.0
.unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
.ok();
}
fn send_retry(&self, status: acp_thread::RetryStatus) {
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
}