Apply rate limits in LLM service (#15997)

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Max Brunsfeld 2024-08-08 15:46:33 -07:00 committed by GitHub
parent 2bc503771b
commit 06625bfe94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 983 additions and 227 deletions

View file

@ -1,66 +1,115 @@
use sea_orm::sea_query::OnConflict;
use sea_orm::QueryOrder;
use super::*;
use sea_orm::QueryOrder;
use std::str::FromStr;
use strum::IntoEnumIterator as _;
pub struct ModelRateLimits {
pub max_requests_per_minute: i32,
pub max_tokens_per_minute: i32,
pub max_tokens_per_day: i32,
}
impl LlmDatabase {
pub async fn initialize_providers(&self) -> Result<()> {
pub async fn initialize_providers(&mut self) -> Result<()> {
self.provider_ids = self
.transaction(|tx| async move {
let existing_providers = provider::Entity::find().all(&*tx).await?;
let mut new_providers = LanguageModelProvider::iter()
.filter(|provider| {
!existing_providers
.iter()
.any(|p| p.name == provider.to_string())
})
.map(|provider| provider::ActiveModel {
name: ActiveValue::set(provider.to_string()),
..Default::default()
})
.peekable();
if new_providers.peek().is_some() {
provider::Entity::insert_many(new_providers)
.exec(&*tx)
.await?;
}
let all_providers: HashMap<_, _> = provider::Entity::find()
.all(&*tx)
.await?
.iter()
.filter_map(|provider| {
LanguageModelProvider::from_str(&provider.name)
.ok()
.map(|p| (p, provider.id))
})
.collect();
Ok(all_providers)
})
.await?;
Ok(())
}
pub async fn initialize_models(&mut self) -> Result<()> {
let all_provider_ids = &self.provider_ids;
self.models = self
.transaction(|tx| async move {
let all_models: HashMap<_, _> = model::Entity::find()
.all(&*tx)
.await?
.into_iter()
.filter_map(|model| {
let provider = all_provider_ids.iter().find_map(|(provider, id)| {
if *id == model.provider_id {
Some(provider)
} else {
None
}
})?;
Some(((*provider, model.name.clone()), model))
})
.collect();
Ok(all_models)
})
.await?;
Ok(())
}
pub async fn insert_models(
&mut self,
models: &[(LanguageModelProvider, String, ModelRateLimits)],
) -> Result<()> {
let all_provider_ids = &self.provider_ids;
self.transaction(|tx| async move {
let providers_and_models = vec![
("anthropic", "claude-3-5-sonnet"),
("anthropic", "claude-3-opus"),
("anthropic", "claude-3-sonnet"),
("anthropic", "claude-3-haiku"),
];
for (provider_name, model_name) in providers_and_models {
let insert_provider = provider::Entity::insert(provider::ActiveModel {
name: ActiveValue::set(provider_name.to_owned()),
model::Entity::insert_many(models.into_iter().map(|(provider, name, rate_limits)| {
let provider_id = all_provider_ids[&provider];
model::ActiveModel {
provider_id: ActiveValue::set(provider_id),
name: ActiveValue::set(name.clone()),
max_requests_per_minute: ActiveValue::set(rate_limits.max_requests_per_minute),
max_tokens_per_minute: ActiveValue::set(rate_limits.max_tokens_per_minute),
max_tokens_per_day: ActiveValue::set(rate_limits.max_tokens_per_day),
..Default::default()
})
.on_conflict(
OnConflict::columns([provider::Column::Name])
.update_column(provider::Column::Name)
.to_owned(),
);
let provider = if tx.support_returning() {
insert_provider.exec_with_returning(&*tx).await?
} else {
insert_provider.exec_without_returning(&*tx).await?;
provider::Entity::find()
.filter(provider::Column::Name.eq(provider_name))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("failed to insert provider"))?
};
model::Entity::insert(model::ActiveModel {
provider_id: ActiveValue::set(provider.id),
name: ActiveValue::set(model_name.to_owned()),
..Default::default()
})
.on_conflict(
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
.update_column(model::Column::Name)
.to_owned(),
)
.exec_without_returning(&*tx)
.await?;
}
}
}))
.exec_without_returning(&*tx)
.await?;
Ok(())
})
.await
.await?;
self.initialize_models().await
}
/// Returns the list of LLM providers.
pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
self.transaction(|tx| async move {
Ok(provider::Entity::find()
.order_by_asc(provider::Column::Name)
.all(&*tx)
.await?)
.await?
.into_iter()
.filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
.collect())
})
.await
}