Move all crates to a top-level crates folder

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Nathan Sobo 2021-10-04 13:22:21 -06:00
parent d768224182
commit fdfed3d7db
282 changed files with 195588 additions and 16 deletions

117
crates/server/src/admin.rs Normal file
View file

@ -0,0 +1,117 @@
use crate::{auth::RequestExt as _, db, AppState, LayoutData, Request, RequestExt as _};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use surf::http::mime;
#[async_trait]
pub trait RequestExt {
async fn require_admin(&self) -> tide::Result<()>;
}
#[async_trait]
impl RequestExt for Request {
async fn require_admin(&self) -> tide::Result<()> {
let current_user = self
.current_user()
.await?
.ok_or_else(|| tide::Error::from_str(401, "not logged in"))?;
if current_user.is_admin {
Ok(())
} else {
Err(tide::Error::from_str(
403,
"authenticated user is not an admin",
))
}
}
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
app.at("/admin").get(get_admin_page);
app.at("/users").post(post_user);
app.at("/users/:id").put(put_user);
app.at("/users/:id/delete").post(delete_user);
app.at("/signups/:id/delete").post(delete_signup);
}
#[derive(Serialize)]
struct AdminData {
#[serde(flatten)]
layout: Arc<LayoutData>,
users: Vec<db::User>,
signups: Vec<db::Signup>,
}
async fn get_admin_page(mut request: Request) -> tide::Result {
request.require_admin().await?;
let data = AdminData {
layout: request.layout_data().await?,
users: request.db().get_all_users().await?,
signups: request.db().get_all_signups().await?,
};
Ok(tide::Response::builder(200)
.body(request.state().render_template("admin.hbs", &data)?)
.content_type(mime::HTML)
.build())
}
async fn post_user(mut request: Request) -> tide::Result {
request.require_admin().await?;
#[derive(Deserialize)]
struct Form {
github_login: String,
#[serde(default)]
admin: bool,
}
let form = request.body_form::<Form>().await?;
let github_login = form
.github_login
.strip_prefix("@")
.unwrap_or(&form.github_login);
if !github_login.is_empty() {
request.db().create_user(github_login, form.admin).await?;
}
Ok(tide::Redirect::new("/admin").into())
}
async fn put_user(mut request: Request) -> tide::Result {
request.require_admin().await?;
let user_id = request.param("id")?.parse()?;
#[derive(Deserialize)]
struct Body {
admin: bool,
}
let body: Body = request.body_json().await?;
request
.db()
.set_user_is_admin(db::UserId(user_id), body.admin)
.await?;
Ok(tide::Response::builder(200).build())
}
async fn delete_user(request: Request) -> tide::Result {
request.require_admin().await?;
let user_id = db::UserId(request.param("id")?.parse()?);
request.db().delete_user(user_id).await?;
Ok(tide::Redirect::new("/admin").into())
}
async fn delete_signup(request: Request) -> tide::Result {
request.require_admin().await?;
let signup_id = db::SignupId(request.param("id")?.parse()?);
request.db().delete_signup(signup_id).await?;
Ok(tide::Redirect::new("/admin").into())
}

View file

@ -0,0 +1,29 @@
use anyhow::anyhow;
use rust_embed::RustEmbed;
use tide::{http::mime, Server};
#[derive(RustEmbed)]
#[folder = "static"]
struct Static;
pub fn add_routes(app: &mut Server<()>) {
app.at("/*path").get(get_static_asset);
}
async fn get_static_asset(request: tide::Request<()>) -> tide::Result {
let path = request.param("path").unwrap();
let content = Static::get(path).ok_or_else(|| anyhow!("asset not found at {}", path))?;
let content_type = if path.starts_with("svg") {
mime::SVG
} else if path.starts_with("styles") {
mime::CSS
} else {
mime::BYTE_STREAM
};
Ok(tide::Response::builder(200)
.content_type(content_type)
.body(content.data.as_ref())
.build())
}

295
crates/server/src/auth.rs Normal file
View file

@ -0,0 +1,295 @@
use super::{
db::{self, UserId},
errors::TideResultExt,
};
use crate::{github, AppState, Request, RequestExt as _};
use anyhow::{anyhow, Context};
use async_trait::async_trait;
pub use oauth2::basic::BasicClient as Client;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
TokenResponse as _, TokenUrl,
};
use rand::thread_rng;
use scrypt::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Scrypt,
};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, convert::TryFrom, sync::Arc};
use surf::{StatusCode, Url};
use tide::{log, Error, Server};
use zrpc::auth as zed_auth;
static CURRENT_GITHUB_USER: &'static str = "current_github_user";
static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
#[derive(Serialize)]
pub struct User {
pub github_login: String,
pub avatar_url: String,
pub is_insider: bool,
pub is_admin: bool,
}
pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
let mut auth_header = request
.header("Authorization")
.ok_or_else(|| {
Error::new(
StatusCode::BadRequest,
anyhow!("missing authorization header"),
)
})?
.last()
.as_str()
.split_whitespace();
let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
Error::new(
StatusCode::BadRequest,
anyhow!("missing user id in authorization header"),
)
})?);
let access_token = auth_header.next().ok_or_else(|| {
Error::new(
StatusCode::BadRequest,
anyhow!("missing access token in authorization header"),
)
})?;
let state = request.state().clone();
let mut credentials_valid = false;
for password_hash in state.db.get_access_token_hashes(user_id).await? {
if verify_access_token(&access_token, &password_hash)? {
credentials_valid = true;
break;
}
}
if !credentials_valid {
Err(Error::new(
StatusCode::Unauthorized,
anyhow!("invalid credentials"),
))?;
}
Ok(user_id)
}
#[async_trait]
pub trait RequestExt {
async fn current_user(&self) -> tide::Result<Option<User>>;
}
#[async_trait]
impl RequestExt for Request {
async fn current_user(&self) -> tide::Result<Option<User>> {
if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
let user = self.db().get_user_by_github_login(&details.login).await?;
Ok(Some(User {
github_login: details.login,
avatar_url: details.avatar_url,
is_insider: user.is_some(),
is_admin: user.map_or(false, |user| user.admin),
}))
} else {
Ok(None)
}
}
}
pub fn build_client(client_id: &str, client_secret: &str) -> Client {
Client::new(
ClientId::new(client_id.to_string()),
Some(oauth2::ClientSecret::new(client_secret.to_string())),
AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
)
}
pub fn add_routes(app: &mut Server<Arc<AppState>>) {
app.at("/sign_in").get(get_sign_in);
app.at("/sign_out").post(post_sign_out);
app.at("/auth_callback").get(get_auth_callback);
}
#[derive(Debug, Deserialize)]
struct NativeAppSignInParams {
native_app_port: String,
native_app_public_key: String,
impersonate: Option<String>,
}
async fn get_sign_in(mut request: Request) -> tide::Result {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
request
.session_mut()
.insert("pkce_verifier", pkce_verifier)?;
let mut redirect_url = Url::parse(&format!(
"{}://{}/auth_callback",
request
.header("X-Forwarded-Proto")
.and_then(|values| values.get(0))
.map(|value| value.as_str())
.unwrap_or("http"),
request.host().unwrap()
))?;
let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
if let Some(query) = app_sign_in_params {
let mut redirect_query = redirect_url.query_pairs_mut();
redirect_query
.clear()
.append_pair("native_app_port", &query.native_app_port)
.append_pair("native_app_public_key", &query.native_app_public_key);
if let Some(impersonate) = &query.impersonate {
redirect_query.append_pair("impersonate", impersonate);
}
}
let (auth_url, csrf_token) = request
.state()
.auth_client
.authorize_url(CsrfToken::new_random)
.set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
.set_pkce_challenge(pkce_challenge)
.url();
request
.session_mut()
.insert("auth_csrf_token", csrf_token)?;
Ok(tide::Redirect::new(auth_url).into())
}
async fn get_auth_callback(mut request: Request) -> tide::Result {
#[derive(Debug, Deserialize)]
struct Query {
code: String,
state: String,
#[serde(flatten)]
native_app_sign_in_params: Option<NativeAppSignInParams>,
}
let query: Query = request.query()?;
let pkce_verifier = request
.session()
.get("pkce_verifier")
.ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
let csrf_token = request
.session()
.get::<CsrfToken>("auth_csrf_token")
.ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
if &query.state != csrf_token.secret() {
return Err(anyhow!("csrf token does not match").into());
}
let github_access_token = request
.state()
.auth_client
.exchange_code(AuthorizationCode::new(query.code))
.set_pkce_verifier(pkce_verifier)
.request_async(oauth2_surf::http_client)
.await
.context("failed to exchange oauth code")?
.access_token()
.secret()
.clone();
let user_details = request
.state()
.github_client
.user(github_access_token)
.details()
.await
.context("failed to fetch user")?;
let user = request
.db()
.get_user_by_github_login(&user_details.login)
.await?;
request
.session_mut()
.insert(CURRENT_GITHUB_USER, user_details.clone())?;
// When signing in from the native app, generate a new access token for the current user. Return
// a redirect so that the user's browser sends this access token to the locally-running app.
if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
let mut user_id = user.id;
if let Some(impersonated_login) = app_sign_in_params.impersonate {
log::info!("attempting to impersonate user @{}", impersonated_login);
if let Some(user) = request.db().get_users_by_ids([user_id]).await?.first() {
if user.admin {
user_id = request.db().create_user(&impersonated_login, false).await?;
log::info!("impersonating user {}", user_id.0);
} else {
log::info!("refusing to impersonate user");
}
}
}
let access_token = create_access_token(request.db(), user_id).await?;
let native_app_public_key =
zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
.context("failed to parse app public key")?;
let encrypted_access_token = native_app_public_key
.encrypt_string(&access_token)
.context("failed to encrypt access token with public key")?;
return Ok(tide::Redirect::new(&format!(
"http://127.0.0.1:{}?user_id={}&access_token={}",
app_sign_in_params.native_app_port, user_id.0, encrypted_access_token,
))
.into());
}
Ok(tide::Redirect::new("/").into())
}
async fn post_sign_out(mut request: Request) -> tide::Result {
request.session_mut().remove(CURRENT_GITHUB_USER);
Ok(tide::Redirect::new("/").into())
}
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
let access_token = zed_auth::random_token();
let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?;
db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
.await?;
Ok(access_token)
}
fn hash_access_token(token: &str) -> tide::Result<String> {
// Avoid slow hashing in debug mode.
let params = if cfg!(debug_assertions) {
scrypt::Params::new(1, 1, 1).unwrap()
} else {
scrypt::Params::recommended()
};
Ok(Scrypt
.hash_password(
token.as_bytes(),
None,
params,
&SaltString::generate(thread_rng()),
)?
.to_string())
}
pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
let hash = PasswordHash::new(hash)?;
Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
}

