Rename zed-server to collab

Over time, I think we may end up having multiple services, so it seems like a good opportunity to name this one more specifically while the cost is low. It just seems like naming it "zed" and "zed-server" leaves it a bit open ended.
This commit is contained in:
Nathan Sobo 2022-04-09 08:30:42 -06:00
parent 17195e615e
commit ab8204368c
124 changed files with 71 additions and 113 deletions

117
crates/collab/src/admin.rs Normal file
View file

@ -0,0 +1,117 @@
use crate::{auth::RequestExt as _, db, AppState, LayoutData, Request, RequestExt as _};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use surf::http::mime;
#[async_trait]
pub trait RequestExt {
async fn require_admin(&self) -> tide::Result<()>;
}
#[async_trait]
impl RequestExt for Request {
async fn require_admin(&self) -> tide::Result<()> {
let current_user = self
.current_user()
.await?
.ok_or_else(|| tide::Error::from_str(401, "not logged in"))?;
if current_user.is_admin {
Ok(())
} else {
Err(tide::Error::from_str(
403,
"authenticated user is not an admin",
))
}
}
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
app.at("/admin").get(get_admin_page);
app.at("/admin/users").post(post_user);
app.at("/admin/users/:id").put(put_user);
app.at("/admin/users/:id/delete").post(delete_user);
app.at("/admin/signups/:id/delete").post(delete_signup);
}
#[derive(Serialize)]
struct AdminData {
#[serde(flatten)]
layout: Arc<LayoutData>,
users: Vec<db::User>,
signups: Vec<db::Signup>,
}
async fn get_admin_page(mut request: Request) -> tide::Result {
request.require_admin().await?;
let data = AdminData {
layout: request.layout_data().await?,
users: request.db().get_all_users().await?,
signups: request.db().get_all_signups().await?,
};
Ok(tide::Response::builder(200)
.body(request.state().render_template("admin.hbs", &data)?)
.content_type(mime::HTML)
.build())
}
async fn post_user(mut request: Request) -> tide::Result {
request.require_admin().await?;
#[derive(Deserialize)]
struct Form {
github_login: String,
#[serde(default)]
admin: bool,
}
let form = request.body_form::<Form>().await?;
let github_login = form
.github_login
.strip_prefix("@")
.unwrap_or(&form.github_login);
if !github_login.is_empty() {
request.db().create_user(github_login, form.admin).await?;
}
Ok(tide::Redirect::new("/admin").into())
}
async fn put_user(mut request: Request) -> tide::Result {
request.require_admin().await?;
let user_id = request.param("id")?.parse()?;
#[derive(Deserialize)]
struct Body {
admin: bool,
}
let body: Body = request.body_json().await?;
request
.db()
.set_user_is_admin(db::UserId(user_id), body.admin)
.await?;
Ok(tide::Response::builder(200).build())
}
async fn delete_user(request: Request) -> tide::Result {
request.require_admin().await?;
let user_id = db::UserId(request.param("id")?.parse()?);
request.db().destroy_user(user_id).await?;
Ok(tide::Redirect::new("/admin").into())
}
async fn delete_signup(request: Request) -> tide::Result {
request.require_admin().await?;
let signup_id = db::SignupId(request.param("id")?.parse()?);
request.db().destroy_signup(signup_id).await?;
Ok(tide::Redirect::new("/admin").into())
}

179
crates/collab/src/api.rs Normal file
View file

@ -0,0 +1,179 @@
use crate::{auth, db::UserId, AppState, Request, RequestExt as _};
use async_trait::async_trait;
use serde::Deserialize;
use serde_json::json;
use std::sync::Arc;
use surf::StatusCode;
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
app.at("/users").get(get_users);
app.at("/users").post(create_user);
app.at("/users/:id").put(update_user);
app.at("/users/:id").delete(destroy_user);
app.at("/users/:github_login").get(get_user);
app.at("/users/:github_login/access_tokens")
.post(create_access_token);
}
async fn get_user(request: Request) -> tide::Result {
request.require_token().await?;
let user = request
.db()
.get_user_by_github_login(request.param("github_login")?)
.await?
.ok_or_else(|| surf::Error::from_str(404, "user not found"))?;
Ok(tide::Response::builder(StatusCode::Ok)
.body(tide::Body::from_json(&user)?)
.build())
}
async fn get_users(request: Request) -> tide::Result {
request.require_token().await?;
let users = request.db().get_all_users().await?;
Ok(tide::Response::builder(StatusCode::Ok)
.body(tide::Body::from_json(&users)?)
.build())
}
async fn create_user(mut request: Request) -> tide::Result {
request.require_token().await?;
#[derive(Deserialize)]
struct Params {
github_login: String,
admin: bool,
}
let params = request.body_json::<Params>().await?;
let user_id = request
.db()
.create_user(&params.github_login, params.admin)
.await?;
let user = request.db().get_user_by_id(user_id).await?.ok_or_else(|| {
surf::Error::from_str(
StatusCode::InternalServerError,
"couldn't find the user we just created",
)
})?;
Ok(tide::Response::builder(StatusCode::Ok)
.body(tide::Body::from_json(&user)?)
.build())
}
async fn update_user(mut request: Request) -> tide::Result {
request.require_token().await?;
#[derive(Deserialize)]
struct Params {
admin: bool,
}
let user_id = UserId(
request
.param("id")?
.parse::<i32>()
.map_err(|error| surf::Error::from_str(StatusCode::BadRequest, error.to_string()))?,
);
let params = request.body_json::<Params>().await?;
request
.db()
.set_user_is_admin(user_id, params.admin)
.await?;
Ok(tide::Response::builder(StatusCode::Ok).build())
}
async fn destroy_user(request: Request) -> tide::Result {
request.require_token().await?;
let user_id = UserId(
request
.param("id")?
.parse::<i32>()
.map_err(|error| surf::Error::from_str(StatusCode::BadRequest, error.to_string()))?,
);
request.db().destroy_user(user_id).await?;
Ok(tide::Response::builder(StatusCode::Ok).build())
}
async fn create_access_token(request: Request) -> tide::Result {
request.require_token().await?;
let user = request
.db()
.get_user_by_github_login(request.param("github_login")?)
.await?
.ok_or_else(|| surf::Error::from_str(StatusCode::NotFound, "user not found"))?;
#[derive(Deserialize)]
struct QueryParams {
public_key: String,
impersonate: Option<String>,
}
let query_params: QueryParams = request.query().map_err(|_| {
surf::Error::from_str(StatusCode::UnprocessableEntity, "invalid query params")
})?;
let mut user_id = user.id;
if let Some(impersonate) = query_params.impersonate {
if user.admin {
if let Some(impersonated_user) =
request.db().get_user_by_github_login(&impersonate).await?
{
user_id = impersonated_user.id;
} else {
return Ok(tide::Response::builder(StatusCode::UnprocessableEntity)
.body(format!(
"Can't impersonate non-existent user {}",
impersonate
))
.build());
}
} else {
return Ok(tide::Response::builder(StatusCode::Unauthorized)
.body(format!(
"Can't impersonate user {} because the real user isn't an admin",
impersonate
))
.build());
}
}
let access_token = auth::create_access_token(request.db().as_ref(), user_id).await?;
let encrypted_access_token =
auth::encrypt_access_token(&access_token, query_params.public_key.clone())?;
Ok(tide::Response::builder(StatusCode::Ok)
.body(json!({"user_id": user_id, "encrypted_access_token": encrypted_access_token}))
.build())
}
#[async_trait]
pub trait RequestExt {
async fn require_token(&self) -> tide::Result<()>;
}
#[async_trait]
impl RequestExt for Request {
async fn require_token(&self) -> tide::Result<()> {
let token = self
.header("Authorization")
.and_then(|header| header.get(0))
.and_then(|header| header.as_str().strip_prefix("token "))
.ok_or_else(|| surf::Error::from_str(403, "invalid authorization header"))?;
if token == self.state().config.api_token {
Ok(())
} else {
Err(tide::Error::from_str(403, "invalid authorization token"))
}
}
}

View file

@ -0,0 +1,29 @@
use anyhow::anyhow;
use rust_embed::RustEmbed;
use tide::{http::mime, Server};
#[derive(RustEmbed)]
#[folder = "static"]
struct Static;
pub fn add_routes(app: &mut Server<()>) {
app.at("/*path").get(get_static_asset);
}
async fn get_static_asset(request: tide::Request<()>) -> tide::Result {
let path = request.param("path").unwrap();
let content = Static::get(path).ok_or_else(|| anyhow!("asset not found at {}", path))?;
let content_type = if path.starts_with("svg") {
mime::SVG
} else if path.starts_with("styles") {
mime::CSS
} else {
mime::BYTE_STREAM
};
Ok(tide::Response::builder(200)
.content_type(content_type)
.body(content.data.as_ref())
.build())
}

309
crates/collab/src/auth.rs Normal file
View file

@ -0,0 +1,309 @@
use super::{
db::{self, UserId},
errors::TideResultExt,
};
use crate::{github, AppState, Request, RequestExt as _};
use anyhow::{anyhow, Context};
use async_trait::async_trait;
pub use oauth2::basic::BasicClient as Client;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
TokenResponse as _, TokenUrl,
};
use rand::thread_rng;
use rpc::auth as zed_auth;
use scrypt::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Scrypt,
};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, convert::TryFrom, sync::Arc};
use surf::{StatusCode, Url};
use tide::{log, Error, Server};
static CURRENT_GITHUB_USER: &'static str = "current_github_user";
static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
#[derive(Serialize)]
pub struct User {
pub github_login: String,
pub avatar_url: String,
pub is_insider: bool,
pub is_admin: bool,
}
pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
let mut auth_header = request
.header("Authorization")
.ok_or_else(|| {
Error::new(
StatusCode::BadRequest,
anyhow!("missing authorization header"),
)
})?
.last()
.as_str()
.split_whitespace();
let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
Error::new(
StatusCode::BadRequest,
anyhow!("missing user id in authorization header"),
)
})?);
let access_token = auth_header.next().ok_or_else(|| {
Error::new(
StatusCode::BadRequest,
anyhow!("missing access token in authorization header"),
)
})?;
let state = request.state().clone();
let mut credentials_valid = false;
for password_hash in state.db.get_access_token_hashes(user_id).await? {
if verify_access_token(&access_token, &password_hash)? {
credentials_valid = true;
break;
}
}
if !credentials_valid {
Err(Error::new(
StatusCode::Unauthorized,
anyhow!("invalid credentials"),
))?;
}
Ok(user_id)
}
#[async_trait]
pub trait RequestExt {
async fn current_user(&self) -> tide::Result<Option<User>>;
}
#[async_trait]
impl RequestExt for Request {
async fn current_user(&self) -> tide::Result<Option<User>> {
if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
let user = self.db().get_user_by_github_login(&details.login).await?;
Ok(Some(User {
github_login: details.login,
avatar_url: details.avatar_url,
is_insider: user.is_some(),
is_admin: user.map_or(false, |user| user.admin),
}))
} else {
Ok(None)
}
}
}
pub fn build_client(client_id: &str, client_secret: &str) -> Client {
Client::new(
ClientId::new(client_id.to_string()),
Some(oauth2::ClientSecret::new(client_secret.to_string())),
AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
)
}
pub fn add_routes(app: &mut Server<Arc<AppState>>) {
app.at("/sign_in").get(get_sign_in);
app.at("/sign_out").post(post_sign_out);
app.at("/auth_callback").get(get_auth_callback);
app.at("/native_app_signin").get(get_sign_in);
app.at("/native_app_signin_succeeded")
.get(get_app_signin_success);
}
#[derive(Debug, Deserialize)]
struct NativeAppSignInParams {
native_app_port: String,
native_app_public_key: String,
impersonate: Option<String>,
}
async fn get_sign_in(mut request: Request) -> tide::Result {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
request
.session_mut()
.insert("pkce_verifier", pkce_verifier)?;
let mut redirect_url = Url::parse(&format!(
"{}://{}/auth_callback",
request
.header("X-Forwarded-Proto")
.and_then(|values| values.get(0))
.map(|value| value.as_str())
.unwrap_or("http"),
request.host().unwrap()
))?;
let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
if let Some(query) = app_sign_in_params {
let mut redirect_query = redirect_url.query_pairs_mut();
redirect_query
.clear()
.append_pair("native_app_port", &query.native_app_port)
.append_pair("native_app_public_key", &query.native_app_public_key);
if let Some(impersonate) = &query.impersonate {
redirect_query.append_pair("impersonate", impersonate);
}
}
let (auth_url, csrf_token) = request
.state()
.auth_client
.authorize_url(CsrfToken::new_random)
.set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
.set_pkce_challenge(pkce_challenge)
.url();
request
.session_mut()
.insert("auth_csrf_token", csrf_token)?;
Ok(tide::Redirect::new(auth_url).into())
}
async fn get_app_signin_success(_: Request) -> tide::Result {
Ok(tide::Redirect::new("/").into())
}
async fn get_auth_callback(mut request: Request) -> tide::Result {
#[derive(Debug, Deserialize)]
struct Query {
code: String,
state: String,
#[serde(flatten)]
native_app_sign_in_params: Option<NativeAppSignInParams>,
}
let query: Query = request.query()?;
let pkce_verifier = request
.session()
.get("pkce_verifier")
.ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
let csrf_token = request
.session()
.get::<CsrfToken>("auth_csrf_token")
.ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
if &query.state != csrf_token.secret() {
return Err(anyhow!("csrf token does not match").into());
}
let github_access_token = request
.state()
.auth_client
.exchange_code(AuthorizationCode::new(query.code))
.set_pkce_verifier(pkce_verifier)
.request_async(oauth2_surf::http_client)
.await
.context("failed to exchange oauth code")?
.access_token()
.secret()
.clone();
let user_details = request
.state()
.github_client
.user(github_access_token)
.details()
.await
.context("failed to fetch user")?;
let user = request
.db()
.get_user_by_github_login(&user_details.login)
.await?;
request
.session_mut()
.insert(CURRENT_GITHUB_USER, user_details.clone())?;
// When signing in from the native app, generate a new access token for the current user. Return
// a redirect so that the user's browser sends this access token to the locally-running app.
if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
let mut user_id = user.id;
if let Some(impersonated_login) = app_sign_in_params.impersonate {
log::info!("attempting to impersonate user @{}", impersonated_login);
if let Some(user) = request.db().get_users_by_ids(vec![user_id]).await?.first() {
if user.admin {
user_id = request.db().create_user(&impersonated_login, false).await?;
log::info!("impersonating user {}", user_id.0);
} else {
log::info!("refusing to impersonate user");
}
}
}
let access_token = create_access_token(request.db().as_ref(), user_id).await?;
let encrypted_access_token = encrypt_access_token(
&access_token,
app_sign_in_params.native_app_public_key.clone(),
)?;
return Ok(tide::Redirect::new(&format!(
"http://127.0.0.1:{}?user_id={}&access_token={}",
app_sign_in_params.native_app_port, user_id.0, encrypted_access_token,
))
.into());
}
Ok(tide::Redirect::new("/").into())
}
async fn post_sign_out(mut request: Request) -> tide::Result {
request.session_mut().remove(CURRENT_GITHUB_USER);
Ok(tide::Redirect::new("/").into())
}
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result<String> {
let access_token = zed_auth::random_token();
let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?;
db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
.await?;
Ok(access_token)
}
fn hash_access_token(token: &str) -> tide::Result<String> {
// Avoid slow hashing in debug mode.
let params = if cfg!(debug_assertions) {
scrypt::Params::new(1, 1, 1).unwrap()
} else {
scrypt::Params::recommended()
};
Ok(Scrypt
.hash_password(
token.as_bytes(),
None,
params,
&SaltString::generate(thread_rng()),
)?
.to_string())
}
pub fn encrypt_access_token(access_token: &str, public_key: String) -> tide::Result<String> {
let native_app_public_key =
zed_auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
let encrypted_access_token = native_app_public_key
.encrypt_string(&access_token)
.context("failed to encrypt access token with public key")?;
Ok(encrypted_access_token)
}
pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
let hash = PasswordHash::new(hash)?;
Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
}

View file

@ -0,0 +1,20 @@
use anyhow::anyhow;
use std::fs;
fn main() -> anyhow::Result<()> {
let env: toml::map::Map<String, toml::Value> = toml::de::from_str(
&fs::read_to_string("./.env.toml").map_err(|_| anyhow!("no .env.toml file found"))?,
)?;
for (key, value) in env {
let value = match value {
toml::Value::String(value) => value,
toml::Value::Integer(value) => value.to_string(),
toml::Value::Float(value) => value.to_string(),
_ => panic!("unsupported TOML value in .env.toml for key {}", key),
};
println!("export {}=\"{}\"", key, value);
}
Ok(())
}

View file

@ -0,0 +1,85 @@
use db::{Db, UserId};
use rand::prelude::*;
use time::{Duration, OffsetDateTime};
#[allow(unused)]
#[path = "../db.rs"]
mod db;
#[async_std::main]
async fn main() {
let mut rng = StdRng::from_entropy();
let database_url = std::env::var("DATABASE_URL").expect("missing DATABASE_URL env var");
let db = Db::new(&database_url, 5)
.await
.expect("failed to connect to postgres database");
let zed_users = ["nathansobo", "maxbrunsfeld", "as-cii", "iamnbutler"];
let mut zed_user_ids = Vec::<UserId>::new();
for zed_user in zed_users {
if let Some(user) = db
.get_user_by_github_login(zed_user)
.await
.expect("failed to fetch user")
{
zed_user_ids.push(user.id);
} else {
zed_user_ids.push(
db.create_user(zed_user, true)
.await
.expect("failed to insert user"),
);
}
}
let zed_org_id = if let Some(org) = db
.find_org_by_slug("zed")
.await
.expect("failed to fetch org")
{
org.id
} else {
db.create_org("Zed", "zed")
.await
.expect("failed to insert org")
};
let general_channel_id = if let Some(channel) = db
.get_org_channels(zed_org_id)
.await
.expect("failed to fetch channels")
.iter()
.find(|c| c.name == "General")
{
channel.id
} else {
let channel_id = db
.create_org_channel(zed_org_id, "General")
.await
.expect("failed to insert channel");
let now = OffsetDateTime::now_utc();
let max_seconds = Duration::days(100).as_seconds_f64();
let mut timestamps = (0..1000)
.map(|_| now - Duration::seconds_f64(rng.gen_range(0_f64..=max_seconds)))
.collect::<Vec<_>>();
timestamps.sort();
for timestamp in timestamps {
let sender_id = *zed_user_ids.choose(&mut rng).unwrap();
let body = lipsum::lipsum_words(rng.gen_range(1..=50));
db.create_channel_message(channel_id, sender_id, &body, timestamp, rng.gen())
.await
.expect("failed to insert message");
}
channel_id
};
for user_id in zed_user_ids {
db.add_org_member(zed_org_id, user_id, true)
.await
.expect("failed to insert org membership");
db.add_channel_member(general_channel_id, user_id, true)
.await
.expect("failed to insert channel membership");
}
}

View file

@ -0,0 +1,15 @@
use crate::{AppState, Request, RequestExt};
use std::sync::Arc;
use tide::http::mime;
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
app.at("/careers").get(get_careers);
}
async fn get_careers(mut request: Request) -> tide::Result {
let data = request.layout_data().await?;
Ok(tide::Response::builder(200)
.body(request.state().render_template("careers.hbs", &data)?)
.content_type(mime::HTML)
.build())
}

View file

@ -0,0 +1,15 @@
use crate::{AppState, Request, RequestExt};
use std::sync::Arc;
use tide::http::mime;
pub fn add_routes(community: &mut tide::Server<Arc<AppState>>) {
community.at("/community").get(get_community);
}
async fn get_community(mut request: Request) -> tide::Result {
let data = request.layout_data().await?;
Ok(tide::Response::builder(200)
.body(request.state().render_template("community.hbs", &data)?)
.content_type(mime::HTML)
.build())
}

1121
crates/collab/src/db.rs Normal file

File diff suppressed because it is too large Load diff

20
crates/collab/src/env.rs Normal file
View file

@ -0,0 +1,20 @@
use anyhow::anyhow;
use std::fs;
pub fn load_dotenv() -> anyhow::Result<()> {
let env: toml::map::Map<String, toml::Value> = toml::de::from_str(
&fs::read_to_string("./.env.toml").map_err(|_| anyhow!("no .env.toml file found"))?,
)?;
for (key, value) in env {
let value = match value {
toml::Value::String(value) => value,
toml::Value::Integer(value) => value.to_string(),
toml::Value::Float(value) => value.to_string(),
_ => panic!("unsupported TOML value in .env.toml for key {}", key),
};
std::env::set_var(key, value);
}
Ok(())
}

View file

@ -0,0 +1,73 @@
use crate::{AppState, LayoutData, Request, RequestExt};
use async_trait::async_trait;
use serde::Serialize;
use std::sync::Arc;
use tide::http::mime;
pub struct Middleware;
#[async_trait]
impl tide::Middleware<Arc<AppState>> for Middleware {
async fn handle(
&self,
mut request: Request,
next: tide::Next<'_, Arc<AppState>>,
) -> tide::Result {
let app = request.state().clone();
let layout_data = request.layout_data().await?;
let mut response = next.run(request).await;
#[derive(Serialize)]
struct ErrorData {
#[serde(flatten)]
layout: Arc<LayoutData>,
status: u16,
reason: &'static str,
}
if !response.status().is_success() {
response.set_body(app.render_template(
"error.hbs",
&ErrorData {
layout: layout_data,
status: response.status().into(),
reason: response.status().canonical_reason(),
},
)?);
response.set_content_type(mime::HTML);
}
Ok(response)
}
}
// Allow tide Results to accept context like other Results do when
// using anyhow.
pub trait TideResultExt {
fn context<C>(self, cx: C) -> Self
where
C: std::fmt::Display + Send + Sync + 'static;
fn with_context<C, F>(self, f: F) -> Self
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C;
}
impl<T> TideResultExt for tide::Result<T> {
fn context<C>(self, cx: C) -> Self
where
C: std::fmt::Display + Send + Sync + 'static,
{
self.map_err(|e| tide::Error::new(e.status(), e.into_inner().context(cx)))
}
fn with_context<C, F>(self, f: F) -> Self
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
self.map_err(|e| tide::Error::new(e.status(), e.into_inner().context(f())))
}
}

View file

@ -0,0 +1,43 @@
use std::{future::Future, time::Instant};
use async_std::sync::Mutex;
#[derive(Default)]
pub struct Expiring<T>(Mutex<Option<ExpiringState<T>>>);
pub struct ExpiringState<T> {
value: T,
expires_at: Instant,
}
impl<T: Clone> Expiring<T> {
pub async fn get_or_refresh<F, G>(&self, f: F) -> tide::Result<T>
where
F: FnOnce() -> G,
G: Future<Output = tide::Result<(T, Instant)>>,
{
let mut state = self.0.lock().await;
if let Some(state) = state.as_mut() {
if Instant::now() >= state.expires_at {
let (value, expires_at) = f().await?;
state.value = value.clone();
state.expires_at = expires_at;
Ok(value)
} else {
Ok(state.value.clone())
}
} else {
let (value, expires_at) = f().await?;
*state = Some(ExpiringState {
value: value.clone(),
expires_at,
});
Ok(value)
}
}
pub async fn clear(&self) {
self.0.lock().await.take();
}
}

281
crates/collab/src/github.rs Normal file
View file

@ -0,0 +1,281 @@
use crate::expiring::Expiring;
use anyhow::{anyhow, Context};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
future::Future,
sync::Arc,
time::{Duration, Instant},
};
use surf::{http::Method, RequestBuilder, Url};
#[derive(Debug, Deserialize, Serialize)]
pub struct Release {
pub tag_name: String,
pub name: String,
pub body: String,
pub draft: bool,
pub assets: Vec<Asset>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Asset {
pub name: String,
pub url: String,
}
pub struct AppClient {
id: usize,
private_key: String,
jwt_bearer_header: Expiring<String>,
}
#[derive(Deserialize)]
struct Installation {
#[allow(unused)]
id: usize,
}
impl AppClient {
#[cfg(test)]
pub fn test() -> Arc<Self> {
Arc::new(Self {
id: Default::default(),
private_key: Default::default(),
jwt_bearer_header: Default::default(),
})
}
pub fn new(id: usize, private_key: String) -> Arc<Self> {
Arc::new(Self {
id,
private_key,
jwt_bearer_header: Default::default(),
})
}
pub async fn repo(self: &Arc<Self>, nwo: String) -> tide::Result<RepoClient> {
let installation: Installation = self
.request(
Method::Get,
&format!("/repos/{}/installation", &nwo),
|refresh| self.bearer_header(refresh),
)
.await?;
Ok(RepoClient {
app: self.clone(),
nwo,
installation_id: installation.id,
installation_token_header: Default::default(),
})
}
pub fn user(self: &Arc<Self>, access_token: String) -> UserClient {
UserClient {
app: self.clone(),
access_token,
}
}
async fn request<T, F, G>(
&self,
method: Method,
path: &str,
get_auth_header: F,
) -> tide::Result<T>
where
T: DeserializeOwned,
F: Fn(bool) -> G,
G: Future<Output = tide::Result<String>>,
{
let mut retried = false;
loop {
let response = RequestBuilder::new(
method,
Url::parse(&format!("https://api.github.com{}", path))?,
)
.header("Accept", "application/vnd.github.v3+json")
.header("Authorization", get_auth_header(retried).await?)
.recv_json()
.await;
if let Err(error) = response.as_ref() {
if error.status() == 401 && !retried {
retried = true;
continue;
}
}
return response;
}
}
async fn bearer_header(&self, refresh: bool) -> tide::Result<String> {
if refresh {
self.jwt_bearer_header.clear().await;
}
self.jwt_bearer_header
.get_or_refresh(|| async {
use jwt_simple::{algorithms::RS256KeyPair, prelude::*};
use std::time;
let key_pair = RS256KeyPair::from_pem(&self.private_key)
.with_context(|| format!("invalid private key {:?}", self.private_key))?;
let mut claims = Claims::create(Duration::from_mins(10));
claims.issued_at = Some(Clock::now_since_epoch() - Duration::from_mins(1));
claims.issuer = Some(self.id.to_string());
let token = key_pair.sign(claims).context("failed to sign claims")?;
let expires_at = time::Instant::now() + time::Duration::from_secs(9 * 60);
Ok((format!("Bearer {}", token), expires_at))
})
.await
}
async fn installation_token_header(
&self,
header: &Expiring<String>,
installation_id: usize,
refresh: bool,
) -> tide::Result<String> {
if refresh {
header.clear().await;
}
header
.get_or_refresh(|| async {
#[derive(Debug, Deserialize)]
struct AccessToken {
token: String,
}
let access_token: AccessToken = self
.request(
Method::Post,
&format!("/app/installations/{}/access_tokens", installation_id),
|refresh| self.bearer_header(refresh),
)
.await?;
let header = format!("Token {}", access_token.token);
let expires_at = Instant::now() + Duration::from_secs(60 * 30);
Ok((header, expires_at))
})
.await
}
}
pub struct RepoClient {
app: Arc<AppClient>,
nwo: String,
installation_id: usize,
installation_token_header: Expiring<String>,
}
impl RepoClient {
#[cfg(test)]
pub fn test(app_client: &Arc<AppClient>) -> Self {
Self {
app: app_client.clone(),
nwo: String::new(),
installation_id: 0,
installation_token_header: Default::default(),
}
}
pub async fn releases(&self) -> tide::Result<Vec<Release>> {
self.get(&format!("/repos/{}/releases?per_page=100", self.nwo))
.await
}
pub async fn release_asset(&self, tag: &str, name: &str) -> tide::Result<surf::Body> {
let release: Release = self
.get(&format!("/repos/{}/releases/tags/{}", self.nwo, tag))
.await?;
let asset = release
.assets
.iter()
.find(|asset| asset.name == name)
.ok_or_else(|| anyhow!("no asset found with name {}", name))?;
let request = surf::get(&asset.url)
.header("Accept", "application/octet-stream'")
.header(
"Authorization",
self.installation_token_header(false).await?,
);
let client = surf::client();
let mut response = client.send(request).await?;
// Avoid using `surf::middleware::Redirect` because that type forwards
// the original request headers to the redirect URI. In this case, the
// redirect will be to S3, which forbids us from supplying an
// `Authorization` header.
if response.status().is_redirection() {
if let Some(url) = response.header("location") {
let request = surf::get(url.as_str()).header("Accept", "application/octet-stream");
response = client.send(request).await?;
}
}
if !response.status().is_success() {
Err(anyhow!("failed to fetch release asset {} {}", tag, name))?;
}
Ok(response.take_body())
}
async fn get<T: DeserializeOwned>(&self, path: &str) -> tide::Result<T> {
self.request::<T>(Method::Get, path).await
}
async fn request<T: DeserializeOwned>(&self, method: Method, path: &str) -> tide::Result<T> {
Ok(self
.app
.request(method, path, |refresh| {
self.installation_token_header(refresh)
})
.await?)
}
async fn installation_token_header(&self, refresh: bool) -> tide::Result<String> {
self.app
.installation_token_header(
&self.installation_token_header,
self.installation_id,
refresh,
)
.await
}
}
pub struct UserClient {
app: Arc<AppClient>,
access_token: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct User {
pub login: String,
pub avatar_url: String,
}
impl UserClient {
pub async fn details(&self) -> tide::Result<User> {
Ok(self
.app
.request(Method::Get, "/user", |_| async {
Ok(self.access_token_header())
})
.await?)
}
fn access_token_header(&self) -> String {
format!("Token {}", self.access_token)
}
}

80
crates/collab/src/home.rs Normal file
View file

@ -0,0 +1,80 @@
use crate::{AppState, Request, RequestExt as _};
use log::as_serde;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tide::{http::mime, Server};
pub fn add_routes(app: &mut Server<Arc<AppState>>) {
app.at("/").get(get_home);
app.at("/signups").post(post_signup);
app.at("/releases/:tag_name/:name").get(get_release_asset);
}
async fn get_home(mut request: Request) -> tide::Result {
let data = request.layout_data().await?;
Ok(tide::Response::builder(200)
.body(request.state().render_template("home.hbs", &data)?)
.content_type(mime::HTML)
.build())
}
async fn post_signup(mut request: Request) -> tide::Result {
#[derive(Debug, Deserialize, Serialize)]
struct Form {
github_login: String,
email_address: String,
about: String,
#[serde(default)]
wants_releases: bool,
#[serde(default)]
wants_updates: bool,
#[serde(default)]
wants_community: bool,
}
let mut form: Form = request.body_form().await?;
form.github_login = form
.github_login
.strip_prefix("@")
.map(str::to_string)
.unwrap_or(form.github_login);
log::info!(form = as_serde!(form); "signup submitted");
// Save signup in the database
request
.db()
.create_signup(
&form.github_login,
&form.email_address,
&form.about,
form.wants_releases,
form.wants_updates,
form.wants_community,
)
.await?;
let layout_data = request.layout_data().await?;
Ok(tide::Response::builder(200)
.body(
request
.state()
.render_template("signup.hbs", &layout_data)?,
)
.content_type(mime::HTML)
.build())
}
async fn get_release_asset(request: Request) -> tide::Result {
let body = request
.state()
.repo_client
.release_asset(request.param("tag_name")?, request.param("name")?)
.await?;
Ok(tide::Response::builder(200)
.header("Cache-Control", "no-transform")
.content_type(mime::BYTE_STREAM)
.body(body)
.build())
}

205
crates/collab/src/main.rs Normal file
View file

@ -0,0 +1,205 @@
mod admin;
mod api;
mod assets;
mod auth;
mod careers;
mod community;
mod db;
mod env;
mod errors;
mod expiring;
mod github;
mod home;
mod releases;
mod rpc;
mod team;
use self::errors::TideResultExt as _;
use ::rpc::Peer;
use anyhow::Result;
use async_std::net::TcpListener;
use async_trait::async_trait;
use auth::RequestExt as _;
use db::{Db, PostgresDb};
use handlebars::{Handlebars, TemplateRenderError};
use parking_lot::RwLock;
use rust_embed::RustEmbed;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use surf::http::cookies::SameSite;
use tide::sessions::SessionMiddleware;
use tide_compress::CompressMiddleware;
type Request = tide::Request<Arc<AppState>>;
#[derive(RustEmbed)]
#[folder = "templates"]
struct Templates;
#[derive(Default, Deserialize)]
pub struct Config {
pub http_port: u16,
pub database_url: String,
pub session_secret: String,
pub github_app_id: usize,
pub github_client_id: String,
pub github_client_secret: String,
pub github_private_key: String,
pub api_token: String,
}
pub struct AppState {
db: Arc<dyn Db>,
handlebars: RwLock<Handlebars<'static>>,
auth_client: auth::Client,
github_client: Arc<github::AppClient>,
repo_client: github::RepoClient,
config: Config,
}
impl AppState {
async fn new(config: Config) -> tide::Result<Arc<Self>> {
let db = PostgresDb::new(&config.database_url, 5).await?;
let github_client =
github::AppClient::new(config.github_app_id, config.github_private_key.clone());
let repo_client = github_client
.repo("zed-industries/zed".into())
.await
.context("failed to initialize github client")?;
let this = Self {
db: Arc::new(db),
handlebars: Default::default(),
auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret),
github_client,
repo_client,
config,
};
this.register_partials();
Ok(Arc::new(this))
}
fn register_partials(&self) {
for path in Templates::iter() {
if let Some(partial_name) = path
.strip_prefix("partials/")
.and_then(|path| path.strip_suffix(".hbs"))
{
let partial = Templates::get(path.as_ref()).unwrap();
self.handlebars
.write()
.register_partial(partial_name, std::str::from_utf8(&partial.data).unwrap())
.unwrap()
}
}
}
fn render_template(
&self,
path: &'static str,
data: &impl Serialize,
) -> Result<String, TemplateRenderError> {
#[cfg(debug_assertions)]
self.register_partials();
self.handlebars.read().render_template(
std::str::from_utf8(&Templates::get(path).unwrap().data).unwrap(),
data,
)
}
}
#[async_trait]
trait RequestExt {
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
fn db(&self) -> &Arc<dyn Db>;
}
#[async_trait]
impl RequestExt for Request {
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>> {
if self.ext::<Arc<LayoutData>>().is_none() {
self.set_ext(Arc::new(LayoutData {
current_user: self.current_user().await?,
}));
}
Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
}
fn db(&self) -> &Arc<dyn Db> {
&self.state().db
}
}
#[derive(Serialize)]
struct LayoutData {
current_user: Option<auth::User>,
}
#[async_std::main]
async fn main() -> tide::Result<()> {
if std::env::var("LOG_JSON").is_ok() {
json_env_logger::init();
} else {
tide::log::start();
}
if let Err(error) = env::load_dotenv() {
log::error!(
"error loading .env.toml (this is expected in production): {}",
error
);
}
let config = envy::from_env::<Config>().expect("error loading config");
let state = AppState::new(config).await?;
let rpc = Peer::new();
run_server(
state.clone(),
rpc,
TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)).await?,
)
.await?;
Ok(())
}
pub async fn run_server(
state: Arc<AppState>,
rpc: Arc<Peer>,
listener: TcpListener,
) -> tide::Result<()> {
let mut web = tide::with_state(state.clone());
web.with(CompressMiddleware::new());
web.with(
SessionMiddleware::new(
db::SessionStore::new_with_table_name(&state.config.database_url, "sessions")
.await
.unwrap(),
state.config.session_secret.as_bytes(),
)
.with_same_site_policy(SameSite::Lax), // Required obtain our session in /auth_callback
);
web.with(errors::Middleware);
api::add_routes(&mut web);
home::add_routes(&mut web);
team::add_routes(&mut web);
careers::add_routes(&mut web);
releases::add_routes(&mut web);
community::add_routes(&mut web);
admin::add_routes(&mut web);
auth::add_routes(&mut web);
let mut assets = tide::new();
assets.with(CompressMiddleware::new());
assets::add_routes(&mut assets);
let mut app = tide::with_state(state.clone());
rpc::add_routes(&mut app, &rpc);
app.at("/").nest(web);
app.at("/static").nest(assets);
app.listen(listener).await?;
Ok(())
}

View file

@ -0,0 +1,54 @@
use crate::{
auth::RequestExt as _, github::Release, AppState, LayoutData, Request, RequestExt as _,
};
use comrak::ComrakOptions;
use serde::Serialize;
use std::sync::Arc;
use tide::http::mime;
pub fn add_routes(releases: &mut tide::Server<Arc<AppState>>) {
releases.at("/releases").get(get_releases);
}
async fn get_releases(mut request: Request) -> tide::Result {
#[derive(Serialize)]
struct ReleasesData {
#[serde(flatten)]
layout: Arc<LayoutData>,
releases: Option<Vec<Release>>,
}
let mut data = ReleasesData {
layout: request.layout_data().await?,
releases: None,
};
if let Some(user) = request.current_user().await? {
if user.is_insider {
data.releases = Some(
request
.state()
.repo_client
.releases()
.await?
.into_iter()
.filter_map(|mut release| {
if release.draft {
None
} else {
let mut options = ComrakOptions::default();
options.render.unsafe_ = true; // Allow raw HTML in the markup. We control these release notes anyway.
release.body = comrak::markdown_to_html(&release.body, &options);
Some(release)
}
})
.collect(),
);
}
}
Ok(tide::Response::builder(200)
.body(request.state().render_template("releases.hbs", &data)?)
.content_type(mime::HTML)
.build())
}

