Remove more unused code related to GitHub auth and errors

This commit is contained in:
Nathan Sobo 2022-04-21 08:57:49 -06:00
parent 9150b77471
commit 9f0b044ba0
6 changed files with 9 additions and 571 deletions

View file

@ -2,28 +2,22 @@ use super::{
db::{self, UserId},
errors::TideResultExt,
};
use crate::{github, AppState, Request, RequestExt as _};
use crate::{github, Request, RequestExt as _};
use anyhow::{anyhow, Context};
use async_trait::async_trait;
pub use oauth2::basic::BasicClient as Client;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
TokenResponse as _, TokenUrl,
};
use rand::thread_rng;
use rpc::auth as zed_auth;
use scrypt::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Scrypt,
};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, convert::TryFrom, sync::Arc};
use surf::{StatusCode, Url};
use tide::{log, Error, Server};
use serde::Serialize;
use std::convert::TryFrom;
use surf::StatusCode;
use tide::Error;
static CURRENT_GITHUB_USER: &'static str = "current_github_user";
static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
#[derive(Serialize)]
pub struct User {
@ -99,172 +93,6 @@ impl RequestExt for Request {
}
}
pub fn build_client(client_id: &str, client_secret: &str) -> Client {
Client::new(
ClientId::new(client_id.to_string()),
Some(oauth2::ClientSecret::new(client_secret.to_string())),
AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
)
}
pub fn add_routes(app: &mut Server<Arc<AppState>>) {
app.at("/sign_in").get(get_sign_in);
app.at("/sign_out").post(post_sign_out);
app.at("/auth_callback").get(get_auth_callback);
app.at("/native_app_signin").get(get_sign_in);
app.at("/native_app_signin_succeeded")
.get(get_app_signin_success);
}
#[derive(Debug, Deserialize)]
struct NativeAppSignInParams {
native_app_port: String,
native_app_public_key: String,
impersonate: Option<String>,
}
async fn get_sign_in(mut request: Request) -> tide::Result {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
request
.session_mut()
.insert("pkce_verifier", pkce_verifier)?;
let mut redirect_url = Url::parse(&format!(
"{}://{}/auth_callback",
request
.header("X-Forwarded-Proto")
.and_then(|values| values.get(0))
.map(|value| value.as_str())
.unwrap_or("http"),
request.host().unwrap()
))?;
let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
if let Some(query) = app_sign_in_params {
let mut redirect_query = redirect_url.query_pairs_mut();
redirect_query
.clear()
.append_pair("native_app_port", &query.native_app_port)
.append_pair("native_app_public_key", &query.native_app_public_key);
if let Some(impersonate) = &query.impersonate {
redirect_query.append_pair("impersonate", impersonate);
}
}
let (auth_url, csrf_token) = request
.state()
.auth_client
.authorize_url(CsrfToken::new_random)
.set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
.set_pkce_challenge(pkce_challenge)
.url();
request
.session_mut()
.insert("auth_csrf_token", csrf_token)?;
Ok(tide::Redirect::new(auth_url).into())
}
async fn get_app_signin_success(_: Request) -> tide::Result {
Ok(tide::Redirect::new("/").into())
}
async fn get_auth_callback(mut request: Request) -> tide::Result {
#[derive(Debug, Deserialize)]
struct Query {
code: String,
state: String,
#[serde(flatten)]
native_app_sign_in_params: Option<NativeAppSignInParams>,
}
let query: Query = request.query()?;
let pkce_verifier = request
.session()
.get("pkce_verifier")
.ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
let csrf_token = request
.session()
.get::<CsrfToken>("auth_csrf_token")
.ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
if &query.state != csrf_token.secret() {
return Err(anyhow!("csrf token does not match").into());
}
let github_access_token = request
.state()
.auth_client
.exchange_code(AuthorizationCode::new(query.code))
.set_pkce_verifier(pkce_verifier)
.request_async(oauth2_surf::http_client)
.await
.context("failed to exchange oauth code")?
.access_token()
.secret()
.clone();
let user_details = request
.state()
.github_client
.user(github_access_token)
.details()
.await
.context("failed to fetch user")?;
let user = request
.db()
.get_user_by_github_login(&user_details.login)
.await?;
request
.session_mut()
.insert(CURRENT_GITHUB_USER, user_details.clone())?;
// When signing in from the native app, generate a new access token for the current user. Return
// a redirect so that the user's browser sends this access token to the locally-running app.
if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
let mut user_id = user.id;
if let Some(impersonated_login) = app_sign_in_params.impersonate {
log::info!("attempting to impersonate user @{}", impersonated_login);
if let Some(user) = request.db().get_users_by_ids(vec![user_id]).await?.first() {
if user.admin {
user_id = request.db().create_user(&impersonated_login, false).await?;
log::info!("impersonating user {}", user_id.0);
} else {
log::info!("refusing to impersonate user");
}
}
}
let access_token = create_access_token(request.db().as_ref(), user_id).await?;
let encrypted_access_token = encrypt_access_token(
&access_token,
app_sign_in_params.native_app_public_key.clone(),
)?;
return Ok(tide::Redirect::new(&format!(
"http://127.0.0.1:{}?user_id={}&access_token={}",
app_sign_in_params.native_app_port, user_id.0, encrypted_access_token,
))
.into());
}
Ok(tide::Redirect::new("/").into())
}
async fn post_sign_out(mut request: Request) -> tide::Result {
request.session_mut().remove(CURRENT_GITHUB_USER);
Ok(tide::Redirect::new("/").into())
}
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result<String> {

View file

@ -1,47 +1,3 @@
use crate::{AppState, LayoutData, Request, RequestExt};
use async_trait::async_trait;
use serde::Serialize;
use std::sync::Arc;
use tide::http::mime;
pub struct Middleware;
#[async_trait]
impl tide::Middleware<Arc<AppState>> for Middleware {
async fn handle(
&self,
mut request: Request,
next: tide::Next<'_, Arc<AppState>>,
) -> tide::Result {
let app = request.state().clone();
let layout_data = request.layout_data().await?;
let mut response = next.run(request).await;
#[derive(Serialize)]
struct ErrorData {
#[serde(flatten)]
layout: Arc<LayoutData>,
status: u16,
reason: &'static str,
}
if !response.status().is_success() {
response.set_body(app.render_template(
"error.hbs",
&ErrorData {
layout: layout_data,
status: response.status().into(),
reason: response.status().canonical_reason(),
},
)?);
response.set_content_type(mime::HTML);
}
Ok(response)
}
}
// Allow tide Results to accept context like other Results do when
// using anyhow.
pub trait TideResultExt {

View file

@ -1,43 +1 @@
use std::{future::Future, time::Instant};
use async_std::sync::Mutex;
#[derive(Default)]
pub struct Expiring<T>(Mutex<Option<ExpiringState<T>>>);
pub struct ExpiringState<T> {
value: T,
expires_at: Instant,
}
impl<T: Clone> Expiring<T> {
pub async fn get_or_refresh<F, G>(&self, f: F) -> tide::Result<T>
where
F: FnOnce() -> G,
G: Future<Output = tide::Result<(T, Instant)>>,
{
let mut state = self.0.lock().await;
if let Some(state) = state.as_mut() {
if Instant::now() >= state.expires_at {
let (value, expires_at) = f().await?;
state.value = value.clone();
state.expires_at = expires_at;
Ok(value)
} else {
Ok(state.value.clone())
}
} else {
let (value, expires_at) = f().await?;
*state = Some(ExpiringState {
value: value.clone(),
expires_at,
});
Ok(value)
}
}
pub async fn clear(&self) {
self.0.lock().await.take();
}
}

View file

@ -1,12 +1,4 @@
use crate::expiring::Expiring;
use anyhow::{anyhow, Context};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
future::Future,
sync::Arc,
time::{Duration, Instant},
};
use surf::{http::Method, RequestBuilder, Url};
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize)]
pub struct Release {
@ -23,259 +15,14 @@ pub struct Asset {
pub url: String,
}
pub struct AppClient {
id: usize,
private_key: String,
jwt_bearer_header: Expiring<String>,
}
#[derive(Deserialize)]
struct Installation {
#[allow(unused)]
id: usize,
}
impl AppClient {
#[cfg(test)]
pub fn test() -> Arc<Self> {
Arc::new(Self {
id: Default::default(),
private_key: Default::default(),
jwt_bearer_header: Default::default(),
})
}
pub fn new(id: usize, private_key: String) -> Arc<Self> {
Arc::new(Self {
id,
private_key,
jwt_bearer_header: Default::default(),
})
}
pub async fn repo(self: &Arc<Self>, nwo: String) -> tide::Result<RepoClient> {
let installation: Installation = self
.request(
Method::Get,
&format!("/repos/{}/installation", &nwo),
|refresh| self.bearer_header(refresh),
)
.await?;
Ok(RepoClient {
app: self.clone(),
nwo,
installation_id: installation.id,
installation_token_header: Default::default(),
})
}
pub fn user(self: &Arc<Self>, access_token: String) -> UserClient {
UserClient {
app: self.clone(),
access_token,
}
}
async fn request<T, F, G>(
&self,
method: Method,
path: &str,
get_auth_header: F,
) -> tide::Result<T>
where
T: DeserializeOwned,
F: Fn(bool) -> G,
G: Future<Output = tide::Result<String>>,
{
let mut retried = false;
loop {
let response = RequestBuilder::new(
method,
Url::parse(&format!("https://api.github.com{}", path))?,
)
.header("Accept", "application/vnd.github.v3+json")
.header("Authorization", get_auth_header(retried).await?)
.recv_json()
.await;
if let Err(error) = response.as_ref() {
if error.status() == 401 && !retried {
retried = true;
continue;
}
}
return response;
}
}
async fn bearer_header(&self, refresh: bool) -> tide::Result<String> {
if refresh {
self.jwt_bearer_header.clear().await;
}
self.jwt_bearer_header
.get_or_refresh(|| async {
use jwt_simple::{algorithms::RS256KeyPair, prelude::*};
use std::time;
let key_pair = RS256KeyPair::from_pem(&self.private_key)
.with_context(|| format!("invalid private key {:?}", self.private_key))?;
let mut claims = Claims::create(Duration::from_mins(10));
claims.issued_at = Some(Clock::now_since_epoch() - Duration::from_mins(1));
claims.issuer = Some(self.id.to_string());
let token = key_pair.sign(claims).context("failed to sign claims")?;
let expires_at = time::Instant::now() + time::Duration::from_secs(9 * 60);
Ok((format!("Bearer {}", token), expires_at))
})
.await
}
async fn installation_token_header(
&self,
header: &Expiring<String>,
installation_id: usize,
refresh: bool,
) -> tide::Result<String> {
if refresh {
header.clear().await;
}
header
.get_or_refresh(|| async {
#[derive(Debug, Deserialize)]
struct AccessToken {
token: String,
}
let access_token: AccessToken = self
.request(
Method::Post,
&format!("/app/installations/{}/access_tokens", installation_id),
|refresh| self.bearer_header(refresh),
)
.await?;
let header = format!("Token {}", access_token.token);
let expires_at = Instant::now() + Duration::from_secs(60 * 30);
Ok((header, expires_at))
})
.await
}
}
pub struct RepoClient {
app: Arc<AppClient>,
nwo: String,
installation_id: usize,
installation_token_header: Expiring<String>,
}
impl RepoClient {
#[cfg(test)]
pub fn test(app_client: &Arc<AppClient>) -> Self {
Self {
app: app_client.clone(),
nwo: String::new(),
installation_id: 0,
installation_token_header: Default::default(),
}
}
pub async fn releases(&self) -> tide::Result<Vec<Release>> {
self.get(&format!("/repos/{}/releases?per_page=100", self.nwo))
.await
}
pub async fn release_asset(&self, tag: &str, name: &str) -> tide::Result<surf::Body> {
let release: Release = self
.get(&format!("/repos/{}/releases/tags/{}", self.nwo, tag))
.await?;
let asset = release
.assets
.iter()
.find(|asset| asset.name == name)
.ok_or_else(|| anyhow!("no asset found with name {}", name))?;
let request = surf::get(&asset.url)
.header("Accept", "application/octet-stream'")
.header(
"Authorization",
self.installation_token_header(false).await?,
);
let client = surf::client();
let mut response = client.send(request).await?;
// Avoid using `surf::middleware::Redirect` because that type forwards
// the original request headers to the redirect URI. In this case, the
// redirect will be to S3, which forbids us from supplying an
// `Authorization` header.
if response.status().is_redirection() {
if let Some(url) = response.header("location") {
let request = surf::get(url.as_str()).header("Accept", "application/octet-stream");
response = client.send(request).await?;
}
}
if !response.status().is_success() {
Err(anyhow!("failed to fetch release asset {} {}", tag, name))?;
}
Ok(response.take_body())
}
async fn get<T: DeserializeOwned>(&self, path: &str) -> tide::Result<T> {
self.request::<T>(Method::Get, path).await
}
async fn request<T: DeserializeOwned>(&self, method: Method, path: &str) -> tide::Result<T> {
Ok(self
.app
.request(method, path, |refresh| {
self.installation_token_header(refresh)
})
.await?)
}
async fn installation_token_header(&self, refresh: bool) -> tide::Result<String> {
self.app
.installation_token_header(
&self.installation_token_header,
self.installation_id,
refresh,
)
.await
}
}
pub struct UserClient {
app: Arc<AppClient>,
access_token: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct User {
pub login: String,
pub avatar_url: String,
}
impl UserClient {
pub async fn details(&self) -> tide::Result<User> {
Ok(self
.app
.request(Method::Get, "/user", |_| async {
Ok(self.access_token_header())
})
.await?)
}
fn access_token_header(&self) -> String {
format!("Token {}", self.access_token)
}
}

View file

@ -8,17 +8,14 @@ mod expiring;
mod github;
mod rpc;
use self::errors::TideResultExt as _;
use ::rpc::Peer;
use anyhow::Result;
use async_std::net::TcpListener;
use async_trait::async_trait;
use auth::RequestExt as _;
use db::{Db, PostgresDb};
use handlebars::{Handlebars, TemplateRenderError};
use handlebars::Handlebars;
use parking_lot::RwLock;
use rust_embed::RustEmbed;
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use std::sync::Arc;
use surf::http::cookies::SameSite;
use tide::sessions::SessionMiddleware;
@ -45,28 +42,16 @@ pub struct Config {
pub struct AppState {
db: Arc<dyn Db>,
handlebars: RwLock<Handlebars<'static>>,
auth_client: auth::Client,
github_client: Arc<github::AppClient>,
repo_client: github::RepoClient,
config: Config,
}
impl AppState {
async fn new(config: Config) -> tide::Result<Arc<Self>> {
let db = PostgresDb::new(&config.database_url, 5).await?;
let github_client =
github::AppClient::new(config.github_app_id, config.github_private_key.clone());
let repo_client = github_client
.repo("zed-industries/zed".into())
.await
.context("failed to initialize github client")?;
let this = Self {
db: Arc::new(db),
handlebars: Default::default(),
auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret),
github_client,
repo_client,
config,
};
this.register_partials();
@ -87,49 +72,20 @@ impl AppState {
}
}
}
fn render_template(
&self,
path: &'static str,
data: &impl Serialize,
) -> Result<String, TemplateRenderError> {
#[cfg(debug_assertions)]
self.register_partials();
self.handlebars.read().render_template(
std::str::from_utf8(&Templates::get(path).unwrap().data).unwrap(),
data,
)
}
}
#[async_trait]
trait RequestExt {
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
fn db(&self) -> &Arc<dyn Db>;
}
#[async_trait]
impl RequestExt for Request {
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>> {
if self.ext::<Arc<LayoutData>>().is_none() {
self.set_ext(Arc::new(LayoutData {
current_user: self.current_user().await?,
}));
}
Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
}
fn db(&self) -> &Arc<dyn Db> {
&self.state().db
}
}
#[derive(Serialize)]
struct LayoutData {
current_user: Option<auth::User>,
}
#[async_std::main]
async fn main() -> tide::Result<()> {
if std::env::var("LOG_JSON").is_ok() {
@ -173,9 +129,7 @@ pub async fn run_server(
)
.with_same_site_policy(SameSite::Lax), // Required obtain our session in /auth_callback
);
web.with(errors::Middleware);
api::add_routes(&mut web);
auth::add_routes(&mut web);
let mut assets = tide::new();
assets.with(CompressMiddleware::new());

View file

@ -1180,9 +1180,8 @@ fn header_contains_ignore_case<T>(
mod tests {
use super::*;
use crate::{
auth,
db::{tests::TestDb, UserId},
github, AppState, Config,
AppState, Config,
};
use ::rpc::Peer;
use client::{
@ -5731,13 +5730,9 @@ mod tests {
let mut config = Config::default();
config.session_secret = "a".repeat(32);
config.database_url = test_db.url.clone();
let github_client = github::AppClient::test();
Arc::new(AppState {
db: test_db.db().clone(),
handlebars: Default::default(),
auth_client: auth::build_client("", ""),
repo_client: github::RepoClient::test(&github_client),
github_client,
config,
})
}