Track cumulative token usage in assistant2 when using anthropic API (#26738)
Release Notes: - N/A
This commit is contained in:
parent
e3c0f56a96
commit
8e0e291bd5
8 changed files with 136 additions and 25 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -7173,6 +7173,7 @@ dependencies = [
|
||||||
"http_client",
|
"http_client",
|
||||||
"language_model",
|
"language_model",
|
||||||
"lmstudio",
|
"lmstudio",
|
||||||
|
"log",
|
||||||
"menu",
|
"menu",
|
||||||
"mistral",
|
"mistral",
|
||||||
"ollama",
|
"ollama",
|
||||||
|
|
|
@ -553,7 +553,7 @@ pub struct Metadata {
|
||||||
pub user_id: Option<String>,
|
pub user_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||||
pub struct Usage {
|
pub struct Usage {
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub input_tokens: Option<u32>,
|
pub input_tokens: Option<u32>,
|
||||||
|
|
|
@ -11,7 +11,7 @@ use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||||
Role, StopReason,
|
Role, StopReason, TokenUsage,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
|
use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
|
||||||
|
@ -81,6 +81,7 @@ pub struct Thread {
|
||||||
tool_use: ToolUseState,
|
tool_use: ToolUseState,
|
||||||
scripting_session: Entity<ScriptingSession>,
|
scripting_session: Entity<ScriptingSession>,
|
||||||
scripting_tool_use: ToolUseState,
|
scripting_tool_use: ToolUseState,
|
||||||
|
cumulative_token_usage: TokenUsage,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Thread {
|
impl Thread {
|
||||||
|
@ -109,6 +110,7 @@ impl Thread {
|
||||||
tool_use: ToolUseState::new(),
|
tool_use: ToolUseState::new(),
|
||||||
scripting_session,
|
scripting_session,
|
||||||
scripting_tool_use: ToolUseState::new(),
|
scripting_tool_use: ToolUseState::new(),
|
||||||
|
cumulative_token_usage: TokenUsage::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,6 +160,8 @@ impl Thread {
|
||||||
tool_use,
|
tool_use,
|
||||||
scripting_session,
|
scripting_session,
|
||||||
scripting_tool_use,
|
scripting_tool_use,
|
||||||
|
// TODO: persist token usage?
|
||||||
|
cumulative_token_usage: TokenUsage::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -490,6 +494,7 @@ impl Thread {
|
||||||
let stream_completion = async {
|
let stream_completion = async {
|
||||||
let mut events = stream.await?;
|
let mut events = stream.await?;
|
||||||
let mut stop_reason = StopReason::EndTurn;
|
let mut stop_reason = StopReason::EndTurn;
|
||||||
|
let mut current_token_usage = TokenUsage::default();
|
||||||
|
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
let event = event?;
|
let event = event?;
|
||||||
|
@ -502,6 +507,12 @@ impl Thread {
|
||||||
LanguageModelCompletionEvent::Stop(reason) => {
|
LanguageModelCompletionEvent::Stop(reason) => {
|
||||||
stop_reason = reason;
|
stop_reason = reason;
|
||||||
}
|
}
|
||||||
|
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
|
||||||
|
thread.cumulative_token_usage =
|
||||||
|
thread.cumulative_token_usage.clone() + token_usage.clone()
|
||||||
|
- current_token_usage.clone();
|
||||||
|
current_token_usage = token_usage;
|
||||||
|
}
|
||||||
LanguageModelCompletionEvent::Text(chunk) => {
|
LanguageModelCompletionEvent::Text(chunk) => {
|
||||||
if let Some(last_message) = thread.messages.last_mut() {
|
if let Some(last_message) = thread.messages.last_mut() {
|
||||||
if last_message.role == Role::Assistant {
|
if last_message.role == Role::Assistant {
|
||||||
|
@ -843,6 +854,10 @@ impl Thread {
|
||||||
|
|
||||||
Ok(String::from_utf8_lossy(&markdown).to_string())
|
Ok(String::from_utf8_lossy(&markdown).to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn cumulative_token_usage(&self) -> TokenUsage {
|
||||||
|
self.cumulative_token_usage.clone()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|
|
@ -2254,6 +2254,7 @@ impl AssistantContext {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::ToolUse(_) => {}
|
LanguageModelCompletionEvent::ToolUse(_) => {}
|
||||||
|
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -17,9 +17,11 @@ use proto::Plan;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
use std::ops::{Add, Sub};
|
||||||
use std::{future::Future, sync::Arc};
|
use std::{future::Future, sync::Arc};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ui::IconName;
|
use ui::IconName;
|
||||||
|
use util::serde::is_default;
|
||||||
|
|
||||||
pub use crate::model::*;
|
pub use crate::model::*;
|
||||||
pub use crate::rate_limiter::*;
|
pub use crate::rate_limiter::*;
|
||||||
|
@ -59,6 +61,7 @@ pub enum LanguageModelCompletionEvent {
|
||||||
Text(String),
|
Text(String),
|
||||||
ToolUse(LanguageModelToolUse),
|
ToolUse(LanguageModelToolUse),
|
||||||
StartMessage { message_id: String },
|
StartMessage { message_id: String },
|
||||||
|
UsageUpdate(TokenUsage),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||||
|
@ -69,6 +72,46 @@ pub enum StopReason {
|
||||||
ToolUse,
|
ToolUse,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Default)]
|
||||||
|
pub struct TokenUsage {
|
||||||
|
#[serde(default, skip_serializing_if = "is_default")]
|
||||||
|
pub input_tokens: u32,
|
||||||
|
#[serde(default, skip_serializing_if = "is_default")]
|
||||||
|
pub output_tokens: u32,
|
||||||
|
#[serde(default, skip_serializing_if = "is_default")]
|
||||||
|
pub cache_creation_input_tokens: u32,
|
||||||
|
#[serde(default, skip_serializing_if = "is_default")]
|
||||||
|
pub cache_read_input_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Add<TokenUsage> for TokenUsage {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn add(self, other: Self) -> Self {
|
||||||
|
Self {
|
||||||
|
input_tokens: self.input_tokens + other.input_tokens,
|
||||||
|
output_tokens: self.output_tokens + other.output_tokens,
|
||||||
|
cache_creation_input_tokens: self.cache_creation_input_tokens
|
||||||
|
+ other.cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sub<TokenUsage> for TokenUsage {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn sub(self, other: Self) -> Self {
|
||||||
|
Self {
|
||||||
|
input_tokens: self.input_tokens - other.input_tokens,
|
||||||
|
output_tokens: self.output_tokens - other.output_tokens,
|
||||||
|
cache_creation_input_tokens: self.cache_creation_input_tokens
|
||||||
|
- other.cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
||||||
pub struct LanguageModelToolUseId(Arc<str>);
|
pub struct LanguageModelToolUseId(Arc<str>);
|
||||||
|
|
||||||
|
@ -176,6 +219,7 @@ pub trait LanguageModel: Send + Sync {
|
||||||
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
||||||
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
|
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
|
||||||
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
|
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
|
||||||
|
Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None,
|
||||||
Err(err) => Some(Err(err)),
|
Err(err) => Some(Err(err)),
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
|
@ -33,6 +33,7 @@ gpui_tokio.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
language_model.workspace = true
|
language_model.workspace = true
|
||||||
lmstudio = { workspace = true, features = ["schemars"] }
|
lmstudio = { workspace = true, features = ["schemars"] }
|
||||||
|
log.workspace = true
|
||||||
menu.workspace = true
|
menu.workspace = true
|
||||||
mistral = { workspace = true, features = ["schemars"] }
|
mistral = { workspace = true, features = ["schemars"] }
|
||||||
ollama = { workspace = true, features = ["schemars"] }
|
ollama = { workspace = true, features = ["schemars"] }
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::ui::InstructionListItem;
|
use crate::ui::InstructionListItem;
|
||||||
use crate::AllLanguageModelSettings;
|
use crate::AllLanguageModelSettings;
|
||||||
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent};
|
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent, Usage};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use collections::{BTreeMap, HashMap};
|
use collections::{BTreeMap, HashMap};
|
||||||
use credentials_provider::CredentialsProvider;
|
use credentials_provider::CredentialsProvider;
|
||||||
|
@ -582,12 +582,16 @@ pub fn map_to_language_model_completion_events(
|
||||||
struct State {
|
struct State {
|
||||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||||
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
tool_uses_by_index: HashMap<usize, RawToolUse>,
|
||||||
|
usage: Usage,
|
||||||
|
stop_reason: StopReason,
|
||||||
}
|
}
|
||||||
|
|
||||||
futures::stream::unfold(
|
futures::stream::unfold(
|
||||||
State {
|
State {
|
||||||
events,
|
events,
|
||||||
tool_uses_by_index: HashMap::default(),
|
tool_uses_by_index: HashMap::default(),
|
||||||
|
usage: Usage::default(),
|
||||||
|
stop_reason: StopReason::EndTurn,
|
||||||
},
|
},
|
||||||
|mut state| async move {
|
|mut state| async move {
|
||||||
while let Some(event) = state.events.next().await {
|
while let Some(event) = state.events.next().await {
|
||||||
|
@ -599,7 +603,7 @@ pub fn map_to_language_model_completion_events(
|
||||||
} => match content_block {
|
} => match content_block {
|
||||||
ResponseContent::Text { text } => {
|
ResponseContent::Text { text } => {
|
||||||
return Some((
|
return Some((
|
||||||
Some(Ok(LanguageModelCompletionEvent::Text(text))),
|
vec![Ok(LanguageModelCompletionEvent::Text(text))],
|
||||||
state,
|
state,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
@ -612,28 +616,25 @@ pub fn map_to_language_model_completion_events(
|
||||||
input_json: String::new(),
|
input_json: String::new(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
return Some((None, state));
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Event::ContentBlockDelta { index, delta } => match delta {
|
Event::ContentBlockDelta { index, delta } => match delta {
|
||||||
ContentDelta::TextDelta { text } => {
|
ContentDelta::TextDelta { text } => {
|
||||||
return Some((
|
return Some((
|
||||||
Some(Ok(LanguageModelCompletionEvent::Text(text))),
|
vec![Ok(LanguageModelCompletionEvent::Text(text))],
|
||||||
state,
|
state,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
ContentDelta::InputJsonDelta { partial_json } => {
|
ContentDelta::InputJsonDelta { partial_json } => {
|
||||||
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
|
||||||
tool_use.input_json.push_str(&partial_json);
|
tool_use.input_json.push_str(&partial_json);
|
||||||
return Some((None, state));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Event::ContentBlockStop { index } => {
|
Event::ContentBlockStop { index } => {
|
||||||
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
|
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
|
||||||
return Some((
|
return Some((
|
||||||
Some(maybe!({
|
vec![maybe!({
|
||||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||||
LanguageModelToolUse {
|
LanguageModelToolUse {
|
||||||
id: tool_use.id.into(),
|
id: tool_use.id.into(),
|
||||||
|
@ -650,44 +651,63 @@ pub fn map_to_language_model_completion_events(
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
))
|
))
|
||||||
})),
|
})],
|
||||||
state,
|
state,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Event::MessageStart { message } => {
|
Event::MessageStart { message } => {
|
||||||
|
update_usage(&mut state.usage, &message.usage);
|
||||||
return Some((
|
return Some((
|
||||||
Some(Ok(LanguageModelCompletionEvent::StartMessage {
|
vec![
|
||||||
message_id: message.id,
|
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||||
})),
|
message_id: message.id,
|
||||||
|
}),
|
||||||
|
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
|
||||||
|
&state.usage,
|
||||||
|
))),
|
||||||
|
],
|
||||||
state,
|
state,
|
||||||
))
|
));
|
||||||
}
|
}
|
||||||
Event::MessageDelta { delta, .. } => {
|
Event::MessageDelta { delta, usage } => {
|
||||||
|
update_usage(&mut state.usage, &usage);
|
||||||
if let Some(stop_reason) = delta.stop_reason.as_deref() {
|
if let Some(stop_reason) = delta.stop_reason.as_deref() {
|
||||||
let stop_reason = match stop_reason {
|
state.stop_reason = match stop_reason {
|
||||||
"end_turn" => StopReason::EndTurn,
|
"end_turn" => StopReason::EndTurn,
|
||||||
"max_tokens" => StopReason::MaxTokens,
|
"max_tokens" => StopReason::MaxTokens,
|
||||||
"tool_use" => StopReason::ToolUse,
|
"tool_use" => StopReason::ToolUse,
|
||||||
_ => StopReason::EndTurn,
|
_ => {
|
||||||
|
log::error!(
|
||||||
|
"Unexpected anthropic stop_reason: {stop_reason}"
|
||||||
|
);
|
||||||
|
StopReason::EndTurn
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return Some((
|
|
||||||
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
|
|
||||||
state,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
return Some((
|
||||||
|
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||||
|
convert_usage(&state.usage),
|
||||||
|
))],
|
||||||
|
state,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Event::MessageStop => {
|
||||||
|
return Some((
|
||||||
|
vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))],
|
||||||
|
state,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
Event::Error { error } => {
|
Event::Error { error } => {
|
||||||
return Some((
|
return Some((
|
||||||
Some(Err(anyhow!(AnthropicError::ApiError(error)))),
|
vec![Err(anyhow!(AnthropicError::ApiError(error)))],
|
||||||
state,
|
state,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
},
|
},
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
return Some((Some(Err(anyhow!(err))), state));
|
return Some((vec![Err(anyhow!(err))], state));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -695,7 +715,32 @@ pub fn map_to_language_model_completion_events(
|
||||||
None
|
None
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.filter_map(|event| async move { event })
|
.flat_map(futures::stream::iter)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Updates usage data by preferring counts from `new`.
|
||||||
|
fn update_usage(usage: &mut Usage, new: &Usage) {
|
||||||
|
if let Some(input_tokens) = new.input_tokens {
|
||||||
|
usage.input_tokens = Some(input_tokens);
|
||||||
|
}
|
||||||
|
if let Some(output_tokens) = new.output_tokens {
|
||||||
|
usage.output_tokens = Some(output_tokens);
|
||||||
|
}
|
||||||
|
if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
|
||||||
|
usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
|
||||||
|
}
|
||||||
|
if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
|
||||||
|
usage.cache_read_input_tokens = Some(cache_read_input_tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
|
||||||
|
language_model::TokenUsage {
|
||||||
|
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||||
|
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||||
|
cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
|
||||||
|
cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConfigurationView {
|
struct ConfigurationView {
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
pub const fn default_true() -> bool {
|
pub const fn default_true() -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_default<T: Default + PartialEq>(value: &T) -> bool {
|
||||||
|
*value == T::default()
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue