collab: Track active user counts independently for each model (#16624)

This PR fixes an issue where the active user count spanned individual
models.

We now track the active user counts on a per-model basis.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-21 17:19:47 -04:00 committed by GitHub
parent f85ca387a7
commit 0229d3ccac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 40 additions and 22 deletions

View file

@ -18,6 +18,7 @@ use axum::{
Extension, Json, Router, TypedHeader,
};
use chrono::{DateTime, Duration, Utc};
use collections::HashMap;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _};
use http_client::IsahcHttpClient;
@ -41,7 +42,8 @@ pub struct LlmState {
pub db: Arc<LlmDatabase>,
pub http_client: IsahcHttpClient,
pub clickhouse_client: Option<clickhouse::Client>,
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
active_user_count_by_model:
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
}
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
@ -69,9 +71,6 @@ impl LlmState {
.build()
.context("failed to construct http client")?;
let initial_active_user_count =
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
let this = Self {
executor,
db,
@ -80,25 +79,34 @@ impl LlmState {
.clickhouse_url
.as_ref()
.and_then(|_| build_clickhouse_client(&config).log_err()),
active_user_count: RwLock::new(initial_active_user_count),
active_user_count_by_model: RwLock::new(HashMap::default()),
config,
};
Ok(Arc::new(this))
}
pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
pub async fn get_active_user_count(
&self,
provider: LanguageModelProvider,
model: &str,
) -> Result<ActiveUserCount> {
let now = Utc::now();
if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
return Ok(*count);
{
let active_user_count_by_model = self.active_user_count_by_model.read().await;
if let Some((last_updated, count)) =
active_user_count_by_model.get(&(provider, model.to_string()))
{
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
return Ok(*count);
}
}
}
let mut cache = self.active_user_count.write().await;
let new_count = self.db.get_active_user_count(now).await?;
*cache = Some((now, new_count));
let mut cache = self.active_user_count_by_model.write().await;
let new_count = self.db.get_active_user_count(provider, model, now).await?;
cache.insert((provider, model.to_string()), (now, new_count));
Ok(new_count)
}
}
@ -419,7 +427,7 @@ async fn check_usage_limit(
)
.await?;
let active_users = state.get_active_user_count().await?;
let active_users = state.get_active_user_count(provider, model_name).await?;
let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
let users_in_recent_days = active_users.users_in_recent_days.max(1);

View file

@ -343,15 +343,27 @@ impl LlmDatabase {
.await
}
pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result<ActiveUserCount> {
/// Returns the active user count for the specified model.
pub async fn get_active_user_count(
&self,
provider: LanguageModelProvider,
model_name: &str,
now: DateTimeUtc,
) -> Result<ActiveUserCount> {
self.transaction(|tx| async move {
let minute_since = now - Duration::minutes(5);
let day_since = now - Duration::days(5);
let model = self
.models
.get(&(provider, model_name.to_string()))
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
let users_in_recent_minutes = usage::Entity::find()
.filter(
usage::Column::Timestamp
.gte(minute_since.naive_utc())
usage::Column::ModelId
.eq(model.id)
.and(usage::Column::Timestamp.gte(minute_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()
@ -362,8 +374,9 @@ impl LlmDatabase {
let users_in_recent_days = usage::Entity::find()
.filter(
usage::Column::Timestamp
.gte(day_since.naive_utc())
usage::Column::ModelId
.eq(model.id)
.and(usage::Column::Timestamp.gte(day_since.naive_utc()))
.and(usage::Column::IsStaff.eq(false)),
)
.select_only()

View file

@ -302,10 +302,7 @@ async fn handle_liveness_probe(
}
if let Some(llm_state) = llm_state {
llm_state
.db
.get_active_user_count(chrono::Utc::now())
.await?;
llm_state.db.list_providers().await?;
}
Ok("ok".to_string())