agent2: Fix token count not updating when changing model/toggling burn mode (#36562)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
44941b5dfe
commit
4290f043cd
3 changed files with 69 additions and 33 deletions
|
@ -26,6 +26,7 @@ assistant_context.workspace = true
|
|||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::HistoryStore;
|
||||
use crate::{
|
||||
ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
|
||||
UserMessageContent, templates::Templates,
|
||||
};
|
||||
use crate::{HistoryStore, TokenUsageUpdated};
|
||||
use acp_thread::{AcpThread, AgentModelSelector};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
|
@ -253,6 +253,7 @@ impl NativeAgent {
|
|||
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
}),
|
||||
cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
|
||||
cx.observe(&thread_handle, move |this, thread, cx| {
|
||||
this.save_thread(thread.clone(), cx)
|
||||
}),
|
||||
|
@ -440,6 +441,23 @@ impl NativeAgent {
|
|||
})
|
||||
}
|
||||
|
||||
fn handle_thread_token_usage_updated(
|
||||
&mut self,
|
||||
thread: Entity<Thread>,
|
||||
usage: &TokenUsageUpdated,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(session) = self.sessions.get(thread.read(cx).id()) else {
|
||||
return;
|
||||
};
|
||||
session
|
||||
.acp_thread
|
||||
.update(cx, |acp_thread, cx| {
|
||||
acp_thread.update_token_usage(usage.0.clone(), cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
_project: Entity<Project>,
|
||||
|
@ -695,11 +713,6 @@ impl NativeAgentConnection {
|
|||
thread.update_tool_call(update, cx)
|
||||
})??;
|
||||
}
|
||||
ThreadEvent::TokenUsageUpdate(usage) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.update_token_usage(Some(usage), cx)
|
||||
})?;
|
||||
}
|
||||
ThreadEvent::TitleUpdate(title) => {
|
||||
acp_thread
|
||||
.update(cx, |thread, cx| thread.update_title(title, cx))??;
|
||||
|
|
|
@ -15,7 +15,8 @@ use agent_settings::{
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
use collections::{HashMap, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
|
@ -25,7 +26,9 @@ use futures::{
|
|||
stream::FuturesUnordered,
|
||||
};
|
||||
use git::repository::DiffType;
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
|
||||
LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
|
||||
|
@ -484,7 +487,6 @@ pub enum ThreadEvent {
|
|||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp_thread::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
TokenUsageUpdate(acp_thread::TokenUsage),
|
||||
TitleUpdate(SharedString),
|
||||
Retry(acp_thread::RetryStatus),
|
||||
Stop(acp::StopReason),
|
||||
|
@ -873,7 +875,12 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
|
||||
let old_usage = self.latest_token_usage();
|
||||
self.model = Some(model);
|
||||
let new_usage = self.latest_token_usage();
|
||||
if old_usage != new_usage {
|
||||
cx.emit(TokenUsageUpdated(new_usage));
|
||||
}
|
||||
cx.notify()
|
||||
}
|
||||
|
||||
|
@ -891,7 +898,12 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
|
||||
let old_usage = self.latest_token_usage();
|
||||
self.completion_mode = mode;
|
||||
let new_usage = self.latest_token_usage();
|
||||
if old_usage != new_usage {
|
||||
cx.emit(TokenUsageUpdated(new_usage));
|
||||
}
|
||||
cx.notify()
|
||||
}
|
||||
|
||||
|
@ -953,13 +965,15 @@ impl Thread {
|
|||
self.flush_pending_message(cx);
|
||||
}
|
||||
|
||||
pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
|
||||
fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
|
||||
let Some(last_user_message) = self.last_user_message() else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.request_token_usage
|
||||
.insert(last_user_message.id.clone(), update);
|
||||
cx.emit(TokenUsageUpdated(self.latest_token_usage()));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
|
||||
|
@ -1180,20 +1194,15 @@ impl Thread {
|
|||
)) => {
|
||||
*tool_use_limit_reached = true;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
||||
)) => {
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_model_request_usage(amount, limit, cx)
|
||||
})?;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
let usage = acp_thread::TokenUsage {
|
||||
max_tokens: model.max_token_count_for_mode(
|
||||
request
|
||||
.mode
|
||||
.unwrap_or(cloud_llm_client::CompletionMode::Normal),
|
||||
),
|
||||
used_tokens: token_usage.total_tokens(),
|
||||
};
|
||||
|
||||
this.update(cx, |this, _cx| this.update_token_usage(token_usage))
|
||||
.ok();
|
||||
|
||||
event_stream.send_token_usage_update(usage);
|
||||
this.update(cx, |this, cx| this.update_token_usage(token_usage, cx))?;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
|
||||
*refusal = true;
|
||||
|
@ -1214,8 +1223,7 @@ impl Thread {
|
|||
event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
})?;
|
||||
}
|
||||
Err(error) => {
|
||||
let completion_mode =
|
||||
|
@ -1325,8 +1333,8 @@ impl Thread {
|
|||
json_parse_error,
|
||||
)));
|
||||
}
|
||||
UsageUpdate(_) | StatusUpdate(_) => {}
|
||||
Stop(_) => unreachable!(),
|
||||
StatusUpdate(_) => {}
|
||||
UsageUpdate(_) | Stop(_) => unreachable!(),
|
||||
}
|
||||
|
||||
None
|
||||
|
@ -1506,6 +1514,21 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context<Self>) {
|
||||
self.project
|
||||
.read(cx)
|
||||
.user_store()
|
||||
.update(cx, |user_store, cx| {
|
||||
user_store.update_model_request_usage(
|
||||
ModelRequestUsage(RequestUsage {
|
||||
amount: amount as i32,
|
||||
limit,
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
pub fn title(&self) -> SharedString {
|
||||
self.title.clone().unwrap_or("New Thread".into())
|
||||
}
|
||||
|
@ -1636,6 +1659,7 @@ impl Thread {
|
|||
})
|
||||
}))
|
||||
}
|
||||
|
||||
fn last_user_message(&self) -> Option<&UserMessage> {
|
||||
self.messages
|
||||
.iter()
|
||||
|
@ -1934,6 +1958,10 @@ impl RunningTurn {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct TokenUsageUpdated(pub Option<acp_thread::TokenUsage>);
|
||||
|
||||
impl EventEmitter<TokenUsageUpdated> for Thread {}
|
||||
|
||||
pub trait AgentTool
|
||||
where
|
||||
Self: 'static + Sized,
|
||||
|
@ -2166,12 +2194,6 @@ impl ThreadEventStream {
|
|||
.ok();
|
||||
}
|
||||
|
||||
fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
|
||||
self.0
|
||||
.unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_retry(&self, status: acp_thread::RetryStatus) {
|
||||
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue