diff --git a/crates/collab/migrations_llm/20240813002237_add_revoked_access_tokens_table.sql b/crates/collab/migrations_llm/20240813002237_add_revoked_access_tokens_table.sql new file mode 100644 index 0000000000..c30e58a6dd --- /dev/null +++ b/crates/collab/migrations_llm/20240813002237_add_revoked_access_tokens_table.sql @@ -0,0 +1,7 @@ +create table revoked_access_tokens ( + id serial primary key, + jti text not null, + revoked_at timestamp without time zone not null default now() +); + +create unique index uix_revoked_access_tokens_on_jti on revoked_access_tokens (jti); diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 42f4db7a38..609610f15c 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -131,6 +131,15 @@ async fn validate_api_token(mut req: Request, next: Next) -> impl IntoR let state = req.extensions().get::>().unwrap(); match LlmTokenClaims::validate(&token, &state.config) { Ok(claims) => { + if state.db.is_access_token_revoked(&claims.jti).await? { + return Err(Error::http( + StatusCode::UNAUTHORIZED, + "unauthorized".to_string(), + )); + } + + tracing::Span::current().record("authn.jti", &claims.jti); + req.extensions_mut().insert(claims); Ok::<_, Error>(next.run(req).await.into_response()) } diff --git a/crates/collab/src/llm/db/ids.rs b/crates/collab/src/llm/db/ids.rs index d0705024df..8cc8a0f974 100644 --- a/crates/collab/src/llm/db/ids.rs +++ b/crates/collab/src/llm/db/ids.rs @@ -7,3 +7,4 @@ id_type!(ModelId); id_type!(ProviderId); id_type!(UsageId); id_type!(UsageMeasureId); +id_type!(RevokedAccessTokenId); diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index ded7a54cfb..907d0589f3 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -1,4 +1,5 @@ use super::*; pub mod providers; +pub mod revoked_access_tokens; pub mod usages; diff --git a/crates/collab/src/llm/db/queries/revoked_access_tokens.rs b/crates/collab/src/llm/db/queries/revoked_access_tokens.rs new file mode 100644 index 0000000000..31d70192a0 --- /dev/null +++ b/crates/collab/src/llm/db/queries/revoked_access_tokens.rs @@ -0,0 +1,15 @@ +use super::*; + +impl LlmDatabase { + /// Returns whether the access token with the given `jti` has been revoked. + pub async fn is_access_token_revoked(&self, jti: &str) -> Result { + self.transaction(|tx| async move { + Ok(revoked_access_token::Entity::find() + .filter(revoked_access_token::Column::Jti.eq(jti)) + .one(&*tx) + .await? + .is_some()) + }) + .await + } +} diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index 2333c20a2e..4beefe2b5d 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -1,5 +1,6 @@ pub mod lifetime_usage; pub mod model; pub mod provider; +pub mod revoked_access_token; pub mod usage; pub mod usage_measure; diff --git a/crates/collab/src/llm/db/tables/revoked_access_token.rs b/crates/collab/src/llm/db/tables/revoked_access_token.rs new file mode 100644 index 0000000000..364963be88 --- /dev/null +++ b/crates/collab/src/llm/db/tables/revoked_access_token.rs @@ -0,0 +1,19 @@ +use chrono::NaiveDateTime; +use sea_orm::entity::prelude::*; + +use crate::llm::db::RevokedAccessTokenId; + +/// A revoked access token. +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "revoked_access_tokens")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: RevokedAccessTokenId, + pub jti: String, + pub revoked_at: NaiveDateTime, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index e0f3c7e573..6994109443 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -150,6 +150,7 @@ async fn main() -> Result<()> { "http_request", method = ?request.method(), matched_path, + authn.jti = tracing::field::Empty ) }) .on_response(