collab: Add GET /models endpoint to LLM service (#17307)

This PR adds a `GET /models` endpoint to the LLM service.

This endpoint returns the models that the authenticated user has access
to.

This is the first step towards populating the models for the hosted
service from the server.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-03 11:41:32 -04:00 committed by GitHub
parent 122f01f9e5
commit 30056254f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 67 additions and 22 deletions

View file

@ -7,7 +7,7 @@ use crate::{Config, Error, Result};
pub fn authorize_access_to_language_model(
config: &Config,
claims: &LlmTokenClaims,
country_code: Option<String>,
country_code: Option<&str>,
provider: LanguageModelProvider,
model: &str,
) -> Result<()> {
@ -49,7 +49,7 @@ fn authorize_access_to_model(
fn authorize_access_for_country(
config: &Config,
country_code: Option<String>,
country_code: Option<&str>,
provider: LanguageModelProvider,
) -> Result<()> {
// In development we won't have the `CF-IPCountry` header, so we can't check
@ -62,7 +62,7 @@ fn authorize_access_for_country(
}
// https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry
let country_code = match country_code.as_deref() {
let country_code = match country_code {
// `XX` - Used for clients without country code data.
None | Some("XX") => Err(Error::http(
StatusCode::BAD_REQUEST,
@ -128,7 +128,7 @@ mod tests {
authorize_access_to_language_model(
&config,
&claims,
Some(country_code.into()),
Some(country_code),
provider,
"the-model",
)
@ -178,7 +178,7 @@ mod tests {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code.into()),
Some(country_code),
provider,
"the-model",
)
@ -223,7 +223,7 @@ mod tests {
let error_response = authorize_access_to_language_model(
&config,
&claims,
Some(country_code.into()),
Some(country_code),
provider,
"the-model",
)
@ -278,13 +278,8 @@ mod tests {
..Default::default()
};
let result = authorize_access_to_language_model(
&config,
&claims,
Some("US".into()),
provider,
model,
);
let result =
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
if expected_access {
assert!(
@ -324,13 +319,8 @@ mod tests {
];
for (provider, model) in test_cases {
let result = authorize_access_to_language_model(
&config,
&claims,
Some("US".into()),
provider,
model,
);
let result =
authorize_access_to_language_model(&config, &claims, Some("US"), provider, model);
assert!(
result.is_ok(),

View file

@ -67,6 +67,14 @@ impl LlmDatabase {
Ok(())
}
/// Returns the list of all known models, with their [`LanguageModelProvider`].
pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
self.models
.iter()
.map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
.collect::<Vec<_>>()
}
/// Returns the names of the known models for the given [`LanguageModelProvider`].
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
self.models