collab: Add GET /models
endpoint to LLM service (#17307)
This PR adds a `GET /models` endpoint to the LLM service. This endpoint returns the models that the authenticated user has access to. This is the first step towards populating the models for the hosted service from the server. Release Notes: - N/A
This commit is contained in:
parent
122f01f9e5
commit
30056254f3
4 changed files with 67 additions and 22 deletions
|
@ -9,6 +9,7 @@ use crate::{
|
|||
};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
use authorization::authorize_access_to_language_model;
|
||||
use axum::routing::get;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||
|
@ -22,6 +23,7 @@ use collections::HashMap;
|
|||
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
||||
use futures::{Stream, StreamExt as _};
|
||||
use http_client::IsahcHttpClient;
|
||||
use rpc::ListModelsResponse;
|
||||
use rpc::{
|
||||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
};
|
||||
|
@ -114,6 +116,7 @@ impl LlmState {
|
|||
|
||||
pub fn routes() -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/models", get(list_models))
|
||||
.route("/completion", post(perform_completion))
|
||||
.layer(middleware::from_fn(validate_api_token))
|
||||
}
|
||||
|
@ -173,6 +176,37 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
|
|||
}
|
||||
}
|
||||
|
||||
async fn list_models(
|
||||
Extension(state): Extension<Arc<LlmState>>,
|
||||
Extension(claims): Extension<LlmTokenClaims>,
|
||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
) -> Result<Json<ListModelsResponse>> {
|
||||
let country_code = country_code_header.map(|header| header.to_string());
|
||||
|
||||
let mut accessible_models = Vec::new();
|
||||
|
||||
for (provider, model) in state.db.all_models() {
|
||||
let authorize_result = authorize_access_to_language_model(
|
||||
&state.config,
|
||||
&claims,
|
||||
country_code.as_deref(),
|
||||
provider,
|
||||
&model.name,
|
||||
);
|
||||
|
||||
if authorize_result.is_ok() {
|
||||
accessible_models.push(rpc::LanguageModel {
|
||||
provider,
|
||||
name: model.name,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(ListModelsResponse {
|
||||
models: accessible_models,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn perform_completion(
|
||||
Extension(state): Extension<Arc<LlmState>>,
|
||||
Extension(claims): Extension<LlmTokenClaims>,
|
||||
|
@ -187,7 +221,9 @@ async fn perform_completion(
|
|||
authorize_access_to_language_model(
|
||||
&state.config,
|
||||
&claims,
|
||||
country_code_header.map(|header| header.to_string()),
|
||||
country_code_header
|
||||
.map(|header| header.to_string())
|
||||
.as_deref(),
|
||||
params.provider,
|
||||
&model,
|
||||
)?;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue