diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index ffb25c39da..80f2682f4a 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,59 +1,62 @@ use crate::{ - db::{Db, User, UserId}, - AppState, Result, + auth, + db::{User, UserId}, + AppState, Error, Result, }; use anyhow::anyhow; use axum::{ body::Body, - extract::Path, - http::{Request, StatusCode}, - response::{IntoResponse, Response}, - routing::{get, put}, + extract::{Path, Query}, + http::StatusCode, + routing::{delete, get, post, put}, Json, Router, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::sync::Arc; pub fn add_routes(router: Router, app: Arc) -> Router { router .route("/users", { let app = app.clone(); - get(move |req| get_users(req, app)) + get(move || get_users(app)) }) .route("/users", { let app = app.clone(); - get(move |params| create_user(params, app)) + post(move |params| create_user(params, app)) }) .route("/users/:id", { let app = app.clone(); put(move |user_id, params| update_user(user_id, params, app)) }) + .route("/users/:id", { + let app = app.clone(); + delete(move |user_id| destroy_user(user_id, app)) + }) + .route("/users/:github_login", { + let app = app.clone(); + get(move |github_login| get_user(github_login, app)) + }) + .route("/users/:github_login/access_tokens", { + let app = app.clone(); + post(move |github_login, params| create_access_token(github_login, params, app)) + }) } -// pub fn add_routes(app: &mut tide::Server>) { -// app.at("/users").get(get_users); -// app.at("/users").post(create_user); -// app.at("/users/:id").put(update_user); -// app.at("/users/:id").delete(destroy_user); -// app.at("/users/:github_login").get(get_user); -// app.at("/users/:github_login/access_tokens") -// .post(create_access_token); -// } - -async fn get_users(request: Request, app: Arc) -> Result>> { - // request.require_token().await?; - +async fn get_users(app: Arc) -> Result>> { let users = app.db.get_all_users().await?; Ok(Json(users)) } #[derive(Deserialize)] -struct CreateUser { +struct CreateUserParams { github_login: String, admin: bool, } -async fn create_user(Json(params): Json, app: Arc) -> Result> { +async fn create_user( + Json(params): Json, + app: Arc, +) -> Result> { let user_id = app .db .create_user(¶ms.github_login, params.admin) @@ -69,102 +72,88 @@ async fn create_user(Json(params): Json, app: Arc) -> Resu } #[derive(Deserialize)] -struct UpdateUser { +struct UpdateUserParams { admin: bool, } async fn update_user( Path(user_id): Path, - Json(params): Json, + Json(params): Json, app: Arc, -) -> Result { - let user_id = UserId(user_id); - app.db.set_user_is_admin(user_id, params.admin).await?; +) -> Result<()> { + app.db + .set_user_is_admin(UserId(user_id), params.admin) + .await?; Ok(()) } -// async fn update_user(mut request: Request) -> tide::Result { -// request.require_token().await?; +async fn destroy_user(Path(user_id): Path, app: Arc) -> Result<()> { + app.db.destroy_user(UserId(user_id)).await?; + Ok(()) +} -// #[derive(Deserialize)] -// struct Params { -// admin: bool, -// } +async fn get_user(Path(login): Path, app: Arc) -> Result> { + let user = app + .db + .get_user_by_github_login(&login) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + Ok(Json(user)) +} -// request -// .db() -// .set_user_is_admin(user_id, params.admin) -// .await?; +#[derive(Deserialize)] +struct CreateAccessTokenQueryParams { + public_key: String, + impersonate: Option, +} -// Ok(tide::Response::builder(StatusCode::Ok).build()) -// } +#[derive(Serialize)] +struct CreateAccessTokenResponse { + user_id: UserId, + encrypted_access_token: String, +} -// async fn destroy_user(request: Request) -> tide::Result { -// request.require_token().await?; -// let user_id = UserId( -// request -// .param("id")? -// .parse::() -// .map_err(|error| surf::Error::from_str(StatusCode::BadRequest, error.to_string()))?, -// ); +async fn create_access_token( + Path(login): Path, + Query(params): Query, + app: Arc, +) -> Result> { + // request.require_token().await?; -// request.db().destroy_user(user_id).await?; + let user = app + .db + .get_user_by_github_login(&login) + .await? + .ok_or_else(|| anyhow!("user not found"))?; -// Ok(tide::Response::builder(StatusCode::Ok).build()) -// } + let mut user_id = user.id; + if let Some(impersonate) = params.impersonate { + if user.admin { + if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? { + user_id = impersonated_user.id; + } else { + return Err(Error::Http( + StatusCode::UNPROCESSABLE_ENTITY, + format!("user {impersonate} does not exist"), + )); + } + } else { + return Err(Error::Http( + StatusCode::UNAUTHORIZED, + format!("you do not have permission to impersonate other users"), + )); + } + } -// async fn create_access_token(request: Request) -> tide::Result { -// request.require_token().await?; + let access_token = auth::create_access_token(app.db.as_ref(), user_id).await?; + let encrypted_access_token = + auth::encrypt_access_token(&access_token, params.public_key.clone())?; -// let user = request -// .db() -// .get_user_by_github_login(request.param("github_login")?) -// .await? -// .ok_or_else(|| surf::Error::from_str(StatusCode::NotFound, "user not found"))?; - -// #[derive(Deserialize)] -// struct QueryParams { -// public_key: String, -// impersonate: Option, -// } - -// let query_params: QueryParams = request.query().map_err(|_| { -// surf::Error::from_str(StatusCode::UnprocessableEntity, "invalid query params") -// })?; - -// let mut user_id = user.id; -// if let Some(impersonate) = query_params.impersonate { -// if user.admin { -// if let Some(impersonated_user) = -// request.db().get_user_by_github_login(&impersonate).await? -// { -// user_id = impersonated_user.id; -// } else { -// return Ok(tide::Response::builder(StatusCode::UnprocessableEntity) -// .body(format!( -// "Can't impersonate non-existent user {}", -// impersonate -// )) -// .build()); -// } -// } else { -// return Ok(tide::Response::builder(StatusCode::Unauthorized) -// .body(format!( -// "Can't impersonate user {} because the real user isn't an admin", -// impersonate -// )) -// .build()); -// } -// } - -// let access_token = auth::create_access_token(request.db().as_ref(), user_id).await?; -// let encrypted_access_token = -// auth::encrypt_access_token(&access_token, query_params.public_key.clone())?; - -// Ok(tide::Response::builder(StatusCode::Ok) -// .body(json!({"user_id": user_id, "encrypted_access_token": encrypted_access_token})) -// .build()) -// } + Ok(Json(CreateAccessTokenResponse { + user_id, + encrypted_access_token, + })) +} // #[async_trait] // pub trait RequestExt { diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 9bbd949641..4fb31749e8 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -1,18 +1,10 @@ -// use super::{ -// db::{self, UserId}, -// errors::TideResultExt, -// }; -// use crate::Request; -// use anyhow::{anyhow, Context}; -// use rand::thread_rng; -// use rpc::auth as zed_auth; -// use scrypt::{ -// password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, -// Scrypt, -// }; -// use std::convert::TryFrom; -// use surf::StatusCode; -// use tide::Error; +use super::db::{self, UserId}; +use anyhow::{Context, Result}; +use rand::thread_rng; +use scrypt::{ + password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, + Scrypt, +}; // pub async fn process_auth_header(request: &Request) -> tide::Result { // let mut auth_header = request @@ -58,45 +50,45 @@ // Ok(user_id) // } -// const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; +const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; -// pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result { -// let access_token = zed_auth::random_token(); -// let access_token_hash = -// hash_access_token(&access_token).context("failed to hash access token")?; -// db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) -// .await?; -// Ok(access_token) -// } +pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> Result { + let access_token = rpc::auth::random_token(); + let access_token_hash = + hash_access_token(&access_token).context("failed to hash access token")?; + db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) + .await?; + Ok(access_token) +} -// fn hash_access_token(token: &str) -> tide::Result { -// // Avoid slow hashing in debug mode. -// let params = if cfg!(debug_assertions) { -// scrypt::Params::new(1, 1, 1).unwrap() -// } else { -// scrypt::Params::recommended() -// }; +fn hash_access_token(token: &str) -> Result { + // Avoid slow hashing in debug mode. + let params = if cfg!(debug_assertions) { + scrypt::Params::new(1, 1, 1).unwrap() + } else { + scrypt::Params::recommended() + }; -// Ok(Scrypt -// .hash_password( -// token.as_bytes(), -// None, -// params, -// &SaltString::generate(thread_rng()), -// )? -// .to_string()) -// } + Ok(Scrypt + .hash_password( + token.as_bytes(), + None, + params, + &SaltString::generate(thread_rng()), + )? + .to_string()) +} -// pub fn encrypt_access_token(access_token: &str, public_key: String) -> tide::Result { -// let native_app_public_key = -// zed_auth::PublicKey::try_from(public_key).context("failed to parse app public key")?; -// let encrypted_access_token = native_app_public_key -// .encrypt_string(&access_token) -// .context("failed to encrypt access token with public key")?; -// Ok(encrypted_access_token) -// } +pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result { + let native_app_public_key = + rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?; + let encrypted_access_token = native_app_public_key + .encrypt_string(&access_token) + .context("failed to encrypt access token with public key")?; + Ok(encrypted_access_token) +} -// pub fn verify_access_token(token: &str, hash: &str) -> tide::Result { -// let hash = PasswordHash::new(hash)?; -// Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok()) -// } +pub fn verify_access_token(token: &str, hash: &str) -> Result { + let hash = PasswordHash::new(hash)?; + Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok()) +} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 6cd264074b..c0ea6ba77c 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -94,33 +94,47 @@ pub async fn run_server( Ok(()) } -type Result = std::result::Result; +pub type Result = std::result::Result; -struct Error(anyhow::Error); +pub enum Error { + Http(StatusCode, String), + Internal(anyhow::Error), +} impl From for Error where E: Into, { fn from(error: E) -> Self { - Self(error.into()) + Self::Internal(error.into()) } } impl IntoResponse for Error { fn into_response(self) -> axum::response::Response { - (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &self.0)).into_response() + match self { + Error::Http(code, message) => (code, message).into_response(), + Error::Internal(error) => { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() + } + } } } impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) + match self { + Error::Http(code, message) => (code, message).fmt(f), + Error::Internal(error) => error.fmt(f), + } } } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) + match self { + Error::Http(code, message) => write!(f, "{code}: {message}"), + Error::Internal(error) => error.fmt(f), + } } }