diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 2a39440af8..bc32a79622 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -26,6 +26,7 @@ assistant_context.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true chrono.workspace = true +client.workspace = true cloud_llm_client.workspace = true collections.workspace = true context_server.workspace = true diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 3c605de803..ab5716d8ad 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,8 +1,8 @@ -use crate::HistoryStore; use crate::{ ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization, UserMessageContent, templates::Templates, }; +use crate::{HistoryStore, TokenUsageUpdated}; use acp_thread::{AcpThread, AgentModelSelector}; use action_log::ActionLog; use agent_client_protocol as acp; @@ -253,6 +253,7 @@ impl NativeAgent { cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); }), + cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated), cx.observe(&thread_handle, move |this, thread, cx| { this.save_thread(thread.clone(), cx) }), @@ -440,6 +441,23 @@ impl NativeAgent { }) } + fn handle_thread_token_usage_updated( + &mut self, + thread: Entity, + usage: &TokenUsageUpdated, + cx: &mut Context, + ) { + let Some(session) = self.sessions.get(thread.read(cx).id()) else { + return; + }; + session + .acp_thread + .update(cx, |acp_thread, cx| { + acp_thread.update_token_usage(usage.0.clone(), cx); + }) + .ok(); + } + fn handle_project_event( &mut self, _project: Entity, @@ -695,11 +713,6 @@ impl NativeAgentConnection { thread.update_tool_call(update, cx) })??; } - ThreadEvent::TokenUsageUpdate(usage) => { - acp_thread.update(cx, |thread, cx| { - thread.update_token_usage(Some(usage), cx) - })?; - } ThreadEvent::TitleUpdate(title) => { acp_thread .update(cx, |thread, cx| thread.update_title(title, cx))??; diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index c1778bf38b..b6405dbcbd 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -15,7 +15,8 @@ use agent_settings::{ use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use chrono::{DateTime, Utc}; -use cloud_llm_client::{CompletionIntent, CompletionRequestStatus}; +use client::{ModelRequestUsage, RequestUsage}; +use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use collections::{HashMap, IndexMap}; use fs::Fs; use futures::{ @@ -25,7 +26,9 @@ use futures::{ stream::FuturesUnordered, }; use git::repository::DiffType; -use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity}; +use gpui::{ + App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, +}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, @@ -484,7 +487,6 @@ pub enum ThreadEvent { ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), - TokenUsageUpdate(acp_thread::TokenUsage), TitleUpdate(SharedString), Retry(acp_thread::RetryStatus), Stop(acp::StopReason), @@ -873,7 +875,12 @@ impl Thread { } pub fn set_model(&mut self, model: Arc, cx: &mut Context) { + let old_usage = self.latest_token_usage(); self.model = Some(model); + let new_usage = self.latest_token_usage(); + if old_usage != new_usage { + cx.emit(TokenUsageUpdated(new_usage)); + } cx.notify() } @@ -891,7 +898,12 @@ impl Thread { } pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context) { + let old_usage = self.latest_token_usage(); self.completion_mode = mode; + let new_usage = self.latest_token_usage(); + if old_usage != new_usage { + cx.emit(TokenUsageUpdated(new_usage)); + } cx.notify() } @@ -953,13 +965,15 @@ impl Thread { self.flush_pending_message(cx); } - pub fn update_token_usage(&mut self, update: language_model::TokenUsage) { + fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context) { let Some(last_user_message) = self.last_user_message() else { return; }; self.request_token_usage .insert(last_user_message.id.clone(), update); + cx.emit(TokenUsageUpdated(self.latest_token_usage())); + cx.notify(); } pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { @@ -1180,20 +1194,15 @@ impl Thread { )) => { *tool_use_limit_reached = true; } + Ok(LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { amount, limit }, + )) => { + this.update(cx, |this, cx| { + this.update_model_request_usage(amount, limit, cx) + })?; + } Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { - let usage = acp_thread::TokenUsage { - max_tokens: model.max_token_count_for_mode( - request - .mode - .unwrap_or(cloud_llm_client::CompletionMode::Normal), - ), - used_tokens: token_usage.total_tokens(), - }; - - this.update(cx, |this, _cx| this.update_token_usage(token_usage)) - .ok(); - - event_stream.send_token_usage_update(usage); + this.update(cx, |this, cx| this.update_token_usage(token_usage, cx))?; } Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => { *refusal = true; @@ -1214,8 +1223,7 @@ impl Thread { event_stream, cx, )); - }) - .ok(); + })?; } Err(error) => { let completion_mode = @@ -1325,8 +1333,8 @@ impl Thread { json_parse_error, ))); } - UsageUpdate(_) | StatusUpdate(_) => {} - Stop(_) => unreachable!(), + StatusUpdate(_) => {} + UsageUpdate(_) | Stop(_) => unreachable!(), } None @@ -1506,6 +1514,21 @@ impl Thread { } } + fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context) { + self.project + .read(cx) + .user_store() + .update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }); + } + pub fn title(&self) -> SharedString { self.title.clone().unwrap_or("New Thread".into()) } @@ -1636,6 +1659,7 @@ impl Thread { }) })) } + fn last_user_message(&self) -> Option<&UserMessage> { self.messages .iter() @@ -1934,6 +1958,10 @@ impl RunningTurn { } } +pub struct TokenUsageUpdated(pub Option); + +impl EventEmitter for Thread {} + pub trait AgentTool where Self: 'static + Sized, @@ -2166,12 +2194,6 @@ impl ThreadEventStream { .ok(); } - fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) { - self.0 - .unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage))) - .ok(); - } - fn send_retry(&self, status: acp_thread::RetryStatus) { self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok(); }