From ed4b29f80c971df820ad1477102ed4ccf034450f Mon Sep 17 00:00:00 2001 From: Umesh Yadav <23421535+imumesh18@users.noreply.github.com> Date: Tue, 17 Jun 2025 16:16:29 +0530 Subject: [PATCH] language_models: Improve token counting for providers (#32853) We push the usage data whenever we receive it from the provider to make sure the counting is correct after the turn has ended. - [x] Ollama - [x] Copilot - [x] Mistral - [x] OpenRouter - [x] LMStudio Put all the changes into a single PR open to move these to separate PR if that makes the review and testing easier. Release Notes: - N/A --- crates/copilot/src/copilot_chat.rs | 14 ++++++++++++++ .../language_models/src/provider/copilot_chat.rs | 13 ++++++++++++- crates/language_models/src/provider/lmstudio.rs | 11 ++++++++++- crates/language_models/src/provider/mistral.rs | 11 ++++++++++- crates/language_models/src/provider/ollama.rs | 8 +++++++- crates/language_models/src/provider/open_router.rs | 14 ++++++++++++-- crates/mistral/src/mistral.rs | 1 + crates/ollama/src/ollama.rs | 2 ++ crates/open_router/src/open_router.rs | 6 ++++++ 9 files changed, 74 insertions(+), 6 deletions(-) diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index 0f81df2e08..c89e4cdb98 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -311,6 +311,20 @@ pub struct FunctionContent { pub struct ResponseEvent { pub choices: Vec, pub id: String, + pub usage: Option, +} + +#[derive(Deserialize, Debug)] +pub struct Usage { + pub completion_tokens: u32, + pub prompt_tokens: u32, + pub prompt_tokens_details: PromptTokensDetails, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct PromptTokensDetails { + pub cached_tokens: u32, } #[derive(Debug, Deserialize)] diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 475f77c318..e0ccbcbae6 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -24,7 +24,7 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, - StopReason, + StopReason, TokenUsage, }; use settings::SettingsStore; use std::time::Duration; @@ -378,6 +378,17 @@ pub fn map_to_language_model_completion_events( } } + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + ))); + } + match choice.finish_reason.as_deref() { Some("stop") => { events.push(Ok(LanguageModelCompletionEvent::Stop( diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 8cb0829c2a..0a75ef2f88 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -7,7 +7,7 @@ use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - StopReason, + StopReason, TokenUsage, }; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, @@ -528,6 +528,15 @@ impl LmStudioEventMapper { } } + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + match choice.finish_reason.as_deref() { Some("stop") => { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index d00af8ecd6..84b7131c7d 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -13,7 +13,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, + RateLimiter, Role, StopReason, TokenUsage, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -626,6 +626,15 @@ impl MistralEventMapper { } } + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + if let Some(finish_reason) = choice.finish_reason.as_deref() { match finish_reason { "stop" => { diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index fca78e4791..42ccd97089 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -8,7 +8,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, + LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; use ollama::{ ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool, @@ -507,6 +507,12 @@ fn map_to_language_model_completion_events( }; if delta.done { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: delta.prompt_eval_count.unwrap_or(0), + output_tokens: delta.eval_count.unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); if state.used_tools { state.used_tools = false; events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 3d1cefa07f..450d56a1b2 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -12,7 +12,7 @@ use language_model::{ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, + RateLimiter, Role, StopReason, TokenUsage, }; use open_router::{Model, ResponseStreamEvent, list_models, stream_completion}; use schemars::JsonSchema; @@ -467,6 +467,7 @@ pub fn into_open_router( } else { None }, + usage: open_router::RequestUsage { include: true }, tools: request .tools .into_iter() @@ -581,6 +582,15 @@ impl OpenRouterEventMapper { } } + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + match choice.finish_reason.as_deref() { Some("stop") => { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); @@ -609,7 +619,7 @@ impl OpenRouterEventMapper { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); } Some(stop_reason) => { - log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); } None => {} diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index 7ad3b1c294..4fc976860c 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -379,6 +379,7 @@ pub struct StreamResponse { pub created: u64, pub model: String, pub choices: Vec, + pub usage: Option, } #[derive(Serialize, Deserialize, Debug)] diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 95a7ded680..e17b08cde6 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -183,6 +183,8 @@ pub struct ChatResponseDelta { pub done_reason: Option, #[allow(unused)] pub done: bool, + pub prompt_eval_count: Option, + pub eval_count: Option, } #[derive(Serialize, Deserialize)] diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs index ad3009b48f..407ed416ec 100644 --- a/crates/open_router/src/open_router.rs +++ b/crates/open_router/src/open_router.rs @@ -127,6 +127,12 @@ pub struct Request { pub parallel_tool_calls: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, + pub usage: RequestUsage, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct RequestUsage { + pub include: bool, } #[derive(Debug, Serialize, Deserialize)]