Add logic for closed beta LLM models (#16482)
Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
parent
41fc6d0885
commit
b5bd8a5c5d
9 changed files with 104 additions and 47 deletions
|
@ -168,6 +168,7 @@ pub struct Config {
|
|||
pub google_ai_api_key: Option<Arc<str>>,
|
||||
pub anthropic_api_key: Option<Arc<str>>,
|
||||
pub anthropic_staff_api_key: Option<Arc<str>>,
|
||||
pub llm_closed_beta_model_name: Option<Arc<str>>,
|
||||
pub qwen2_7b_api_key: Option<Arc<str>>,
|
||||
pub qwen2_7b_api_url: Option<Arc<str>>,
|
||||
pub zed_client_checksum_seed: Option<String>,
|
||||
|
@ -219,6 +220,7 @@ impl Config {
|
|||
google_ai_api_key: None,
|
||||
anthropic_api_key: None,
|
||||
anthropic_staff_api_key: None,
|
||||
llm_closed_beta_model_name: None,
|
||||
clickhouse_url: None,
|
||||
clickhouse_user: None,
|
||||
clickhouse_password: None,
|
||||
|
|
|
@ -12,11 +12,12 @@ pub fn authorize_access_to_language_model(
|
|||
model: &str,
|
||||
) -> Result<()> {
|
||||
authorize_access_for_country(config, country_code, provider)?;
|
||||
authorize_access_to_model(claims, provider, model)?;
|
||||
authorize_access_to_model(config, claims, provider, model)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn authorize_access_to_model(
|
||||
config: &Config,
|
||||
claims: &LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
|
@ -25,13 +26,25 @@ fn authorize_access_to_model(
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
match (provider, model) {
|
||||
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
|
||||
_ => Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to model {model:?} is not included in your plan"),
|
||||
))?,
|
||||
match provider {
|
||||
LanguageModelProvider::Anthropic => {
|
||||
if model == "claude-3-5-sonnet" {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if claims.has_llm_closed_beta_feature_flag
|
||||
&& Some(model) == config.llm_closed_beta_model_name.as_deref()
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to model {model:?} is not included in your plan"),
|
||||
))
|
||||
}
|
||||
|
||||
fn authorize_access_for_country(
|
||||
|
|
|
@ -82,12 +82,13 @@ impl LlmDatabase {
|
|||
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
|
||||
|
||||
let mut results = Vec::new();
|
||||
for (provider, model) in self.models.keys().cloned() {
|
||||
for ((provider, model_name), model) in self.models.iter() {
|
||||
let mut usages = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::Timestamp
|
||||
.gte(past_minute.naive_utc())
|
||||
.and(usage::Column::IsStaff.eq(false))
|
||||
.and(usage::Column::ModelId.eq(model.id))
|
||||
.and(
|
||||
usage::Column::MeasureId
|
||||
.eq(requests_per_minute)
|
||||
|
@ -125,8 +126,8 @@ impl LlmDatabase {
|
|||
}
|
||||
|
||||
results.push(ApplicationWideUsage {
|
||||
provider,
|
||||
model,
|
||||
provider: *provider,
|
||||
model: model_name.clone(),
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
})
|
||||
|
|
|
@ -20,6 +20,8 @@ pub struct LlmTokenClaims {
|
|||
#[serde(default)]
|
||||
pub github_user_login: Option<String>,
|
||||
pub is_staff: bool,
|
||||
#[serde(default)]
|
||||
pub has_llm_closed_beta_feature_flag: bool,
|
||||
pub plan: rpc::proto::Plan,
|
||||
}
|
||||
|
||||
|
@ -30,6 +32,7 @@ impl LlmTokenClaims {
|
|||
user_id: UserId,
|
||||
github_user_login: String,
|
||||
is_staff: bool,
|
||||
has_llm_closed_beta_feature_flag: bool,
|
||||
plan: rpc::proto::Plan,
|
||||
config: &Config,
|
||||
) -> Result<String> {
|
||||
|
@ -46,6 +49,7 @@ impl LlmTokenClaims {
|
|||
user_id: user_id.to_proto(),
|
||||
github_user_login: Some(github_user_login),
|
||||
is_staff,
|
||||
has_llm_closed_beta_feature_flag,
|
||||
plan,
|
||||
};
|
||||
|
||||
|
|
|
@ -4918,7 +4918,10 @@ async fn get_llm_api_token(
|
|||
let db = session.db().await;
|
||||
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
|
||||
let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
|
||||
let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
|
||||
|
||||
if !session.is_staff() && !has_language_models_feature_flag {
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
|
||||
|
@ -4943,6 +4946,7 @@ async fn get_llm_api_token(
|
|||
user.id,
|
||||
user.github_login.clone(),
|
||||
session.is_staff(),
|
||||
has_llm_closed_beta_feature_flag,
|
||||
session.current_plan(db).await?,
|
||||
&session.app_state.config,
|
||||
)?;
|
||||
|
|
|
@ -667,6 +667,7 @@ impl TestServer {
|
|||
google_ai_api_key: None,
|
||||
anthropic_api_key: None,
|
||||
anthropic_staff_api_key: None,
|
||||
llm_closed_beta_model_name: None,
|
||||
clickhouse_url: None,
|
||||
clickhouse_user: None,
|
||||
clickhouse_password: None,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue