diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 91136b46d0..b6721cee84 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -2,28 +2,22 @@ use super::{ db::{self, UserId}, errors::TideResultExt, }; -use crate::{github, AppState, Request, RequestExt as _}; +use crate::{github, Request, RequestExt as _}; use anyhow::{anyhow, Context}; use async_trait::async_trait; pub use oauth2::basic::BasicClient as Client; -use oauth2::{ - AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, - TokenResponse as _, TokenUrl, -}; use rand::thread_rng; use rpc::auth as zed_auth; use scrypt::{ password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Scrypt, }; -use serde::{Deserialize, Serialize}; -use std::{borrow::Cow, convert::TryFrom, sync::Arc}; -use surf::{StatusCode, Url}; -use tide::{log, Error, Server}; +use serde::Serialize; +use std::convert::TryFrom; +use surf::StatusCode; +use tide::Error; static CURRENT_GITHUB_USER: &'static str = "current_github_user"; -static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize"; -static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token"; #[derive(Serialize)] pub struct User { @@ -99,172 +93,6 @@ impl RequestExt for Request { } } -pub fn build_client(client_id: &str, client_secret: &str) -> Client { - Client::new( - ClientId::new(client_id.to_string()), - Some(oauth2::ClientSecret::new(client_secret.to_string())), - AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(), - Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()), - ) -} - -pub fn add_routes(app: &mut Server>) { - app.at("/sign_in").get(get_sign_in); - app.at("/sign_out").post(post_sign_out); - app.at("/auth_callback").get(get_auth_callback); - app.at("/native_app_signin").get(get_sign_in); - app.at("/native_app_signin_succeeded") - .get(get_app_signin_success); -} - -#[derive(Debug, Deserialize)] -struct NativeAppSignInParams { - native_app_port: String, - native_app_public_key: String, - impersonate: Option, -} - -async fn get_sign_in(mut request: Request) -> tide::Result { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - - request - .session_mut() - .insert("pkce_verifier", pkce_verifier)?; - - let mut redirect_url = Url::parse(&format!( - "{}://{}/auth_callback", - request - .header("X-Forwarded-Proto") - .and_then(|values| values.get(0)) - .map(|value| value.as_str()) - .unwrap_or("http"), - request.host().unwrap() - ))?; - - let app_sign_in_params: Option = request.query().ok(); - if let Some(query) = app_sign_in_params { - let mut redirect_query = redirect_url.query_pairs_mut(); - redirect_query - .clear() - .append_pair("native_app_port", &query.native_app_port) - .append_pair("native_app_public_key", &query.native_app_public_key); - - if let Some(impersonate) = &query.impersonate { - redirect_query.append_pair("impersonate", impersonate); - } - } - - let (auth_url, csrf_token) = request - .state() - .auth_client - .authorize_url(CsrfToken::new_random) - .set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url))) - .set_pkce_challenge(pkce_challenge) - .url(); - - request - .session_mut() - .insert("auth_csrf_token", csrf_token)?; - - Ok(tide::Redirect::new(auth_url).into()) -} - -async fn get_app_signin_success(_: Request) -> tide::Result { - Ok(tide::Redirect::new("/").into()) -} - -async fn get_auth_callback(mut request: Request) -> tide::Result { - #[derive(Debug, Deserialize)] - struct Query { - code: String, - state: String, - - #[serde(flatten)] - native_app_sign_in_params: Option, - } - - let query: Query = request.query()?; - - let pkce_verifier = request - .session() - .get("pkce_verifier") - .ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?; - - let csrf_token = request - .session() - .get::("auth_csrf_token") - .ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?; - - if &query.state != csrf_token.secret() { - return Err(anyhow!("csrf token does not match").into()); - } - - let github_access_token = request - .state() - .auth_client - .exchange_code(AuthorizationCode::new(query.code)) - .set_pkce_verifier(pkce_verifier) - .request_async(oauth2_surf::http_client) - .await - .context("failed to exchange oauth code")? - .access_token() - .secret() - .clone(); - - let user_details = request - .state() - .github_client - .user(github_access_token) - .details() - .await - .context("failed to fetch user")?; - - let user = request - .db() - .get_user_by_github_login(&user_details.login) - .await?; - - request - .session_mut() - .insert(CURRENT_GITHUB_USER, user_details.clone())?; - - // When signing in from the native app, generate a new access token for the current user. Return - // a redirect so that the user's browser sends this access token to the locally-running app. - if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) { - let mut user_id = user.id; - if let Some(impersonated_login) = app_sign_in_params.impersonate { - log::info!("attempting to impersonate user @{}", impersonated_login); - if let Some(user) = request.db().get_users_by_ids(vec![user_id]).await?.first() { - if user.admin { - user_id = request.db().create_user(&impersonated_login, false).await?; - log::info!("impersonating user {}", user_id.0); - } else { - log::info!("refusing to impersonate user"); - } - } - } - - let access_token = create_access_token(request.db().as_ref(), user_id).await?; - let encrypted_access_token = encrypt_access_token( - &access_token, - app_sign_in_params.native_app_public_key.clone(), - )?; - - return Ok(tide::Redirect::new(&format!( - "http://127.0.0.1:{}?user_id={}&access_token={}", - app_sign_in_params.native_app_port, user_id.0, encrypted_access_token, - )) - .into()); - } - - Ok(tide::Redirect::new("/").into()) -} - -async fn post_sign_out(mut request: Request) -> tide::Result { - request.session_mut().remove(CURRENT_GITHUB_USER); - Ok(tide::Redirect::new("/").into()) -} - const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result { diff --git a/crates/collab/src/errors.rs b/crates/collab/src/errors.rs index 1dbb44361f..93e46848a1 100644 --- a/crates/collab/src/errors.rs +++ b/crates/collab/src/errors.rs @@ -1,47 +1,3 @@ -use crate::{AppState, LayoutData, Request, RequestExt}; -use async_trait::async_trait; -use serde::Serialize; -use std::sync::Arc; -use tide::http::mime; - -pub struct Middleware; - -#[async_trait] -impl tide::Middleware> for Middleware { - async fn handle( - &self, - mut request: Request, - next: tide::Next<'_, Arc>, - ) -> tide::Result { - let app = request.state().clone(); - let layout_data = request.layout_data().await?; - - let mut response = next.run(request).await; - - #[derive(Serialize)] - struct ErrorData { - #[serde(flatten)] - layout: Arc, - status: u16, - reason: &'static str, - } - - if !response.status().is_success() { - response.set_body(app.render_template( - "error.hbs", - &ErrorData { - layout: layout_data, - status: response.status().into(), - reason: response.status().canonical_reason(), - }, - )?); - response.set_content_type(mime::HTML); - } - - Ok(response) - } -} - // Allow tide Results to accept context like other Results do when // using anyhow. pub trait TideResultExt { diff --git a/crates/collab/src/expiring.rs b/crates/collab/src/expiring.rs index ba974dc8e0..8b13789179 100644 --- a/crates/collab/src/expiring.rs +++ b/crates/collab/src/expiring.rs @@ -1,43 +1 @@ -use std::{future::Future, time::Instant}; -use async_std::sync::Mutex; - -#[derive(Default)] -pub struct Expiring(Mutex>>); - -pub struct ExpiringState { - value: T, - expires_at: Instant, -} - -impl Expiring { - pub async fn get_or_refresh(&self, f: F) -> tide::Result - where - F: FnOnce() -> G, - G: Future>, - { - let mut state = self.0.lock().await; - - if let Some(state) = state.as_mut() { - if Instant::now() >= state.expires_at { - let (value, expires_at) = f().await?; - state.value = value.clone(); - state.expires_at = expires_at; - Ok(value) - } else { - Ok(state.value.clone()) - } - } else { - let (value, expires_at) = f().await?; - *state = Some(ExpiringState { - value: value.clone(), - expires_at, - }); - Ok(value) - } - } - - pub async fn clear(&self) { - self.0.lock().await.take(); - } -} diff --git a/crates/collab/src/github.rs b/crates/collab/src/github.rs index e5bcb45f30..09cedf9019 100644 --- a/crates/collab/src/github.rs +++ b/crates/collab/src/github.rs @@ -1,12 +1,4 @@ -use crate::expiring::Expiring; -use anyhow::{anyhow, Context}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use std::{ - future::Future, - sync::Arc, - time::{Duration, Instant}, -}; -use surf::{http::Method, RequestBuilder, Url}; +use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize)] pub struct Release { @@ -23,259 +15,14 @@ pub struct Asset { pub url: String, } -pub struct AppClient { - id: usize, - private_key: String, - jwt_bearer_header: Expiring, -} - #[derive(Deserialize)] struct Installation { #[allow(unused)] id: usize, } -impl AppClient { - #[cfg(test)] - pub fn test() -> Arc { - Arc::new(Self { - id: Default::default(), - private_key: Default::default(), - jwt_bearer_header: Default::default(), - }) - } - - pub fn new(id: usize, private_key: String) -> Arc { - Arc::new(Self { - id, - private_key, - jwt_bearer_header: Default::default(), - }) - } - - pub async fn repo(self: &Arc, nwo: String) -> tide::Result { - let installation: Installation = self - .request( - Method::Get, - &format!("/repos/{}/installation", &nwo), - |refresh| self.bearer_header(refresh), - ) - .await?; - - Ok(RepoClient { - app: self.clone(), - nwo, - installation_id: installation.id, - installation_token_header: Default::default(), - }) - } - - pub fn user(self: &Arc, access_token: String) -> UserClient { - UserClient { - app: self.clone(), - access_token, - } - } - - async fn request( - &self, - method: Method, - path: &str, - get_auth_header: F, - ) -> tide::Result - where - T: DeserializeOwned, - F: Fn(bool) -> G, - G: Future>, - { - let mut retried = false; - - loop { - let response = RequestBuilder::new( - method, - Url::parse(&format!("https://api.github.com{}", path))?, - ) - .header("Accept", "application/vnd.github.v3+json") - .header("Authorization", get_auth_header(retried).await?) - .recv_json() - .await; - - if let Err(error) = response.as_ref() { - if error.status() == 401 && !retried { - retried = true; - continue; - } - } - - return response; - } - } - - async fn bearer_header(&self, refresh: bool) -> tide::Result { - if refresh { - self.jwt_bearer_header.clear().await; - } - - self.jwt_bearer_header - .get_or_refresh(|| async { - use jwt_simple::{algorithms::RS256KeyPair, prelude::*}; - use std::time; - - let key_pair = RS256KeyPair::from_pem(&self.private_key) - .with_context(|| format!("invalid private key {:?}", self.private_key))?; - let mut claims = Claims::create(Duration::from_mins(10)); - claims.issued_at = Some(Clock::now_since_epoch() - Duration::from_mins(1)); - claims.issuer = Some(self.id.to_string()); - let token = key_pair.sign(claims).context("failed to sign claims")?; - let expires_at = time::Instant::now() + time::Duration::from_secs(9 * 60); - - Ok((format!("Bearer {}", token), expires_at)) - }) - .await - } - - async fn installation_token_header( - &self, - header: &Expiring, - installation_id: usize, - refresh: bool, - ) -> tide::Result { - if refresh { - header.clear().await; - } - - header - .get_or_refresh(|| async { - #[derive(Debug, Deserialize)] - struct AccessToken { - token: String, - } - - let access_token: AccessToken = self - .request( - Method::Post, - &format!("/app/installations/{}/access_tokens", installation_id), - |refresh| self.bearer_header(refresh), - ) - .await?; - - let header = format!("Token {}", access_token.token); - let expires_at = Instant::now() + Duration::from_secs(60 * 30); - - Ok((header, expires_at)) - }) - .await - } -} - -pub struct RepoClient { - app: Arc, - nwo: String, - installation_id: usize, - installation_token_header: Expiring, -} - -impl RepoClient { - #[cfg(test)] - pub fn test(app_client: &Arc) -> Self { - Self { - app: app_client.clone(), - nwo: String::new(), - installation_id: 0, - installation_token_header: Default::default(), - } - } - - pub async fn releases(&self) -> tide::Result> { - self.get(&format!("/repos/{}/releases?per_page=100", self.nwo)) - .await - } - - pub async fn release_asset(&self, tag: &str, name: &str) -> tide::Result { - let release: Release = self - .get(&format!("/repos/{}/releases/tags/{}", self.nwo, tag)) - .await?; - - let asset = release - .assets - .iter() - .find(|asset| asset.name == name) - .ok_or_else(|| anyhow!("no asset found with name {}", name))?; - - let request = surf::get(&asset.url) - .header("Accept", "application/octet-stream'") - .header( - "Authorization", - self.installation_token_header(false).await?, - ); - - let client = surf::client(); - let mut response = client.send(request).await?; - - // Avoid using `surf::middleware::Redirect` because that type forwards - // the original request headers to the redirect URI. In this case, the - // redirect will be to S3, which forbids us from supplying an - // `Authorization` header. - if response.status().is_redirection() { - if let Some(url) = response.header("location") { - let request = surf::get(url.as_str()).header("Accept", "application/octet-stream"); - response = client.send(request).await?; - } - } - - if !response.status().is_success() { - Err(anyhow!("failed to fetch release asset {} {}", tag, name))?; - } - - Ok(response.take_body()) - } - - async fn get(&self, path: &str) -> tide::Result { - self.request::(Method::Get, path).await - } - - async fn request(&self, method: Method, path: &str) -> tide::Result { - Ok(self - .app - .request(method, path, |refresh| { - self.installation_token_header(refresh) - }) - .await?) - } - - async fn installation_token_header(&self, refresh: bool) -> tide::Result { - self.app - .installation_token_header( - &self.installation_token_header, - self.installation_id, - refresh, - ) - .await - } -} - -pub struct UserClient { - app: Arc, - access_token: String, -} - #[derive(Clone, Debug, Deserialize, Serialize)] pub struct User { pub login: String, pub avatar_url: String, } - -impl UserClient { - pub async fn details(&self) -> tide::Result { - Ok(self - .app - .request(Method::Get, "/user", |_| async { - Ok(self.access_token_header()) - }) - .await?) - } - - fn access_token_header(&self) -> String { - format!("Token {}", self.access_token) - } -} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index b505b95a8e..005175a665 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -8,17 +8,14 @@ mod expiring; mod github; mod rpc; -use self::errors::TideResultExt as _; use ::rpc::Peer; -use anyhow::Result; use async_std::net::TcpListener; use async_trait::async_trait; -use auth::RequestExt as _; use db::{Db, PostgresDb}; -use handlebars::{Handlebars, TemplateRenderError}; +use handlebars::Handlebars; use parking_lot::RwLock; use rust_embed::RustEmbed; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use std::sync::Arc; use surf::http::cookies::SameSite; use tide::sessions::SessionMiddleware; @@ -45,28 +42,16 @@ pub struct Config { pub struct AppState { db: Arc, handlebars: RwLock>, - auth_client: auth::Client, - github_client: Arc, - repo_client: github::RepoClient, config: Config, } impl AppState { async fn new(config: Config) -> tide::Result> { let db = PostgresDb::new(&config.database_url, 5).await?; - let github_client = - github::AppClient::new(config.github_app_id, config.github_private_key.clone()); - let repo_client = github_client - .repo("zed-industries/zed".into()) - .await - .context("failed to initialize github client")?; let this = Self { db: Arc::new(db), handlebars: Default::default(), - auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret), - github_client, - repo_client, config, }; this.register_partials(); @@ -87,49 +72,20 @@ impl AppState { } } } - - fn render_template( - &self, - path: &'static str, - data: &impl Serialize, - ) -> Result { - #[cfg(debug_assertions)] - self.register_partials(); - - self.handlebars.read().render_template( - std::str::from_utf8(&Templates::get(path).unwrap().data).unwrap(), - data, - ) - } } #[async_trait] trait RequestExt { - async fn layout_data(&mut self) -> tide::Result>; fn db(&self) -> &Arc; } #[async_trait] impl RequestExt for Request { - async fn layout_data(&mut self) -> tide::Result> { - if self.ext::>().is_none() { - self.set_ext(Arc::new(LayoutData { - current_user: self.current_user().await?, - })); - } - Ok(self.ext::>().unwrap().clone()) - } - fn db(&self) -> &Arc { &self.state().db } } -#[derive(Serialize)] -struct LayoutData { - current_user: Option, -} - #[async_std::main] async fn main() -> tide::Result<()> { if std::env::var("LOG_JSON").is_ok() { @@ -173,9 +129,7 @@ pub async fn run_server( ) .with_same_site_policy(SameSite::Lax), // Required obtain our session in /auth_callback ); - web.with(errors::Middleware); api::add_routes(&mut web); - auth::add_routes(&mut web); let mut assets = tide::new(); assets.with(CompressMiddleware::new()); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index ba93be570c..8c36366861 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1180,9 +1180,8 @@ fn header_contains_ignore_case( mod tests { use super::*; use crate::{ - auth, db::{tests::TestDb, UserId}, - github, AppState, Config, + AppState, Config, }; use ::rpc::Peer; use client::{ @@ -5731,13 +5730,9 @@ mod tests { let mut config = Config::default(); config.session_secret = "a".repeat(32); config.database_url = test_db.url.clone(); - let github_client = github::AppClient::test(); Arc::new(AppState { db: test_db.db().clone(), handlebars: Default::default(), - auth_client: auth::build_client("", ""), - repo_client: github::RepoClient::test(&github_client), - github_client, config, }) }