collab: Add ability to revoke LLM service access tokens (#16143)

This PR adds the ability to revoke access tokens for the LLM service.

There is a new `revoked_access_tokens` table that contains the
identifiers (`jti`) of revoked access tokens.

To revoke an access token, insert a record into this table:

```sql
insert into revoked_access_tokens (jti) values ('1e887b9e-37f5-49e8-8feb-3274e5a86b67');
```

We now attach the `jti` as `authn.jti` to the tracing spans so that we
can associate an access token with a given request to the LLM service.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-12 21:47:05 -04:00 committed by GitHub
parent 0bc9fc9487
commit b4c22cc861
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 54 additions and 0 deletions

View file

@ -0,0 +1,7 @@
create table revoked_access_tokens (
id serial primary key,
jti text not null,
revoked_at timestamp without time zone not null default now()
);
create unique index uix_revoked_access_tokens_on_jti on revoked_access_tokens (jti);

View file

@ -131,6 +131,15 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
let state = req.extensions().get::<Arc<LlmState>>().unwrap(); let state = req.extensions().get::<Arc<LlmState>>().unwrap();
match LlmTokenClaims::validate(&token, &state.config) { match LlmTokenClaims::validate(&token, &state.config) {
Ok(claims) => { Ok(claims) => {
if state.db.is_access_token_revoked(&claims.jti).await? {
return Err(Error::http(
StatusCode::UNAUTHORIZED,
"unauthorized".to_string(),
));
}
tracing::Span::current().record("authn.jti", &claims.jti);
req.extensions_mut().insert(claims); req.extensions_mut().insert(claims);
Ok::<_, Error>(next.run(req).await.into_response()) Ok::<_, Error>(next.run(req).await.into_response())
} }

View file

@ -7,3 +7,4 @@ id_type!(ModelId);
id_type!(ProviderId); id_type!(ProviderId);
id_type!(UsageId); id_type!(UsageId);
id_type!(UsageMeasureId); id_type!(UsageMeasureId);
id_type!(RevokedAccessTokenId);

View file

@ -1,4 +1,5 @@
use super::*; use super::*;
pub mod providers; pub mod providers;
pub mod revoked_access_tokens;
pub mod usages; pub mod usages;

View file

@ -0,0 +1,15 @@
use super::*;
impl LlmDatabase {
/// Returns whether the access token with the given `jti` has been revoked.
pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
self.transaction(|tx| async move {
Ok(revoked_access_token::Entity::find()
.filter(revoked_access_token::Column::Jti.eq(jti))
.one(&*tx)
.await?
.is_some())
})
.await
}
}

View file

@ -1,5 +1,6 @@
pub mod lifetime_usage; pub mod lifetime_usage;
pub mod model; pub mod model;
pub mod provider; pub mod provider;
pub mod revoked_access_token;
pub mod usage; pub mod usage;
pub mod usage_measure; pub mod usage_measure;

View file

@ -0,0 +1,19 @@
use chrono::NaiveDateTime;
use sea_orm::entity::prelude::*;
use crate::llm::db::RevokedAccessTokenId;
/// A revoked access token.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "revoked_access_tokens")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: RevokedAccessTokenId,
pub jti: String,
pub revoked_at: NaiveDateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -150,6 +150,7 @@ async fn main() -> Result<()> {
"http_request", "http_request",
method = ?request.method(), method = ?request.method(),
matched_path, matched_path,
authn.jti = tracing::field::Empty
) )
}) })
.on_response( .on_response(