View file

@ -0,0 +1,20 @@
use anyhow::anyhow;
use std::fs;
fn main() -> anyhow::Result<()> {
let env: toml::map::Map<String, toml::Value> = toml::de::from_str(
&fs::read_to_string("./.env.toml").map_err(|_| anyhow!("no .env.toml file found"))?,
)?;
for (key, value) in env {
let value = match value {
toml::Value::String(value) => value,
toml::Value::Integer(value) => value.to_string(),
toml::Value::Float(value) => value.to_string(),
_ => panic!("unsupported TOML value in .env.toml for key {}", key),
};
println!("export {}=\"{}\"", key, value);
}
Ok(())
}

View file

@ -0,0 +1,95 @@
use db::{Db, UserId};
use rand::prelude::*;
use tide::log;
use time::{Duration, OffsetDateTime};
#[allow(unused)]
#[path = "../db.rs"]
mod db;
#[path = "../env.rs"]
mod env;
#[async_std::main]
async fn main() {
if let Err(error) = env::load_dotenv() {
log::error!(
"error loading .env.toml (this is expected in production): {}",
error
);
}
let mut rng = StdRng::from_entropy();
let database_url = std::env::var("DATABASE_URL").expect("missing DATABASE_URL env var");
let db = Db::new(&database_url, 5)
.await
.expect("failed to connect to postgres database");
let zed_users = ["nathansobo", "maxbrunsfeld", "as-cii", "iamnbutler"];
let mut zed_user_ids = Vec::<UserId>::new();
for zed_user in zed_users {
if let Some(user) = db
.get_user_by_github_login(zed_user)
.await
.expect("failed to fetch user")
{
zed_user_ids.push(user.id);
} else {
zed_user_ids.push(
db.create_user(zed_user, true)
.await
.expect("failed to insert user"),
);
}
}
let zed_org_id = if let Some(org) = db
.find_org_by_slug("zed")
.await
.expect("failed to fetch org")
{
org.id
} else {
db.create_org("Zed", "zed")
.await
.expect("failed to insert org")
};
let general_channel_id = if let Some(channel) = db
.get_org_channels(zed_org_id)
.await
.expect("failed to fetch channels")
.iter()
.find(|c| c.name == "General")
{
channel.id
} else {
let channel_id = db
.create_org_channel(zed_org_id, "General")
.await
.expect("failed to insert channel");
let now = OffsetDateTime::now_utc();
let max_seconds = Duration::days(100).as_seconds_f64();
let mut timestamps = (0..1000)
.map(|_| now - Duration::seconds_f64(rng.gen_range(0_f64..=max_seconds)))
.collect::<Vec<_>>();
timestamps.sort();
for timestamp in timestamps {
let sender_id = *zed_user_ids.choose(&mut rng).unwrap();
let body = lipsum::lipsum_words(rng.gen_range(1..=50));
db.create_channel_message(channel_id, sender_id, &body, timestamp, rng.gen())
.await
.expect("failed to insert message");
}
channel_id
};
for user_id in zed_user_ids {
db.add_org_member(zed_org_id, user_id, true)
.await
.expect("failed to insert org membership");
db.add_channel_member(general_channel_id, user_id, true)
.await
.expect("failed to insert channel membership");
}
}

View file

@ -0,0 +1,15 @@
use crate::{AppState, Request, RequestExt};
use std::sync::Arc;
use tide::http::mime;
pub fn add_routes(community: &mut tide::Server<Arc<AppState>>) {
community.at("/community").get(get_community);
}
async fn get_community(mut request: Request) -> tide::Result {
let data = request.layout_data().await?;
Ok(tide::Response::builder(200)
.body(request.state().render_template("community.hbs", &data)?)
.content_type(mime::HTML)
.build())
}

710
crates/server/src/db.rs Normal file
View file

@ -0,0 +1,710 @@
use anyhow::Context;
use async_std::task::{block_on, yield_now};
use serde::Serialize;
use sqlx::{types::Uuid, FromRow, Result};
use time::OffsetDateTime;
pub use async_sqlx_session::PostgresSessionStore as SessionStore;
pub use sqlx::postgres::PgPoolOptions as DbOptions;
macro_rules! test_support {
($self:ident, { $($token:tt)* }) => {{
let body = async {
$($token)*
};
if $self.test_mode {
yield_now().await;
block_on(body)
} else {
body.await
}
}};
}
#[derive(Clone)]
pub struct Db {
pool: sqlx::PgPool,
test_mode: bool,
}
impl Db {
pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
let pool = DbOptions::new()
.max_connections(max_connections)
.connect(url)
.await
.context("failed to connect to postgres database")?;
Ok(Self {
pool,
test_mode: false,
})
}
// signups
pub async fn create_signup(
&self,
github_login: &str,
email_address: &str,
about: &str,
wants_releases: bool,
wants_updates: bool,
wants_community: bool,
) -> Result<SignupId> {
test_support!(self, {
let query = "
INSERT INTO signups (
github_login,
email_address,
about,
wants_releases,
wants_updates,
wants_community
)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id
";
sqlx::query_scalar(query)
.bind(github_login)
.bind(email_address)
.bind(about)
.bind(wants_releases)
.bind(wants_updates)
.bind(wants_community)
.fetch_one(&self.pool)
.await
.map(SignupId)
})
}
pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
test_support!(self, {
let query = "SELECT * FROM signups ORDER BY github_login ASC";
sqlx::query_as(query).fetch_all(&self.pool).await
})
}
pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
test_support!(self, {
let query = "DELETE FROM signups WHERE id = $1";
sqlx::query(query)
.bind(id.0)
.execute(&self.pool)
.await
.map(drop)
})
}
// users
pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
test_support!(self, {
let query = "
INSERT INTO users (github_login, admin)
VALUES ($1, $2)
ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
RETURNING id
";
sqlx::query_scalar(query)
.bind(github_login)
.bind(admin)
.fetch_one(&self.pool)
.await
.map(UserId)
})
}
pub async fn get_all_users(&self) -> Result<Vec<User>> {
test_support!(self, {
let query = "SELECT * FROM users ORDER BY github_login ASC";
sqlx::query_as(query).fetch_all(&self.pool).await
})
}
pub async fn get_users_by_ids(
&self,
ids: impl IntoIterator<Item = UserId>,
) -> Result<Vec<User>> {
let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
test_support!(self, {
let query = "
SELECT users.*
FROM users
WHERE users.id = ANY ($1)
";
sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
})
}
pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
test_support!(self, {
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
sqlx::query_as(query)
.bind(github_login)
.fetch_optional(&self.pool)
.await
})
}
pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
test_support!(self, {
let query = "UPDATE users SET admin = $1 WHERE id = $2";
sqlx::query(query)
.bind(is_admin)
.bind(id.0)
.execute(&self.pool)
.await
.map(drop)
})
}
pub async fn delete_user(&self, id: UserId) -> Result<()> {
test_support!(self, {
let query = "DELETE FROM users WHERE id = $1;";
sqlx::query(query)
.bind(id.0)
.execute(&self.pool)
.await
.map(drop)
})
}
// access tokens
pub async fn create_access_token_hash(
&self,
user_id: UserId,
access_token_hash: &str,
max_access_token_count: usize,
) -> Result<()> {
test_support!(self, {
let insert_query = "
INSERT INTO access_tokens (user_id, hash)
VALUES ($1, $2);
";
let cleanup_query = "
DELETE FROM access_tokens
WHERE id IN (
SELECT id from access_tokens
WHERE user_id = $1
ORDER BY id DESC
OFFSET $3
)
";
let mut tx = self.pool.begin().await?;
sqlx::query(insert_query)
.bind(user_id.0)
.bind(access_token_hash)
.execute(&mut tx)
.await?;
sqlx::query(cleanup_query)
.bind(user_id.0)
.bind(access_token_hash)
.bind(max_access_token_count as u32)
.execute(&mut tx)
.await?;
tx.commit().await
})
}
pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
test_support!(self, {
let query = "
SELECT hash
FROM access_tokens
WHERE user_id = $1
ORDER BY id DESC
";
sqlx::query_scalar(query)
.bind(user_id.0)
.fetch_all(&self.pool)
.await
})
}
// orgs
#[allow(unused)] // Help rust-analyzer
#[cfg(any(test, feature = "seed-support"))]
pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
test_support!(self, {
let query = "
SELECT *
FROM orgs
WHERE slug = $1
";
sqlx::query_as(query)
.bind(slug)
.fetch_optional(&self.pool)
.await
})
}
#[cfg(any(test, feature = "seed-support"))]
pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
test_support!(self, {
let query = "
INSERT INTO orgs (name, slug)
VALUES ($1, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(name)
.bind(slug)
.fetch_one(&self.pool)
.await
.map(OrgId)
})
}
#[cfg(any(test, feature = "seed-support"))]
pub async fn add_org_member(
&self,
org_id: OrgId,
user_id: UserId,
is_admin: bool,
) -> Result<()> {
test_support!(self, {
let query = "
INSERT INTO org_memberships (org_id, user_id, admin)
VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING
";
sqlx::query(query)
.bind(org_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.pool)
.await
.map(drop)
})
}
// channels
#[cfg(any(test, feature = "seed-support"))]
pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
test_support!(self, {
let query = "
INSERT INTO channels (owner_id, owner_is_user, name)
VALUES ($1, false, $2)
RETURNING id
";
sqlx::query_scalar(query)
.bind(org_id.0)
.bind(name)
.fetch_one(&self.pool)
.await
.map(ChannelId)
})
}
#[allow(unused)] // Help rust-analyzer
#[cfg(any(test, feature = "seed-support"))]
pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
test_support!(self, {
let query = "
SELECT *
FROM channels
WHERE
channels.owner_is_user = false AND
channels.owner_id = $1
";
sqlx::query_as(query)
.bind(org_id.0)
.fetch_all(&self.pool)
.await
})
}
pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
test_support!(self, {
let query = "
SELECT
channels.id, channels.name
FROM
channel_memberships, channels
WHERE
channel_memberships.user_id = $1 AND
channel_memberships.channel_id = channels.id
";
sqlx::query_as(query)
.bind(user_id.0)
.fetch_all(&self.pool)
.await
})
}
pub async fn can_user_access_channel(
&self,
user_id: UserId,
channel_id: ChannelId,
) -> Result<bool> {
test_support!(self, {
let query = "
SELECT id
FROM channel_memberships
WHERE user_id = $1 AND channel_id = $2
LIMIT 1
";
sqlx::query_scalar::<_, i32>(query)
.bind(user_id.0)
.bind(channel_id.0)
.fetch_optional(&self.pool)
.await
.map(|e| e.is_some())
})
}
#[cfg(any(test, feature = "seed-support"))]
pub async fn add_channel_member(
&self,
channel_id: ChannelId,
user_id: UserId,
is_admin: bool,
) -> Result<()> {
test_support!(self, {
let query = "
INSERT INTO channel_memberships (channel_id, user_id, admin)
VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING
";
sqlx::query(query)
.bind(channel_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.pool)
.await
.map(drop)
})
}
// messages
pub async fn create_channel_message(
&self,
channel_id: ChannelId,
sender_id: UserId,
body: &str,
timestamp: OffsetDateTime,
nonce: u128,
) -> Result<MessageId> {
test_support!(self, {
let query = "
INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
RETURNING id
";
sqlx::query_scalar(query)
.bind(channel_id.0)
.bind(sender_id.0)
.bind(body)
.bind(timestamp)
.bind(Uuid::from_u128(nonce))
.fetch_one(&self.pool)
.await
.map(MessageId)
})
}
pub async fn get_channel_messages(
&self,
channel_id: ChannelId,
count: usize,
before_id: Option<MessageId>,
) -> Result<Vec<ChannelMessage>> {
test_support!(self, {
let query = r#"
SELECT * FROM (
SELECT
id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
FROM
channel_messages
WHERE
channel_id = $1 AND
id < $2
ORDER BY id DESC
LIMIT $3
) as recent_messages
ORDER BY id ASC
"#;
sqlx::query_as(query)
.bind(channel_id.0)
.bind(before_id.unwrap_or(MessageId::MAX))
.bind(count as i64)
.fetch_all(&self.pool)
.await
})
}
}
macro_rules! id_type {
($name:ident) => {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
#[sqlx(transparent)]
#[serde(transparent)]
pub struct $name(pub i32);
impl $name {
#[allow(unused)]
pub const MAX: Self = Self(i32::MAX);
#[allow(unused)]
pub fn from_proto(value: u64) -> Self {
Self(value as i32)
}
#[allow(unused)]
pub fn to_proto(&self) -> u64 {
self.0 as u64
}
}
};
}
id_type!(UserId);
#[derive(Debug, FromRow, Serialize, PartialEq)]
pub struct User {
pub id: UserId,
pub github_login: String,
pub admin: bool,
}
id_type!(OrgId);
#[derive(FromRow)]
pub struct Org {
pub id: OrgId,
pub name: String,
pub slug: String,
}
id_type!(SignupId);
#[derive(Debug, FromRow, Serialize)]
pub struct Signup {
pub id: SignupId,
pub github_login: String,
pub email_address: String,
pub about: String,
pub wants_releases: Option<bool>,
pub wants_updates: Option<bool>,
pub wants_community: Option<bool>,
}
id_type!(ChannelId);
#[derive(Debug, FromRow, Serialize)]
pub struct Channel {
pub id: ChannelId,
pub name: String,
}
id_type!(MessageId);
#[derive(Debug, FromRow)]
pub struct ChannelMessage {
pub id: MessageId,
pub sender_id: UserId,
pub body: String,
pub sent_at: OffsetDateTime,
pub nonce: Uuid,
}
#[cfg(test)]
pub mod tests {
use super::*;
use rand::prelude::*;
use sqlx::{
migrate::{MigrateDatabase, Migrator},
Postgres,
};
use std::path::Path;
pub struct TestDb {
pub db: Db,
pub name: String,
pub url: String,
}
impl TestDb {
pub fn new() -> Self {
// Enable tests to run in parallel by serializing the creation of each test database.
lazy_static::lazy_static! {
static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
}
let mut rng = StdRng::from_entropy();
let name = format!("zed-test-{}", rng.gen::<u128>());
let url = format!("postgres://postgres@localhost/{}", name);
let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
let db = block_on(async {
{
let _lock = DB_CREATION.lock();
Postgres::create_database(&url)
.await
.expect("failed to create test db");
}
let mut db = Db::new(&url, 5).await.unwrap();
db.test_mode = true;
let migrator = Migrator::new(migrations_path).await.unwrap();
migrator.run(&db.pool).await.unwrap();
db
});
Self { db, name, url }
}
pub fn db(&self) -> &Db {
&self.db
}
}
impl Drop for TestDb {
fn drop(&mut self) {
block_on(async {
let query = "
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
";
sqlx::query(query)
.bind(&self.name)
.execute(&self.db.pool)
.await
.unwrap();
self.db.pool.close().await;
Postgres::drop_database(&self.url).await.unwrap();
});
}
}
#[gpui::test]
async fn test_get_users_by_ids() {
let test_db = TestDb::new();
let db = test_db.db();
let user = db.create_user("user", false).await.unwrap();
let friend1 = db.create_user("friend-1", false).await.unwrap();
let friend2 = db.create_user("friend-2", false).await.unwrap();
let friend3 = db.create_user("friend-3", false).await.unwrap();
assert_eq!(
db.get_users_by_ids([user, friend1, friend2, friend3])
.await
.unwrap(),
vec![
User {
id: user,
github_login: "user".to_string(),
admin: false,
},
User {
id: friend1,
github_login: "friend-1".to_string(),
admin: false,
},
User {
id: friend2,
github_login: "friend-2".to_string(),
admin: false,
},
User {
id: friend3,
github_login: "friend-3".to_string(),
admin: false,
}
]
);
}
#[gpui::test]
async fn test_recent_channel_messages() {
let test_db = TestDb::new();
let db = test_db.db();
let user = db.create_user("user", false).await.unwrap();
let org = db.create_org("org", "org").await.unwrap();
let channel = db.create_org_channel(org, "channel").await.unwrap();
for i in 0..10 {
db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
.await
.unwrap();
}
let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
assert_eq!(
messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
["5", "6", "7", "8", "9"]
);
let prev_messages = db
.get_channel_messages(channel, 4, Some(messages[0].id))
.await
.unwrap();
assert_eq!(
prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
["1", "2", "3", "4"]
);
}
#[gpui::test]
async fn test_channel_message_nonces() {
let test_db = TestDb::new();
let db = test_db.db();
let user = db.create_user("user", false).await.unwrap();
let org = db.create_org("org", "org").await.unwrap();
let channel = db.create_org_channel(org, "channel").await.unwrap();
let msg1_id = db
.create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
.await
.unwrap();
let msg2_id = db
.create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
.await
.unwrap();
let msg3_id = db
.create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
.await
.unwrap();
let msg4_id = db
.create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
.await
.unwrap();
assert_ne!(msg1_id, msg2_id);
assert_eq!(msg1_id, msg3_id);
assert_eq!(msg2_id, msg4_id);
}
#[gpui::test]
async fn test_create_access_tokens() {
let test_db = TestDb::new();
let db = test_db.db();
let user = db.create_user("the-user", false).await.unwrap();
db.create_access_token_hash(user, "h1", 3).await.unwrap();
db.create_access_token_hash(user, "h2", 3).await.unwrap();
assert_eq!(
db.get_access_token_hashes(user).await.unwrap(),
&["h2".to_string(), "h1".to_string()]
);
db.create_access_token_hash(user, "h3", 3).await.unwrap();
assert_eq!(
db.get_access_token_hashes(user).await.unwrap(),
&["h3".to_string(), "h2".to_string(), "h1".to_string(),]
);
db.create_access_token_hash(user, "h4", 3).await.unwrap();
assert_eq!(
db.get_access_token_hashes(user).await.unwrap(),
&["h4".to_string(), "h3".to_string(), "h2".to_string(),]
);
db.create_access_token_hash(user, "h5", 3).await.unwrap();
assert_eq!(
db.get_access_token_hashes(user).await.unwrap(),
&["h5".to_string(), "h4".to_string(), "h3".to_string()]
);
}
}

20
crates/server/src/env.rs Normal file
View file

@ -0,0 +1,20 @@
use anyhow::anyhow;
use std::fs;
pub fn load_dotenv() -> anyhow::Result<()> {
let env: toml::map::Map<String, toml::Value> = toml::de::from_str(
&fs::read_to_string("./.env.toml").map_err(|_| anyhow!("no .env.toml file found"))?,
)?;
for (key, value) in env {
let value = match value {
toml::Value::String(value) => value,
toml::Value::Integer(value) => value.to_string(),
toml::Value::Float(value) => value.to_string(),
_ => panic!("unsupported TOML value in .env.toml for key {}", key),
};
std::env::set_var(key, value);
}
Ok(())
}

View file

@ -0,0 +1,73 @@
use crate::{AppState, LayoutData, Request, RequestExt};
use async_trait::async_trait;
use serde::Serialize;
use std::sync::Arc;
use tide::http::mime;
pub struct Middleware;
#[async_trait]
impl tide::Middleware<Arc<AppState>> for Middleware {
async fn handle(
&self,
mut request: Request,
next: tide::Next<'_, Arc<AppState>>,
) -> tide::Result {
let app = request.state().clone();
let layout_data = request.layout_data().await?;
let mut response = next.run(request).await;
#[derive(Serialize)]
struct ErrorData {
#[serde(flatten)]
layout: Arc<LayoutData>,
status: u16,
reason: &'static str,
}
if !response.status().is_success() {
response.set_body(app.render_template(
"error.hbs",
&ErrorData {
layout: layout_data,
status: response.status().into(),
reason: response.status().canonical_reason(),
},
)?);
response.set_content_type(mime::HTML);
}
Ok(response)
}
}
// Allow tide Results to accept context like other Results do when
// using anyhow.
pub trait TideResultExt {
fn context<C>(self, cx: C) -> Self
where
C: std::fmt::Display + Send + Sync + 'static;
fn with_context<C, F>(self, f: F) -> Self
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C;
}
impl<T> TideResultExt for tide::Result<T> {
fn context<C>(self, cx: C) -> Self
where
C: std::fmt::Display + Send + Sync + 'static,
{
self.map_err(|e| tide::Error::new(e.status(), e.into_inner().context(cx)))
}
fn with_context<C, F>(self, f: F) -> Self
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
self.map_err(|e| tide::Error::new(e.status(), e.into_inner().context(f())))
}
}

View file

@ -0,0 +1,43 @@
use std::{future::Future, time::Instant};
use async_std::sync::Mutex;
#[derive(Default)]
pub struct Expiring<T>(Mutex<Option<ExpiringState<T>>>);
pub struct ExpiringState<T> {
value: T,
expires_at: Instant,
}
impl<T: Clone> Expiring<T> {
pub async fn get_or_refresh<F, G>(&self, f: F) -> tide::Result<T>
where
F: FnOnce() -> G,
G: Future<Output = tide::Result<(T, Instant)>>,
{
let mut state = self.0.lock().await;
if let Some(state) = state.as_mut() {
if Instant::now() >= state.expires_at {
let (value, expires_at) = f().await?;
state.value = value.clone();
state.expires_at = expires_at;
Ok(value)
} else {
Ok(state.value.clone())
}
} else {
let (value, expires_at) = f().await?;
*state = Some(ExpiringState {
value: value.clone(),
expires_at,
});
Ok(value)
}
}
pub async fn clear(&self) {
self.0.lock().await.take();
}
}

265
crates/server/src/github.rs Normal file
View file

@ -0,0 +1,265 @@
use crate::expiring::Expiring;
use anyhow::{anyhow, Context};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
future::Future,
sync::Arc,
time::{Duration, Instant},
};
use surf::{http::Method, RequestBuilder, Url};
#[derive(Debug, Deserialize, Serialize)]
pub struct Release {
pub tag_name: String,
pub name: String,
pub body: String,
pub draft: bool,
pub assets: Vec<Asset>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Asset {
pub name: String,
pub url: String,
}
pub struct AppClient {
id: usize,
private_key: String,
jwt_bearer_header: Expiring<String>,
}
#[derive(Deserialize)]
struct Installation {
#[allow(unused)]
id: usize,
}
impl AppClient {
#[cfg(test)]
pub fn test() -> Arc<Self> {
Arc::new(Self {
id: Default::default(),
private_key: Default::default(),
jwt_bearer_header: Default::default(),
})
}
pub fn new(id: usize, private_key: String) -> Arc<Self> {
Arc::new(Self {
id,
private_key,
jwt_bearer_header: Default::default(),
})
}
pub async fn repo(self: &Arc<Self>, nwo: String) -> tide::Result<RepoClient> {
let installation: Installation = self
.request(
Method::Get,
&format!("/repos/{}/installation", &nwo),
|refresh| self.bearer_header(refresh),
)
.await?;
Ok(RepoClient {
app: self.clone(),
nwo,
installation_id: installation.id,
installation_token_header: Default::default(),
})
}
pub fn user(self: &Arc<Self>, access_token: String) -> UserClient {
UserClient {
app: self.clone(),
access_token,
}
}
async fn request<T, F, G>(
&self,
method: Method,
path: &str,
get_auth_header: F,
) -> tide::Result<T>
where
T: DeserializeOwned,
F: Fn(bool) -> G,
G: Future<Output = tide::Result<String>>,
{
let mut retried = false;
loop {
let response = RequestBuilder::new(
method,
Url::parse(&format!("https://api.github.com{}", path))?,
)
.header("Accept", "application/vnd.github.v3+json")
.header("Authorization", get_auth_header(retried).await?)
.recv_json()
.await;
if let Err(error) = response.as_ref() {
if error.status() == 401 && !retried {
retried = true;
continue;
}
}
return response;
}
}
async fn bearer_header(&self, refresh: bool) -> tide::Result<String> {
if refresh {
self.jwt_bearer_header.clear().await;
}
self.jwt_bearer_header
.get_or_refresh(|| async {
use jwt_simple::{algorithms::RS256KeyPair, prelude::*};
use std::time;
let key_pair = RS256KeyPair::from_pem(&self.private_key)
.with_context(|| format!("invalid private key {:?}", self.private_key))?;
let mut claims = Claims::create(Duration::from_mins(10));
claims.issued_at = Some(Clock::now_since_epoch() - Duration::from_mins(1));
claims.issuer = Some(self.id.to_string());
let token = key_pair.sign(claims).context("failed to sign claims")?;
let expires_at = time::Instant::now() + time::Duration::from_secs(9 * 60);
Ok((format!("Bearer {}", token), expires_at))
})
.await
}
async fn installation_token_header(
&self,
header: &Expiring<String>,
installation_id: usize,
refresh: bool,
) -> tide::Result<String> {
if refresh {
header.clear().await;
}
header
.get_or_refresh(|| async {
#[derive(Debug, Deserialize)]
struct AccessToken {
token: String,
}
let access_token: AccessToken = self
.request(
Method::Post,
&format!("/app/installations/{}/access_tokens", installation_id),
|refresh| self.bearer_header(refresh),
)
.await?;
let header = format!("Token {}", access_token.token);
let expires_at = Instant::now() + Duration::from_secs(60 * 30);
Ok((header, expires_at))
})
.await
}
}
pub struct RepoClient {
app: Arc<AppClient>,
nwo: String,
installation_id: usize,
installation_token_header: Expiring<String>,
}
impl RepoClient {
#[cfg(test)]
pub fn test(app_client: &Arc<AppClient>) -> Self {
Self {
app: app_client.clone(),
nwo: String::new(),
installation_id: 0,
installation_token_header: Default::default(),
}
}
pub async fn releases(&self) -> tide::Result<Vec<Release>> {
self.get(&format!("/repos/{}/releases?per_page=100", self.nwo))
.await
}
pub async fn release_asset(&self, tag: &str, name: &str) -> tide::Result<surf::Body> {
let release: Release = self
.get(&format!("/repos/{}/releases/tags/{}", self.nwo, tag))
.await?;
let asset = release
.assets
.iter()
.find(|asset| asset.name == name)
.ok_or_else(|| anyhow!("no asset found with name {}", name))?;
let request = surf::get(&asset.url)
.header("Accept", "application/octet-stream'")
.header(
"Authorization",
self.installation_token_header(false).await?,
);
let client = surf::client().with(surf::middleware::Redirect::new(5));
let mut response = client.send(request).await?;
Ok(response.take_body())
}
async fn get<T: DeserializeOwned>(&self, path: &str) -> tide::Result<T> {
self.request::<T>(Method::Get, path).await
}
async fn request<T: DeserializeOwned>(&self, method: Method, path: &str) -> tide::Result<T> {
Ok(self
.app
.request(method, path, |refresh| {
self.installation_token_header(refresh)
})
.await?)
}
async fn installation_token_header(&self, refresh: bool) -> tide::Result<String> {
self.app
.installation_token_header(
&self.installation_token_header,
self.installation_id,
refresh,
)
.await
}
}
pub struct UserClient {
app: Arc<AppClient>,
access_token: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct User {
pub login: String,
pub avatar_url: String,
}
impl UserClient {
pub async fn details(&self) -> tide::Result<User> {
Ok(self
.app
.request(Method::Get, "/user", |_| async {
Ok(self.access_token_header())
})
.await?)
}
fn access_token_header(&self) -> String {
format!("Token {}", self.access_token)
}
}

79
crates/server/src/home.rs Normal file
View file

@ -0,0 +1,79 @@
use crate::{AppState, Request, RequestExt as _};
use serde::Deserialize;
use std::sync::Arc;
use tide::{http::mime, log, Server};
pub fn add_routes(app: &mut Server<Arc<AppState>>) {
app.at("/").get(get_home);
app.at("/signups").post(post_signup);
app.at("/releases/:tag_name/:name").get(get_release_asset);
}
async fn get_home(mut request: Request) -> tide::Result {
let data = request.layout_data().await?;
Ok(tide::Response::builder(200)
.body(request.state().render_template("home.hbs", &data)?)
.content_type(mime::HTML)
.build())
}
async fn post_signup(mut request: Request) -> tide::Result {
#[derive(Debug, Deserialize)]
struct Form {
github_login: String,
email_address: String,
about: String,
#[serde(default)]
wants_releases: bool,
#[serde(default)]
wants_updates: bool,
#[serde(default)]
wants_community: bool,
}
let mut form: Form = request.body_form().await?;
form.github_login = form
.github_login
.strip_prefix("@")
.map(str::to_string)
.unwrap_or(form.github_login);
log::info!("Signup submitted: {:?}", form);
// Save signup in the database
request
.db()
.create_signup(
&form.github_login,
&form.email_address,
&form.about,
form.wants_releases,
form.wants_updates,
form.wants_community,
)
.await?;
let layout_data = request.layout_data().await?;
Ok(tide::Response::builder(200)
.body(
request
.state()
.render_template("signup.hbs", &layout_data)?,
)
.content_type(mime::HTML)
.build())
}
async fn get_release_asset(request: Request) -> tide::Result {
let body = request
.state()
.repo_client
.release_asset(request.param("tag_name")?, request.param("name")?)
.await?;
Ok(tide::Response::builder(200)
.header("Cache-Control", "no-transform")
.content_type(mime::BYTE_STREAM)
.body(body)
.build())
}

196
crates/server/src/main.rs Normal file
View file

@ -0,0 +1,196 @@
mod admin;
mod assets;
mod auth;
mod community;
mod db;
mod env;
mod errors;
mod expiring;
mod github;
mod home;
mod releases;
mod rpc;
mod team;
use self::errors::TideResultExt as _;
use anyhow::Result;
use async_std::net::TcpListener;
use async_trait::async_trait;
use auth::RequestExt as _;
use db::Db;
use handlebars::{Handlebars, TemplateRenderError};
use parking_lot::RwLock;
use rust_embed::RustEmbed;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use surf::http::cookies::SameSite;
use tide::{log, sessions::SessionMiddleware};
use tide_compress::CompressMiddleware;
use zrpc::Peer;
type Request = tide::Request<Arc<AppState>>;
#[derive(RustEmbed)]
#[folder = "templates"]
struct Templates;
#[derive(Default, Deserialize)]
pub struct Config {
pub http_port: u16,
pub database_url: String,
pub session_secret: String,
pub github_app_id: usize,
pub github_client_id: String,
pub github_client_secret: String,
pub github_private_key: String,
}
pub struct AppState {
db: Db,
handlebars: RwLock<Handlebars<'static>>,
auth_client: auth::Client,
github_client: Arc<github::AppClient>,
repo_client: github::RepoClient,
config: Config,
}
impl AppState {
async fn new(config: Config) -> tide::Result<Arc<Self>> {
let db = Db::new(&config.database_url, 5).await?;
let github_client =
github::AppClient::new(config.github_app_id, config.github_private_key.clone());
let repo_client = github_client
.repo("zed-industries/zed".into())
.await
.context("failed to initialize github client")?;
let this = Self {
db,
handlebars: Default::default(),
auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret),
github_client,
repo_client,
config,
};
this.register_partials();
Ok(Arc::new(this))
}
fn register_partials(&self) {
for path in Templates::iter() {
if let Some(partial_name) = path
.strip_prefix("partials/")
.and_then(|path| path.strip_suffix(".hbs"))
{
let partial = Templates::get(path.as_ref()).unwrap();
self.handlebars
.write()
.register_partial(partial_name, std::str::from_utf8(&partial.data).unwrap())
.unwrap()
}
}
}
fn render_template(
&self,
path: &'static str,
data: &impl Serialize,
) -> Result<String, TemplateRenderError> {
#[cfg(debug_assertions)]
self.register_partials();
self.handlebars.read().render_template(
std::str::from_utf8(&Templates::get(path).unwrap().data).unwrap(),
data,
)
}
}
#[async_trait]
trait RequestExt {
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
fn db(&self) -> &Db;
}
#[async_trait]
impl RequestExt for Request {
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>> {
if self.ext::<Arc<LayoutData>>().is_none() {
self.set_ext(Arc::new(LayoutData {
current_user: self.current_user().await?,
}));
}
Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
}
fn db(&self) -> &Db {
&self.state().db
}
}
#[derive(Serialize)]
struct LayoutData {
current_user: Option<auth::User>,
}
#[async_std::main]
async fn main() -> tide::Result<()> {
log::start();
if let Err(error) = env::load_dotenv() {
log::error!(
"error loading .env.toml (this is expected in production): {}",
error
);
}
let config = envy::from_env::<Config>().expect("error loading config");
let state = AppState::new(config).await?;
let rpc = Peer::new();
run_server(
state.clone(),
rpc,
TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)).await?,
)
.await?;
Ok(())
}
pub async fn run_server(
state: Arc<AppState>,
rpc: Arc<Peer>,
listener: TcpListener,
) -> tide::Result<()> {
let mut web = tide::with_state(state.clone());
web.with(CompressMiddleware::new());
web.with(
SessionMiddleware::new(
db::SessionStore::new_with_table_name(&state.config.database_url, "sessions")
.await
.unwrap(),
state.config.session_secret.as_bytes(),
)
.with_same_site_policy(SameSite::Lax), // Required obtain our session in /auth_callback
);
web.with(errors::Middleware);
home::add_routes(&mut web);
team::add_routes(&mut web);
releases::add_routes(&mut web);
community::add_routes(&mut web);
admin::add_routes(&mut web);
auth::add_routes(&mut web);
let mut assets = tide::new();
assets.with(CompressMiddleware::new());
assets::add_routes(&mut assets);
let mut app = tide::with_state(state.clone());
rpc::add_routes(&mut app, &rpc);
app.at("/").nest(web);
app.at("/static").nest(assets);
app.listen(listener).await?;
Ok(())
}

