Track cumulative token usage in assistant2 when using anthropic API (#26738)

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-03-13 16:56:16 -06:00 committed by GitHub
parent e3c0f56a96
commit 8e0e291bd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 136 additions and 25 deletions

View file

@ -33,6 +33,7 @@ gpui_tokio.workspace = true
http_client.workspace = true
language_model.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }

View file

@ -1,6 +1,6 @@
use crate::ui::InstructionListItem;
use crate::AllLanguageModelSettings;
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent};
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent, Usage};
use anyhow::{anyhow, Context as _, Result};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
@ -582,12 +582,16 @@ pub fn map_to_language_model_completion_events(
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>,
usage: Usage,
stop_reason: StopReason,
}
futures::stream::unfold(
State {
events,
tool_uses_by_index: HashMap::default(),
usage: Usage::default(),
stop_reason: StopReason::EndTurn,
},
|mut state| async move {
while let Some(event) = state.events.next().await {
@ -599,7 +603,7 @@ pub fn map_to_language_model_completion_events(
} => match content_block {
ResponseContent::Text { text } => {
return Some((
Some(Ok(LanguageModelCompletionEvent::Text(text))),
vec![Ok(LanguageModelCompletionEvent::Text(text))],
state,
));
}
@ -612,28 +616,25 @@ pub fn map_to_language_model_completion_events(
input_json: String::new(),
},
);
return Some((None, state));
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
return Some((
Some(Ok(LanguageModelCompletionEvent::Text(text))),
vec![Ok(LanguageModelCompletionEvent::Text(text))],
state,
));
}
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
return Some((None, state));
}
}
},
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
return Some((
Some(maybe!({
vec![maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
@ -650,44 +651,63 @@ pub fn map_to_language_model_completion_events(
},
},
))
})),
})],
state,
));
}
}
Event::MessageStart { message } => {
update_usage(&mut state.usage, &message.usage);
return Some((
Some(Ok(LanguageModelCompletionEvent::StartMessage {
message_id: message.id,
})),
vec![
Ok(LanguageModelCompletionEvent::StartMessage {
message_id: message.id,
}),
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
&state.usage,
))),
],
state,
))
));
}
Event::MessageDelta { delta, .. } => {
Event::MessageDelta { delta, usage } => {
update_usage(&mut state.usage, &usage);
if let Some(stop_reason) = delta.stop_reason.as_deref() {
let stop_reason = match stop_reason {
state.stop_reason = match stop_reason {
"end_turn" => StopReason::EndTurn,
"max_tokens" => StopReason::MaxTokens,
"tool_use" => StopReason::ToolUse,
_ => StopReason::EndTurn,
_ => {
log::error!(
"Unexpected anthropic stop_reason: {stop_reason}"
);
StopReason::EndTurn
}
};
return Some((
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
state,
));
}
return Some((
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&state.usage),
))],
state,
));
}
Event::MessageStop => {
return Some((
vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))],
state,
));
}
Event::Error { error } => {
return Some((
Some(Err(anyhow!(AnthropicError::ApiError(error)))),
vec![Err(anyhow!(AnthropicError::ApiError(error)))],
state,
));
}
_ => {}
},
Err(err) => {
return Some((Some(Err(anyhow!(err))), state));
return Some((vec![Err(anyhow!(err))], state));
}
}
}
@ -695,7 +715,32 @@ pub fn map_to_language_model_completion_events(
None
},
)
.filter_map(|event| async move { event })
.flat_map(futures::stream::iter)
}
/// Updates usage data by preferring counts from `new`.
fn update_usage(usage: &mut Usage, new: &Usage) {
if let Some(input_tokens) = new.input_tokens {
usage.input_tokens = Some(input_tokens);
}
if let Some(output_tokens) = new.output_tokens {
usage.output_tokens = Some(output_tokens);
}
if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
}
if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
usage.cache_read_input_tokens = Some(cache_read_input_tokens);
}
}
fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
language_model::TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
}
}
struct ConfigurationView {