For impersonating access tokens, store impersonatee in the new column

This way, we don't need an index on both columns
This commit is contained in:
Max Brunsfeld 2024-01-17 17:58:59 -08:00
parent 69bff7bb77
commit 9f04fd9019
8 changed files with 38 additions and 45 deletions

View file

@ -19,11 +19,10 @@ CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id");
CREATE TABLE "access_tokens" ( CREATE TABLE "access_tokens" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT, "id" INTEGER PRIMARY KEY AUTOINCREMENT,
"user_id" INTEGER REFERENCES users (id), "user_id" INTEGER REFERENCES users (id),
"impersonator_id" INTEGER REFERENCES users (id), "impersonated_user_id" INTEGER REFERENCES users (id),
"hash" VARCHAR(128) "hash" VARCHAR(128)
); );
CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id"); CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id");
CREATE INDEX "index_access_tokens_impersonator_id" ON "access_tokens" ("impersonator_id");
CREATE TABLE "contacts" ( CREATE TABLE "contacts" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT, "id" INTEGER PRIMARY KEY AUTOINCREMENT,

View file

@ -1,3 +1 @@
ALTER TABLE access_tokens ADD COLUMN impersonator_id integer; ALTER TABLE access_tokens ADD COLUMN impersonated_user_id integer;
CREATE INDEX "index_access_tokens_impersonator_id" ON "access_tokens" ("impersonator_id");

View file

@ -156,13 +156,11 @@ async fn create_access_token(
.await? .await?
.ok_or_else(|| anyhow!("user not found"))?; .ok_or_else(|| anyhow!("user not found"))?;
let mut user_id = user.id; let mut impersonated_user_id = None;
let mut impersonator_id = None;
if let Some(impersonate) = params.impersonate { if let Some(impersonate) = params.impersonate {
if user.admin { if user.admin {
if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? { if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
impersonator_id = Some(user_id); impersonated_user_id = Some(impersonated_user.id);
user_id = impersonated_user.id;
} else { } else {
return Err(Error::Http( return Err(Error::Http(
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
@ -177,12 +175,13 @@ async fn create_access_token(
} }
} }
let access_token = auth::create_access_token(app.db.as_ref(), user_id, impersonator_id).await?; let access_token =
auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
let encrypted_access_token = let encrypted_access_token =
auth::encrypt_access_token(&access_token, params.public_key.clone())?; auth::encrypt_access_token(&access_token, params.public_key.clone())?;
Ok(Json(CreateAccessTokenResponse { Ok(Json(CreateAccessTokenResponse {
user_id, user_id: impersonated_user_id.unwrap_or(user_id),
encrypted_access_token, encrypted_access_token,
})) }))
} }

View file

@ -120,7 +120,7 @@ struct AccessTokenJson {
pub async fn create_access_token( pub async fn create_access_token(
db: &db::Database, db: &db::Database,
user_id: UserId, user_id: UserId,
impersonator_id: Option<UserId>, impersonated_user_id: Option<UserId>,
) -> Result<String> { ) -> Result<String> {
const VERSION: usize = 1; const VERSION: usize = 1;
let access_token = rpc::auth::random_token(); let access_token = rpc::auth::random_token();
@ -129,7 +129,7 @@ pub async fn create_access_token(
let id = db let id = db
.create_access_token( .create_access_token(
user_id, user_id,
impersonator_id, impersonated_user_id,
&access_token_hash, &access_token_hash,
MAX_ACCESS_TOKENS_TO_STORE, MAX_ACCESS_TOKENS_TO_STORE,
) )
@ -185,7 +185,8 @@ pub async fn verify_access_token(
let token: AccessTokenJson = serde_json::from_str(&token)?; let token: AccessTokenJson = serde_json::from_str(&token)?;
let db_token = db.get_access_token(token.id).await?; let db_token = db.get_access_token(token.id).await?;
if db_token.user_id != user_id { let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
if token_user_id != user_id {
return Err(anyhow!("no such access token"))?; return Err(anyhow!("no such access token"))?;
} }
@ -199,6 +200,10 @@ pub async fn verify_access_token(
METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64); METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
Ok(VerifyAccessTokenResult { Ok(VerifyAccessTokenResult {
is_valid, is_valid,
impersonator_id: db_token.impersonator_id, impersonator_id: if db_token.impersonated_user_id.is_some() {
Some(db_token.user_id)
} else {
None
},
}) })
} }

View file

@ -6,7 +6,7 @@ impl Database {
pub async fn create_access_token( pub async fn create_access_token(
&self, &self,
user_id: UserId, user_id: UserId,
impersonator_id: Option<UserId>, impersonated_user_id: Option<UserId>,
access_token_hash: &str, access_token_hash: &str,
max_access_token_count: usize, max_access_token_count: usize,
) -> Result<AccessTokenId> { ) -> Result<AccessTokenId> {
@ -15,28 +15,20 @@ impl Database {
let token = access_token::ActiveModel { let token = access_token::ActiveModel {
user_id: ActiveValue::set(user_id), user_id: ActiveValue::set(user_id),
impersonator_id: ActiveValue::set(impersonator_id), impersonated_user_id: ActiveValue::set(impersonated_user_id),
hash: ActiveValue::set(access_token_hash.into()), hash: ActiveValue::set(access_token_hash.into()),
..Default::default() ..Default::default()
} }
.insert(&*tx) .insert(&*tx)
.await?; .await?;
let existing_token_filter = if let Some(impersonator_id) = impersonator_id {
access_token::Column::ImpersonatorId.eq(impersonator_id)
} else {
access_token::Column::UserId
.eq(user_id)
.and(access_token::Column::ImpersonatorId.is_null())
};
access_token::Entity::delete_many() access_token::Entity::delete_many()
.filter( .filter(
access_token::Column::Id.in_subquery( access_token::Column::Id.in_subquery(
Query::select() Query::select()
.column(access_token::Column::Id) .column(access_token::Column::Id)
.from(access_token::Entity) .from(access_token::Entity)
.cond_where(existing_token_filter) .and_where(access_token::Column::UserId.eq(user_id))
.order_by(access_token::Column::Id, sea_orm::Order::Desc) .order_by(access_token::Column::Id, sea_orm::Order::Desc)
.limit(10000) .limit(10000)
.offset(max_access_token_count as u64) .offset(max_access_token_count as u64)

View file

@ -7,7 +7,7 @@ pub struct Model {
#[sea_orm(primary_key)] #[sea_orm(primary_key)]
pub id: AccessTokenId, pub id: AccessTokenId,
pub user_id: UserId, pub user_id: UserId,
pub impersonator_id: Option<UserId>, pub impersonated_user_id: Option<UserId>,
pub hash: String, pub hash: String,
} }

View file

@ -178,7 +178,7 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
access_token::Model { access_token::Model {
id: token_1, id: token_1,
user_id: user_1, user_id: user_1,
impersonator_id: None, impersonated_user_id: None,
hash: "h1".into(), hash: "h1".into(),
} }
); );
@ -187,7 +187,7 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
access_token::Model { access_token::Model {
id: token_2, id: token_2,
user_id: user_1, user_id: user_1,
impersonator_id: None, impersonated_user_id: None,
hash: "h2".into() hash: "h2".into()
} }
); );
@ -198,7 +198,7 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
access_token::Model { access_token::Model {
id: token_3, id: token_3,
user_id: user_1, user_id: user_1,
impersonator_id: None, impersonated_user_id: None,
hash: "h3".into() hash: "h3".into()
} }
); );
@ -207,7 +207,7 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
access_token::Model { access_token::Model {
id: token_2, id: token_2,
user_id: user_1, user_id: user_1,
impersonator_id: None, impersonated_user_id: None,
hash: "h2".into() hash: "h2".into()
} }
); );
@ -219,7 +219,7 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
access_token::Model { access_token::Model {
id: token_4, id: token_4,
user_id: user_1, user_id: user_1,
impersonator_id: None, impersonated_user_id: None,
hash: "h4".into() hash: "h4".into()
} }
); );
@ -228,7 +228,7 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
access_token::Model { access_token::Model {
id: token_3, id: token_3,
user_id: user_1, user_id: user_1,
impersonator_id: None, impersonated_user_id: None,
hash: "h3".into() hash: "h3".into()
} }
); );
@ -238,15 +238,15 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
// An access token for user 2 impersonating user 1 does not // An access token for user 2 impersonating user 1 does not
// count against user 1's access token limit (of 2). // count against user 1's access token limit (of 2).
let token_5 = db let token_5 = db
.create_access_token(user_1, Some(user_2), "h5", 2) .create_access_token(user_2, Some(user_1), "h5", 2)
.await .await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
db.get_access_token(token_5).await.unwrap(), db.get_access_token(token_5).await.unwrap(),
access_token::Model { access_token::Model {
id: token_5, id: token_5,
user_id: user_1, user_id: user_2,
impersonator_id: Some(user_2), impersonated_user_id: Some(user_1),
hash: "h5".into() hash: "h5".into()
} }
); );
@ -255,7 +255,7 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
access_token::Model { access_token::Model {
id: token_3, id: token_3,
user_id: user_1, user_id: user_1,
impersonator_id: None, impersonated_user_id: None,
hash: "h3".into() hash: "h3".into()
} }
); );
@ -263,19 +263,19 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
// Only a limited number (2) of access tokens are stored for user 2 // Only a limited number (2) of access tokens are stored for user 2
// impersonating other users. // impersonating other users.
let token_6 = db let token_6 = db
.create_access_token(user_1, Some(user_2), "h6", 2) .create_access_token(user_2, Some(user_1), "h6", 2)
.await .await
.unwrap(); .unwrap();
let token_7 = db let token_7 = db
.create_access_token(user_1, Some(user_2), "h7", 2) .create_access_token(user_2, Some(user_1), "h7", 2)
.await .await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
db.get_access_token(token_6).await.unwrap(), db.get_access_token(token_6).await.unwrap(),
access_token::Model { access_token::Model {
id: token_6, id: token_6,
user_id: user_1, user_id: user_2,
impersonator_id: Some(user_2), impersonated_user_id: Some(user_1),
hash: "h6".into() hash: "h6".into()
} }
); );
@ -283,8 +283,8 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
db.get_access_token(token_7).await.unwrap(), db.get_access_token(token_7).await.unwrap(),
access_token::Model { access_token::Model {
id: token_7, id: token_7,
user_id: user_1, user_id: user_2,
impersonator_id: Some(user_2), impersonated_user_id: Some(user_1),
hash: "h7".into() hash: "h7".into()
} }
); );
@ -294,7 +294,7 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
access_token::Model { access_token::Model {
id: token_3, id: token_3,
user_id: user_1, user_id: user_1,
impersonator_id: None, impersonated_user_id: None,
hash: "h3".into() hash: "h3".into()
} }
); );

View file

@ -14,7 +14,7 @@
- Ensure that the Xcode command line tools are using your newly installed copy of Xcode: - Ensure that the Xcode command line tools are using your newly installed copy of Xcode:
``` ```
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer. sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
``` ```
* Install the Rust wasm toolchain: * Install the Rust wasm toolchain: