From 538fc23a77cc306657ee7221c9a516707f474f31 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Mon, 25 Apr 2022 20:05:09 -0600 Subject: [PATCH] WIP --- crates/collab/src/api.rs | 79 +++++++++++++++++++--------------- crates/collab/src/auth.rs | 90 ++++++++++++++++++++++----------------- crates/collab/src/main.rs | 23 +++------- 3 files changed, 99 insertions(+), 93 deletions(-) diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 80f2682f4a..180066907e 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -7,42 +7,45 @@ use anyhow::anyhow; use axum::{ body::Body, extract::{Path, Query}, - http::StatusCode, - routing::{delete, get, post, put}, - Json, Router, + http::{self, Request, StatusCode}, + middleware::{self, Next}, + response::IntoResponse, + routing::{get, post, put}, + Extension, Json, Router, }; use serde::{Deserialize, Serialize}; use std::sync::Arc; +use tower::ServiceBuilder; -pub fn add_routes(router: Router, app: Arc) -> Router { - router - .route("/users", { - let app = app.clone(); - get(move || get_users(app)) - }) - .route("/users", { - let app = app.clone(); - post(move |params| create_user(params, app)) - }) - .route("/users/:id", { - let app = app.clone(); - put(move |user_id, params| update_user(user_id, params, app)) - }) - .route("/users/:id", { - let app = app.clone(); - delete(move |user_id| destroy_user(user_id, app)) - }) - .route("/users/:github_login", { - let app = app.clone(); - get(move |github_login| get_user(github_login, app)) - }) - .route("/users/:github_login/access_tokens", { - let app = app.clone(); - post(move |github_login, params| create_access_token(github_login, params, app)) - }) +pub fn routes(state: Arc) -> Router { + Router::new() + .route("/users", get(get_users).post(create_user)) + .route("/users/:id", put(update_user).delete(destroy_user)) + .route("/users/:gh_login", get(get_user)) + .route("/users/:gh_login/access_tokens", post(create_access_token)) + .layer( + ServiceBuilder::new() + .layer(Extension(state)) + .layer(middleware::from_fn(validate_api_token)), + ) } -async fn get_users(app: Arc) -> Result>> { +pub async fn validate_api_token(req: Request, next: Next) -> impl IntoResponse { + let mut auth_header = req + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()) + .ok_or_else(|| { + Error::Http( + StatusCode::BAD_REQUEST, + "missing authorization header".to_string(), + ) + })?; + + Ok::<_, Error>(next.run(req).await) +} + +async fn get_users(Extension(app): Extension>) -> Result>> { let users = app.db.get_all_users().await?; Ok(Json(users)) } @@ -55,7 +58,7 @@ struct CreateUserParams { async fn create_user( Json(params): Json, - app: Arc, + Extension(app): Extension>, ) -> Result> { let user_id = app .db @@ -79,7 +82,7 @@ struct UpdateUserParams { async fn update_user( Path(user_id): Path, Json(params): Json, - app: Arc, + Extension(app): Extension>, ) -> Result<()> { app.db .set_user_is_admin(UserId(user_id), params.admin) @@ -87,12 +90,18 @@ async fn update_user( Ok(()) } -async fn destroy_user(Path(user_id): Path, app: Arc) -> Result<()> { +async fn destroy_user( + Path(user_id): Path, + Extension(app): Extension>, +) -> Result<()> { app.db.destroy_user(UserId(user_id)).await?; Ok(()) } -async fn get_user(Path(login): Path, app: Arc) -> Result> { +async fn get_user( + Path(login): Path, + Extension(app): Extension>, +) -> Result> { let user = app .db .get_user_by_github_login(&login) @@ -116,7 +125,7 @@ struct CreateAccessTokenResponse { async fn create_access_token( Path(login): Path, Query(params): Query, - app: Arc, + Extension(app): Extension>, ) -> Result> { // request.require_token().await?; diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 4fb31749e8..39ae919a69 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -1,54 +1,64 @@ +use std::sync::Arc; + use super::db::{self, UserId}; +use crate::{AppState, Error}; use anyhow::{Context, Result}; +use axum::{ + http::{self, Request, StatusCode}, + middleware::Next, + response::IntoResponse, +}; use rand::thread_rng; use scrypt::{ password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Scrypt, }; -// pub async fn process_auth_header(request: &Request) -> tide::Result { -// let mut auth_header = request -// .header("Authorization") -// .ok_or_else(|| { -// Error::new( -// StatusCode::BadRequest, -// anyhow!("missing authorization header"), -// ) -// })? -// .last() -// .as_str() -// .split_whitespace(); -// let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| { -// Error::new( -// StatusCode::BadRequest, -// anyhow!("missing user id in authorization header"), -// ) -// })?); -// let access_token = auth_header.next().ok_or_else(|| { -// Error::new( -// StatusCode::BadRequest, -// anyhow!("missing access token in authorization header"), -// ) -// })?; +pub async fn validate_header(req: Request, next: Next) -> impl IntoResponse { + let mut auth_header = req + .headers() + .get(http::header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()) + .ok_or_else(|| { + Error::Http( + StatusCode::BAD_REQUEST, + "missing authorization header".to_string(), + ) + })? + .split_whitespace(); -// let state = request.state().clone(); -// let mut credentials_valid = false; -// for password_hash in state.db.get_access_token_hashes(user_id).await? { -// if verify_access_token(&access_token, &password_hash)? { -// credentials_valid = true; -// break; -// } -// } + let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| { + Error::Http( + StatusCode::BAD_REQUEST, + "missing user id in authorization header".to_string(), + ) + })?); -// if !credentials_valid { -// Err(Error::new( -// StatusCode::Unauthorized, -// anyhow!("invalid credentials"), -// ))?; -// } + let access_token = auth_header.next().ok_or_else(|| { + Error::Http( + StatusCode::BAD_REQUEST, + "missing access token in authorization header".to_string(), + ) + })?; -// Ok(user_id) -// } + let state = req.extensions().get::>().unwrap(); + let mut credentials_valid = false; + for password_hash in state.db.get_access_token_hashes(user_id).await? { + if verify_access_token(&access_token, &password_hash)? { + credentials_valid = true; + break; + } + } + + if !credentials_valid { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "invalid credentials".to_string(), + ))?; + } + + Ok::<_, Error>(next.run(req).await) +} const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index c0ea6ba77c..b3da6df4a8 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -11,8 +11,6 @@ use db::{Db, PostgresDb}; use serde::Deserialize; use std::{net::TcpListener, sync::Arc}; -// type Request = tide::Request>; - #[derive(Default, Deserialize)] pub struct Config { pub http_port: u16, @@ -22,31 +20,20 @@ pub struct Config { pub struct AppState { db: Arc, - config: Config, + api_token: String, } impl AppState { async fn new(config: Config) -> Result> { let db = PostgresDb::new(&config.database_url, 5).await?; - let this = Self { db: Arc::new(db), - config, + api_token: config.api_token.clone(), }; Ok(Arc::new(this)) } } -// trait RequestExt { -// fn db(&self) -> &Arc; -// } - -// impl RequestExt for Request { -// fn db(&self) -> &Arc { -// &self.data::>().unwrap().db -// } -// } - #[tokio::main] async fn main() -> Result<()> { if std::env::var("LOG_JSON").is_ok() { @@ -68,7 +55,7 @@ async fn main() -> Result<()> { run_server( state.clone(), rpc, - TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) + TcpListener::bind(&format!("0.0.0.0:{}", config.http_port)) .expect("failed to bind TCP listener"), ) .await?; @@ -80,11 +67,11 @@ pub async fn run_server( peer: Arc, listener: TcpListener, ) -> Result<()> { - let app = Router::::new(); // TODO: Compression on API routes? // TODO: Authenticate API routes. - let app = api::add_routes(app, state); + let app = Router::::new().merge(api::routes(state.clone())); + // TODO: Add rpc routes axum::Server::from_tcp(listener)?