collab: Remove LLM service (#28728)
This PR removes the LLM service from collab, as it has been moved to Cloudflare. Release Notes: - N/A
This commit is contained in:
parent
12b012eab3
commit
fc1252b0cd
17 changed files with 8 additions and 2315 deletions
8
.github/workflows/deploy_collab.yml
vendored
8
.github/workflows/deploy_collab.yml
vendored
|
@ -117,12 +117,10 @@ jobs:
|
|||
export ZED_KUBE_NAMESPACE=production
|
||||
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10
|
||||
export ZED_API_LOAD_BALANCER_SIZE_UNIT=2
|
||||
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=2
|
||||
elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then
|
||||
export ZED_KUBE_NAMESPACE=staging
|
||||
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1
|
||||
export ZED_API_LOAD_BALANCER_SIZE_UNIT=1
|
||||
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=1
|
||||
else
|
||||
echo "cowardly refusing to deploy from an unknown branch"
|
||||
exit 1
|
||||
|
@ -147,9 +145,3 @@ jobs:
|
|||
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
|
||||
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
|
||||
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
|
||||
|
||||
export ZED_SERVICE_NAME=llm
|
||||
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_LLM_LOAD_BALANCER_SIZE_UNIT
|
||||
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
|
||||
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
|
||||
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
|
||||
|
|
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2942,7 +2942,6 @@ dependencies = [
|
|||
name = "collab"
|
||||
version = "0.44.0"
|
||||
dependencies = [
|
||||
"anthropic",
|
||||
"anyhow",
|
||||
"assistant",
|
||||
"assistant_context_editor",
|
||||
|
|
|
@ -18,7 +18,6 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
|
|||
test-support = ["sqlite"]
|
||||
|
||||
[dependencies]
|
||||
anthropic.workspace = true
|
||||
anyhow.workspace = true
|
||||
async-stripe.workspace = true
|
||||
async-tungstenite.workspace = true
|
||||
|
|
|
@ -253,7 +253,6 @@ impl Config {
|
|||
pub enum ServiceMode {
|
||||
Api,
|
||||
Collab,
|
||||
Llm,
|
||||
All,
|
||||
}
|
||||
|
||||
|
@ -265,10 +264,6 @@ impl ServiceMode {
|
|||
pub fn is_api(&self) -> bool {
|
||||
matches!(self, Self::Api | Self::All)
|
||||
}
|
||||
|
||||
pub fn is_llm(&self) -> bool {
|
||||
matches!(self, Self::Llm | Self::All)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
|
|
|
@ -1,448 +1,10 @@
|
|||
mod authorization;
|
||||
pub mod db;
|
||||
mod token;
|
||||
|
||||
use crate::api::CloudflareIpCountryHeader;
|
||||
use crate::api::events::SnowflakeRow;
|
||||
use crate::build_kinesis_client;
|
||||
use crate::rpc::MIN_ACCOUNT_AGE_FOR_LLM_USE;
|
||||
use crate::{Cents, Config, Error, Result, db::UserId, executor::Executor};
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use authorization::authorize_access_to_language_model;
|
||||
use axum::routing::get;
|
||||
use axum::{
|
||||
Extension, Json, Router, TypedHeader,
|
||||
body::Body,
|
||||
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||
middleware::{self, Next},
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use collections::HashMap;
|
||||
use db::TokenUsage;
|
||||
use db::{ActiveUserCount, LlmDatabase, usage_measure::UsageMeasure};
|
||||
use futures::{Stream, StreamExt as _};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use rpc::{
|
||||
EXPIRED_LLM_TOKEN_HEADER_NAME, LanguageModelProvider, PerformCompletionParams, proto::Plan,
|
||||
};
|
||||
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
|
||||
use serde_json::json;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use tokio::sync::RwLock;
|
||||
use util::ResultExt;
|
||||
use crate::Cents;
|
||||
|
||||
pub use token::*;
|
||||
|
||||
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
|
||||
|
||||
pub struct LlmState {
|
||||
pub config: Config,
|
||||
pub executor: Executor,
|
||||
pub db: Arc<LlmDatabase>,
|
||||
pub http_client: ReqwestClient,
|
||||
pub kinesis_client: Option<aws_sdk_kinesis::Client>,
|
||||
active_user_count_by_model:
|
||||
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
|
||||
}
|
||||
|
||||
impl LlmState {
|
||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||
let database_url = config
|
||||
.llm_database_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
|
||||
let max_connections = config
|
||||
.llm_database_max_connections
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
|
||||
|
||||
let mut db_options = db::ConnectOptions::new(database_url);
|
||||
db_options.max_connections(max_connections);
|
||||
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
|
||||
db.initialize().await?;
|
||||
|
||||
let db = Arc::new(db);
|
||||
|
||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||
let http_client =
|
||||
ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
|
||||
|
||||
let this = Self {
|
||||
executor,
|
||||
db,
|
||||
http_client,
|
||||
kinesis_client: if config.kinesis_access_key.is_some() {
|
||||
build_kinesis_client(&config).await.log_err()
|
||||
} else {
|
||||
None
|
||||
},
|
||||
active_user_count_by_model: RwLock::new(HashMap::default()),
|
||||
config,
|
||||
};
|
||||
|
||||
Ok(Arc::new(this))
|
||||
}
|
||||
|
||||
pub async fn get_active_user_count(
|
||||
&self,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
) -> Result<ActiveUserCount> {
|
||||
let now = Utc::now();
|
||||
|
||||
{
|
||||
let active_user_count_by_model = self.active_user_count_by_model.read().await;
|
||||
if let Some((last_updated, count)) =
|
||||
active_user_count_by_model.get(&(provider, model.to_string()))
|
||||
{
|
||||
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
|
||||
return Ok(*count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut cache = self.active_user_count_by_model.write().await;
|
||||
let new_count = self.db.get_active_user_count(provider, model, now).await?;
|
||||
cache.insert((provider, model.to_string()), (now, new_count));
|
||||
Ok(new_count)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes() -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/models", get(list_models))
|
||||
.route("/completion", post(perform_completion))
|
||||
.layer(middleware::from_fn(validate_api_token))
|
||||
}
|
||||
|
||||
async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
let token = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.ok_or_else(|| {
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing authorization header".to_string(),
|
||||
)
|
||||
})?
|
||||
.strip_prefix("Bearer ")
|
||||
.ok_or_else(|| {
|
||||
Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid authorization header".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
|
||||
match LlmTokenClaims::validate(token, &state.config) {
|
||||
Ok(claims) => {
|
||||
if state.db.is_access_token_revoked(&claims.jti).await? {
|
||||
return Err(Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
tracing::Span::current()
|
||||
.record("user_id", claims.user_id)
|
||||
.record("login", claims.github_user_login.clone())
|
||||
.record("authn.jti", &claims.jti)
|
||||
.record("is_staff", claims.is_staff);
|
||||
|
||||
req.extensions_mut().insert(claims);
|
||||
Ok::<_, Error>(next.run(req).await.into_response())
|
||||
}
|
||||
Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized".to_string(),
|
||||
[(
|
||||
HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
|
||||
HeaderValue::from_static("true"),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)),
|
||||
Err(_err) => Err(Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_models(
|
||||
Extension(state): Extension<Arc<LlmState>>,
|
||||
Extension(claims): Extension<LlmTokenClaims>,
|
||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
) -> Result<Json<ListModelsResponse>> {
|
||||
let country_code = country_code_header.map(|header| header.to_string());
|
||||
|
||||
let mut accessible_models = Vec::new();
|
||||
|
||||
for (provider, model) in state.db.all_models() {
|
||||
let authorize_result = authorize_access_to_language_model(
|
||||
&state.config,
|
||||
&claims,
|
||||
country_code.as_deref(),
|
||||
provider,
|
||||
&model.name,
|
||||
);
|
||||
|
||||
if authorize_result.is_ok() {
|
||||
accessible_models.push(rpc::LanguageModel {
|
||||
provider,
|
||||
name: model.name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(ListModelsResponse {
|
||||
models: accessible_models,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn perform_completion(
|
||||
Extension(state): Extension<Arc<LlmState>>,
|
||||
Extension(claims): Extension<LlmTokenClaims>,
|
||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
Json(params): Json<PerformCompletionParams>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let model = normalize_model_name(
|
||||
state.db.model_names_for_provider(params.provider),
|
||||
params.model,
|
||||
);
|
||||
|
||||
let bypass_account_age_check = claims.has_llm_subscription || claims.bypass_account_age_check;
|
||||
if !bypass_account_age_check {
|
||||
if Utc::now().naive_utc() - claims.account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
|
||||
Err(anyhow!("account too young"))?
|
||||
}
|
||||
}
|
||||
|
||||
authorize_access_to_language_model(
|
||||
&state.config,
|
||||
&claims,
|
||||
country_code_header
|
||||
.map(|header| header.to_string())
|
||||
.as_deref(),
|
||||
params.provider,
|
||||
&model,
|
||||
)?;
|
||||
|
||||
check_usage_limit(&state, params.provider, &model, &claims).await?;
|
||||
|
||||
let stream = match params.provider {
|
||||
LanguageModelProvider::Anthropic => {
|
||||
let api_key = if claims.is_staff {
|
||||
state
|
||||
.config
|
||||
.anthropic_staff_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI staff API key configured on the server")?
|
||||
} else {
|
||||
state
|
||||
.config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?
|
||||
};
|
||||
|
||||
let mut request: anthropic::Request =
|
||||
serde_json::from_str(params.provider_request.get())?;
|
||||
|
||||
// Override the model on the request with the latest version of the model that is
|
||||
// known to the server.
|
||||
//
|
||||
// Right now, we use the version that's defined in `model.id()`, but we will likely
|
||||
// want to change this code once a new version of an Anthropic model is released,
|
||||
// so that users can use the new version, without having to update Zed.
|
||||
request.model = match model.as_str() {
|
||||
"claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
|
||||
"claude-3-7-sonnet" => anthropic::Model::Claude3_7Sonnet.id().to_string(),
|
||||
"claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
|
||||
"claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
|
||||
"claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
|
||||
_ => request.model,
|
||||
};
|
||||
|
||||
let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
|
||||
&state.http_client,
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
request,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
|
||||
Some(anthropic::ApiErrorCode::RateLimitError) => {
|
||||
tracing::info!(
|
||||
target: "upstream rate limit exceeded",
|
||||
user_id = claims.user_id,
|
||||
login = claims.github_user_login,
|
||||
authn.jti = claims.jti,
|
||||
is_staff = claims.is_staff,
|
||||
provider = params.provider.to_string(),
|
||||
model = model
|
||||
);
|
||||
|
||||
Error::http(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
"Upstream Anthropic rate limit exceeded.".to_string(),
|
||||
)
|
||||
}
|
||||
Some(anthropic::ApiErrorCode::InvalidRequestError) => {
|
||||
Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
|
||||
}
|
||||
Some(anthropic::ApiErrorCode::OverloadedError) => {
|
||||
Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
|
||||
}
|
||||
Some(_) => {
|
||||
Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
|
||||
}
|
||||
None => Error::Internal(anyhow!(err)),
|
||||
},
|
||||
anthropic::AnthropicError::Other(err) => Error::Internal(err),
|
||||
})?;
|
||||
|
||||
if let Some(rate_limit_info) = rate_limit_info {
|
||||
tracing::info!(
|
||||
target: "upstream rate limit",
|
||||
is_staff = claims.is_staff,
|
||||
provider = params.provider.to_string(),
|
||||
model = model,
|
||||
tokens_remaining = rate_limit_info.tokens.as_ref().map(|limits| limits.remaining),
|
||||
input_tokens_remaining = rate_limit_info.input_tokens.as_ref().map(|limits| limits.remaining),
|
||||
output_tokens_remaining = rate_limit_info.output_tokens.as_ref().map(|limits| limits.remaining),
|
||||
requests_remaining = rate_limit_info.requests.as_ref().map(|limits| limits.remaining),
|
||||
requests_reset = ?rate_limit_info.requests.as_ref().map(|limits| limits.reset),
|
||||
tokens_reset = ?rate_limit_info.tokens.as_ref().map(|limits| limits.reset),
|
||||
input_tokens_reset = ?rate_limit_info.input_tokens.as_ref().map(|limits| limits.reset),
|
||||
output_tokens_reset = ?rate_limit_info.output_tokens.as_ref().map(|limits| limits.reset),
|
||||
);
|
||||
}
|
||||
|
||||
chunks
|
||||
.map(move |event| {
|
||||
let chunk = event?;
|
||||
let (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_input_tokens,
|
||||
cache_read_input_tokens,
|
||||
) = match &chunk {
|
||||
anthropic::Event::MessageStart {
|
||||
message: anthropic::Response { usage, .. },
|
||||
}
|
||||
| anthropic::Event::MessageDelta { usage, .. } => (
|
||||
usage.input_tokens.unwrap_or(0) as usize,
|
||||
usage.output_tokens.unwrap_or(0) as usize,
|
||||
usage.cache_creation_input_tokens.unwrap_or(0) as usize,
|
||||
usage.cache_read_input_tokens.unwrap_or(0) as usize,
|
||||
),
|
||||
_ => (0, 0, 0, 0),
|
||||
};
|
||||
|
||||
anyhow::Ok(CompletionChunk {
|
||||
bytes: serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_input_tokens,
|
||||
cache_read_input_tokens,
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
LanguageModelProvider::OpenAi => {
|
||||
let api_key = state
|
||||
.config
|
||||
.openai_api_key
|
||||
.as_ref()
|
||||
.context("no OpenAI API key configured on the server")?;
|
||||
let chunks = open_ai::stream_completion(
|
||||
&state.http_client,
|
||||
open_ai::OPEN_AI_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(params.provider_request.get())?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
chunks
|
||||
.map(|event| {
|
||||
event.map(|chunk| {
|
||||
let input_tokens =
|
||||
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
|
||||
let output_tokens =
|
||||
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
|
||||
CompletionChunk {
|
||||
bytes: serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
}
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
LanguageModelProvider::Google => {
|
||||
let api_key = state
|
||||
.config
|
||||
.google_ai_api_key
|
||||
.as_ref()
|
||||
.context("no Google AI API key configured on the server")?;
|
||||
let chunks = google_ai::stream_generate_content(
|
||||
&state.http_client,
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(params.provider_request.get())?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
chunks
|
||||
.map(|event| {
|
||||
event.map(|chunk| {
|
||||
// TODO - implement token counting for Google AI
|
||||
CompletionChunk {
|
||||
bytes: serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
}
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(TokenCountingStream {
|
||||
state,
|
||||
claims,
|
||||
provider: params.provider,
|
||||
model,
|
||||
tokens: TokenUsage::default(),
|
||||
inner_stream: stream,
|
||||
})))
|
||||
}
|
||||
|
||||
fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
|
||||
if let Some(known_model_name) = known_models
|
||||
.iter()
|
||||
.filter(|known_model_name| name.starts_with(known_model_name.as_str()))
|
||||
.max_by_key(|known_model_name| known_model_name.len())
|
||||
{
|
||||
known_model_name.to_string()
|
||||
} else {
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
/// The maximum monthly spending an individual user can reach on the free tier
|
||||
/// before they have to pay.
|
||||
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
|
||||
|
@ -452,330 +14,3 @@ pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
|
|||
///
|
||||
/// Used to prevent surprise bills.
|
||||
pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
|
||||
|
||||
async fn check_usage_limit(
|
||||
state: &Arc<LlmState>,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
claims: &LlmTokenClaims,
|
||||
) -> Result<()> {
|
||||
if claims.is_staff {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let user_id = UserId::from_proto(claims.user_id);
|
||||
let model = state.db.model(provider, model_name)?;
|
||||
let free_tier = claims.free_tier_monthly_spending_limit();
|
||||
|
||||
let spending_this_month = state
|
||||
.db
|
||||
.get_user_spending_for_month(user_id, Utc::now())
|
||||
.await?;
|
||||
if spending_this_month >= free_tier {
|
||||
if !claims.has_llm_subscription {
|
||||
return Err(Error::http(
|
||||
StatusCode::PAYMENT_REQUIRED,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let monthly_spend = spending_this_month.saturating_sub(free_tier);
|
||||
if monthly_spend >= Cents(claims.max_monthly_spend_in_cents) {
|
||||
return Err(Error::Http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
[(
|
||||
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
|
||||
HeaderValue::from_static("true"),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let active_users = state.get_active_user_count(provider, model_name).await?;
|
||||
|
||||
let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
|
||||
let users_in_recent_days = active_users.users_in_recent_days.max(1);
|
||||
|
||||
let per_user_max_requests_per_minute =
|
||||
model.max_requests_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_tokens_per_minute =
|
||||
model.max_tokens_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_input_tokens_per_minute =
|
||||
model.max_input_tokens_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_output_tokens_per_minute =
|
||||
model.max_output_tokens_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
|
||||
|
||||
let usage = state
|
||||
.db
|
||||
.get_usage(user_id, provider, model_name, Utc::now())
|
||||
.await?;
|
||||
|
||||
let checks = match (provider, model_name) {
|
||||
(LanguageModelProvider::Anthropic, "claude-3-7-sonnet") => vec![
|
||||
(
|
||||
usage.requests_this_minute,
|
||||
per_user_max_requests_per_minute,
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
),
|
||||
(
|
||||
usage.input_tokens_this_minute,
|
||||
per_user_max_tokens_per_minute,
|
||||
UsageMeasure::InputTokensPerMinute,
|
||||
),
|
||||
(
|
||||
usage.output_tokens_this_minute,
|
||||
per_user_max_tokens_per_minute,
|
||||
UsageMeasure::OutputTokensPerMinute,
|
||||
),
|
||||
(
|
||||
usage.tokens_this_day,
|
||||
per_user_max_tokens_per_day,
|
||||
UsageMeasure::TokensPerDay,
|
||||
),
|
||||
],
|
||||
_ => vec![
|
||||
(
|
||||
usage.requests_this_minute,
|
||||
per_user_max_requests_per_minute,
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
),
|
||||
(
|
||||
usage.tokens_this_minute,
|
||||
per_user_max_tokens_per_minute,
|
||||
UsageMeasure::TokensPerMinute,
|
||||
),
|
||||
(
|
||||
usage.tokens_this_day,
|
||||
per_user_max_tokens_per_day,
|
||||
UsageMeasure::TokensPerDay,
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
for (used, limit, usage_measure) in checks {
|
||||
if used > limit {
|
||||
let resource = match usage_measure {
|
||||
UsageMeasure::RequestsPerMinute => "requests_per_minute",
|
||||
UsageMeasure::TokensPerMinute => "tokens_per_minute",
|
||||
UsageMeasure::InputTokensPerMinute => "input_tokens_per_minute",
|
||||
UsageMeasure::OutputTokensPerMinute => "output_tokens_per_minute",
|
||||
UsageMeasure::TokensPerDay => "tokens_per_day",
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
target: "user rate limit",
|
||||
user_id = claims.user_id,
|
||||
login = claims.github_user_login,
|
||||
authn.jti = claims.jti,
|
||||
is_staff = claims.is_staff,
|
||||
provider = provider.to_string(),
|
||||
model = model.name,
|
||||
usage_measure = resource,
|
||||
requests_this_minute = usage.requests_this_minute,
|
||||
tokens_this_minute = usage.tokens_this_minute,
|
||||
input_tokens_this_minute = usage.input_tokens_this_minute,
|
||||
output_tokens_this_minute = usage.output_tokens_this_minute,
|
||||
tokens_this_day = usage.tokens_this_day,
|
||||
users_in_recent_minutes = users_in_recent_minutes,
|
||||
users_in_recent_days = users_in_recent_days,
|
||||
max_requests_per_minute = per_user_max_requests_per_minute,
|
||||
max_tokens_per_minute = per_user_max_tokens_per_minute,
|
||||
max_input_tokens_per_minute = per_user_max_input_tokens_per_minute,
|
||||
max_output_tokens_per_minute = per_user_max_output_tokens_per_minute,
|
||||
max_tokens_per_day = per_user_max_tokens_per_day,
|
||||
);
|
||||
|
||||
SnowflakeRow::new(
|
||||
"Language Model Rate Limited",
|
||||
Some(claims.metrics_id),
|
||||
claims.is_staff,
|
||||
claims.system_id.clone(),
|
||||
json!({
|
||||
"usage": usage,
|
||||
"users_in_recent_minutes": users_in_recent_minutes,
|
||||
"users_in_recent_days": users_in_recent_days,
|
||||
"max_requests_per_minute": per_user_max_requests_per_minute,
|
||||
"max_tokens_per_minute": per_user_max_tokens_per_minute,
|
||||
"max_input_tokens_per_minute": per_user_max_input_tokens_per_minute,
|
||||
"max_output_tokens_per_minute": per_user_max_output_tokens_per_minute,
|
||||
"max_tokens_per_day": per_user_max_tokens_per_day,
|
||||
"plan": match claims.plan {
|
||||
Plan::Free => "free".to_string(),
|
||||
Plan::ZedPro => "zed_pro".to_string(),
|
||||
},
|
||||
"model": model.name.clone(),
|
||||
"provider": provider.to_string(),
|
||||
"usage_measure": resource.to_string(),
|
||||
}),
|
||||
)
|
||||
.write(&state.kinesis_client, &state.config.kinesis_stream)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
return Err(Error::http(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
format!("Rate limit exceeded. Maximum {} reached.", resource),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CompletionChunk {
|
||||
bytes: Vec<u8>,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
cache_creation_input_tokens: usize,
|
||||
cache_read_input_tokens: usize,
|
||||
}
|
||||
|
||||
struct TokenCountingStream<S> {
|
||||
state: Arc<LlmState>,
|
||||
claims: LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: String,
|
||||
tokens: TokenUsage,
|
||||
inner_stream: S,
|
||||
}
|
||||
|
||||
impl<S> Stream for TokenCountingStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
|
||||
{
|
||||
type Item = Result<Vec<u8>, anyhow::Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(mut chunk))) => {
|
||||
chunk.bytes.push(b'\n');
|
||||
self.tokens.input += chunk.input_tokens;
|
||||
self.tokens.output += chunk.output_tokens;
|
||||
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
|
||||
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
|
||||
Poll::Ready(Some(Ok(chunk.bytes)))
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Drop for TokenCountingStream<S> {
|
||||
fn drop(&mut self) {
|
||||
let state = self.state.clone();
|
||||
let claims = self.claims.clone();
|
||||
let provider = self.provider;
|
||||
let model = std::mem::take(&mut self.model);
|
||||
let tokens = self.tokens;
|
||||
self.state.executor.spawn_detached(async move {
|
||||
let usage = state
|
||||
.db
|
||||
.record_usage(
|
||||
UserId::from_proto(claims.user_id),
|
||||
claims.is_staff,
|
||||
provider,
|
||||
&model,
|
||||
tokens,
|
||||
claims.has_llm_subscription,
|
||||
Cents(claims.max_monthly_spend_in_cents),
|
||||
claims.free_tier_monthly_spending_limit(),
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
if let Some(usage) = usage {
|
||||
tracing::info!(
|
||||
target: "user usage",
|
||||
user_id = claims.user_id,
|
||||
login = claims.github_user_login,
|
||||
authn.jti = claims.jti,
|
||||
is_staff = claims.is_staff,
|
||||
provider = provider.to_string(),
|
||||
model = model,
|
||||
requests_this_minute = usage.requests_this_minute,
|
||||
tokens_this_minute = usage.tokens_this_minute,
|
||||
input_tokens_this_minute = usage.input_tokens_this_minute,
|
||||
output_tokens_this_minute = usage.output_tokens_this_minute,
|
||||
);
|
||||
|
||||
let properties = json!({
|
||||
"has_llm_subscription": claims.has_llm_subscription,
|
||||
"max_monthly_spend_in_cents": claims.max_monthly_spend_in_cents,
|
||||
"plan": match claims.plan {
|
||||
Plan::Free => "free".to_string(),
|
||||
Plan::ZedPro => "zed_pro".to_string(),
|
||||
},
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
"usage": usage,
|
||||
"tokens": tokens
|
||||
});
|
||||
SnowflakeRow::new(
|
||||
"Language Model Used",
|
||||
Some(claims.metrics_id),
|
||||
claims.is_staff,
|
||||
claims.system_id.clone(),
|
||||
properties,
|
||||
)
|
||||
.write(&state.kinesis_client, &state.config.kinesis_stream)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn log_usage_periodically(state: Arc<LlmState>) {
|
||||
state.executor.clone().spawn_detached(async move {
|
||||
loop {
|
||||
state
|
||||
.executor
|
||||
.sleep(std::time::Duration::from_secs(30))
|
||||
.await;
|
||||
|
||||
for provider in LanguageModelProvider::iter() {
|
||||
for model in state.db.model_names_for_provider(provider) {
|
||||
if let Some(active_user_count) = state
|
||||
.get_active_user_count(provider, &model)
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
tracing::info!(
|
||||
target: "active user counts",
|
||||
provider = provider.to_string(),
|
||||
model = model,
|
||||
users_in_recent_minutes = active_user_count.users_in_recent_minutes,
|
||||
users_in_recent_days = active_user_count.users_in_recent_days,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usages) = state
|
||||
.db
|
||||
.get_application_wide_usages_by_model(Utc::now())
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
for usage in usages {
|
||||
tracing::info!(
|
||||
target: "computed usage",
|
||||
provider = usage.provider.to_string(),
|
||||
model = usage.model,
|
||||
requests_this_minute = usage.requests_this_minute,
|
||||
tokens_this_minute = usage.tokens_this_minute,
|
||||
input_tokens_this_minute = usage.input_tokens_this_minute,
|
||||
output_tokens_this_minute = usage.output_tokens_this_minute,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,330 +0,0 @@
|
|||
use reqwest::StatusCode;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
use crate::llm::LlmTokenClaims;
|
||||
use crate::{Config, Error, Result};
|
||||
|
||||
pub fn authorize_access_to_language_model(
|
||||
config: &Config,
|
||||
claims: &LlmTokenClaims,
|
||||
country_code: Option<&str>,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
) -> Result<()> {
|
||||
authorize_access_for_country(config, country_code, provider)?;
|
||||
authorize_access_to_model(config, claims, provider, model)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn authorize_access_to_model(
|
||||
config: &Config,
|
||||
claims: &LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
) -> Result<()> {
|
||||
if claims.is_staff {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if provider == LanguageModelProvider::Anthropic {
|
||||
if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if claims.has_llm_closed_beta_feature_flag
|
||||
&& Some(model) == config.llm_closed_beta_model_name.as_deref()
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to model {model:?} is not included in your plan"),
|
||||
))
|
||||
}
|
||||
|
||||
fn authorize_access_for_country(
|
||||
config: &Config,
|
||||
country_code: Option<&str>,
|
||||
provider: LanguageModelProvider,
|
||||
) -> Result<()> {
|
||||
// In development we won't have the `CF-IPCountry` header, so we can't check
|
||||
// the country code.
|
||||
//
|
||||
// This shouldn't be necessary, as anyone running in development will need to provide
|
||||
// their own API credentials in order to use an LLM provider.
|
||||
if config.is_development() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
|
||||
let country_code = match country_code {
|
||||
// `XX` - Used for clients without country code data.
|
||||
None | Some("XX") => Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"no country code".to_string(),
|
||||
))?,
|
||||
// `T1` - Used for clients using the Tor network.
|
||||
Some("T1") => Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to {provider:?} models is not available over Tor"),
|
||||
))?,
|
||||
Some(country_code) => country_code,
|
||||
};
|
||||
|
||||
let is_country_supported_by_provider = match provider {
|
||||
LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
|
||||
LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
|
||||
LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
|
||||
};
|
||||
if !is_country_supported_by_provider {
|
||||
Err(Error::http(
|
||||
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
|
||||
format!(
|
||||
"access to {provider:?} models is not available in your region ({country_code})"
|
||||
),
|
||||
))?
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::response::IntoResponse;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::proto::Plan;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_with_supported_country(
|
||||
_cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
is_staff: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "US"), // United States
|
||||
(LanguageModelProvider::Anthropic, "GB"), // United Kingdom
|
||||
(LanguageModelProvider::OpenAi, "US"), // United States
|
||||
(LanguageModelProvider::OpenAi, "GB"), // United Kingdom
|
||||
(LanguageModelProvider::Google, "US"), // United States
|
||||
(LanguageModelProvider::Google, "GB"), // United Kingdom
|
||||
];
|
||||
|
||||
for (provider, country_code) in cases {
|
||||
authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some(country_code),
|
||||
provider,
|
||||
"the-model",
|
||||
)
|
||||
.unwrap_or_else(|_| {
|
||||
panic!("expected authorization to return Ok for {provider:?}: {country_code}")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_with_unsupported_country(
|
||||
_cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "AF"), // Afghanistan
|
||||
(LanguageModelProvider::Anthropic, "BY"), // Belarus
|
||||
(LanguageModelProvider::Anthropic, "CF"), // Central African Republic
|
||||
(LanguageModelProvider::Anthropic, "CN"), // China
|
||||
(LanguageModelProvider::Anthropic, "CU"), // Cuba
|
||||
(LanguageModelProvider::Anthropic, "ER"), // Eritrea
|
||||
(LanguageModelProvider::Anthropic, "ET"), // Ethiopia
|
||||
(LanguageModelProvider::Anthropic, "IR"), // Iran
|
||||
(LanguageModelProvider::Anthropic, "KP"), // North Korea
|
||||
(LanguageModelProvider::Anthropic, "XK"), // Kosovo
|
||||
(LanguageModelProvider::Anthropic, "LY"), // Libya
|
||||
(LanguageModelProvider::Anthropic, "MM"), // Myanmar
|
||||
(LanguageModelProvider::Anthropic, "RU"), // Russia
|
||||
(LanguageModelProvider::Anthropic, "SO"), // Somalia
|
||||
(LanguageModelProvider::Anthropic, "SS"), // South Sudan
|
||||
(LanguageModelProvider::Anthropic, "SD"), // Sudan
|
||||
(LanguageModelProvider::Anthropic, "SY"), // Syria
|
||||
(LanguageModelProvider::Anthropic, "VE"), // Venezuela
|
||||
(LanguageModelProvider::Anthropic, "YE"), // Yemen
|
||||
(LanguageModelProvider::OpenAi, "KP"), // North Korea
|
||||
(LanguageModelProvider::Google, "KP"), // North Korea
|
||||
];
|
||||
|
||||
for (provider, country_code) in cases {
|
||||
let error_response = authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some(country_code),
|
||||
provider,
|
||||
"the-model",
|
||||
)
|
||||
.expect_err(&format!(
|
||||
"expected authorization to return an error for {provider:?}: {country_code}"
|
||||
))
|
||||
.into_response();
|
||||
|
||||
assert_eq!(
|
||||
error_response.status(),
|
||||
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
|
||||
);
|
||||
let response_body = hyper::body::to_bytes(error_response.into_body())
|
||||
.await
|
||||
.unwrap()
|
||||
.to_vec();
|
||||
assert_eq!(
|
||||
String::from_utf8(response_body).unwrap(),
|
||||
format!(
|
||||
"access to {provider:?} models is not available in your region ({country_code})"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "T1"), // Tor
|
||||
(LanguageModelProvider::OpenAi, "T1"), // Tor
|
||||
(LanguageModelProvider::Google, "T1"), // Tor
|
||||
];
|
||||
|
||||
for (provider, country_code) in cases {
|
||||
let error_response = authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some(country_code),
|
||||
provider,
|
||||
"the-model",
|
||||
)
|
||||
.expect_err(&format!(
|
||||
"expected authorization to return an error for {provider:?}: {country_code}"
|
||||
))
|
||||
.into_response();
|
||||
|
||||
assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
|
||||
let response_body = hyper::body::to_bytes(error_response.into_body())
|
||||
.await
|
||||
.unwrap()
|
||||
.to_vec();
|
||||
assert_eq!(
|
||||
String::from_utf8(response_body).unwrap(),
|
||||
format!("access to {provider:?} models is not available over Tor")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_based_on_plan() {
|
||||
let config = Config::test();
|
||||
|
||||
let test_cases = vec![
|
||||
// Pro plan should have access to claude-3.5-sonnet
|
||||
(
|
||||
Plan::ZedPro,
|
||||
LanguageModelProvider::Anthropic,
|
||||
"claude-3-5-sonnet",
|
||||
true,
|
||||
),
|
||||
// Free plan should have access to claude-3.5-sonnet
|
||||
(
|
||||
Plan::Free,
|
||||
LanguageModelProvider::Anthropic,
|
||||
"claude-3-5-sonnet",
|
||||
true,
|
||||
),
|
||||
// Pro plan should NOT have access to other Anthropic models
|
||||
(
|
||||
Plan::ZedPro,
|
||||
LanguageModelProvider::Anthropic,
|
||||
"claude-3-opus",
|
||||
false,
|
||||
),
|
||||
];
|
||||
|
||||
for (plan, provider, model, expected_access) in test_cases {
|
||||
let claims = LlmTokenClaims {
|
||||
plan,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result =
|
||||
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
|
||||
|
||||
if expected_access {
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Expected access to be granted for plan {:?}, provider {:?}, model {}",
|
||||
plan,
|
||||
provider,
|
||||
model
|
||||
);
|
||||
} else {
|
||||
let error = result.expect_err(&format!(
|
||||
"Expected access to be denied for plan {:?}, provider {:?}, model {}",
|
||||
plan, provider, model
|
||||
));
|
||||
let response = error.into_response();
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_for_staff() {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
is_staff: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Staff should have access to all models
|
||||
let test_cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
|
||||
(LanguageModelProvider::Anthropic, "claude-2"),
|
||||
(LanguageModelProvider::Anthropic, "claude-123-agi"),
|
||||
(LanguageModelProvider::OpenAi, "gpt-4"),
|
||||
(LanguageModelProvider::Google, "gemini-pro"),
|
||||
];
|
||||
|
||||
for (provider, model) in test_cases {
|
||||
let result =
|
||||
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Expected staff to have access to provider {:?}, model {}",
|
||||
provider,
|
||||
model
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,7 +20,6 @@ use std::future::Future;
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
pub use queries::usages::{ActiveUserCount, TokenUsage};
|
||||
pub use sea_orm::ConnectOptions;
|
||||
use sea_orm::prelude::*;
|
||||
use sea_orm::{
|
||||
|
|
|
@ -2,5 +2,4 @@ use super::*;
|
|||
|
||||
pub mod billing_events;
|
||||
pub mod providers;
|
||||
pub mod revoked_access_tokens;
|
||||
pub mod usages;
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
/// Returns whether the access token with the given `jti` has been revoked.
|
||||
pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(revoked_access_token::Entity::find()
|
||||
.filter(revoked_access_token::Column::Jti.eq(jti))
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.is_some())
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -1,56 +1,12 @@
|
|||
use crate::db::UserId;
|
||||
use crate::llm::Cents;
|
||||
use chrono::{Datelike, Duration};
|
||||
use chrono::Datelike;
|
||||
use futures::StreamExt as _;
|
||||
use rpc::LanguageModelProvider;
|
||||
use sea_orm::QuerySelect;
|
||||
use std::{iter, str::FromStr};
|
||||
use std::str::FromStr;
|
||||
use strum::IntoEnumIterator as _;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
|
||||
pub struct TokenUsage {
|
||||
pub input: usize,
|
||||
pub input_cache_creation: usize,
|
||||
pub input_cache_read: usize,
|
||||
pub output: usize,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
pub fn total(&self) -> usize {
|
||||
self.input + self.input_cache_creation + self.input_cache_read + self.output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
|
||||
pub struct Usage {
|
||||
pub requests_this_minute: usize,
|
||||
pub tokens_this_minute: usize,
|
||||
pub input_tokens_this_minute: usize,
|
||||
pub output_tokens_this_minute: usize,
|
||||
pub tokens_this_day: usize,
|
||||
pub tokens_this_month: TokenUsage,
|
||||
pub spending_this_month: Cents,
|
||||
pub lifetime_spending: Cents,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct ApplicationWideUsage {
|
||||
pub provider: LanguageModelProvider,
|
||||
pub model: String,
|
||||
pub requests_this_minute: usize,
|
||||
pub tokens_this_minute: usize,
|
||||
pub input_tokens_this_minute: usize,
|
||||
pub output_tokens_this_minute: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct ActiveUserCount {
|
||||
pub users_in_recent_minutes: usize,
|
||||
pub users_in_recent_days: usize,
|
||||
}
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
|
||||
let all_measures = self
|
||||
|
@ -90,100 +46,6 @@ impl LlmDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_application_wide_usages_by_model(
|
||||
&self,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Vec<ApplicationWideUsage>> {
|
||||
self.transaction(|tx| async move {
|
||||
let past_minute = now - Duration::minutes(1);
|
||||
let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
|
||||
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
|
||||
let input_tokens_per_minute =
|
||||
self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute];
|
||||
let output_tokens_per_minute =
|
||||
self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute];
|
||||
|
||||
let mut results = Vec::new();
|
||||
for ((provider, model_name), model) in self.models.iter() {
|
||||
let mut usages = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::Timestamp
|
||||
.gte(past_minute.naive_utc())
|
||||
.and(usage::Column::IsStaff.eq(false))
|
||||
.and(usage::Column::ModelId.eq(model.id))
|
||||
.and(
|
||||
usage::Column::MeasureId
|
||||
.eq(requests_per_minute)
|
||||
.or(usage::Column::MeasureId.eq(tokens_per_minute)),
|
||||
),
|
||||
)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut requests_this_minute = 0;
|
||||
let mut tokens_this_minute = 0;
|
||||
let mut input_tokens_this_minute = 0;
|
||||
let mut output_tokens_this_minute = 0;
|
||||
while let Some(usage) = usages.next().await {
|
||||
let usage = usage?;
|
||||
if usage.measure_id == requests_per_minute {
|
||||
requests_this_minute += Self::get_live_buckets(
|
||||
&usage,
|
||||
now.naive_utc(),
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
)
|
||||
.0
|
||||
.iter()
|
||||
.copied()
|
||||
.sum::<i64>() as usize;
|
||||
} else if usage.measure_id == tokens_per_minute {
|
||||
tokens_this_minute += Self::get_live_buckets(
|
||||
&usage,
|
||||
now.naive_utc(),
|
||||
UsageMeasure::TokensPerMinute,
|
||||
)
|
||||
.0
|
||||
.iter()
|
||||
.copied()
|
||||
.sum::<i64>() as usize;
|
||||
} else if usage.measure_id == input_tokens_per_minute {
|
||||
input_tokens_this_minute += Self::get_live_buckets(
|
||||
&usage,
|
||||
now.naive_utc(),
|
||||
UsageMeasure::InputTokensPerMinute,
|
||||
)
|
||||
.0
|
||||
.iter()
|
||||
.copied()
|
||||
.sum::<i64>() as usize;
|
||||
} else if usage.measure_id == output_tokens_per_minute {
|
||||
output_tokens_this_minute += Self::get_live_buckets(
|
||||
&usage,
|
||||
now.naive_utc(),
|
||||
UsageMeasure::OutputTokensPerMinute,
|
||||
)
|
||||
.0
|
||||
.iter()
|
||||
.copied()
|
||||
.sum::<i64>() as usize;
|
||||
}
|
||||
}
|
||||
|
||||
results.push(ApplicationWideUsage {
|
||||
provider: *provider,
|
||||
model: model_name.clone(),
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
input_tokens_this_minute,
|
||||
output_tokens_this_minute,
|
||||
})
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_user_spending_for_month(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
|
@ -223,499 +85,6 @@ impl LlmDatabase {
|
|||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_usage(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Usage> {
|
||||
self.transaction(|tx| async move {
|
||||
let model = self
|
||||
.models
|
||||
.get(&(provider, model_name.to_string()))
|
||||
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
|
||||
|
||||
let usages = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let month = now.date_naive().month() as i32;
|
||||
let year = now.date_naive().year();
|
||||
let monthly_usage = monthly_usage::Entity::find()
|
||||
.filter(
|
||||
monthly_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(monthly_usage::Column::ModelId.eq(model.id))
|
||||
.and(monthly_usage::Column::Month.eq(month))
|
||||
.and(monthly_usage::Column::Year.eq(year)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
let lifetime_usage = lifetime_usage::Entity::find()
|
||||
.filter(
|
||||
lifetime_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(lifetime_usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
let requests_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
|
||||
let tokens_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
|
||||
let input_tokens_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?;
|
||||
let output_tokens_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?;
|
||||
let tokens_this_day =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
|
||||
let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
|
||||
calculate_spending(
|
||||
model,
|
||||
monthly_usage.input_tokens as usize,
|
||||
monthly_usage.cache_creation_input_tokens as usize,
|
||||
monthly_usage.cache_read_input_tokens as usize,
|
||||
monthly_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
Cents::ZERO
|
||||
};
|
||||
let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
|
||||
calculate_spending(
|
||||
model,
|
||||
lifetime_usage.input_tokens as usize,
|
||||
lifetime_usage.cache_creation_input_tokens as usize,
|
||||
lifetime_usage.cache_read_input_tokens as usize,
|
||||
lifetime_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
Cents::ZERO
|
||||
};
|
||||
|
||||
Ok(Usage {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
input_tokens_this_minute,
|
||||
output_tokens_this_minute,
|
||||
tokens_this_day,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.input_tokens as usize),
|
||||
input_cache_creation: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
|
||||
input_cache_read: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
|
||||
output: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.output_tokens as usize),
|
||||
},
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn record_usage(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
is_staff: bool,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
tokens: TokenUsage,
|
||||
has_llm_subscription: bool,
|
||||
max_monthly_spend: Cents,
|
||||
free_tier_monthly_spending_limit: Cents,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Usage> {
|
||||
self.transaction(|tx| async move {
|
||||
let model = self.model(provider, model_name)?;
|
||||
|
||||
let usages = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let requests_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
now,
|
||||
1,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let tokens_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::TokensPerMinute,
|
||||
now,
|
||||
tokens.total(),
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let input_tokens_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::InputTokensPerMinute,
|
||||
now,
|
||||
// Cache read input tokens are not counted for the purposes of rate limits (but they are still billed).
|
||||
tokens.input + tokens.input_cache_creation,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let output_tokens_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::OutputTokensPerMinute,
|
||||
now,
|
||||
tokens.output,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let tokens_this_day = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::TokensPerDay,
|
||||
now,
|
||||
tokens.total(),
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let month = now.date_naive().month() as i32;
|
||||
let year = now.date_naive().year();
|
||||
|
||||
// Update monthly usage
|
||||
let monthly_usage = monthly_usage::Entity::find()
|
||||
.filter(
|
||||
monthly_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(monthly_usage::Column::ModelId.eq(model.id))
|
||||
.and(monthly_usage::Column::Month.eq(month))
|
||||
.and(monthly_usage::Column::Year.eq(year)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
let monthly_usage = match monthly_usage {
|
||||
Some(usage) => {
|
||||
monthly_usage::Entity::update(monthly_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
monthly_usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
month: ActiveValue::set(month),
|
||||
year: ActiveValue::set(year),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let spending_this_month = calculate_spending(
|
||||
model,
|
||||
monthly_usage.input_tokens as usize,
|
||||
monthly_usage.cache_creation_input_tokens as usize,
|
||||
monthly_usage.cache_read_input_tokens as usize,
|
||||
monthly_usage.output_tokens as usize,
|
||||
);
|
||||
|
||||
if !is_staff
|
||||
&& spending_this_month > free_tier_monthly_spending_limit
|
||||
&& has_llm_subscription
|
||||
&& (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend
|
||||
{
|
||||
billing_event::ActiveModel {
|
||||
id: ActiveValue::not_set(),
|
||||
idempotency_key: ActiveValue::not_set(),
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
input_cache_creation_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Update lifetime usage
|
||||
let lifetime_usage = lifetime_usage::Entity::find()
|
||||
.filter(
|
||||
lifetime_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(lifetime_usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
let lifetime_usage = match lifetime_usage {
|
||||
Some(usage) => {
|
||||
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
lifetime_usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let lifetime_spending = calculate_spending(
|
||||
model,
|
||||
lifetime_usage.input_tokens as usize,
|
||||
lifetime_usage.cache_creation_input_tokens as usize,
|
||||
lifetime_usage.cache_read_input_tokens as usize,
|
||||
lifetime_usage.output_tokens as usize,
|
||||
);
|
||||
|
||||
Ok(Usage {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
input_tokens_this_minute,
|
||||
output_tokens_this_minute,
|
||||
tokens_this_day,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage.input_tokens as usize,
|
||||
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
|
||||
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
|
||||
output: monthly_usage.output_tokens as usize,
|
||||
},
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the active user count for the specified model.
|
||||
pub async fn get_active_user_count(
|
||||
&self,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<ActiveUserCount> {
|
||||
self.transaction(|tx| async move {
|
||||
let minute_since = now - Duration::minutes(5);
|
||||
let day_since = now - Duration::days(5);
|
||||
|
||||
let model = self
|
||||
.models
|
||||
.get(&(provider, model_name.to_string()))
|
||||
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
|
||||
|
||||
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
|
||||
|
||||
let users_in_recent_minutes = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::ModelId
|
||||
.eq(model.id)
|
||||
.and(usage::Column::MeasureId.eq(tokens_per_minute))
|
||||
.and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
|
||||
.and(usage::Column::IsStaff.eq(false)),
|
||||
)
|
||||
.select_only()
|
||||
.column(usage::Column::UserId)
|
||||
.group_by(usage::Column::UserId)
|
||||
.count(&*tx)
|
||||
.await? as usize;
|
||||
|
||||
let users_in_recent_days = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::ModelId
|
||||
.eq(model.id)
|
||||
.and(usage::Column::MeasureId.eq(tokens_per_minute))
|
||||
.and(usage::Column::Timestamp.gte(day_since.naive_utc()))
|
||||
.and(usage::Column::IsStaff.eq(false)),
|
||||
)
|
||||
.select_only()
|
||||
.column(usage::Column::UserId)
|
||||
.group_by(usage::Column::UserId)
|
||||
.count(&*tx)
|
||||
.await? as usize;
|
||||
|
||||
Ok(ActiveUserCount {
|
||||
users_in_recent_minutes,
|
||||
users_in_recent_days,
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn update_usage_for_measure(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
is_staff: bool,
|
||||
model_id: ModelId,
|
||||
usages: &[usage::Model],
|
||||
usage_measure: UsageMeasure,
|
||||
now: DateTimeUtc,
|
||||
usage_to_add: usize,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<usize> {
|
||||
let now = now.naive_utc();
|
||||
let measure_id = *self
|
||||
.usage_measure_ids
|
||||
.get(&usage_measure)
|
||||
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
|
||||
|
||||
let mut id = None;
|
||||
let mut timestamp = now;
|
||||
let mut buckets = vec![0_i64];
|
||||
|
||||
if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
|
||||
id = Some(old_usage.id);
|
||||
let (live_buckets, buckets_since) =
|
||||
Self::get_live_buckets(old_usage, now, usage_measure);
|
||||
if !live_buckets.is_empty() {
|
||||
buckets.clear();
|
||||
buckets.extend_from_slice(live_buckets);
|
||||
buckets.extend(iter::repeat(0).take(buckets_since));
|
||||
timestamp =
|
||||
old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
|
||||
}
|
||||
}
|
||||
|
||||
*buckets.last_mut().unwrap() += usage_to_add as i64;
|
||||
let total_usage = buckets.iter().sum::<i64>() as usize;
|
||||
|
||||
let mut model = usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
is_staff: ActiveValue::set(is_staff),
|
||||
model_id: ActiveValue::set(model_id),
|
||||
measure_id: ActiveValue::set(measure_id),
|
||||
timestamp: ActiveValue::set(timestamp),
|
||||
buckets: ActiveValue::set(buckets),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some(id) = id {
|
||||
model.id = ActiveValue::unchanged(id);
|
||||
model.update(tx).await?;
|
||||
} else {
|
||||
usage::Entity::insert(model)
|
||||
.exec_without_returning(tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(total_usage)
|
||||
}
|
||||
|
||||
fn get_usage_for_measure(
|
||||
&self,
|
||||
usages: &[usage::Model],
|
||||
now: DateTimeUtc,
|
||||
usage_measure: UsageMeasure,
|
||||
) -> Result<usize> {
|
||||
let now = now.naive_utc();
|
||||
let measure_id = *self
|
||||
.usage_measure_ids
|
||||
.get(&usage_measure)
|
||||
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
|
||||
let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
|
||||
return Ok(0);
|
||||
};
|
||||
|
||||
let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
|
||||
Ok(live_buckets.iter().sum::<i64>() as _)
|
||||
}
|
||||
|
||||
fn get_live_buckets(
|
||||
usage: &usage::Model,
|
||||
now: chrono::NaiveDateTime,
|
||||
measure: UsageMeasure,
|
||||
) -> (&[i64], usize) {
|
||||
let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
|
||||
let buckets_since_usage =
|
||||
seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
|
||||
let buckets_since_usage = buckets_since_usage.ceil() as usize;
|
||||
let mut live_buckets = &[] as &[i64];
|
||||
if buckets_since_usage < measure.bucket_count() {
|
||||
let expired_bucket_count =
|
||||
(usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
|
||||
live_buckets = &usage.buckets[expired_bucket_count..];
|
||||
while live_buckets.first() == Some(&0) {
|
||||
live_buckets = &live_buckets[1..];
|
||||
}
|
||||
}
|
||||
(live_buckets, buckets_since_usage)
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_spending(
|
||||
|
@ -741,32 +110,3 @@ fn calculate_spending(
|
|||
+ output_token_cost;
|
||||
Cents::new(spending as u32)
|
||||
}
|
||||
|
||||
const MINUTE_BUCKET_COUNT: usize = 12;
|
||||
const DAY_BUCKET_COUNT: usize = 48;
|
||||
|
||||
impl UsageMeasure {
|
||||
fn bucket_count(&self) -> usize {
|
||||
match self {
|
||||
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
|
||||
UsageMeasure::TokensPerMinute
|
||||
| UsageMeasure::InputTokensPerMinute
|
||||
| UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT,
|
||||
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
|
||||
}
|
||||
}
|
||||
|
||||
fn total_duration(&self) -> Duration {
|
||||
match self {
|
||||
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
|
||||
UsageMeasure::TokensPerMinute
|
||||
| UsageMeasure::InputTokensPerMinute
|
||||
| UsageMeasure::OutputTokensPerMinute => Duration::minutes(1),
|
||||
UsageMeasure::TokensPerDay => Duration::hours(24),
|
||||
}
|
||||
}
|
||||
|
||||
fn bucket_duration(&self) -> Duration {
|
||||
self.total_duration() / self.bucket_count() as i32
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
pub mod billing_event;
|
||||
pub mod lifetime_usage;
|
||||
pub mod model;
|
||||
pub mod monthly_usage;
|
||||
pub mod provider;
|
||||
pub mod revoked_access_token;
|
||||
pub mod usage;
|
||||
pub mod usage_measure;
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
use crate::{db::UserId, llm::db::ModelId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "lifetime_usages")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: i32,
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub input_tokens: i64,
|
||||
pub cache_creation_input_tokens: i64,
|
||||
pub cache_read_input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,19 +0,0 @@
|
|||
use chrono::NaiveDateTime;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
use crate::llm::db::RevokedAccessTokenId;
|
||||
|
||||
/// A revoked access token.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "revoked_access_tokens")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: RevokedAccessTokenId,
|
||||
pub jti: String,
|
||||
pub revoked_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,6 +1,4 @@
|
|||
mod billing_tests;
|
||||
mod provider_tests;
|
||||
mod usage_tests;
|
||||
|
||||
use gpui::BackgroundExecutor;
|
||||
use parking_lot::Mutex;
|
||||
|
|
|
@ -1,152 +0,0 @@
|
|||
use crate::{
|
||||
Cents,
|
||||
db::UserId,
|
||||
llm::{
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
db::{LlmDatabase, TokenUsage, queries::providers::ModelParams},
|
||||
},
|
||||
test_llm_db,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
test_llm_db!(
|
||||
test_billing_limit_exceeded,
|
||||
test_billing_limit_exceeded_postgres
|
||||
);
|
||||
|
||||
async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
|
||||
let provider = LanguageModelProvider::Anthropic;
|
||||
let model = "fake-claude-limerick";
|
||||
const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
|
||||
const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
|
||||
|
||||
// Initialize the database and insert the model
|
||||
db.initialize().await.unwrap();
|
||||
db.insert_models(&[ModelParams {
|
||||
provider,
|
||||
name: model.to_string(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 50_000,
|
||||
price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
|
||||
price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
|
||||
}])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Set a fixed datetime for consistent testing
|
||||
let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let max_monthly_spend = Cents::from_dollars(11);
|
||||
|
||||
// Record usage that brings us close to the limit but doesn't exceed it
|
||||
// Let's say we use $10.50 worth of tokens
|
||||
let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
|
||||
let usage = TokenUsage {
|
||||
input: tokens_to_use,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
|
||||
// Verify that before we record any usage, there are 0 billing events
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 0);
|
||||
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the recorded usage and spending
|
||||
let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
// Verify that we exceeded the free tier usage
|
||||
assert_eq!(recorded_usage.spending_this_month, Cents::new(1050));
|
||||
assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT);
|
||||
|
||||
// Verify that there is one `billing_event` record
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 1);
|
||||
|
||||
let (billing_event, _model) = &billing_events[0];
|
||||
assert_eq!(billing_event.user_id, user_id);
|
||||
assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
|
||||
assert_eq!(billing_event.input_cache_creation_tokens, 0);
|
||||
assert_eq!(billing_event.input_cache_read_tokens, 0);
|
||||
assert_eq!(billing_event.output_tokens, 0);
|
||||
|
||||
// Record usage that puts us at $20.50
|
||||
let usage_2 = TokenUsage {
|
||||
input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage_2,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the updated usage and spending
|
||||
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(updated_usage.spending_this_month, Cents::new(2050));
|
||||
|
||||
// Verify that there are now two billing events
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 2);
|
||||
|
||||
let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit
|
||||
let usage_exceeding = TokenUsage {
|
||||
input: tokens_to_exceed,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
|
||||
// This should still create a billing event as it's the first request that exceeds the limit
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage_exceeding,
|
||||
true,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
max_monthly_spend,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// Verify the updated usage and spending
|
||||
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(updated_usage.spending_this_month, Cents::new(2150));
|
||||
|
||||
// Verify that we never exceed the user max spending for the user
|
||||
// and avoid charging them.
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 2);
|
||||
}
|
|
@ -1,306 +0,0 @@
|
|||
use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT;
|
||||
use crate::{
|
||||
Cents,
|
||||
db::UserId,
|
||||
llm::db::{
|
||||
LlmDatabase, TokenUsage,
|
||||
queries::{providers::ModelParams, usages::Usage},
|
||||
},
|
||||
test_llm_db,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
|
||||
|
||||
async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
let provider = LanguageModelProvider::Anthropic;
|
||||
let model = "claude-3-5-sonnet";
|
||||
|
||||
db.initialize().await.unwrap();
|
||||
db.insert_models(&[ModelParams {
|
||||
provider,
|
||||
name: model.to_string(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 50_000,
|
||||
price_per_million_input_tokens: 50,
|
||||
price_per_million_output_tokens: 50,
|
||||
}])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We're using a fixed datetime to prevent flakiness based on the clock.
|
||||
let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let now = t0;
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let now = t0 + Duration::seconds(10);
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 3000,
|
||||
input_tokens_this_minute: 3000,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 3000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
let now = t0 + Duration::seconds(60);
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 2000,
|
||||
input_tokens_this_minute: 2000,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 3000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
let now = t0 + Duration::seconds(60);
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 5000,
|
||||
input_tokens_this_minute: 5000,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 6000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
let t1 = t0 + Duration::hours(24);
|
||||
let now = t1;
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 0,
|
||||
tokens_this_minute: 0,
|
||||
input_tokens_this_minute: 0,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 5000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 4000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 4000,
|
||||
input_tokens_this_minute: 4000,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 9000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 10000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
// We're using a fixed datetime to prevent flakiness based on the clock.
|
||||
let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
|
||||
// Test cache creation input tokens
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 1500,
|
||||
input_tokens_this_minute: 1500,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 1500,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
// Test cache read input tokens
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 2800,
|
||||
input_tokens_this_minute: 2500,
|
||||
output_tokens_this_minute: 0,
|
||||
tokens_this_day: 2800,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
}
|
|
@ -9,14 +9,14 @@ use axum::{
|
|||
|
||||
use collab::api::CloudflareIpCountryHeader;
|
||||
use collab::api::billing::sync_llm_usage_with_stripe_periodically;
|
||||
use collab::llm::{db::LlmDatabase, log_usage_periodically};
|
||||
use collab::llm::db::LlmDatabase;
|
||||
use collab::migrations::run_database_migrations;
|
||||
use collab::user_backfiller::spawn_user_backfiller;
|
||||
use collab::{
|
||||
AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
|
||||
env, executor::Executor, rpc::ResultExt,
|
||||
};
|
||||
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically, llm::LlmState};
|
||||
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
|
||||
use db::Database;
|
||||
use std::{
|
||||
env::args,
|
||||
|
@ -74,11 +74,10 @@ async fn main() -> Result<()> {
|
|||
let mode = match args.next().as_deref() {
|
||||
Some("collab") => ServiceMode::Collab,
|
||||
Some("api") => ServiceMode::Api,
|
||||
Some("llm") => ServiceMode::Llm,
|
||||
Some("all") => ServiceMode::All,
|
||||
_ => {
|
||||
return Err(anyhow!(
|
||||
"usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
|
||||
"usage: collab <version | migrate | seed | serve <api|collab|all>>"
|
||||
))?;
|
||||
}
|
||||
};
|
||||
|
@ -97,20 +96,9 @@ async fn main() -> Result<()> {
|
|||
|
||||
let mut on_shutdown = None;
|
||||
|
||||
if mode.is_llm() {
|
||||
setup_llm_database(&config).await?;
|
||||
|
||||
let state = LlmState::new(config.clone(), Executor::Production).await?;
|
||||
|
||||
log_usage_periodically(state.clone());
|
||||
|
||||
app = app
|
||||
.merge(collab::llm::routes())
|
||||
.layer(Extension(state.clone()));
|
||||
}
|
||||
|
||||
if mode.is_collab() || mode.is_api() {
|
||||
setup_app_database(&config).await?;
|
||||
setup_llm_database(&config).await?;
|
||||
|
||||
let state = AppState::new(config, Executor::Production).await?;
|
||||
|
||||
|
@ -336,18 +324,11 @@ async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
|
|||
format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
|
||||
}
|
||||
|
||||
async fn handle_liveness_probe(
|
||||
app_state: Option<Extension<Arc<AppState>>>,
|
||||
llm_state: Option<Extension<Arc<LlmState>>>,
|
||||
) -> Result<String> {
|
||||
async fn handle_liveness_probe(app_state: Option<Extension<Arc<AppState>>>) -> Result<String> {
|
||||
if let Some(state) = app_state {
|
||||
state.db.get_all_users(0, 1).await?;
|
||||
}
|
||||
|
||||
if let Some(llm_state) = llm_state {
|
||||
llm_state.db.list_providers().await?;
|
||||
}
|
||||
|
||||
Ok("ok".to_string())
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue