diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs index 5c0295d9fb..865ca97f7a 100644 --- a/crates/collab/src/llm/authorization.rs +++ b/crates/collab/src/llm/authorization.rs @@ -26,7 +26,7 @@ fn authorize_access_to_model( } match (provider, model) { - (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => { + (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => { Ok(()) } _ => Err(Error::http( @@ -240,14 +240,14 @@ mod tests { ( Plan::ZedPro, LanguageModelProvider::Anthropic, - "claude-3.5-sonnet", + "claude-3-5-sonnet", true, ), // Free plan should have access to claude-3.5-sonnet ( Plan::Free, LanguageModelProvider::Anthropic, - "claude-3.5-sonnet", + "claude-3-5-sonnet", true, ), // Pro plan should NOT have access to other Anthropic models @@ -303,7 +303,7 @@ mod tests { // Staff should have access to all models let test_cases = vec![ - (LanguageModelProvider::Anthropic, "claude-3.5-sonnet"), + (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"), (LanguageModelProvider::Anthropic, "claude-2"), (LanguageModelProvider::Anthropic, "claude-123-agi"), (LanguageModelProvider::OpenAi, "gpt-4"), diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index ca5e1990f4..c84f818635 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -71,7 +71,7 @@ use std::{ time::{Duration, Instant}, }; use time::OffsetDateTime; -use tokio::sync::{watch, Semaphore}; +use tokio::sync::{watch, MutexGuard, Semaphore}; use tower::ServiceBuilder; use tracing::{ field::{self}, @@ -192,7 +192,7 @@ impl Session { } } - pub async fn current_plan(&self) -> anyhow::Result { + pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result { if self.is_staff() { return Ok(proto::Plan::ZedPro); } @@ -201,7 +201,6 @@ impl Session { return Ok(proto::Plan::Free); }; - let db = self.db().await; if db.has_active_billing_subscription(user_id).await? { Ok(proto::Plan::ZedPro) } else { @@ -3500,7 +3499,7 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { } async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> { - let plan = session.current_plan().await?; + let plan = session.current_plan(session.db().await).await?; session .peer @@ -4503,7 +4502,7 @@ async fn count_language_model_tokens( }; authorize_access_to_legacy_llm_endpoints(&session).await?; - let rate_limit: Box = match session.current_plan().await? { + let rate_limit: Box = match session.current_plan(session.db().await).await? { proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit), proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit), }; @@ -4623,7 +4622,7 @@ async fn compute_embeddings( let api_key = api_key.context("no OpenAI API key configured on the server")?; authorize_access_to_legacy_llm_endpoints(&session).await?; - let rate_limit: Box = match session.current_plan().await? { + let rate_limit: Box = match session.current_plan(session.db().await).await? { proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit), proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit), }; @@ -4940,11 +4939,10 @@ async fn get_llm_api_token( if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE { Err(anyhow!("account too young"))? } - let token = LlmTokenClaims::create( user.id, session.is_staff(), - session.current_plan().await?, + session.current_plan(db).await?, &session.app_state.config, )?; response.send(proto::GetLlmTokenResponse { token })?;