agent2: Fix token count not updating when changing model/toggling burn mode (#36562)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2025-08-20 11:29:05 +02:00 committed by GitHub
parent 44941b5dfe
commit 4290f043cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 69 additions and 33 deletions

View file

@ -26,6 +26,7 @@ assistant_context.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
context_server.workspace = true

View file

@ -1,8 +1,8 @@
use crate::HistoryStore;
use crate::{
ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
UserMessageContent, templates::Templates,
};
use crate::{HistoryStore, TokenUsageUpdated};
use acp_thread::{AcpThread, AgentModelSelector};
use action_log::ActionLog;
use agent_client_protocol as acp;
@ -253,6 +253,7 @@ impl NativeAgent {
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
cx.observe(&thread_handle, move |this, thread, cx| {
this.save_thread(thread.clone(), cx)
}),
@ -440,6 +441,23 @@ impl NativeAgent {
})
}
fn handle_thread_token_usage_updated(
&mut self,
thread: Entity<Thread>,
usage: &TokenUsageUpdated,
cx: &mut Context<Self>,
) {
let Some(session) = self.sessions.get(thread.read(cx).id()) else {
return;
};
session
.acp_thread
.update(cx, |acp_thread, cx| {
acp_thread.update_token_usage(usage.0.clone(), cx);
})
.ok();
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,
@ -695,11 +713,6 @@ impl NativeAgentConnection {
thread.update_tool_call(update, cx)
})??;
}
ThreadEvent::TokenUsageUpdate(usage) => {
acp_thread.update(cx, |thread, cx| {
thread.update_token_usage(Some(usage), cx)
})?;
}
ThreadEvent::TitleUpdate(title) => {
acp_thread
.update(cx, |thread, cx| thread.update_title(title, cx))??;

View file

@ -15,7 +15,8 @@ use agent_settings::{
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use chrono::{DateTime, Utc};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use collections::{HashMap, IndexMap};
use fs::Fs;
use futures::{
@ -25,7 +26,9 @@ use futures::{
stream::FuturesUnordered,
};
use git::repository::DiffType;
use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
@ -484,7 +487,6 @@ pub enum ThreadEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
TokenUsageUpdate(acp_thread::TokenUsage),
TitleUpdate(SharedString),
Retry(acp_thread::RetryStatus),
Stop(acp::StopReason),
@ -873,7 +875,12 @@ impl Thread {
}
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
let old_usage = self.latest_token_usage();
self.model = Some(model);
let new_usage = self.latest_token_usage();
if old_usage != new_usage {
cx.emit(TokenUsageUpdated(new_usage));
}
cx.notify()
}
@ -891,7 +898,12 @@ impl Thread {
}
pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
let old_usage = self.latest_token_usage();
self.completion_mode = mode;
let new_usage = self.latest_token_usage();
if old_usage != new_usage {
cx.emit(TokenUsageUpdated(new_usage));
}
cx.notify()
}
@ -953,13 +965,15 @@ impl Thread {
self.flush_pending_message(cx);
}
pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
let Some(last_user_message) = self.last_user_message() else {
return;
};
self.request_token_usage
.insert(last_user_message.id.clone(), update);
cx.emit(TokenUsageUpdated(self.latest_token_usage()));
cx.notify();
}
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
@ -1180,20 +1194,15 @@ impl Thread {
)) => {
*tool_use_limit_reached = true;
}
Ok(LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
)) => {
this.update(cx, |this, cx| {
this.update_model_request_usage(amount, limit, cx)
})?;
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
let usage = acp_thread::TokenUsage {
max_tokens: model.max_token_count_for_mode(
request
.mode
.unwrap_or(cloud_llm_client::CompletionMode::Normal),
),
used_tokens: token_usage.total_tokens(),
};
this.update(cx, |this, _cx| this.update_token_usage(token_usage))
.ok();
event_stream.send_token_usage_update(usage);
this.update(cx, |this, cx| this.update_token_usage(token_usage, cx))?;
}
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
*refusal = true;
@ -1214,8 +1223,7 @@ impl Thread {
event_stream,
cx,
));
})
.ok();
})?;
}
Err(error) => {
let completion_mode =
@ -1325,8 +1333,8 @@ impl Thread {
json_parse_error,
)));
}
UsageUpdate(_) | StatusUpdate(_) => {}
Stop(_) => unreachable!(),
StatusUpdate(_) => {}
UsageUpdate(_) | Stop(_) => unreachable!(),
}
None
@ -1506,6 +1514,21 @@ impl Thread {
}
}
fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context<Self>) {
self.project
.read(cx)
.user_store()
.update(cx, |user_store, cx| {
user_store.update_model_request_usage(
ModelRequestUsage(RequestUsage {
amount: amount as i32,
limit,
}),
cx,
)
});
}
pub fn title(&self) -> SharedString {
self.title.clone().unwrap_or("New Thread".into())
}
@ -1636,6 +1659,7 @@ impl Thread {
})
}))
}
fn last_user_message(&self) -> Option<&UserMessage> {
self.messages
.iter()
@ -1934,6 +1958,10 @@ impl RunningTurn {
}
}
pub struct TokenUsageUpdated(pub Option<acp_thread::TokenUsage>);
impl EventEmitter<TokenUsageUpdated> for Thread {}
pub trait AgentTool
where
Self: 'static + Sized,
@ -2166,12 +2194,6 @@ impl ThreadEventStream {
.ok();
}
fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
self.0
.unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
.ok();
}
fn send_retry(&self, status: acp_thread::RetryStatus) {
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
}