assistant: Limit model access for Zed AI users to Claude-3.5-sonnet (#15904)
This prevents users from accessing other models, such as OpenAI's GPT-4 or Google's Gemini-Pro. Staff members can still access all models. Co-authored-by: Thorsten <thorsten@zed.dev> Release Notes: - N/A --------- Co-authored-by: Thorsten <thorsten@zed.dev>
This commit is contained in:
parent
efbf7ada28
commit
3a52d6cc52
3 changed files with 130 additions and 5 deletions
|
@ -6,21 +6,40 @@ use crate::{Config, Error, Result};
|
||||||
|
|
||||||
pub fn authorize_access_to_language_model(
|
pub fn authorize_access_to_language_model(
|
||||||
config: &Config,
|
config: &Config,
|
||||||
_claims: &LlmTokenClaims,
|
claims: &LlmTokenClaims,
|
||||||
country_code: Option<String>,
|
country_code: Option<String>,
|
||||||
provider: LanguageModelProvider,
|
provider: LanguageModelProvider,
|
||||||
model: &str,
|
model: &str,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
authorize_access_for_country(config, country_code, provider, model)?;
|
authorize_access_for_country(config, country_code, provider)?;
|
||||||
|
authorize_access_to_model(claims, provider, model)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn authorize_access_to_model(
|
||||||
|
claims: &LlmTokenClaims,
|
||||||
|
provider: LanguageModelProvider,
|
||||||
|
model: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
if claims.is_staff {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
match (provider, model) {
|
||||||
|
(LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
_ => Err(Error::http(
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
format!("access to model {model:?} is not included in your plan"),
|
||||||
|
))?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn authorize_access_for_country(
|
fn authorize_access_for_country(
|
||||||
config: &Config,
|
config: &Config,
|
||||||
country_code: Option<String>,
|
country_code: Option<String>,
|
||||||
provider: LanguageModelProvider,
|
provider: LanguageModelProvider,
|
||||||
_model: &str,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// In development we won't have the `CF-IPCountry` header, so we can't check
|
// In development we won't have the `CF-IPCountry` header, so we can't check
|
||||||
// the country code.
|
// the country code.
|
||||||
|
@ -79,6 +98,7 @@ mod tests {
|
||||||
let claims = LlmTokenClaims {
|
let claims = LlmTokenClaims {
|
||||||
user_id: 99,
|
user_id: 99,
|
||||||
plan: Plan::ZedPro,
|
plan: Plan::ZedPro,
|
||||||
|
is_staff: true,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -210,4 +230,101 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_authorize_access_to_language_model_based_on_plan() {
|
||||||
|
let config = Config::test();
|
||||||
|
|
||||||
|
let test_cases = vec![
|
||||||
|
// Pro plan should have access to claude-3.5-sonnet
|
||||||
|
(
|
||||||
|
Plan::ZedPro,
|
||||||
|
LanguageModelProvider::Anthropic,
|
||||||
|
"claude-3.5-sonnet",
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
// Free plan should have access to claude-3.5-sonnet
|
||||||
|
(
|
||||||
|
Plan::Free,
|
||||||
|
LanguageModelProvider::Anthropic,
|
||||||
|
"claude-3.5-sonnet",
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
// Pro plan should NOT have access to other Anthropic models
|
||||||
|
(
|
||||||
|
Plan::ZedPro,
|
||||||
|
LanguageModelProvider::Anthropic,
|
||||||
|
"claude-3-opus",
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (plan, provider, model, expected_access) in test_cases {
|
||||||
|
let claims = LlmTokenClaims {
|
||||||
|
plan,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = authorize_access_to_language_model(
|
||||||
|
&config,
|
||||||
|
&claims,
|
||||||
|
Some("US".into()),
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
);
|
||||||
|
|
||||||
|
if expected_access {
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"Expected access to be granted for plan {:?}, provider {:?}, model {}",
|
||||||
|
plan,
|
||||||
|
provider,
|
||||||
|
model
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
let error = result.expect_err(&format!(
|
||||||
|
"Expected access to be denied for plan {:?}, provider {:?}, model {}",
|
||||||
|
plan, provider, model
|
||||||
|
));
|
||||||
|
let response = error.into_response();
|
||||||
|
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_authorize_access_to_language_model_for_staff() {
|
||||||
|
let config = Config::test();
|
||||||
|
|
||||||
|
let claims = LlmTokenClaims {
|
||||||
|
is_staff: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Staff should have access to all models
|
||||||
|
let test_cases = vec![
|
||||||
|
(LanguageModelProvider::Anthropic, "claude-3.5-sonnet"),
|
||||||
|
(LanguageModelProvider::Anthropic, "claude-2"),
|
||||||
|
(LanguageModelProvider::Anthropic, "claude-123-agi"),
|
||||||
|
(LanguageModelProvider::OpenAi, "gpt-4"),
|
||||||
|
(LanguageModelProvider::Google, "gemini-pro"),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (provider, model) in test_cases {
|
||||||
|
let result = authorize_access_to_language_model(
|
||||||
|
&config,
|
||||||
|
&claims,
|
||||||
|
Some("US".into()),
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"Expected staff to have access to provider {:?}, model {}",
|
||||||
|
provider,
|
||||||
|
model
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,13 +13,19 @@ pub struct LlmTokenClaims {
|
||||||
pub exp: u64,
|
pub exp: u64,
|
||||||
pub jti: String,
|
pub jti: String,
|
||||||
pub user_id: u64,
|
pub user_id: u64,
|
||||||
|
pub is_staff: bool,
|
||||||
pub plan: rpc::proto::Plan,
|
pub plan: rpc::proto::Plan,
|
||||||
}
|
}
|
||||||
|
|
||||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||||
|
|
||||||
impl LlmTokenClaims {
|
impl LlmTokenClaims {
|
||||||
pub fn create(user_id: UserId, plan: rpc::proto::Plan, config: &Config) -> Result<String> {
|
pub fn create(
|
||||||
|
user_id: UserId,
|
||||||
|
is_staff: bool,
|
||||||
|
plan: rpc::proto::Plan,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<String> {
|
||||||
let secret = config
|
let secret = config
|
||||||
.llm_api_secret
|
.llm_api_secret
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
@ -31,6 +37,7 @@ impl LlmTokenClaims {
|
||||||
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
||||||
jti: uuid::Uuid::new_v4().to_string(),
|
jti: uuid::Uuid::new_v4().to_string(),
|
||||||
user_id: user_id.to_proto(),
|
user_id: user_id.to_proto(),
|
||||||
|
is_staff,
|
||||||
plan,
|
plan,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -5164,6 +5164,7 @@ async fn get_llm_api_token(
|
||||||
|
|
||||||
let token = LlmTokenClaims::create(
|
let token = LlmTokenClaims::create(
|
||||||
session.user_id(),
|
session.user_id(),
|
||||||
|
session.is_staff(),
|
||||||
session.current_plan().await?,
|
session.current_plan().await?,
|
||||||
&session.app_state.config,
|
&session.app_state.config,
|
||||||
)?;
|
)?;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue