collab: Add usage-based billing for LLM interactions (#19081)

This PR adds usage-based billing for LLM interactions in the Assistant.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Marshall Bowers 2024-10-11 13:36:54 -04:00 committed by GitHub
parent f1c45d988e
commit 22ea7cef7a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 918 additions and 280 deletions

View file

@ -20,13 +20,14 @@ use axum::{
};
use chrono::{DateTime, Duration, Utc};
use collections::HashMap;
use db::TokenUsage;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _};
use isahc_http_client::IsahcHttpClient;
use rpc::ListModelsResponse;
use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use std::{
pin::Pin,
sync::Arc,
@ -418,10 +419,7 @@ async fn perform_completion(
claims,
provider: params.provider,
model,
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
tokens: TokenUsage::default(),
inner_stream: stream,
})))
}
@ -476,6 +474,19 @@ async fn check_usage_limit(
"Maximum spending limit reached for this month.".to_string(),
));
}
if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) {
return Err(Error::Http(
StatusCode::FORBIDDEN,
"Maximum spending limit reached for this month.".to_string(),
[(
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
HeaderValue::from_static("true"),
)]
.into_iter()
.collect(),
));
}
}
}
@ -598,10 +609,7 @@ struct TokenCountingStream<S> {
claims: LlmTokenClaims,
provider: LanguageModelProvider,
model: String,
input_tokens: usize,
output_tokens: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
tokens: TokenUsage,
inner_stream: S,
}
@ -615,10 +623,10 @@ where
match Pin::new(&mut self.inner_stream).poll_next(cx) {
Poll::Ready(Some(Ok(mut chunk))) => {
chunk.bytes.push(b'\n');
self.input_tokens += chunk.input_tokens;
self.output_tokens += chunk.output_tokens;
self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
self.cache_read_input_tokens += chunk.cache_read_input_tokens;
self.tokens.input += chunk.input_tokens;
self.tokens.output += chunk.output_tokens;
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
Poll::Ready(Some(Ok(chunk.bytes)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
@ -634,10 +642,7 @@ impl<S> Drop for TokenCountingStream<S> {
let claims = self.claims.clone();
let provider = self.provider;
let model = std::mem::take(&mut self.model);
let input_token_count = self.input_tokens;
let output_token_count = self.output_tokens;
let cache_creation_input_token_count = self.cache_creation_input_tokens;
let cache_read_input_token_count = self.cache_read_input_tokens;
let tokens = self.tokens;
self.state.executor.spawn_detached(async move {
let usage = state
.db
@ -646,10 +651,9 @@ impl<S> Drop for TokenCountingStream<S> {
claims.is_staff,
provider,
&model,
input_token_count,
cache_creation_input_token_count,
cache_read_input_token_count,
output_token_count,
tokens,
claims.has_llm_subscription,
Cents(claims.max_monthly_spend_in_cents),
Utc::now(),
)
.await
@ -679,22 +683,23 @@ impl<S> Drop for TokenCountingStream<S> {
},
model,
provider: provider.to_string(),
input_token_count: input_token_count as u64,
cache_creation_input_token_count: cache_creation_input_token_count
as u64,
cache_read_input_token_count: cache_read_input_token_count as u64,
output_token_count: output_token_count as u64,
input_token_count: tokens.input as u64,
cache_creation_input_token_count: tokens.input_cache_creation as u64,
cache_read_input_token_count: tokens.input_cache_read as u64,
output_token_count: tokens.output as u64,
requests_this_minute: usage.requests_this_minute as u64,
tokens_this_minute: usage.tokens_this_minute as u64,
tokens_this_day: usage.tokens_this_day as u64,
input_tokens_this_month: usage.input_tokens_this_month as u64,
input_tokens_this_month: usage.tokens_this_month.input as u64,
cache_creation_input_tokens_this_month: usage
.cache_creation_input_tokens_this_month
.tokens_this_month
.input_cache_creation
as u64,
cache_read_input_tokens_this_month: usage
.cache_read_input_tokens_this_month
.tokens_this_month
.input_cache_read
as u64,
output_tokens_this_month: usage.output_tokens_this_month as u64,
output_tokens_this_month: usage.tokens_this_month.output as u64,
spending_this_month: usage.spending_this_month.0 as u64,
lifetime_spending: usage.lifetime_spending.0 as u64,
},