Authorize access to language model providers based on country (#15859)

This PR updates the LLM service to authorize access to language model
providers based on the requester's country.

We detect the country using Cloudflare's
[`CF-IPCountry`](https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry)
header.

The country code is then checked against the list of supported countries
for the given LLM provider. Countries that are not supported will
receive an `HTTP 451: Unavailable For Legal Reasons` response.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-06 11:49:04 -04:00 committed by GitHub
parent 9c6ccaffe3
commit cf5f4dddf5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 921 additions and 1 deletions

View file

@ -1,3 +1,5 @@
mod supported_countries;
use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@ -7,6 +9,8 @@ use serde_json::Value;
use std::{convert::TryFrom, future::Future, time::Duration};
use strum::EnumIter;
pub use supported_countries::*;
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {

View file

@ -0,0 +1,207 @@
use std::collections::HashSet;
use std::sync::LazyLock;
/// Returns whether the given country code is supported by OpenAI.
///
/// https://platform.openai.com/docs/supported-countries
pub fn is_supported_country(country_code: &str) -> bool {
SUPPORTED_COUNTRIES.contains(&country_code)
}
/// The list of country codes supported by OpenAI.
///
/// https://platform.openai.com/docs/supported-countries
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
vec![
"AL", // Albania
"DZ", // Algeria
"AF", // Afghanistan
"AD", // Andorra
"AO", // Angola
"AG", // Antigua and Barbuda
"AR", // Argentina
"AM", // Armenia
"AU", // Australia
"AT", // Austria
"AZ", // Azerbaijan
"BS", // Bahamas
"BH", // Bahrain
"BD", // Bangladesh
"BB", // Barbados
"BE", // Belgium
"BZ", // Belize
"BJ", // Benin
"BT", // Bhutan
"BO", // Bolivia
"BA", // Bosnia and Herzegovina
"BW", // Botswana
"BR", // Brazil
"BN", // Brunei
"BG", // Bulgaria
"BF", // Burkina Faso
"BI", // Burundi
"CV", // Cabo Verde
"KH", // Cambodia
"CM", // Cameroon
"CA", // Canada
"CF", // Central African Republic
"TD", // Chad
"CL", // Chile
"CO", // Colombia
"KM", // Comoros
"CG", // Congo (Brazzaville)
"CD", // Congo (DRC)
"CR", // Costa Rica
"CI", // Côte d'Ivoire
"HR", // Croatia
"CY", // Cyprus
"CZ", // Czechia (Czech Republic)
"DK", // Denmark
"DJ", // Djibouti
"DM", // Dominica
"DO", // Dominican Republic
"EC", // Ecuador
"EG", // Egypt
"SV", // El Salvador
"GQ", // Equatorial Guinea
"ER", // Eritrea
"EE", // Estonia
"SZ", // Eswatini (Swaziland)
"ET", // Ethiopia
"FJ", // Fiji
"FI", // Finland
"FR", // France
"GA", // Gabon
"GM", // Gambia
"GE", // Georgia
"DE", // Germany
"GH", // Ghana
"GR", // Greece
"GD", // Grenada
"GT", // Guatemala
"GN", // Guinea
"GW", // Guinea-Bissau
"GY", // Guyana
"HT", // Haiti
"VA", // Holy See (Vatican City)
"HN", // Honduras
"HU", // Hungary
"IS", // Iceland
"IN", // India
"ID", // Indonesia
"IQ", // Iraq
"IE", // Ireland
"IL", // Israel
"IT", // Italy
"JM", // Jamaica
"JP", // Japan
"JO", // Jordan
"KZ", // Kazakhstan
"KE", // Kenya
"KI", // Kiribati
"KW", // Kuwait
"KG", // Kyrgyzstan
"LA", // Laos
"LV", // Latvia
"LB", // Lebanon
"LS", // Lesotho
"LR", // Liberia
"LY", // Libya
"LI", // Liechtenstein
"LT", // Lithuania
"LU", // Luxembourg
"MG", // Madagascar
"MW", // Malawi
"MY", // Malaysia
"MV", // Maldives
"ML", // Mali
"MT", // Malta
"MH", // Marshall Islands
"MR", // Mauritania
"MU", // Mauritius
"MX", // Mexico
"FM", // Micronesia
"MD", // Moldova
"MC", // Monaco
"MN", // Mongolia
"ME", // Montenegro
"MA", // Morocco
"MZ", // Mozambique
"MM", // Myanmar
"NA", // Namibia
"NR", // Nauru
"NP", // Nepal
"NL", // Netherlands
"NZ", // New Zealand
"NI", // Nicaragua
"NE", // Niger
"NG", // Nigeria
"MK", // North Macedonia
"NO", // Norway
"OM", // Oman
"PK", // Pakistan
"PW", // Palau
"PS", // Palestine
"PA", // Panama
"PG", // Papua New Guinea
"PY", // Paraguay
"PE", // Peru
"PH", // Philippines
"PL", // Poland
"PT", // Portugal
"QA", // Qatar
"RO", // Romania
"RW", // Rwanda
"KN", // Saint Kitts and Nevis
"LC", // Saint Lucia
"VC", // Saint Vincent and the Grenadines
"WS", // Samoa
"SM", // San Marino
"ST", // Sao Tome and Principe
"SA", // Saudi Arabia
"SN", // Senegal
"RS", // Serbia
"SC", // Seychelles
"SL", // Sierra Leone
"SG", // Singapore
"SK", // Slovakia
"SI", // Slovenia
"SB", // Solomon Islands
"SO", // Somalia
"ZA", // South Africa
"KR", // South Korea
"SS", // South Sudan
"ES", // Spain
"LK", // Sri Lanka
"SR", // Suriname
"SE", // Sweden
"CH", // Switzerland
"SD", // Sudan
"TW", // Taiwan
"TJ", // Tajikistan
"TZ", // Tanzania
"TH", // Thailand
"TL", // Timor-Leste (East Timor)
"TG", // Togo
"TO", // Tonga
"TT", // Trinidad and Tobago
"TN", // Tunisia
"TR", // Turkey
"TM", // Turkmenistan
"TV", // Tuvalu
"UG", // Uganda
"UA", // Ukraine (with certain exceptions)
"AE", // United Arab Emirates
"GB", // United Kingdom
"US", // United States of America
"UY", // Uruguay
"UZ", // Uzbekistan
"VU", // Vanuatu
"VN", // Vietnam
"YE", // Yemen
"ZM", // Zambia
"ZW", // Zimbabwe
]
.into_iter()
.collect()
});