diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs
index 80f2682f4a..180066907e 100644
--- a/crates/collab/src/api.rs
+++ b/crates/collab/src/api.rs
@@ -7,42 +7,45 @@ use anyhow::anyhow;
use axum::{
body::Body,
extract::{Path, Query},
- http::StatusCode,
- routing::{delete, get, post, put},
- Json, Router,
+ http::{self, Request, StatusCode},
+ middleware::{self, Next},
+ response::IntoResponse,
+ routing::{get, post, put},
+ Extension, Json, Router,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
+use tower::ServiceBuilder;
-pub fn add_routes(router: Router
, app: Arc) -> Router {
- router
- .route("/users", {
- let app = app.clone();
- get(move || get_users(app))
- })
- .route("/users", {
- let app = app.clone();
- post(move |params| create_user(params, app))
- })
- .route("/users/:id", {
- let app = app.clone();
- put(move |user_id, params| update_user(user_id, params, app))
- })
- .route("/users/:id", {
- let app = app.clone();
- delete(move |user_id| destroy_user(user_id, app))
- })
- .route("/users/:github_login", {
- let app = app.clone();
- get(move |github_login| get_user(github_login, app))
- })
- .route("/users/:github_login/access_tokens", {
- let app = app.clone();
- post(move |github_login, params| create_access_token(github_login, params, app))
- })
+pub fn routes(state: Arc) -> Router {
+ Router::new()
+ .route("/users", get(get_users).post(create_user))
+ .route("/users/:id", put(update_user).delete(destroy_user))
+ .route("/users/:gh_login", get(get_user))
+ .route("/users/:gh_login/access_tokens", post(create_access_token))
+ .layer(
+ ServiceBuilder::new()
+ .layer(Extension(state))
+ .layer(middleware::from_fn(validate_api_token)),
+ )
}
-async fn get_users(app: Arc) -> Result>> {
+pub async fn validate_api_token(req: Request, next: Next) -> impl IntoResponse {
+ let mut auth_header = req
+ .headers()
+ .get(http::header::AUTHORIZATION)
+ .and_then(|header| header.to_str().ok())
+ .ok_or_else(|| {
+ Error::Http(
+ StatusCode::BAD_REQUEST,
+ "missing authorization header".to_string(),
+ )
+ })?;
+
+ Ok::<_, Error>(next.run(req).await)
+}
+
+async fn get_users(Extension(app): Extension>) -> Result>> {
let users = app.db.get_all_users().await?;
Ok(Json(users))
}
@@ -55,7 +58,7 @@ struct CreateUserParams {
async fn create_user(
Json(params): Json,
- app: Arc,
+ Extension(app): Extension>,
) -> Result> {
let user_id = app
.db
@@ -79,7 +82,7 @@ struct UpdateUserParams {
async fn update_user(
Path(user_id): Path,
Json(params): Json,
- app: Arc,
+ Extension(app): Extension>,
) -> Result<()> {
app.db
.set_user_is_admin(UserId(user_id), params.admin)
@@ -87,12 +90,18 @@ async fn update_user(
Ok(())
}
-async fn destroy_user(Path(user_id): Path, app: Arc) -> Result<()> {
+async fn destroy_user(
+ Path(user_id): Path,
+ Extension(app): Extension>,
+) -> Result<()> {
app.db.destroy_user(UserId(user_id)).await?;
Ok(())
}
-async fn get_user(Path(login): Path, app: Arc) -> Result> {
+async fn get_user(
+ Path(login): Path,
+ Extension(app): Extension>,
+) -> Result> {
let user = app
.db
.get_user_by_github_login(&login)
@@ -116,7 +125,7 @@ struct CreateAccessTokenResponse {
async fn create_access_token(
Path(login): Path,
Query(params): Query,
- app: Arc,
+ Extension(app): Extension>,
) -> Result> {
// request.require_token().await?;
diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs
index 4fb31749e8..39ae919a69 100644
--- a/crates/collab/src/auth.rs
+++ b/crates/collab/src/auth.rs
@@ -1,54 +1,64 @@
+use std::sync::Arc;
+
use super::db::{self, UserId};
+use crate::{AppState, Error};
use anyhow::{Context, Result};
+use axum::{
+ http::{self, Request, StatusCode},
+ middleware::Next,
+ response::IntoResponse,
+};
use rand::thread_rng;
use scrypt::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Scrypt,
};
-// pub async fn process_auth_header(request: &Request) -> tide::Result {
-// let mut auth_header = request
-// .header("Authorization")
-// .ok_or_else(|| {
-// Error::new(
-// StatusCode::BadRequest,
-// anyhow!("missing authorization header"),
-// )
-// })?
-// .last()
-// .as_str()
-// .split_whitespace();
-// let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
-// Error::new(
-// StatusCode::BadRequest,
-// anyhow!("missing user id in authorization header"),
-// )
-// })?);
-// let access_token = auth_header.next().ok_or_else(|| {
-// Error::new(
-// StatusCode::BadRequest,
-// anyhow!("missing access token in authorization header"),
-// )
-// })?;
+pub async fn validate_header(req: Request, next: Next) -> impl IntoResponse {
+ let mut auth_header = req
+ .headers()
+ .get(http::header::AUTHORIZATION)
+ .and_then(|header| header.to_str().ok())
+ .ok_or_else(|| {
+ Error::Http(
+ StatusCode::BAD_REQUEST,
+ "missing authorization header".to_string(),
+ )
+ })?
+ .split_whitespace();
-// let state = request.state().clone();
-// let mut credentials_valid = false;
-// for password_hash in state.db.get_access_token_hashes(user_id).await? {
-// if verify_access_token(&access_token, &password_hash)? {
-// credentials_valid = true;
-// break;
-// }
-// }
+ let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
+ Error::Http(
+ StatusCode::BAD_REQUEST,
+ "missing user id in authorization header".to_string(),
+ )
+ })?);
-// if !credentials_valid {
-// Err(Error::new(
-// StatusCode::Unauthorized,
-// anyhow!("invalid credentials"),
-// ))?;
-// }
+ let access_token = auth_header.next().ok_or_else(|| {
+ Error::Http(
+ StatusCode::BAD_REQUEST,
+ "missing access token in authorization header".to_string(),
+ )
+ })?;
-// Ok(user_id)
-// }
+ let state = req.extensions().get::>().unwrap();
+ let mut credentials_valid = false;
+ for password_hash in state.db.get_access_token_hashes(user_id).await? {
+ if verify_access_token(&access_token, &password_hash)? {
+ credentials_valid = true;
+ break;
+ }
+ }
+
+ if !credentials_valid {
+ Err(Error::Http(
+ StatusCode::UNAUTHORIZED,
+ "invalid credentials".to_string(),
+ ))?;
+ }
+
+ Ok::<_, Error>(next.run(req).await)
+}
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs
index c0ea6ba77c..b3da6df4a8 100644
--- a/crates/collab/src/main.rs
+++ b/crates/collab/src/main.rs
@@ -11,8 +11,6 @@ use db::{Db, PostgresDb};
use serde::Deserialize;
use std::{net::TcpListener, sync::Arc};
-// type Request = tide::Request>;
-
#[derive(Default, Deserialize)]
pub struct Config {
pub http_port: u16,
@@ -22,31 +20,20 @@ pub struct Config {
pub struct AppState {
db: Arc,
- config: Config,
+ api_token: String,
}
impl AppState {
async fn new(config: Config) -> Result> {
let db = PostgresDb::new(&config.database_url, 5).await?;
-
let this = Self {
db: Arc::new(db),
- config,
+ api_token: config.api_token.clone(),
};
Ok(Arc::new(this))
}
}
-// trait RequestExt {
-// fn db(&self) -> &Arc;
-// }
-
-// impl RequestExt for Request {
-// fn db(&self) -> &Arc {
-// &self.data::>().unwrap().db
-// }
-// }
-
#[tokio::main]
async fn main() -> Result<()> {
if std::env::var("LOG_JSON").is_ok() {
@@ -68,7 +55,7 @@ async fn main() -> Result<()> {
run_server(
state.clone(),
rpc,
- TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
+ TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
.expect("failed to bind TCP listener"),
)
.await?;
@@ -80,11 +67,11 @@ pub async fn run_server(
peer: Arc,
listener: TcpListener,
) -> Result<()> {
- let app = Router::::new();
// TODO: Compression on API routes?
// TODO: Authenticate API routes.
- let app = api::add_routes(app, state);
+ let app = Router::::new().merge(api::routes(state.clone()));
+
// TODO: Add rpc routes
axum::Server::from_tcp(listener)?