collab: Remove LLM service (#28728)

This PR removes the LLM service from collab, as it has been moved to
Cloudflare.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-14 19:47:14 -04:00 committed by GitHub
parent 12b012eab3
commit fc1252b0cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 8 additions and 2315 deletions

View file

@ -117,12 +117,10 @@ jobs:
export ZED_KUBE_NAMESPACE=production export ZED_KUBE_NAMESPACE=production
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10 export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10
export ZED_API_LOAD_BALANCER_SIZE_UNIT=2 export ZED_API_LOAD_BALANCER_SIZE_UNIT=2
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=2
elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then
export ZED_KUBE_NAMESPACE=staging export ZED_KUBE_NAMESPACE=staging
export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1 export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1
export ZED_API_LOAD_BALANCER_SIZE_UNIT=1 export ZED_API_LOAD_BALANCER_SIZE_UNIT=1
export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=1
else else
echo "cowardly refusing to deploy from an unknown branch" echo "cowardly refusing to deploy from an unknown branch"
exit 1 exit 1
@ -147,9 +145,3 @@ jobs:
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
export ZED_SERVICE_NAME=llm
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_LLM_LOAD_BALANCER_SIZE_UNIT
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"

1
Cargo.lock generated
View file

@ -2942,7 +2942,6 @@ dependencies = [
name = "collab" name = "collab"
version = "0.44.0" version = "0.44.0"
dependencies = [ dependencies = [
"anthropic",
"anyhow", "anyhow",
"assistant", "assistant",
"assistant_context_editor", "assistant_context_editor",

View file

@ -18,7 +18,6 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
test-support = ["sqlite"] test-support = ["sqlite"]
[dependencies] [dependencies]
anthropic.workspace = true
anyhow.workspace = true anyhow.workspace = true
async-stripe.workspace = true async-stripe.workspace = true
async-tungstenite.workspace = true async-tungstenite.workspace = true

View file

@ -253,7 +253,6 @@ impl Config {
pub enum ServiceMode { pub enum ServiceMode {
Api, Api,
Collab, Collab,
Llm,
All, All,
} }
@ -265,10 +264,6 @@ impl ServiceMode {
pub fn is_api(&self) -> bool { pub fn is_api(&self) -> bool {
matches!(self, Self::Api | Self::All) matches!(self, Self::Api | Self::All)
} }
pub fn is_llm(&self) -> bool {
matches!(self, Self::Llm | Self::All)
}
} }
pub struct AppState { pub struct AppState {

View file

@ -1,448 +1,10 @@
mod authorization;
pub mod db; pub mod db;
mod token; mod token;
use crate::api::CloudflareIpCountryHeader; use crate::Cents;
use crate::api::events::SnowflakeRow;
use crate::build_kinesis_client;
use crate::rpc::MIN_ACCOUNT_AGE_FOR_LLM_USE;
use crate::{Cents, Config, Error, Result, db::UserId, executor::Executor};
use anyhow::{Context as _, anyhow};
use authorization::authorize_access_to_language_model;
use axum::routing::get;
use axum::{
Extension, Json, Router, TypedHeader,
body::Body,
http::{self, HeaderName, HeaderValue, Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::post,
};
use chrono::{DateTime, Duration, Utc};
use collections::HashMap;
use db::TokenUsage;
use db::{ActiveUserCount, LlmDatabase, usage_measure::UsageMeasure};
use futures::{Stream, StreamExt as _};
use reqwest_client::ReqwestClient;
use rpc::{
EXPIRED_LLM_TOKEN_HEADER_NAME, LanguageModelProvider, PerformCompletionParams, proto::Plan,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use serde_json::json;
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use strum::IntoEnumIterator;
use tokio::sync::RwLock;
use util::ResultExt;
pub use token::*; pub use token::*;
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
pub struct LlmState {
pub config: Config,
pub executor: Executor,
pub db: Arc<LlmDatabase>,
pub http_client: ReqwestClient,
pub kinesis_client: Option<aws_sdk_kinesis::Client>,
active_user_count_by_model:
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
}
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let database_url = config
.llm_database_url
.as_ref()
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
let max_connections = config
.llm_database_max_connections
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections);
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
db.initialize().await?;
let db = Arc::new(db);
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client =
ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
let this = Self {
executor,
db,
http_client,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()
} else {
None
},
active_user_count_by_model: RwLock::new(HashMap::default()),
config,
};
Ok(Arc::new(this))
}
pub async fn get_active_user_count(
&self,
provider: LanguageModelProvider,
model: &str,
) -> Result<ActiveUserCount> {
let now = Utc::now();
{
let active_user_count_by_model = self.active_user_count_by_model.read().await;
if let Some((last_updated, count)) =
active_user_count_by_model.get(&(provider, model.to_string()))
{
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
return Ok(*count);
}
}
}
let mut cache = self.active_user_count_by_model.write().await;
let new_count = self.db.get_active_user_count(provider, model, now).await?;
cache.insert((provider, model.to_string()), (now, new_count));
Ok(new_count)
}
}
pub fn routes() -> Router<(), Body> {
Router::new()
.route("/models", get(list_models))
.route("/completion", post(perform_completion))
.layer(middleware::from_fn(validate_api_token))
}
async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
let token = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::http(
StatusCode::BAD_REQUEST,
"missing authorization header".to_string(),
)
})?
.strip_prefix("Bearer ")
.ok_or_else(|| {
Error::http(
StatusCode::BAD_REQUEST,
"invalid authorization header".to_string(),
)
})?;
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
match LlmTokenClaims::validate(token, &state.config) {
Ok(claims) => {
if state.db.is_access_token_revoked(&claims.jti).await? {
return Err(Error::http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
));
}
tracing::Span::current()
.record("user_id", claims.user_id)
.record("login", claims.github_user_login.clone())
.record("authn.jti", &claims.jti)
.record("is_staff", claims.is_staff);
req.extensions_mut().insert(claims);
Ok::<_, Error>(next.run(req).await.into_response())
}
Err(ValidateLlmTokenError::Expired) => Err(Error::Http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
[(
HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME),
HeaderValue::from_static("true"),
)]
.into_iter()
.collect(),
)),
Err(_err) => Err(Error::http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
)),
}
}
async fn list_models(
Extension(state): Extension<Arc<LlmState>>,
Extension(claims): Extension<LlmTokenClaims>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
) -> Result<Json<ListModelsResponse>> {
let country_code = country_code_header.map(|header| header.to_string());
let mut accessible_models = Vec::new();
for (provider, model) in state.db.all_models() {
let authorize_result = authorize_access_to_language_model(
&state.config,
&claims,
country_code.as_deref(),
provider,
&model.name,
);
if authorize_result.is_ok() {
accessible_models.push(rpc::LanguageModel {
provider,
name: model.name,
});
}
}
Ok(Json(ListModelsResponse {
models: accessible_models,
}))
}
async fn perform_completion(
Extension(state): Extension<Arc<LlmState>>,
Extension(claims): Extension<LlmTokenClaims>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
let model = normalize_model_name(
state.db.model_names_for_provider(params.provider),
params.model,
);
let bypass_account_age_check = claims.has_llm_subscription || claims.bypass_account_age_check;
if !bypass_account_age_check {
if Utc::now().naive_utc() - claims.account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
Err(anyhow!("account too young"))?
}
}
authorize_access_to_language_model(
&state.config,
&claims,
country_code_header
.map(|header| header.to_string())
.as_deref(),
params.provider,
&model,
)?;
check_usage_limit(&state, params.provider, &model, &claims).await?;
let stream = match params.provider {
LanguageModelProvider::Anthropic => {
let api_key = if claims.is_staff {
state
.config
.anthropic_staff_api_key
.as_ref()
.context("no Anthropic AI staff API key configured on the server")?
} else {
state
.config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?
};
let mut request: anthropic::Request =
serde_json::from_str(params.provider_request.get())?;
// Override the model on the request with the latest version of the model that is
// known to the server.
//
// Right now, we use the version that's defined in `model.id()`, but we will likely
// want to change this code once a new version of an Anthropic model is released,
// so that users can use the new version, without having to update Zed.
request.model = match model.as_str() {
"claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
"claude-3-7-sonnet" => anthropic::Model::Claude3_7Sonnet.id().to_string(),
"claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
"claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
"claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
_ => request.model,
};
let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
&state.http_client,
anthropic::ANTHROPIC_API_URL,
api_key,
request,
)
.await
.map_err(|err| match err {
anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
Some(anthropic::ApiErrorCode::RateLimitError) => {
tracing::info!(
target: "upstream rate limit exceeded",
user_id = claims.user_id,
login = claims.github_user_login,
authn.jti = claims.jti,
is_staff = claims.is_staff,
provider = params.provider.to_string(),
model = model
);
Error::http(
StatusCode::TOO_MANY_REQUESTS,
"Upstream Anthropic rate limit exceeded.".to_string(),
)
}
Some(anthropic::ApiErrorCode::InvalidRequestError) => {
Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
}
Some(anthropic::ApiErrorCode::OverloadedError) => {
Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
}
Some(_) => {
Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
}
None => Error::Internal(anyhow!(err)),
},
anthropic::AnthropicError::Other(err) => Error::Internal(err),
})?;
if let Some(rate_limit_info) = rate_limit_info {
tracing::info!(
target: "upstream rate limit",
is_staff = claims.is_staff,
provider = params.provider.to_string(),
model = model,
tokens_remaining = rate_limit_info.tokens.as_ref().map(|limits| limits.remaining),
input_tokens_remaining = rate_limit_info.input_tokens.as_ref().map(|limits| limits.remaining),
output_tokens_remaining = rate_limit_info.output_tokens.as_ref().map(|limits| limits.remaining),
requests_remaining = rate_limit_info.requests.as_ref().map(|limits| limits.remaining),
requests_reset = ?rate_limit_info.requests.as_ref().map(|limits| limits.reset),
tokens_reset = ?rate_limit_info.tokens.as_ref().map(|limits| limits.reset),
input_tokens_reset = ?rate_limit_info.input_tokens.as_ref().map(|limits| limits.reset),
output_tokens_reset = ?rate_limit_info.output_tokens.as_ref().map(|limits| limits.reset),
);
}
chunks
.map(move |event| {
let chunk = event?;
let (
input_tokens,
output_tokens,
cache_creation_input_tokens,
cache_read_input_tokens,
) = match &chunk {
anthropic::Event::MessageStart {
message: anthropic::Response { usage, .. },
}
| anthropic::Event::MessageDelta { usage, .. } => (
usage.input_tokens.unwrap_or(0) as usize,
usage.output_tokens.unwrap_or(0) as usize,
usage.cache_creation_input_tokens.unwrap_or(0) as usize,
usage.cache_read_input_tokens.unwrap_or(0) as usize,
),
_ => (0, 0, 0, 0),
};
anyhow::Ok(CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
cache_creation_input_tokens,
cache_read_input_tokens,
})
})
.boxed()
}
LanguageModelProvider::OpenAi => {
let api_key = state
.config
.openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let chunks = open_ai::stream_completion(
&state.http_client,
open_ai::OPEN_AI_API_URL,
api_key,
serde_json::from_str(params.provider_request.get())?,
)
.await?;
chunks
.map(|event| {
event.map(|chunk| {
let input_tokens =
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
let output_tokens =
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens,
output_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
})
})
.boxed()
}
LanguageModelProvider::Google => {
let api_key = state
.config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
let chunks = google_ai::stream_generate_content(
&state.http_client,
google_ai::API_URL,
api_key,
serde_json::from_str(params.provider_request.get())?,
)
.await?;
chunks
.map(|event| {
event.map(|chunk| {
// TODO - implement token counting for Google AI
CompletionChunk {
bytes: serde_json::to_vec(&chunk).unwrap(),
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
}
})
})
.boxed()
}
};
Ok(Response::new(Body::wrap_stream(TokenCountingStream {
state,
claims,
provider: params.provider,
model,
tokens: TokenUsage::default(),
inner_stream: stream,
})))
}
fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
if let Some(known_model_name) = known_models
.iter()
.filter(|known_model_name| name.starts_with(known_model_name.as_str()))
.max_by_key(|known_model_name| known_model_name.len())
{
known_model_name.to_string()
} else {
name
}
}
/// The maximum monthly spending an individual user can reach on the free tier /// The maximum monthly spending an individual user can reach on the free tier
/// before they have to pay. /// before they have to pay.
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10); pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
@ -452,330 +14,3 @@ pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
/// ///
/// Used to prevent surprise bills. /// Used to prevent surprise bills.
pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10); pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
async fn check_usage_limit(
state: &Arc<LlmState>,
provider: LanguageModelProvider,
model_name: &str,
claims: &LlmTokenClaims,
) -> Result<()> {
if claims.is_staff {
return Ok(());
}
let user_id = UserId::from_proto(claims.user_id);
let model = state.db.model(provider, model_name)?;
let free_tier = claims.free_tier_monthly_spending_limit();
let spending_this_month = state
.db
.get_user_spending_for_month(user_id, Utc::now())
.await?;
if spending_this_month >= free_tier {
if !claims.has_llm_subscription {
return Err(Error::http(
StatusCode::PAYMENT_REQUIRED,
"Maximum spending limit reached for this month.".to_string(),
));
}
let monthly_spend = spending_this_month.saturating_sub(free_tier);
if monthly_spend >= Cents(claims.max_monthly_spend_in_cents) {
return Err(Error::Http(
StatusCode::FORBIDDEN,
"Maximum spending limit reached for this month.".to_string(),
[(
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
HeaderValue::from_static("true"),
)]
.into_iter()
.collect(),
));
}
}
let active_users = state.get_active_user_count(provider, model_name).await?;
let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
let users_in_recent_days = active_users.users_in_recent_days.max(1);
let per_user_max_requests_per_minute =
model.max_requests_per_minute as usize / users_in_recent_minutes;
let per_user_max_tokens_per_minute =
model.max_tokens_per_minute as usize / users_in_recent_minutes;
let per_user_max_input_tokens_per_minute =
model.max_input_tokens_per_minute as usize / users_in_recent_minutes;
let per_user_max_output_tokens_per_minute =
model.max_output_tokens_per_minute as usize / users_in_recent_minutes;
let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
let usage = state
.db
.get_usage(user_id, provider, model_name, Utc::now())
.await?;
let checks = match (provider, model_name) {
(LanguageModelProvider::Anthropic, "claude-3-7-sonnet") => vec![
(
usage.requests_this_minute,
per_user_max_requests_per_minute,
UsageMeasure::RequestsPerMinute,
),
(
usage.input_tokens_this_minute,
per_user_max_tokens_per_minute,
UsageMeasure::InputTokensPerMinute,
),
(
usage.output_tokens_this_minute,
per_user_max_tokens_per_minute,
UsageMeasure::OutputTokensPerMinute,
),
(
usage.tokens_this_day,
per_user_max_tokens_per_day,
UsageMeasure::TokensPerDay,
),
],
_ => vec![
(
usage.requests_this_minute,
per_user_max_requests_per_minute,
UsageMeasure::RequestsPerMinute,
),
(
usage.tokens_this_minute,
per_user_max_tokens_per_minute,
UsageMeasure::TokensPerMinute,
),
(
usage.tokens_this_day,
per_user_max_tokens_per_day,
UsageMeasure::TokensPerDay,
),
],
};
for (used, limit, usage_measure) in checks {
if used > limit {
let resource = match usage_measure {
UsageMeasure::RequestsPerMinute => "requests_per_minute",
UsageMeasure::TokensPerMinute => "tokens_per_minute",
UsageMeasure::InputTokensPerMinute => "input_tokens_per_minute",
UsageMeasure::OutputTokensPerMinute => "output_tokens_per_minute",
UsageMeasure::TokensPerDay => "tokens_per_day",
};
tracing::info!(
target: "user rate limit",
user_id = claims.user_id,
login = claims.github_user_login,
authn.jti = claims.jti,
is_staff = claims.is_staff,
provider = provider.to_string(),
model = model.name,
usage_measure = resource,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
input_tokens_this_minute = usage.input_tokens_this_minute,
output_tokens_this_minute = usage.output_tokens_this_minute,
tokens_this_day = usage.tokens_this_day,
users_in_recent_minutes = users_in_recent_minutes,
users_in_recent_days = users_in_recent_days,
max_requests_per_minute = per_user_max_requests_per_minute,
max_tokens_per_minute = per_user_max_tokens_per_minute,
max_input_tokens_per_minute = per_user_max_input_tokens_per_minute,
max_output_tokens_per_minute = per_user_max_output_tokens_per_minute,
max_tokens_per_day = per_user_max_tokens_per_day,
);
SnowflakeRow::new(
"Language Model Rate Limited",
Some(claims.metrics_id),
claims.is_staff,
claims.system_id.clone(),
json!({
"usage": usage,
"users_in_recent_minutes": users_in_recent_minutes,
"users_in_recent_days": users_in_recent_days,
"max_requests_per_minute": per_user_max_requests_per_minute,
"max_tokens_per_minute": per_user_max_tokens_per_minute,
"max_input_tokens_per_minute": per_user_max_input_tokens_per_minute,
"max_output_tokens_per_minute": per_user_max_output_tokens_per_minute,
"max_tokens_per_day": per_user_max_tokens_per_day,
"plan": match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
"model": model.name.clone(),
"provider": provider.to_string(),
"usage_measure": resource.to_string(),
}),
)
.write(&state.kinesis_client, &state.config.kinesis_stream)
.await
.log_err();
return Err(Error::http(
StatusCode::TOO_MANY_REQUESTS,
format!("Rate limit exceeded. Maximum {} reached.", resource),
));
}
}
Ok(())
}
struct CompletionChunk {
bytes: Vec<u8>,
input_tokens: usize,
output_tokens: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
}
struct TokenCountingStream<S> {
state: Arc<LlmState>,
claims: LlmTokenClaims,
provider: LanguageModelProvider,
model: String,
tokens: TokenUsage,
inner_stream: S,
}
impl<S> Stream for TokenCountingStream<S>
where
S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
{
type Item = Result<Vec<u8>, anyhow::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner_stream).poll_next(cx) {
Poll::Ready(Some(Ok(mut chunk))) => {
chunk.bytes.push(b'\n');
self.tokens.input += chunk.input_tokens;
self.tokens.output += chunk.output_tokens;
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
Poll::Ready(Some(Ok(chunk.bytes)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl<S> Drop for TokenCountingStream<S> {
fn drop(&mut self) {
let state = self.state.clone();
let claims = self.claims.clone();
let provider = self.provider;
let model = std::mem::take(&mut self.model);
let tokens = self.tokens;
self.state.executor.spawn_detached(async move {
let usage = state
.db
.record_usage(
UserId::from_proto(claims.user_id),
claims.is_staff,
provider,
&model,
tokens,
claims.has_llm_subscription,
Cents(claims.max_monthly_spend_in_cents),
claims.free_tier_monthly_spending_limit(),
Utc::now(),
)
.await
.log_err();
if let Some(usage) = usage {
tracing::info!(
target: "user usage",
user_id = claims.user_id,
login = claims.github_user_login,
authn.jti = claims.jti,
is_staff = claims.is_staff,
provider = provider.to_string(),
model = model,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
input_tokens_this_minute = usage.input_tokens_this_minute,
output_tokens_this_minute = usage.output_tokens_this_minute,
);
let properties = json!({
"has_llm_subscription": claims.has_llm_subscription,
"max_monthly_spend_in_cents": claims.max_monthly_spend_in_cents,
"plan": match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
"model": model,
"provider": provider,
"usage": usage,
"tokens": tokens
});
SnowflakeRow::new(
"Language Model Used",
Some(claims.metrics_id),
claims.is_staff,
claims.system_id.clone(),
properties,
)
.write(&state.kinesis_client, &state.config.kinesis_stream)
.await
.log_err();
}
})
}
}
pub fn log_usage_periodically(state: Arc<LlmState>) {
state.executor.clone().spawn_detached(async move {
loop {
state
.executor
.sleep(std::time::Duration::from_secs(30))
.await;
for provider in LanguageModelProvider::iter() {
for model in state.db.model_names_for_provider(provider) {
if let Some(active_user_count) = state
.get_active_user_count(provider, &model)
.await
.log_err()
{
tracing::info!(
target: "active user counts",
provider = provider.to_string(),
model = model,
users_in_recent_minutes = active_user_count.users_in_recent_minutes,
users_in_recent_days = active_user_count.users_in_recent_days,
);
}
}
}
if let Some(usages) = state
.db
.get_application_wide_usages_by_model(Utc::now())
.await
.log_err()
{
for usage in usages {
tracing::info!(
target: "computed usage",
provider = usage.provider.to_string(),
model = usage.model,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
input_tokens_this_minute = usage.input_tokens_this_minute,
output_tokens_this_minute = usage.output_tokens_this_minute,
);
}
}
}
})
}

View file

@ -1,330 +0,0 @@
use reqwest::StatusCode;
use rpc::LanguageModelProvider;
use crate::llm::LlmTokenClaims;
use crate::{Config, Error, Result};
pub fn authorize_access_to_language_model(
config: &Config,
claims: &LlmTokenClaims,
country_code: Option<&str>,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
authorize_access_for_country(config, country_code, provider)?;
authorize_access_to_model(config, claims, provider, model)?;
Ok(())
}
fn authorize_access_to_model(
config: &Config,
claims: &LlmTokenClaims,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
if claims.is_staff {
return Ok(());
}
if provider == LanguageModelProvider::Anthropic {
if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" {
return Ok(());
}
if claims.has_llm_closed_beta_feature_flag
&& Some(model) == config.llm_closed_beta_model_name.as_deref()
{
return Ok(());
}
}
Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),
))
}
fn authorize_access_for_country(
config: &Config,
country_code: Option<&str>,
provider: LanguageModelProvider,
) -> Result<()> {
// In development we won't have the `CF-IPCountry` header, so we can't check
// the country code.
//
// This shouldn't be necessary, as anyone running in development will need to provide
// their own API credentials in order to use an LLM provider.
if config.is_development() {
return Ok(());
}
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
let country_code = match country_code {
// `XX` - Used for clients without country code data.
None | Some("XX") => Err(Error::http(
StatusCode::BAD_REQUEST,
"no country code".to_string(),
))?,
// `T1` - Used for clients using the Tor network.
Some("T1") => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to {provider:?} models is not available over Tor"),
))?,
Some(country_code) => country_code,
};
let is_country_supported_by_provider = match provider {
LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
};
if !is_country_supported_by_provider {
Err(Error::http(
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
format!(
"access to {provider:?} models is not available in your region ({country_code})"
),
))?
}
Ok(())
}
#[cfg(test)]
mod tests {
use axum::response::IntoResponse;
use pretty_assertions::assert_eq;
use rpc::proto::Plan;
use super::*;
#[gpui::test]
async fn test_authorize_access_to_language_model_with_supported_country(
_cx: &mut gpui::TestAppContext,
) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
is_staff: true,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "US"), // United States
(LanguageModelProvider::Anthropic, "GB"), // United Kingdom
(LanguageModelProvider::OpenAi, "US"), // United States
(LanguageModelProvider::OpenAi, "GB"), // United Kingdom
(LanguageModelProvider::Google, "US"), // United States
(LanguageModelProvider::Google, "GB"), // United Kingdom
];
for (provider, country_code) in cases {
authorize_access_to_language_model(
&config,
&claims,
Some(country_code),
provider,
"the-model",
)
.unwrap_or_else(|_| {
panic!("expected authorization to return Ok for {provider:?}: {country_code}")
})
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_with_unsupported_country(
_cx: &mut gpui::TestAppContext,
) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "AF"), // Afghanistan
(LanguageModelProvider::Anthropic, "BY"), // Belarus
(LanguageModelProvider::Anthropic, "CF"), // Central African Republic
(LanguageModelProvider::Anthropic, "CN"), // China
(LanguageModelProvider::Anthropic, "CU"), // Cuba
(LanguageModelProvider::Anthropic, "ER"), // Eritrea
(LanguageModelProvider::Anthropic, "ET"), // Ethiopia
(LanguageModelProvider::Anthropic, "IR"), // Iran
(LanguageModelProvider::Anthropic, "KP"), // North Korea
(LanguageModelProvider::Anthropic, "XK"), // Kosovo
(LanguageModelProvider::Anthropic, "LY"), // Libya
(LanguageModelProvider::Anthropic, "MM"), // Myanmar
(LanguageModelProvider::Anthropic, "RU"), // Russia
(LanguageModelProvider::Anthropic, "SO"), // Somalia
(LanguageModelProvider::Anthropic, "SS"), // South Sudan
(LanguageModelProvider::Anthropic, "SD"), // Sudan
(LanguageModelProvider::Anthropic, "SY"), // Syria
(LanguageModelProvider::Anthropic, "VE"), // Venezuela
(LanguageModelProvider::Anthropic, "YE"), // Yemen
(LanguageModelProvider::OpenAi, "KP"), // North Korea
(LanguageModelProvider::Google, "KP"), // North Korea
];
for (provider, country_code) in cases {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code),
provider,
"the-model",
)
.expect_err(&format!(
"expected authorization to return an error for {provider:?}: {country_code}"
))
.into_response();
assert_eq!(
error_response.status(),
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
);
let response_body = hyper::body::to_bytes(error_response.into_body())
.await
.unwrap()
.to_vec();
assert_eq!(
String::from_utf8(response_body).unwrap(),
format!(
"access to {provider:?} models is not available in your region ({country_code})"
)
);
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
let config = Config::test();
let claims = LlmTokenClaims {
user_id: 99,
plan: Plan::ZedPro,
..Default::default()
};
let cases = vec![
(LanguageModelProvider::Anthropic, "T1"), // Tor
(LanguageModelProvider::OpenAi, "T1"), // Tor
(LanguageModelProvider::Google, "T1"), // Tor
];
for (provider, country_code) in cases {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code),
provider,
"the-model",
)
.expect_err(&format!(
"expected authorization to return an error for {provider:?}: {country_code}"
))
.into_response();
assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
let response_body = hyper::body::to_bytes(error_response.into_body())
.await
.unwrap()
.to_vec();
assert_eq!(
String::from_utf8(response_body).unwrap(),
format!("access to {provider:?} models is not available over Tor")
);
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_based_on_plan() {
let config = Config::test();
let test_cases = vec![
// Pro plan should have access to claude-3.5-sonnet
(
Plan::ZedPro,
LanguageModelProvider::Anthropic,
"claude-3-5-sonnet",
true,
),
// Free plan should have access to claude-3.5-sonnet
(
Plan::Free,
LanguageModelProvider::Anthropic,
"claude-3-5-sonnet",
true,
),
// Pro plan should NOT have access to other Anthropic models
(
Plan::ZedPro,
LanguageModelProvider::Anthropic,
"claude-3-opus",
false,
),
];
for (plan, provider, model, expected_access) in test_cases {
let claims = LlmTokenClaims {
plan,
..Default::default()
};
let result =
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
if expected_access {
assert!(
result.is_ok(),
"Expected access to be granted for plan {:?}, provider {:?}, model {}",
plan,
provider,
model
);
} else {
let error = result.expect_err(&format!(
"Expected access to be denied for plan {:?}, provider {:?}, model {}",
plan, provider, model
));
let response = error.into_response();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
}
#[gpui::test]
async fn test_authorize_access_to_language_model_for_staff() {
let config = Config::test();
let claims = LlmTokenClaims {
is_staff: true,
..Default::default()
};
// Staff should have access to all models
let test_cases = vec![
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
(LanguageModelProvider::Anthropic, "claude-2"),
(LanguageModelProvider::Anthropic, "claude-123-agi"),
(LanguageModelProvider::OpenAi, "gpt-4"),
(LanguageModelProvider::Google, "gemini-pro"),
];
for (provider, model) in test_cases {
let result =
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
assert!(
result.is_ok(),
"Expected staff to have access to provider {:?}, model {}",
provider,
model
);
}
}
}

View file

@ -20,7 +20,6 @@ use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use anyhow::anyhow; use anyhow::anyhow;
pub use queries::usages::{ActiveUserCount, TokenUsage};
pub use sea_orm::ConnectOptions; pub use sea_orm::ConnectOptions;
use sea_orm::prelude::*; use sea_orm::prelude::*;
use sea_orm::{ use sea_orm::{

View file

@ -2,5 +2,4 @@ use super::*;
pub mod billing_events; pub mod billing_events;
pub mod providers; pub mod providers;
pub mod revoked_access_tokens;
pub mod usages; pub mod usages;

View file

@ -1,15 +0,0 @@
use super::*;
impl LlmDatabase {
/// Returns whether the access token with the given `jti` has been revoked.
pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
self.transaction(|tx| async move {
Ok(revoked_access_token::Entity::find()
.filter(revoked_access_token::Column::Jti.eq(jti))
.one(&*tx)
.await?
.is_some())
})
.await
}
}

View file

@ -1,56 +1,12 @@
use crate::db::UserId; use crate::db::UserId;
use crate::llm::Cents; use crate::llm::Cents;
use chrono::{Datelike, Duration}; use chrono::Datelike;
use futures::StreamExt as _; use futures::StreamExt as _;
use rpc::LanguageModelProvider; use std::str::FromStr;
use sea_orm::QuerySelect;
use std::{iter, str::FromStr};
use strum::IntoEnumIterator as _; use strum::IntoEnumIterator as _;
use super::*; use super::*;
#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
pub struct TokenUsage {
pub input: usize,
pub input_cache_creation: usize,
pub input_cache_read: usize,
pub output: usize,
}
impl TokenUsage {
pub fn total(&self) -> usize {
self.input + self.input_cache_creation + self.input_cache_read + self.output
}
}
#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
pub struct Usage {
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
pub input_tokens_this_minute: usize,
pub output_tokens_this_minute: usize,
pub tokens_this_day: usize,
pub tokens_this_month: TokenUsage,
pub spending_this_month: Cents,
pub lifetime_spending: Cents,
}
#[derive(Debug, PartialEq, Clone)]
pub struct ApplicationWideUsage {
pub provider: LanguageModelProvider,
pub model: String,
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
pub input_tokens_this_minute: usize,
pub output_tokens_this_minute: usize,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ActiveUserCount {
pub users_in_recent_minutes: usize,
pub users_in_recent_days: usize,
}
impl LlmDatabase { impl LlmDatabase {
pub async fn initialize_usage_measures(&mut self) -> Result<()> { pub async fn initialize_usage_measures(&mut self) -> Result<()> {
let all_measures = self let all_measures = self
@ -90,100 +46,6 @@ impl LlmDatabase {
Ok(()) Ok(())
} }
pub async fn get_application_wide_usages_by_model(
&self,
now: DateTimeUtc,
) -> Result<Vec<ApplicationWideUsage>> {
self.transaction(|tx| async move {
let past_minute = now - Duration::minutes(1);
let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
let input_tokens_per_minute =
self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute];
let output_tokens_per_minute =
self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute];
let mut results = Vec::new();
for ((provider, model_name), model) in self.models.iter() {
let mut usages = usage::Entity::find()
.filter(
usage::Column::Timestamp
.gte(past_minute.naive_utc())
.and(usage::Column::IsStaff.eq(false))
.and(usage::Column::ModelId.eq(model.id))
.and(
usage::Column::MeasureId
.eq(requests_per_minute)
.or(usage::Column::MeasureId.eq(tokens_per_minute)),
),
)
.stream(&*tx)
.await?;
let mut requests_this_minute = 0;
let mut tokens_this_minute = 0;
let mut input_tokens_this_minute = 0;
let mut output_tokens_this_minute = 0;
while let Some(usage) = usages.next().await {
let usage = usage?;
if usage.measure_id == requests_per_minute {
requests_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::RequestsPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == tokens_per_minute {
tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::TokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == input_tokens_per_minute {
input_tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::InputTokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == output_tokens_per_minute {
output_tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::OutputTokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
}
}
results.push(ApplicationWideUsage {
provider: *provider,
model: model_name.clone(),
requests_this_minute,
tokens_this_minute,
input_tokens_this_minute,
output_tokens_this_minute,
})
}
Ok(results)
})
.await
}
pub async fn get_user_spending_for_month( pub async fn get_user_spending_for_month(
&self, &self,
user_id: UserId, user_id: UserId,
@ -223,499 +85,6 @@ impl LlmDatabase {
}) })
.await .await
} }
pub async fn get_usage(
&self,
user_id: UserId,
provider: LanguageModelProvider,
model_name: &str,
now: DateTimeUtc,
) -> Result<Usage> {
self.transaction(|tx| async move {
let model = self
.models
.get(&(provider, model_name.to_string()))
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
let usages = usage::Entity::find()
.filter(
usage::Column::UserId
.eq(user_id)
.and(usage::Column::ModelId.eq(model.id)),
)
.all(&*tx)
.await?;
let month = now.date_naive().month() as i32;
let year = now.date_naive().year();
let monthly_usage = monthly_usage::Entity::find()
.filter(
monthly_usage::Column::UserId
.eq(user_id)
.and(monthly_usage::Column::ModelId.eq(model.id))
.and(monthly_usage::Column::Month.eq(month))
.and(monthly_usage::Column::Year.eq(year)),
)
.one(&*tx)
.await?;
let lifetime_usage = lifetime_usage::Entity::find()
.filter(
lifetime_usage::Column::UserId
.eq(user_id)
.and(lifetime_usage::Column::ModelId.eq(model.id)),
)
.one(&*tx)
.await?;
let requests_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
let tokens_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
let input_tokens_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?;
let output_tokens_this_minute =
self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?;
let tokens_this_day =
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
calculate_spending(
model,
monthly_usage.input_tokens as usize,
monthly_usage.cache_creation_input_tokens as usize,
monthly_usage.cache_read_input_tokens as usize,
monthly_usage.output_tokens as usize,
)
} else {
Cents::ZERO
};
let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
calculate_spending(
model,
lifetime_usage.input_tokens as usize,
lifetime_usage.cache_creation_input_tokens as usize,
lifetime_usage.cache_read_input_tokens as usize,
lifetime_usage.output_tokens as usize,
)
} else {
Cents::ZERO
};
Ok(Usage {
requests_this_minute,
tokens_this_minute,
input_tokens_this_minute,
output_tokens_this_minute,
tokens_this_day,
tokens_this_month: TokenUsage {
input: monthly_usage
.as_ref()
.map_or(0, |usage| usage.input_tokens as usize),
input_cache_creation: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
input_cache_read: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
output: monthly_usage
.as_ref()
.map_or(0, |usage| usage.output_tokens as usize),
},
spending_this_month,
lifetime_spending,
})
})
.await
}
pub async fn record_usage(
&self,
user_id: UserId,
is_staff: bool,
provider: LanguageModelProvider,
model_name: &str,
tokens: TokenUsage,
has_llm_subscription: bool,
max_monthly_spend: Cents,
free_tier_monthly_spending_limit: Cents,
now: DateTimeUtc,
) -> Result<Usage> {
self.transaction(|tx| async move {
let model = self.model(provider, model_name)?;
let usages = usage::Entity::find()
.filter(
usage::Column::UserId
.eq(user_id)
.and(usage::Column::ModelId.eq(model.id)),
)
.all(&*tx)
.await?;
let requests_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::RequestsPerMinute,
now,
1,
&tx,
)
.await?;
let tokens_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::TokensPerMinute,
now,
tokens.total(),
&tx,
)
.await?;
let input_tokens_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::InputTokensPerMinute,
now,
// Cache read input tokens are not counted for the purposes of rate limits (but they are still billed).
tokens.input + tokens.input_cache_creation,
&tx,
)
.await?;
let output_tokens_this_minute = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::OutputTokensPerMinute,
now,
tokens.output,
&tx,
)
.await?;
let tokens_this_day = self
.update_usage_for_measure(
user_id,
is_staff,
model.id,
&usages,
UsageMeasure::TokensPerDay,
now,
tokens.total(),
&tx,
)
.await?;
let month = now.date_naive().month() as i32;
let year = now.date_naive().year();
// Update monthly usage
let monthly_usage = monthly_usage::Entity::find()
.filter(
monthly_usage::Column::UserId
.eq(user_id)
.and(monthly_usage::Column::ModelId.eq(model.id))
.and(monthly_usage::Column::Month.eq(month))
.and(monthly_usage::Column::Year.eq(year)),
)
.one(&*tx)
.await?;
let monthly_usage = match monthly_usage {
Some(usage) => {
monthly_usage::Entity::update(monthly_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
.await?
}
None => {
monthly_usage::ActiveModel {
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
month: ActiveValue::set(month),
year: ActiveValue::set(year),
input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default()
}
.insert(&*tx)
.await?
}
};
let spending_this_month = calculate_spending(
model,
monthly_usage.input_tokens as usize,
monthly_usage.cache_creation_input_tokens as usize,
monthly_usage.cache_read_input_tokens as usize,
monthly_usage.output_tokens as usize,
);
if !is_staff
&& spending_this_month > free_tier_monthly_spending_limit
&& has_llm_subscription
&& (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend
{
billing_event::ActiveModel {
id: ActiveValue::not_set(),
idempotency_key: ActiveValue::not_set(),
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(tokens.input as i64),
input_cache_creation_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
}
.insert(&*tx)
.await?;
}
// Update lifetime usage
let lifetime_usage = lifetime_usage::Entity::find()
.filter(
lifetime_usage::Column::UserId
.eq(user_id)
.and(lifetime_usage::Column::ModelId.eq(model.id)),
)
.one(&*tx)
.await?;
let lifetime_usage = match lifetime_usage {
Some(usage) => {
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
.await?
}
None => {
lifetime_usage::ActiveModel {
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default()
}
.insert(&*tx)
.await?
}
};
let lifetime_spending = calculate_spending(
model,
lifetime_usage.input_tokens as usize,
lifetime_usage.cache_creation_input_tokens as usize,
lifetime_usage.cache_read_input_tokens as usize,
lifetime_usage.output_tokens as usize,
);
Ok(Usage {
requests_this_minute,
tokens_this_minute,
input_tokens_this_minute,
output_tokens_this_minute,
tokens_this_day,
tokens_this_month: TokenUsage {
input: monthly_usage.input_tokens as usize,
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
output: monthly_usage.output_tokens as usize,
},
spending_this_month,
lifetime_spending,
})
})
.await
}
/// Returns the active user count for the specified model.
pub async fn get_active_user_count(
&self,
provider: LanguageModelProvider,
model_name: &str,
now: DateTimeUtc,
) -> Result<ActiveUserCount> {
self.transaction(|tx| async move {
let minute_since = now - Duration::minutes(5);
let day_since = now - Duration::days(5);
let model = self
.models
.get(&(provider, model_name.to_string()))
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
let users_in_recent_minutes = usage::Entity::find()
.filter(
usage::Column::ModelId
.eq(model.id)
.and(usage::Column::MeasureId.eq(tokens_per_minute))
.and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()
.column(usage::Column::UserId)
.group_by(usage::Column::UserId)
.count(&*tx)
.await? as usize;
let users_in_recent_days = usage::Entity::find()
.filter(
usage::Column::ModelId
.eq(model.id)
.and(usage::Column::MeasureId.eq(tokens_per_minute))
.and(usage::Column::Timestamp.gte(day_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()
.column(usage::Column::UserId)
.group_by(usage::Column::UserId)
.count(&*tx)
.await? as usize;
Ok(ActiveUserCount {
users_in_recent_minutes,
users_in_recent_days,
})
})
.await
}
async fn update_usage_for_measure(
&self,
user_id: UserId,
is_staff: bool,
model_id: ModelId,
usages: &[usage::Model],
usage_measure: UsageMeasure,
now: DateTimeUtc,
usage_to_add: usize,
tx: &DatabaseTransaction,
) -> Result<usize> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
.get(&usage_measure)
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
let mut id = None;
let mut timestamp = now;
let mut buckets = vec![0_i64];
if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
id = Some(old_usage.id);
let (live_buckets, buckets_since) =
Self::get_live_buckets(old_usage, now, usage_measure);
if !live_buckets.is_empty() {
buckets.clear();
buckets.extend_from_slice(live_buckets);
buckets.extend(iter::repeat(0).take(buckets_since));
timestamp =
old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
}
}
*buckets.last_mut().unwrap() += usage_to_add as i64;
let total_usage = buckets.iter().sum::<i64>() as usize;
let mut model = usage::ActiveModel {
user_id: ActiveValue::set(user_id),
is_staff: ActiveValue::set(is_staff),
model_id: ActiveValue::set(model_id),
measure_id: ActiveValue::set(measure_id),
timestamp: ActiveValue::set(timestamp),
buckets: ActiveValue::set(buckets),
..Default::default()
};
if let Some(id) = id {
model.id = ActiveValue::unchanged(id);
model.update(tx).await?;
} else {
usage::Entity::insert(model)
.exec_without_returning(tx)
.await?;
}
Ok(total_usage)
}
fn get_usage_for_measure(
&self,
usages: &[usage::Model],
now: DateTimeUtc,
usage_measure: UsageMeasure,
) -> Result<usize> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
.get(&usage_measure)
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
return Ok(0);
};
let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
Ok(live_buckets.iter().sum::<i64>() as _)
}
fn get_live_buckets(
usage: &usage::Model,
now: chrono::NaiveDateTime,
measure: UsageMeasure,
) -> (&[i64], usize) {
let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
let buckets_since_usage =
seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
let buckets_since_usage = buckets_since_usage.ceil() as usize;
let mut live_buckets = &[] as &[i64];
if buckets_since_usage < measure.bucket_count() {
let expired_bucket_count =
(usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
live_buckets = &usage.buckets[expired_bucket_count..];
while live_buckets.first() == Some(&0) {
live_buckets = &live_buckets[1..];
}
}
(live_buckets, buckets_since_usage)
}
} }
fn calculate_spending( fn calculate_spending(
@ -741,32 +110,3 @@ fn calculate_spending(
+ output_token_cost; + output_token_cost;
Cents::new(spending as u32) Cents::new(spending as u32)
} }
const MINUTE_BUCKET_COUNT: usize = 12;
const DAY_BUCKET_COUNT: usize = 48;
impl UsageMeasure {
fn bucket_count(&self) -> usize {
match self {
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerMinute
| UsageMeasure::InputTokensPerMinute
| UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT,
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
}
}
fn total_duration(&self) -> Duration {
match self {
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerMinute
| UsageMeasure::InputTokensPerMinute
| UsageMeasure::OutputTokensPerMinute => Duration::minutes(1),
UsageMeasure::TokensPerDay => Duration::hours(24),
}
}
fn bucket_duration(&self) -> Duration {
self.total_duration() / self.bucket_count() as i32
}
}

View file

@ -1,8 +1,6 @@
pub mod billing_event; pub mod billing_event;
pub mod lifetime_usage;
pub mod model; pub mod model;
pub mod monthly_usage; pub mod monthly_usage;
pub mod provider; pub mod provider;
pub mod revoked_access_token;
pub mod usage; pub mod usage;
pub mod usage_measure; pub mod usage_measure;

View file

@ -1,20 +0,0 @@
use crate::{db::UserId, llm::db::ModelId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "lifetime_usages")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub user_id: UserId,
pub model_id: ModelId,
pub input_tokens: i64,
pub cache_creation_input_tokens: i64,
pub cache_read_input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,19 +0,0 @@
use chrono::NaiveDateTime;
use sea_orm::entity::prelude::*;
use crate::llm::db::RevokedAccessTokenId;
/// A revoked access token.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "revoked_access_tokens")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: RevokedAccessTokenId,
pub jti: String,
pub revoked_at: NaiveDateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,6 +1,4 @@
mod billing_tests;
mod provider_tests; mod provider_tests;
mod usage_tests;
use gpui::BackgroundExecutor; use gpui::BackgroundExecutor;
use parking_lot::Mutex; use parking_lot::Mutex;

View file

@ -1,152 +0,0 @@
use crate::{
Cents,
db::UserId,
llm::{
FREE_TIER_MONTHLY_SPENDING_LIMIT,
db::{LlmDatabase, TokenUsage, queries::providers::ModelParams},
},
test_llm_db,
};
use chrono::{DateTime, Utc};
use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
test_llm_db!(
test_billing_limit_exceeded,
test_billing_limit_exceeded_postgres
);
async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
let provider = LanguageModelProvider::Anthropic;
let model = "fake-claude-limerick";
const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
// Initialize the database and insert the model
db.initialize().await.unwrap();
db.insert_models(&[ModelParams {
provider,
name: model.to_string(),
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 50_000,
price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
}])
.await
.unwrap();
// Set a fixed datetime for consistent testing
let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
.unwrap()
.with_timezone(&Utc);
let user_id = UserId::from_proto(123);
let max_monthly_spend = Cents::from_dollars(11);
// Record usage that brings us close to the limit but doesn't exceed it
// Let's say we use $10.50 worth of tokens
let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
let usage = TokenUsage {
input: tokens_to_use,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
// Verify that before we record any usage, there are 0 billing events
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 0);
db.record_usage(
user_id,
false,
provider,
model,
usage,
true,
max_monthly_spend,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
// Verify the recorded usage and spending
let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
// Verify that we exceeded the free tier usage
assert_eq!(recorded_usage.spending_this_month, Cents::new(1050));
assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT);
// Verify that there is one `billing_event` record
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 1);
let (billing_event, _model) = &billing_events[0];
assert_eq!(billing_event.user_id, user_id);
assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
assert_eq!(billing_event.input_cache_creation_tokens, 0);
assert_eq!(billing_event.input_cache_read_tokens, 0);
assert_eq!(billing_event.output_tokens, 0);
// Record usage that puts us at $20.50
let usage_2 = TokenUsage {
input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
db.record_usage(
user_id,
false,
provider,
model,
usage_2,
true,
max_monthly_spend,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
// Verify the updated usage and spending
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(updated_usage.spending_this_month, Cents::new(2050));
// Verify that there are now two billing events
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 2);
let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit
let usage_exceeding = TokenUsage {
input: tokens_to_exceed,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
// This should still create a billing event as it's the first request that exceeds the limit
db.record_usage(
user_id,
false,
provider,
model,
usage_exceeding,
true,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
max_monthly_spend,
now,
)
.await
.unwrap();
// Verify the updated usage and spending
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(updated_usage.spending_this_month, Cents::new(2150));
// Verify that we never exceed the user max spending for the user
// and avoid charging them.
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 2);
}

View file

@ -1,306 +0,0 @@
use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT;
use crate::{
Cents,
db::UserId,
llm::db::{
LlmDatabase, TokenUsage,
queries::{providers::ModelParams, usages::Usage},
},
test_llm_db,
};
use chrono::{DateTime, Duration, Utc};
use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
async fn test_tracking_usage(db: &mut LlmDatabase) {
let provider = LanguageModelProvider::Anthropic;
let model = "claude-3-5-sonnet";
db.initialize().await.unwrap();
db.insert_models(&[ModelParams {
provider,
name: model.to_string(),
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 50_000,
price_per_million_input_tokens: 50,
price_per_million_output_tokens: 50,
}])
.await
.unwrap();
// We're using a fixed datetime to prevent flakiness based on the clock.
let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
.unwrap()
.with_timezone(&Utc);
let user_id = UserId::from_proto(123);
let now = t0;
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let now = t0 + Duration::seconds(10);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 2000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 3000,
input_tokens_this_minute: 3000,
output_tokens_this_minute: 0,
tokens_this_day: 3000,
tokens_this_month: TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let now = t0 + Duration::seconds(60);
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 2000,
input_tokens_this_minute: 2000,
output_tokens_this_minute: 0,
tokens_this_day: 3000,
tokens_this_month: TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let now = t0 + Duration::seconds(60);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 5000,
input_tokens_this_minute: 5000,
output_tokens_this_minute: 0,
tokens_this_day: 6000,
tokens_this_month: TokenUsage {
input: 6000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let t1 = t0 + Duration::hours(24);
let now = t1;
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 0,
tokens_this_minute: 0,
input_tokens_this_minute: 0,
output_tokens_this_minute: 0,
tokens_this_day: 5000,
tokens_this_month: TokenUsage {
input: 6000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 4000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 4000,
input_tokens_this_minute: 4000,
output_tokens_this_minute: 0,
tokens_this_day: 9000,
tokens_this_month: TokenUsage {
input: 10000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
// We're using a fixed datetime to prevent flakiness based on the clock.
let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
.unwrap()
.with_timezone(&Utc);
// Test cache creation input tokens
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 500,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 1,
tokens_this_minute: 1500,
input_tokens_this_minute: 1500,
output_tokens_this_minute: 0,
tokens_this_day: 1500,
tokens_this_month: TokenUsage {
input: 1000,
input_cache_creation: 500,
input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
// Test cache read input tokens
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 300,
output: 0,
},
false,
Cents::ZERO,
FREE_TIER_MONTHLY_SPENDING_LIMIT,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
usage,
Usage {
requests_this_minute: 2,
tokens_this_minute: 2800,
input_tokens_this_minute: 2500,
output_tokens_this_minute: 0,
tokens_this_day: 2800,
tokens_this_month: TokenUsage {
input: 2000,
input_cache_creation: 500,
input_cache_read: 300,
output: 0,
},
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
}

View file

@ -9,14 +9,14 @@ use axum::{
use collab::api::CloudflareIpCountryHeader; use collab::api::CloudflareIpCountryHeader;
use collab::api::billing::sync_llm_usage_with_stripe_periodically; use collab::api::billing::sync_llm_usage_with_stripe_periodically;
use collab::llm::{db::LlmDatabase, log_usage_periodically}; use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations; use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller; use collab::user_backfiller::spawn_user_backfiller;
use collab::{ use collab::{
AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db, AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
env, executor::Executor, rpc::ResultExt, env, executor::Executor, rpc::ResultExt,
}; };
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically, llm::LlmState}; use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
use db::Database; use db::Database;
use std::{ use std::{
env::args, env::args,
@ -74,11 +74,10 @@ async fn main() -> Result<()> {
let mode = match args.next().as_deref() { let mode = match args.next().as_deref() {
Some("collab") => ServiceMode::Collab, Some("collab") => ServiceMode::Collab,
Some("api") => ServiceMode::Api, Some("api") => ServiceMode::Api,
Some("llm") => ServiceMode::Llm,
Some("all") => ServiceMode::All, Some("all") => ServiceMode::All,
_ => { _ => {
return Err(anyhow!( return Err(anyhow!(
"usage: collab <version | migrate | seed | serve <api|collab|llm|all>>" "usage: collab <version | migrate | seed | serve <api|collab|all>>"
))?; ))?;
} }
}; };
@ -97,20 +96,9 @@ async fn main() -> Result<()> {
let mut on_shutdown = None; let mut on_shutdown = None;
if mode.is_llm() {
setup_llm_database(&config).await?;
let state = LlmState::new(config.clone(), Executor::Production).await?;
log_usage_periodically(state.clone());
app = app
.merge(collab::llm::routes())
.layer(Extension(state.clone()));
}
if mode.is_collab() || mode.is_api() { if mode.is_collab() || mode.is_api() {
setup_app_database(&config).await?; setup_app_database(&config).await?;
setup_llm_database(&config).await?;
let state = AppState::new(config, Executor::Production).await?; let state = AppState::new(config, Executor::Production).await?;
@ -336,18 +324,11 @@ async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown")) format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
} }
async fn handle_liveness_probe( async fn handle_liveness_probe(app_state: Option<Extension<Arc<AppState>>>) -> Result<String> {
app_state: Option<Extension<Arc<AppState>>>,
llm_state: Option<Extension<Arc<LlmState>>>,
) -> Result<String> {
if let Some(state) = app_state { if let Some(state) = app_state {
state.db.get_all_users(0, 1).await?; state.db.get_all_users(0, 1).await?;
} }
if let Some(llm_state) = llm_state {
llm_state.db.list_providers().await?;
}
Ok("ok".to_string()) Ok("ok".to_string())
} }