language_models: Improve token counting for providers (#32853)
We push the usage data whenever we receive it from the provider to make sure the counting is correct after the turn has ended. - [x] Ollama - [x] Copilot - [x] Mistral - [x] OpenRouter - [x] LMStudio Put all the changes into a single PR open to move these to separate PR if that makes the review and testing easier. Release Notes: - N/A
This commit is contained in:
parent
d4c9522da7
commit
ed4b29f80c
9 changed files with 74 additions and 6 deletions
|
@ -311,6 +311,20 @@ pub struct FunctionContent {
|
|||
pub struct ResponseEvent {
|
||||
pub choices: Vec<ResponseChoice>,
|
||||
pub id: String,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Usage {
|
||||
pub completion_tokens: u32,
|
||||
pub prompt_tokens: u32,
|
||||
pub prompt_tokens_details: PromptTokensDetails,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PromptTokensDetails {
|
||||
pub cached_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
|
@ -24,7 +24,7 @@ use language_model::{
|
|||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
|
||||
StopReason,
|
||||
StopReason, TokenUsage,
|
||||
};
|
||||
use settings::SettingsStore;
|
||||
use std::time::Duration;
|
||||
|
@ -378,6 +378,17 @@ pub fn map_to_language_model_completion_events(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
},
|
||||
)));
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
|
|
|
@ -7,7 +7,7 @@ use http_client::HttpClient;
|
|||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
StopReason,
|
||||
StopReason, TokenUsage,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
|
@ -528,6 +528,15 @@ impl LmStudioEventMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
|
|
|
@ -13,7 +13,7 @@ use language_model::{
|
|||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -626,6 +626,15 @@ impl MistralEventMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
if let Some(finish_reason) = choice.finish_reason.as_deref() {
|
||||
match finish_reason {
|
||||
"stop" => {
|
||||
|
|
|
@ -8,7 +8,7 @@ use language_model::{
|
|||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
|
||||
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason,
|
||||
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use ollama::{
|
||||
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
|
||||
|
@ -507,6 +507,12 @@ fn map_to_language_model_completion_events(
|
|||
};
|
||||
|
||||
if delta.done {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: delta.prompt_eval_count.unwrap_or(0),
|
||||
output_tokens: delta.eval_count.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
if state.used_tools {
|
||||
state.used_tools = false;
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
|
|
|
@ -12,7 +12,7 @@ use language_model::{
|
|||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use open_router::{Model, ResponseStreamEvent, list_models, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
|
@ -467,6 +467,7 @@ pub fn into_open_router(
|
|||
} else {
|
||||
None
|
||||
},
|
||||
usage: open_router::RequestUsage { include: true },
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
|
@ -581,6 +582,15 @@ impl OpenRouterEventMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
|
@ -609,7 +619,7 @@ impl OpenRouterEventMapper {
|
|||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
Some(stop_reason) => {
|
||||
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
|
||||
log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",);
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
None => {}
|
||||
|
|
|
@ -379,6 +379,7 @@ pub struct StreamResponse {
|
|||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<StreamChoice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
|
|
|
@ -183,6 +183,8 @@ pub struct ChatResponseDelta {
|
|||
pub done_reason: Option<String>,
|
||||
#[allow(unused)]
|
||||
pub done: bool,
|
||||
pub prompt_eval_count: Option<u32>,
|
||||
pub eval_count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
|
|
@ -127,6 +127,12 @@ pub struct Request {
|
|||
pub parallel_tool_calls: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
pub usage: RequestUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
pub struct RequestUsage {
|
||||
pub include: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue