diff --git a/crates/server/src/api.rs b/crates/server/src/api.rs index 0999a28d90..69b60fe9ec 100644 --- a/crates/server/src/api.rs +++ b/crates/server/src/api.rs @@ -111,7 +111,7 @@ async fn create_access_token(request: Request) -> tide::Result { .get_user_by_github_login(request.param("github_login")?) .await? .ok_or_else(|| surf::Error::from_str(StatusCode::NotFound, "user not found"))?; - let access_token = auth::create_access_token(request.db(), user.id).await?; + let access_token = auth::create_access_token(request.db().as_ref(), user.id).await?; #[derive(Deserialize)] struct QueryParams { diff --git a/crates/server/src/auth.rs b/crates/server/src/auth.rs index 1fbd137d12..91136b46d0 100644 --- a/crates/server/src/auth.rs +++ b/crates/server/src/auth.rs @@ -234,7 +234,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result { let mut user_id = user.id; if let Some(impersonated_login) = app_sign_in_params.impersonate { log::info!("attempting to impersonate user @{}", impersonated_login); - if let Some(user) = request.db().get_users_by_ids([user_id]).await?.first() { + if let Some(user) = request.db().get_users_by_ids(vec![user_id]).await?.first() { if user.admin { user_id = request.db().create_user(&impersonated_login, false).await?; log::info!("impersonating user {}", user_id.0); @@ -244,7 +244,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result { } } - let access_token = create_access_token(request.db(), user_id).await?; + let access_token = create_access_token(request.db().as_ref(), user_id).await?; let encrypted_access_token = encrypt_access_token( &access_token, app_sign_in_params.native_app_public_key.clone(), @@ -267,7 +267,7 @@ async fn post_sign_out(mut request: Request) -> tide::Result { const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; -pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result { +pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result { let access_token = zed_auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; diff --git a/crates/server/src/db.rs b/crates/server/src/db.rs index d9f7fc85b6..37e35be5f8 100644 --- a/crates/server/src/db.rs +++ b/crates/server/src/db.rs @@ -1,11 +1,12 @@ use anyhow::Context; -use async_std::task::{block_on, yield_now}; -use serde::Serialize; -use sqlx::{types::Uuid, FromRow, Result}; -use time::OffsetDateTime; - +use anyhow::Result; pub use async_sqlx_session::PostgresSessionStore as SessionStore; +use async_std::task::{block_on, yield_now}; +use async_trait::async_trait; +use serde::Serialize; pub use sqlx::postgres::PgPoolOptions as DbOptions; +use sqlx::{types::Uuid, FromRow}; +use time::OffsetDateTime; macro_rules! test_support { ($self:ident, { $($token:tt)* }) => {{ @@ -21,13 +22,77 @@ macro_rules! test_support { }}; } -#[derive(Clone)] -pub struct Db { +#[async_trait] +pub trait Db: Send + Sync { + async fn create_signup( + &self, + github_login: &str, + email_address: &str, + about: &str, + wants_releases: bool, + wants_updates: bool, + wants_community: bool, + ) -> Result; + async fn get_all_signups(&self) -> Result>; + async fn destroy_signup(&self, id: SignupId) -> Result<()>; + async fn create_user(&self, github_login: &str, admin: bool) -> Result; + async fn get_all_users(&self) -> Result>; + async fn get_user_by_id(&self, id: UserId) -> Result>; + async fn get_users_by_ids(&self, ids: Vec) -> Result>; + async fn get_user_by_github_login(&self, github_login: &str) -> Result>; + async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>; + async fn destroy_user(&self, id: UserId) -> Result<()>; + async fn create_access_token_hash( + &self, + user_id: UserId, + access_token_hash: &str, + max_access_token_count: usize, + ) -> Result<()>; + async fn get_access_token_hashes(&self, user_id: UserId) -> Result>; + #[cfg(any(test, feature = "seed-support"))] + async fn find_org_by_slug(&self, slug: &str) -> Result>; + #[cfg(any(test, feature = "seed-support"))] + async fn create_org(&self, name: &str, slug: &str) -> Result; + #[cfg(any(test, feature = "seed-support"))] + async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>; + #[cfg(any(test, feature = "seed-support"))] + async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result; + #[cfg(any(test, feature = "seed-support"))] + async fn get_org_channels(&self, org_id: OrgId) -> Result>; + async fn get_accessible_channels(&self, user_id: UserId) -> Result>; + async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId) + -> Result; + #[cfg(any(test, feature = "seed-support"))] + async fn add_channel_member( + &self, + channel_id: ChannelId, + user_id: UserId, + is_admin: bool, + ) -> Result<()>; + async fn create_channel_message( + &self, + channel_id: ChannelId, + sender_id: UserId, + body: &str, + timestamp: OffsetDateTime, + nonce: u128, + ) -> Result; + async fn get_channel_messages( + &self, + channel_id: ChannelId, + count: usize, + before_id: Option, + ) -> Result>; + #[cfg(test)] + async fn teardown(&self, name: &str, url: &str); +} + +pub struct PostgresDb { pool: sqlx::PgPool, test_mode: bool, } -impl Db { +impl PostgresDb { pub async fn new(url: &str, max_connections: u32) -> tide::Result { let pool = DbOptions::new() .max_connections(max_connections) @@ -39,10 +104,12 @@ impl Db { test_mode: false, }) } +} +#[async_trait] +impl Db for PostgresDb { // signups - - pub async fn create_signup( + async fn create_signup( &self, github_login: &str, email_address: &str, @@ -64,7 +131,7 @@ impl Db { VALUES ($1, $2, $3, $4, $5, $6) RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(github_login) .bind(email_address) .bind(about) @@ -73,31 +140,31 @@ impl Db { .bind(wants_community) .fetch_one(&self.pool) .await - .map(SignupId) + .map(SignupId)?) }) } - pub async fn get_all_signups(&self) -> Result> { + async fn get_all_signups(&self) -> Result> { test_support!(self, { let query = "SELECT * FROM signups ORDER BY github_login ASC"; - sqlx::query_as(query).fetch_all(&self.pool).await + Ok(sqlx::query_as(query).fetch_all(&self.pool).await?) }) } - pub async fn destroy_signup(&self, id: SignupId) -> Result<()> { + async fn destroy_signup(&self, id: SignupId) -> Result<()> { test_support!(self, { let query = "DELETE FROM signups WHERE id = $1"; - sqlx::query(query) + Ok(sqlx::query(query) .bind(id.0) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } // users - pub async fn create_user(&self, github_login: &str, admin: bool) -> Result { + async fn create_user(&self, github_login: &str, admin: bool) -> Result { test_support!(self, { let query = " INSERT INTO users (github_login, admin) @@ -105,31 +172,28 @@ impl Db { ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(github_login) .bind(admin) .fetch_one(&self.pool) .await - .map(UserId) + .map(UserId)?) }) } - pub async fn get_all_users(&self) -> Result> { + async fn get_all_users(&self) -> Result> { test_support!(self, { let query = "SELECT * FROM users ORDER BY github_login ASC"; - sqlx::query_as(query).fetch_all(&self.pool).await + Ok(sqlx::query_as(query).fetch_all(&self.pool).await?) }) } - pub async fn get_user_by_id(&self, id: UserId) -> Result> { - let users = self.get_users_by_ids([id]).await?; + async fn get_user_by_id(&self, id: UserId) -> Result> { + let users = self.get_users_by_ids(vec![id]).await?; Ok(users.into_iter().next()) } - pub async fn get_users_by_ids( - &self, - ids: impl IntoIterator, - ) -> Result> { + async fn get_users_by_ids(&self, ids: Vec) -> Result> { let ids = ids.into_iter().map(|id| id.0).collect::>(); test_support!(self, { let query = " @@ -138,33 +202,36 @@ impl Db { WHERE users.id = ANY ($1) "; - sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await + Ok(sqlx::query_as(query) + .bind(&ids) + .fetch_all(&self.pool) + .await?) }) } - pub async fn get_user_by_github_login(&self, github_login: &str) -> Result> { + async fn get_user_by_github_login(&self, github_login: &str) -> Result> { test_support!(self, { let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1"; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(github_login) .fetch_optional(&self.pool) - .await + .await?) }) } - pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { + async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { test_support!(self, { let query = "UPDATE users SET admin = $1 WHERE id = $2"; - sqlx::query(query) + Ok(sqlx::query(query) .bind(is_admin) .bind(id.0) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } - pub async fn destroy_user(&self, id: UserId) -> Result<()> { + async fn destroy_user(&self, id: UserId) -> Result<()> { test_support!(self, { let query = "DELETE FROM access_tokens WHERE user_id = $1;"; sqlx::query(query) @@ -173,17 +240,17 @@ impl Db { .await .map(drop)?; let query = "DELETE FROM users WHERE id = $1;"; - sqlx::query(query) + Ok(sqlx::query(query) .bind(id.0) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } // access tokens - pub async fn create_access_token_hash( + async fn create_access_token_hash( &self, user_id: UserId, access_token_hash: &str, @@ -216,11 +283,11 @@ impl Db { .bind(max_access_token_count as u32) .execute(&mut tx) .await?; - tx.commit().await + Ok(tx.commit().await?) }) } - pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { + async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { test_support!(self, { let query = " SELECT hash @@ -228,10 +295,10 @@ impl Db { WHERE user_id = $1 ORDER BY id DESC "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(user_id.0) .fetch_all(&self.pool) - .await + .await?) }) } @@ -239,82 +306,77 @@ impl Db { #[allow(unused)] // Help rust-analyzer #[cfg(any(test, feature = "seed-support"))] - pub async fn find_org_by_slug(&self, slug: &str) -> Result> { + async fn find_org_by_slug(&self, slug: &str) -> Result> { test_support!(self, { let query = " SELECT * FROM orgs WHERE slug = $1 "; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(slug) .fetch_optional(&self.pool) - .await + .await?) }) } #[cfg(any(test, feature = "seed-support"))] - pub async fn create_org(&self, name: &str, slug: &str) -> Result { + async fn create_org(&self, name: &str, slug: &str) -> Result { test_support!(self, { let query = " INSERT INTO orgs (name, slug) VALUES ($1, $2) RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(name) .bind(slug) .fetch_one(&self.pool) .await - .map(OrgId) + .map(OrgId)?) }) } #[cfg(any(test, feature = "seed-support"))] - pub async fn add_org_member( - &self, - org_id: OrgId, - user_id: UserId, - is_admin: bool, - ) -> Result<()> { + async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> { test_support!(self, { let query = " INSERT INTO org_memberships (org_id, user_id, admin) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING "; - sqlx::query(query) + Ok(sqlx::query(query) .bind(org_id.0) .bind(user_id.0) .bind(is_admin) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } // channels #[cfg(any(test, feature = "seed-support"))] - pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { + async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { test_support!(self, { let query = " INSERT INTO channels (owner_id, owner_is_user, name) VALUES ($1, false, $2) RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(org_id.0) .bind(name) .fetch_one(&self.pool) .await - .map(ChannelId) + .map(ChannelId)?) }) } #[allow(unused)] // Help rust-analyzer #[cfg(any(test, feature = "seed-support"))] - pub async fn get_org_channels(&self, org_id: OrgId) -> Result> { + async fn get_org_channels(&self, org_id: OrgId) -> Result> { test_support!(self, { let query = " SELECT * @@ -323,32 +385,32 @@ impl Db { channels.owner_is_user = false AND channels.owner_id = $1 "; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(org_id.0) .fetch_all(&self.pool) - .await + .await?) }) } - pub async fn get_accessible_channels(&self, user_id: UserId) -> Result> { + async fn get_accessible_channels(&self, user_id: UserId) -> Result> { test_support!(self, { let query = " SELECT - channels.id, channels.name + channels.* FROM channel_memberships, channels WHERE channel_memberships.user_id = $1 AND channel_memberships.channel_id = channels.id "; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(user_id.0) .fetch_all(&self.pool) - .await + .await?) }) } - pub async fn can_user_access_channel( + async fn can_user_access_channel( &self, user_id: UserId, channel_id: ChannelId, @@ -360,17 +422,17 @@ impl Db { WHERE user_id = $1 AND channel_id = $2 LIMIT 1 "; - sqlx::query_scalar::<_, i32>(query) + Ok(sqlx::query_scalar::<_, i32>(query) .bind(user_id.0) .bind(channel_id.0) .fetch_optional(&self.pool) .await - .map(|e| e.is_some()) + .map(|e| e.is_some())?) }) } #[cfg(any(test, feature = "seed-support"))] - pub async fn add_channel_member( + async fn add_channel_member( &self, channel_id: ChannelId, user_id: UserId, @@ -382,19 +444,19 @@ impl Db { VALUES ($1, $2, $3) ON CONFLICT DO NOTHING "; - sqlx::query(query) + Ok(sqlx::query(query) .bind(channel_id.0) .bind(user_id.0) .bind(is_admin) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } // messages - pub async fn create_channel_message( + async fn create_channel_message( &self, channel_id: ChannelId, sender_id: UserId, @@ -409,7 +471,7 @@ impl Db { ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(channel_id.0) .bind(sender_id.0) .bind(body) @@ -417,11 +479,11 @@ impl Db { .bind(Uuid::from_u128(nonce)) .fetch_one(&self.pool) .await - .map(MessageId) + .map(MessageId)?) }) } - pub async fn get_channel_messages( + async fn get_channel_messages( &self, channel_id: ChannelId, count: usize, @@ -431,7 +493,7 @@ impl Db { let query = r#" SELECT * FROM ( SELECT - id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce + id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce FROM channel_messages WHERE @@ -442,12 +504,34 @@ impl Db { ) as recent_messages ORDER BY id ASC "#; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(channel_id.0) .bind(before_id.unwrap_or(MessageId::MAX)) .bind(count as i64) .fetch_all(&self.pool) + .await?) + }) + } + + #[cfg(test)] + async fn teardown(&self, name: &str, url: &str) { + use util::ResultExt; + + test_support!(self, { + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid(); + "; + sqlx::query(query) + .bind(name) + .execute(&self.pool) .await + .log_err(); + self.pool.close().await; + ::drop_database(url) + .await + .log_err(); }) } } @@ -479,7 +563,7 @@ macro_rules! id_type { } id_type!(UserId); -#[derive(Debug, FromRow, Serialize, PartialEq)] +#[derive(Clone, Debug, FromRow, Serialize, PartialEq)] pub struct User { pub id: UserId, pub github_login: String, @@ -507,16 +591,19 @@ pub struct Signup { } id_type!(ChannelId); -#[derive(Debug, FromRow, Serialize)] +#[derive(Clone, Debug, FromRow, Serialize)] pub struct Channel { pub id: ChannelId, pub name: String, + pub owner_id: i32, + pub owner_is_user: bool, } id_type!(MessageId); -#[derive(Debug, FromRow)] +#[derive(Clone, Debug, FromRow)] pub struct ChannelMessage { pub id: MessageId, + pub channel_id: ChannelId, pub sender_id: UserId, pub body: String, pub sent_at: OffsetDateTime, @@ -526,6 +613,9 @@ pub struct ChannelMessage { #[cfg(test)] pub mod tests { use super::*; + use anyhow::anyhow; + use collections::BTreeMap; + use gpui::{executor::Background, TestAppContext}; use lazy_static::lazy_static; use parking_lot::Mutex; use rand::prelude::*; @@ -533,227 +623,119 @@ pub mod tests { migrate::{MigrateDatabase, Migrator}, Postgres, }; - use std::{ - mem, - path::Path, - sync::atomic::{AtomicUsize, Ordering::SeqCst}, - thread, - }; - use util::ResultExt as _; + use std::{path::Path, sync::Arc}; + use util::post_inc; - pub struct TestDb { - pub db: Option, - pub name: String, - pub url: String, - clean_pool_on_drop: bool, - } + #[gpui::test] + async fn test_get_users_by_ids(cx: TestAppContext) { + for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] { + let db = test_db.db(); - lazy_static! { - static ref DB_POOL: Mutex> = Default::default(); - static ref DB_COUNT: AtomicUsize = Default::default(); - } + let user = db.create_user("user", false).await.unwrap(); + let friend1 = db.create_user("friend-1", false).await.unwrap(); + let friend2 = db.create_user("friend-2", false).await.unwrap(); + let friend3 = db.create_user("friend-3", false).await.unwrap(); - impl TestDb { - pub fn new() -> Self { - DB_COUNT.fetch_add(1, SeqCst); - let mut pool = DB_POOL.lock(); - if let Some(db) = pool.pop() { - db.truncate(); - db - } else { - let mut rng = StdRng::from_entropy(); - let name = format!("zed-test-{}", rng.gen::()); - let url = format!("postgres://postgres@localhost/{}", name); - let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")); - let db = block_on(async { - Postgres::create_database(&url) - .await - .expect("failed to create test db"); - let mut db = Db::new(&url, 5).await.unwrap(); - db.test_mode = true; - let migrator = Migrator::new(migrations_path).await.unwrap(); - migrator.run(&db.pool).await.unwrap(); - db - }); - - Self { - db: Some(db), - name, - url, - clean_pool_on_drop: false, - } - } - } - - pub fn set_clean_pool_on_drop(&mut self, delete_on_drop: bool) { - self.clean_pool_on_drop = delete_on_drop; - } - - pub fn db(&self) -> &Db { - self.db.as_ref().unwrap() - } - - fn truncate(&self) { - block_on(async { - let query = " - SELECT tablename FROM pg_tables - WHERE schemaname = 'public'; - "; - let table_names = sqlx::query_scalar::<_, String>(query) - .fetch_all(&self.db().pool) + assert_eq!( + db.get_users_by_ids(vec![user, friend1, friend2, friend3]) .await - .unwrap(); - sqlx::query(&format!( - "TRUNCATE TABLE {} RESTART IDENTITY", - table_names.join(", ") - )) - .execute(&self.db().pool) - .await - .unwrap(); - }) - } - - async fn teardown(mut self) -> Result<()> { - let db = self.db.take().unwrap(); - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid(); - "; - sqlx::query(query) - .bind(&self.name) - .execute(&db.pool) - .await?; - db.pool.close().await; - Postgres::drop_database(&self.url).await?; - Ok(()) + .unwrap(), + vec![ + User { + id: user, + github_login: "user".to_string(), + admin: false, + }, + User { + id: friend1, + github_login: "friend-1".to_string(), + admin: false, + }, + User { + id: friend2, + github_login: "friend-2".to_string(), + admin: false, + }, + User { + id: friend3, + github_login: "friend-3".to_string(), + admin: false, + } + ] + ); } } - impl Drop for TestDb { - fn drop(&mut self) { - if let Some(db) = self.db.take() { - DB_POOL.lock().push(TestDb { - db: Some(db), - name: mem::take(&mut self.name), - url: mem::take(&mut self.url), - clean_pool_on_drop: true, - }); - if DB_COUNT.fetch_sub(1, SeqCst) == 1 - && (self.clean_pool_on_drop || thread::panicking()) - { - block_on(async move { - let mut pool = DB_POOL.lock(); - for db in pool.drain(..) { - db.teardown().await.log_err(); - } - }); - } + #[gpui::test] + async fn test_recent_channel_messages(cx: TestAppContext) { + for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] { + let db = test_db.db(); + let user = db.create_user("user", false).await.unwrap(); + let org = db.create_org("org", "org").await.unwrap(); + let channel = db.create_org_channel(org, "channel").await.unwrap(); + for i in 0..10 { + db.create_channel_message( + channel, + user, + &i.to_string(), + OffsetDateTime::now_utc(), + i, + ) + .await + .unwrap(); } - } - } - #[gpui::test] - async fn test_get_users_by_ids() { - let test_db = TestDb::new(); - let db = test_db.db(); + let messages = db.get_channel_messages(channel, 5, None).await.unwrap(); + assert_eq!( + messages.iter().map(|m| &m.body).collect::>(), + ["5", "6", "7", "8", "9"] + ); - let user = db.create_user("user", false).await.unwrap(); - let friend1 = db.create_user("friend-1", false).await.unwrap(); - let friend2 = db.create_user("friend-2", false).await.unwrap(); - let friend3 = db.create_user("friend-3", false).await.unwrap(); - - assert_eq!( - db.get_users_by_ids([user, friend1, friend2, friend3]) - .await - .unwrap(), - vec![ - User { - id: user, - github_login: "user".to_string(), - admin: false, - }, - User { - id: friend1, - github_login: "friend-1".to_string(), - admin: false, - }, - User { - id: friend2, - github_login: "friend-2".to_string(), - admin: false, - }, - User { - id: friend3, - github_login: "friend-3".to_string(), - admin: false, - } - ] - ); - } - - #[gpui::test] - async fn test_recent_channel_messages() { - let test_db = TestDb::new(); - let db = test_db.db(); - let user = db.create_user("user", false).await.unwrap(); - let org = db.create_org("org", "org").await.unwrap(); - let channel = db.create_org_channel(org, "channel").await.unwrap(); - for i in 0..10 { - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) + let prev_messages = db + .get_channel_messages(channel, 4, Some(messages[0].id)) .await .unwrap(); + assert_eq!( + prev_messages.iter().map(|m| &m.body).collect::>(), + ["1", "2", "3", "4"] + ); } - - let messages = db.get_channel_messages(channel, 5, None).await.unwrap(); - assert_eq!( - messages.iter().map(|m| &m.body).collect::>(), - ["5", "6", "7", "8", "9"] - ); - - let prev_messages = db - .get_channel_messages(channel, 4, Some(messages[0].id)) - .await - .unwrap(); - assert_eq!( - prev_messages.iter().map(|m| &m.body).collect::>(), - ["1", "2", "3", "4"] - ); } #[gpui::test] - async fn test_channel_message_nonces() { - let test_db = TestDb::new(); - let db = test_db.db(); - let user = db.create_user("user", false).await.unwrap(); - let org = db.create_org("org", "org").await.unwrap(); - let channel = db.create_org_channel(org, "channel").await.unwrap(); + async fn test_channel_message_nonces(cx: TestAppContext) { + for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] { + let db = test_db.db(); + let user = db.create_user("user", false).await.unwrap(); + let org = db.create_org("org", "org").await.unwrap(); + let channel = db.create_org_channel(org, "channel").await.unwrap(); - let msg1_id = db - .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg2_id = db - .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - let msg3_id = db - .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg4_id = db - .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); + let msg1_id = db + .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg2_id = db + .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); + let msg3_id = db + .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg4_id = db + .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); - assert_ne!(msg1_id, msg2_id); - assert_eq!(msg1_id, msg3_id); - assert_eq!(msg2_id, msg4_id); + assert_ne!(msg1_id, msg2_id); + assert_eq!(msg1_id, msg3_id); + assert_eq!(msg2_id, msg4_id); + } } #[gpui::test] async fn test_create_access_tokens() { - let test_db = TestDb::new(); + let test_db = TestDb::postgres(); let db = test_db.db(); let user = db.create_user("the-user", false).await.unwrap(); @@ -782,4 +764,359 @@ pub mod tests { &["h5".to_string(), "h4".to_string(), "h3".to_string()] ); } + + pub struct TestDb { + pub db: Option>, + pub name: String, + pub url: String, + } + + impl TestDb { + pub fn postgres() -> Self { + lazy_static! { + static ref LOCK: Mutex<()> = Mutex::new(()); + } + + let _guard = LOCK.lock(); + let mut rng = StdRng::from_entropy(); + let name = format!("zed-test-{}", rng.gen::()); + let url = format!("postgres://postgres@localhost/{}", name); + let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")); + let db = block_on(async { + Postgres::create_database(&url) + .await + .expect("failed to create test db"); + let mut db = PostgresDb::new(&url, 5).await.unwrap(); + db.test_mode = true; + let migrator = Migrator::new(migrations_path).await.unwrap(); + migrator.run(&db.pool).await.unwrap(); + db + }); + Self { + db: Some(Arc::new(db)), + name, + url, + } + } + + pub fn fake(background: Arc) -> Self { + Self { + db: Some(Arc::new(FakeDb::new(background))), + name: "fake".to_string(), + url: "fake".to_string(), + } + } + + pub fn db(&self) -> &Arc { + self.db.as_ref().unwrap() + } + } + + impl Drop for TestDb { + fn drop(&mut self) { + if let Some(db) = self.db.take() { + block_on(db.teardown(&self.name, &self.url)); + } + } + } + + pub struct FakeDb { + background: Arc, + users: Mutex>, + next_user_id: Mutex, + orgs: Mutex>, + next_org_id: Mutex, + org_memberships: Mutex>, + channels: Mutex>, + next_channel_id: Mutex, + channel_memberships: Mutex>, + channel_messages: Mutex>, + next_channel_message_id: Mutex, + } + + impl FakeDb { + pub fn new(background: Arc) -> Self { + Self { + background, + users: Default::default(), + next_user_id: Mutex::new(1), + orgs: Default::default(), + next_org_id: Mutex::new(1), + org_memberships: Default::default(), + channels: Default::default(), + next_channel_id: Mutex::new(1), + channel_memberships: Default::default(), + channel_messages: Default::default(), + next_channel_message_id: Mutex::new(1), + } + } + } + + #[async_trait] + impl Db for FakeDb { + async fn create_signup( + &self, + _github_login: &str, + _email_address: &str, + _about: &str, + _wants_releases: bool, + _wants_updates: bool, + _wants_community: bool, + ) -> Result { + unimplemented!() + } + + async fn get_all_signups(&self) -> Result> { + unimplemented!() + } + + async fn destroy_signup(&self, _id: SignupId) -> Result<()> { + unimplemented!() + } + + async fn create_user(&self, github_login: &str, admin: bool) -> Result { + self.background.simulate_random_delay().await; + + let mut users = self.users.lock(); + if let Some(user) = users + .values() + .find(|user| user.github_login == github_login) + { + Ok(user.id) + } else { + let user_id = UserId(post_inc(&mut *self.next_user_id.lock())); + users.insert( + user_id, + User { + id: user_id, + github_login: github_login.to_string(), + admin, + }, + ); + Ok(user_id) + } + } + + async fn get_all_users(&self) -> Result> { + unimplemented!() + } + + async fn get_user_by_id(&self, id: UserId) -> Result> { + Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next()) + } + + async fn get_users_by_ids(&self, ids: Vec) -> Result> { + self.background.simulate_random_delay().await; + let users = self.users.lock(); + Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect()) + } + + async fn get_user_by_github_login(&self, _github_login: &str) -> Result> { + unimplemented!() + } + + async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> { + unimplemented!() + } + + async fn destroy_user(&self, _id: UserId) -> Result<()> { + unimplemented!() + } + + async fn create_access_token_hash( + &self, + _user_id: UserId, + _access_token_hash: &str, + _max_access_token_count: usize, + ) -> Result<()> { + unimplemented!() + } + + async fn get_access_token_hashes(&self, _user_id: UserId) -> Result> { + unimplemented!() + } + + async fn find_org_by_slug(&self, _slug: &str) -> Result> { + unimplemented!() + } + + async fn create_org(&self, name: &str, slug: &str) -> Result { + self.background.simulate_random_delay().await; + let mut orgs = self.orgs.lock(); + if orgs.values().any(|org| org.slug == slug) { + Err(anyhow!("org already exists")) + } else { + let org_id = OrgId(post_inc(&mut *self.next_org_id.lock())); + orgs.insert( + org_id, + Org { + id: org_id, + name: name.to_string(), + slug: slug.to_string(), + }, + ); + Ok(org_id) + } + } + + async fn add_org_member( + &self, + org_id: OrgId, + user_id: UserId, + is_admin: bool, + ) -> Result<()> { + self.background.simulate_random_delay().await; + if !self.orgs.lock().contains_key(&org_id) { + return Err(anyhow!("org does not exist")); + } + if !self.users.lock().contains_key(&user_id) { + return Err(anyhow!("user does not exist")); + } + + self.org_memberships + .lock() + .entry((org_id, user_id)) + .or_insert(is_admin); + Ok(()) + } + + async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { + self.background.simulate_random_delay().await; + if !self.orgs.lock().contains_key(&org_id) { + return Err(anyhow!("org does not exist")); + } + + let mut channels = self.channels.lock(); + let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock())); + channels.insert( + channel_id, + Channel { + id: channel_id, + name: name.to_string(), + owner_id: org_id.0, + owner_is_user: false, + }, + ); + Ok(channel_id) + } + + async fn get_org_channels(&self, org_id: OrgId) -> Result> { + self.background.simulate_random_delay().await; + Ok(self + .channels + .lock() + .values() + .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0) + .cloned() + .collect()) + } + + async fn get_accessible_channels(&self, user_id: UserId) -> Result> { + self.background.simulate_random_delay().await; + let channels = self.channels.lock(); + let memberships = self.channel_memberships.lock(); + Ok(channels + .values() + .filter(|channel| memberships.contains_key(&(channel.id, user_id))) + .cloned() + .collect()) + } + + async fn can_user_access_channel( + &self, + user_id: UserId, + channel_id: ChannelId, + ) -> Result { + self.background.simulate_random_delay().await; + Ok(self + .channel_memberships + .lock() + .contains_key(&(channel_id, user_id))) + } + + async fn add_channel_member( + &self, + channel_id: ChannelId, + user_id: UserId, + is_admin: bool, + ) -> Result<()> { + self.background.simulate_random_delay().await; + if !self.channels.lock().contains_key(&channel_id) { + return Err(anyhow!("channel does not exist")); + } + if !self.users.lock().contains_key(&user_id) { + return Err(anyhow!("user does not exist")); + } + + self.channel_memberships + .lock() + .entry((channel_id, user_id)) + .or_insert(is_admin); + Ok(()) + } + + async fn create_channel_message( + &self, + channel_id: ChannelId, + sender_id: UserId, + body: &str, + timestamp: OffsetDateTime, + nonce: u128, + ) -> Result { + self.background.simulate_random_delay().await; + if !self.channels.lock().contains_key(&channel_id) { + return Err(anyhow!("channel does not exist")); + } + if !self.users.lock().contains_key(&sender_id) { + return Err(anyhow!("user does not exist")); + } + + let mut messages = self.channel_messages.lock(); + if let Some(message) = messages + .values() + .find(|message| message.nonce.as_u128() == nonce) + { + Ok(message.id) + } else { + let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock())); + messages.insert( + message_id, + ChannelMessage { + id: message_id, + channel_id, + sender_id, + body: body.to_string(), + sent_at: timestamp, + nonce: Uuid::from_u128(nonce), + }, + ); + Ok(message_id) + } + } + + async fn get_channel_messages( + &self, + channel_id: ChannelId, + count: usize, + before_id: Option, + ) -> Result> { + let mut messages = self + .channel_messages + .lock() + .values() + .rev() + .filter(|message| { + message.channel_id == channel_id + && message.id < before_id.unwrap_or(MessageId::MAX) + }) + .take(count) + .cloned() + .collect::>(); + dbg!(count, before_id, &messages); + messages.sort_unstable_by_key(|message| message.id); + Ok(messages) + } + + async fn teardown(&self, _name: &str, _url: &str) {} + } } diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index 3301fb24a9..47c8c82190 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -20,7 +20,7 @@ use anyhow::Result; use async_std::net::TcpListener; use async_trait::async_trait; use auth::RequestExt as _; -use db::Db; +use db::{Db, PostgresDb}; use handlebars::{Handlebars, TemplateRenderError}; use parking_lot::RwLock; use rust_embed::RustEmbed; @@ -49,7 +49,7 @@ pub struct Config { } pub struct AppState { - db: Db, + db: Arc, handlebars: RwLock>, auth_client: auth::Client, github_client: Arc, @@ -59,7 +59,7 @@ pub struct AppState { impl AppState { async fn new(config: Config) -> tide::Result> { - let db = Db::new(&config.database_url, 5).await?; + let db = PostgresDb::new(&config.database_url, 5).await?; let github_client = github::AppClient::new(config.github_app_id, config.github_private_key.clone()); let repo_client = github_client @@ -68,7 +68,7 @@ impl AppState { .context("failed to initialize github client")?; let this = Self { - db, + db: Arc::new(db), handlebars: Default::default(), auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret), github_client, @@ -112,7 +112,7 @@ impl AppState { #[async_trait] trait RequestExt { async fn layout_data(&mut self) -> tide::Result>; - fn db(&self) -> &Db; + fn db(&self) -> &Arc; } #[async_trait] @@ -126,7 +126,7 @@ impl RequestExt for Request { Ok(self.ext::>().unwrap().clone()) } - fn db(&self) -> &Db { + fn db(&self) -> &Arc { &self.state().db } } diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index cf38889260..adb0592df5 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -785,7 +785,12 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result { - let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto); + let user_ids = request + .payload + .user_ids + .into_iter() + .map(UserId::from_proto) + .collect(); let users = self .app_state .db @@ -1139,18 +1144,14 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_share_project( - mut cx_a: TestAppContext, - mut cx_b: TestAppContext, - last_iteration: bool, - ) { + async fn test_share_project(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { let (window_b, _) = cx_b.add_window(|_| EmptyView); let lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); cx_a.foreground().forbid_parking(); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1282,17 +1283,13 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_unshare_project( - mut cx_a: TestAppContext, - mut cx_b: TestAppContext, - last_iteration: bool, - ) { + async fn test_unshare_project(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { let lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); cx_a.foreground().forbid_parking(); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1387,14 +1384,13 @@ mod tests { mut cx_a: TestAppContext, mut cx_b: TestAppContext, mut cx_c: TestAppContext, - last_iteration: bool, ) { let lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); cx_a.foreground().forbid_parking(); // Connect to a server as 3 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; let client_c = server.create_client(&mut cx_c, "user_c").await; @@ -1566,17 +1562,13 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_buffer_conflict_after_save( - mut cx_a: TestAppContext, - mut cx_b: TestAppContext, - last_iteration: bool, - ) { + async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); let lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1658,17 +1650,13 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_buffer_reloading( - mut cx_a: TestAppContext, - mut cx_b: TestAppContext, - last_iteration: bool, - ) { + async fn test_buffer_reloading(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); let lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1747,14 +1735,13 @@ mod tests { async fn test_editing_while_guest_opens_buffer( mut cx_a: TestAppContext, mut cx_b: TestAppContext, - last_iteration: bool, ) { cx_a.foreground().forbid_parking(); let lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1830,14 +1817,13 @@ mod tests { async fn test_leaving_worktree_while_opening_buffer( mut cx_a: TestAppContext, mut cx_b: TestAppContext, - last_iteration: bool, ) { cx_a.foreground().forbid_parking(); let lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1906,17 +1892,13 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_peer_disconnection( - mut cx_a: TestAppContext, - mut cx_b: TestAppContext, - last_iteration: bool, - ) { + async fn test_peer_disconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); let lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1984,7 +1966,6 @@ mod tests { async fn test_collaborating_with_diagnostics( mut cx_a: TestAppContext, mut cx_b: TestAppContext, - last_iteration: bool, ) { cx_a.foreground().forbid_parking(); let mut lang_registry = Arc::new(LanguageRegistry::new()); @@ -2005,7 +1986,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2209,7 +2190,6 @@ mod tests { async fn test_collaborating_with_completion( mut cx_a: TestAppContext, mut cx_b: TestAppContext, - last_iteration: bool, ) { cx_a.foreground().forbid_parking(); let mut lang_registry = Arc::new(LanguageRegistry::new()); @@ -2237,7 +2217,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2419,11 +2399,7 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_formatting_buffer( - mut cx_a: TestAppContext, - mut cx_b: TestAppContext, - last_iteration: bool, - ) { + async fn test_formatting_buffer(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); let mut lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); @@ -2443,7 +2419,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2525,11 +2501,7 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_definition( - mut cx_a: TestAppContext, - mut cx_b: TestAppContext, - last_iteration: bool, - ) { + async fn test_definition(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); let mut lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); @@ -2564,7 +2536,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2682,7 +2654,6 @@ mod tests { mut cx_a: TestAppContext, mut cx_b: TestAppContext, mut rng: StdRng, - last_iteration: bool, ) { cx_a.foreground().forbid_parking(); let mut lang_registry = Arc::new(LanguageRegistry::new()); @@ -2713,7 +2684,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2792,7 +2763,6 @@ mod tests { async fn test_collaborating_with_code_actions( mut cx_a: TestAppContext, mut cx_b: TestAppContext, - last_iteration: bool, ) { cx_a.foreground().forbid_parking(); let mut lang_registry = Arc::new(LanguageRegistry::new()); @@ -2815,7 +2785,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -3032,15 +3002,11 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_basic_chat( - mut cx_a: TestAppContext, - mut cx_b: TestAppContext, - last_iteration: bool, - ) { + async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -3176,10 +3142,10 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_chat_message_validation(mut cx_a: TestAppContext, last_iteration: bool) { + async fn test_chat_message_validation(mut cx_a: TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let db = &server.app_state.db; @@ -3236,15 +3202,11 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_chat_reconnection( - mut cx_a: TestAppContext, - mut cx_b: TestAppContext, - last_iteration: bool, - ) { + async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; let mut status_b = client_b.status(); @@ -3456,14 +3418,13 @@ mod tests { mut cx_a: TestAppContext, mut cx_b: TestAppContext, mut cx_c: TestAppContext, - last_iteration: bool, ) { cx_a.foreground().forbid_parking(); let lang_registry = Arc::new(LanguageRegistry::new()); let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 3 clients. - let mut server = TestServer::start(cx_a.foreground(), last_iteration).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; let client_c = server.create_client(&mut cx_c, "user_c").await; @@ -3595,7 +3556,7 @@ mod tests { } #[gpui::test(iterations = 100)] - async fn test_random_collaboration(cx: TestAppContext, rng: StdRng, last_iteration: bool) { + async fn test_random_collaboration(cx: TestAppContext, rng: StdRng) { cx.foreground().forbid_parking(); let max_peers = env::var("MAX_PEERS") .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) @@ -3654,7 +3615,7 @@ mod tests { .await; let operations = Rc::new(Cell::new(0)); - let mut server = TestServer::start(cx.foreground(), last_iteration).await; + let mut server = TestServer::start(cx.foreground(), cx.background()).await; let mut clients = Vec::new(); let mut next_entity_id = 100000; @@ -3849,9 +3810,11 @@ mod tests { } impl TestServer { - async fn start(foreground: Rc, clean_db_pool_on_drop: bool) -> Self { - let mut test_db = TestDb::new(); - test_db.set_clean_pool_on_drop(clean_db_pool_on_drop); + async fn start( + foreground: Rc, + background: Arc, + ) -> Self { + let test_db = TestDb::fake(background); let app_state = Self::build_app_state(&test_db).await; let peer = Peer::new(); let notifications = mpsc::unbounded();