Track cumulative token usage in assistant2 when using anthropic API (#26738)

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-03-13 16:56:16 -06:00 committed by GitHub
parent e3c0f56a96
commit 8e0e291bd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 136 additions and 25 deletions

1
Cargo.lock generated
View file

@ -7173,6 +7173,7 @@ dependencies = [
"http_client", "http_client",
"language_model", "language_model",
"lmstudio", "lmstudio",
"log",
"menu", "menu",
"mistral", "mistral",
"ollama", "ollama",

View file

@ -553,7 +553,7 @@ pub struct Metadata {
pub user_id: Option<String>, pub user_id: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, Default)]
pub struct Usage { pub struct Usage {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub input_tokens: Option<u32>, pub input_tokens: Option<u32>,

View file

@ -11,7 +11,7 @@ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, Role, StopReason, TokenUsage,
}; };
use project::Project; use project::Project;
use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder}; use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
@ -81,6 +81,7 @@ pub struct Thread {
tool_use: ToolUseState, tool_use: ToolUseState,
scripting_session: Entity<ScriptingSession>, scripting_session: Entity<ScriptingSession>,
scripting_tool_use: ToolUseState, scripting_tool_use: ToolUseState,
cumulative_token_usage: TokenUsage,
} }
impl Thread { impl Thread {
@ -109,6 +110,7 @@ impl Thread {
tool_use: ToolUseState::new(), tool_use: ToolUseState::new(),
scripting_session, scripting_session,
scripting_tool_use: ToolUseState::new(), scripting_tool_use: ToolUseState::new(),
cumulative_token_usage: TokenUsage::default(),
} }
} }
@ -158,6 +160,8 @@ impl Thread {
tool_use, tool_use,
scripting_session, scripting_session,
scripting_tool_use, scripting_tool_use,
// TODO: persist token usage?
cumulative_token_usage: TokenUsage::default(),
} }
} }
@ -490,6 +494,7 @@ impl Thread {
let stream_completion = async { let stream_completion = async {
let mut events = stream.await?; let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn; let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default();
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
let event = event?; let event = event?;
@ -502,6 +507,12 @@ impl Thread {
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason; stop_reason = reason;
} }
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
thread.cumulative_token_usage =
thread.cumulative_token_usage.clone() + token_usage.clone()
- current_token_usage.clone();
current_token_usage = token_usage;
}
LanguageModelCompletionEvent::Text(chunk) => { LanguageModelCompletionEvent::Text(chunk) => {
if let Some(last_message) = thread.messages.last_mut() { if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant { if last_message.role == Role::Assistant {
@ -843,6 +854,10 @@ impl Thread {
Ok(String::from_utf8_lossy(&markdown).to_string()) Ok(String::from_utf8_lossy(&markdown).to_string())
} }
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone()
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View file

@ -2254,6 +2254,7 @@ impl AssistantContext {
); );
} }
LanguageModelCompletionEvent::ToolUse(_) => {} LanguageModelCompletionEvent::ToolUse(_) => {}
LanguageModelCompletionEvent::UsageUpdate(_) => {}
} }
}); });

View file

@ -17,9 +17,11 @@ use proto::Plan;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt; use std::fmt;
use std::ops::{Add, Sub};
use std::{future::Future, sync::Arc}; use std::{future::Future, sync::Arc};
use thiserror::Error; use thiserror::Error;
use ui::IconName; use ui::IconName;
use util::serde::is_default;
pub use crate::model::*; pub use crate::model::*;
pub use crate::rate_limiter::*; pub use crate::rate_limiter::*;
@ -59,6 +61,7 @@ pub enum LanguageModelCompletionEvent {
Text(String), Text(String),
ToolUse(LanguageModelToolUse), ToolUse(LanguageModelToolUse),
StartMessage { message_id: String }, StartMessage { message_id: String },
UsageUpdate(TokenUsage),
} }
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
@ -69,6 +72,46 @@ pub enum StopReason {
ToolUse, ToolUse,
} }
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Default)]
pub struct TokenUsage {
#[serde(default, skip_serializing_if = "is_default")]
pub input_tokens: u32,
#[serde(default, skip_serializing_if = "is_default")]
pub output_tokens: u32,
#[serde(default, skip_serializing_if = "is_default")]
pub cache_creation_input_tokens: u32,
#[serde(default, skip_serializing_if = "is_default")]
pub cache_read_input_tokens: u32,
}
impl Add<TokenUsage> for TokenUsage {
type Output = Self;
fn add(self, other: Self) -> Self {
Self {
input_tokens: self.input_tokens + other.input_tokens,
output_tokens: self.output_tokens + other.output_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens
+ other.cache_creation_input_tokens,
cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
}
}
}
impl Sub<TokenUsage> for TokenUsage {
type Output = Self;
fn sub(self, other: Self) -> Self {
Self {
input_tokens: self.input_tokens - other.input_tokens,
output_tokens: self.output_tokens - other.output_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens
- other.cache_creation_input_tokens,
cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
}
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUseId(Arc<str>); pub struct LanguageModelToolUseId(Arc<str>);
@ -176,6 +219,7 @@ pub trait LanguageModel: Send + Sync {
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None,
Err(err) => Some(Err(err)), Err(err) => Some(Err(err)),
} }
})) }))

View file

@ -33,6 +33,7 @@ gpui_tokio.workspace = true
http_client.workspace = true http_client.workspace = true
language_model.workspace = true language_model.workspace = true
lmstudio = { workspace = true, features = ["schemars"] } lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
menu.workspace = true menu.workspace = true
mistral = { workspace = true, features = ["schemars"] } mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] }

View file

@ -1,6 +1,6 @@
use crate::ui::InstructionListItem; use crate::ui::InstructionListItem;
use crate::AllLanguageModelSettings; use crate::AllLanguageModelSettings;
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent}; use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent, Usage};
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::{BTreeMap, HashMap}; use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider; use credentials_provider::CredentialsProvider;
@ -582,12 +582,16 @@ pub fn map_to_language_model_completion_events(
struct State { struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>, events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>, tool_uses_by_index: HashMap<usize, RawToolUse>,
usage: Usage,
stop_reason: StopReason,
} }
futures::stream::unfold( futures::stream::unfold(
State { State {
events, events,
tool_uses_by_index: HashMap::default(), tool_uses_by_index: HashMap::default(),
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
}, },
|mut state| async move { |mut state| async move {
while let Some(event) = state.events.next().await { while let Some(event) = state.events.next().await {
@ -599,7 +603,7 @@ pub fn map_to_language_model_completion_events(
} => match content_block { } => match content_block {
ResponseContent::Text { text } => { ResponseContent::Text { text } => {
return Some(( return Some((
Some(Ok(LanguageModelCompletionEvent::Text(text))), vec![Ok(LanguageModelCompletionEvent::Text(text))],
state, state,
)); ));
} }
@ -612,28 +616,25 @@ pub fn map_to_language_model_completion_events(
input_json: String::new(), input_json: String::new(),
}, },
); );
return Some((None, state));
} }
}, },
Event::ContentBlockDelta { index, delta } => match delta { Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => { ContentDelta::TextDelta { text } => {
return Some(( return Some((
Some(Ok(LanguageModelCompletionEvent::Text(text))), vec![Ok(LanguageModelCompletionEvent::Text(text))],
state, state,
)); ));
} }
ContentDelta::InputJsonDelta { partial_json } => { ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json); tool_use.input_json.push_str(&partial_json);
return Some((None, state));
} }
} }
}, },
Event::ContentBlockStop { index } => { Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
return Some(( return Some((
Some(maybe!({ vec![maybe!({
Ok(LanguageModelCompletionEvent::ToolUse( Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse { LanguageModelToolUse {
id: tool_use.id.into(), id: tool_use.id.into(),
@ -650,44 +651,63 @@ pub fn map_to_language_model_completion_events(
}, },
}, },
)) ))
})), })],
state, state,
)); ));
} }
} }
Event::MessageStart { message } => { Event::MessageStart { message } => {
update_usage(&mut state.usage, &message.usage);
return Some(( return Some((
Some(Ok(LanguageModelCompletionEvent::StartMessage { vec![
Ok(LanguageModelCompletionEvent::StartMessage {
message_id: message.id, message_id: message.id,
})), }),
state, Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
)) &state.usage,
} ))),
Event::MessageDelta { delta, .. } => { ],
if let Some(stop_reason) = delta.stop_reason.as_deref() {
let stop_reason = match stop_reason {
"end_turn" => StopReason::EndTurn,
"max_tokens" => StopReason::MaxTokens,
"tool_use" => StopReason::ToolUse,
_ => StopReason::EndTurn,
};
return Some((
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
state, state,
)); ));
} }
Event::MessageDelta { delta, usage } => {
update_usage(&mut state.usage, &usage);
if let Some(stop_reason) = delta.stop_reason.as_deref() {
state.stop_reason = match stop_reason {
"end_turn" => StopReason::EndTurn,
"max_tokens" => StopReason::MaxTokens,
"tool_use" => StopReason::ToolUse,
_ => {
log::error!(
"Unexpected anthropic stop_reason: {stop_reason}"
);
StopReason::EndTurn
}
};
}
return Some((
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&state.usage),
))],
state,
));
}
Event::MessageStop => {
return Some((
vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))],
state,
));
} }
Event::Error { error } => { Event::Error { error } => {
return Some(( return Some((
Some(Err(anyhow!(AnthropicError::ApiError(error)))), vec![Err(anyhow!(AnthropicError::ApiError(error)))],
state, state,
)); ));
} }
_ => {} _ => {}
}, },
Err(err) => { Err(err) => {
return Some((Some(Err(anyhow!(err))), state)); return Some((vec![Err(anyhow!(err))], state));
} }
} }
} }
@ -695,7 +715,32 @@ pub fn map_to_language_model_completion_events(
None None
}, },
) )
.filter_map(|event| async move { event }) .flat_map(futures::stream::iter)
}
/// Updates usage data by preferring counts from `new`.
fn update_usage(usage: &mut Usage, new: &Usage) {
if let Some(input_tokens) = new.input_tokens {
usage.input_tokens = Some(input_tokens);
}
if let Some(output_tokens) = new.output_tokens {
usage.output_tokens = Some(output_tokens);
}
if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
}
if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
usage.cache_read_input_tokens = Some(cache_read_input_tokens);
}
}
fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
language_model::TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
}
} }
struct ConfigurationView { struct ConfigurationView {

View file

@ -1,3 +1,7 @@
pub const fn default_true() -> bool { pub const fn default_true() -> bool {
true true
} }
pub fn is_default<T: Default + PartialEq>(value: &T) -> bool {
*value == T::default()
}