Add tracing needed for LLM rate limit dashboards (#16388)

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-08-16 14:52:31 -07:00 committed by GitHub
parent 9ef3306f55
commit 1b1070e0f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 227 additions and 29 deletions

View file

@ -217,7 +217,7 @@ async fn perform_completion(
_ => request.model,
};
let chunks = anthropic::stream_completion(
let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
&state.http_client,
anthropic::ANTHROPIC_API_URL,
api_key,
@ -245,6 +245,18 @@ async fn perform_completion(
anthropic::AnthropicError::Other(err) => Error::Internal(err),
})?;
if let Some(rate_limit_info) = rate_limit_info {
tracing::info!(
target: "upstream rate limit",
provider = params.provider.to_string(),
model = model,
tokens_remaining = rate_limit_info.tokens_remaining,
requests_remaining = rate_limit_info.requests_remaining,
requests_reset = ?rate_limit_info.requests_reset,
tokens_reset = ?rate_limit_info.tokens_reset,
);
}
chunks
.map(move |event| {
let chunk = event?;
@ -540,33 +552,74 @@ impl<S> Drop for TokenCountingStream<S> {
.await
.log_err();
if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
report_llm_usage(
clickhouse_client,
LlmUsageEventRow {
time: Utc::now().timestamp_millis(),
user_id: claims.user_id as i32,
is_staff: claims.is_staff,
plan: match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
if let Some(usage) = usage {
tracing::info!(
target: "user usage",
user_id = claims.user_id,
login = claims.github_user_login,
authn.jti = claims.jti,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
);
if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
report_llm_usage(
clickhouse_client,
LlmUsageEventRow {
time: Utc::now().timestamp_millis(),
user_id: claims.user_id as i32,
is_staff: claims.is_staff,
plan: match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
model,
provider: provider.to_string(),
input_token_count: input_token_count as u64,
output_token_count: output_token_count as u64,
requests_this_minute: usage.requests_this_minute as u64,
tokens_this_minute: usage.tokens_this_minute as u64,
tokens_this_day: usage.tokens_this_day as u64,
input_tokens_this_month: usage.input_tokens_this_month as u64,
output_tokens_this_month: usage.output_tokens_this_month as u64,
spending_this_month: usage.spending_this_month as u64,
lifetime_spending: usage.lifetime_spending as u64,
},
model,
provider: provider.to_string(),
input_token_count: input_token_count as u64,
output_token_count: output_token_count as u64,
requests_this_minute: usage.requests_this_minute as u64,
tokens_this_minute: usage.tokens_this_minute as u64,
tokens_this_day: usage.tokens_this_day as u64,
input_tokens_this_month: usage.input_tokens_this_month as u64,
output_tokens_this_month: usage.output_tokens_this_month as u64,
spending_this_month: usage.spending_this_month as u64,
lifetime_spending: usage.lifetime_spending as u64,
},
)
.await
.log_err();
)
.await
.log_err();
}
}
})
}
}
pub fn log_usage_periodically(state: Arc<LlmState>) {
state.executor.clone().spawn_detached(async move {
loop {
state
.executor
.sleep(std::time::Duration::from_secs(30))
.await;
let Some(usages) = state
.db
.get_application_wide_usages_by_model(Utc::now())
.await
.log_err()
else {
continue;
};
for usage in usages {
tracing::info!(
target: "computed usage",
provider = usage.provider.to_string(),
model = usage.model,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
);
}
}
})
}

View file

@ -1,5 +1,6 @@
use crate::db::UserId;
use chrono::Duration;
use futures::StreamExt as _;
use rpc::LanguageModelProvider;
use sea_orm::QuerySelect;
use std::{iter, str::FromStr};
@ -18,6 +19,14 @@ pub struct Usage {
pub lifetime_spending: usize,
}
#[derive(Debug, PartialEq, Clone)]
pub struct ApplicationWideUsage {
pub provider: LanguageModelProvider,
pub model: String,
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ActiveUserCount {
pub users_in_recent_minutes: usize,
@ -63,6 +72,71 @@ impl LlmDatabase {
Ok(())
}
pub async fn get_application_wide_usages_by_model(
&self,
now: DateTimeUtc,
) -> Result<Vec<ApplicationWideUsage>> {
self.transaction(|tx| async move {
let past_minute = now - Duration::minutes(1);
let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
let mut results = Vec::new();
for (provider, model) in self.models.keys().cloned() {
let mut usages = usage::Entity::find()
.filter(
usage::Column::Timestamp
.gte(past_minute.naive_utc())
.and(usage::Column::IsStaff.eq(false))
.and(
usage::Column::MeasureId
.eq(requests_per_minute)
.or(usage::Column::MeasureId.eq(tokens_per_minute)),
),
)
.stream(&*tx)
.await?;
let mut requests_this_minute = 0;
let mut tokens_this_minute = 0;
while let Some(usage) = usages.next().await {
let usage = usage?;
if usage.measure_id == requests_per_minute {
requests_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::RequestsPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
} else if usage.measure_id == tokens_per_minute {
tokens_this_minute += Self::get_live_buckets(
&usage,
now.naive_utc(),
UsageMeasure::TokensPerMinute,
)
.0
.iter()
.copied()
.sum::<i64>() as usize;
}
}
results.push(ApplicationWideUsage {
provider,
model,
requests_this_minute,
tokens_this_minute,
})
}
Ok(results)
})
.await
}
pub async fn get_usage(
&self,
user_id: UserId,

View file

@ -5,7 +5,7 @@ use axum::{
routing::get,
Extension, Router,
};
use collab::llm::db::LlmDatabase;
use collab::llm::{db::LlmDatabase, log_usage_periodically};
use collab::migrations::run_database_migrations;
use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
use collab::{
@ -95,6 +95,8 @@ async fn main() -> Result<()> {
let state = LlmState::new(config.clone(), Executor::Production).await?;
log_usage_periodically(state.clone());
app = app
.merge(collab::llm::routes())
.layer(Extension(state.clone()));