collab: Adapt rate limits based on plan (#15548)
This PR updates the rate limits to adapt based on the user's current plan. For the free plan rate limits I just took one-tenth of the existing rate limits (which are now the Pro limits). We can adjust, as needed. Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
7a0149f17c
commit
8c54a46202
5 changed files with 195 additions and 106 deletions
|
@ -199,6 +199,23 @@ impl Session {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
|
||||
if self.is_staff() {
|
||||
return Ok(proto::Plan::ZedPro);
|
||||
}
|
||||
|
||||
let Some(user_id) = self.user_id() else {
|
||||
return Ok(proto::Plan::Free);
|
||||
};
|
||||
|
||||
let db = self.db().await;
|
||||
if db.has_active_billing_subscription(user_id).await? {
|
||||
Ok(proto::Plan::ZedPro)
|
||||
} else {
|
||||
Ok(proto::Plan::Free)
|
||||
}
|
||||
}
|
||||
|
||||
fn dev_server_id(&self) -> Option<DevServerId> {
|
||||
match &self.principal {
|
||||
Principal::User(_) | Principal::Impersonated { .. } => None,
|
||||
|
@ -3537,15 +3554,8 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
|
|||
version.0.minor() < 139
|
||||
}
|
||||
|
||||
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let active_subscriptions = db.get_active_billing_subscriptions(user_id).await?;
|
||||
|
||||
let plan = if session.is_staff() || !active_subscriptions.is_empty() {
|
||||
proto::Plan::ZedPro
|
||||
} else {
|
||||
proto::Plan::Free
|
||||
};
|
||||
async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
|
||||
let plan = session.current_plan().await?;
|
||||
|
||||
session
|
||||
.peer
|
||||
|
@ -4532,22 +4542,41 @@ async fn acknowledge_buffer_version(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
struct CompleteWithLanguageModelRateLimit;
|
||||
struct ZedProCompleteWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for CompleteWithLanguageModelRateLimit {
|
||||
fn capacity() -> usize {
|
||||
impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(120) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration() -> chrono::Duration {
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"complete-with-language-model"
|
||||
fn db_name(&self) -> &'static str {
|
||||
"zed-pro:complete-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
struct FreeCompleteWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(120 / 10) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"free:complete-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4562,9 +4591,14 @@ async fn complete_with_language_model(
|
|||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
|
@ -4602,9 +4636,14 @@ async fn stream_complete_with_language_model(
|
|||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
|
@ -4684,9 +4723,14 @@ async fn count_language_model_tokens(
|
|||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
|
@ -4713,41 +4757,79 @@ async fn count_language_model_tokens(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
struct CountLanguageModelTokensRateLimit;
|
||||
struct ZedProCountLanguageModelTokensRateLimit;
|
||||
|
||||
impl RateLimit for CountLanguageModelTokensRateLimit {
|
||||
fn capacity() -> usize {
|
||||
impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(600) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration() -> chrono::Duration {
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"count-language-model-tokens"
|
||||
fn db_name(&self) -> &'static str {
|
||||
"zed-pro:count-language-model-tokens"
|
||||
}
|
||||
}
|
||||
|
||||
struct ComputeEmbeddingsRateLimit;
|
||||
struct FreeCountLanguageModelTokensRateLimit;
|
||||
|
||||
impl RateLimit for ComputeEmbeddingsRateLimit {
|
||||
fn capacity() -> usize {
|
||||
impl RateLimit for FreeCountLanguageModelTokensRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(600 / 10) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"free:count-language-model-tokens"
|
||||
}
|
||||
}
|
||||
|
||||
struct ZedProComputeEmbeddingsRateLimit;
|
||||
|
||||
impl RateLimit for ZedProComputeEmbeddingsRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(5000) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration() -> chrono::Duration {
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"compute-embeddings"
|
||||
fn db_name(&self) -> &'static str {
|
||||
"zed-pro:compute-embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
struct FreeComputeEmbeddingsRateLimit;
|
||||
|
||||
impl RateLimit for FreeComputeEmbeddingsRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(5000 / 10) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"free:compute-embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4760,9 +4842,14 @@ async fn compute_embeddings(
|
|||
let api_key = api_key.context("no OpenAI API key configured on the server")?;
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<ComputeEmbeddingsRateLimit>(session.user_id())
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
let embeddings = match request.model.as_str() {
|
||||
|
@ -4834,10 +4921,10 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<()
|
|||
let db = session.db().await;
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
if flags.iter().any(|flag| flag == "language-models") {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("permission denied"))?
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
|
||||
/// Get a Supermaven API key for the user
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue