Store the impersonator id on access tokens created via ZED_IMPERSONATE

* Use the impersonator id to prevent these tokens from counting
  against the impersonated user when limiting the users' total
  of access tokens.
* When connecting using an access token with an impersonator
  add the impersonator as a field to the tracing span that wraps
  the task for that connection.
* Disallow impersonating users via the admin API token in production,
  because when using the admin API token, we aren't able to identify
  the impersonator.

Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-01-17 15:46:36 -08:00
parent 9521f49160
commit ab1bea515c
9 changed files with 198 additions and 39 deletions

View file

@ -27,6 +27,9 @@ lazy_static! {
.unwrap();
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Impersonator(pub Option<db::User>);
/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
/// and one for the access tokens that we issue.
pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
@ -57,28 +60,50 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
})?;
let state = req.extensions().get::<Arc<AppState>>().unwrap();
let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
state.config.api_token == admin_token
// In development, allow impersonation using the admin API token.
// Don't allow this in production because we can't tell who is doing
// the impersonating.
let validate_result = if let (Some(admin_token), true) = (
access_token.strip_prefix("ADMIN_TOKEN:"),
state.config.is_development(),
) {
Ok(VerifyAccessTokenResult {
is_valid: state.config.api_token == admin_token,
impersonator_id: None,
})
} else {
verify_access_token(&access_token, user_id, &state.db)
.await
.unwrap_or(false)
verify_access_token(&access_token, user_id, &state.db).await
};
if credentials_valid {
let user = state
.db
.get_user_by_id(user_id)
.await?
.ok_or_else(|| anyhow!("user {} not found", user_id))?;
req.extensions_mut().insert(user);
Ok::<_, Error>(next.run(req).await)
} else {
Err(Error::Http(
StatusCode::UNAUTHORIZED,
"invalid credentials".to_string(),
))
if let Ok(validate_result) = validate_result {
if validate_result.is_valid {
let user = state
.db
.get_user_by_id(user_id)
.await?
.ok_or_else(|| anyhow!("user {} not found", user_id))?;
let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
let impersonator = state
.db
.get_user_by_id(impersonator_id)
.await?
.ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
Some(impersonator)
} else {
None
};
req.extensions_mut().insert(user);
req.extensions_mut().insert(Impersonator(impersonator));
return Ok::<_, Error>(next.run(req).await);
}
}
Err(Error::Http(
StatusCode::UNAUTHORIZED,
"invalid credentials".to_string(),
))
}
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
@ -92,13 +117,22 @@ struct AccessTokenJson {
/// Creates a new access token to identify the given user. before returning it, you should
/// encrypt it with the user's public key.
pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
pub async fn create_access_token(
db: &db::Database,
user_id: UserId,
impersonator_id: Option<UserId>,
) -> Result<String> {
const VERSION: usize = 1;
let access_token = rpc::auth::random_token();
let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?;
let id = db
.create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
.create_access_token(
user_id,
impersonator_id,
&access_token_hash,
MAX_ACCESS_TOKENS_TO_STORE,
)
.await?;
Ok(serde_json::to_string(&AccessTokenJson {
version: VERSION,
@ -137,8 +171,17 @@ pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<St
Ok(encrypted_access_token)
}
pub struct VerifyAccessTokenResult {
pub is_valid: bool,
pub impersonator_id: Option<UserId>,
}
/// verify access token returns true if the given token is valid for the given user.
pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database>) -> Result<bool> {
pub async fn verify_access_token(
token: &str,
user_id: UserId,
db: &Arc<Database>,
) -> Result<VerifyAccessTokenResult> {
let token: AccessTokenJson = serde_json::from_str(&token)?;
let db_token = db.get_access_token(token.id).await?;
@ -154,5 +197,8 @@ pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database
let duration = t0.elapsed();
log::info!("hashed access token in {:?}", duration);
METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
Ok(is_valid)
Ok(VerifyAccessTokenResult {
is_valid,
impersonator_id: db_token.impersonator_id,
})
}