Add logic for closed beta LLM models (#16482)
Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
parent
41fc6d0885
commit
b5bd8a5c5d
9 changed files with 104 additions and 47 deletions
|
@ -139,6 +139,11 @@ spec:
|
||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
name: anthropic
|
name: anthropic
|
||||||
key: staff_api_key
|
key: staff_api_key
|
||||||
|
- name: LLM_CLOSED_BETA_MODEL_NAME
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: llm-closed-beta
|
||||||
|
key: model_name
|
||||||
- name: GOOGLE_AI_API_KEY
|
- name: GOOGLE_AI_API_KEY
|
||||||
valueFrom:
|
valueFrom:
|
||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
|
|
|
@ -168,6 +168,7 @@ pub struct Config {
|
||||||
pub google_ai_api_key: Option<Arc<str>>,
|
pub google_ai_api_key: Option<Arc<str>>,
|
||||||
pub anthropic_api_key: Option<Arc<str>>,
|
pub anthropic_api_key: Option<Arc<str>>,
|
||||||
pub anthropic_staff_api_key: Option<Arc<str>>,
|
pub anthropic_staff_api_key: Option<Arc<str>>,
|
||||||
|
pub llm_closed_beta_model_name: Option<Arc<str>>,
|
||||||
pub qwen2_7b_api_key: Option<Arc<str>>,
|
pub qwen2_7b_api_key: Option<Arc<str>>,
|
||||||
pub qwen2_7b_api_url: Option<Arc<str>>,
|
pub qwen2_7b_api_url: Option<Arc<str>>,
|
||||||
pub zed_client_checksum_seed: Option<String>,
|
pub zed_client_checksum_seed: Option<String>,
|
||||||
|
@ -219,6 +220,7 @@ impl Config {
|
||||||
google_ai_api_key: None,
|
google_ai_api_key: None,
|
||||||
anthropic_api_key: None,
|
anthropic_api_key: None,
|
||||||
anthropic_staff_api_key: None,
|
anthropic_staff_api_key: None,
|
||||||
|
llm_closed_beta_model_name: None,
|
||||||
clickhouse_url: None,
|
clickhouse_url: None,
|
||||||
clickhouse_user: None,
|
clickhouse_user: None,
|
||||||
clickhouse_password: None,
|
clickhouse_password: None,
|
||||||
|
|
|
@ -12,11 +12,12 @@ pub fn authorize_access_to_language_model(
|
||||||
model: &str,
|
model: &str,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
authorize_access_for_country(config, country_code, provider)?;
|
authorize_access_for_country(config, country_code, provider)?;
|
||||||
authorize_access_to_model(claims, provider, model)?;
|
authorize_access_to_model(config, claims, provider, model)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authorize_access_to_model(
|
fn authorize_access_to_model(
|
||||||
|
config: &Config,
|
||||||
claims: &LlmTokenClaims,
|
claims: &LlmTokenClaims,
|
||||||
provider: LanguageModelProvider,
|
provider: LanguageModelProvider,
|
||||||
model: &str,
|
model: &str,
|
||||||
|
@ -25,13 +26,25 @@ fn authorize_access_to_model(
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
match (provider, model) {
|
match provider {
|
||||||
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
|
LanguageModelProvider::Anthropic => {
|
||||||
_ => Err(Error::http(
|
if model == "claude-3-5-sonnet" {
|
||||||
StatusCode::FORBIDDEN,
|
return Ok(());
|
||||||
format!("access to model {model:?} is not included in your plan"),
|
}
|
||||||
))?,
|
|
||||||
|
if claims.has_llm_closed_beta_feature_flag
|
||||||
|
&& Some(model) == config.llm_closed_beta_model_name.as_deref()
|
||||||
|
{
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Err(Error::http(
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
format!("access to model {model:?} is not included in your plan"),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authorize_access_for_country(
|
fn authorize_access_for_country(
|
||||||
|
|
|
@ -82,12 +82,13 @@ impl LlmDatabase {
|
||||||
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
|
let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
for (provider, model) in self.models.keys().cloned() {
|
for ((provider, model_name), model) in self.models.iter() {
|
||||||
let mut usages = usage::Entity::find()
|
let mut usages = usage::Entity::find()
|
||||||
.filter(
|
.filter(
|
||||||
usage::Column::Timestamp
|
usage::Column::Timestamp
|
||||||
.gte(past_minute.naive_utc())
|
.gte(past_minute.naive_utc())
|
||||||
.and(usage::Column::IsStaff.eq(false))
|
.and(usage::Column::IsStaff.eq(false))
|
||||||
|
.and(usage::Column::ModelId.eq(model.id))
|
||||||
.and(
|
.and(
|
||||||
usage::Column::MeasureId
|
usage::Column::MeasureId
|
||||||
.eq(requests_per_minute)
|
.eq(requests_per_minute)
|
||||||
|
@ -125,8 +126,8 @@ impl LlmDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
results.push(ApplicationWideUsage {
|
results.push(ApplicationWideUsage {
|
||||||
provider,
|
provider: *provider,
|
||||||
model,
|
model: model_name.clone(),
|
||||||
requests_this_minute,
|
requests_this_minute,
|
||||||
tokens_this_minute,
|
tokens_this_minute,
|
||||||
})
|
})
|
||||||
|
|
|
@ -20,6 +20,8 @@ pub struct LlmTokenClaims {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub github_user_login: Option<String>,
|
pub github_user_login: Option<String>,
|
||||||
pub is_staff: bool,
|
pub is_staff: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub has_llm_closed_beta_feature_flag: bool,
|
||||||
pub plan: rpc::proto::Plan,
|
pub plan: rpc::proto::Plan,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,6 +32,7 @@ impl LlmTokenClaims {
|
||||||
user_id: UserId,
|
user_id: UserId,
|
||||||
github_user_login: String,
|
github_user_login: String,
|
||||||
is_staff: bool,
|
is_staff: bool,
|
||||||
|
has_llm_closed_beta_feature_flag: bool,
|
||||||
plan: rpc::proto::Plan,
|
plan: rpc::proto::Plan,
|
||||||
config: &Config,
|
config: &Config,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
|
@ -46,6 +49,7 @@ impl LlmTokenClaims {
|
||||||
user_id: user_id.to_proto(),
|
user_id: user_id.to_proto(),
|
||||||
github_user_login: Some(github_user_login),
|
github_user_login: Some(github_user_login),
|
||||||
is_staff,
|
is_staff,
|
||||||
|
has_llm_closed_beta_feature_flag,
|
||||||
plan,
|
plan,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -4918,7 +4918,10 @@ async fn get_llm_api_token(
|
||||||
let db = session.db().await;
|
let db = session.db().await;
|
||||||
|
|
||||||
let flags = db.get_user_flags(session.user_id()).await?;
|
let flags = db.get_user_flags(session.user_id()).await?;
|
||||||
if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
|
let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
|
||||||
|
let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
|
||||||
|
|
||||||
|
if !session.is_staff() && !has_language_models_feature_flag {
|
||||||
Err(anyhow!("permission denied"))?
|
Err(anyhow!("permission denied"))?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4943,6 +4946,7 @@ async fn get_llm_api_token(
|
||||||
user.id,
|
user.id,
|
||||||
user.github_login.clone(),
|
user.github_login.clone(),
|
||||||
session.is_staff(),
|
session.is_staff(),
|
||||||
|
has_llm_closed_beta_feature_flag,
|
||||||
session.current_plan(db).await?,
|
session.current_plan(db).await?,
|
||||||
&session.app_state.config,
|
&session.app_state.config,
|
||||||
)?;
|
)?;
|
||||||
|
|
|
@ -667,6 +667,7 @@ impl TestServer {
|
||||||
google_ai_api_key: None,
|
google_ai_api_key: None,
|
||||||
anthropic_api_key: None,
|
anthropic_api_key: None,
|
||||||
anthropic_staff_api_key: None,
|
anthropic_staff_api_key: None,
|
||||||
|
llm_closed_beta_model_name: None,
|
||||||
clickhouse_url: None,
|
clickhouse_url: None,
|
||||||
clickhouse_user: None,
|
clickhouse_user: None,
|
||||||
clickhouse_password: None,
|
clickhouse_password: None,
|
||||||
|
|
|
@ -43,6 +43,11 @@ impl FeatureFlag for LanguageModels {
|
||||||
const NAME: &'static str = "language-models";
|
const NAME: &'static str = "language-models";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct LlmClosedBeta {}
|
||||||
|
impl FeatureFlag for LlmClosedBeta {
|
||||||
|
const NAME: &'static str = "llm-closed-beta";
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ZedPro {}
|
pub struct ZedPro {}
|
||||||
impl FeatureFlag for ZedPro {
|
impl FeatureFlag for ZedPro {
|
||||||
const NAME: &'static str = "zed-pro";
|
const NAME: &'static str = "zed-pro";
|
||||||
|
|
|
@ -8,7 +8,7 @@ use anthropic::AnthropicError;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use feature_flags::{FeatureFlagAppExt, ZedPro};
|
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
|
||||||
use futures::{
|
use futures::{
|
||||||
future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
|
future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
|
||||||
TryStreamExt as _,
|
TryStreamExt as _,
|
||||||
|
@ -26,7 +26,10 @@ use smol::{
|
||||||
io::{AsyncReadExt, BufReader},
|
io::{AsyncReadExt, BufReader},
|
||||||
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
|
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
|
||||||
};
|
};
|
||||||
use std::{future, sync::Arc};
|
use std::{
|
||||||
|
future,
|
||||||
|
sync::{Arc, LazyLock},
|
||||||
|
};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
|
||||||
|
@ -37,6 +40,18 @@ use super::anthropic::count_anthropic_tokens;
|
||||||
pub const PROVIDER_ID: &str = "zed.dev";
|
pub const PROVIDER_ID: &str = "zed.dev";
|
||||||
pub const PROVIDER_NAME: &str = "Zed";
|
pub const PROVIDER_NAME: &str = "Zed";
|
||||||
|
|
||||||
|
const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
|
||||||
|
option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
|
||||||
|
|
||||||
|
fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
|
||||||
|
static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
|
||||||
|
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
|
||||||
|
.map(|json| serde_json::from_str(json).unwrap())
|
||||||
|
.unwrap_or(Vec::new())
|
||||||
|
});
|
||||||
|
ADDITIONAL_MODELS.as_slice()
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Default, Clone, Debug, PartialEq)]
|
#[derive(Default, Clone, Debug, PartialEq)]
|
||||||
pub struct ZedDotDevSettings {
|
pub struct ZedDotDevSettings {
|
||||||
pub available_models: Vec<AvailableModel>,
|
pub available_models: Vec<AvailableModel>,
|
||||||
|
@ -200,40 +215,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||||
for model in ZedModel::iter() {
|
for model in ZedModel::iter() {
|
||||||
models.insert(model.id().to_string(), CloudModel::Zed(model));
|
models.insert(model.id().to_string(), CloudModel::Zed(model));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override with available models from settings
|
|
||||||
for model in &AllLanguageModelSettings::get_global(cx)
|
|
||||||
.zed_dot_dev
|
|
||||||
.available_models
|
|
||||||
{
|
|
||||||
let model = match model.provider {
|
|
||||||
AvailableProvider::Anthropic => {
|
|
||||||
CloudModel::Anthropic(anthropic::Model::Custom {
|
|
||||||
name: model.name.clone(),
|
|
||||||
display_name: model.display_name.clone(),
|
|
||||||
max_tokens: model.max_tokens,
|
|
||||||
tool_override: model.tool_override.clone(),
|
|
||||||
cache_configuration: model.cache_configuration.as_ref().map(|config| {
|
|
||||||
anthropic::AnthropicModelCacheConfiguration {
|
|
||||||
max_cache_anchors: config.max_cache_anchors,
|
|
||||||
should_speculate: config.should_speculate,
|
|
||||||
min_total_token: config.min_total_token,
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
max_output_tokens: model.max_output_tokens,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
|
||||||
name: model.name.clone(),
|
|
||||||
max_tokens: model.max_tokens,
|
|
||||||
}),
|
|
||||||
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
|
|
||||||
name: model.name.clone(),
|
|
||||||
max_tokens: model.max_tokens,
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
models.insert(model.id().to_string(), model.clone());
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
models.insert(
|
models.insert(
|
||||||
anthropic::Model::Claude3_5Sonnet.id().to_string(),
|
anthropic::Model::Claude3_5Sonnet.id().to_string(),
|
||||||
|
@ -241,6 +222,47 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
|
||||||
|
zed_cloud_provider_additional_models()
|
||||||
|
} else {
|
||||||
|
&[]
|
||||||
|
};
|
||||||
|
|
||||||
|
// Override with available models from settings
|
||||||
|
for model in AllLanguageModelSettings::get_global(cx)
|
||||||
|
.zed_dot_dev
|
||||||
|
.available_models
|
||||||
|
.iter()
|
||||||
|
.chain(llm_closed_beta_models)
|
||||||
|
.cloned()
|
||||||
|
{
|
||||||
|
let model = match model.provider {
|
||||||
|
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
|
||||||
|
name: model.name.clone(),
|
||||||
|
display_name: model.display_name.clone(),
|
||||||
|
max_tokens: model.max_tokens,
|
||||||
|
tool_override: model.tool_override.clone(),
|
||||||
|
cache_configuration: model.cache_configuration.as_ref().map(|config| {
|
||||||
|
anthropic::AnthropicModelCacheConfiguration {
|
||||||
|
max_cache_anchors: config.max_cache_anchors,
|
||||||
|
should_speculate: config.should_speculate,
|
||||||
|
min_total_token: config.min_total_token,
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
max_output_tokens: model.max_output_tokens,
|
||||||
|
}),
|
||||||
|
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
||||||
|
name: model.name.clone(),
|
||||||
|
max_tokens: model.max_tokens,
|
||||||
|
}),
|
||||||
|
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
|
||||||
|
name: model.name.clone(),
|
||||||
|
max_tokens: model.max_tokens,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
models.insert(model.id().to_string(), model.clone());
|
||||||
|
}
|
||||||
|
|
||||||
models
|
models
|
||||||
.into_values()
|
.into_values()
|
||||||
.map(|model| {
|
.map(|model| {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue