This commit is contained in:
Nathan Sobo 2022-04-25 20:05:09 -06:00
parent 35bec69fa4
commit 538fc23a77
3 changed files with 99 additions and 93 deletions

View file

@ -7,42 +7,45 @@ use anyhow::anyhow;
use axum::{
body::Body,
extract::{Path, Query},
http::StatusCode,
routing::{delete, get, post, put},
Json, Router,
http::{self, Request, StatusCode},
middleware::{self, Next},
response::IntoResponse,
routing::{get, post, put},
Extension, Json, Router,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tower::ServiceBuilder;
pub fn add_routes(router: Router<Body>, app: Arc<AppState>) -> Router<Body> {
router
.route("/users", {
let app = app.clone();
get(move || get_users(app))
})
.route("/users", {
let app = app.clone();
post(move |params| create_user(params, app))
})
.route("/users/:id", {
let app = app.clone();
put(move |user_id, params| update_user(user_id, params, app))
})
.route("/users/:id", {
let app = app.clone();
delete(move |user_id| destroy_user(user_id, app))
})
.route("/users/:github_login", {
let app = app.clone();
get(move |github_login| get_user(github_login, app))
})
.route("/users/:github_login/access_tokens", {
let app = app.clone();
post(move |github_login, params| create_access_token(github_login, params, app))
})
pub fn routes(state: Arc<AppState>) -> Router<Body> {
Router::new()
.route("/users", get(get_users).post(create_user))
.route("/users/:id", put(update_user).delete(destroy_user))
.route("/users/:gh_login", get(get_user))
.route("/users/:gh_login/access_tokens", post(create_access_token))
.layer(
ServiceBuilder::new()
.layer(Extension(state))
.layer(middleware::from_fn(validate_api_token)),
)
}
async fn get_users(app: Arc<AppState>) -> Result<Json<Vec<User>>> {
pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
let mut auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::Http(
StatusCode::BAD_REQUEST,
"missing authorization header".to_string(),
)
})?;
Ok::<_, Error>(next.run(req).await)
}
async fn get_users(Extension(app): Extension<Arc<AppState>>) -> Result<Json<Vec<User>>> {
let users = app.db.get_all_users().await?;
Ok(Json(users))
}
@ -55,7 +58,7 @@ struct CreateUserParams {
async fn create_user(
Json(params): Json<CreateUserParams>,
app: Arc<AppState>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<Json<User>> {
let user_id = app
.db
@ -79,7 +82,7 @@ struct UpdateUserParams {
async fn update_user(
Path(user_id): Path<i32>,
Json(params): Json<UpdateUserParams>,
app: Arc<AppState>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<()> {
app.db
.set_user_is_admin(UserId(user_id), params.admin)
@ -87,12 +90,18 @@ async fn update_user(
Ok(())
}
async fn destroy_user(Path(user_id): Path<i32>, app: Arc<AppState>) -> Result<()> {
async fn destroy_user(
Path(user_id): Path<i32>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<()> {
app.db.destroy_user(UserId(user_id)).await?;
Ok(())
}
async fn get_user(Path(login): Path<String>, app: Arc<AppState>) -> Result<Json<User>> {
async fn get_user(
Path(login): Path<String>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<Json<User>> {
let user = app
.db
.get_user_by_github_login(&login)
@ -116,7 +125,7 @@ struct CreateAccessTokenResponse {
async fn create_access_token(
Path(login): Path<String>,
Query(params): Query<CreateAccessTokenQueryParams>,
app: Arc<AppState>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<Json<CreateAccessTokenResponse>> {
// request.require_token().await?;

View file

@ -1,54 +1,64 @@
use std::sync::Arc;
use super::db::{self, UserId};
use crate::{AppState, Error};
use anyhow::{Context, Result};
use axum::{
http::{self, Request, StatusCode},
middleware::Next,
response::IntoResponse,
};
use rand::thread_rng;
use scrypt::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Scrypt,
};
// pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
// let mut auth_header = request
// .header("Authorization")
// .ok_or_else(|| {
// Error::new(
// StatusCode::BadRequest,
// anyhow!("missing authorization header"),
// )
// })?
// .last()
// .as_str()
// .split_whitespace();
// let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
// Error::new(
// StatusCode::BadRequest,
// anyhow!("missing user id in authorization header"),
// )
// })?);
// let access_token = auth_header.next().ok_or_else(|| {
// Error::new(
// StatusCode::BadRequest,
// anyhow!("missing access token in authorization header"),
// )
// })?;
pub async fn validate_header<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
let mut auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.ok_or_else(|| {
Error::Http(
StatusCode::BAD_REQUEST,
"missing authorization header".to_string(),
)
})?
.split_whitespace();
// let state = request.state().clone();
// let mut credentials_valid = false;
// for password_hash in state.db.get_access_token_hashes(user_id).await? {
// if verify_access_token(&access_token, &password_hash)? {
// credentials_valid = true;
// break;
// }
// }
let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
Error::Http(
StatusCode::BAD_REQUEST,
"missing user id in authorization header".to_string(),
)
})?);
// if !credentials_valid {
// Err(Error::new(
// StatusCode::Unauthorized,
// anyhow!("invalid credentials"),
// ))?;
// }
let access_token = auth_header.next().ok_or_else(|| {
Error::Http(
StatusCode::BAD_REQUEST,
"missing access token in authorization header".to_string(),
)
})?;
// Ok(user_id)
// }
let state = req.extensions().get::<Arc<AppState>>().unwrap();
let mut credentials_valid = false;
for password_hash in state.db.get_access_token_hashes(user_id).await? {
if verify_access_token(&access_token, &password_hash)? {
credentials_valid = true;
break;
}
}
if !credentials_valid {
Err(Error::Http(
StatusCode::UNAUTHORIZED,
"invalid credentials".to_string(),
))?;
}
Ok::<_, Error>(next.run(req).await)
}
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;

View file

@ -11,8 +11,6 @@ use db::{Db, PostgresDb};
use serde::Deserialize;
use std::{net::TcpListener, sync::Arc};
// type Request = tide::Request<Arc<AppState>>;
#[derive(Default, Deserialize)]
pub struct Config {
pub http_port: u16,
@ -22,31 +20,20 @@ pub struct Config {
pub struct AppState {
db: Arc<dyn Db>,
config: Config,
api_token: String,
}
impl AppState {
async fn new(config: Config) -> Result<Arc<Self>> {
let db = PostgresDb::new(&config.database_url, 5).await?;
let this = Self {
db: Arc::new(db),
config,
api_token: config.api_token.clone(),
};
Ok(Arc::new(this))
}
}
// trait RequestExt {
// fn db(&self) -> &Arc<dyn Db>;
// }
// impl RequestExt for Request<Body> {
// fn db(&self) -> &Arc<dyn Db> {
// &self.data::<Arc<AppState>>().unwrap().db
// }
// }
#[tokio::main]
async fn main() -> Result<()> {
if std::env::var("LOG_JSON").is_ok() {
@ -68,7 +55,7 @@ async fn main() -> Result<()> {
run_server(
state.clone(),
rpc,
TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
.expect("failed to bind TCP listener"),
)
.await?;
@ -80,11 +67,11 @@ pub async fn run_server(
peer: Arc<Peer>,
listener: TcpListener,
) -> Result<()> {
let app = Router::<Body>::new();
// TODO: Compression on API routes?
// TODO: Authenticate API routes.
let app = api::add_routes(app, state);
let app = Router::<Body>::new().merge(api::routes(state.clone()));
// TODO: Add rpc routes
axum::Server::from_tcp(listener)?