collab: Adapt rate limits based on plan (#15548)

This PR updates the rate limits to adapt based on the user's current
plan.

For the free plan rate limits I just took one-tenth of the existing rate
limits (which are now the Pro limits). We can adjust, as needed.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Marshall Bowers 2024-07-31 14:27:28 -04:00 committed by GitHub
parent 7a0149f17c
commit 8c54a46202
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 195 additions and 106 deletions

View file

@ -199,6 +199,23 @@ impl Session {
}
}
pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
if self.is_staff() {
return Ok(proto::Plan::ZedPro);
}
let Some(user_id) = self.user_id() else {
return Ok(proto::Plan::Free);
};
let db = self.db().await;
if db.has_active_billing_subscription(user_id).await? {
Ok(proto::Plan::ZedPro)
} else {
Ok(proto::Plan::Free)
}
}
fn dev_server_id(&self) -> Option<DevServerId> {
match &self.principal {
Principal::User(_) | Principal::Impersonated { .. } => None,
@ -3537,15 +3554,8 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
version.0.minor() < 139
}
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
let db = session.db().await;
let active_subscriptions = db.get_active_billing_subscriptions(user_id).await?;
let plan = if session.is_staff() || !active_subscriptions.is_empty() {
proto::Plan::ZedPro
} else {
proto::Plan::Free
};
async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
let plan = session.current_plan().await?;
session
.peer
@ -4532,22 +4542,41 @@ async fn acknowledge_buffer_version(
Ok(())
}
struct CompleteWithLanguageModelRateLimit;
struct ZedProCompleteWithLanguageModelRateLimit;
impl RateLimit for CompleteWithLanguageModelRateLimit {
fn capacity() -> usize {
impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
fn capacity(&self) -> usize {
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120) // Picked arbitrarily
}
fn refill_duration() -> chrono::Duration {
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name() -> &'static str {
"complete-with-language-model"
fn db_name(&self) -> &'static str {
"zed-pro:complete-with-language-model"
}
}
struct FreeCompleteWithLanguageModelRateLimit;
impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
fn capacity(&self) -> usize {
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120 / 10) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"free:complete-with-language-model"
}
}
@ -4562,9 +4591,14 @@ async fn complete_with_language_model(
};
authorize_access_to_language_models(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
};
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.check(&*rate_limit, session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
@ -4602,9 +4636,14 @@ async fn stream_complete_with_language_model(
};
authorize_access_to_language_models(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
};
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.check(&*rate_limit, session.user_id())
.await?;
match proto::LanguageModelProvider::from_i32(request.provider) {
@ -4684,9 +4723,14 @@ async fn count_language_model_tokens(
};
authorize_access_to_language_models(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
};
session
.rate_limiter
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
.check(&*rate_limit, session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
@ -4713,41 +4757,79 @@ async fn count_language_model_tokens(
Ok(())
}
struct CountLanguageModelTokensRateLimit;
struct ZedProCountLanguageModelTokensRateLimit;
impl RateLimit for CountLanguageModelTokensRateLimit {
fn capacity() -> usize {
impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
fn capacity(&self) -> usize {
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
}
fn refill_duration() -> chrono::Duration {
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name() -> &'static str {
"count-language-model-tokens"
fn db_name(&self) -> &'static str {
"zed-pro:count-language-model-tokens"
}
}
struct ComputeEmbeddingsRateLimit;
struct FreeCountLanguageModelTokensRateLimit;
impl RateLimit for ComputeEmbeddingsRateLimit {
fn capacity() -> usize {
impl RateLimit for FreeCountLanguageModelTokensRateLimit {
fn capacity(&self) -> usize {
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600 / 10) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"free:count-language-model-tokens"
}
}
struct ZedProComputeEmbeddingsRateLimit;
impl RateLimit for ZedProComputeEmbeddingsRateLimit {
fn capacity(&self) -> usize {
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5000) // Picked arbitrarily
}
fn refill_duration() -> chrono::Duration {
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name() -> &'static str {
"compute-embeddings"
fn db_name(&self) -> &'static str {
"zed-pro:compute-embeddings"
}
}
struct FreeComputeEmbeddingsRateLimit;
impl RateLimit for FreeComputeEmbeddingsRateLimit {
fn capacity(&self) -> usize {
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5000 / 10) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"free:compute-embeddings"
}
}
@ -4760,9 +4842,14 @@ async fn compute_embeddings(
let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_language_models(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
};
session
.rate_limiter
.check::<ComputeEmbeddingsRateLimit>(session.user_id())
.check(&*rate_limit, session.user_id())
.await?;
let embeddings = match request.model.as_str() {
@ -4834,10 +4921,10 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<()
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
if flags.iter().any(|flag| flag == "language-models") {
Ok(())
} else {
Err(anyhow!("permission denied"))?
return Ok(());
}
Err(anyhow!("permission denied"))?
}
/// Get a Supermaven API key for the user