collab: Add usage-based billing for LLM interactions (#19081)
This PR adds usage-based billing for LLM interactions in the Assistant. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Richard <richard@zed.dev> Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
f1c45d988e
commit
22ea7cef7a
20 changed files with 918 additions and 280 deletions
|
@ -20,13 +20,14 @@ use axum::{
|
|||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use collections::HashMap;
|
||||
use db::TokenUsage;
|
||||
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
||||
use futures::{Stream, StreamExt as _};
|
||||
use isahc_http_client::IsahcHttpClient;
|
||||
use rpc::ListModelsResponse;
|
||||
use rpc::{
|
||||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
};
|
||||
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
|
@ -418,10 +419,7 @@ async fn perform_completion(
|
|||
claims,
|
||||
provider: params.provider,
|
||||
model,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
tokens: TokenUsage::default(),
|
||||
inner_stream: stream,
|
||||
})))
|
||||
}
|
||||
|
@ -476,6 +474,19 @@ async fn check_usage_limit(
|
|||
"Maximum spending limit reached for this month.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) {
|
||||
return Err(Error::Http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
[(
|
||||
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
|
||||
HeaderValue::from_static("true"),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -598,10 +609,7 @@ struct TokenCountingStream<S> {
|
|||
claims: LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: String,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
cache_creation_input_tokens: usize,
|
||||
cache_read_input_tokens: usize,
|
||||
tokens: TokenUsage,
|
||||
inner_stream: S,
|
||||
}
|
||||
|
||||
|
@ -615,10 +623,10 @@ where
|
|||
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(mut chunk))) => {
|
||||
chunk.bytes.push(b'\n');
|
||||
self.input_tokens += chunk.input_tokens;
|
||||
self.output_tokens += chunk.output_tokens;
|
||||
self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
|
||||
self.cache_read_input_tokens += chunk.cache_read_input_tokens;
|
||||
self.tokens.input += chunk.input_tokens;
|
||||
self.tokens.output += chunk.output_tokens;
|
||||
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
|
||||
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
|
||||
Poll::Ready(Some(Ok(chunk.bytes)))
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||
|
@ -634,10 +642,7 @@ impl<S> Drop for TokenCountingStream<S> {
|
|||
let claims = self.claims.clone();
|
||||
let provider = self.provider;
|
||||
let model = std::mem::take(&mut self.model);
|
||||
let input_token_count = self.input_tokens;
|
||||
let output_token_count = self.output_tokens;
|
||||
let cache_creation_input_token_count = self.cache_creation_input_tokens;
|
||||
let cache_read_input_token_count = self.cache_read_input_tokens;
|
||||
let tokens = self.tokens;
|
||||
self.state.executor.spawn_detached(async move {
|
||||
let usage = state
|
||||
.db
|
||||
|
@ -646,10 +651,9 @@ impl<S> Drop for TokenCountingStream<S> {
|
|||
claims.is_staff,
|
||||
provider,
|
||||
&model,
|
||||
input_token_count,
|
||||
cache_creation_input_token_count,
|
||||
cache_read_input_token_count,
|
||||
output_token_count,
|
||||
tokens,
|
||||
claims.has_llm_subscription,
|
||||
Cents(claims.max_monthly_spend_in_cents),
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
|
@ -679,22 +683,23 @@ impl<S> Drop for TokenCountingStream<S> {
|
|||
},
|
||||
model,
|
||||
provider: provider.to_string(),
|
||||
input_token_count: input_token_count as u64,
|
||||
cache_creation_input_token_count: cache_creation_input_token_count
|
||||
as u64,
|
||||
cache_read_input_token_count: cache_read_input_token_count as u64,
|
||||
output_token_count: output_token_count as u64,
|
||||
input_token_count: tokens.input as u64,
|
||||
cache_creation_input_token_count: tokens.input_cache_creation as u64,
|
||||
cache_read_input_token_count: tokens.input_cache_read as u64,
|
||||
output_token_count: tokens.output as u64,
|
||||
requests_this_minute: usage.requests_this_minute as u64,
|
||||
tokens_this_minute: usage.tokens_this_minute as u64,
|
||||
tokens_this_day: usage.tokens_this_day as u64,
|
||||
input_tokens_this_month: usage.input_tokens_this_month as u64,
|
||||
input_tokens_this_month: usage.tokens_this_month.input as u64,
|
||||
cache_creation_input_tokens_this_month: usage
|
||||
.cache_creation_input_tokens_this_month
|
||||
.tokens_this_month
|
||||
.input_cache_creation
|
||||
as u64,
|
||||
cache_read_input_tokens_this_month: usage
|
||||
.cache_read_input_tokens_this_month
|
||||
.tokens_this_month
|
||||
.input_cache_read
|
||||
as u64,
|
||||
output_tokens_this_month: usage.output_tokens_this_month as u64,
|
||||
output_tokens_this_month: usage.tokens_this_month.output as u64,
|
||||
spending_this_month: usage.spending_this_month.0 as u64,
|
||||
lifetime_spending: usage.lifetime_spending.0 as u64,
|
||||
},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue