Lookup access tokens by id when authenticating a connection
This avoids the cost of hashing an access token multiple times, to compare it to all known access tokens for a given user. Co-authored-by: Antonio Scandurra <antonio@zed.dev>
This commit is contained in:
parent
3464961aa4
commit
26dae3c04e
3 changed files with 86 additions and 46 deletions
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
db::{self, UserId},
|
||||
db::{self, AccessTokenId, UserId},
|
||||
AppState, Error, Result,
|
||||
};
|
||||
use anyhow::{anyhow, Context};
|
||||
|
@ -13,6 +13,7 @@ use scrypt::{
|
|||
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Scrypt,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
|
@ -42,20 +43,19 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
|||
)
|
||||
})?;
|
||||
|
||||
let mut credentials_valid = false;
|
||||
let state = req.extensions().get::<Arc<AppState>>().unwrap();
|
||||
if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
|
||||
if state.config.api_token == admin_token {
|
||||
credentials_valid = true;
|
||||
}
|
||||
let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
|
||||
state.config.api_token == admin_token
|
||||
} else {
|
||||
for password_hash in state.db.get_access_token_hashes(user_id).await? {
|
||||
if verify_access_token(access_token, &password_hash)? {
|
||||
credentials_valid = true;
|
||||
break;
|
||||
}
|
||||
let access_token: AccessTokenJson = serde_json::from_str(&access_token)?;
|
||||
|
||||
let token = state.db.get_access_token(access_token.id).await?;
|
||||
if token.user_id != user_id {
|
||||
return Err(anyhow!("no such access token"))?;
|
||||
}
|
||||
}
|
||||
|
||||
verify_access_token(&access_token.token, &token.hash)?
|
||||
};
|
||||
|
||||
if credentials_valid {
|
||||
let user = state
|
||||
|
@ -75,13 +75,26 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
|||
|
||||
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AccessTokenJson {
|
||||
version: usize,
|
||||
id: AccessTokenId,
|
||||
token: String,
|
||||
}
|
||||
|
||||
pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
|
||||
const VERSION: usize = 1;
|
||||
let access_token = rpc::auth::random_token();
|
||||
let access_token_hash =
|
||||
hash_access_token(&access_token).context("failed to hash access token")?;
|
||||
db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
|
||||
let id = db
|
||||
.create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
|
||||
.await?;
|
||||
Ok(access_token)
|
||||
Ok(serde_json::to_string(&AccessTokenJson {
|
||||
version: VERSION,
|
||||
id,
|
||||
token: access_token,
|
||||
})?)
|
||||
}
|
||||
|
||||
fn hash_access_token(token: &str) -> Result<String> {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue