diff --git a/Cargo.lock b/Cargo.lock index 89c04bb990..0ef11676eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -953,6 +953,7 @@ dependencies = [ "tokio", "tokio-tungstenite", "toml", + "tower", "util", "workspace", ] diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 1b44d1228b..8489cc2be6 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -20,7 +20,7 @@ util = { path = "../util" } anyhow = "1.0.40" async-trait = "0.1.50" async-tungstenite = "0.16" -axum = "0.5" +axum = { version = "0.5", features = ["json"] } base64 = "0.13" envy = "0.4.2" env_logger = "0.8" @@ -36,6 +36,7 @@ serde_json = "1.0" sha-1 = "0.9" tokio = { version = "1", features = ["full"] } tokio-tungstenite = "0.17" +tower = "0.4" time = "0.2" toml = "0.5.8" diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 3bb231e0f3..ffb25c39da 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,20 +1,33 @@ -// use crate::{auth, db::UserId, AppState, Request, RequestExt as _}; -use anyhow::Result; +use crate::{ + db::{Db, User, UserId}, + AppState, Result, +}; +use anyhow::anyhow; use axum::{ body::Body, - http::{Request, Response, StatusCode}, - routing::get, - Router, + extract::Path, + http::{Request, StatusCode}, + response::{IntoResponse, Response}, + routing::{get, put}, + Json, Router, }; use serde::Deserialize; -use serde_json::json; use std::sync::Arc; -use crate::AppState; -// use surf::StatusCode; - -pub fn add_routes(router: Router) -> Router { - router.route("/users", get(get_users)) +pub fn add_routes(router: Router, app: Arc) -> Router { + router + .route("/users", { + let app = app.clone(); + get(move |req| get_users(req, app)) + }) + .route("/users", { + let app = app.clone(); + get(move |params| create_user(params, app)) + }) + .route("/users/:id", { + let app = app.clone(); + put(move |user_id, params| update_user(user_id, params, app)) + }) } // pub fn add_routes(app: &mut tide::Server>) { @@ -27,65 +40,48 @@ pub fn add_routes(router: Router) -> Router { // .post(create_access_token); // } -async fn get_users(request: Request) -> Result, (StatusCode, String)> { +async fn get_users(request: Request, app: Arc) -> Result>> { // request.require_token().await?; - // let users = request.db().get_all_users().await?; - - // Body::from - - // let body = "Hello World"; - // Ok(Response::builder() - // .header(CONTENT_LENGTH, body.len() as u64) - // .header(CONTENT_TYPE, "text/plain") - // .body(Body::from(body))?) - - // Ok(tide::Response::builder(StatusCode::Ok) - // .body(tide::Body::from_json(&users)?) - // .build()) - todo!() + let users = app.db.get_all_users().await?; + Ok(Json(users)) } -// async fn get_user(request: Request) -> tide::Result { -// request.require_token().await?; +#[derive(Deserialize)] +struct CreateUser { + github_login: String, + admin: bool, +} -// let user = request -// .db() -// .get_user_by_github_login(request.param("github_login")?) -// .await? -// .ok_or_else(|| surf::Error::from_str(404, "user not found"))?; +async fn create_user(Json(params): Json, app: Arc) -> Result> { + let user_id = app + .db + .create_user(¶ms.github_login, params.admin) + .await?; -// Ok(tide::Response::builder(StatusCode::Ok) -// .body(tide::Body::from_json(&user)?) -// .build()) -// } + let user = app + .db + .get_user_by_id(user_id) + .await? + .ok_or_else(|| anyhow!("couldn't find the user we just created"))?; -// async fn create_user(mut request: Request) -> tide::Result { -// request.require_token().await?; + Ok(Json(user)) +} -// #[derive(Deserialize)] -// struct Params { -// github_login: String, -// admin: bool, -// } -// let params = request.body_json::().await?; +#[derive(Deserialize)] +struct UpdateUser { + admin: bool, +} -// let user_id = request -// .db() -// .create_user(¶ms.github_login, params.admin) -// .await?; - -// let user = request.db().get_user_by_id(user_id).await?.ok_or_else(|| { -// surf::Error::from_str( -// StatusCode::InternalServerError, -// "couldn't find the user we just created", -// ) -// })?; - -// Ok(tide::Response::builder(StatusCode::Ok) -// .body(tide::Body::from_json(&user)?) -// .build()) -// } +async fn update_user( + Path(user_id): Path, + Json(params): Json, + app: Arc, +) -> Result { + let user_id = UserId(user_id); + app.db.set_user_is_admin(user_id, params.admin).await?; + Ok(()) +} // async fn update_user(mut request: Request) -> tide::Result { // request.require_token().await?; @@ -94,13 +90,6 @@ async fn get_users(request: Request) -> Result, (StatusCode // struct Params { // admin: bool, // } -// let user_id = UserId( -// request -// .param("id")? -// .parse::() -// .map_err(|error| surf::Error::from_str(StatusCode::BadRequest, error.to_string()))?, -// ); -// let params = request.body_json::().await?; // request // .db() diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index b7737fd17c..6cd264074b 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -5,8 +5,7 @@ mod env; mod rpc; use ::rpc::Peer; -use anyhow::Result; -use axum::{body::Body, http::StatusCode, Router}; +use axum::{body::Body, http::StatusCode, response::IntoResponse, Router}; use db::{Db, PostgresDb}; use serde::Deserialize; @@ -76,24 +75,16 @@ async fn main() -> Result<()> { Ok(()) } -async fn handle_anyhow_error(err: anyhow::Error) -> (StatusCode, String) { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Something went wrong: {}", err), - ) -} - pub async fn run_server( state: Arc, peer: Arc, listener: TcpListener, ) -> Result<()> { let app = Router::::new(); - // TODO: Assign app state to request somehow // TODO: Compression on API routes? // TODO: Authenticate API routes. - let app = api::add_routes(app); + let app = api::add_routes(app, state); // TODO: Add rpc routes axum::Server::from_tcp(listener)? @@ -102,3 +93,34 @@ pub async fn run_server( Ok(()) } + +type Result = std::result::Result; + +struct Error(anyhow::Error); + +impl From for Error +where + E: Into, +{ + fn from(error: E) -> Self { + Self(error.into()) + } +} + +impl IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &self.0)).into_response() + } +} + +impl std::fmt::Debug for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +}