
This PR reworks how we do checks for model names in the LLM service. We now normalize the model names using the models defined in the database. Release Notes: - N/A
328 lines
11 KiB
Rust
328 lines
11 KiB
Rust
use reqwest::StatusCode;
|
|
use rpc::LanguageModelProvider;
|
|
|
|
use crate::llm::LlmTokenClaims;
|
|
use crate::{Config, Error, Result};
|
|
|
|
pub fn authorize_access_to_language_model(
|
|
config: &Config,
|
|
claims: &LlmTokenClaims,
|
|
country_code: Option<String>,
|
|
provider: LanguageModelProvider,
|
|
model: &str,
|
|
) -> Result<()> {
|
|
authorize_access_for_country(config, country_code, provider)?;
|
|
authorize_access_to_model(claims, provider, model)?;
|
|
Ok(())
|
|
}
|
|
|
|
fn authorize_access_to_model(
|
|
claims: &LlmTokenClaims,
|
|
provider: LanguageModelProvider,
|
|
model: &str,
|
|
) -> Result<()> {
|
|
if claims.is_staff {
|
|
return Ok(());
|
|
}
|
|
|
|
match (provider, model) {
|
|
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
|
|
_ => Err(Error::http(
|
|
StatusCode::FORBIDDEN,
|
|
format!("access to model {model:?} is not included in your plan"),
|
|
))?,
|
|
}
|
|
}
|
|
|
|
fn authorize_access_for_country(
|
|
config: &Config,
|
|
country_code: Option<String>,
|
|
provider: LanguageModelProvider,
|
|
) -> Result<()> {
|
|
// In development we won't have the `CF-IPCountry` header, so we can't check
|
|
// the country code.
|
|
//
|
|
// This shouldn't be necessary, as anyone running in development will need to provide
|
|
// their own API credentials in order to use an LLM provider.
|
|
if config.is_development() {
|
|
return Ok(());
|
|
}
|
|
|
|
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
|
|
let country_code = match country_code.as_deref() {
|
|
// `XX` - Used for clients without country code data.
|
|
None | Some("XX") => Err(Error::http(
|
|
StatusCode::BAD_REQUEST,
|
|
"no country code".to_string(),
|
|
))?,
|
|
// `T1` - Used for clients using the Tor network.
|
|
Some("T1") => Err(Error::http(
|
|
StatusCode::FORBIDDEN,
|
|
format!("access to {provider:?} models is not available over Tor"),
|
|
))?,
|
|
Some(country_code) => country_code,
|
|
};
|
|
|
|
let is_country_supported_by_provider = match provider {
|
|
LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
|
|
LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
|
|
LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
|
|
LanguageModelProvider::Zed => true,
|
|
};
|
|
if !is_country_supported_by_provider {
|
|
Err(Error::http(
|
|
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
|
|
format!("access to {provider:?} models is not available in your region"),
|
|
))?
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use axum::response::IntoResponse;
|
|
use pretty_assertions::assert_eq;
|
|
use rpc::proto::Plan;
|
|
|
|
use super::*;
|
|
|
|
#[gpui::test]
|
|
async fn test_authorize_access_to_language_model_with_supported_country(
|
|
_cx: &mut gpui::TestAppContext,
|
|
) {
|
|
let config = Config::test();
|
|
|
|
let claims = LlmTokenClaims {
|
|
user_id: 99,
|
|
plan: Plan::ZedPro,
|
|
is_staff: true,
|
|
..Default::default()
|
|
};
|
|
|
|
let cases = vec![
|
|
(LanguageModelProvider::Anthropic, "US"), // United States
|
|
(LanguageModelProvider::Anthropic, "GB"), // United Kingdom
|
|
(LanguageModelProvider::OpenAi, "US"), // United States
|
|
(LanguageModelProvider::OpenAi, "GB"), // United Kingdom
|
|
(LanguageModelProvider::Google, "US"), // United States
|
|
(LanguageModelProvider::Google, "GB"), // United Kingdom
|
|
];
|
|
|
|
for (provider, country_code) in cases {
|
|
authorize_access_to_language_model(
|
|
&config,
|
|
&claims,
|
|
Some(country_code.into()),
|
|
provider,
|
|
"the-model",
|
|
)
|
|
.unwrap_or_else(|_| {
|
|
panic!("expected authorization to return Ok for {provider:?}: {country_code}")
|
|
})
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_authorize_access_to_language_model_with_unsupported_country(
|
|
_cx: &mut gpui::TestAppContext,
|
|
) {
|
|
let config = Config::test();
|
|
|
|
let claims = LlmTokenClaims {
|
|
user_id: 99,
|
|
plan: Plan::ZedPro,
|
|
..Default::default()
|
|
};
|
|
|
|
let cases = vec![
|
|
(LanguageModelProvider::Anthropic, "AF"), // Afghanistan
|
|
(LanguageModelProvider::Anthropic, "BY"), // Belarus
|
|
(LanguageModelProvider::Anthropic, "CF"), // Central African Republic
|
|
(LanguageModelProvider::Anthropic, "CN"), // China
|
|
(LanguageModelProvider::Anthropic, "CU"), // Cuba
|
|
(LanguageModelProvider::Anthropic, "ER"), // Eritrea
|
|
(LanguageModelProvider::Anthropic, "ET"), // Ethiopia
|
|
(LanguageModelProvider::Anthropic, "IR"), // Iran
|
|
(LanguageModelProvider::Anthropic, "KP"), // North Korea
|
|
(LanguageModelProvider::Anthropic, "XK"), // Kosovo
|
|
(LanguageModelProvider::Anthropic, "LY"), // Libya
|
|
(LanguageModelProvider::Anthropic, "MM"), // Myanmar
|
|
(LanguageModelProvider::Anthropic, "RU"), // Russia
|
|
(LanguageModelProvider::Anthropic, "SO"), // Somalia
|
|
(LanguageModelProvider::Anthropic, "SS"), // South Sudan
|
|
(LanguageModelProvider::Anthropic, "SD"), // Sudan
|
|
(LanguageModelProvider::Anthropic, "SY"), // Syria
|
|
(LanguageModelProvider::Anthropic, "VE"), // Venezuela
|
|
(LanguageModelProvider::Anthropic, "YE"), // Yemen
|
|
(LanguageModelProvider::OpenAi, "KP"), // North Korea
|
|
(LanguageModelProvider::Google, "KP"), // North Korea
|
|
];
|
|
|
|
for (provider, country_code) in cases {
|
|
let error_response = authorize_access_to_language_model(
|
|
&config,
|
|
&claims,
|
|
Some(country_code.into()),
|
|
provider,
|
|
"the-model",
|
|
)
|
|
.expect_err(&format!(
|
|
"expected authorization to return an error for {provider:?}: {country_code}"
|
|
))
|
|
.into_response();
|
|
|
|
assert_eq!(
|
|
error_response.status(),
|
|
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
|
|
);
|
|
let response_body = hyper::body::to_bytes(error_response.into_body())
|
|
.await
|
|
.unwrap()
|
|
.to_vec();
|
|
assert_eq!(
|
|
String::from_utf8(response_body).unwrap(),
|
|
format!("access to {provider:?} models is not available in your region")
|
|
);
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
|
|
let config = Config::test();
|
|
|
|
let claims = LlmTokenClaims {
|
|
user_id: 99,
|
|
plan: Plan::ZedPro,
|
|
..Default::default()
|
|
};
|
|
|
|
let cases = vec![
|
|
(LanguageModelProvider::Anthropic, "T1"), // Tor
|
|
(LanguageModelProvider::OpenAi, "T1"), // Tor
|
|
(LanguageModelProvider::Google, "T1"), // Tor
|
|
(LanguageModelProvider::Zed, "T1"), // Tor
|
|
];
|
|
|
|
for (provider, country_code) in cases {
|
|
let error_response = authorize_access_to_language_model(
|
|
&config,
|
|
&claims,
|
|
Some(country_code.into()),
|
|
provider,
|
|
"the-model",
|
|
)
|
|
.expect_err(&format!(
|
|
"expected authorization to return an error for {provider:?}: {country_code}"
|
|
))
|
|
.into_response();
|
|
|
|
assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
|
|
let response_body = hyper::body::to_bytes(error_response.into_body())
|
|
.await
|
|
.unwrap()
|
|
.to_vec();
|
|
assert_eq!(
|
|
String::from_utf8(response_body).unwrap(),
|
|
format!("access to {provider:?} models is not available over Tor")
|
|
);
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_authorize_access_to_language_model_based_on_plan() {
|
|
let config = Config::test();
|
|
|
|
let test_cases = vec![
|
|
// Pro plan should have access to claude-3.5-sonnet
|
|
(
|
|
Plan::ZedPro,
|
|
LanguageModelProvider::Anthropic,
|
|
"claude-3-5-sonnet",
|
|
true,
|
|
),
|
|
// Free plan should have access to claude-3.5-sonnet
|
|
(
|
|
Plan::Free,
|
|
LanguageModelProvider::Anthropic,
|
|
"claude-3-5-sonnet",
|
|
true,
|
|
),
|
|
// Pro plan should NOT have access to other Anthropic models
|
|
(
|
|
Plan::ZedPro,
|
|
LanguageModelProvider::Anthropic,
|
|
"claude-3-opus",
|
|
false,
|
|
),
|
|
];
|
|
|
|
for (plan, provider, model, expected_access) in test_cases {
|
|
let claims = LlmTokenClaims {
|
|
plan,
|
|
..Default::default()
|
|
};
|
|
|
|
let result = authorize_access_to_language_model(
|
|
&config,
|
|
&claims,
|
|
Some("US".into()),
|
|
provider,
|
|
model,
|
|
);
|
|
|
|
if expected_access {
|
|
assert!(
|
|
result.is_ok(),
|
|
"Expected access to be granted for plan {:?}, provider {:?}, model {}",
|
|
plan,
|
|
provider,
|
|
model
|
|
);
|
|
} else {
|
|
let error = result.expect_err(&format!(
|
|
"Expected access to be denied for plan {:?}, provider {:?}, model {}",
|
|
plan, provider, model
|
|
));
|
|
let response = error.into_response();
|
|
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_authorize_access_to_language_model_for_staff() {
|
|
let config = Config::test();
|
|
|
|
let claims = LlmTokenClaims {
|
|
is_staff: true,
|
|
..Default::default()
|
|
};
|
|
|
|
// Staff should have access to all models
|
|
let test_cases = vec![
|
|
(LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
|
|
(LanguageModelProvider::Anthropic, "claude-2"),
|
|
(LanguageModelProvider::Anthropic, "claude-123-agi"),
|
|
(LanguageModelProvider::OpenAi, "gpt-4"),
|
|
(LanguageModelProvider::Google, "gemini-pro"),
|
|
];
|
|
|
|
for (provider, model) in test_cases {
|
|
let result = authorize_access_to_language_model(
|
|
&config,
|
|
&claims,
|
|
Some("US".into()),
|
|
provider,
|
|
model,
|
|
);
|
|
|
|
assert!(
|
|
result.is_ok(),
|
|
"Expected staff to have access to provider {:?}, model {}",
|
|
provider,
|
|
model
|
|
);
|
|
}
|
|
}
|
|
}
|