View file

@ -0,0 +1,55 @@
use crate::{
auth::RequestExt as _, github::Release, AppState, LayoutData, Request, RequestExt as _,
};
use comrak::ComrakOptions;
use serde::{Serialize};
use std::sync::Arc;
use tide::{http::mime};
pub fn add_routes(releases: &mut tide::Server<Arc<AppState>>) {
releases.at("/releases").get(get_releases);
}
async fn get_releases(mut request: Request) -> tide::Result {
#[derive(Serialize)]
struct ReleasesData {
#[serde(flatten)]
layout: Arc<LayoutData>,
releases: Option<Vec<Release>>,
}
let mut data = ReleasesData {
layout: request.layout_data().await?,
releases: None,
};
if let Some(user) = request.current_user().await? {
if user.is_insider {
data.releases = Some(
request
.state()
.repo_client
.releases()
.await?
.into_iter()
.filter_map(|mut release| {
if release.draft {
None
} else {
let mut options = ComrakOptions::default();
options.render.unsafe_ = true; // Allow raw HTML in the markup. We control these release notes anyway.
release.body = comrak::markdown_to_html(&release.body, &options);
Some(release)
}
})
.collect(),
);
}
}
Ok(tide::Response::builder(200)
.body(request.state().render_template("releases.hbs", &data)?)
.content_type(mime::HTML)
.build())
}

