diff --git a/Cargo.lock b/Cargo.lock index 5ca8da4976..63504ee912 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7173,6 +7173,7 @@ dependencies = [ "http_client", "language_model", "lmstudio", + "log", "menu", "mistral", "ollama", diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index c64d621143..e886e38976 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -553,7 +553,7 @@ pub struct Metadata { pub user_id: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Default)] pub struct Usage { #[serde(default, skip_serializing_if = "Option::is_none")] pub input_tokens: Option, diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 7aa9040dce..0898cd9606 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -11,7 +11,7 @@ use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, - Role, StopReason, + Role, StopReason, TokenUsage, }; use project::Project; use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder}; @@ -81,6 +81,7 @@ pub struct Thread { tool_use: ToolUseState, scripting_session: Entity, scripting_tool_use: ToolUseState, + cumulative_token_usage: TokenUsage, } impl Thread { @@ -109,6 +110,7 @@ impl Thread { tool_use: ToolUseState::new(), scripting_session, scripting_tool_use: ToolUseState::new(), + cumulative_token_usage: TokenUsage::default(), } } @@ -158,6 +160,8 @@ impl Thread { tool_use, scripting_session, scripting_tool_use, + // TODO: persist token usage? + cumulative_token_usage: TokenUsage::default(), } } @@ -490,6 +494,7 @@ impl Thread { let stream_completion = async { let mut events = stream.await?; let mut stop_reason = StopReason::EndTurn; + let mut current_token_usage = TokenUsage::default(); while let Some(event) = events.next().await { let event = event?; @@ -502,6 +507,12 @@ impl Thread { LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; } + LanguageModelCompletionEvent::UsageUpdate(token_usage) => { + thread.cumulative_token_usage = + thread.cumulative_token_usage.clone() + token_usage.clone() + - current_token_usage.clone(); + current_token_usage = token_usage; + } LanguageModelCompletionEvent::Text(chunk) => { if let Some(last_message) = thread.messages.last_mut() { if last_message.role == Role::Assistant { @@ -843,6 +854,10 @@ impl Thread { Ok(String::from_utf8_lossy(&markdown).to_string()) } + + pub fn cumulative_token_usage(&self) -> TokenUsage { + self.cumulative_token_usage.clone() + } } #[derive(Debug, Clone)] diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index 5431ee9081..4120f5b22e 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -2254,6 +2254,7 @@ impl AssistantContext { ); } LanguageModelCompletionEvent::ToolUse(_) => {} + LanguageModelCompletionEvent::UsageUpdate(_) => {} } }); diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 7f759e8586..8ed190e731 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -17,9 +17,11 @@ use proto::Plan; use schemars::JsonSchema; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::fmt; +use std::ops::{Add, Sub}; use std::{future::Future, sync::Arc}; use thiserror::Error; use ui::IconName; +use util::serde::is_default; pub use crate::model::*; pub use crate::rate_limiter::*; @@ -59,6 +61,7 @@ pub enum LanguageModelCompletionEvent { Text(String), ToolUse(LanguageModelToolUse), StartMessage { message_id: String }, + UsageUpdate(TokenUsage), } #[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] @@ -69,6 +72,46 @@ pub enum StopReason { ToolUse, } +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Default)] +pub struct TokenUsage { + #[serde(default, skip_serializing_if = "is_default")] + pub input_tokens: u32, + #[serde(default, skip_serializing_if = "is_default")] + pub output_tokens: u32, + #[serde(default, skip_serializing_if = "is_default")] + pub cache_creation_input_tokens: u32, + #[serde(default, skip_serializing_if = "is_default")] + pub cache_read_input_tokens: u32, +} + +impl Add for TokenUsage { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self { + input_tokens: self.input_tokens + other.input_tokens, + output_tokens: self.output_tokens + other.output_tokens, + cache_creation_input_tokens: self.cache_creation_input_tokens + + other.cache_creation_input_tokens, + cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens, + } + } +} + +impl Sub for TokenUsage { + type Output = Self; + + fn sub(self, other: Self) -> Self { + Self { + input_tokens: self.input_tokens - other.input_tokens, + output_tokens: self.output_tokens - other.output_tokens, + cache_creation_input_tokens: self.cache_creation_input_tokens + - other.cache_creation_input_tokens, + cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens, + } + } +} + #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] pub struct LanguageModelToolUseId(Arc); @@ -176,6 +219,7 @@ pub trait LanguageModel: Send + Sync { Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Stop(_)) => None, Ok(LanguageModelCompletionEvent::ToolUse(_)) => None, + Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None, Err(err) => Some(Err(err)), } })) diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index cda61c1cfa..7fd9638ba5 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -33,6 +33,7 @@ gpui_tokio.workspace = true http_client.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } +log.workspace = true menu.workspace = true mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index c82bdc72a3..cc55d11595 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,6 +1,6 @@ use crate::ui::InstructionListItem; use crate::AllLanguageModelSettings; -use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent}; +use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent, Usage}; use anyhow::{anyhow, Context as _, Result}; use collections::{BTreeMap, HashMap}; use credentials_provider::CredentialsProvider; @@ -582,12 +582,16 @@ pub fn map_to_language_model_completion_events( struct State { events: Pin>>>, tool_uses_by_index: HashMap, + usage: Usage, + stop_reason: StopReason, } futures::stream::unfold( State { events, tool_uses_by_index: HashMap::default(), + usage: Usage::default(), + stop_reason: StopReason::EndTurn, }, |mut state| async move { while let Some(event) = state.events.next().await { @@ -599,7 +603,7 @@ pub fn map_to_language_model_completion_events( } => match content_block { ResponseContent::Text { text } => { return Some(( - Some(Ok(LanguageModelCompletionEvent::Text(text))), + vec![Ok(LanguageModelCompletionEvent::Text(text))], state, )); } @@ -612,28 +616,25 @@ pub fn map_to_language_model_completion_events( input_json: String::new(), }, ); - - return Some((None, state)); } }, Event::ContentBlockDelta { index, delta } => match delta { ContentDelta::TextDelta { text } => { return Some(( - Some(Ok(LanguageModelCompletionEvent::Text(text))), + vec![Ok(LanguageModelCompletionEvent::Text(text))], state, )); } ContentDelta::InputJsonDelta { partial_json } => { if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { tool_use.input_json.push_str(&partial_json); - return Some((None, state)); } } }, Event::ContentBlockStop { index } => { if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { return Some(( - Some(maybe!({ + vec![maybe!({ Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id: tool_use.id.into(), @@ -650,44 +651,63 @@ pub fn map_to_language_model_completion_events( }, }, )) - })), + })], state, )); } } Event::MessageStart { message } => { + update_usage(&mut state.usage, &message.usage); return Some(( - Some(Ok(LanguageModelCompletionEvent::StartMessage { - message_id: message.id, - })), + vec![ + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), + Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( + &state.usage, + ))), + ], state, - )) + )); } - Event::MessageDelta { delta, .. } => { + Event::MessageDelta { delta, usage } => { + update_usage(&mut state.usage, &usage); if let Some(stop_reason) = delta.stop_reason.as_deref() { - let stop_reason = match stop_reason { + state.stop_reason = match stop_reason { "end_turn" => StopReason::EndTurn, "max_tokens" => StopReason::MaxTokens, "tool_use" => StopReason::ToolUse, - _ => StopReason::EndTurn, + _ => { + log::error!( + "Unexpected anthropic stop_reason: {stop_reason}" + ); + StopReason::EndTurn + } }; - - return Some(( - Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))), - state, - )); } + return Some(( + vec![Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&state.usage), + ))], + state, + )); + } + Event::MessageStop => { + return Some(( + vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))], + state, + )); } Event::Error { error } => { return Some(( - Some(Err(anyhow!(AnthropicError::ApiError(error)))), + vec![Err(anyhow!(AnthropicError::ApiError(error)))], state, )); } _ => {} }, Err(err) => { - return Some((Some(Err(anyhow!(err))), state)); + return Some((vec![Err(anyhow!(err))], state)); } } } @@ -695,7 +715,32 @@ pub fn map_to_language_model_completion_events( None }, ) - .filter_map(|event| async move { event }) + .flat_map(futures::stream::iter) +} + +/// Updates usage data by preferring counts from `new`. +fn update_usage(usage: &mut Usage, new: &Usage) { + if let Some(input_tokens) = new.input_tokens { + usage.input_tokens = Some(input_tokens); + } + if let Some(output_tokens) = new.output_tokens { + usage.output_tokens = Some(output_tokens); + } + if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { + usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); + } + if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { + usage.cache_read_input_tokens = Some(cache_read_input_tokens); + } +} + +fn convert_usage(usage: &Usage) -> language_model::TokenUsage { + language_model::TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), + } } struct ConfigurationView { diff --git a/crates/util/src/serde.rs b/crates/util/src/serde.rs index be948c659f..4aa4bb1a49 100644 --- a/crates/util/src/serde.rs +++ b/crates/util/src/serde.rs @@ -1,3 +1,7 @@ pub const fn default_true() -> bool { true } + +pub fn is_default(value: &T) -> bool { + *value == T::default() +}