diff --git a/.github/workflows/deploy_collab.yml b/.github/workflows/deploy_collab.yml index d921a08bf1..eb5875afcc 100644 --- a/.github/workflows/deploy_collab.yml +++ b/.github/workflows/deploy_collab.yml @@ -117,12 +117,10 @@ jobs: export ZED_KUBE_NAMESPACE=production export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10 export ZED_API_LOAD_BALANCER_SIZE_UNIT=2 - export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=2 elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then export ZED_KUBE_NAMESPACE=staging export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1 export ZED_API_LOAD_BALANCER_SIZE_UNIT=1 - export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=1 else echo "cowardly refusing to deploy from an unknown branch" exit 1 @@ -147,9 +145,3 @@ jobs: envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" - - export ZED_SERVICE_NAME=llm - export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_LLM_LOAD_BALANCER_SIZE_UNIT - envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch - echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" diff --git a/Cargo.lock b/Cargo.lock index b1c75fe3f6..b95c9dce18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2942,7 +2942,6 @@ dependencies = [ name = "collab" version = "0.44.0" dependencies = [ - "anthropic", "anyhow", "assistant", "assistant_context_editor", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index f60131e0de..c4aa90e2c2 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -18,7 +18,6 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"] test-support = ["sqlite"] [dependencies] -anthropic.workspace = true anyhow.workspace = true async-stripe.workspace = true async-tungstenite.workspace = true diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 334d015d4b..2e682d2878 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -253,7 +253,6 @@ impl Config { pub enum ServiceMode { Api, Collab, - Llm, All, } @@ -265,10 +264,6 @@ impl ServiceMode { pub fn is_api(&self) -> bool { matches!(self, Self::Api | Self::All) } - - pub fn is_llm(&self) -> bool { - matches!(self, Self::Llm | Self::All) - } } pub struct AppState { diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 8c6fd772df..13d503e7d4 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,448 +1,10 @@ -mod authorization; pub mod db; mod token; -use crate::api::CloudflareIpCountryHeader; -use crate::api::events::SnowflakeRow; -use crate::build_kinesis_client; -use crate::rpc::MIN_ACCOUNT_AGE_FOR_LLM_USE; -use crate::{Cents, Config, Error, Result, db::UserId, executor::Executor}; -use anyhow::{Context as _, anyhow}; -use authorization::authorize_access_to_language_model; -use axum::routing::get; -use axum::{ - Extension, Json, Router, TypedHeader, - body::Body, - http::{self, HeaderName, HeaderValue, Request, StatusCode}, - middleware::{self, Next}, - response::{IntoResponse, Response}, - routing::post, -}; -use chrono::{DateTime, Duration, Utc}; -use collections::HashMap; -use db::TokenUsage; -use db::{ActiveUserCount, LlmDatabase, usage_measure::UsageMeasure}; -use futures::{Stream, StreamExt as _}; -use reqwest_client::ReqwestClient; -use rpc::{ - EXPIRED_LLM_TOKEN_HEADER_NAME, LanguageModelProvider, PerformCompletionParams, proto::Plan, -}; -use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME}; -use serde_json::json; -use std::{ - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; -use strum::IntoEnumIterator; -use tokio::sync::RwLock; -use util::ResultExt; +use crate::Cents; pub use token::*; -const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30); - -pub struct LlmState { - pub config: Config, - pub executor: Executor, - pub db: Arc, - pub http_client: ReqwestClient, - pub kinesis_client: Option, - active_user_count_by_model: - RwLock, ActiveUserCount)>>, -} - -impl LlmState { - pub async fn new(config: Config, executor: Executor) -> Result> { - let database_url = config - .llm_database_url - .as_ref() - .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?; - let max_connections = config - .llm_database_max_connections - .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?; - - let mut db_options = db::ConnectOptions::new(database_url); - db_options.max_connections(max_connections); - let mut db = LlmDatabase::new(db_options, executor.clone()).await?; - db.initialize().await?; - - let db = Arc::new(db); - - let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION")); - let http_client = - ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?; - - let this = Self { - executor, - db, - http_client, - kinesis_client: if config.kinesis_access_key.is_some() { - build_kinesis_client(&config).await.log_err() - } else { - None - }, - active_user_count_by_model: RwLock::new(HashMap::default()), - config, - }; - - Ok(Arc::new(this)) - } - - pub async fn get_active_user_count( - &self, - provider: LanguageModelProvider, - model: &str, - ) -> Result { - let now = Utc::now(); - - { - let active_user_count_by_model = self.active_user_count_by_model.read().await; - if let Some((last_updated, count)) = - active_user_count_by_model.get(&(provider, model.to_string())) - { - if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION { - return Ok(*count); - } - } - } - - let mut cache = self.active_user_count_by_model.write().await; - let new_count = self.db.get_active_user_count(provider, model, now).await?; - cache.insert((provider, model.to_string()), (now, new_count)); - Ok(new_count) - } -} - -pub fn routes() -> Router<(), Body> { - Router::new() - .route("/models", get(list_models)) - .route("/completion", post(perform_completion)) - .layer(middleware::from_fn(validate_api_token)) -} - -async fn validate_api_token(mut req: Request, next: Next) -> impl IntoResponse { - let token = req - .headers() - .get(http::header::AUTHORIZATION) - .and_then(|header| header.to_str().ok()) - .ok_or_else(|| { - Error::http( - StatusCode::BAD_REQUEST, - "missing authorization header".to_string(), - ) - })? - .strip_prefix("Bearer ") - .ok_or_else(|| { - Error::http( - StatusCode::BAD_REQUEST, - "invalid authorization header".to_string(), - ) - })?; - - let state = req.extensions().get::>().unwrap(); - match LlmTokenClaims::validate(token, &state.config) { - Ok(claims) => { - if state.db.is_access_token_revoked(&claims.jti).await? { - return Err(Error::http( - StatusCode::UNAUTHORIZED, - "unauthorized".to_string(), - )); - } - - tracing::Span::current() - .record("user_id", claims.user_id) - .record("login", claims.github_user_login.clone()) - .record("authn.jti", &claims.jti) - .record("is_staff", claims.is_staff); - - req.extensions_mut().insert(claims); - Ok::<_, Error>(next.run(req).await.into_response()) - } - Err(ValidateLlmTokenError::Expired) => Err(Error::Http( - StatusCode::UNAUTHORIZED, - "unauthorized".to_string(), - [( - HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME), - HeaderValue::from_static("true"), - )] - .into_iter() - .collect(), - )), - Err(_err) => Err(Error::http( - StatusCode::UNAUTHORIZED, - "unauthorized".to_string(), - )), - } -} - -async fn list_models( - Extension(state): Extension>, - Extension(claims): Extension, - country_code_header: Option>, -) -> Result> { - let country_code = country_code_header.map(|header| header.to_string()); - - let mut accessible_models = Vec::new(); - - for (provider, model) in state.db.all_models() { - let authorize_result = authorize_access_to_language_model( - &state.config, - &claims, - country_code.as_deref(), - provider, - &model.name, - ); - - if authorize_result.is_ok() { - accessible_models.push(rpc::LanguageModel { - provider, - name: model.name, - }); - } - } - - Ok(Json(ListModelsResponse { - models: accessible_models, - })) -} - -async fn perform_completion( - Extension(state): Extension>, - Extension(claims): Extension, - country_code_header: Option>, - Json(params): Json, -) -> Result { - let model = normalize_model_name( - state.db.model_names_for_provider(params.provider), - params.model, - ); - - let bypass_account_age_check = claims.has_llm_subscription || claims.bypass_account_age_check; - if !bypass_account_age_check { - if Utc::now().naive_utc() - claims.account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE { - Err(anyhow!("account too young"))? - } - } - - authorize_access_to_language_model( - &state.config, - &claims, - country_code_header - .map(|header| header.to_string()) - .as_deref(), - params.provider, - &model, - )?; - - check_usage_limit(&state, params.provider, &model, &claims).await?; - - let stream = match params.provider { - LanguageModelProvider::Anthropic => { - let api_key = if claims.is_staff { - state - .config - .anthropic_staff_api_key - .as_ref() - .context("no Anthropic AI staff API key configured on the server")? - } else { - state - .config - .anthropic_api_key - .as_ref() - .context("no Anthropic AI API key configured on the server")? - }; - - let mut request: anthropic::Request = - serde_json::from_str(params.provider_request.get())?; - - // Override the model on the request with the latest version of the model that is - // known to the server. - // - // Right now, we use the version that's defined in `model.id()`, but we will likely - // want to change this code once a new version of an Anthropic model is released, - // so that users can use the new version, without having to update Zed. - request.model = match model.as_str() { - "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(), - "claude-3-7-sonnet" => anthropic::Model::Claude3_7Sonnet.id().to_string(), - "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(), - "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(), - "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(), - _ => request.model, - }; - - let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info( - &state.http_client, - anthropic::ANTHROPIC_API_URL, - api_key, - request, - ) - .await - .map_err(|err| match err { - anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() { - Some(anthropic::ApiErrorCode::RateLimitError) => { - tracing::info!( - target: "upstream rate limit exceeded", - user_id = claims.user_id, - login = claims.github_user_login, - authn.jti = claims.jti, - is_staff = claims.is_staff, - provider = params.provider.to_string(), - model = model - ); - - Error::http( - StatusCode::TOO_MANY_REQUESTS, - "Upstream Anthropic rate limit exceeded.".to_string(), - ) - } - Some(anthropic::ApiErrorCode::InvalidRequestError) => { - Error::http(StatusCode::BAD_REQUEST, api_error.message.clone()) - } - Some(anthropic::ApiErrorCode::OverloadedError) => { - Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone()) - } - Some(_) => { - Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone()) - } - None => Error::Internal(anyhow!(err)), - }, - anthropic::AnthropicError::Other(err) => Error::Internal(err), - })?; - - if let Some(rate_limit_info) = rate_limit_info { - tracing::info!( - target: "upstream rate limit", - is_staff = claims.is_staff, - provider = params.provider.to_string(), - model = model, - tokens_remaining = rate_limit_info.tokens.as_ref().map(|limits| limits.remaining), - input_tokens_remaining = rate_limit_info.input_tokens.as_ref().map(|limits| limits.remaining), - output_tokens_remaining = rate_limit_info.output_tokens.as_ref().map(|limits| limits.remaining), - requests_remaining = rate_limit_info.requests.as_ref().map(|limits| limits.remaining), - requests_reset = ?rate_limit_info.requests.as_ref().map(|limits| limits.reset), - tokens_reset = ?rate_limit_info.tokens.as_ref().map(|limits| limits.reset), - input_tokens_reset = ?rate_limit_info.input_tokens.as_ref().map(|limits| limits.reset), - output_tokens_reset = ?rate_limit_info.output_tokens.as_ref().map(|limits| limits.reset), - ); - } - - chunks - .map(move |event| { - let chunk = event?; - let ( - input_tokens, - output_tokens, - cache_creation_input_tokens, - cache_read_input_tokens, - ) = match &chunk { - anthropic::Event::MessageStart { - message: anthropic::Response { usage, .. }, - } - | anthropic::Event::MessageDelta { usage, .. } => ( - usage.input_tokens.unwrap_or(0) as usize, - usage.output_tokens.unwrap_or(0) as usize, - usage.cache_creation_input_tokens.unwrap_or(0) as usize, - usage.cache_read_input_tokens.unwrap_or(0) as usize, - ), - _ => (0, 0, 0, 0), - }; - - anyhow::Ok(CompletionChunk { - bytes: serde_json::to_vec(&chunk).unwrap(), - input_tokens, - output_tokens, - cache_creation_input_tokens, - cache_read_input_tokens, - }) - }) - .boxed() - } - LanguageModelProvider::OpenAi => { - let api_key = state - .config - .openai_api_key - .as_ref() - .context("no OpenAI API key configured on the server")?; - let chunks = open_ai::stream_completion( - &state.http_client, - open_ai::OPEN_AI_API_URL, - api_key, - serde_json::from_str(params.provider_request.get())?, - ) - .await?; - - chunks - .map(|event| { - event.map(|chunk| { - let input_tokens = - chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize; - let output_tokens = - chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize; - CompletionChunk { - bytes: serde_json::to_vec(&chunk).unwrap(), - input_tokens, - output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } - }) - }) - .boxed() - } - LanguageModelProvider::Google => { - let api_key = state - .config - .google_ai_api_key - .as_ref() - .context("no Google AI API key configured on the server")?; - let chunks = google_ai::stream_generate_content( - &state.http_client, - google_ai::API_URL, - api_key, - serde_json::from_str(params.provider_request.get())?, - ) - .await?; - - chunks - .map(|event| { - event.map(|chunk| { - // TODO - implement token counting for Google AI - CompletionChunk { - bytes: serde_json::to_vec(&chunk).unwrap(), - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } - }) - }) - .boxed() - } - }; - - Ok(Response::new(Body::wrap_stream(TokenCountingStream { - state, - claims, - provider: params.provider, - model, - tokens: TokenUsage::default(), - inner_stream: stream, - }))) -} - -fn normalize_model_name(known_models: Vec, name: String) -> String { - if let Some(known_model_name) = known_models - .iter() - .filter(|known_model_name| name.starts_with(known_model_name.as_str())) - .max_by_key(|known_model_name| known_model_name.len()) - { - known_model_name.to_string() - } else { - name - } -} - /// The maximum monthly spending an individual user can reach on the free tier /// before they have to pay. pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10); @@ -452,330 +14,3 @@ pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10); /// /// Used to prevent surprise bills. pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10); - -async fn check_usage_limit( - state: &Arc, - provider: LanguageModelProvider, - model_name: &str, - claims: &LlmTokenClaims, -) -> Result<()> { - if claims.is_staff { - return Ok(()); - } - - let user_id = UserId::from_proto(claims.user_id); - let model = state.db.model(provider, model_name)?; - let free_tier = claims.free_tier_monthly_spending_limit(); - - let spending_this_month = state - .db - .get_user_spending_for_month(user_id, Utc::now()) - .await?; - if spending_this_month >= free_tier { - if !claims.has_llm_subscription { - return Err(Error::http( - StatusCode::PAYMENT_REQUIRED, - "Maximum spending limit reached for this month.".to_string(), - )); - } - - let monthly_spend = spending_this_month.saturating_sub(free_tier); - if monthly_spend >= Cents(claims.max_monthly_spend_in_cents) { - return Err(Error::Http( - StatusCode::FORBIDDEN, - "Maximum spending limit reached for this month.".to_string(), - [( - HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME), - HeaderValue::from_static("true"), - )] - .into_iter() - .collect(), - )); - } - } - - let active_users = state.get_active_user_count(provider, model_name).await?; - - let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1); - let users_in_recent_days = active_users.users_in_recent_days.max(1); - - let per_user_max_requests_per_minute = - model.max_requests_per_minute as usize / users_in_recent_minutes; - let per_user_max_tokens_per_minute = - model.max_tokens_per_minute as usize / users_in_recent_minutes; - let per_user_max_input_tokens_per_minute = - model.max_input_tokens_per_minute as usize / users_in_recent_minutes; - let per_user_max_output_tokens_per_minute = - model.max_output_tokens_per_minute as usize / users_in_recent_minutes; - let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days; - - let usage = state - .db - .get_usage(user_id, provider, model_name, Utc::now()) - .await?; - - let checks = match (provider, model_name) { - (LanguageModelProvider::Anthropic, "claude-3-7-sonnet") => vec![ - ( - usage.requests_this_minute, - per_user_max_requests_per_minute, - UsageMeasure::RequestsPerMinute, - ), - ( - usage.input_tokens_this_minute, - per_user_max_tokens_per_minute, - UsageMeasure::InputTokensPerMinute, - ), - ( - usage.output_tokens_this_minute, - per_user_max_tokens_per_minute, - UsageMeasure::OutputTokensPerMinute, - ), - ( - usage.tokens_this_day, - per_user_max_tokens_per_day, - UsageMeasure::TokensPerDay, - ), - ], - _ => vec![ - ( - usage.requests_this_minute, - per_user_max_requests_per_minute, - UsageMeasure::RequestsPerMinute, - ), - ( - usage.tokens_this_minute, - per_user_max_tokens_per_minute, - UsageMeasure::TokensPerMinute, - ), - ( - usage.tokens_this_day, - per_user_max_tokens_per_day, - UsageMeasure::TokensPerDay, - ), - ], - }; - - for (used, limit, usage_measure) in checks { - if used > limit { - let resource = match usage_measure { - UsageMeasure::RequestsPerMinute => "requests_per_minute", - UsageMeasure::TokensPerMinute => "tokens_per_minute", - UsageMeasure::InputTokensPerMinute => "input_tokens_per_minute", - UsageMeasure::OutputTokensPerMinute => "output_tokens_per_minute", - UsageMeasure::TokensPerDay => "tokens_per_day", - }; - - tracing::info!( - target: "user rate limit", - user_id = claims.user_id, - login = claims.github_user_login, - authn.jti = claims.jti, - is_staff = claims.is_staff, - provider = provider.to_string(), - model = model.name, - usage_measure = resource, - requests_this_minute = usage.requests_this_minute, - tokens_this_minute = usage.tokens_this_minute, - input_tokens_this_minute = usage.input_tokens_this_minute, - output_tokens_this_minute = usage.output_tokens_this_minute, - tokens_this_day = usage.tokens_this_day, - users_in_recent_minutes = users_in_recent_minutes, - users_in_recent_days = users_in_recent_days, - max_requests_per_minute = per_user_max_requests_per_minute, - max_tokens_per_minute = per_user_max_tokens_per_minute, - max_input_tokens_per_minute = per_user_max_input_tokens_per_minute, - max_output_tokens_per_minute = per_user_max_output_tokens_per_minute, - max_tokens_per_day = per_user_max_tokens_per_day, - ); - - SnowflakeRow::new( - "Language Model Rate Limited", - Some(claims.metrics_id), - claims.is_staff, - claims.system_id.clone(), - json!({ - "usage": usage, - "users_in_recent_minutes": users_in_recent_minutes, - "users_in_recent_days": users_in_recent_days, - "max_requests_per_minute": per_user_max_requests_per_minute, - "max_tokens_per_minute": per_user_max_tokens_per_minute, - "max_input_tokens_per_minute": per_user_max_input_tokens_per_minute, - "max_output_tokens_per_minute": per_user_max_output_tokens_per_minute, - "max_tokens_per_day": per_user_max_tokens_per_day, - "plan": match claims.plan { - Plan::Free => "free".to_string(), - Plan::ZedPro => "zed_pro".to_string(), - }, - "model": model.name.clone(), - "provider": provider.to_string(), - "usage_measure": resource.to_string(), - }), - ) - .write(&state.kinesis_client, &state.config.kinesis_stream) - .await - .log_err(); - - return Err(Error::http( - StatusCode::TOO_MANY_REQUESTS, - format!("Rate limit exceeded. Maximum {} reached.", resource), - )); - } - } - - Ok(()) -} - -struct CompletionChunk { - bytes: Vec, - input_tokens: usize, - output_tokens: usize, - cache_creation_input_tokens: usize, - cache_read_input_tokens: usize, -} - -struct TokenCountingStream { - state: Arc, - claims: LlmTokenClaims, - provider: LanguageModelProvider, - model: String, - tokens: TokenUsage, - inner_stream: S, -} - -impl Stream for TokenCountingStream -where - S: Stream> + Unpin, -{ - type Item = Result, anyhow::Error>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut self.inner_stream).poll_next(cx) { - Poll::Ready(Some(Ok(mut chunk))) => { - chunk.bytes.push(b'\n'); - self.tokens.input += chunk.input_tokens; - self.tokens.output += chunk.output_tokens; - self.tokens.input_cache_creation += chunk.cache_creation_input_tokens; - self.tokens.input_cache_read += chunk.cache_read_input_tokens; - Poll::Ready(Some(Ok(chunk.bytes))) - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -impl Drop for TokenCountingStream { - fn drop(&mut self) { - let state = self.state.clone(); - let claims = self.claims.clone(); - let provider = self.provider; - let model = std::mem::take(&mut self.model); - let tokens = self.tokens; - self.state.executor.spawn_detached(async move { - let usage = state - .db - .record_usage( - UserId::from_proto(claims.user_id), - claims.is_staff, - provider, - &model, - tokens, - claims.has_llm_subscription, - Cents(claims.max_monthly_spend_in_cents), - claims.free_tier_monthly_spending_limit(), - Utc::now(), - ) - .await - .log_err(); - - if let Some(usage) = usage { - tracing::info!( - target: "user usage", - user_id = claims.user_id, - login = claims.github_user_login, - authn.jti = claims.jti, - is_staff = claims.is_staff, - provider = provider.to_string(), - model = model, - requests_this_minute = usage.requests_this_minute, - tokens_this_minute = usage.tokens_this_minute, - input_tokens_this_minute = usage.input_tokens_this_minute, - output_tokens_this_minute = usage.output_tokens_this_minute, - ); - - let properties = json!({ - "has_llm_subscription": claims.has_llm_subscription, - "max_monthly_spend_in_cents": claims.max_monthly_spend_in_cents, - "plan": match claims.plan { - Plan::Free => "free".to_string(), - Plan::ZedPro => "zed_pro".to_string(), - }, - "model": model, - "provider": provider, - "usage": usage, - "tokens": tokens - }); - SnowflakeRow::new( - "Language Model Used", - Some(claims.metrics_id), - claims.is_staff, - claims.system_id.clone(), - properties, - ) - .write(&state.kinesis_client, &state.config.kinesis_stream) - .await - .log_err(); - } - }) - } -} - -pub fn log_usage_periodically(state: Arc) { - state.executor.clone().spawn_detached(async move { - loop { - state - .executor - .sleep(std::time::Duration::from_secs(30)) - .await; - - for provider in LanguageModelProvider::iter() { - for model in state.db.model_names_for_provider(provider) { - if let Some(active_user_count) = state - .get_active_user_count(provider, &model) - .await - .log_err() - { - tracing::info!( - target: "active user counts", - provider = provider.to_string(), - model = model, - users_in_recent_minutes = active_user_count.users_in_recent_minutes, - users_in_recent_days = active_user_count.users_in_recent_days, - ); - } - } - } - - if let Some(usages) = state - .db - .get_application_wide_usages_by_model(Utc::now()) - .await - .log_err() - { - for usage in usages { - tracing::info!( - target: "computed usage", - provider = usage.provider.to_string(), - model = usage.model, - requests_this_minute = usage.requests_this_minute, - tokens_this_minute = usage.tokens_this_minute, - input_tokens_this_minute = usage.input_tokens_this_minute, - output_tokens_this_minute = usage.output_tokens_this_minute, - ); - } - } - } - }) -} diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs deleted file mode 100644 index 1ce7d7afdc..0000000000 --- a/crates/collab/src/llm/authorization.rs +++ /dev/null @@ -1,330 +0,0 @@ -use reqwest::StatusCode; -use rpc::LanguageModelProvider; - -use crate::llm::LlmTokenClaims; -use crate::{Config, Error, Result}; - -pub fn authorize_access_to_language_model( - config: &Config, - claims: &LlmTokenClaims, - country_code: Option<&str>, - provider: LanguageModelProvider, - model: &str, -) -> Result<()> { - authorize_access_for_country(config, country_code, provider)?; - authorize_access_to_model(config, claims, provider, model)?; - Ok(()) -} - -fn authorize_access_to_model( - config: &Config, - claims: &LlmTokenClaims, - provider: LanguageModelProvider, - model: &str, -) -> Result<()> { - if claims.is_staff { - return Ok(()); - } - - if provider == LanguageModelProvider::Anthropic { - if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" { - return Ok(()); - } - - if claims.has_llm_closed_beta_feature_flag - && Some(model) == config.llm_closed_beta_model_name.as_deref() - { - return Ok(()); - } - } - - Err(Error::http( - StatusCode::FORBIDDEN, - format!("access to model {model:?} is not included in your plan"), - )) -} - -fn authorize_access_for_country( - config: &Config, - country_code: Option<&str>, - provider: LanguageModelProvider, -) -> Result<()> { - // In development we won't have the `CF-IPCountry` header, so we can't check - // the country code. - // - // This shouldn't be necessary, as anyone running in development will need to provide - // their own API credentials in order to use an LLM provider. - if config.is_development() { - return Ok(()); - } - - // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry - let country_code = match country_code { - // `XX` - Used for clients without country code data. - None | Some("XX") => Err(Error::http( - StatusCode::BAD_REQUEST, - "no country code".to_string(), - ))?, - // `T1` - Used for clients using the Tor network. - Some("T1") => Err(Error::http( - StatusCode::FORBIDDEN, - format!("access to {provider:?} models is not available over Tor"), - ))?, - Some(country_code) => country_code, - }; - - let is_country_supported_by_provider = match provider { - LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code), - LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code), - LanguageModelProvider::Google => google_ai::is_supported_country(country_code), - }; - if !is_country_supported_by_provider { - Err(Error::http( - StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS, - format!( - "access to {provider:?} models is not available in your region ({country_code})" - ), - ))? - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use axum::response::IntoResponse; - use pretty_assertions::assert_eq; - use rpc::proto::Plan; - - use super::*; - - #[gpui::test] - async fn test_authorize_access_to_language_model_with_supported_country( - _cx: &mut gpui::TestAppContext, - ) { - let config = Config::test(); - - let claims = LlmTokenClaims { - user_id: 99, - plan: Plan::ZedPro, - is_staff: true, - ..Default::default() - }; - - let cases = vec![ - (LanguageModelProvider::Anthropic, "US"), // United States - (LanguageModelProvider::Anthropic, "GB"), // United Kingdom - (LanguageModelProvider::OpenAi, "US"), // United States - (LanguageModelProvider::OpenAi, "GB"), // United Kingdom - (LanguageModelProvider::Google, "US"), // United States - (LanguageModelProvider::Google, "GB"), // United Kingdom - ]; - - for (provider, country_code) in cases { - authorize_access_to_language_model( - &config, - &claims, - Some(country_code), - provider, - "the-model", - ) - .unwrap_or_else(|_| { - panic!("expected authorization to return Ok for {provider:?}: {country_code}") - }) - } - } - - #[gpui::test] - async fn test_authorize_access_to_language_model_with_unsupported_country( - _cx: &mut gpui::TestAppContext, - ) { - let config = Config::test(); - - let claims = LlmTokenClaims { - user_id: 99, - plan: Plan::ZedPro, - ..Default::default() - }; - - let cases = vec![ - (LanguageModelProvider::Anthropic, "AF"), // Afghanistan - (LanguageModelProvider::Anthropic, "BY"), // Belarus - (LanguageModelProvider::Anthropic, "CF"), // Central African Republic - (LanguageModelProvider::Anthropic, "CN"), // China - (LanguageModelProvider::Anthropic, "CU"), // Cuba - (LanguageModelProvider::Anthropic, "ER"), // Eritrea - (LanguageModelProvider::Anthropic, "ET"), // Ethiopia - (LanguageModelProvider::Anthropic, "IR"), // Iran - (LanguageModelProvider::Anthropic, "KP"), // North Korea - (LanguageModelProvider::Anthropic, "XK"), // Kosovo - (LanguageModelProvider::Anthropic, "LY"), // Libya - (LanguageModelProvider::Anthropic, "MM"), // Myanmar - (LanguageModelProvider::Anthropic, "RU"), // Russia - (LanguageModelProvider::Anthropic, "SO"), // Somalia - (LanguageModelProvider::Anthropic, "SS"), // South Sudan - (LanguageModelProvider::Anthropic, "SD"), // Sudan - (LanguageModelProvider::Anthropic, "SY"), // Syria - (LanguageModelProvider::Anthropic, "VE"), // Venezuela - (LanguageModelProvider::Anthropic, "YE"), // Yemen - (LanguageModelProvider::OpenAi, "KP"), // North Korea - (LanguageModelProvider::Google, "KP"), // North Korea - ]; - - for (provider, country_code) in cases { - let error_response = authorize_access_to_language_model( - &config, - &claims, - Some(country_code), - provider, - "the-model", - ) - .expect_err(&format!( - "expected authorization to return an error for {provider:?}: {country_code}" - )) - .into_response(); - - assert_eq!( - error_response.status(), - StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS - ); - let response_body = hyper::body::to_bytes(error_response.into_body()) - .await - .unwrap() - .to_vec(); - assert_eq!( - String::from_utf8(response_body).unwrap(), - format!( - "access to {provider:?} models is not available in your region ({country_code})" - ) - ); - } - } - - #[gpui::test] - async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) { - let config = Config::test(); - - let claims = LlmTokenClaims { - user_id: 99, - plan: Plan::ZedPro, - ..Default::default() - }; - - let cases = vec![ - (LanguageModelProvider::Anthropic, "T1"), // Tor - (LanguageModelProvider::OpenAi, "T1"), // Tor - (LanguageModelProvider::Google, "T1"), // Tor - ]; - - for (provider, country_code) in cases { - let error_response = authorize_access_to_language_model( - &config, - &claims, - Some(country_code), - provider, - "the-model", - ) - .expect_err(&format!( - "expected authorization to return an error for {provider:?}: {country_code}" - )) - .into_response(); - - assert_eq!(error_response.status(), StatusCode::FORBIDDEN); - let response_body = hyper::body::to_bytes(error_response.into_body()) - .await - .unwrap() - .to_vec(); - assert_eq!( - String::from_utf8(response_body).unwrap(), - format!("access to {provider:?} models is not available over Tor") - ); - } - } - - #[gpui::test] - async fn test_authorize_access_to_language_model_based_on_plan() { - let config = Config::test(); - - let test_cases = vec![ - // Pro plan should have access to claude-3.5-sonnet - ( - Plan::ZedPro, - LanguageModelProvider::Anthropic, - "claude-3-5-sonnet", - true, - ), - // Free plan should have access to claude-3.5-sonnet - ( - Plan::Free, - LanguageModelProvider::Anthropic, - "claude-3-5-sonnet", - true, - ), - // Pro plan should NOT have access to other Anthropic models - ( - Plan::ZedPro, - LanguageModelProvider::Anthropic, - "claude-3-opus", - false, - ), - ]; - - for (plan, provider, model, expected_access) in test_cases { - let claims = LlmTokenClaims { - plan, - ..Default::default() - }; - - let result = - authorize_access_to_language_model(&config, &claims, Some("US"), provider, model); - - if expected_access { - assert!( - result.is_ok(), - "Expected access to be granted for plan {:?}, provider {:?}, model {}", - plan, - provider, - model - ); - } else { - let error = result.expect_err(&format!( - "Expected access to be denied for plan {:?}, provider {:?}, model {}", - plan, provider, model - )); - let response = error.into_response(); - assert_eq!(response.status(), StatusCode::FORBIDDEN); - } - } - } - - #[gpui::test] - async fn test_authorize_access_to_language_model_for_staff() { - let config = Config::test(); - - let claims = LlmTokenClaims { - is_staff: true, - ..Default::default() - }; - - // Staff should have access to all models - let test_cases = vec![ - (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"), - (LanguageModelProvider::Anthropic, "claude-2"), - (LanguageModelProvider::Anthropic, "claude-123-agi"), - (LanguageModelProvider::OpenAi, "gpt-4"), - (LanguageModelProvider::Google, "gemini-pro"), - ]; - - for (provider, model) in test_cases { - let result = - authorize_access_to_language_model(&config, &claims, Some("US"), provider, model); - - assert!( - result.is_ok(), - "Expected staff to have access to provider {:?}, model {}", - provider, - model - ); - } - } -} diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index 6a46184171..f56e9e61e3 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -20,7 +20,6 @@ use std::future::Future; use std::sync::Arc; use anyhow::anyhow; -pub use queries::usages::{ActiveUserCount, TokenUsage}; pub use sea_orm::ConnectOptions; use sea_orm::prelude::*; use sea_orm::{ diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index 79a17999b7..4a4a10fb51 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -2,5 +2,4 @@ use super::*; pub mod billing_events; pub mod providers; -pub mod revoked_access_tokens; pub mod usages; diff --git a/crates/collab/src/llm/db/queries/revoked_access_tokens.rs b/crates/collab/src/llm/db/queries/revoked_access_tokens.rs deleted file mode 100644 index 31d70192a0..0000000000 --- a/crates/collab/src/llm/db/queries/revoked_access_tokens.rs +++ /dev/null @@ -1,15 +0,0 @@ -use super::*; - -impl LlmDatabase { - /// Returns whether the access token with the given `jti` has been revoked. - pub async fn is_access_token_revoked(&self, jti: &str) -> Result { - self.transaction(|tx| async move { - Ok(revoked_access_token::Entity::find() - .filter(revoked_access_token::Column::Jti.eq(jti)) - .one(&*tx) - .await? - .is_some()) - }) - .await - } -} diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 3dee5a41f6..6313e7572c 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -1,56 +1,12 @@ use crate::db::UserId; use crate::llm::Cents; -use chrono::{Datelike, Duration}; +use chrono::Datelike; use futures::StreamExt as _; -use rpc::LanguageModelProvider; -use sea_orm::QuerySelect; -use std::{iter, str::FromStr}; +use std::str::FromStr; use strum::IntoEnumIterator as _; use super::*; -#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)] -pub struct TokenUsage { - pub input: usize, - pub input_cache_creation: usize, - pub input_cache_read: usize, - pub output: usize, -} - -impl TokenUsage { - pub fn total(&self) -> usize { - self.input + self.input_cache_creation + self.input_cache_read + self.output - } -} - -#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)] -pub struct Usage { - pub requests_this_minute: usize, - pub tokens_this_minute: usize, - pub input_tokens_this_minute: usize, - pub output_tokens_this_minute: usize, - pub tokens_this_day: usize, - pub tokens_this_month: TokenUsage, - pub spending_this_month: Cents, - pub lifetime_spending: Cents, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct ApplicationWideUsage { - pub provider: LanguageModelProvider, - pub model: String, - pub requests_this_minute: usize, - pub tokens_this_minute: usize, - pub input_tokens_this_minute: usize, - pub output_tokens_this_minute: usize, -} - -#[derive(Clone, Copy, Debug, Default)] -pub struct ActiveUserCount { - pub users_in_recent_minutes: usize, - pub users_in_recent_days: usize, -} - impl LlmDatabase { pub async fn initialize_usage_measures(&mut self) -> Result<()> { let all_measures = self @@ -90,100 +46,6 @@ impl LlmDatabase { Ok(()) } - pub async fn get_application_wide_usages_by_model( - &self, - now: DateTimeUtc, - ) -> Result> { - self.transaction(|tx| async move { - let past_minute = now - Duration::minutes(1); - let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute]; - let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute]; - let input_tokens_per_minute = - self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute]; - let output_tokens_per_minute = - self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute]; - - let mut results = Vec::new(); - for ((provider, model_name), model) in self.models.iter() { - let mut usages = usage::Entity::find() - .filter( - usage::Column::Timestamp - .gte(past_minute.naive_utc()) - .and(usage::Column::IsStaff.eq(false)) - .and(usage::Column::ModelId.eq(model.id)) - .and( - usage::Column::MeasureId - .eq(requests_per_minute) - .or(usage::Column::MeasureId.eq(tokens_per_minute)), - ), - ) - .stream(&*tx) - .await?; - - let mut requests_this_minute = 0; - let mut tokens_this_minute = 0; - let mut input_tokens_this_minute = 0; - let mut output_tokens_this_minute = 0; - while let Some(usage) = usages.next().await { - let usage = usage?; - if usage.measure_id == requests_per_minute { - requests_this_minute += Self::get_live_buckets( - &usage, - now.naive_utc(), - UsageMeasure::RequestsPerMinute, - ) - .0 - .iter() - .copied() - .sum::() as usize; - } else if usage.measure_id == tokens_per_minute { - tokens_this_minute += Self::get_live_buckets( - &usage, - now.naive_utc(), - UsageMeasure::TokensPerMinute, - ) - .0 - .iter() - .copied() - .sum::() as usize; - } else if usage.measure_id == input_tokens_per_minute { - input_tokens_this_minute += Self::get_live_buckets( - &usage, - now.naive_utc(), - UsageMeasure::InputTokensPerMinute, - ) - .0 - .iter() - .copied() - .sum::() as usize; - } else if usage.measure_id == output_tokens_per_minute { - output_tokens_this_minute += Self::get_live_buckets( - &usage, - now.naive_utc(), - UsageMeasure::OutputTokensPerMinute, - ) - .0 - .iter() - .copied() - .sum::() as usize; - } - } - - results.push(ApplicationWideUsage { - provider: *provider, - model: model_name.clone(), - requests_this_minute, - tokens_this_minute, - input_tokens_this_minute, - output_tokens_this_minute, - }) - } - - Ok(results) - }) - .await - } - pub async fn get_user_spending_for_month( &self, user_id: UserId, @@ -223,499 +85,6 @@ impl LlmDatabase { }) .await } - - pub async fn get_usage( - &self, - user_id: UserId, - provider: LanguageModelProvider, - model_name: &str, - now: DateTimeUtc, - ) -> Result { - self.transaction(|tx| async move { - let model = self - .models - .get(&(provider, model_name.to_string())) - .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?; - - let usages = usage::Entity::find() - .filter( - usage::Column::UserId - .eq(user_id) - .and(usage::Column::ModelId.eq(model.id)), - ) - .all(&*tx) - .await?; - - let month = now.date_naive().month() as i32; - let year = now.date_naive().year(); - let monthly_usage = monthly_usage::Entity::find() - .filter( - monthly_usage::Column::UserId - .eq(user_id) - .and(monthly_usage::Column::ModelId.eq(model.id)) - .and(monthly_usage::Column::Month.eq(month)) - .and(monthly_usage::Column::Year.eq(year)), - ) - .one(&*tx) - .await?; - let lifetime_usage = lifetime_usage::Entity::find() - .filter( - lifetime_usage::Column::UserId - .eq(user_id) - .and(lifetime_usage::Column::ModelId.eq(model.id)), - ) - .one(&*tx) - .await?; - - let requests_this_minute = - self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?; - let tokens_this_minute = - self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?; - let input_tokens_this_minute = - self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?; - let output_tokens_this_minute = - self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?; - let tokens_this_day = - self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?; - let spending_this_month = if let Some(monthly_usage) = &monthly_usage { - calculate_spending( - model, - monthly_usage.input_tokens as usize, - monthly_usage.cache_creation_input_tokens as usize, - monthly_usage.cache_read_input_tokens as usize, - monthly_usage.output_tokens as usize, - ) - } else { - Cents::ZERO - }; - let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage { - calculate_spending( - model, - lifetime_usage.input_tokens as usize, - lifetime_usage.cache_creation_input_tokens as usize, - lifetime_usage.cache_read_input_tokens as usize, - lifetime_usage.output_tokens as usize, - ) - } else { - Cents::ZERO - }; - - Ok(Usage { - requests_this_minute, - tokens_this_minute, - input_tokens_this_minute, - output_tokens_this_minute, - tokens_this_day, - tokens_this_month: TokenUsage { - input: monthly_usage - .as_ref() - .map_or(0, |usage| usage.input_tokens as usize), - input_cache_creation: monthly_usage - .as_ref() - .map_or(0, |usage| usage.cache_creation_input_tokens as usize), - input_cache_read: monthly_usage - .as_ref() - .map_or(0, |usage| usage.cache_read_input_tokens as usize), - output: monthly_usage - .as_ref() - .map_or(0, |usage| usage.output_tokens as usize), - }, - spending_this_month, - lifetime_spending, - }) - }) - .await - } - - pub async fn record_usage( - &self, - user_id: UserId, - is_staff: bool, - provider: LanguageModelProvider, - model_name: &str, - tokens: TokenUsage, - has_llm_subscription: bool, - max_monthly_spend: Cents, - free_tier_monthly_spending_limit: Cents, - now: DateTimeUtc, - ) -> Result { - self.transaction(|tx| async move { - let model = self.model(provider, model_name)?; - - let usages = usage::Entity::find() - .filter( - usage::Column::UserId - .eq(user_id) - .and(usage::Column::ModelId.eq(model.id)), - ) - .all(&*tx) - .await?; - - let requests_this_minute = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::RequestsPerMinute, - now, - 1, - &tx, - ) - .await?; - let tokens_this_minute = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::TokensPerMinute, - now, - tokens.total(), - &tx, - ) - .await?; - let input_tokens_this_minute = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::InputTokensPerMinute, - now, - // Cache read input tokens are not counted for the purposes of rate limits (but they are still billed). - tokens.input + tokens.input_cache_creation, - &tx, - ) - .await?; - let output_tokens_this_minute = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::OutputTokensPerMinute, - now, - tokens.output, - &tx, - ) - .await?; - let tokens_this_day = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::TokensPerDay, - now, - tokens.total(), - &tx, - ) - .await?; - - let month = now.date_naive().month() as i32; - let year = now.date_naive().year(); - - // Update monthly usage - let monthly_usage = monthly_usage::Entity::find() - .filter( - monthly_usage::Column::UserId - .eq(user_id) - .and(monthly_usage::Column::ModelId.eq(model.id)) - .and(monthly_usage::Column::Month.eq(month)) - .and(monthly_usage::Column::Year.eq(year)), - ) - .one(&*tx) - .await?; - - let monthly_usage = match monthly_usage { - Some(usage) => { - monthly_usage::Entity::update(monthly_usage::ActiveModel { - id: ActiveValue::unchanged(usage.id), - input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64), - cache_creation_input_tokens: ActiveValue::set( - usage.cache_creation_input_tokens + tokens.input_cache_creation as i64, - ), - cache_read_input_tokens: ActiveValue::set( - usage.cache_read_input_tokens + tokens.input_cache_read as i64, - ), - output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64), - ..Default::default() - }) - .exec(&*tx) - .await? - } - None => { - monthly_usage::ActiveModel { - user_id: ActiveValue::set(user_id), - model_id: ActiveValue::set(model.id), - month: ActiveValue::set(month), - year: ActiveValue::set(year), - input_tokens: ActiveValue::set(tokens.input as i64), - cache_creation_input_tokens: ActiveValue::set( - tokens.input_cache_creation as i64, - ), - cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64), - output_tokens: ActiveValue::set(tokens.output as i64), - ..Default::default() - } - .insert(&*tx) - .await? - } - }; - - let spending_this_month = calculate_spending( - model, - monthly_usage.input_tokens as usize, - monthly_usage.cache_creation_input_tokens as usize, - monthly_usage.cache_read_input_tokens as usize, - monthly_usage.output_tokens as usize, - ); - - if !is_staff - && spending_this_month > free_tier_monthly_spending_limit - && has_llm_subscription - && (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend - { - billing_event::ActiveModel { - id: ActiveValue::not_set(), - idempotency_key: ActiveValue::not_set(), - user_id: ActiveValue::set(user_id), - model_id: ActiveValue::set(model.id), - input_tokens: ActiveValue::set(tokens.input as i64), - input_cache_creation_tokens: ActiveValue::set( - tokens.input_cache_creation as i64, - ), - input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64), - output_tokens: ActiveValue::set(tokens.output as i64), - } - .insert(&*tx) - .await?; - } - - // Update lifetime usage - let lifetime_usage = lifetime_usage::Entity::find() - .filter( - lifetime_usage::Column::UserId - .eq(user_id) - .and(lifetime_usage::Column::ModelId.eq(model.id)), - ) - .one(&*tx) - .await?; - - let lifetime_usage = match lifetime_usage { - Some(usage) => { - lifetime_usage::Entity::update(lifetime_usage::ActiveModel { - id: ActiveValue::unchanged(usage.id), - input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64), - cache_creation_input_tokens: ActiveValue::set( - usage.cache_creation_input_tokens + tokens.input_cache_creation as i64, - ), - cache_read_input_tokens: ActiveValue::set( - usage.cache_read_input_tokens + tokens.input_cache_read as i64, - ), - output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64), - ..Default::default() - }) - .exec(&*tx) - .await? - } - None => { - lifetime_usage::ActiveModel { - user_id: ActiveValue::set(user_id), - model_id: ActiveValue::set(model.id), - input_tokens: ActiveValue::set(tokens.input as i64), - cache_creation_input_tokens: ActiveValue::set( - tokens.input_cache_creation as i64, - ), - cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64), - output_tokens: ActiveValue::set(tokens.output as i64), - ..Default::default() - } - .insert(&*tx) - .await? - } - }; - - let lifetime_spending = calculate_spending( - model, - lifetime_usage.input_tokens as usize, - lifetime_usage.cache_creation_input_tokens as usize, - lifetime_usage.cache_read_input_tokens as usize, - lifetime_usage.output_tokens as usize, - ); - - Ok(Usage { - requests_this_minute, - tokens_this_minute, - input_tokens_this_minute, - output_tokens_this_minute, - tokens_this_day, - tokens_this_month: TokenUsage { - input: monthly_usage.input_tokens as usize, - input_cache_creation: monthly_usage.cache_creation_input_tokens as usize, - input_cache_read: monthly_usage.cache_read_input_tokens as usize, - output: monthly_usage.output_tokens as usize, - }, - spending_this_month, - lifetime_spending, - }) - }) - .await - } - - /// Returns the active user count for the specified model. - pub async fn get_active_user_count( - &self, - provider: LanguageModelProvider, - model_name: &str, - now: DateTimeUtc, - ) -> Result { - self.transaction(|tx| async move { - let minute_since = now - Duration::minutes(5); - let day_since = now - Duration::days(5); - - let model = self - .models - .get(&(provider, model_name.to_string())) - .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?; - - let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute]; - - let users_in_recent_minutes = usage::Entity::find() - .filter( - usage::Column::ModelId - .eq(model.id) - .and(usage::Column::MeasureId.eq(tokens_per_minute)) - .and(usage::Column::Timestamp.gte(minute_since.naive_utc())) - .and(usage::Column::IsStaff.eq(false)), - ) - .select_only() - .column(usage::Column::UserId) - .group_by(usage::Column::UserId) - .count(&*tx) - .await? as usize; - - let users_in_recent_days = usage::Entity::find() - .filter( - usage::Column::ModelId - .eq(model.id) - .and(usage::Column::MeasureId.eq(tokens_per_minute)) - .and(usage::Column::Timestamp.gte(day_since.naive_utc())) - .and(usage::Column::IsStaff.eq(false)), - ) - .select_only() - .column(usage::Column::UserId) - .group_by(usage::Column::UserId) - .count(&*tx) - .await? as usize; - - Ok(ActiveUserCount { - users_in_recent_minutes, - users_in_recent_days, - }) - }) - .await - } - - async fn update_usage_for_measure( - &self, - user_id: UserId, - is_staff: bool, - model_id: ModelId, - usages: &[usage::Model], - usage_measure: UsageMeasure, - now: DateTimeUtc, - usage_to_add: usize, - tx: &DatabaseTransaction, - ) -> Result { - let now = now.naive_utc(); - let measure_id = *self - .usage_measure_ids - .get(&usage_measure) - .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?; - - let mut id = None; - let mut timestamp = now; - let mut buckets = vec![0_i64]; - - if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) { - id = Some(old_usage.id); - let (live_buckets, buckets_since) = - Self::get_live_buckets(old_usage, now, usage_measure); - if !live_buckets.is_empty() { - buckets.clear(); - buckets.extend_from_slice(live_buckets); - buckets.extend(iter::repeat(0).take(buckets_since)); - timestamp = - old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32); - } - } - - *buckets.last_mut().unwrap() += usage_to_add as i64; - let total_usage = buckets.iter().sum::() as usize; - - let mut model = usage::ActiveModel { - user_id: ActiveValue::set(user_id), - is_staff: ActiveValue::set(is_staff), - model_id: ActiveValue::set(model_id), - measure_id: ActiveValue::set(measure_id), - timestamp: ActiveValue::set(timestamp), - buckets: ActiveValue::set(buckets), - ..Default::default() - }; - - if let Some(id) = id { - model.id = ActiveValue::unchanged(id); - model.update(tx).await?; - } else { - usage::Entity::insert(model) - .exec_without_returning(tx) - .await?; - } - - Ok(total_usage) - } - - fn get_usage_for_measure( - &self, - usages: &[usage::Model], - now: DateTimeUtc, - usage_measure: UsageMeasure, - ) -> Result { - let now = now.naive_utc(); - let measure_id = *self - .usage_measure_ids - .get(&usage_measure) - .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?; - let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else { - return Ok(0); - }; - - let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure); - Ok(live_buckets.iter().sum::() as _) - } - - fn get_live_buckets( - usage: &usage::Model, - now: chrono::NaiveDateTime, - measure: UsageMeasure, - ) -> (&[i64], usize) { - let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0); - let buckets_since_usage = - seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32; - let buckets_since_usage = buckets_since_usage.ceil() as usize; - let mut live_buckets = &[] as &[i64]; - if buckets_since_usage < measure.bucket_count() { - let expired_bucket_count = - (usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count()); - live_buckets = &usage.buckets[expired_bucket_count..]; - while live_buckets.first() == Some(&0) { - live_buckets = &live_buckets[1..]; - } - } - (live_buckets, buckets_since_usage) - } } fn calculate_spending( @@ -741,32 +110,3 @@ fn calculate_spending( + output_token_cost; Cents::new(spending as u32) } - -const MINUTE_BUCKET_COUNT: usize = 12; -const DAY_BUCKET_COUNT: usize = 48; - -impl UsageMeasure { - fn bucket_count(&self) -> usize { - match self { - UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT, - UsageMeasure::TokensPerMinute - | UsageMeasure::InputTokensPerMinute - | UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT, - UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT, - } - } - - fn total_duration(&self) -> Duration { - match self { - UsageMeasure::RequestsPerMinute => Duration::minutes(1), - UsageMeasure::TokensPerMinute - | UsageMeasure::InputTokensPerMinute - | UsageMeasure::OutputTokensPerMinute => Duration::minutes(1), - UsageMeasure::TokensPerDay => Duration::hours(24), - } - } - - fn bucket_duration(&self) -> Duration { - self.total_duration() / self.bucket_count() as i32 - } -} diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index 407c5c8fd0..5f2d357a87 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -1,8 +1,6 @@ pub mod billing_event; -pub mod lifetime_usage; pub mod model; pub mod monthly_usage; pub mod provider; -pub mod revoked_access_token; pub mod usage; pub mod usage_measure; diff --git a/crates/collab/src/llm/db/tables/lifetime_usage.rs b/crates/collab/src/llm/db/tables/lifetime_usage.rs deleted file mode 100644 index fc8354699b..0000000000 --- a/crates/collab/src/llm/db/tables/lifetime_usage.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::{db::UserId, llm::db::ModelId}; -use sea_orm::entity::prelude::*; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "lifetime_usages")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: i32, - pub user_id: UserId, - pub model_id: ModelId, - pub input_tokens: i64, - pub cache_creation_input_tokens: i64, - pub cache_read_input_tokens: i64, - pub output_tokens: i64, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/revoked_access_token.rs b/crates/collab/src/llm/db/tables/revoked_access_token.rs deleted file mode 100644 index 364963be88..0000000000 --- a/crates/collab/src/llm/db/tables/revoked_access_token.rs +++ /dev/null @@ -1,19 +0,0 @@ -use chrono::NaiveDateTime; -use sea_orm::entity::prelude::*; - -use crate::llm::db::RevokedAccessTokenId; - -/// A revoked access token. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "revoked_access_tokens")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: RevokedAccessTokenId, - pub jti: String, - pub revoked_at: NaiveDateTime, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tests.rs b/crates/collab/src/llm/db/tests.rs index 59f92958c7..43a1b8b0d4 100644 --- a/crates/collab/src/llm/db/tests.rs +++ b/crates/collab/src/llm/db/tests.rs @@ -1,6 +1,4 @@ -mod billing_tests; mod provider_tests; -mod usage_tests; use gpui::BackgroundExecutor; use parking_lot::Mutex; diff --git a/crates/collab/src/llm/db/tests/billing_tests.rs b/crates/collab/src/llm/db/tests/billing_tests.rs deleted file mode 100644 index 3a95610bc2..0000000000 --- a/crates/collab/src/llm/db/tests/billing_tests.rs +++ /dev/null @@ -1,152 +0,0 @@ -use crate::{ - Cents, - db::UserId, - llm::{ - FREE_TIER_MONTHLY_SPENDING_LIMIT, - db::{LlmDatabase, TokenUsage, queries::providers::ModelParams}, - }, - test_llm_db, -}; -use chrono::{DateTime, Utc}; -use pretty_assertions::assert_eq; -use rpc::LanguageModelProvider; - -test_llm_db!( - test_billing_limit_exceeded, - test_billing_limit_exceeded_postgres -); - -async fn test_billing_limit_exceeded(db: &mut LlmDatabase) { - let provider = LanguageModelProvider::Anthropic; - let model = "fake-claude-limerick"; - const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5; - const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5; - - // Initialize the database and insert the model - db.initialize().await.unwrap(); - db.insert_models(&[ModelParams { - provider, - name: model.to_string(), - max_requests_per_minute: 5, - max_tokens_per_minute: 10_000, - max_tokens_per_day: 50_000, - price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS, - price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS, - }]) - .await - .unwrap(); - - // Set a fixed datetime for consistent testing - let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z") - .unwrap() - .with_timezone(&Utc); - - let user_id = UserId::from_proto(123); - - let max_monthly_spend = Cents::from_dollars(11); - - // Record usage that brings us close to the limit but doesn't exceed it - // Let's say we use $10.50 worth of tokens - let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens - let usage = TokenUsage { - input: tokens_to_use, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }; - - // Verify that before we record any usage, there are 0 billing events - let billing_events = db.get_billing_events().await.unwrap(); - assert_eq!(billing_events.len(), 0); - - db.record_usage( - user_id, - false, - provider, - model, - usage, - true, - max_monthly_spend, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - // Verify the recorded usage and spending - let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - // Verify that we exceeded the free tier usage - assert_eq!(recorded_usage.spending_this_month, Cents::new(1050)); - assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT); - - // Verify that there is one `billing_event` record - let billing_events = db.get_billing_events().await.unwrap(); - assert_eq!(billing_events.len(), 1); - - let (billing_event, _model) = &billing_events[0]; - assert_eq!(billing_event.user_id, user_id); - assert_eq!(billing_event.input_tokens, tokens_to_use as i64); - assert_eq!(billing_event.input_cache_creation_tokens, 0); - assert_eq!(billing_event.input_cache_read_tokens, 0); - assert_eq!(billing_event.output_tokens, 0); - - // Record usage that puts us at $20.50 - let usage_2 = TokenUsage { - input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }; - db.record_usage( - user_id, - false, - provider, - model, - usage_2, - true, - max_monthly_spend, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - // Verify the updated usage and spending - let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!(updated_usage.spending_this_month, Cents::new(2050)); - - // Verify that there are now two billing events - let billing_events = db.get_billing_events().await.unwrap(); - assert_eq!(billing_events.len(), 2); - - let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit - let usage_exceeding = TokenUsage { - input: tokens_to_exceed, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }; - - // This should still create a billing event as it's the first request that exceeds the limit - db.record_usage( - user_id, - false, - provider, - model, - usage_exceeding, - true, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - max_monthly_spend, - now, - ) - .await - .unwrap(); - // Verify the updated usage and spending - let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!(updated_usage.spending_this_month, Cents::new(2150)); - - // Verify that we never exceed the user max spending for the user - // and avoid charging them. - let billing_events = db.get_billing_events().await.unwrap(); - assert_eq!(billing_events.len(), 2); -} diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs deleted file mode 100644 index 0a4ef7f4cf..0000000000 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ /dev/null @@ -1,306 +0,0 @@ -use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT; -use crate::{ - Cents, - db::UserId, - llm::db::{ - LlmDatabase, TokenUsage, - queries::{providers::ModelParams, usages::Usage}, - }, - test_llm_db, -}; -use chrono::{DateTime, Duration, Utc}; -use pretty_assertions::assert_eq; -use rpc::LanguageModelProvider; - -test_llm_db!(test_tracking_usage, test_tracking_usage_postgres); - -async fn test_tracking_usage(db: &mut LlmDatabase) { - let provider = LanguageModelProvider::Anthropic; - let model = "claude-3-5-sonnet"; - - db.initialize().await.unwrap(); - db.insert_models(&[ModelParams { - provider, - name: model.to_string(), - max_requests_per_minute: 5, - max_tokens_per_minute: 10_000, - max_tokens_per_day: 50_000, - price_per_million_input_tokens: 50, - price_per_million_output_tokens: 50, - }]) - .await - .unwrap(); - - // We're using a fixed datetime to prevent flakiness based on the clock. - let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z") - .unwrap() - .with_timezone(&Utc); - let user_id = UserId::from_proto(123); - - let now = t0; - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 1000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let now = t0 + Duration::seconds(10); - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 2000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 2, - tokens_this_minute: 3000, - input_tokens_this_minute: 3000, - output_tokens_this_minute: 0, - tokens_this_day: 3000, - tokens_this_month: TokenUsage { - input: 3000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - let now = t0 + Duration::seconds(60); - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 1, - tokens_this_minute: 2000, - input_tokens_this_minute: 2000, - output_tokens_this_minute: 0, - tokens_this_day: 3000, - tokens_this_month: TokenUsage { - input: 3000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - let now = t0 + Duration::seconds(60); - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 3000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 2, - tokens_this_minute: 5000, - input_tokens_this_minute: 5000, - output_tokens_this_minute: 0, - tokens_this_day: 6000, - tokens_this_month: TokenUsage { - input: 6000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - let t1 = t0 + Duration::hours(24); - let now = t1; - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 0, - tokens_this_minute: 0, - input_tokens_this_minute: 0, - output_tokens_this_minute: 0, - tokens_this_day: 5000, - tokens_this_month: TokenUsage { - input: 6000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 4000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 1, - tokens_this_minute: 4000, - input_tokens_this_minute: 4000, - output_tokens_this_minute: 0, - tokens_this_day: 9000, - tokens_this_month: TokenUsage { - input: 10000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - // We're using a fixed datetime to prevent flakiness based on the clock. - let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z") - .unwrap() - .with_timezone(&Utc); - - // Test cache creation input tokens - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 1000, - input_cache_creation: 500, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 1, - tokens_this_minute: 1500, - input_tokens_this_minute: 1500, - output_tokens_this_minute: 0, - tokens_this_day: 1500, - tokens_this_month: TokenUsage { - input: 1000, - input_cache_creation: 500, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - // Test cache read input tokens - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 1000, - input_cache_creation: 0, - input_cache_read: 300, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 2, - tokens_this_minute: 2800, - input_tokens_this_minute: 2500, - output_tokens_this_minute: 0, - tokens_this_day: 2800, - tokens_this_month: TokenUsage { - input: 2000, - input_cache_creation: 500, - input_cache_read: 300, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); -} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 30dab40cce..8f850ee847 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -9,14 +9,14 @@ use axum::{ use collab::api::CloudflareIpCountryHeader; use collab::api::billing::sync_llm_usage_with_stripe_periodically; -use collab::llm::{db::LlmDatabase, log_usage_periodically}; +use collab::llm::db::LlmDatabase; use collab::migrations::run_database_migrations; use collab::user_backfiller::spawn_user_backfiller; use collab::{ AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, rpc::ResultExt, }; -use collab::{ServiceMode, api::billing::poll_stripe_events_periodically, llm::LlmState}; +use collab::{ServiceMode, api::billing::poll_stripe_events_periodically}; use db::Database; use std::{ env::args, @@ -74,11 +74,10 @@ async fn main() -> Result<()> { let mode = match args.next().as_deref() { Some("collab") => ServiceMode::Collab, Some("api") => ServiceMode::Api, - Some("llm") => ServiceMode::Llm, Some("all") => ServiceMode::All, _ => { return Err(anyhow!( - "usage: collab >" + "usage: collab >" ))?; } }; @@ -97,20 +96,9 @@ async fn main() -> Result<()> { let mut on_shutdown = None; - if mode.is_llm() { - setup_llm_database(&config).await?; - - let state = LlmState::new(config.clone(), Executor::Production).await?; - - log_usage_periodically(state.clone()); - - app = app - .merge(collab::llm::routes()) - .layer(Extension(state.clone())); - } - if mode.is_collab() || mode.is_api() { setup_app_database(&config).await?; + setup_llm_database(&config).await?; let state = AppState::new(config, Executor::Production).await?; @@ -336,18 +324,11 @@ async fn handle_root(Extension(mode): Extension) -> String { format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown")) } -async fn handle_liveness_probe( - app_state: Option>>, - llm_state: Option>>, -) -> Result { +async fn handle_liveness_probe(app_state: Option>>) -> Result { if let Some(state) = app_state { state.db.get_all_users(0, 1).await?; } - if let Some(llm_state) = llm_state { - llm_state.db.list_providers().await?; - } - Ok("ok".to_string()) }