diff --git a/Cargo.lock b/Cargo.lock
index 89c04bb990..0ef11676eb 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -953,6 +953,7 @@ dependencies = [
"tokio",
"tokio-tungstenite",
"toml",
+ "tower",
"util",
"workspace",
]
diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml
index 1b44d1228b..8489cc2be6 100644
--- a/crates/collab/Cargo.toml
+++ b/crates/collab/Cargo.toml
@@ -20,7 +20,7 @@ util = { path = "../util" }
anyhow = "1.0.40"
async-trait = "0.1.50"
async-tungstenite = "0.16"
-axum = "0.5"
+axum = { version = "0.5", features = ["json"] }
base64 = "0.13"
envy = "0.4.2"
env_logger = "0.8"
@@ -36,6 +36,7 @@ serde_json = "1.0"
sha-1 = "0.9"
tokio = { version = "1", features = ["full"] }
tokio-tungstenite = "0.17"
+tower = "0.4"
time = "0.2"
toml = "0.5.8"
diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs
index 3bb231e0f3..ffb25c39da 100644
--- a/crates/collab/src/api.rs
+++ b/crates/collab/src/api.rs
@@ -1,20 +1,33 @@
-// use crate::{auth, db::UserId, AppState, Request, RequestExt as _};
-use anyhow::Result;
+use crate::{
+ db::{Db, User, UserId},
+ AppState, Result,
+};
+use anyhow::anyhow;
use axum::{
body::Body,
- http::{Request, Response, StatusCode},
- routing::get,
- Router,
+ extract::Path,
+ http::{Request, StatusCode},
+ response::{IntoResponse, Response},
+ routing::{get, put},
+ Json, Router,
};
use serde::Deserialize;
-use serde_json::json;
use std::sync::Arc;
-use crate::AppState;
-// use surf::StatusCode;
-
-pub fn add_routes(router: Router
) -> Router {
- router.route("/users", get(get_users))
+pub fn add_routes(router: Router, app: Arc) -> Router {
+ router
+ .route("/users", {
+ let app = app.clone();
+ get(move |req| get_users(req, app))
+ })
+ .route("/users", {
+ let app = app.clone();
+ get(move |params| create_user(params, app))
+ })
+ .route("/users/:id", {
+ let app = app.clone();
+ put(move |user_id, params| update_user(user_id, params, app))
+ })
}
// pub fn add_routes(app: &mut tide::Server>) {
@@ -27,65 +40,48 @@ pub fn add_routes(router: Router) -> Router {
// .post(create_access_token);
// }
-async fn get_users(request: Request) -> Result, (StatusCode, String)> {
+async fn get_users(request: Request, app: Arc) -> Result>> {
// request.require_token().await?;
- // let users = request.db().get_all_users().await?;
-
- // Body::from
-
- // let body = "Hello World";
- // Ok(Response::builder()
- // .header(CONTENT_LENGTH, body.len() as u64)
- // .header(CONTENT_TYPE, "text/plain")
- // .body(Body::from(body))?)
-
- // Ok(tide::Response::builder(StatusCode::Ok)
- // .body(tide::Body::from_json(&users)?)
- // .build())
- todo!()
+ let users = app.db.get_all_users().await?;
+ Ok(Json(users))
}
-// async fn get_user(request: Request) -> tide::Result {
-// request.require_token().await?;
+#[derive(Deserialize)]
+struct CreateUser {
+ github_login: String,
+ admin: bool,
+}
-// let user = request
-// .db()
-// .get_user_by_github_login(request.param("github_login")?)
-// .await?
-// .ok_or_else(|| surf::Error::from_str(404, "user not found"))?;
+async fn create_user(Json(params): Json, app: Arc) -> Result> {
+ let user_id = app
+ .db
+ .create_user(¶ms.github_login, params.admin)
+ .await?;
-// Ok(tide::Response::builder(StatusCode::Ok)
-// .body(tide::Body::from_json(&user)?)
-// .build())
-// }
+ let user = app
+ .db
+ .get_user_by_id(user_id)
+ .await?
+ .ok_or_else(|| anyhow!("couldn't find the user we just created"))?;
-// async fn create_user(mut request: Request) -> tide::Result {
-// request.require_token().await?;
+ Ok(Json(user))
+}
-// #[derive(Deserialize)]
-// struct Params {
-// github_login: String,
-// admin: bool,
-// }
-// let params = request.body_json::().await?;
+#[derive(Deserialize)]
+struct UpdateUser {
+ admin: bool,
+}
-// let user_id = request
-// .db()
-// .create_user(¶ms.github_login, params.admin)
-// .await?;
-
-// let user = request.db().get_user_by_id(user_id).await?.ok_or_else(|| {
-// surf::Error::from_str(
-// StatusCode::InternalServerError,
-// "couldn't find the user we just created",
-// )
-// })?;
-
-// Ok(tide::Response::builder(StatusCode::Ok)
-// .body(tide::Body::from_json(&user)?)
-// .build())
-// }
+async fn update_user(
+ Path(user_id): Path,
+ Json(params): Json,
+ app: Arc,
+) -> Result {
+ let user_id = UserId(user_id);
+ app.db.set_user_is_admin(user_id, params.admin).await?;
+ Ok(())
+}
// async fn update_user(mut request: Request) -> tide::Result {
// request.require_token().await?;
@@ -94,13 +90,6 @@ async fn get_users(request: Request) -> Result, (StatusCode
// struct Params {
// admin: bool,
// }
-// let user_id = UserId(
-// request
-// .param("id")?
-// .parse::()
-// .map_err(|error| surf::Error::from_str(StatusCode::BadRequest, error.to_string()))?,
-// );
-// let params = request.body_json::().await?;
// request
// .db()
diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs
index b7737fd17c..6cd264074b 100644
--- a/crates/collab/src/main.rs
+++ b/crates/collab/src/main.rs
@@ -5,8 +5,7 @@ mod env;
mod rpc;
use ::rpc::Peer;
-use anyhow::Result;
-use axum::{body::Body, http::StatusCode, Router};
+use axum::{body::Body, http::StatusCode, response::IntoResponse, Router};
use db::{Db, PostgresDb};
use serde::Deserialize;
@@ -76,24 +75,16 @@ async fn main() -> Result<()> {
Ok(())
}
-async fn handle_anyhow_error(err: anyhow::Error) -> (StatusCode, String) {
- (
- StatusCode::INTERNAL_SERVER_ERROR,
- format!("Something went wrong: {}", err),
- )
-}
-
pub async fn run_server(
state: Arc,
peer: Arc,
listener: TcpListener,
) -> Result<()> {
let app = Router::::new();
- // TODO: Assign app state to request somehow
// TODO: Compression on API routes?
// TODO: Authenticate API routes.
- let app = api::add_routes(app);
+ let app = api::add_routes(app, state);
// TODO: Add rpc routes
axum::Server::from_tcp(listener)?
@@ -102,3 +93,34 @@ pub async fn run_server(
Ok(())
}
+
+type Result = std::result::Result;
+
+struct Error(anyhow::Error);
+
+impl From for Error
+where
+ E: Into,
+{
+ fn from(error: E) -> Self {
+ Self(error.into())
+ }
+}
+
+impl IntoResponse for Error {
+ fn into_response(self) -> axum::response::Response {
+ (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &self.0)).into_response()
+ }
+}
+
+impl std::fmt::Debug for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.fmt(f)
+ }
+}