From 1e8b50f471ca11b29278bcd2a29f7d556b93bdaf Mon Sep 17 00:00:00 2001 From: Thomas Mickley-Doyle Date: Wed, 26 Mar 2025 17:21:01 -0500 Subject: [PATCH] Add token usage to `LanguageModelTextStream` (#27490) Release Notes: - N/A --------- Co-authored-by: Michael Sloan --- crates/assistant/src/inline_assistant.rs | 3 +- crates/assistant2/src/buffer_codegen.rs | 21 ++++++++++-- crates/language_model/src/language_model.rs | 38 +++++++++++++++------ 3 files changed, 49 insertions(+), 13 deletions(-) diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 35b8c76866..db48b4e3f2 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -3712,7 +3712,7 @@ mod tests { language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, Point, }; - use language_model::LanguageModelRegistry; + use language_model::{LanguageModelRegistry, TokenUsage}; use rand::prelude::*; use serde::Serialize; use settings::SettingsStore; @@ -4091,6 +4091,7 @@ mod tests { future::ready(Ok(LanguageModelTextStream { message_id: None, stream: chunks_rx.map(Ok).boxed(), + last_token_usage: Arc::new(Mutex::new(TokenUsage::default())), })), cx, ); diff --git a/crates/assistant2/src/buffer_codegen.rs b/crates/assistant2/src/buffer_codegen.rs index 230670b260..d13c31829b 100644 --- a/crates/assistant2/src/buffer_codegen.rs +++ b/crates/assistant2/src/buffer_codegen.rs @@ -482,11 +482,17 @@ impl CodegenAlternative { self.generation = cx.spawn(async move |codegen, cx| { let stream = stream.await; + let token_usage = stream + .as_ref() + .ok() + .map(|stream| stream.last_token_usage.clone()); let message_id = stream .as_ref() .ok() .and_then(|stream| stream.message_id.clone()); let generate = async { + let model_telemetry_id = model_telemetry_id.clone(); + let model_provider_id = model_provider_id.clone(); let (mut diff_tx, mut diff_rx) = mpsc::channel(1); let executor = cx.background_executor().clone(); let message_id = message_id.clone(); @@ -596,7 +602,7 @@ impl CodegenAlternative { kind: AssistantKind::Inline, phase: AssistantPhase::Response, model: model_telemetry_id, - model_provider: model_provider_id.to_string(), + model_provider: model_provider_id, response_latency, error_message, language_name: language_name.map(|name| name.to_proto()), @@ -677,6 +683,16 @@ impl CodegenAlternative { } this.elapsed_time = Some(elapsed_time); this.completion = Some(completion.lock().clone()); + if let Some(usage) = token_usage { + let usage = usage.lock(); + telemetry::event!( + "Inline Assistant Completion", + model = model_telemetry_id, + model_provider = model_provider_id, + input_tokens = usage.input_tokens, + output_tokens = usage.output_tokens, + ) + } cx.emit(CodegenEvent::Finished); cx.notify(); }) @@ -1021,7 +1037,7 @@ mod tests { language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, Point, }; - use language_model::LanguageModelRegistry; + use language_model::{LanguageModelRegistry, TokenUsage}; use rand::prelude::*; use serde::Serialize; use settings::SettingsStore; @@ -1405,6 +1421,7 @@ mod tests { future::ready(Ok(LanguageModelTextStream { message_id: None, stream: chunks_rx.map(Ok).boxed(), + last_token_usage: Arc::new(Mutex::new(TokenUsage::default())), })), cx, ); diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 7620e50732..c876c9c611 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -14,6 +14,7 @@ use futures::FutureExt; use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; use icons::IconName; +use parking_lot::Mutex; use proto::Plan; use schemars::JsonSchema; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -141,6 +142,8 @@ pub struct LanguageModelToolUse { pub struct LanguageModelTextStream { pub message_id: Option, pub stream: BoxStream<'static, Result>, + // Has complete token usage after the stream has finished + pub last_token_usage: Arc>, } impl Default for LanguageModelTextStream { @@ -148,6 +151,7 @@ impl Default for LanguageModelTextStream { Self { message_id: None, stream: Box::pin(futures::stream::empty()), + last_token_usage: Arc::new(Mutex::new(TokenUsage::default())), } } } @@ -200,6 +204,7 @@ pub trait LanguageModel: Send + Sync { let mut events = events.await?.fuse(); let mut message_id = None; let mut first_item_text = None; + let last_token_usage = Arc::new(Mutex::new(TokenUsage::default())); if let Some(first_event) = events.next().await { match first_event { @@ -214,20 +219,33 @@ pub trait LanguageModel: Send + Sync { } let stream = futures::stream::iter(first_item_text.map(Ok)) - .chain(events.filter_map(|result| async move { - match result { - Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, - Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), - Ok(LanguageModelCompletionEvent::Thinking(_)) => None, - Ok(LanguageModelCompletionEvent::Stop(_)) => None, - Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, - Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None, - Err(err) => Some(Err(err)), + .chain(events.filter_map({ + let last_token_usage = last_token_usage.clone(); + move |result| { + let last_token_usage = last_token_usage.clone(); + async move { + match result { + Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, + Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), + Ok(LanguageModelCompletionEvent::Thinking(_)) => None, + Ok(LanguageModelCompletionEvent::Stop(_)) => None, + Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, + Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { + *last_token_usage.lock() = token_usage; + None + } + Err(err) => Some(Err(err)), + } + } } })) .boxed(); - Ok(LanguageModelTextStream { message_id, stream }) + Ok(LanguageModelTextStream { + message_id, + stream, + last_token_usage, + }) } .boxed() }