2307
crates/server/src/rpc.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,615 @@
use crate::db::{ChannelId, UserId};
use anyhow::anyhow;
use std::collections::{hash_map, HashMap, HashSet};
use zrpc::{proto, ConnectionId};
#[derive(Default)]
pub struct Store {
connections: HashMap<ConnectionId, ConnectionState>,
connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
worktrees: HashMap<u64, Worktree>,
visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
channels: HashMap<ChannelId, Channel>,
next_worktree_id: u64,
}
struct ConnectionState {
user_id: UserId,
worktrees: HashSet<u64>,
channels: HashSet<ChannelId>,
}
pub struct Worktree {
pub host_connection_id: ConnectionId,
pub collaborator_user_ids: Vec<UserId>,
pub root_name: String,
pub share: Option<WorktreeShare>,
}
pub struct WorktreeShare {
pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
pub active_replica_ids: HashSet<ReplicaId>,
pub entries: HashMap<u64, proto::Entry>,
}
#[derive(Default)]
pub struct Channel {
pub connection_ids: HashSet<ConnectionId>,
}
pub type ReplicaId = u16;
#[derive(Default)]
pub struct RemovedConnectionState {
pub hosted_worktrees: HashMap<u64, Worktree>,
pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
pub collaborator_ids: HashSet<UserId>,
}
pub struct JoinedWorktree<'a> {
pub replica_id: ReplicaId,
pub worktree: &'a Worktree,
}
pub struct UnsharedWorktree {
pub connection_ids: Vec<ConnectionId>,
pub collaborator_ids: Vec<UserId>,
}
pub struct LeftWorktree {
pub connection_ids: Vec<ConnectionId>,
pub collaborator_ids: Vec<UserId>,
}
impl Store {
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
self.connections.insert(
connection_id,
ConnectionState {
user_id,
worktrees: Default::default(),
channels: Default::default(),
},
);
self.connections_by_user_id
.entry(user_id)
.or_default()
.insert(connection_id);
}
pub fn remove_connection(
&mut self,
connection_id: ConnectionId,
) -> tide::Result<RemovedConnectionState> {
let connection = if let Some(connection) = self.connections.remove(&connection_id) {
connection
} else {
return Err(anyhow!("no such connection"))?;
};
for channel_id in &connection.channels {
if let Some(channel) = self.channels.get_mut(&channel_id) {
channel.connection_ids.remove(&connection_id);
}
}
let user_connections = self
.connections_by_user_id
.get_mut(&connection.user_id)
.unwrap();
user_connections.remove(&connection_id);
if user_connections.is_empty() {
self.connections_by_user_id.remove(&connection.user_id);
}
let mut result = RemovedConnectionState::default();
for worktree_id in connection.worktrees.clone() {
if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
result
.collaborator_ids
.extend(worktree.collaborator_user_ids.iter().copied());
result.hosted_worktrees.insert(worktree_id, worktree);
} else if let Some(worktree) = self.leave_worktree(connection_id, worktree_id) {
result
.guest_worktree_ids
.insert(worktree_id, worktree.connection_ids);
result.collaborator_ids.extend(worktree.collaborator_ids);
}
}
#[cfg(test)]
self.check_invariants();
Ok(result)
}
#[cfg(test)]
pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
self.channels.get(&id)
}
pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.insert(channel_id);
self.channels
.entry(channel_id)
.or_default()
.connection_ids
.insert(connection_id);
}
}
pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.remove(&channel_id);
if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
entry.get_mut().connection_ids.remove(&connection_id);
if entry.get_mut().connection_ids.is_empty() {
entry.remove();
}
}
}
}
pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
Ok(self
.connections
.get(&connection_id)
.ok_or_else(|| anyhow!("unknown connection"))?
.user_id)
}
pub fn connection_ids_for_user<'a>(
&'a self,
user_id: UserId,
) -> impl 'a + Iterator<Item = ConnectionId> {
self.connections_by_user_id
.get(&user_id)
.into_iter()
.flatten()
.copied()
}
pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
let mut collaborators = HashMap::new();
for worktree_id in self
.visible_worktrees_by_user_id
.get(&user_id)
.unwrap_or(&HashSet::new())
{
let worktree = &self.worktrees[worktree_id];
let mut guests = HashSet::new();
if let Ok(share) = worktree.share() {
for guest_connection_id in share.guest_connection_ids.keys() {
if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
guests.insert(user_id.to_proto());
}
}
}
if let Ok(host_user_id) = self.user_id_for_connection(worktree.host_connection_id) {
collaborators
.entry(host_user_id)
.or_insert_with(|| proto::Collaborator {
user_id: host_user_id.to_proto(),
worktrees: Vec::new(),
})
.worktrees
.push(proto::WorktreeMetadata {
id: *worktree_id,
root_name: worktree.root_name.clone(),
is_shared: worktree.share.is_some(),
guests: guests.into_iter().collect(),
});
}
}
collaborators.into_values().collect()
}
pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
let worktree_id = self.next_worktree_id;
for collaborator_user_id in &worktree.collaborator_user_ids {
self.visible_worktrees_by_user_id
.entry(*collaborator_user_id)
.or_default()
.insert(worktree_id);
}
self.next_worktree_id += 1;
if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
connection.worktrees.insert(worktree_id);
}
self.worktrees.insert(worktree_id, worktree);
#[cfg(test)]
self.check_invariants();
worktree_id
}
pub fn remove_worktree(
&mut self,
worktree_id: u64,
acting_connection_id: ConnectionId,
) -> tide::Result<Worktree> {
let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
if e.get().host_connection_id != acting_connection_id {
Err(anyhow!("not your worktree"))?;
}
e.remove()
} else {
return Err(anyhow!("no such worktree"))?;
};
if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
connection.worktrees.remove(&worktree_id);
}
if let Some(share) = &worktree.share {
for connection_id in share.guest_connection_ids.keys() {
if let Some(connection) = self.connections.get_mut(connection_id) {
connection.worktrees.remove(&worktree_id);
}
}
}
for collaborator_user_id in &worktree.collaborator_user_ids {
if let Some(visible_worktrees) = self
.visible_worktrees_by_user_id
.get_mut(&collaborator_user_id)
{
visible_worktrees.remove(&worktree_id);
}
}
#[cfg(test)]
self.check_invariants();
Ok(worktree)
}
pub fn share_worktree(
&mut self,
worktree_id: u64,
connection_id: ConnectionId,
entries: HashMap<u64, proto::Entry>,
) -> Option<Vec<UserId>> {
if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
if worktree.host_connection_id == connection_id {
worktree.share = Some(WorktreeShare {
guest_connection_ids: Default::default(),
active_replica_ids: Default::default(),
entries,
});
return Some(worktree.collaborator_user_ids.clone());
}
}
None
}
pub fn unshare_worktree(
&mut self,
worktree_id: u64,
acting_connection_id: ConnectionId,
) -> tide::Result<UnsharedWorktree> {
let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
worktree
} else {
return Err(anyhow!("no such worktree"))?;
};
if worktree.host_connection_id != acting_connection_id {
return Err(anyhow!("not your worktree"))?;
}
let connection_ids = worktree.connection_ids();
let collaborator_ids = worktree.collaborator_user_ids.clone();
if let Some(share) = worktree.share.take() {
for connection_id in share.guest_connection_ids.into_keys() {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.worktrees.remove(&worktree_id);
}
}
#[cfg(test)]
self.check_invariants();
Ok(UnsharedWorktree {
connection_ids,
collaborator_ids,
})
} else {
Err(anyhow!("worktree is not shared"))?
}
}
pub fn join_worktree(
&mut self,
connection_id: ConnectionId,
user_id: UserId,
worktree_id: u64,
) -> tide::Result<JoinedWorktree> {
let connection = self
.connections
.get_mut(&connection_id)
.ok_or_else(|| anyhow!("no such connection"))?;
let worktree = self
.worktrees
.get_mut(&worktree_id)
.and_then(|worktree| {
if worktree.collaborator_user_ids.contains(&user_id) {
Some(worktree)
} else {
None
}
})
.ok_or_else(|| anyhow!("no such worktree"))?;
let share = worktree.share_mut()?;
connection.worktrees.insert(worktree_id);
let mut replica_id = 1;
while share.active_replica_ids.contains(&replica_id) {
replica_id += 1;
}
share.active_replica_ids.insert(replica_id);
share.guest_connection_ids.insert(connection_id, replica_id);
#[cfg(test)]
self.check_invariants();
Ok(JoinedWorktree {
replica_id,
worktree: &self.worktrees[&worktree_id],
})
}
pub fn leave_worktree(
&mut self,
connection_id: ConnectionId,
worktree_id: u64,
) -> Option<LeftWorktree> {
let worktree = self.worktrees.get_mut(&worktree_id)?;
let share = worktree.share.as_mut()?;
let replica_id = share.guest_connection_ids.remove(&connection_id)?;
share.active_replica_ids.remove(&replica_id);
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.worktrees.remove(&worktree_id);
}
let connection_ids = worktree.connection_ids();
let collaborator_ids = worktree.collaborator_user_ids.clone();
#[cfg(test)]
self.check_invariants();
Some(LeftWorktree {
connection_ids,
collaborator_ids,
})
}
pub fn update_worktree(
&mut self,
connection_id: ConnectionId,
worktree_id: u64,
removed_entries: &[u64],
updated_entries: &[proto::Entry],
) -> tide::Result<Vec<ConnectionId>> {
let worktree = self.write_worktree(worktree_id, connection_id)?;
let share = worktree.share_mut()?;
for entry_id in removed_entries {
share.entries.remove(&entry_id);
}
for entry in updated_entries {
share.entries.insert(entry.id, entry.clone());
}
Ok(worktree.connection_ids())
}
pub fn worktree_host_connection_id(
&self,
connection_id: ConnectionId,
worktree_id: u64,
) -> tide::Result<ConnectionId> {
Ok(self
.read_worktree(worktree_id, connection_id)?
.host_connection_id)
}
pub fn worktree_guest_connection_ids(
&self,
connection_id: ConnectionId,
worktree_id: u64,
) -> tide::Result<Vec<ConnectionId>> {
Ok(self
.read_worktree(worktree_id, connection_id)?
.share()?
.guest_connection_ids
.keys()
.copied()
.collect())
}
pub fn worktree_connection_ids(
&self,
connection_id: ConnectionId,
worktree_id: u64,
) -> tide::Result<Vec<ConnectionId>> {
Ok(self
.read_worktree(worktree_id, connection_id)?
.connection_ids())
}
pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
Some(self.channels.get(&channel_id)?.connection_ids())
}
fn read_worktree(
&self,
worktree_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&Worktree> {
let worktree = self
.worktrees
.get(&worktree_id)
.ok_or_else(|| anyhow!("worktree not found"))?;
if worktree.host_connection_id == connection_id
|| worktree
.share()?
.guest_connection_ids
.contains_key(&connection_id)
{
Ok(worktree)
} else {
Err(anyhow!(
"{} is not a member of worktree {}",
connection_id,
worktree_id
))?
}
}
fn write_worktree(
&mut self,
worktree_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&mut Worktree> {
let worktree = self
.worktrees
.get_mut(&worktree_id)
.ok_or_else(|| anyhow!("worktree not found"))?;
if worktree.host_connection_id == connection_id
|| worktree.share.as_ref().map_or(false, |share| {
share.guest_connection_ids.contains_key(&connection_id)
})
{
Ok(worktree)
} else {
Err(anyhow!(
"{} is not a member of worktree {}",
connection_id,
worktree_id
))?
}
}
#[cfg(test)]
fn check_invariants(&self) {
for (connection_id, connection) in &self.connections {
for worktree_id in &connection.worktrees {
let worktree = &self.worktrees.get(&worktree_id).unwrap();
if worktree.host_connection_id != *connection_id {
assert!(worktree
.share()
.unwrap()
.guest_connection_ids
.contains_key(connection_id));
}
}
for channel_id in &connection.channels {
let channel = self.channels.get(channel_id).unwrap();
assert!(channel.connection_ids.contains(connection_id));
}
assert!(self
.connections_by_user_id
.get(&connection.user_id)
.unwrap()
.contains(connection_id));
}
for (user_id, connection_ids) in &self.connections_by_user_id {
for connection_id in connection_ids {
assert_eq!(
self.connections.get(connection_id).unwrap().user_id,
*user_id
);
}
}
for (worktree_id, worktree) in &self.worktrees {
let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
assert!(host_connection.worktrees.contains(worktree_id));
for collaborator_id in &worktree.collaborator_user_ids {
let visible_worktree_ids = self
.visible_worktrees_by_user_id
.get(collaborator_id)
.unwrap();
assert!(visible_worktree_ids.contains(worktree_id));
}
if let Some(share) = &worktree.share {
for guest_connection_id in share.guest_connection_ids.keys() {
let guest_connection = self.connections.get(guest_connection_id).unwrap();
assert!(guest_connection.worktrees.contains(worktree_id));
}
assert_eq!(
share.active_replica_ids.len(),
share.guest_connection_ids.len(),
);
assert_eq!(
share.active_replica_ids,
share
.guest_connection_ids
.values()
.copied()
.collect::<HashSet<_>>(),
);
}
}
for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
for worktree_id in visible_worktree_ids {
let worktree = self.worktrees.get(worktree_id).unwrap();
assert!(worktree.collaborator_user_ids.contains(user_id));
}
}
for (channel_id, channel) in &self.channels {
for connection_id in &channel.connection_ids {
let connection = self.connections.get(connection_id).unwrap();
assert!(connection.channels.contains(channel_id));
}
}
}
}
impl Worktree {
pub fn connection_ids(&self) -> Vec<ConnectionId> {
if let Some(share) = &self.share {
share
.guest_connection_ids
.keys()
.copied()
.chain(Some(self.host_connection_id))
.collect()
} else {
vec![self.host_connection_id]
}
}
pub fn share(&self) -> tide::Result<&WorktreeShare> {
Ok(self
.share
.as_ref()
.ok_or_else(|| anyhow!("worktree is not shared"))?)
}
fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
Ok(self
.share
.as_mut()
.ok_or_else(|| anyhow!("worktree is not shared"))?)
}
}
impl Channel {
fn connection_ids(&self) -> Vec<ConnectionId> {
self.connection_ids.iter().copied().collect()
}
}

15
crates/server/src/team.rs Normal file
View file

@ -0,0 +1,15 @@
use crate::{AppState, Request, RequestExt};
use std::sync::Arc;
use tide::http::mime;
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
app.at("/team").get(get_team);
}
async fn get_team(mut request: Request) -> tide::Result {
let data = request.layout_data().await?;
Ok(tide::Response::builder(200)
.body(request.state().render_template("team.hbs", &data)?)
.content_type(mime::HTML)
.build())
}