language_models: Improve token counting for providers (#32853)

We push the usage data whenever we receive it from the provider to make
sure the counting is correct after the turn has ended.

- [x] Ollama 
- [x] Copilot 
- [x] Mistral 
- [x] OpenRouter 
- [x] LMStudio

Put all the changes into a single PR open to move these to separate PR
if that makes the review and testing easier.

Release Notes:

- N/A
This commit is contained in:
Umesh Yadav 2025-06-17 16:16:29 +05:30 committed by GitHub
parent d4c9522da7
commit ed4b29f80c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 74 additions and 6 deletions

View file

@ -311,6 +311,20 @@ pub struct FunctionContent {
pub struct ResponseEvent { pub struct ResponseEvent {
pub choices: Vec<ResponseChoice>, pub choices: Vec<ResponseChoice>,
pub id: String, pub id: String,
pub usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
pub struct Usage {
pub completion_tokens: u32,
pub prompt_tokens: u32,
pub prompt_tokens_details: PromptTokensDetails,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct PromptTokensDetails {
pub cached_tokens: u32,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]

View file

@ -24,7 +24,7 @@ use language_model::{
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
StopReason, StopReason, TokenUsage,
}; };
use settings::SettingsStore; use settings::SettingsStore;
use std::time::Duration; use std::time::Duration;
@ -378,6 +378,17 @@ pub fn map_to_language_model_completion_events(
} }
} }
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
)));
}
match choice.finish_reason.as_deref() { match choice.finish_reason.as_deref() {
Some("stop") => { Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop( events.push(Ok(LanguageModelCompletionEvent::Stop(

View file

@ -7,7 +7,7 @@ use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
StopReason, StopReason, TokenUsage,
}; };
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@ -528,6 +528,15 @@ impl LmStudioEventMapper {
} }
} }
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
match choice.finish_reason.as_deref() { match choice.finish_reason.as_deref() {
Some("stop") => { Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));

View file

@ -13,7 +13,7 @@ use language_model::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, RateLimiter, Role, StopReason, TokenUsage,
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -626,6 +626,15 @@ impl MistralEventMapper {
} }
} }
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
if let Some(finish_reason) = choice.finish_reason.as_deref() { if let Some(finish_reason) = choice.finish_reason.as_deref() {
match finish_reason { match finish_reason {
"stop" => { "stop" => {

View file

@ -8,7 +8,7 @@ use language_model::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
}; };
use ollama::{ use ollama::{
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool, ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
@ -507,6 +507,12 @@ fn map_to_language_model_completion_events(
}; };
if delta.done { if delta.done {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: delta.prompt_eval_count.unwrap_or(0),
output_tokens: delta.eval_count.unwrap_or(0),
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
if state.used_tools { if state.used_tools {
state.used_tools = false; state.used_tools = false;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));

View file

@ -12,7 +12,7 @@ use language_model::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, RateLimiter, Role, StopReason, TokenUsage,
}; };
use open_router::{Model, ResponseStreamEvent, list_models, stream_completion}; use open_router::{Model, ResponseStreamEvent, list_models, stream_completion};
use schemars::JsonSchema; use schemars::JsonSchema;
@ -467,6 +467,7 @@ pub fn into_open_router(
} else { } else {
None None
}, },
usage: open_router::RequestUsage { include: true },
tools: request tools: request
.tools .tools
.into_iter() .into_iter()
@ -581,6 +582,15 @@ impl OpenRouterEventMapper {
} }
} }
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
match choice.finish_reason.as_deref() { match choice.finish_reason.as_deref() {
Some("stop") => { Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
@ -609,7 +619,7 @@ impl OpenRouterEventMapper {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
} }
Some(stop_reason) => { Some(stop_reason) => {
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",);
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
} }
None => {} None => {}

View file

@ -379,6 +379,7 @@ pub struct StreamResponse {
pub created: u64, pub created: u64,
pub model: String, pub model: String,
pub choices: Vec<StreamChoice>, pub choices: Vec<StreamChoice>,
pub usage: Option<Usage>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]

View file

@ -183,6 +183,8 @@ pub struct ChatResponseDelta {
pub done_reason: Option<String>, pub done_reason: Option<String>,
#[allow(unused)] #[allow(unused)]
pub done: bool, pub done: bool,
pub prompt_eval_count: Option<u32>,
pub eval_count: Option<u32>,
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]

View file

@ -127,6 +127,12 @@ pub struct Request {
pub parallel_tool_calls: Option<bool>, pub parallel_tool_calls: Option<bool>,
#[serde(default, skip_serializing_if = "Vec::is_empty")] #[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>, pub tools: Vec<ToolDefinition>,
pub usage: RequestUsage,
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct RequestUsage {
pub include: bool,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]