6137
crates/collab/src/rpc.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,790 @@
use crate::db::{ChannelId, UserId};
use anyhow::anyhow;
use collections::{BTreeMap, HashMap, HashSet};
use rpc::{proto, ConnectionId};
use std::{collections::hash_map, path::PathBuf};
#[derive(Default)]
pub struct Store {
connections: HashMap<ConnectionId, ConnectionState>,
connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
projects: HashMap<u64, Project>,
visible_projects_by_user_id: HashMap<UserId, HashSet<u64>>,
channels: HashMap<ChannelId, Channel>,
next_project_id: u64,
}
struct ConnectionState {
user_id: UserId,
projects: HashSet<u64>,
channels: HashSet<ChannelId>,
}
pub struct Project {
pub host_connection_id: ConnectionId,
pub host_user_id: UserId,
pub share: Option<ProjectShare>,
pub worktrees: HashMap<u64, Worktree>,
pub language_servers: Vec<proto::LanguageServer>,
}
pub struct Worktree {
pub authorized_user_ids: Vec<UserId>,
pub root_name: String,
pub visible: bool,
}
#[derive(Default)]
pub struct ProjectShare {
pub guests: HashMap<ConnectionId, (ReplicaId, UserId)>,
pub active_replica_ids: HashSet<ReplicaId>,
pub worktrees: HashMap<u64, WorktreeShare>,
}
#[derive(Default)]
pub struct WorktreeShare {
pub entries: HashMap<u64, proto::Entry>,
pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
}
#[derive(Default)]
pub struct Channel {
pub connection_ids: HashSet<ConnectionId>,
}
pub type ReplicaId = u16;
#[derive(Default)]
pub struct RemovedConnectionState {
pub hosted_projects: HashMap<u64, Project>,
pub guest_project_ids: HashMap<u64, Vec<ConnectionId>>,
pub contact_ids: HashSet<UserId>,
}
pub struct JoinedProject<'a> {
pub replica_id: ReplicaId,
pub project: &'a Project,
}
pub struct UnsharedProject {
pub connection_ids: Vec<ConnectionId>,
pub authorized_user_ids: Vec<UserId>,
}
pub struct LeftProject {
pub connection_ids: Vec<ConnectionId>,
pub authorized_user_ids: Vec<UserId>,
}
impl Store {
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
self.connections.insert(
connection_id,
ConnectionState {
user_id,
projects: Default::default(),
channels: Default::default(),
},
);
self.connections_by_user_id
.entry(user_id)
.or_default()
.insert(connection_id);
}
pub fn remove_connection(
&mut self,
connection_id: ConnectionId,
) -> tide::Result<RemovedConnectionState> {
let connection = if let Some(connection) = self.connections.remove(&connection_id) {
connection
} else {
return Err(anyhow!("no such connection"))?;
};
for channel_id in &connection.channels {
if let Some(channel) = self.channels.get_mut(&channel_id) {
channel.connection_ids.remove(&connection_id);
}
}
let user_connections = self
.connections_by_user_id
.get_mut(&connection.user_id)
.unwrap();
user_connections.remove(&connection_id);
if user_connections.is_empty() {
self.connections_by_user_id.remove(&connection.user_id);
}
let mut result = RemovedConnectionState::default();
for project_id in connection.projects.clone() {
if let Ok(project) = self.unregister_project(project_id, connection_id) {
result.contact_ids.extend(project.authorized_user_ids());
result.hosted_projects.insert(project_id, project);
} else if let Ok(project) = self.leave_project(connection_id, project_id) {
result
.guest_project_ids
.insert(project_id, project.connection_ids);
result.contact_ids.extend(project.authorized_user_ids);
}
}
#[cfg(test)]
self.check_invariants();
Ok(result)
}
#[cfg(test)]
pub fn channel(&self, id: ChannelId) -> Option<&Channel> {
self.channels.get(&id)
}
pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.insert(channel_id);
self.channels
.entry(channel_id)
.or_default()
.connection_ids
.insert(connection_id);
}
}
pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.remove(&channel_id);
if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
entry.get_mut().connection_ids.remove(&connection_id);
if entry.get_mut().connection_ids.is_empty() {
entry.remove();
}
}
}
}
pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
Ok(self
.connections
.get(&connection_id)
.ok_or_else(|| anyhow!("unknown connection"))?
.user_id)
}
pub fn connection_ids_for_user<'a>(
&'a self,
user_id: UserId,
) -> impl 'a + Iterator<Item = ConnectionId> {
self.connections_by_user_id
.get(&user_id)
.into_iter()
.flatten()
.copied()
}
pub fn contacts_for_user(&self, user_id: UserId) -> Vec<proto::Contact> {
let mut contacts = HashMap::default();
for project_id in self
.visible_projects_by_user_id
.get(&user_id)
.unwrap_or(&HashSet::default())
{
let project = &self.projects[project_id];
let mut guests = HashSet::default();
if let Ok(share) = project.share() {
for guest_connection_id in share.guests.keys() {
if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
guests.insert(user_id.to_proto());
}
}
}
if let Ok(host_user_id) = self.user_id_for_connection(project.host_connection_id) {
let mut worktree_root_names = project
.worktrees
.values()
.filter(|worktree| worktree.visible)
.map(|worktree| worktree.root_name.clone())
.collect::<Vec<_>>();
worktree_root_names.sort_unstable();
contacts
.entry(host_user_id)
.or_insert_with(|| proto::Contact {
user_id: host_user_id.to_proto(),
projects: Vec::new(),
})
.projects
.push(proto::ProjectMetadata {
id: *project_id,
worktree_root_names,
is_shared: project.share.is_some(),
guests: guests.into_iter().collect(),
});
}
}
contacts.into_values().collect()
}
pub fn register_project(
&mut self,
host_connection_id: ConnectionId,
host_user_id: UserId,
) -> u64 {
let project_id = self.next_project_id;
self.projects.insert(
project_id,
Project {
host_connection_id,
host_user_id,
share: None,
worktrees: Default::default(),
language_servers: Default::default(),
},
);
if let Some(connection) = self.connections.get_mut(&host_connection_id) {
connection.projects.insert(project_id);
}
self.next_project_id += 1;
project_id
}
pub fn register_worktree(
&mut self,
project_id: u64,
worktree_id: u64,
connection_id: ConnectionId,
worktree: Worktree,
) -> tide::Result<()> {
let project = self
.projects
.get_mut(&project_id)
.ok_or_else(|| anyhow!("no such project"))?;
if project.host_connection_id == connection_id {
for authorized_user_id in &worktree.authorized_user_ids {
self.visible_projects_by_user_id
.entry(*authorized_user_id)
.or_default()
.insert(project_id);
}
project.worktrees.insert(worktree_id, worktree);
if let Ok(share) = project.share_mut() {
share.worktrees.insert(worktree_id, Default::default());
}
#[cfg(test)]
self.check_invariants();
Ok(())
} else {
Err(anyhow!("no such project"))?
}
}
pub fn unregister_project(
&mut self,
project_id: u64,
connection_id: ConnectionId,
) -> tide::Result<Project> {
match self.projects.entry(project_id) {
hash_map::Entry::Occupied(e) => {
if e.get().host_connection_id == connection_id {
for user_id in e.get().authorized_user_ids() {
if let hash_map::Entry::Occupied(mut projects) =
self.visible_projects_by_user_id.entry(user_id)
{
projects.get_mut().remove(&project_id);
}
}
let project = e.remove();
if let Some(host_connection) = self.connections.get_mut(&connection_id) {
host_connection.projects.remove(&project_id);
}
if let Some(share) = &project.share {
for guest_connection in share.guests.keys() {
if let Some(connection) = self.connections.get_mut(&guest_connection) {
connection.projects.remove(&project_id);
}
}
}
#[cfg(test)]
self.check_invariants();
Ok(project)
} else {
Err(anyhow!("no such project"))?
}
}
hash_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?,
}
}
pub fn unregister_worktree(
&mut self,
project_id: u64,
worktree_id: u64,
acting_connection_id: ConnectionId,
) -> tide::Result<(Worktree, Vec<ConnectionId>)> {
let project = self
.projects
.get_mut(&project_id)
.ok_or_else(|| anyhow!("no such project"))?;
if project.host_connection_id != acting_connection_id {
Err(anyhow!("not your worktree"))?;
}
let worktree = project
.worktrees
.remove(&worktree_id)
.ok_or_else(|| anyhow!("no such worktree"))?;
let mut guest_connection_ids = Vec::new();
if let Ok(share) = project.share_mut() {
guest_connection_ids.extend(share.guests.keys());
share.worktrees.remove(&worktree_id);
}
for authorized_user_id in &worktree.authorized_user_ids {
if let Some(visible_projects) =
self.visible_projects_by_user_id.get_mut(authorized_user_id)
{
if !project.has_authorized_user_id(*authorized_user_id) {
visible_projects.remove(&project_id);
}
}
}
#[cfg(test)]
self.check_invariants();
Ok((worktree, guest_connection_ids))
}
pub fn share_project(&mut self, project_id: u64, connection_id: ConnectionId) -> bool {
if let Some(project) = self.projects.get_mut(&project_id) {
if project.host_connection_id == connection_id {
let mut share = ProjectShare::default();
for worktree_id in project.worktrees.keys() {
share.worktrees.insert(*worktree_id, Default::default());
}
project.share = Some(share);
return true;
}
}
false
}
pub fn unshare_project(
&mut self,
project_id: u64,
acting_connection_id: ConnectionId,
) -> tide::Result<UnsharedProject> {
let project = if let Some(project) = self.projects.get_mut(&project_id) {
project
} else {
return Err(anyhow!("no such project"))?;
};
if project.host_connection_id != acting_connection_id {
return Err(anyhow!("not your project"))?;
}
let connection_ids = project.connection_ids();
let authorized_user_ids = project.authorized_user_ids();
if let Some(share) = project.share.take() {
for connection_id in share.guests.into_keys() {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.projects.remove(&project_id);
}
}
#[cfg(test)]
self.check_invariants();
Ok(UnsharedProject {
connection_ids,
authorized_user_ids,
})
} else {
Err(anyhow!("project is not shared"))?
}
}
pub fn update_diagnostic_summary(
&mut self,
project_id: u64,
worktree_id: u64,
connection_id: ConnectionId,
summary: proto::DiagnosticSummary,
) -> tide::Result<Vec<ConnectionId>> {
let project = self
.projects
.get_mut(&project_id)
.ok_or_else(|| anyhow!("no such project"))?;
if project.host_connection_id == connection_id {
let worktree = project
.share_mut()?
.worktrees
.get_mut(&worktree_id)
.ok_or_else(|| anyhow!("no such worktree"))?;
worktree
.diagnostic_summaries
.insert(summary.path.clone().into(), summary);
return Ok(project.connection_ids());
}
Err(anyhow!("no such worktree"))?
}
pub fn start_language_server(
&mut self,
project_id: u64,
connection_id: ConnectionId,
language_server: proto::LanguageServer,
) -> tide::Result<Vec<ConnectionId>> {
let project = self
.projects
.get_mut(&project_id)
.ok_or_else(|| anyhow!("no such project"))?;
if project.host_connection_id == connection_id {
project.language_servers.push(language_server);
return Ok(project.connection_ids());
}
Err(anyhow!("no such project"))?
}
pub fn join_project(
&mut self,
connection_id: ConnectionId,
user_id: UserId,
project_id: u64,
) -> tide::Result<JoinedProject> {
let connection = self
.connections
.get_mut(&connection_id)
.ok_or_else(|| anyhow!("no such connection"))?;
let project = self
.projects
.get_mut(&project_id)
.and_then(|project| {
if project.has_authorized_user_id(user_id) {
Some(project)
} else {
None
}
})
.ok_or_else(|| anyhow!("no such project"))?;
let share = project.share_mut()?;
connection.projects.insert(project_id);
let mut replica_id = 1;
while share.active_replica_ids.contains(&replica_id) {
replica_id += 1;
}
share.active_replica_ids.insert(replica_id);
share.guests.insert(connection_id, (replica_id, user_id));
#[cfg(test)]
self.check_invariants();
Ok(JoinedProject {
replica_id,
project: &self.projects[&project_id],
})
}
pub fn leave_project(
&mut self,
connection_id: ConnectionId,
project_id: u64,
) -> tide::Result<LeftProject> {
let project = self
.projects
.get_mut(&project_id)
.ok_or_else(|| anyhow!("no such project"))?;
let share = project
.share
.as_mut()
.ok_or_else(|| anyhow!("project is not shared"))?;
let (replica_id, _) = share
.guests
.remove(&connection_id)
.ok_or_else(|| anyhow!("cannot leave a project before joining it"))?;
share.active_replica_ids.remove(&replica_id);
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.projects.remove(&project_id);
}
let connection_ids = project.connection_ids();
let authorized_user_ids = project.authorized_user_ids();
#[cfg(test)]
self.check_invariants();
Ok(LeftProject {
connection_ids,
authorized_user_ids,
})
}
pub fn update_worktree(
&mut self,
connection_id: ConnectionId,
project_id: u64,
worktree_id: u64,
removed_entries: &[u64],
updated_entries: &[proto::Entry],
) -> tide::Result<Vec<ConnectionId>> {
let project = self.write_project(project_id, connection_id)?;
let worktree = project
.share_mut()?
.worktrees
.get_mut(&worktree_id)
.ok_or_else(|| anyhow!("no such worktree"))?;
for entry_id in removed_entries {
worktree.entries.remove(&entry_id);
}
for entry in updated_entries {
worktree.entries.insert(entry.id, entry.clone());
}
let connection_ids = project.connection_ids();
#[cfg(test)]
self.check_invariants();
Ok(connection_ids)
}
pub fn project_connection_ids(
&self,
project_id: u64,
acting_connection_id: ConnectionId,
) -> tide::Result<Vec<ConnectionId>> {
Ok(self
.read_project(project_id, acting_connection_id)?
.connection_ids())
}
pub fn channel_connection_ids(&self, channel_id: ChannelId) -> tide::Result<Vec<ConnectionId>> {
Ok(self
.channels
.get(&channel_id)
.ok_or_else(|| anyhow!("no such channel"))?
.connection_ids())
}
#[cfg(test)]
pub fn project(&self, project_id: u64) -> Option<&Project> {
self.projects.get(&project_id)
}
pub fn read_project(
&self,
project_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&Project> {
let project = self
.projects
.get(&project_id)
.ok_or_else(|| anyhow!("no such project"))?;
if project.host_connection_id == connection_id
|| project
.share
.as_ref()
.ok_or_else(|| anyhow!("project is not shared"))?
.guests
.contains_key(&connection_id)
{
Ok(project)
} else {
Err(anyhow!("no such project"))?
}
}
fn write_project(
&mut self,
project_id: u64,
connection_id: ConnectionId,
) -> tide::Result<&mut Project> {
let project = self
.projects
.get_mut(&project_id)
.ok_or_else(|| anyhow!("no such project"))?;
if project.host_connection_id == connection_id
|| project
.share
.as_ref()
.ok_or_else(|| anyhow!("project is not shared"))?
.guests
.contains_key(&connection_id)
{
Ok(project)
} else {
Err(anyhow!("no such project"))?
}
}
#[cfg(test)]
fn check_invariants(&self) {
for (connection_id, connection) in &self.connections {
for project_id in &connection.projects {
let project = &self.projects.get(&project_id).unwrap();
if project.host_connection_id != *connection_id {
assert!(project
.share
.as_ref()
.unwrap()
.guests
.contains_key(connection_id));
}
if let Some(share) = project.share.as_ref() {
for (worktree_id, worktree) in share.worktrees.iter() {
let mut paths = HashMap::default();
for entry in worktree.entries.values() {
let prev_entry = paths.insert(&entry.path, entry);
assert_eq!(
prev_entry,
None,
"worktree {:?}, duplicate path for entries {:?} and {:?}",
worktree_id,
prev_entry.unwrap(),
entry
);
}
}
}
}
for channel_id in &connection.channels {
let channel = self.channels.get(channel_id).unwrap();
assert!(channel.connection_ids.contains(connection_id));
}
assert!(self
.connections_by_user_id
.get(&connection.user_id)
.unwrap()
.contains(connection_id));
}
for (user_id, connection_ids) in &self.connections_by_user_id {
for connection_id in connection_ids {
assert_eq!(
self.connections.get(connection_id).unwrap().user_id,
*user_id
);
}
}
for (project_id, project) in &self.projects {
let host_connection = self.connections.get(&project.host_connection_id).unwrap();
assert!(host_connection.projects.contains(project_id));
for authorized_user_ids in project.authorized_user_ids() {
let visible_project_ids = self
.visible_projects_by_user_id
.get(&authorized_user_ids)
.unwrap();
assert!(visible_project_ids.contains(project_id));
}
if let Some(share) = &project.share {
for guest_connection_id in share.guests.keys() {
let guest_connection = self.connections.get(guest_connection_id).unwrap();
assert!(guest_connection.projects.contains(project_id));
}
assert_eq!(share.active_replica_ids.len(), share.guests.len(),);
assert_eq!(
share.active_replica_ids,
share
.guests
.values()
.map(|(replica_id, _)| *replica_id)
.collect::<HashSet<_>>(),
);
}
}
for (user_id, visible_project_ids) in &self.visible_projects_by_user_id {
for project_id in visible_project_ids {
let project = self.projects.get(project_id).unwrap();
assert!(project.authorized_user_ids().contains(user_id));
}
}
for (channel_id, channel) in &self.channels {
for connection_id in &channel.connection_ids {
let connection = self.connections.get(connection_id).unwrap();
assert!(connection.channels.contains(channel_id));
}
}
}
}
impl Project {
pub fn has_authorized_user_id(&self, user_id: UserId) -> bool {
self.worktrees
.values()
.any(|worktree| worktree.authorized_user_ids.contains(&user_id))
}
pub fn authorized_user_ids(&self) -> Vec<UserId> {
let mut ids = self
.worktrees
.values()
.flat_map(|worktree| worktree.authorized_user_ids.iter())
.copied()
.collect::<Vec<_>>();
ids.sort_unstable();
ids.dedup();
ids
}
pub fn guest_connection_ids(&self) -> Vec<ConnectionId> {
if let Some(share) = &self.share {
share.guests.keys().copied().collect()
} else {
Vec::new()
}
}
pub fn connection_ids(&self) -> Vec<ConnectionId> {
if let Some(share) = &self.share {
share
.guests
.keys()
.copied()
.chain(Some(self.host_connection_id))
.collect()
} else {
vec![self.host_connection_id]
}
}
pub fn share(&self) -> tide::Result<&ProjectShare> {
Ok(self
.share
.as_ref()
.ok_or_else(|| anyhow!("worktree is not shared"))?)
}
fn share_mut(&mut self) -> tide::Result<&mut ProjectShare> {
Ok(self
.share
.as_mut()
.ok_or_else(|| anyhow!("worktree is not shared"))?)
}
}
impl Channel {
fn connection_ids(&self) -> Vec<ConnectionId> {
self.connection_ids.iter().copied().collect()
}
}

15
crates/collab/src/team.rs Normal file
View file

@ -0,0 +1,15 @@
use crate::{AppState, Request, RequestExt};
use std::sync::Arc;
use tide::http::mime;
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
app.at("/team").get(get_team);
}
async fn get_team(mut request: Request) -> tide::Result {
let data = request.layout_data().await?;
Ok(tide::Response::builder(200)
.body(request.state().render_template("team.hbs", &data)?)
.content_type(mime::HTML)
.build())
}