Authorize access to language model providers based on country (#15859)
This PR updates the LLM service to authorize access to language model providers based on the requester's country. We detect the country using Cloudflare's [`CF-IPCountry`](https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry) header. The country code is then checked against the list of supported countries for the given LLM provider. Countries that are not supported will receive an `HTTP 451: Unavailable For Legal Reasons` response. Release Notes: - N/A
This commit is contained in:
parent
9c6ccaffe3
commit
cf5f4dddf5
13 changed files with 921 additions and 1 deletions
|
@ -90,6 +90,7 @@ fs = { workspace = true, features = ["test-support"] }
|
|||
git = { workspace = true, features = ["test-support"] }
|
||||
git_hosting_providers.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
hyper.workspace = true
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
@ -185,6 +185,46 @@ impl Config {
|
|||
_ => "https://zed.dev",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
Self {
|
||||
http_port: 0,
|
||||
database_url: "".into(),
|
||||
database_max_connections: 0,
|
||||
api_token: "".into(),
|
||||
invite_link_prefix: "".into(),
|
||||
live_kit_server: None,
|
||||
live_kit_key: None,
|
||||
live_kit_secret: None,
|
||||
llm_api_secret: None,
|
||||
rust_log: None,
|
||||
log_json: None,
|
||||
zed_environment: "test".into(),
|
||||
blob_store_url: None,
|
||||
blob_store_region: None,
|
||||
blob_store_access_key: None,
|
||||
blob_store_secret_key: None,
|
||||
blob_store_bucket: None,
|
||||
openai_api_key: None,
|
||||
google_ai_api_key: None,
|
||||
anthropic_api_key: None,
|
||||
clickhouse_url: None,
|
||||
clickhouse_user: None,
|
||||
clickhouse_password: None,
|
||||
clickhouse_database: None,
|
||||
zed_client_checksum_seed: None,
|
||||
slack_panics_webhook: None,
|
||||
auto_join_channel_id: None,
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
qwen2_7b_api_key: None,
|
||||
qwen2_7b_api_url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The service mode that collab should run in.
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
mod authorization;
|
||||
mod token;
|
||||
|
||||
use crate::api::CloudflareIpCountryHeader;
|
||||
use crate::llm::authorization::authorize_access_to_language_model;
|
||||
use crate::{executor::Executor, Config, Error, Result};
|
||||
use anyhow::Context as _;
|
||||
use axum::TypedHeader;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||
|
@ -91,9 +95,18 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
|
|||
|
||||
async fn perform_completion(
|
||||
Extension(state): Extension<Arc<LlmState>>,
|
||||
Extension(_claims): Extension<LlmTokenClaims>,
|
||||
Extension(claims): Extension<LlmTokenClaims>,
|
||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
Json(params): Json<PerformCompletionParams>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
authorize_access_to_language_model(
|
||||
&state.config,
|
||||
&claims,
|
||||
country_code_header.map(|header| header.to_string()),
|
||||
params.provider,
|
||||
¶ms.model,
|
||||
)?;
|
||||
|
||||
match params.provider {
|
||||
LanguageModelProvider::Anthropic => {
|
||||
let api_key = state
|
||||
|
|
213
crates/collab/src/llm/authorization.rs
Normal file
213
crates/collab/src/llm/authorization.rs
Normal file
|
@ -0,0 +1,213 @@
|
|||
use reqwest::StatusCode;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
use crate::llm::LlmTokenClaims;
|
||||
use crate::{Config, Error, Result};
|
||||
|
||||
pub fn authorize_access_to_language_model(
|
||||
config: &Config,
|
||||
_claims: &LlmTokenClaims,
|
||||
country_code: Option<String>,
|
||||
provider: LanguageModelProvider,
|
||||
model: &str,
|
||||
) -> Result<()> {
|
||||
authorize_access_for_country(config, country_code, provider, model)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn authorize_access_for_country(
|
||||
config: &Config,
|
||||
country_code: Option<String>,
|
||||
provider: LanguageModelProvider,
|
||||
_model: &str,
|
||||
) -> Result<()> {
|
||||
// In development we won't have the `CF-IPCountry` header, so we can't check
|
||||
// the country code.
|
||||
//
|
||||
// This shouldn't be necessary, as anyone running in development will need to provide
|
||||
// their own API credentials in order to use an LLM provider.
|
||||
if config.is_development() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
|
||||
let country_code = match country_code.as_deref() {
|
||||
// `XX` - Used for clients without country code data.
|
||||
None | Some("XX") => Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"no country code".to_string(),
|
||||
))?,
|
||||
// `T1` - Used for clients using the Tor network.
|
||||
Some("T1") => Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
format!("access to {provider:?} models is not available over Tor"),
|
||||
))?,
|
||||
Some(country_code) => country_code,
|
||||
};
|
||||
|
||||
let is_country_supported_by_provider = match provider {
|
||||
LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code),
|
||||
LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code),
|
||||
LanguageModelProvider::Google => google_ai::is_supported_country(country_code),
|
||||
LanguageModelProvider::Zed => true,
|
||||
};
|
||||
if !is_country_supported_by_provider {
|
||||
Err(Error::http(
|
||||
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS,
|
||||
format!("access to {provider:?} models is not available in your region"),
|
||||
))?
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::response::IntoResponse;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::proto::Plan;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_with_supported_country(
|
||||
_cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "US"), // United States
|
||||
(LanguageModelProvider::Anthropic, "GB"), // United Kingdom
|
||||
(LanguageModelProvider::OpenAi, "US"), // United States
|
||||
(LanguageModelProvider::OpenAi, "GB"), // United Kingdom
|
||||
(LanguageModelProvider::Google, "US"), // United States
|
||||
(LanguageModelProvider::Google, "GB"), // United Kingdom
|
||||
];
|
||||
|
||||
for (provider, country_code) in cases {
|
||||
authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some(country_code.into()),
|
||||
provider,
|
||||
"the-model",
|
||||
)
|
||||
.unwrap_or_else(|_| {
|
||||
panic!("expected authorization to return Ok for {provider:?}: {country_code}")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_with_unsupported_country(
|
||||
_cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "AF"), // Afghanistan
|
||||
(LanguageModelProvider::Anthropic, "BY"), // Belarus
|
||||
(LanguageModelProvider::Anthropic, "CF"), // Central African Republic
|
||||
(LanguageModelProvider::Anthropic, "CN"), // China
|
||||
(LanguageModelProvider::Anthropic, "CU"), // Cuba
|
||||
(LanguageModelProvider::Anthropic, "ER"), // Eritrea
|
||||
(LanguageModelProvider::Anthropic, "ET"), // Ethiopia
|
||||
(LanguageModelProvider::Anthropic, "IR"), // Iran
|
||||
(LanguageModelProvider::Anthropic, "KP"), // North Korea
|
||||
(LanguageModelProvider::Anthropic, "XK"), // Kosovo
|
||||
(LanguageModelProvider::Anthropic, "LY"), // Libya
|
||||
(LanguageModelProvider::Anthropic, "MM"), // Myanmar
|
||||
(LanguageModelProvider::Anthropic, "RU"), // Russia
|
||||
(LanguageModelProvider::Anthropic, "SO"), // Somalia
|
||||
(LanguageModelProvider::Anthropic, "SS"), // South Sudan
|
||||
(LanguageModelProvider::Anthropic, "SD"), // Sudan
|
||||
(LanguageModelProvider::Anthropic, "SY"), // Syria
|
||||
(LanguageModelProvider::Anthropic, "VE"), // Venezuela
|
||||
(LanguageModelProvider::Anthropic, "YE"), // Yemen
|
||||
(LanguageModelProvider::OpenAi, "KP"), // North Korea
|
||||
(LanguageModelProvider::Google, "KP"), // North Korea
|
||||
];
|
||||
|
||||
for (provider, country_code) in cases {
|
||||
let error_response = authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some(country_code.into()),
|
||||
provider,
|
||||
"the-model",
|
||||
)
|
||||
.expect_err(&format!(
|
||||
"expected authorization to return an error for {provider:?}: {country_code}"
|
||||
))
|
||||
.into_response();
|
||||
|
||||
assert_eq!(
|
||||
error_response.status(),
|
||||
StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
|
||||
);
|
||||
let response_body = hyper::body::to_bytes(error_response.into_body())
|
||||
.await
|
||||
.unwrap()
|
||||
.to_vec();
|
||||
assert_eq!(
|
||||
String::from_utf8(response_body).unwrap(),
|
||||
format!("access to {provider:?} models is not available in your region")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) {
|
||||
let config = Config::test();
|
||||
|
||||
let claims = LlmTokenClaims {
|
||||
user_id: 99,
|
||||
plan: Plan::ZedPro,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cases = vec![
|
||||
(LanguageModelProvider::Anthropic, "T1"), // Tor
|
||||
(LanguageModelProvider::OpenAi, "T1"), // Tor
|
||||
(LanguageModelProvider::Google, "T1"), // Tor
|
||||
(LanguageModelProvider::Zed, "T1"), // Tor
|
||||
];
|
||||
|
||||
for (provider, country_code) in cases {
|
||||
let error_response = authorize_access_to_language_model(
|
||||
&config,
|
||||
&claims,
|
||||
Some(country_code.into()),
|
||||
provider,
|
||||
"the-model",
|
||||
)
|
||||
.expect_err(&format!(
|
||||
"expected authorization to return an error for {provider:?}: {country_code}"
|
||||
))
|
||||
.into_response();
|
||||
|
||||
assert_eq!(error_response.status(), StatusCode::FORBIDDEN);
|
||||
let response_body = hyper::body::to_bytes(error_response.into_body())
|
||||
.await
|
||||
.unwrap()
|
||||
.to_vec();
|
||||
assert_eq!(
|
||||
String::from_utf8(response_body).unwrap(),
|
||||
format!("access to {provider:?} models is not available over Tor")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue