diff --git a/Cargo.lock b/Cargo.lock index 2353733dc0..bfc797d6cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3324,7 +3324,6 @@ dependencies = [ "http_client", "hyper 0.14.32", "indoc", - "jsonwebtoken", "language", "language_model", "livekit_api", @@ -3370,7 +3369,6 @@ dependencies = [ "telemetry_events", "text", "theme", - "thiserror 2.0.12", "time", "tokio", "toml 0.8.20", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 9af95317e6..9a867f9e05 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -39,7 +39,6 @@ futures.workspace = true gpui.workspace = true hex.workspace = true http_client.workspace = true -jsonwebtoken.workspace = true livekit_api.workspace = true log.workspace = true nanoid.workspace = true @@ -65,7 +64,6 @@ subtle.workspace = true supermaven_api.workspace = true telemetry_events.workspace = true text.workspace = true -thiserror.workspace = true time.workspace = true tokio = { workspace = true, features = ["full"] } toml.workspace = true diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index de74858168..ca8e89bc6d 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,7 +1,4 @@ pub mod db; -mod token; - -pub use token::*; pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial"; diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs deleted file mode 100644 index da01c7f3be..0000000000 --- a/crates/collab/src/llm/token.rs +++ /dev/null @@ -1,146 +0,0 @@ -use crate::db::billing_subscription::SubscriptionKind; -use crate::db::{billing_customer, billing_subscription, user}; -use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG}; -use crate::{Config, db::billing_preference}; -use anyhow::{Context as _, Result}; -use chrono::{NaiveDateTime, Utc}; -use cloud_llm_client::Plan; -use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; -use serde::{Deserialize, Serialize}; -use std::time::Duration; -use thiserror::Error; -use uuid::Uuid; - -#[derive(Clone, Debug, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct LlmTokenClaims { - pub iat: u64, - pub exp: u64, - pub jti: String, - pub user_id: u64, - pub system_id: Option, - pub metrics_id: Uuid, - pub github_user_login: String, - pub account_created_at: NaiveDateTime, - pub is_staff: bool, - pub has_llm_closed_beta_feature_flag: bool, - pub bypass_account_age_check: bool, - pub use_llm_request_queue: bool, - pub plan: Plan, - pub has_extended_trial: bool, - pub subscription_period: (NaiveDateTime, NaiveDateTime), - pub enable_model_request_overages: bool, - pub model_request_overages_spend_limit_in_cents: u32, - pub can_use_web_search_tool: bool, - #[serde(default)] - pub has_overdue_invoices: bool, -} - -const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60); - -impl LlmTokenClaims { - pub fn create( - user: &user::Model, - is_staff: bool, - billing_customer: billing_customer::Model, - billing_preferences: Option, - feature_flags: &Vec, - subscription: billing_subscription::Model, - system_id: Option, - config: &Config, - ) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - let plan = if is_staff { - Plan::ZedPro - } else { - subscription.kind.map_or(Plan::ZedFree, |kind| match kind { - SubscriptionKind::ZedFree => Plan::ZedFree, - SubscriptionKind::ZedPro => Plan::ZedPro, - SubscriptionKind::ZedProTrial => Plan::ZedProTrial, - }) - }; - let subscription_period = - billing_subscription::Model::current_period(Some(subscription), is_staff) - .map(|(start, end)| (start.naive_utc(), end.naive_utc())) - .context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?; - - let now = Utc::now(); - let claims = Self { - iat: now.timestamp() as u64, - exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, - jti: uuid::Uuid::new_v4().to_string(), - user_id: user.id.to_proto(), - system_id, - metrics_id: user.metrics_id, - github_user_login: user.github_login.clone(), - account_created_at: user.account_created_at(), - is_staff, - has_llm_closed_beta_feature_flag: feature_flags - .iter() - .any(|flag| flag == "llm-closed-beta"), - bypass_account_age_check: feature_flags - .iter() - .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG), - can_use_web_search_tool: true, - use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"), - plan, - has_extended_trial: feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG), - subscription_period, - enable_model_request_overages: billing_preferences - .as_ref() - .map_or(false, |preferences| { - preferences.model_request_overages_enabled - }), - model_request_overages_spend_limit_in_cents: billing_preferences - .as_ref() - .map_or(0, |preferences| { - preferences.model_request_overages_spend_limit_in_cents as u32 - }), - has_overdue_invoices: billing_customer.has_overdue_invoices, - }; - - Ok(jsonwebtoken::encode( - &Header::default(), - &claims, - &EncodingKey::from_secret(secret.as_ref()), - )?) - } - - pub fn validate(token: &str, config: &Config) -> Result { - let secret = config - .llm_api_secret - .as_ref() - .context("no LLM API secret")?; - - match jsonwebtoken::decode::( - token, - &DecodingKey::from_secret(secret.as_ref()), - &Validation::default(), - ) { - Ok(token) => Ok(token.claims), - Err(e) => { - if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature { - Err(ValidateLlmTokenError::Expired) - } else { - Err(ValidateLlmTokenError::JwtError(e)) - } - } - } - } -} - -#[derive(Error, Debug)] -pub enum ValidateLlmTokenError { - #[error("access token is expired")] - Expired, - #[error("access token validation error: {0}")] - JwtError(#[from] jsonwebtoken::errors::Error), - #[error("{0}")] - Other(#[from] anyhow::Error), -} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 584970a4c6..715ff4e67d 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,14 +1,12 @@ mod connection_pool; -use crate::api::billing::find_or_create_billing_customer; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; use crate::db::billing_subscription::SubscriptionKind; use crate::llm::db::LlmDatabase; use crate::llm::{ - AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims, + AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, MIN_ACCOUNT_AGE_FOR_LLM_USE, }; -use crate::stripe_client::StripeCustomerId; use crate::{ AppState, Error, Result, auth, db::{ @@ -218,6 +216,7 @@ struct Session { /// The GeoIP country code for the user. #[allow(unused)] geoip_country_code: Option, + #[allow(unused)] system_id: Option, _executor: Executor, } @@ -464,7 +463,6 @@ impl Server { .add_message_handler(unfollow) .add_message_handler(update_followers) .add_request_handler(get_private_user_info) - .add_request_handler(get_llm_api_token) .add_request_handler(accept_terms_of_service) .add_message_handler(acknowledge_channel_message) .add_message_handler(acknowledge_buffer_version) @@ -4251,96 +4249,6 @@ async fn accept_terms_of_service( accepted_tos_at: accepted_tos_at.timestamp() as u64, })?; - // When the user accepts the terms of service, we want to refresh their LLM - // token to grant access. - session - .peer - .send(session.connection_id, proto::RefreshLlmToken {})?; - - Ok(()) -} - -async fn get_llm_api_token( - _request: proto::GetLlmToken, - response: Response, - session: MessageContext, -) -> Result<()> { - let db = session.db().await; - - let flags = db.get_user_flags(session.user_id()).await?; - - let user_id = session.user_id(); - let user = db - .get_user_by_id(user_id) - .await? - .with_context(|| format!("user {user_id} not found"))?; - - if user.accepted_tos_at.is_none() { - Err(anyhow!("terms of service not accepted"))? - } - - let stripe_client = session - .app_state - .stripe_client - .as_ref() - .context("failed to retrieve Stripe client")?; - - let stripe_billing = session - .app_state - .stripe_billing - .as_ref() - .context("failed to retrieve Stripe billing object")?; - - let billing_customer = if let Some(billing_customer) = - db.get_billing_customer_by_user_id(user.id).await? - { - billing_customer - } else { - let customer_id = stripe_billing - .find_or_create_customer_by_email(user.email_address.as_deref()) - .await?; - - find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id) - .await? - .context("billing customer not found")? - }; - - let billing_subscription = - if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? { - billing_subscription - } else { - let stripe_customer_id = - StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); - - let stripe_subscription = stripe_billing - .subscribe_to_zed_free(stripe_customer_id) - .await?; - - db.create_billing_subscription(&db::CreateBillingSubscriptionParams { - billing_customer_id: billing_customer.id, - kind: Some(SubscriptionKind::ZedFree), - stripe_subscription_id: stripe_subscription.id.to_string(), - stripe_subscription_status: stripe_subscription.status.into(), - stripe_cancellation_reason: None, - stripe_current_period_start: Some(stripe_subscription.current_period_start), - stripe_current_period_end: Some(stripe_subscription.current_period_end), - }) - .await? - }; - - let billing_preferences = db.get_billing_preferences(user.id).await?; - - let token = LlmTokenClaims::create( - &user, - session.is_staff(), - billing_customer, - billing_preferences, - &flags, - billing_subscription, - session.system_id.clone(), - &session.app_state.config, - )?; - response.send(proto::GetLlmTokenResponse { token })?; Ok(()) } diff --git a/crates/proto/proto/ai.proto b/crates/proto/proto/ai.proto index 67c2224387..1064ed2f8d 100644 --- a/crates/proto/proto/ai.proto +++ b/crates/proto/proto/ai.proto @@ -158,14 +158,6 @@ message SynchronizeContextsResponse { repeated ContextVersion contexts = 1; } -message GetLlmToken {} - -message GetLlmTokenResponse { - string token = 1; -} - -message RefreshLlmToken {} - enum LanguageModelRole { LanguageModelUser = 0; LanguageModelAssistant = 1; diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 856a793c2f..b6c7fc3cac 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -250,10 +250,6 @@ message Envelope { AddWorktree add_worktree = 222; AddWorktreeResponse add_worktree_response = 223; - GetLlmToken get_llm_token = 235; - GetLlmTokenResponse get_llm_token_response = 236; - RefreshLlmToken refresh_llm_token = 259; - LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241; LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242; @@ -419,7 +415,9 @@ message Envelope { reserved 221; reserved 224 to 229; reserved 230 to 231; + reserved 235 to 236; reserved 246; + reserved 259; reserved 270; reserved 247 to 254; reserved 255 to 256; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index a5dd97661f..8be9fed172 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -119,8 +119,6 @@ messages!( (GetTypeDefinitionResponse, Background), (GetImplementation, Background), (GetImplementationResponse, Background), - (GetLlmToken, Background), - (GetLlmTokenResponse, Background), (OpenUnstagedDiff, Foreground), (OpenUnstagedDiffResponse, Foreground), (OpenUncommittedDiff, Foreground), @@ -196,7 +194,6 @@ messages!( (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), (RefreshInlayHints, Foreground), - (RefreshLlmToken, Background), (RegisterBufferWithLanguageServers, Background), (RejoinChannelBuffers, Foreground), (RejoinChannelBuffersResponse, Foreground), @@ -354,7 +351,6 @@ request_messages!( (GetDocumentHighlights, GetDocumentHighlightsResponse), (GetDocumentSymbols, GetDocumentSymbolsResponse), (GetHover, GetHoverResponse), - (GetLlmToken, GetLlmTokenResponse), (GetNotifications, GetNotificationsResponse), (GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetProjectSymbols, GetProjectSymbolsResponse),