Remove RPC messages pertaining to the LLM token (#36252)
This PR removes the RPC messages pertaining to the LLM token. We now retrieve the LLM token from Cloud. Release Notes: - N/A
This commit is contained in:
parent
257e0991d8
commit
75b832029a
8 changed files with 4 additions and 263 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -3324,7 +3324,6 @@ dependencies = [
|
|||
"http_client",
|
||||
"hyper 0.14.32",
|
||||
"indoc",
|
||||
"jsonwebtoken",
|
||||
"language",
|
||||
"language_model",
|
||||
"livekit_api",
|
||||
|
@ -3370,7 +3369,6 @@ dependencies = [
|
|||
"telemetry_events",
|
||||
"text",
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"time",
|
||||
"tokio",
|
||||
"toml 0.8.20",
|
||||
|
|
|
@ -39,7 +39,6 @@ futures.workspace = true
|
|||
gpui.workspace = true
|
||||
hex.workspace = true
|
||||
http_client.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
livekit_api.workspace = true
|
||||
log.workspace = true
|
||||
nanoid.workspace = true
|
||||
|
@ -65,7 +64,6 @@ subtle.workspace = true
|
|||
supermaven_api.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
toml.workspace = true
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
pub mod db;
|
||||
mod token;
|
||||
|
||||
pub use token::*;
|
||||
|
||||
pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
|
||||
|
||||
|
|
|
@ -1,146 +0,0 @@
|
|||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::db::{billing_customer, billing_subscription, user};
|
||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG};
|
||||
use crate::{Config, db::billing_preference};
|
||||
use anyhow::{Context as _, Result};
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
use cloud_llm_client::Plan;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LlmTokenClaims {
|
||||
pub iat: u64,
|
||||
pub exp: u64,
|
||||
pub jti: String,
|
||||
pub user_id: u64,
|
||||
pub system_id: Option<String>,
|
||||
pub metrics_id: Uuid,
|
||||
pub github_user_login: String,
|
||||
pub account_created_at: NaiveDateTime,
|
||||
pub is_staff: bool,
|
||||
pub has_llm_closed_beta_feature_flag: bool,
|
||||
pub bypass_account_age_check: bool,
|
||||
pub use_llm_request_queue: bool,
|
||||
pub plan: Plan,
|
||||
pub has_extended_trial: bool,
|
||||
pub subscription_period: (NaiveDateTime, NaiveDateTime),
|
||||
pub enable_model_request_overages: bool,
|
||||
pub model_request_overages_spend_limit_in_cents: u32,
|
||||
pub can_use_web_search_tool: bool,
|
||||
#[serde(default)]
|
||||
pub has_overdue_invoices: bool,
|
||||
}
|
||||
|
||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||
|
||||
impl LlmTokenClaims {
|
||||
pub fn create(
|
||||
user: &user::Model,
|
||||
is_staff: bool,
|
||||
billing_customer: billing_customer::Model,
|
||||
billing_preferences: Option<billing_preference::Model>,
|
||||
feature_flags: &Vec<String>,
|
||||
subscription: billing_subscription::Model,
|
||||
system_id: Option<String>,
|
||||
config: &Config,
|
||||
) -> Result<String> {
|
||||
let secret = config
|
||||
.llm_api_secret
|
||||
.as_ref()
|
||||
.context("no LLM API secret")?;
|
||||
|
||||
let plan = if is_staff {
|
||||
Plan::ZedPro
|
||||
} else {
|
||||
subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
|
||||
SubscriptionKind::ZedFree => Plan::ZedFree,
|
||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
||||
})
|
||||
};
|
||||
let subscription_period =
|
||||
billing_subscription::Model::current_period(Some(subscription), is_staff)
|
||||
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
|
||||
.context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?;
|
||||
|
||||
let now = Utc::now();
|
||||
let claims = Self {
|
||||
iat: now.timestamp() as u64,
|
||||
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user.id.to_proto(),
|
||||
system_id,
|
||||
metrics_id: user.metrics_id,
|
||||
github_user_login: user.github_login.clone(),
|
||||
account_created_at: user.account_created_at(),
|
||||
is_staff,
|
||||
has_llm_closed_beta_feature_flag: feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == "llm-closed-beta"),
|
||||
bypass_account_age_check: feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG),
|
||||
can_use_web_search_tool: true,
|
||||
use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
|
||||
plan,
|
||||
has_extended_trial: feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
|
||||
subscription_period,
|
||||
enable_model_request_overages: billing_preferences
|
||||
.as_ref()
|
||||
.map_or(false, |preferences| {
|
||||
preferences.model_request_overages_enabled
|
||||
}),
|
||||
model_request_overages_spend_limit_in_cents: billing_preferences
|
||||
.as_ref()
|
||||
.map_or(0, |preferences| {
|
||||
preferences.model_request_overages_spend_limit_in_cents as u32
|
||||
}),
|
||||
has_overdue_invoices: billing_customer.has_overdue_invoices,
|
||||
};
|
||||
|
||||
Ok(jsonwebtoken::encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(secret.as_ref()),
|
||||
)?)
|
||||
}
|
||||
|
||||
pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
|
||||
let secret = config
|
||||
.llm_api_secret
|
||||
.as_ref()
|
||||
.context("no LLM API secret")?;
|
||||
|
||||
match jsonwebtoken::decode::<Self>(
|
||||
token,
|
||||
&DecodingKey::from_secret(secret.as_ref()),
|
||||
&Validation::default(),
|
||||
) {
|
||||
Ok(token) => Ok(token.claims),
|
||||
Err(e) => {
|
||||
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
|
||||
Err(ValidateLlmTokenError::Expired)
|
||||
} else {
|
||||
Err(ValidateLlmTokenError::JwtError(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidateLlmTokenError {
|
||||
#[error("access token is expired")]
|
||||
Expired,
|
||||
#[error("access token validation error: {0}")]
|
||||
JwtError(#[from] jsonwebtoken::errors::Error),
|
||||
#[error("{0}")]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
|
@ -1,14 +1,12 @@
|
|||
mod connection_pool;
|
||||
|
||||
use crate::api::billing::find_or_create_billing_customer;
|
||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::llm::{
|
||||
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
|
||||
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG,
|
||||
MIN_ACCOUNT_AGE_FOR_LLM_USE,
|
||||
};
|
||||
use crate::stripe_client::StripeCustomerId;
|
||||
use crate::{
|
||||
AppState, Error, Result, auth,
|
||||
db::{
|
||||
|
@ -218,6 +216,7 @@ struct Session {
|
|||
/// The GeoIP country code for the user.
|
||||
#[allow(unused)]
|
||||
geoip_country_code: Option<String>,
|
||||
#[allow(unused)]
|
||||
system_id: Option<String>,
|
||||
_executor: Executor,
|
||||
}
|
||||
|
@ -464,7 +463,6 @@ impl Server {
|
|||
.add_message_handler(unfollow)
|
||||
.add_message_handler(update_followers)
|
||||
.add_request_handler(get_private_user_info)
|
||||
.add_request_handler(get_llm_api_token)
|
||||
.add_request_handler(accept_terms_of_service)
|
||||
.add_message_handler(acknowledge_channel_message)
|
||||
.add_message_handler(acknowledge_buffer_version)
|
||||
|
@ -4251,96 +4249,6 @@ async fn accept_terms_of_service(
|
|||
accepted_tos_at: accepted_tos_at.timestamp() as u64,
|
||||
})?;
|
||||
|
||||
// When the user accepts the terms of service, we want to refresh their LLM
|
||||
// token to grant access.
|
||||
session
|
||||
.peer
|
||||
.send(session.connection_id, proto::RefreshLlmToken {})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_llm_api_token(
|
||||
_request: proto::GetLlmToken,
|
||||
response: Response<proto::GetLlmToken>,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
|
||||
let user_id = session.user_id();
|
||||
let user = db
|
||||
.get_user_by_id(user_id)
|
||||
.await?
|
||||
.with_context(|| format!("user {user_id} not found"))?;
|
||||
|
||||
if user.accepted_tos_at.is_none() {
|
||||
Err(anyhow!("terms of service not accepted"))?
|
||||
}
|
||||
|
||||
let stripe_client = session
|
||||
.app_state
|
||||
.stripe_client
|
||||
.as_ref()
|
||||
.context("failed to retrieve Stripe client")?;
|
||||
|
||||
let stripe_billing = session
|
||||
.app_state
|
||||
.stripe_billing
|
||||
.as_ref()
|
||||
.context("failed to retrieve Stripe billing object")?;
|
||||
|
||||
let billing_customer = if let Some(billing_customer) =
|
||||
db.get_billing_customer_by_user_id(user.id).await?
|
||||
{
|
||||
billing_customer
|
||||
} else {
|
||||
let customer_id = stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?;
|
||||
|
||||
find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
|
||||
.await?
|
||||
.context("billing customer not found")?
|
||||
};
|
||||
|
||||
let billing_subscription =
|
||||
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
|
||||
billing_subscription
|
||||
} else {
|
||||
let stripe_customer_id =
|
||||
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||
|
||||
let stripe_subscription = stripe_billing
|
||||
.subscribe_to_zed_free(stripe_customer_id)
|
||||
.await?;
|
||||
|
||||
db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
|
||||
billing_customer_id: billing_customer.id,
|
||||
kind: Some(SubscriptionKind::ZedFree),
|
||||
stripe_subscription_id: stripe_subscription.id.to_string(),
|
||||
stripe_subscription_status: stripe_subscription.status.into(),
|
||||
stripe_cancellation_reason: None,
|
||||
stripe_current_period_start: Some(stripe_subscription.current_period_start),
|
||||
stripe_current_period_end: Some(stripe_subscription.current_period_end),
|
||||
})
|
||||
.await?
|
||||
};
|
||||
|
||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
&user,
|
||||
session.is_staff(),
|
||||
billing_customer,
|
||||
billing_preferences,
|
||||
&flags,
|
||||
billing_subscription,
|
||||
session.system_id.clone(),
|
||||
&session.app_state.config,
|
||||
)?;
|
||||
response.send(proto::GetLlmTokenResponse { token })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -158,14 +158,6 @@ message SynchronizeContextsResponse {
|
|||
repeated ContextVersion contexts = 1;
|
||||
}
|
||||
|
||||
message GetLlmToken {}
|
||||
|
||||
message GetLlmTokenResponse {
|
||||
string token = 1;
|
||||
}
|
||||
|
||||
message RefreshLlmToken {}
|
||||
|
||||
enum LanguageModelRole {
|
||||
LanguageModelUser = 0;
|
||||
LanguageModelAssistant = 1;
|
||||
|
|
|
@ -250,10 +250,6 @@ message Envelope {
|
|||
AddWorktree add_worktree = 222;
|
||||
AddWorktreeResponse add_worktree_response = 223;
|
||||
|
||||
GetLlmToken get_llm_token = 235;
|
||||
GetLlmTokenResponse get_llm_token_response = 236;
|
||||
RefreshLlmToken refresh_llm_token = 259;
|
||||
|
||||
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
|
||||
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
|
||||
|
||||
|
@ -419,7 +415,9 @@ message Envelope {
|
|||
reserved 221;
|
||||
reserved 224 to 229;
|
||||
reserved 230 to 231;
|
||||
reserved 235 to 236;
|
||||
reserved 246;
|
||||
reserved 259;
|
||||
reserved 270;
|
||||
reserved 247 to 254;
|
||||
reserved 255 to 256;
|
||||
|
|
|
@ -119,8 +119,6 @@ messages!(
|
|||
(GetTypeDefinitionResponse, Background),
|
||||
(GetImplementation, Background),
|
||||
(GetImplementationResponse, Background),
|
||||
(GetLlmToken, Background),
|
||||
(GetLlmTokenResponse, Background),
|
||||
(OpenUnstagedDiff, Foreground),
|
||||
(OpenUnstagedDiffResponse, Foreground),
|
||||
(OpenUncommittedDiff, Foreground),
|
||||
|
@ -196,7 +194,6 @@ messages!(
|
|||
(PrepareRenameResponse, Background),
|
||||
(ProjectEntryResponse, Foreground),
|
||||
(RefreshInlayHints, Foreground),
|
||||
(RefreshLlmToken, Background),
|
||||
(RegisterBufferWithLanguageServers, Background),
|
||||
(RejoinChannelBuffers, Foreground),
|
||||
(RejoinChannelBuffersResponse, Foreground),
|
||||
|
@ -354,7 +351,6 @@ request_messages!(
|
|||
(GetDocumentHighlights, GetDocumentHighlightsResponse),
|
||||
(GetDocumentSymbols, GetDocumentSymbolsResponse),
|
||||
(GetHover, GetHoverResponse),
|
||||
(GetLlmToken, GetLlmTokenResponse),
|
||||
(GetNotifications, GetNotificationsResponse),
|
||||
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
|
||||
(GetProjectSymbols, GetProjectSymbolsResponse),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue