collab: Remove LLM service (#28728)

This PR removes the LLM service from collab, as it has been moved to
Cloudflare.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-14 19:47:14 -04:00 committed by GitHub
parent 12b012eab3
commit fc1252b0cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 8 additions and 2315 deletions

View file

@ -1,330 +0,0 @@
use reqwest::StatusCode;
use rpc::LanguageModelProvider;
use crate::llm::LlmTokenClaims;
use crate::{Config, Error, Result};
pub fn authorize_access_to_language_model(
config: &Config,
claims: &LlmTokenClaims,
country_code: Option<&str>,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
authorize_access_for_country(config, country_code, provider)?;
authorize_access_to_model(config, claims, provider, model)?;
Ok(())
}
fn authorize_access_to_model(
config: &Config,
claims: &LlmTokenClaims,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
if claims.is_staff {
return Ok(());
}
if provider == LanguageModelProvider::Anthropic {
if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
return Ok(());
}
if claims.has_llm_closed_beta_feature_flag
&& Some(model) == config.llm_closed_beta_model_name.as_deref()
{
return Ok(());
}
}
Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),
))
}
fn authorize_access_for_country(
config: &Config,
country_code: Option<&str>,
provider: LanguageModelProvider,
) -> Result<()> {
// In development we won't have the `CF-IPCountry` header, so we can't check
// the country code.
//
// This shouldn't be necessary, as anyone running in development will need to provide
// their own API credentials in order to use an LLM provider.
if config.is_development() {
return Ok(());
}
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
let country_code = match country_code {
// `XX` - Used for clients without country code data.
None | Some("XX") => Err(Error::http(
StatusCode::BAD_REQUEST,
"no country code".to_string(),
))?,
// `T1` - Used for clients using the Tor network.
Some("T1") => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to {provider:?} models is not available over Tor"),
))?,
Some(country_code) => country_code,
};
let is_country_supported_by_provider = match provider {
LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
};
if !is_country_supported_by_provider {
Err(Error::http(
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
format!(
"access to {provider:?} models is not available in your region ({country_code})"
),
))?
}
Ok(())
}
#[cfg(test)]
mod tests {
use axum::response::IntoResponse;
use pretty_assertions::assert_eq;
use rpc::proto::Plan;
use super::*;
#[gpui::test]
async fn test_authorize_access_to_language_model_with_supported_country(
_cx: &mut gpui::TestAppContext,
) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
is_staff: true,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "US"), // United States
(LanguageModelProvider::Anthropic, "GB"), // United Kingdom
(LanguageModelProvider::OpenAi, "US"), // United States
(LanguageModelProvider::OpenAi, "GB"), // United Kingdom
(LanguageModelProvider::Google, "US"), // United States
(LanguageModelProvider::Google, "GB"), // United Kingdom
];
for (provider, country_code) in cases {
authorize_access_to_language_model(
&config,
&claims,
Some(country_code),
provider,
"the-model",
)
.unwrap_or_else(|_| {
panic!("expected authorization to return Ok for {provider:?}: {country_code}")
})
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_with_unsupported_country(
_cx: &mut gpui::TestAppContext,
) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "AF"), // Afghanistan
(LanguageModelProvider::Anthropic, "BY"), // Belarus
(LanguageModelProvider::Anthropic, "CF"), // Central African Republic
(LanguageModelProvider::Anthropic, "CN"), // China
(LanguageModelProvider::Anthropic, "CU"), // Cuba
(LanguageModelProvider::Anthropic, "ER"), // Eritrea
(LanguageModelProvider::Anthropic, "ET"), // Ethiopia
(LanguageModelProvider::Anthropic, "IR"), // Iran
(LanguageModelProvider::Anthropic, "KP"), // North Korea
(LanguageModelProvider::Anthropic, "XK"), // Kosovo
(LanguageModelProvider::Anthropic, "LY"), // Libya
(LanguageModelProvider::Anthropic, "MM"), // Myanmar
(LanguageModelProvider::Anthropic, "RU"), // Russia
(LanguageModelProvider::Anthropic, "SO"), // Somalia
(LanguageModelProvider::Anthropic, "SS"), // South Sudan
(LanguageModelProvider::Anthropic, "SD"), // Sudan
(LanguageModelProvider::Anthropic, "SY"), // Syria
(LanguageModelProvider::Anthropic, "VE"), // Venezuela
(LanguageModelProvider::Anthropic, "YE"), // Yemen
(LanguageModelProvider::OpenAi, "KP"), // North Korea
(LanguageModelProvider::Google, "KP"), // North Korea
];
for (provider, country_code) in cases {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code),
provider,
"the-model",
)
.expect_err(&format!(
"expected authorization to return an error for {provider:?}: {country_code}"
))
.into_response();
assert_eq!(
error_response.status(),
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
);
let response_body = hyper::body::to_bytes(error_response.into_body())
.await
.unwrap()
.to_vec();
assert_eq!(
String::from_utf8(response_body).unwrap(),
format!(
"access to {provider:?} models is not available in your region ({country_code})"
)
);
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "T1"), // Tor
(LanguageModelProvider::OpenAi, "T1"), // Tor
(LanguageModelProvider::Google, "T1"), // Tor
];
for (provider, country_code) in cases {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code),
provider,
"the-model",
)
.expect_err(&format!(
"expected authorization to return an error for {provider:?}: {country_code}"
))
.into_response();
assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
let response_body = hyper::body::to_bytes(error_response.into_body())
.await
.unwrap()
.to_vec();
assert_eq!(
String::from_utf8(response_body).unwrap(),
format!("access to {provider:?} models is not available over Tor")
);
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_based_on_plan() {
let config = Config::test();
let test_cases = vec![
// Pro plan should have access to claude-3.5-sonnet
(
Plan::ZedPro,
LanguageModelProvider::Anthropic,
"claude-3-5-sonnet",
true,
),
// Free plan should have access to claude-3.5-sonnet
(
Plan::Free,
LanguageModelProvider::Anthropic,
"claude-3-5-sonnet",
true,
),
// Pro plan should NOT have access to other Anthropic models
(
Plan::ZedPro,
LanguageModelProvider::Anthropic,
"claude-3-opus",
false,
),
];
for (plan, provider, model, expected_access) in test_cases {
let claims = LlmTokenClaims {
plan,
..Default::default()
};
let result =
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
if expected_access {
assert!(
result.is_ok(),
"Expected access to be granted for plan {:?}, provider {:?}, model {}",
plan,
provider,
model
);
} else {
let error = result.expect_err(&format!(
"Expected access to be denied for plan {:?}, provider {:?}, model {}",
plan, provider, model
));
let response = error.into_response();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_for_staff() {
let config = Config::test();
let claims = LlmTokenClaims {
is_staff: true,
..Default::default()
};
// Staff should have access to all models
let test_cases = vec![
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
(LanguageModelProvider::Anthropic, "claude-2"),
(LanguageModelProvider::Anthropic, "claude-123-agi"),
(LanguageModelProvider::OpenAi, "gpt-4"),
(LanguageModelProvider::Google, "gemini-pro"),
];
for (provider, model) in test_cases {
let result =
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
assert!(
result.is_ok(),
"Expected staff to have access to provider {:?}, model {}",
provider,
model
);
}
}
}

View file

@ -20,7 +20,6 @@ use std::future::Future;
use std::sync::Arc;
use anyhow::anyhow;
pub use queries::usages::{ActiveUserCount, TokenUsage};
pub use sea_orm::ConnectOptions;
use sea_orm::prelude::*;
use sea_orm::{

View file

@ -2,5 +2,4 @@ use super::*;
pub mod billing_events;
pub mod providers;
pub mod revoked_access_tokens;
pub mod usages;

View file

@ -1,15 +0,0 @@
use super::*;
impl LlmDatabase {
/// Returns whether the access token with the given `jti` has been revoked.
pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
self.transaction(|tx| async move {
Ok(revoked_access_token::Entity::find()
.filter(revoked_access_token::Column::Jti.eq(jti))
.one(&*tx)
.await?
.is_some())
})
.await
}
}

View file

@ -1,56 +1,12 @@
use crate::db::UserId;
use crate::llm::Cents;
use chrono::{Datelike, Duration};
use chrono::Datelike;
use futures::StreamExt as _;
use rpc::LanguageModelProvider;
use sea_orm::QuerySelect;
use std::{iter, str::FromStr};
use std::str::FromStr;
use strum::IntoEnumIterator as _;
use super::*;
#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
pub struct TokenUsage {
pub input: usize,
pub input_cache_creation: usize,
pub input_cache_read: usize,
pub output: usize,
}
impl TokenUsage {
pub fn total(&self) -> usize {
self.input + self.input_cache_creation + self.input_cache_read + self.output
}
}
#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
pub struct Usage {
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
pub input_tokens_this_minute: usize,
pub output_tokens_this_minute: usize,
pub tokens_this_day: usize,
pub tokens_this_month: TokenUsage,
pub spending_this_month: Cents,
pub lifetime_spending: Cents,
}
#[derive(Debug, PartialEq, Clone)]
pub struct ApplicationWideUsage {
pub provider: LanguageModelProvider,
pub model: String,
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
pub input_tokens_this_minute: usize,
pub output_tokens_this_minute: usize,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ActiveUserCount {
pub users_in_recent_minutes: usize,
pub users_in_recent_days: usize,
}
impl LlmDatabase {
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
let all_measures = self
@ -90,100 +46,6 @@ impl LlmDatabase {
Ok(())
}
pub async fn get_application_wide_usages_by_model(
&self,
now: DateTimeUtc,
) -> Result<Vec<ApplicationWideUsage>> {
self.transaction(|tx| async move {
let past_minute = now - Duration::minutes(1);
let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
let input_tokens_per_minute =
self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute];
let output_tokens_per_minute =
self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute];
let mut results = Vec::new();
for ((provider, model_name), model) in self.models.iter() {
let mut usages = usage::Entity::find()
.filter(
usage::Column::Timestamp
.gte(past_minute.naive_utc())
.and(usage::Column::IsStaff.eq(false))
.and(usage::Column::ModelId.eq(model.id))
.and(
usage::Column::MeasureId
.eq(requests_per_minute)
.or(usage::Column::MeasureId.eq(tokens_per_minute)),
),
)
.stream(&*tx)
.await?;
let mut requests_this_minute = 0;
let mut tokens_this_minute = 0;
let mut input_tokens_this_minute = 0;
let mut output_tokens_this_minute = 0;
while let Some(usage) = usages.next().await {
let usage = usage?;
if usage.measure_id == requests_per_minute {
requests_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::RequestsPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == tokens_per_minute {
tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::TokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == input_tokens_per_minute {
input_tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::InputTokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == output_tokens_per_minute {
output_tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::OutputTokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
}
}
results.push(ApplicationWideUsage {
provider: *provider,
model: model_name.clone(),
requests_this_minute,
tokens_this_minute,
input_tokens_this_minute,
output_tokens_this_minute,
})
}
Ok(results)
})
.await
}
pub async fn get_user_spending_for_month(
&self,
user_id: UserId,
@ -223,499 +85,6 @@ impl LlmDatabase {
})
.await
}
pub async fn get_usage(
&self,
user_id: UserId,
provider: LanguageModelProvider,
model_name: &str,
now: DateTimeUtc,
) -> Result<Usage> {
self.transaction(|tx| async move {
let model = self
.models
.get(&(provider, model_name.to_string()))
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
let usages = usage::Entity::find()
.filter(
usage::Column::UserId
.eq(user_id)
.and(usage::Column::ModelId.eq(model.id)),
)
.all(&*tx)
.await?;
let month = now.date_naive().month() as i32;
let year = now.date_naive().year();
let monthly_usage = monthly_usage::Entity::find()
.filter(
monthly_usage::Column::UserId
.eq(user_id)
.and(monthly_usage::Column::ModelId.eq(model.id))
.and(monthly_usage::Column::Month.eq(month))
.and(monthly_usage::Column::Year.eq(year)),
)
.one(&*tx)
.await?;
let lifetime_usage = lifetime_usage::Entity::find()
.filter(
lifetime_usage::Column::UserId
.eq(user_id)
.and(lifetime_usage::Column::ModelId.eq(model.id)),
)
.one(&*tx)
.await?;
let requests_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
let tokens_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
let input_tokens_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?;
let output_tokens_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?;
let tokens_this_day =
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
calculate_spending(
model,
monthly_usage.input_tokens as usize,
monthly_usage.cache_creation_input_tokens as usize,
monthly_usage.cache_read_input_tokens as usize,
monthly_usage.output_tokens as usize,
)
} else {
Cents::ZERO
};
let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
calculate_spending(
model,
lifetime_usage.input_tokens as usize,
lifetime_usage.cache_creation_input_tokens as usize,
lifetime_usage.cache_read_input_tokens as usize,
lifetime_usage.output_tokens as usize,
)
} else {
Cents::ZERO
};
Ok(Usage {
requests_this_minute,
tokens_this_minute,
input_tokens_this_minute,
output_tokens_this_minute,
tokens_this_day,
tokens_this_month: TokenUsage {
input: monthly_usage
.as_ref()
.map_or(0, |usage| usage.input_tokens as usize),
input_cache_creation: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
input_cache_read: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
output: monthly_usage
.as_ref()
.map_or(0, |usage| usage.output_tokens as usize),
},
spending_this_month,
lifetime_spending,
})
})
.await
}
pub async fn record_usage(
&self,
user_id: UserId,
is_staff: bool,
provider: LanguageModelProvider,
model_name: &str,
tokens: TokenUsage,
has_llm_subscription: bool,
max_monthly_spend: Cents,
free_tier_monthly_spending_limit: Cents,
now: DateTimeUtc,
) -> Result<Usage> {
self.transaction(|tx| async move {
let model = self.model(provider, model_name)?;
let usages = usage::Entity::find()
.filter(
usage::Column::UserId
.eq(user_id)
.and(usage::Column::ModelId.eq(model.id)),
)
.all(&*tx)
.await?;
let requests_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::RequestsPerMinute,
now,
1,
&tx,
)
.await?;
let tokens_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::TokensPerMinute,
now,
tokens.total(),
&tx,
)
.await?;
let input_tokens_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::InputTokensPerMinute,
now,
// Cache read input tokens are not counted for the purposes of rate limits (but they are still billed).
tokens.input + tokens.input_cache_creation,
&tx,
)
.await?;
let output_tokens_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::OutputTokensPerMinute,
now,
tokens.output,
&tx,
)
.await?;
let tokens_this_day = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::TokensPerDay,
now,
tokens.total(),
&tx,
)
.await?;
let month = now.date_naive().month() as i32;
let year = now.date_naive().year();
// Update monthly usage
let monthly_usage = monthly_usage::Entity::find()
.filter(
monthly_usage::Column::UserId
.eq(user_id)
.and(monthly_usage::Column::ModelId.eq(model.id))
.and(monthly_usage::Column::Month.eq(month))
.and(monthly_usage::Column::Year.eq(year)),
)
.one(&*tx)
.await?;
let monthly_usage = match monthly_usage {
Some(usage) => {
monthly_usage::Entity::update(monthly_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
.await?
}
None => {
monthly_usage::ActiveModel {
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
month: ActiveValue::set(month),
year: ActiveValue::set(year),
input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default()
}
.insert(&*tx)
.await?
}
};
let spending_this_month = calculate_spending(
model,
monthly_usage.input_tokens as usize,
monthly_usage.cache_creation_input_tokens as usize,
monthly_usage.cache_read_input_tokens as usize,
monthly_usage.output_tokens as usize,
);
if !is_staff
&& spending_this_month > free_tier_monthly_spending_limit
&& has_llm_subscription
&& (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend
{
billing_event::ActiveModel {
id: ActiveValue::not_set(),
idempotency_key: ActiveValue::not_set(),
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(tokens.input as i64),
input_cache_creation_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
}
.insert(&*tx)
.await?;
}
// Update lifetime usage
let lifetime_usage = lifetime_usage::Entity::find()
.filter(
lifetime_usage::Column::UserId
.eq(user_id)
.and(lifetime_usage::Column::ModelId.eq(model.id)),
)
.one(&*tx)
.await?;
let lifetime_usage = match lifetime_usage {
Some(usage) => {
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
.await?
}
None => {
lifetime_usage::ActiveModel {
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default()
}
.insert(&*tx)
.await?
}
};
let lifetime_spending = calculate_spending(
model,
lifetime_usage.input_tokens as usize,
lifetime_usage.cache_creation_input_tokens as usize,
lifetime_usage.cache_read_input_tokens as usize,
lifetime_usage.output_tokens as usize,
);
Ok(Usage {
requests_this_minute,
tokens_this_minute,
input_tokens_this_minute,
output_tokens_this_minute,
tokens_this_day,
tokens_this_month: TokenUsage {
input: monthly_usage.input_tokens as usize,
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
output: monthly_usage.output_tokens as usize,
},
spending_this_month,
lifetime_spending,
})
})
.await
}
/// Returns the active user count for the specified model.
pub async fn get_active_user_count(
&self,
provider: LanguageModelProvider,
model_name: &str,
now: DateTimeUtc,
) -> Result<ActiveUserCount> {
self.transaction(|tx| async move {
let minute_since = now - Duration::minutes(5);
let day_since = now - Duration::days(5);
let model = self
.models
.get(&(provider, model_name.to_string()))
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
let users_in_recent_minutes = usage::Entity::find()
.filter(
usage::Column::ModelId
.eq(model.id)
.and(usage::Column::MeasureId.eq(tokens_per_minute))
.and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()
.column(usage::Column::UserId)
.group_by(usage::Column::UserId)
.count(&*tx)
.await? as usize;
let users_in_recent_days = usage::Entity::find()
.filter(
usage::Column::ModelId
.eq(model.id)
.and(usage::Column::MeasureId.eq(tokens_per_minute))
.and(usage::Column::Timestamp.gte(day_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()
.column(usage::Column::UserId)
.group_by(usage::Column::UserId)
.count(&*tx)
.await? as usize;
Ok(ActiveUserCount {
users_in_recent_minutes,
users_in_recent_days,
})
})
.await
}
async fn update_usage_for_measure(
&self,
user_id: UserId,
is_staff: bool,
model_id: ModelId,
usages: &[usage::Model],
usage_measure: UsageMeasure,
now: DateTimeUtc,
usage_to_add: usize,
tx: &DatabaseTransaction,
) -> Result<usize> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
.get(&usage_measure)
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
let mut id = None;
let mut timestamp = now;
let mut buckets = vec![0_i64];
if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
id = Some(old_usage.id);
let (live_buckets, buckets_since) =
Self::get_live_buckets(old_usage, now, usage_measure);
if !live_buckets.is_empty() {
buckets.clear();
buckets.extend_from_slice(live_buckets);
buckets.extend(iter::repeat(0).take(buckets_since));
timestamp =
old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
}
}
*buckets.last_mut().unwrap() += usage_to_add as i64;
let total_usage = buckets.iter().sum::<i64>() as usize;
let mut model = usage::ActiveModel {
user_id: ActiveValue::set(user_id),
is_staff: ActiveValue::set(is_staff),
model_id: ActiveValue::set(model_id),
measure_id: ActiveValue::set(measure_id),
timestamp: ActiveValue::set(timestamp),
buckets: ActiveValue::set(buckets),
..Default::default()
};
if let Some(id) = id {
model.id = ActiveValue::unchanged(id);
model.update(tx).await?;
} else {
usage::Entity::insert(model)
.exec_without_returning(tx)
.await?;
}
Ok(total_usage)
}
fn get_usage_for_measure(
&self,
usages: &[usage::Model],
now: DateTimeUtc,
usage_measure: UsageMeasure,
) -> Result<usize> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
.get(&usage_measure)
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
return Ok(0);
};
let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
Ok(live_buckets.iter().sum::<i64>() as _)
}
fn get_live_buckets(
usage: &usage::Model,
now: chrono::NaiveDateTime,
measure: UsageMeasure,
) -> (&[i64], usize) {
let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
let buckets_since_usage =
seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
let buckets_since_usage = buckets_since_usage.ceil() as usize;
let mut live_buckets = &[] as &[i64];
if buckets_since_usage < measure.bucket_count() {
let expired_bucket_count =
(usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
live_buckets = &usage.buckets[expired_bucket_count..];
while live_buckets.first() == Some(&0) {
live_buckets = &live_buckets[1..];
}
}
(live_buckets, buckets_since_usage)
}
}
fn calculate_spending(
@ -741,32 +110,3 @@ fn calculate_spending(
+ output_token_cost;
Cents::new(spending as u32)
}
const MINUTE_BUCKET_COUNT: usize = 12;
const DAY_BUCKET_COUNT: usize = 48;
impl UsageMeasure {
fn bucket_count(&self) -> usize {
match self {
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerMinute
| UsageMeasure::InputTokensPerMinute
| UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
}
}
fn total_duration(&self) -> Duration {
match self {
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerMinute
| UsageMeasure::InputTokensPerMinute
| UsageMeasure::OutputTokensPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerDay => Duration::hours(24),
}
}
fn bucket_duration(&self) -> Duration {
self.total_duration() / self.bucket_count() as i32
}
}

View file

@ -1,8 +1,6 @@
pub mod billing_event;
pub mod lifetime_usage;
pub mod model;
pub mod monthly_usage;
pub mod provider;
pub mod revoked_access_token;
pub mod usage;
pub mod usage_measure;

View file

@ -1,20 +0,0 @@
use crate::{db::UserId, llm::db::ModelId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "lifetime_usages")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub user_id: UserId,
pub model_id: ModelId,
pub input_tokens: i64,
pub cache_creation_input_tokens: i64,
pub cache_read_input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,19 +0,0 @@
use chrono::NaiveDateTime;
use sea_orm::entity::prelude::*;
use crate::llm::db::RevokedAccessTokenId;
/// A revoked access token.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "revoked_access_tokens")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: RevokedAccessTokenId,
pub jti: String,
pub revoked_at: NaiveDateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,6 +1,4 @@
mod billing_tests;
mod provider_tests;
mod usage_tests;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;

View file

@ -1,152 +0,0 @@
use crate::{
Cents,
db::UserId,
llm::{
FREE_TIER_MONTHLY_SPENDING_LIMIT,
db::{LlmDatabase, TokenUsage, queries::providers::ModelParams},
},
test_llm_db,
};
use chrono::{DateTime, Utc};
use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
test_llm_db!(
test_billing_limit_exceeded,
test_billing_limit_exceeded_postgres
);
async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
let provider = LanguageModelProvider::Anthropic;
let model = "fake-claude-limerick";
const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
// Initialize the database and insert the model
db.initialize().await.unwrap();
db.insert_models(&[ModelParams {
provider,
name: model.to_string(),
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 50_000,
price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
}])
.await
.unwrap();
// Set a fixed datetime for consistent testing
let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
.unwrap()
.with_timezone(&Utc);
let user_id = UserId::from_proto(123);
let max_monthly_spend = Cents::from_dollars(11);
// Record usage that brings us close to the limit but doesn't exceed it
// Let's say we use $10.50 worth of tokens
let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
let usage = TokenUsage {
input: tokens_to_use,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
// Verify that before we record any usage, there are 0 billing events
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 0);
db.record_usage(
user_id,
false,
provider,
model,
usage,
true,
max_monthly_spend,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
// Verify the recorded usage and spending
let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
// Verify that we exceeded the free tier usage
assert_eq!(recorded_usage.spending_this_month, Cents::new(1050));
assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT);
// Verify that there is one `billing_event` record
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 1);
let (billing_event, _model) = &billing_events[0];
assert_eq!(billing_event.user_id, user_id);
assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
assert_eq!(billing_event.input_cache_creation_tokens, 0);
assert_eq!(billing_event.input_cache_read_tokens, 0);
assert_eq!(billing_event.output_tokens, 0);
// Record usage that puts us at $20.50
let usage_2 = TokenUsage {
input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
db.record_usage(
user_id,
false,
provider,
model,
usage_2,
true,
max_monthly_spend,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
// Verify the updated usage and spending
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(updated_usage.spending_this_month, Cents::new(2050));
// Verify that there are now two billing events
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 2);
let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit
let usage_exceeding = TokenUsage {
input: tokens_to_exceed,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
// This should still create a billing event as it's the first request that exceeds the limit
db.record_usage(
user_id,
false,
provider,
model,
usage_exceeding,
true,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
max_monthly_spend,
now,
)
.await
.unwrap();
// Verify the updated usage and spending
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(updated_usage.spending_this_month, Cents::new(2150));
// Verify that we never exceed the user max spending for the user
// and avoid charging them.
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 2);
}

View file

@ -1,306 +0,0 @@
use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT;
use crate::{
Cents,
db::UserId,
llm::db::{
LlmDatabase, TokenUsage,
queries::{providers::ModelParams, usages::Usage},
},
test_llm_db,
};
use chrono::{DateTime, Duration, Utc};
use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
async fn test_tracking_usage(db: &mut LlmDatabase) {
let provider = LanguageModelProvider::Anthropic;
let model = "claude-3-5-sonnet";
db.initialize().await.unwrap();
db.insert_models(&[ModelParams {
provider,
name: model.to_string(),
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 50_000,
price_per_million_input_tokens: 50,
price_per_million_output_tokens: 50,
}])
.await
.unwrap();
// We're using a fixed datetime to prevent flakiness based on the clock.
let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
.unwrap()
.with_timezone(&Utc);
let user_id = UserId::from_proto(123);
let now = t0;
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let now = t0 + Duration::seconds(10);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 2000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 3000,
input_tokens_this_minute: 3000,
output_tokens_this_minute: 0,
tokens_this_day: 3000,
tokens_this_month: TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let now = t0 + Duration::seconds(60);
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 2000,
input_tokens_this_minute: 2000,
output_tokens_this_minute: 0,
tokens_this_day: 3000,
tokens_this_month: TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let now = t0 + Duration::seconds(60);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 5000,
input_tokens_this_minute: 5000,
output_tokens_this_minute: 0,
tokens_this_day: 6000,
tokens_this_month: TokenUsage {
input: 6000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let t1 = t0 + Duration::hours(24);
let now = t1;
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 0,
tokens_this_minute: 0,
input_tokens_this_minute: 0,
output_tokens_this_minute: 0,
tokens_this_day: 5000,
tokens_this_month: TokenUsage {
input: 6000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 4000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 4000,
input_tokens_this_minute: 4000,
output_tokens_this_minute: 0,
tokens_this_day: 9000,
tokens_this_month: TokenUsage {
input: 10000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
// We're using a fixed datetime to prevent flakiness based on the clock.
let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
.unwrap()
.with_timezone(&Utc);
// Test cache creation input tokens
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 500,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 1500,
input_tokens_this_minute: 1500,
output_tokens_this_minute: 0,
tokens_this_day: 1500,
tokens_this_month: TokenUsage {
input: 1000,
input_cache_creation: 500,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
// Test cache read input tokens
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 300,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 2800,
input_tokens_this_minute: 2500,
output_tokens_this_minute: 0,
tokens_this_day: 2800,
tokens_this_month: TokenUsage {
input: 2000,
input_cache_creation: 500,
input_cache_read: 300,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
}