Remove RPC messages pertaining to the LLM token (#36252)
This PR removes the RPC messages pertaining to the LLM token. We now retrieve the LLM token from Cloud. Release Notes: - N/A
This commit is contained in:
parent
257e0991d8
commit
75b832029a
8 changed files with 4 additions and 263 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -3324,7 +3324,6 @@ dependencies = [
|
||||||
"http_client",
|
"http_client",
|
||||||
"hyper 0.14.32",
|
"hyper 0.14.32",
|
||||||
"indoc",
|
"indoc",
|
||||||
"jsonwebtoken",
|
|
||||||
"language",
|
"language",
|
||||||
"language_model",
|
"language_model",
|
||||||
"livekit_api",
|
"livekit_api",
|
||||||
|
@ -3370,7 +3369,6 @@ dependencies = [
|
||||||
"telemetry_events",
|
"telemetry_events",
|
||||||
"text",
|
"text",
|
||||||
"theme",
|
"theme",
|
||||||
"thiserror 2.0.12",
|
|
||||||
"time",
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
"toml 0.8.20",
|
"toml 0.8.20",
|
||||||
|
|
|
@ -39,7 +39,6 @@ futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
hex.workspace = true
|
hex.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
jsonwebtoken.workspace = true
|
|
||||||
livekit_api.workspace = true
|
livekit_api.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
nanoid.workspace = true
|
nanoid.workspace = true
|
||||||
|
@ -65,7 +64,6 @@ subtle.workspace = true
|
||||||
supermaven_api.workspace = true
|
supermaven_api.workspace = true
|
||||||
telemetry_events.workspace = true
|
telemetry_events.workspace = true
|
||||||
text.workspace = true
|
text.workspace = true
|
||||||
thiserror.workspace = true
|
|
||||||
time.workspace = true
|
time.workspace = true
|
||||||
tokio = { workspace = true, features = ["full"] }
|
tokio = { workspace = true, features = ["full"] }
|
||||||
toml.workspace = true
|
toml.workspace = true
|
||||||
|
|
|
@ -1,7 +1,4 @@
|
||||||
pub mod db;
|
pub mod db;
|
||||||
mod token;
|
|
||||||
|
|
||||||
pub use token::*;
|
|
||||||
|
|
||||||
pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
|
pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
|
||||||
|
|
||||||
|
|
|
@ -1,146 +0,0 @@
|
||||||
use crate::db::billing_subscription::SubscriptionKind;
|
|
||||||
use crate::db::{billing_customer, billing_subscription, user};
|
|
||||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG};
|
|
||||||
use crate::{Config, db::billing_preference};
|
|
||||||
use anyhow::{Context as _, Result};
|
|
||||||
use chrono::{NaiveDateTime, Utc};
|
|
||||||
use cloud_llm_client::Plan;
|
|
||||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::time::Duration;
|
|
||||||
use thiserror::Error;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct LlmTokenClaims {
|
|
||||||
pub iat: u64,
|
|
||||||
pub exp: u64,
|
|
||||||
pub jti: String,
|
|
||||||
pub user_id: u64,
|
|
||||||
pub system_id: Option<String>,
|
|
||||||
pub metrics_id: Uuid,
|
|
||||||
pub github_user_login: String,
|
|
||||||
pub account_created_at: NaiveDateTime,
|
|
||||||
pub is_staff: bool,
|
|
||||||
pub has_llm_closed_beta_feature_flag: bool,
|
|
||||||
pub bypass_account_age_check: bool,
|
|
||||||
pub use_llm_request_queue: bool,
|
|
||||||
pub plan: Plan,
|
|
||||||
pub has_extended_trial: bool,
|
|
||||||
pub subscription_period: (NaiveDateTime, NaiveDateTime),
|
|
||||||
pub enable_model_request_overages: bool,
|
|
||||||
pub model_request_overages_spend_limit_in_cents: u32,
|
|
||||||
pub can_use_web_search_tool: bool,
|
|
||||||
#[serde(default)]
|
|
||||||
pub has_overdue_invoices: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
|
||||||
|
|
||||||
impl LlmTokenClaims {
|
|
||||||
pub fn create(
|
|
||||||
user: &user::Model,
|
|
||||||
is_staff: bool,
|
|
||||||
billing_customer: billing_customer::Model,
|
|
||||||
billing_preferences: Option<billing_preference::Model>,
|
|
||||||
feature_flags: &Vec<String>,
|
|
||||||
subscription: billing_subscription::Model,
|
|
||||||
system_id: Option<String>,
|
|
||||||
config: &Config,
|
|
||||||
) -> Result<String> {
|
|
||||||
let secret = config
|
|
||||||
.llm_api_secret
|
|
||||||
.as_ref()
|
|
||||||
.context("no LLM API secret")?;
|
|
||||||
|
|
||||||
let plan = if is_staff {
|
|
||||||
Plan::ZedPro
|
|
||||||
} else {
|
|
||||||
subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
|
|
||||||
SubscriptionKind::ZedFree => Plan::ZedFree,
|
|
||||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
|
||||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
|
||||||
})
|
|
||||||
};
|
|
||||||
let subscription_period =
|
|
||||||
billing_subscription::Model::current_period(Some(subscription), is_staff)
|
|
||||||
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
|
|
||||||
.context("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started.")?;
|
|
||||||
|
|
||||||
let now = Utc::now();
|
|
||||||
let claims = Self {
|
|
||||||
iat: now.timestamp() as u64,
|
|
||||||
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
|
||||||
jti: uuid::Uuid::new_v4().to_string(),
|
|
||||||
user_id: user.id.to_proto(),
|
|
||||||
system_id,
|
|
||||||
metrics_id: user.metrics_id,
|
|
||||||
github_user_login: user.github_login.clone(),
|
|
||||||
account_created_at: user.account_created_at(),
|
|
||||||
is_staff,
|
|
||||||
has_llm_closed_beta_feature_flag: feature_flags
|
|
||||||
.iter()
|
|
||||||
.any(|flag| flag == "llm-closed-beta"),
|
|
||||||
bypass_account_age_check: feature_flags
|
|
||||||
.iter()
|
|
||||||
.any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG),
|
|
||||||
can_use_web_search_tool: true,
|
|
||||||
use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
|
|
||||||
plan,
|
|
||||||
has_extended_trial: feature_flags
|
|
||||||
.iter()
|
|
||||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
|
|
||||||
subscription_period,
|
|
||||||
enable_model_request_overages: billing_preferences
|
|
||||||
.as_ref()
|
|
||||||
.map_or(false, |preferences| {
|
|
||||||
preferences.model_request_overages_enabled
|
|
||||||
}),
|
|
||||||
model_request_overages_spend_limit_in_cents: billing_preferences
|
|
||||||
.as_ref()
|
|
||||||
.map_or(0, |preferences| {
|
|
||||||
preferences.model_request_overages_spend_limit_in_cents as u32
|
|
||||||
}),
|
|
||||||
has_overdue_invoices: billing_customer.has_overdue_invoices,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(jsonwebtoken::encode(
|
|
||||||
&Header::default(),
|
|
||||||
&claims,
|
|
||||||
&EncodingKey::from_secret(secret.as_ref()),
|
|
||||||
)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn validate(token: &str, config: &Config) -> Result<LlmTokenClaims, ValidateLlmTokenError> {
|
|
||||||
let secret = config
|
|
||||||
.llm_api_secret
|
|
||||||
.as_ref()
|
|
||||||
.context("no LLM API secret")?;
|
|
||||||
|
|
||||||
match jsonwebtoken::decode::<Self>(
|
|
||||||
token,
|
|
||||||
&DecodingKey::from_secret(secret.as_ref()),
|
|
||||||
&Validation::default(),
|
|
||||||
) {
|
|
||||||
Ok(token) => Ok(token.claims),
|
|
||||||
Err(e) => {
|
|
||||||
if e.kind() == &jsonwebtoken::errors::ErrorKind::ExpiredSignature {
|
|
||||||
Err(ValidateLlmTokenError::Expired)
|
|
||||||
} else {
|
|
||||||
Err(ValidateLlmTokenError::JwtError(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum ValidateLlmTokenError {
|
|
||||||
#[error("access token is expired")]
|
|
||||||
Expired,
|
|
||||||
#[error("access token validation error: {0}")]
|
|
||||||
JwtError(#[from] jsonwebtoken::errors::Error),
|
|
||||||
#[error("{0}")]
|
|
||||||
Other(#[from] anyhow::Error),
|
|
||||||
}
|
|
|
@ -1,14 +1,12 @@
|
||||||
mod connection_pool;
|
mod connection_pool;
|
||||||
|
|
||||||
use crate::api::billing::find_or_create_billing_customer;
|
|
||||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||||
use crate::db::billing_subscription::SubscriptionKind;
|
use crate::db::billing_subscription::SubscriptionKind;
|
||||||
use crate::llm::db::LlmDatabase;
|
use crate::llm::db::LlmDatabase;
|
||||||
use crate::llm::{
|
use crate::llm::{
|
||||||
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
|
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG,
|
||||||
MIN_ACCOUNT_AGE_FOR_LLM_USE,
|
MIN_ACCOUNT_AGE_FOR_LLM_USE,
|
||||||
};
|
};
|
||||||
use crate::stripe_client::StripeCustomerId;
|
|
||||||
use crate::{
|
use crate::{
|
||||||
AppState, Error, Result, auth,
|
AppState, Error, Result, auth,
|
||||||
db::{
|
db::{
|
||||||
|
@ -218,6 +216,7 @@ struct Session {
|
||||||
/// The GeoIP country code for the user.
|
/// The GeoIP country code for the user.
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
geoip_country_code: Option<String>,
|
geoip_country_code: Option<String>,
|
||||||
|
#[allow(unused)]
|
||||||
system_id: Option<String>,
|
system_id: Option<String>,
|
||||||
_executor: Executor,
|
_executor: Executor,
|
||||||
}
|
}
|
||||||
|
@ -464,7 +463,6 @@ impl Server {
|
||||||
.add_message_handler(unfollow)
|
.add_message_handler(unfollow)
|
||||||
.add_message_handler(update_followers)
|
.add_message_handler(update_followers)
|
||||||
.add_request_handler(get_private_user_info)
|
.add_request_handler(get_private_user_info)
|
||||||
.add_request_handler(get_llm_api_token)
|
|
||||||
.add_request_handler(accept_terms_of_service)
|
.add_request_handler(accept_terms_of_service)
|
||||||
.add_message_handler(acknowledge_channel_message)
|
.add_message_handler(acknowledge_channel_message)
|
||||||
.add_message_handler(acknowledge_buffer_version)
|
.add_message_handler(acknowledge_buffer_version)
|
||||||
|
@ -4251,96 +4249,6 @@ async fn accept_terms_of_service(
|
||||||
accepted_tos_at: accepted_tos_at.timestamp() as u64,
|
accepted_tos_at: accepted_tos_at.timestamp() as u64,
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// When the user accepts the terms of service, we want to refresh their LLM
|
|
||||||
// token to grant access.
|
|
||||||
session
|
|
||||||
.peer
|
|
||||||
.send(session.connection_id, proto::RefreshLlmToken {})?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_llm_api_token(
|
|
||||||
_request: proto::GetLlmToken,
|
|
||||||
response: Response<proto::GetLlmToken>,
|
|
||||||
session: MessageContext,
|
|
||||||
) -> Result<()> {
|
|
||||||
let db = session.db().await;
|
|
||||||
|
|
||||||
let flags = db.get_user_flags(session.user_id()).await?;
|
|
||||||
|
|
||||||
let user_id = session.user_id();
|
|
||||||
let user = db
|
|
||||||
.get_user_by_id(user_id)
|
|
||||||
.await?
|
|
||||||
.with_context(|| format!("user {user_id} not found"))?;
|
|
||||||
|
|
||||||
if user.accepted_tos_at.is_none() {
|
|
||||||
Err(anyhow!("terms of service not accepted"))?
|
|
||||||
}
|
|
||||||
|
|
||||||
let stripe_client = session
|
|
||||||
.app_state
|
|
||||||
.stripe_client
|
|
||||||
.as_ref()
|
|
||||||
.context("failed to retrieve Stripe client")?;
|
|
||||||
|
|
||||||
let stripe_billing = session
|
|
||||||
.app_state
|
|
||||||
.stripe_billing
|
|
||||||
.as_ref()
|
|
||||||
.context("failed to retrieve Stripe billing object")?;
|
|
||||||
|
|
||||||
let billing_customer = if let Some(billing_customer) =
|
|
||||||
db.get_billing_customer_by_user_id(user.id).await?
|
|
||||||
{
|
|
||||||
billing_customer
|
|
||||||
} else {
|
|
||||||
let customer_id = stripe_billing
|
|
||||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
|
|
||||||
.await?
|
|
||||||
.context("billing customer not found")?
|
|
||||||
};
|
|
||||||
|
|
||||||
let billing_subscription =
|
|
||||||
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
|
|
||||||
billing_subscription
|
|
||||||
} else {
|
|
||||||
let stripe_customer_id =
|
|
||||||
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
|
||||||
|
|
||||||
let stripe_subscription = stripe_billing
|
|
||||||
.subscribe_to_zed_free(stripe_customer_id)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
|
|
||||||
billing_customer_id: billing_customer.id,
|
|
||||||
kind: Some(SubscriptionKind::ZedFree),
|
|
||||||
stripe_subscription_id: stripe_subscription.id.to_string(),
|
|
||||||
stripe_subscription_status: stripe_subscription.status.into(),
|
|
||||||
stripe_cancellation_reason: None,
|
|
||||||
stripe_current_period_start: Some(stripe_subscription.current_period_start),
|
|
||||||
stripe_current_period_end: Some(stripe_subscription.current_period_end),
|
|
||||||
})
|
|
||||||
.await?
|
|
||||||
};
|
|
||||||
|
|
||||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
|
||||||
|
|
||||||
let token = LlmTokenClaims::create(
|
|
||||||
&user,
|
|
||||||
session.is_staff(),
|
|
||||||
billing_customer,
|
|
||||||
billing_preferences,
|
|
||||||
&flags,
|
|
||||||
billing_subscription,
|
|
||||||
session.system_id.clone(),
|
|
||||||
&session.app_state.config,
|
|
||||||
)?;
|
|
||||||
response.send(proto::GetLlmTokenResponse { token })?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -158,14 +158,6 @@ message SynchronizeContextsResponse {
|
||||||
repeated ContextVersion contexts = 1;
|
repeated ContextVersion contexts = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GetLlmToken {}
|
|
||||||
|
|
||||||
message GetLlmTokenResponse {
|
|
||||||
string token = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message RefreshLlmToken {}
|
|
||||||
|
|
||||||
enum LanguageModelRole {
|
enum LanguageModelRole {
|
||||||
LanguageModelUser = 0;
|
LanguageModelUser = 0;
|
||||||
LanguageModelAssistant = 1;
|
LanguageModelAssistant = 1;
|
||||||
|
|
|
@ -250,10 +250,6 @@ message Envelope {
|
||||||
AddWorktree add_worktree = 222;
|
AddWorktree add_worktree = 222;
|
||||||
AddWorktreeResponse add_worktree_response = 223;
|
AddWorktreeResponse add_worktree_response = 223;
|
||||||
|
|
||||||
GetLlmToken get_llm_token = 235;
|
|
||||||
GetLlmTokenResponse get_llm_token_response = 236;
|
|
||||||
RefreshLlmToken refresh_llm_token = 259;
|
|
||||||
|
|
||||||
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
|
LspExtSwitchSourceHeader lsp_ext_switch_source_header = 241;
|
||||||
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
|
LspExtSwitchSourceHeaderResponse lsp_ext_switch_source_header_response = 242;
|
||||||
|
|
||||||
|
@ -419,7 +415,9 @@ message Envelope {
|
||||||
reserved 221;
|
reserved 221;
|
||||||
reserved 224 to 229;
|
reserved 224 to 229;
|
||||||
reserved 230 to 231;
|
reserved 230 to 231;
|
||||||
|
reserved 235 to 236;
|
||||||
reserved 246;
|
reserved 246;
|
||||||
|
reserved 259;
|
||||||
reserved 270;
|
reserved 270;
|
||||||
reserved 247 to 254;
|
reserved 247 to 254;
|
||||||
reserved 255 to 256;
|
reserved 255 to 256;
|
||||||
|
|
|
@ -119,8 +119,6 @@ messages!(
|
||||||
(GetTypeDefinitionResponse, Background),
|
(GetTypeDefinitionResponse, Background),
|
||||||
(GetImplementation, Background),
|
(GetImplementation, Background),
|
||||||
(GetImplementationResponse, Background),
|
(GetImplementationResponse, Background),
|
||||||
(GetLlmToken, Background),
|
|
||||||
(GetLlmTokenResponse, Background),
|
|
||||||
(OpenUnstagedDiff, Foreground),
|
(OpenUnstagedDiff, Foreground),
|
||||||
(OpenUnstagedDiffResponse, Foreground),
|
(OpenUnstagedDiffResponse, Foreground),
|
||||||
(OpenUncommittedDiff, Foreground),
|
(OpenUncommittedDiff, Foreground),
|
||||||
|
@ -196,7 +194,6 @@ messages!(
|
||||||
(PrepareRenameResponse, Background),
|
(PrepareRenameResponse, Background),
|
||||||
(ProjectEntryResponse, Foreground),
|
(ProjectEntryResponse, Foreground),
|
||||||
(RefreshInlayHints, Foreground),
|
(RefreshInlayHints, Foreground),
|
||||||
(RefreshLlmToken, Background),
|
|
||||||
(RegisterBufferWithLanguageServers, Background),
|
(RegisterBufferWithLanguageServers, Background),
|
||||||
(RejoinChannelBuffers, Foreground),
|
(RejoinChannelBuffers, Foreground),
|
||||||
(RejoinChannelBuffersResponse, Foreground),
|
(RejoinChannelBuffersResponse, Foreground),
|
||||||
|
@ -354,7 +351,6 @@ request_messages!(
|
||||||
(GetDocumentHighlights, GetDocumentHighlightsResponse),
|
(GetDocumentHighlights, GetDocumentHighlightsResponse),
|
||||||
(GetDocumentSymbols, GetDocumentSymbolsResponse),
|
(GetDocumentSymbols, GetDocumentSymbolsResponse),
|
||||||
(GetHover, GetHoverResponse),
|
(GetHover, GetHoverResponse),
|
||||||
(GetLlmToken, GetLlmTokenResponse),
|
|
||||||
(GetNotifications, GetNotificationsResponse),
|
(GetNotifications, GetNotificationsResponse),
|
||||||
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
|
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
|
||||||
(GetProjectSymbols, GetProjectSymbolsResponse),
|
(GetProjectSymbols, GetProjectSymbolsResponse),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue