Lookup access tokens by id when authenticating a connection
This avoids the cost of hashing an access token multiple times, to compare it to all known access tokens for a given user. Co-authored-by: Antonio Scandurra <antonio@zed.dev>
This commit is contained in:
parent
3464961aa4
commit
26dae3c04e
3 changed files with 86 additions and 46 deletions
|
@ -1,5 +1,5 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
db::{self, UserId},
|
db::{self, AccessTokenId, UserId},
|
||||||
AppState, Error, Result,
|
AppState, Error, Result,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
|
@ -13,6 +13,7 @@ use scrypt::{
|
||||||
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||||
Scrypt,
|
Scrypt,
|
||||||
};
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||||
|
@ -42,20 +43,19 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let mut credentials_valid = false;
|
|
||||||
let state = req.extensions().get::<Arc<AppState>>().unwrap();
|
let state = req.extensions().get::<Arc<AppState>>().unwrap();
|
||||||
if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
|
let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
|
||||||
if state.config.api_token == admin_token {
|
state.config.api_token == admin_token
|
||||||
credentials_valid = true;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for password_hash in state.db.get_access_token_hashes(user_id).await? {
|
let access_token: AccessTokenJson = serde_json::from_str(&access_token)?;
|
||||||
if verify_access_token(access_token, &password_hash)? {
|
|
||||||
credentials_valid = true;
|
let token = state.db.get_access_token(access_token.id).await?;
|
||||||
break;
|
if token.user_id != user_id {
|
||||||
}
|
return Err(anyhow!("no such access token"))?;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
verify_access_token(&access_token.token, &token.hash)?
|
||||||
|
};
|
||||||
|
|
||||||
if credentials_valid {
|
if credentials_valid {
|
||||||
let user = state
|
let user = state
|
||||||
|
@ -75,13 +75,26 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
||||||
|
|
||||||
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
|
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct AccessTokenJson {
|
||||||
|
version: usize,
|
||||||
|
id: AccessTokenId,
|
||||||
|
token: String,
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
|
pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
|
||||||
|
const VERSION: usize = 1;
|
||||||
let access_token = rpc::auth::random_token();
|
let access_token = rpc::auth::random_token();
|
||||||
let access_token_hash =
|
let access_token_hash =
|
||||||
hash_access_token(&access_token).context("failed to hash access token")?;
|
hash_access_token(&access_token).context("failed to hash access token")?;
|
||||||
db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
|
let id = db
|
||||||
|
.create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(access_token)
|
Ok(serde_json::to_string(&AccessTokenJson {
|
||||||
|
version: VERSION,
|
||||||
|
id,
|
||||||
|
token: access_token,
|
||||||
|
})?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn hash_access_token(token: &str) -> Result<String> {
|
fn hash_access_token(token: &str) -> Result<String> {
|
||||||
|
|
|
@ -2746,16 +2746,16 @@ impl Database {
|
||||||
|
|
||||||
// access tokens
|
// access tokens
|
||||||
|
|
||||||
pub async fn create_access_token_hash(
|
pub async fn create_access_token(
|
||||||
&self,
|
&self,
|
||||||
user_id: UserId,
|
user_id: UserId,
|
||||||
access_token_hash: &str,
|
access_token_hash: &str,
|
||||||
max_access_token_count: usize,
|
max_access_token_count: usize,
|
||||||
) -> Result<()> {
|
) -> Result<AccessTokenId> {
|
||||||
self.transaction(|tx| async {
|
self.transaction(|tx| async {
|
||||||
let tx = tx;
|
let tx = tx;
|
||||||
|
|
||||||
access_token::ActiveModel {
|
let token = access_token::ActiveModel {
|
||||||
user_id: ActiveValue::set(user_id),
|
user_id: ActiveValue::set(user_id),
|
||||||
hash: ActiveValue::set(access_token_hash.into()),
|
hash: ActiveValue::set(access_token_hash.into()),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
@ -2778,26 +2778,20 @@ impl Database {
|
||||||
)
|
)
|
||||||
.exec(&*tx)
|
.exec(&*tx)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(token.id)
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
|
pub async fn get_access_token(
|
||||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
|
&self,
|
||||||
enum QueryAs {
|
access_token_id: AccessTokenId,
|
||||||
Hash,
|
) -> Result<access_token::Model> {
|
||||||
}
|
|
||||||
|
|
||||||
self.transaction(|tx| async move {
|
self.transaction(|tx| async move {
|
||||||
Ok(access_token::Entity::find()
|
Ok(access_token::Entity::find_by_id(access_token_id)
|
||||||
.select_only()
|
.one(&*tx)
|
||||||
.column(access_token::Column::Hash)
|
.await?
|
||||||
.filter(access_token::Column::UserId.eq(user_id))
|
.ok_or_else(|| anyhow!("no such access token"))?)
|
||||||
.order_by_desc(access_token::Column::Id)
|
|
||||||
.into_values::<_, QueryAs>()
|
|
||||||
.all(&*tx)
|
|
||||||
.await?)
|
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
@ -177,30 +177,63 @@ test_both_dbs!(
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.user_id;
|
.user_id;
|
||||||
|
|
||||||
db.create_access_token_hash(user, "h1", 3).await.unwrap();
|
let token_1 = db.create_access_token(user, "h1", 2).await.unwrap();
|
||||||
db.create_access_token_hash(user, "h2", 3).await.unwrap();
|
let token_2 = db.create_access_token(user, "h2", 2).await.unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
db.get_access_token_hashes(user).await.unwrap(),
|
db.get_access_token(token_1).await.unwrap(),
|
||||||
&["h2".to_string(), "h1".to_string()]
|
access_token::Model {
|
||||||
|
id: token_1,
|
||||||
|
user_id: user,
|
||||||
|
hash: "h1".into(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
db.get_access_token(token_2).await.unwrap(),
|
||||||
|
access_token::Model {
|
||||||
|
id: token_2,
|
||||||
|
user_id: user,
|
||||||
|
hash: "h2".into()
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
db.create_access_token_hash(user, "h3", 3).await.unwrap();
|
let token_3 = db.create_access_token(user, "h3", 2).await.unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
db.get_access_token_hashes(user).await.unwrap(),
|
db.get_access_token(token_3).await.unwrap(),
|
||||||
&["h3".to_string(), "h2".to_string(), "h1".to_string(),]
|
access_token::Model {
|
||||||
|
id: token_3,
|
||||||
|
user_id: user,
|
||||||
|
hash: "h3".into()
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
assert_eq!(
|
||||||
|
db.get_access_token(token_2).await.unwrap(),
|
||||||
|
access_token::Model {
|
||||||
|
id: token_2,
|
||||||
|
user_id: user,
|
||||||
|
hash: "h2".into()
|
||||||
|
}
|
||||||
|
);
|
||||||
|
assert!(db.get_access_token(token_1).await.is_err());
|
||||||
|
|
||||||
db.create_access_token_hash(user, "h4", 3).await.unwrap();
|
let token_4 = db.create_access_token(user, "h4", 2).await.unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
db.get_access_token_hashes(user).await.unwrap(),
|
db.get_access_token(token_4).await.unwrap(),
|
||||||
&["h4".to_string(), "h3".to_string(), "h2".to_string(),]
|
access_token::Model {
|
||||||
|
id: token_4,
|
||||||
|
user_id: user,
|
||||||
|
hash: "h4".into()
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
db.create_access_token_hash(user, "h5", 3).await.unwrap();
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
db.get_access_token_hashes(user).await.unwrap(),
|
db.get_access_token(token_3).await.unwrap(),
|
||||||
&["h5".to_string(), "h4".to_string(), "h3".to_string()]
|
access_token::Model {
|
||||||
|
id: token_3,
|
||||||
|
user_id: user,
|
||||||
|
hash: "h3".into()
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
assert!(db.get_access_token(token_2).await.is_err());
|
||||||
|
assert!(db.get_access_token(token_1).await.is_err());
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue