Track cumulative token usage in assistant2 when using anthropic API (#26738)
Release Notes: - N/A
This commit is contained in:
parent
e3c0f56a96
commit
8e0e291bd5
8 changed files with 136 additions and 25 deletions
|
@ -33,6 +33,7 @@ gpui_tokio.workspace = true
|
|||
http_client.workspace = true
|
||||
language_model.workspace = true
|
||||
lmstudio = { workspace = true, features = ["schemars"] }
|
||||
log.workspace = true
|
||||
menu.workspace = true
|
||||
mistral = { workspace = true, features = ["schemars"] }
|
||||
ollama = { workspace = true, features = ["schemars"] }
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::ui::InstructionListItem;
|
||||
use crate::AllLanguageModelSettings;
|
||||
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent};
|
||||
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent, Usage};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
|
@ -582,12 +582,16 @@ pub fn map_to_language_model_completion_events(
|
|||
struct State {
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
||||
usage: Usage,
|
||||
stop_reason: StopReason,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
usage: Usage::default(),
|
||||
stop_reason: StopReason::EndTurn,
|
||||
},
|
||||
|mut state| async move {
|
||||
while let Some(event) = state.events.next().await {
|
||||
|
@ -599,7 +603,7 @@ pub fn map_to_language_model_completion_events(
|
|||
} => match content_block {
|
||||
ResponseContent::Text { text } => {
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::Text(text))),
|
||||
vec![Ok(LanguageModelCompletionEvent::Text(text))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
|
@ -612,28 +616,25 @@ pub fn map_to_language_model_completion_events(
|
|||
input_json: String::new(),
|
||||
},
|
||||
);
|
||||
|
||||
return Some((None, state));
|
||||
}
|
||||
},
|
||||
Event::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => {
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::Text(text))),
|
||||
vec![Ok(LanguageModelCompletionEvent::Text(text))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
||||
tool_use.input_json.push_str(&partial_json);
|
||||
return Some((None, state));
|
||||
}
|
||||
}
|
||||
},
|
||||
Event::ContentBlockStop { index } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
|
||||
return Some((
|
||||
Some(maybe!({
|
||||
vec![maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
|
@ -650,44 +651,63 @@ pub fn map_to_language_model_completion_events(
|
|||
},
|
||||
},
|
||||
))
|
||||
})),
|
||||
})],
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
Event::MessageStart { message } => {
|
||||
update_usage(&mut state.usage, &message.usage);
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::StartMessage {
|
||||
message_id: message.id,
|
||||
})),
|
||||
vec![
|
||||
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||
message_id: message.id,
|
||||
}),
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
|
||||
&state.usage,
|
||||
))),
|
||||
],
|
||||
state,
|
||||
))
|
||||
));
|
||||
}
|
||||
Event::MessageDelta { delta, .. } => {
|
||||
Event::MessageDelta { delta, usage } => {
|
||||
update_usage(&mut state.usage, &usage);
|
||||
if let Some(stop_reason) = delta.stop_reason.as_deref() {
|
||||
let stop_reason = match stop_reason {
|
||||
state.stop_reason = match stop_reason {
|
||||
"end_turn" => StopReason::EndTurn,
|
||||
"max_tokens" => StopReason::MaxTokens,
|
||||
"tool_use" => StopReason::ToolUse,
|
||||
_ => StopReason::EndTurn,
|
||||
_ => {
|
||||
log::error!(
|
||||
"Unexpected anthropic stop_reason: {stop_reason}"
|
||||
);
|
||||
StopReason::EndTurn
|
||||
}
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
|
||||
state,
|
||||
));
|
||||
}
|
||||
return Some((
|
||||
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&state.usage),
|
||||
))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
Event::MessageStop => {
|
||||
return Some((
|
||||
vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
Event::Error { error } => {
|
||||
return Some((
|
||||
Some(Err(anyhow!(AnthropicError::ApiError(error)))),
|
||||
vec![Err(anyhow!(AnthropicError::ApiError(error)))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Err(err) => {
|
||||
return Some((Some(Err(anyhow!(err))), state));
|
||||
return Some((vec![Err(anyhow!(err))], state));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -695,7 +715,32 @@ pub fn map_to_language_model_completion_events(
|
|||
None
|
||||
},
|
||||
)
|
||||
.filter_map(|event| async move { event })
|
||||
.flat_map(futures::stream::iter)
|
||||
}
|
||||
|
||||
/// Updates usage data by preferring counts from `new`.
|
||||
fn update_usage(usage: &mut Usage, new: &Usage) {
|
||||
if let Some(input_tokens) = new.input_tokens {
|
||||
usage.input_tokens = Some(input_tokens);
|
||||
}
|
||||
if let Some(output_tokens) = new.output_tokens {
|
||||
usage.output_tokens = Some(output_tokens);
|
||||
}
|
||||
if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
|
||||
usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
|
||||
}
|
||||
if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
|
||||
usage.cache_read_input_tokens = Some(cache_read_input_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
|
||||
language_model::TokenUsage {
|
||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||
cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
|
||||
cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue