Apply rate limits in LLM service (#15997)
Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev> Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
2bc503771b
commit
06625bfe94
21 changed files with 983 additions and 227 deletions
|
@ -2,24 +2,25 @@ mod authorization;
|
|||
pub mod db;
|
||||
mod token;
|
||||
|
||||
use crate::api::CloudflareIpCountryHeader;
|
||||
use crate::llm::authorization::authorize_access_to_language_model;
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::{executor::Executor, Config, Error, Result};
|
||||
use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
use axum::TypedHeader;
|
||||
use authorization::authorize_access_to_language_model;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||
middleware::{self, Next},
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
Extension, Json, Router,
|
||||
Extension, Json, Router, TypedHeader,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use db::{ActiveUserCount, LlmDatabase};
|
||||
use futures::StreamExt as _;
|
||||
use http_client::IsahcHttpClient;
|
||||
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use util::ResultExt;
|
||||
|
||||
pub use token::*;
|
||||
|
||||
|
@ -28,8 +29,11 @@ pub struct LlmState {
|
|||
pub executor: Executor,
|
||||
pub db: Option<Arc<LlmDatabase>>,
|
||||
pub http_client: IsahcHttpClient,
|
||||
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
|
||||
}
|
||||
|
||||
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
|
||||
|
||||
impl LlmState {
|
||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||
// TODO: This is temporary until we have the LLM database stood up.
|
||||
|
@ -44,7 +48,8 @@ impl LlmState {
|
|||
|
||||
let mut db_options = db::ConnectOptions::new(database_url);
|
||||
db_options.max_connections(max_connections);
|
||||
let db = LlmDatabase::new(db_options, executor.clone()).await?;
|
||||
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
|
||||
db.initialize().await?;
|
||||
|
||||
Some(Arc::new(db))
|
||||
} else {
|
||||
|
@ -57,15 +62,41 @@ impl LlmState {
|
|||
.build()
|
||||
.context("failed to construct http client")?;
|
||||
|
||||
let initial_active_user_count = if let Some(db) = &db {
|
||||
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let this = Self {
|
||||
config,
|
||||
executor,
|
||||
db,
|
||||
http_client,
|
||||
active_user_count: RwLock::new(initial_active_user_count),
|
||||
};
|
||||
|
||||
Ok(Arc::new(this))
|
||||
}
|
||||
|
||||
pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
|
||||
let now = Utc::now();
|
||||
|
||||
if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
|
||||
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
|
||||
return Ok(*count);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(db) = &self.db {
|
||||
let mut cache = self.active_user_count.write().await;
|
||||
let new_count = db.get_active_user_count(now).await?;
|
||||
*cache = Some((now, new_count));
|
||||
Ok(new_count)
|
||||
} else {
|
||||
Ok(ActiveUserCount::default())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes() -> Router<(), Body> {
|
||||
|
@ -122,14 +153,22 @@ async fn perform_completion(
|
|||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
Json(params): Json<PerformCompletionParams>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let model = normalize_model_name(params.provider, params.model);
|
||||
|
||||
authorize_access_to_language_model(
|
||||
&state.config,
|
||||
&claims,
|
||||
country_code_header.map(|header| header.to_string()),
|
||||
params.provider,
|
||||
¶ms.model,
|
||||
&model,
|
||||
)?;
|
||||
|
||||
let user_id = claims.user_id as i32;
|
||||
|
||||
if state.db.is_some() {
|
||||
check_usage_limit(&state, params.provider, &model, &claims).await?;
|
||||
}
|
||||
|
||||
match params.provider {
|
||||
LanguageModelProvider::Anthropic => {
|
||||
let api_key = state
|
||||
|
@ -160,9 +199,31 @@ async fn perform_completion(
|
|||
)
|
||||
.await?;
|
||||
|
||||
let stream = chunks.map(|event| {
|
||||
let mut recorder = state.db.clone().map(|db| UsageRecorder {
|
||||
db,
|
||||
executor: state.executor.clone(),
|
||||
user_id,
|
||||
provider: params.provider,
|
||||
model,
|
||||
token_count: 0,
|
||||
});
|
||||
|
||||
let stream = chunks.map(move |event| {
|
||||
let mut buffer = Vec::new();
|
||||
event.map(|chunk| {
|
||||
match &chunk {
|
||||
anthropic::Event::MessageStart {
|
||||
message: anthropic::Response { usage, .. },
|
||||
}
|
||||
| anthropic::Event::MessageDelta { usage, .. } => {
|
||||
if let Some(recorder) = &mut recorder {
|
||||
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
|
||||
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
buffer.clear();
|
||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||
buffer.push(b'\n');
|
||||
|
@ -259,3 +320,102 @@ async fn perform_completion(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
||||
match provider {
|
||||
LanguageModelProvider::Anthropic => {
|
||||
for prefix in &[
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-haiku",
|
||||
"claude-3-opus",
|
||||
"claude-3-sonnet",
|
||||
] {
|
||||
if name.starts_with(prefix) {
|
||||
return prefix.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
LanguageModelProvider::OpenAi => {}
|
||||
LanguageModelProvider::Google => {}
|
||||
LanguageModelProvider::Zed => {}
|
||||
}
|
||||
|
||||
name
|
||||
}
|
||||
|
||||
async fn check_usage_limit(
|
||||
state: &Arc<LlmState>,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
claims: &LlmTokenClaims,
|
||||
) -> Result<()> {
|
||||
let db = state
|
||||
.db
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("LLM database not configured"))?;
|
||||
let model = db.model(provider, model_name)?;
|
||||
let usage = db
|
||||
.get_usage(claims.user_id as i32, provider, model_name, Utc::now())
|
||||
.await?;
|
||||
|
||||
let active_users = state.get_active_user_count().await?;
|
||||
|
||||
let per_user_max_requests_per_minute =
|
||||
model.max_requests_per_minute as usize / active_users.users_in_recent_minutes.max(1);
|
||||
let per_user_max_tokens_per_minute =
|
||||
model.max_tokens_per_minute as usize / active_users.users_in_recent_minutes.max(1);
|
||||
let per_user_max_tokens_per_day =
|
||||
model.max_tokens_per_day as usize / active_users.users_in_recent_days.max(1);
|
||||
|
||||
let checks = [
|
||||
(
|
||||
usage.requests_this_minute,
|
||||
per_user_max_requests_per_minute,
|
||||
"requests per minute",
|
||||
),
|
||||
(
|
||||
usage.tokens_this_minute,
|
||||
per_user_max_tokens_per_minute,
|
||||
"tokens per minute",
|
||||
),
|
||||
(
|
||||
usage.tokens_this_day,
|
||||
per_user_max_tokens_per_day,
|
||||
"tokens per day",
|
||||
),
|
||||
];
|
||||
|
||||
for (usage, limit, resource) in checks {
|
||||
if usage > limit {
|
||||
return Err(Error::http(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
format!("Rate limit exceeded. Maximum {} reached.", resource),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
struct UsageRecorder {
|
||||
db: Arc<LlmDatabase>,
|
||||
executor: Executor,
|
||||
user_id: i32,
|
||||
provider: LanguageModelProvider,
|
||||
model: String,
|
||||
token_count: usize,
|
||||
}
|
||||
|
||||
impl Drop for UsageRecorder {
|
||||
fn drop(&mut self) {
|
||||
let db = self.db.clone();
|
||||
let user_id = self.user_id;
|
||||
let provider = self.provider;
|
||||
let model = std::mem::take(&mut self.model);
|
||||
let token_count = self.token_count;
|
||||
self.executor.spawn_detached(async move {
|
||||
db.record_usage(user_id, provider, &model, token_count, Utc::now())
|
||||
.await
|
||||
.log_err();
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue