diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d4d73e1edd..793ef35be2 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -6,6 +6,7 @@ mod terminal; pub use connection::*; pub use diff::*; pub use mention::*; +use serde::{Deserialize, Serialize}; pub use terminal::*; use action_log::ActionLog; @@ -664,6 +665,12 @@ impl PlanEntry { } } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TokenUsage { + pub max_tokens: u64, + pub used_tokens: u64, +} + #[derive(Debug, Clone)] pub struct RetryStatus { pub last_error: SharedString, @@ -683,12 +690,14 @@ pub struct AcpThread { send_task: Option>, connection: Rc, session_id: acp::SessionId, + token_usage: Option, } #[derive(Debug)] pub enum AcpThreadEvent { NewEntry, TitleUpdated, + TokenUsageUpdated, EntryUpdated(usize), EntriesRemoved(Range), ToolAuthorizationRequired, @@ -748,6 +757,7 @@ impl AcpThread { send_task: None, connection, session_id, + token_usage: None, } } @@ -787,6 +797,10 @@ impl AcpThread { } } + pub fn token_usage(&self) -> Option<&TokenUsage> { + self.token_usage.as_ref() + } + pub fn has_pending_edit_tool_calls(&self) -> bool { for entry in self.entries.iter().rev() { match entry { @@ -937,6 +951,11 @@ impl AcpThread { Ok(()) } + pub fn update_token_usage(&mut self, usage: Option, cx: &mut Context) { + self.token_usage = usage; + cx.emit(AcpThreadEvent::TokenUsageUpdated); + } + pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context) { cx.emit(AcpThreadEvent::Retry(status)); } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index b09f383029..8cae975ce5 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -10,7 +10,7 @@ use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; use uuid::Uuid; -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct UserMessageId(Arc); impl UserMessageId { diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 890f7e774b..d18773ff7b 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -66,6 +66,7 @@ zstd.workspace = true [dev-dependencies] agent = { workspace = true, "features" = ["test-support"] } +assistant_context = { workspace = true, "features" = ["test-support"] } ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 48f46a52fc..6303144d96 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,8 +1,8 @@ +use crate::HistoryStore; use crate::{ - ContextServerRegistry, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent, - templates::Templates, + ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization, + UserMessageContent, templates::Templates, }; -use crate::{HistoryStore, ThreadsDatabase}; use acp_thread::{AcpThread, AgentModelSelector}; use action_log::ActionLog; use agent_client_protocol as acp; @@ -673,6 +673,11 @@ impl NativeAgentConnection { thread.update_tool_call(update, cx) })??; } + ThreadEvent::TokenUsageUpdate(usage) => { + acp_thread.update(cx, |thread, cx| { + thread.update_token_usage(Some(usage), cx) + })?; + } ThreadEvent::TitleUpdate(title) => { acp_thread .update(cx, |thread, cx| thread.update_title(title, cx))??; @@ -895,10 +900,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection { cx: &mut App, ) -> Option> { self.0.update(cx, |agent, _cx| { - agent - .sessions - .get(session_id) - .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _) + agent.sessions.get(session_id).map(|session| { + Rc::new(NativeAgentSessionEditor { + thread: session.thread.clone(), + acp_thread: session.acp_thread.clone(), + }) as _ + }) }) } @@ -907,14 +914,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection { } } -struct NativeAgentSessionEditor(Entity); +struct NativeAgentSessionEditor { + thread: Entity, + acp_thread: WeakEntity, +} impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { - Task::ready( - self.0 - .update(cx, |thread, cx| thread.truncate(message_id, cx)), - ) + match self.thread.update(cx, |thread, cx| { + thread.truncate(message_id.clone(), cx)?; + Ok(thread.latest_token_usage()) + }) { + Ok(usage) => { + self.acp_thread + .update(cx, |thread, cx| { + thread.update_token_usage(usage, cx); + }) + .ok(); + Task::ready(Ok(())) + } + Err(error) => Task::ready(Err(error)), + } } } diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index 27a109c573..610a2575c4 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -1,4 +1,5 @@ use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; +use acp_thread::UserMessageId; use agent::thread_store; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; @@ -42,7 +43,7 @@ pub struct DbThread { #[serde(default)] pub cumulative_token_usage: language_model::TokenUsage, #[serde(default)] - pub request_token_usage: Vec, + pub request_token_usage: HashMap, #[serde(default)] pub model: Option, #[serde(default)] @@ -67,7 +68,10 @@ impl DbThread { fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result { let mut messages = Vec::new(); - for msg in thread.messages { + let mut request_token_usage = HashMap::default(); + + let mut last_user_message_id = None; + for (ix, msg) in thread.messages.into_iter().enumerate() { let message = match msg.role { language_model::Role::User => { let mut content = Vec::new(); @@ -93,9 +97,12 @@ impl DbThread { content.push(UserMessageContent::Text(msg.context)); } + let id = UserMessageId::new(); + last_user_message_id = Some(id.clone()); + crate::Message::User(UserMessage { // MessageId from old format can't be meaningfully converted, so generate a new one - id: acp_thread::UserMessageId::new(), + id, content, }) } @@ -154,6 +161,12 @@ impl DbThread { ); } + if let Some(last_user_message_id) = &last_user_message_id + && let Some(token_usage) = thread.request_token_usage.get(ix).copied() + { + request_token_usage.insert(last_user_message_id.clone(), token_usage); + } + crate::Message::Agent(AgentMessage { content, tool_results, @@ -175,7 +188,7 @@ impl DbThread { summary: thread.detailed_summary_state, initial_project_snapshot: thread.initial_project_snapshot, cumulative_token_usage: thread.cumulative_token_usage, - request_token_usage: thread.request_token_usage, + request_token_usage, model: thread.model, completion_mode: thread.completion_mode, profile: thread.profile, diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 7fa12e5711..d07ca42d3b 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1117,7 +1117,7 @@ async fn test_refusal(cx: &mut TestAppContext) { } #[gpui::test] -async fn test_truncate(cx: &mut TestAppContext) { +async fn test_truncate_first_message(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); @@ -1137,9 +1137,18 @@ async fn test_truncate(cx: &mut TestAppContext) { Hello "} ); + assert_eq!(thread.latest_token_usage(), None); }); fake_model.send_last_completion_stream_text_chunk("Hey!"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 32_000, + output_tokens: 16_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!( @@ -1154,6 +1163,13 @@ async fn test_truncate(cx: &mut TestAppContext) { Hey! "} ); + assert_eq!( + thread.latest_token_usage(), + Some(acp_thread::TokenUsage { + used_tokens: 32_000 + 16_000, + max_tokens: 1_000_000, + }) + ); }); thread @@ -1162,6 +1178,7 @@ async fn test_truncate(cx: &mut TestAppContext) { cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!(thread.to_markdown(), ""); + assert_eq!(thread.latest_token_usage(), None); }); // Ensure we can still send a new message after truncation. @@ -1182,6 +1199,14 @@ async fn test_truncate(cx: &mut TestAppContext) { }); cx.run_until_parked(); fake_model.send_last_completion_stream_text_chunk("Ahoy!"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 40_000, + output_tokens: 20_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!( @@ -1196,9 +1221,126 @@ async fn test_truncate(cx: &mut TestAppContext) { Ahoy! "} ); + + assert_eq!( + thread.latest_token_usage(), + Some(acp_thread::TokenUsage { + used_tokens: 40_000 + 20_000, + max_tokens: 1_000_000, + }) + ); }); } +#[gpui::test] +async fn test_truncate_second_message(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Message 1"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Message 1 response"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 32_000, + output_tokens: 16_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let assert_first_message_state = |cx: &mut TestAppContext| { + thread.clone().read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Message 1 + + ## Assistant + + Message 1 response + "} + ); + + assert_eq!( + thread.latest_token_usage(), + Some(acp_thread::TokenUsage { + used_tokens: 32_000 + 16_000, + max_tokens: 1_000_000, + }) + ); + }); + }; + + assert_first_message_state(cx); + + let second_message_id = UserMessageId::new(); + thread + .update(cx, |thread, cx| { + thread.send(second_message_id.clone(), ["Message 2"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("Message 2 response"); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( + language_model::TokenUsage { + input_tokens: 40_000, + output_tokens: 20_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Message 1 + + ## Assistant + + Message 1 response + + ## User + + Message 2 + + ## Assistant + + Message 2 response + "} + ); + + assert_eq!( + thread.latest_token_usage(), + Some(acp_thread::TokenUsage { + used_tokens: 40_000 + 20_000, + max_tokens: 1_000_000, + }) + ); + }); + + thread + .update(cx, |thread, cx| thread.truncate(second_message_id, cx)) + .unwrap(); + cx.run_until_parked(); + + assert_first_message_state(cx); +} + #[gpui::test] async fn test_title_generation(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index ba5cd1f477..4bc45f1544 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -13,7 +13,7 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use chrono::{DateTime, Utc}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus}; -use collections::IndexMap; +use collections::{HashMap, IndexMap}; use fs::Fs; use futures::{ FutureExt, @@ -24,8 +24,8 @@ use futures::{ use git::repository::DiffType; use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity}; use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, - LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt, + LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, @@ -481,6 +481,7 @@ pub enum ThreadEvent { ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + TokenUsageUpdate(acp_thread::TokenUsage), TitleUpdate(SharedString), Retry(acp_thread::RetryStatus), Stop(acp::StopReason), @@ -509,8 +510,7 @@ pub struct Thread { pending_message: Option, tools: BTreeMap>, tool_use_limit_reached: bool, - #[allow(unused)] - request_token_usage: Vec, + request_token_usage: HashMap, #[allow(unused)] cumulative_token_usage: TokenUsage, #[allow(unused)] @@ -548,7 +548,7 @@ impl Thread { pending_message: None, tools: BTreeMap::default(), tool_use_limit_reached: false, - request_token_usage: Vec::new(), + request_token_usage: HashMap::default(), cumulative_token_usage: TokenUsage::default(), initial_project_snapshot: { let project_snapshot = Self::project_snapshot(project.clone(), cx); @@ -951,6 +951,15 @@ impl Thread { self.flush_pending_message(cx); } + pub fn update_token_usage(&mut self, update: language_model::TokenUsage) { + let Some(last_user_message) = self.last_user_message() else { + return; + }; + + self.request_token_usage + .insert(last_user_message.id.clone(), update); + } + pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { self.cancel(cx); let Some(position) = self.messages.iter().position( @@ -958,11 +967,31 @@ impl Thread { ) else { return Err(anyhow!("Message not found")); }; - self.messages.truncate(position); + + for message in self.messages.drain(position..) { + match message { + Message::User(message) => { + self.request_token_usage.remove(&message.id); + } + Message::Agent(_) | Message::Resume => {} + } + } + cx.notify(); Ok(()) } + pub fn latest_token_usage(&self) -> Option { + let last_user_message = self.last_user_message()?; + let tokens = self.request_token_usage.get(&last_user_message.id)?; + let model = self.model.clone()?; + + Some(acp_thread::TokenUsage { + max_tokens: model.max_token_count_for_mode(self.completion_mode.into()), + used_tokens: tokens.total_tokens(), + }) + } + pub fn resume( &mut self, cx: &mut Context, @@ -1148,6 +1177,21 @@ impl Thread { )) => { *tool_use_limit_reached = true; } + Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { + let usage = acp_thread::TokenUsage { + max_tokens: model.max_token_count_for_mode( + request + .mode + .unwrap_or(cloud_llm_client::CompletionMode::Normal), + ), + used_tokens: token_usage.total_tokens(), + }; + + this.update(cx, |this, _cx| this.update_token_usage(token_usage)) + .ok(); + + event_stream.send_token_usage_update(usage); + } Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => { *refusal = true; return Ok(FuturesUnordered::default()); @@ -1532,6 +1576,16 @@ impl Thread { }) })) } + fn last_user_message(&self) -> Option<&UserMessage> { + self.messages + .iter() + .rev() + .find_map(|message| match message { + Message::User(user_message) => Some(user_message), + Message::Agent(_) => None, + Message::Resume => None, + }) + } fn pending_message(&mut self) -> &mut AgentMessage { self.pending_message.get_or_insert_default() @@ -2051,6 +2105,12 @@ impl ThreadEventStream { .ok(); } + fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) { + self.0 + .unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage))) + .ok(); + } + fn send_retry(&self, status: acp_thread::RetryStatus) { self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok(); } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 9f1e8d857f..878891c6f1 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -816,7 +816,7 @@ impl AcpThreadView { self.thread_retry_status.take(); self.thread_state = ThreadState::ServerExited { status: *status }; } - AcpThreadEvent::TitleUpdated => {} + AcpThreadEvent::TitleUpdated | AcpThreadEvent::TokenUsageUpdated => {} } cx.notify(); } @@ -2794,6 +2794,7 @@ impl AcpThreadView { .child( h_flex() .gap_1() + .children(self.render_token_usage(cx)) .children(self.profile_selector.clone()) .children(self.model_selector.clone()) .child(self.render_send_button(cx)), @@ -2816,6 +2817,44 @@ impl AcpThreadView { .thread(acp_thread.session_id(), cx) } + fn render_token_usage(&self, cx: &mut Context) -> Option
{ + let thread = self.thread()?.read(cx); + let usage = thread.token_usage()?; + let is_generating = thread.status() != ThreadStatus::Idle; + + let used = crate::text_thread_editor::humanize_token_count(usage.used_tokens); + let max = crate::text_thread_editor::humanize_token_count(usage.max_tokens); + + Some( + h_flex() + .flex_shrink_0() + .gap_0p5() + .mr_1() + .child( + Label::new(used) + .size(LabelSize::Small) + .color(Color::Muted) + .map(|label| { + if is_generating { + label + .with_animation( + "used-tokens-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.6, 1.)), + |label, delta| label.alpha(delta), + ) + .into_any() + } else { + label.into_any_element() + } + }), + ) + .child(Label::new("/").size(LabelSize::Small).color(Color::Muted)) + .child(Label::new(max).size(LabelSize::Small).color(Color::Muted)), + ) + } + fn toggle_burn_mode( &mut self, _: &ToggleBurnMode, diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 9d2ee0bf89..a695136562 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1526,6 +1526,7 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } AcpThreadEvent::TitleUpdated + | AcpThreadEvent::TokenUsageUpdated | AcpThreadEvent::EntriesRemoved(_) | AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::Retry(_) => {}