diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index cd92287b39..239a16cfee 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -33,10 +33,14 @@ macro_rules! test_support { }; if cfg!(test) { + #[cfg(not(test))] + unreachable!(); + #[cfg(test)] if let Some(background) = $self.background.as_ref() { background.simulate_random_delay().await; } + #[cfg(test)] $self.runtime.as_ref().unwrap().block_on(body) } else { @@ -63,8 +67,6 @@ impl RowsAffected for sqlx::postgres::PgQueryResult { #[cfg(test)] impl Db { - const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); - pub async fn new(url: &str, max_connections: u32) -> Result { use std::str::FromStr as _; let options = sqlx::sqlite::SqliteConnectOptions::from_str(url) @@ -83,8 +85,19 @@ impl Db { }) } - #[cfg(test)] - pub fn teardown(&self, _url: &str) {} + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { + test_support!(self, { + let query = " + SELECT users.* + FROM users + WHERE users.id IN (SELECT value from json_each($1)) + "; + Ok(sqlx::query_as(query) + .bind(&serde_json::json!(ids)) + .fetch_all(&self.pool) + .await?) + }) + } pub async fn get_user_metrics_id(&self, id: UserId) -> Result { test_support!(self, { @@ -155,11 +168,13 @@ impl Db { ) -> Result { unimplemented!() } + + pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> { + unimplemented!() + } } impl Db { - const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); - pub async fn new(url: &str, max_connections: u32) -> Result { let pool = sqlx::postgres::PgPoolOptions::new() .max_connections(max_connections) @@ -210,6 +225,20 @@ impl Db { }) } + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { + test_support!(self, { + let query = " + SELECT users.* + FROM users + WHERE users.id = ANY ($1) + "; + Ok(sqlx::query_as(query) + .bind(&ids.into_iter().map(|id| id.0).collect::>()) + .fetch_all(&self.pool) + .await?) + }) + } + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { test_support!(self, { let query = " @@ -382,7 +411,7 @@ impl Db { device_id ) VALUES - ($1, $2, FALSE, $3, $4, $5, FALSE, $6) + ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8) RETURNING id ", ) @@ -486,6 +515,26 @@ impl Db { }) }) } + + pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { + test_support!(self, { + let emails = invites + .iter() + .map(|s| s.email_address.as_str()) + .collect::>(); + sqlx::query( + " + UPDATE signups + SET email_confirmation_sent = TRUE + WHERE email_address = ANY ($1) + ", + ) + .bind(&emails) + .execute(&self.pool) + .await?; + Ok(()) + }) + } } impl Db @@ -600,20 +649,6 @@ where }) } - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - test_support!(self, { - let query = " - SELECT users.* - FROM users - WHERE users.id IN (SELECT value from json_each($1)) - "; - Ok(sqlx::query_as(query) - .bind(&serde_json::json!(ids)) - .fetch_all(&self.pool) - .await?) - }) - } - pub async fn get_users_with_no_invites( &self, invited_by_another_user: bool, @@ -770,26 +805,6 @@ where }) } - pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { - test_support!(self, { - let emails = invites - .iter() - .map(|s| s.email_address.as_str()) - .collect::>(); - sqlx::query( - " - UPDATE signups - SET email_confirmation_sent = TRUE - WHERE email_address IN (SELECT value from json_each($1)) - ", - ) - .bind(&serde_json::json!(emails)) - .execute(&self.pool) - .await?; - Ok(()) - }) - } - // invite codes pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { @@ -1330,16 +1345,23 @@ pub use test::*; mod test { use super::*; use gpui::executor::Background; + use lazy_static::lazy_static; + use parking_lot::Mutex; use rand::prelude::*; + use sqlx::migrate::MigrateDatabase; use std::sync::Arc; - pub struct TestDb { - pub db: Option>, + pub struct SqliteTestDb { + pub db: Option>>, pub conn: sqlx::sqlite::SqliteConnection, + } + + pub struct PostgresTestDb { + pub db: Option>>, pub url: String, } - impl TestDb { + impl SqliteTestDb { pub fn new(background: Arc) -> Self { let mut rng = StdRng::from_entropy(); let url = format!("file:zed-test-{}?mode=memory", rng.gen::()); @@ -1350,10 +1372,9 @@ mod test { .unwrap(); let (mut db, conn) = runtime.block_on(async { - let db = DefaultDb::new(&url, 5).await.unwrap(); - db.migrate(Path::new(DefaultDb::MIGRATIONS_PATH), false) - .await - .unwrap(); + let db = Db::::new(&url, 5).await.unwrap(); + let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); + db.migrate(migrations_path.as_ref(), false).await.unwrap(); let conn = db.pool.acquire().await.unwrap().detach(); (db, conn) }); @@ -1364,16 +1385,57 @@ mod test { Self { db: Some(Arc::new(db)), conn, - url, } } - pub fn db(&self) -> &Arc { + pub fn db(&self) -> &Arc> { self.db.as_ref().unwrap() } } - impl Drop for TestDb { + impl PostgresTestDb { + pub fn new(background: Arc) -> Self { + lazy_static! { + static ref LOCK: Mutex<()> = Mutex::new(()); + } + + let _guard = LOCK.lock(); + let mut rng = StdRng::from_entropy(); + let url = format!( + "postgres://postgres@localhost/zed-test-{}", + rng.gen::() + ); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let mut db = runtime.block_on(async { + sqlx::Postgres::create_database(&url) + .await + .expect("failed to create test db"); + let db = Db::::new(&url, 5).await.unwrap(); + let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); + db.migrate(Path::new(migrations_path), false).await.unwrap(); + db + }); + + db.background = Some(background); + db.runtime = Some(runtime); + + Self { + db: Some(Arc::new(db)), + url, + } + } + + pub fn db(&self) -> &Arc> { + self.db.as_ref().unwrap() + } + } + + impl Drop for PostgresTestDb { fn drop(&mut self) { let db = self.db.take().unwrap(); db.teardown(&self.url); diff --git a/crates/collab/src/db_tests.rs b/crates/collab/src/db_tests.rs index b6a785e9f1..8eda7d34e2 100644 --- a/crates/collab/src/db_tests.rs +++ b/crates/collab/src/db_tests.rs @@ -2,228 +2,192 @@ use super::db::*; use gpui::executor::{Background, Deterministic}; use std::sync::Arc; -#[gpui::test] -async fn test_get_users_by_ids() { - let test_db = TestDb::new(build_background_executor()); - let db = test_db.db(); +macro_rules! test_both_dbs { + ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => { + #[gpui::test] + async fn $postgres_test_name() { + let test_db = PostgresTestDb::new(Deterministic::new(0).build_background()); + let $db = test_db.db(); + $body + } - let mut user_ids = Vec::new(); - for i in 1..=4 { - user_ids.push( - db.create_user( - &format!("user{i}@example.com"), + #[gpui::test] + async fn $sqlite_test_name() { + let test_db = SqliteTestDb::new(Deterministic::new(0).build_background()); + let $db = test_db.db(); + $body + } + }; +} + +test_both_dbs!( + test_get_users_by_ids_postgres, + test_get_users_by_ids_sqlite, + db, + { + let mut user_ids = Vec::new(); + for i in 1..=4 { + user_ids.push( + db.create_user( + &format!("user{i}@example.com"), + false, + NewUserParams { + github_login: format!("user{i}"), + github_user_id: i, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id, + ); + } + + assert_eq!( + db.get_users_by_ids(user_ids.clone()).await.unwrap(), + vec![ + User { + id: user_ids[0], + github_login: "user1".to_string(), + github_user_id: Some(1), + email_address: Some("user1@example.com".to_string()), + admin: false, + ..Default::default() + }, + User { + id: user_ids[1], + github_login: "user2".to_string(), + github_user_id: Some(2), + email_address: Some("user2@example.com".to_string()), + admin: false, + ..Default::default() + }, + User { + id: user_ids[2], + github_login: "user3".to_string(), + github_user_id: Some(3), + email_address: Some("user3@example.com".to_string()), + admin: false, + ..Default::default() + }, + User { + id: user_ids[3], + github_login: "user4".to_string(), + github_user_id: Some(4), + email_address: Some("user4@example.com".to_string()), + admin: false, + ..Default::default() + } + ] + ); + } +); + +test_both_dbs!( + test_get_user_by_github_account_postgres, + test_get_user_by_github_account_sqlite, + db, + { + let user_id1 = db + .create_user( + "user1@example.com", false, NewUserParams { - github_login: format!("user{i}"), - github_user_id: i, + github_login: "login1".into(), + github_user_id: 101, invite_count: 0, }, ) .await .unwrap() - .user_id, - ); - } - - assert_eq!( - db.get_users_by_ids(user_ids.clone()).await.unwrap(), - vec![ - User { - id: user_ids[0], - github_login: "user1".to_string(), - github_user_id: Some(1), - email_address: Some("user1@example.com".to_string()), - admin: false, - ..Default::default() - }, - User { - id: user_ids[1], - github_login: "user2".to_string(), - github_user_id: Some(2), - email_address: Some("user2@example.com".to_string()), - admin: false, - ..Default::default() - }, - User { - id: user_ids[2], - github_login: "user3".to_string(), - github_user_id: Some(3), - email_address: Some("user3@example.com".to_string()), - admin: false, - ..Default::default() - }, - User { - id: user_ids[3], - github_login: "user4".to_string(), - github_user_id: Some(4), - email_address: Some("user4@example.com".to_string()), - admin: false, - ..Default::default() - } - ] - ); -} - -#[gpui::test] -async fn test_get_user_by_github_account() { - let test_db = TestDb::new(build_background_executor()); - let db = test_db.db(); - let user_id1 = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "login1".into(), - github_user_id: 101, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let user_id2 = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "login2".into(), - github_user_id: 102, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let user = db - .get_user_by_github_account("login1", None) - .await - .unwrap() - .unwrap(); - assert_eq!(user.id, user_id1); - assert_eq!(&user.github_login, "login1"); - assert_eq!(user.github_user_id, Some(101)); - - assert!(db - .get_user_by_github_account("non-existent-login", None) - .await - .unwrap() - .is_none()); - - let user = db - .get_user_by_github_account("the-new-login2", Some(102)) - .await - .unwrap() - .unwrap(); - assert_eq!(user.id, user_id2); - assert_eq!(&user.github_login, "the-new-login2"); - assert_eq!(user.github_user_id, Some(102)); -} - -#[gpui::test] -async fn test_create_access_tokens() { - let test_db = TestDb::new(build_background_executor()); - let db = test_db.db(); - let user = db - .create_user( - "u1@example.com", - false, - NewUserParams { - github_login: "u1".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - db.create_access_token_hash(user, "h1", 3).await.unwrap(); - db.create_access_token_hash(user, "h2", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h2".to_string(), "h1".to_string()] - ); - - db.create_access_token_hash(user, "h3", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h3".to_string(), "h2".to_string(), "h1".to_string(),] - ); - - db.create_access_token_hash(user, "h4", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h4".to_string(), "h3".to_string(), "h2".to_string(),] - ); - - db.create_access_token_hash(user, "h5", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h5".to_string(), "h4".to_string(), "h3".to_string()] - ); -} - -#[test] -fn test_fuzzy_like_string() { - assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); - assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); - assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); -} - -#[gpui::test] -async fn test_fuzzy_search_users() { - let test_db = TestDb::new(build_background_executor()); - let db = test_db.db(); - for (i, github_login) in [ - "California", - "colorado", - "oregon", - "washington", - "florida", - "delaware", - "rhode-island", - ] - .into_iter() - .enumerate() - { - db.create_user( - &format!("{github_login}@example.com"), - false, - NewUserParams { - github_login: github_login.into(), - github_user_id: i as i32, - invite_count: 0, - }, - ) - .await - .unwrap(); - } - - assert_eq!( - fuzzy_search_user_names(db, "clr").await, - &["colorado", "California"] - ); - assert_eq!( - fuzzy_search_user_names(db, "ro").await, - &["rhode-island", "colorado", "oregon"], - ); - - async fn fuzzy_search_user_names(db: &DefaultDb, query: &str) -> Vec { - db.fuzzy_search_users(query, 10) + .user_id; + let user_id2 = db + .create_user( + "user2@example.com", + false, + NewUserParams { + github_login: "login2".into(), + github_user_id: 102, + invite_count: 0, + }, + ) .await .unwrap() - .into_iter() - .map(|user| user.github_login) - .collect::>() + .user_id; + + let user = db + .get_user_by_github_account("login1", None) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id1); + assert_eq!(&user.github_login, "login1"); + assert_eq!(user.github_user_id, Some(101)); + + assert!(db + .get_user_by_github_account("non-existent-login", None) + .await + .unwrap() + .is_none()); + + let user = db + .get_user_by_github_account("the-new-login2", Some(102)) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id2); + assert_eq!(&user.github_login, "the-new-login2"); + assert_eq!(user.github_user_id, Some(102)); } -} +); -#[gpui::test] -async fn test_add_contacts() { - let test_db = TestDb::new(build_background_executor()); - let db = test_db.db(); +test_both_dbs!( + test_create_access_tokens_postgres, + test_create_access_tokens_sqlite, + db, + { + let user = db + .create_user( + "u1@example.com", + false, + NewUserParams { + github_login: "u1".into(), + github_user_id: 1, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + db.create_access_token_hash(user, "h1", 3).await.unwrap(); + db.create_access_token_hash(user, "h2", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h2".to_string(), "h1".to_string()] + ); + + db.create_access_token_hash(user, "h3", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h3".to_string(), "h2".to_string(), "h1".to_string(),] + ); + + db.create_access_token_hash(user, "h4", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h4".to_string(), "h3".to_string(), "h2".to_string(),] + ); + + db.create_access_token_hash(user, "h5", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h5".to_string(), "h4".to_string(), "h3".to_string()] + ); + } +); + +test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { let mut user_ids = Vec::new(); for i in 0..3 { user_ids.push( @@ -381,12 +345,109 @@ async fn test_add_contacts() { should_notify: false }], ); +}); + +test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { + let NewUserResult { + user_id: user1, + metrics_id: metrics_id1, + .. + } = db + .create_user( + "person1@example.com", + false, + NewUserParams { + github_login: "person1".into(), + github_user_id: 101, + invite_count: 5, + }, + ) + .await + .unwrap(); + let NewUserResult { + user_id: user2, + metrics_id: metrics_id2, + .. + } = db + .create_user( + "person2@example.com", + false, + NewUserParams { + github_login: "person2".into(), + github_user_id: 102, + invite_count: 5, + }, + ) + .await + .unwrap(); + + assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); + assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); + assert_eq!(metrics_id1.len(), 36); + assert_eq!(metrics_id2.len(), 36); + assert_ne!(metrics_id1, metrics_id2); +}); + +#[test] +fn test_fuzzy_like_string() { + assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); + assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); + assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); +} + +#[gpui::test] +async fn test_fuzzy_search_users() { + let test_db = PostgresTestDb::new(build_background_executor()); + let db = test_db.db(); + for (i, github_login) in [ + "California", + "colorado", + "oregon", + "washington", + "florida", + "delaware", + "rhode-island", + ] + .into_iter() + .enumerate() + { + db.create_user( + &format!("{github_login}@example.com"), + false, + NewUserParams { + github_login: github_login.into(), + github_user_id: i as i32, + invite_count: 0, + }, + ) + .await + .unwrap(); + } + + assert_eq!( + fuzzy_search_user_names(db, "clr").await, + &["colorado", "California"] + ); + assert_eq!( + fuzzy_search_user_names(db, "ro").await, + &["rhode-island", "colorado", "oregon"], + ); + + async fn fuzzy_search_user_names(db: &Db, query: &str) -> Vec { + db.fuzzy_search_users(query, 10) + .await + .unwrap() + .into_iter() + .map(|user| user.github_login) + .collect::>() + } } #[gpui::test] async fn test_invite_codes() { - let test_db = TestDb::new(build_background_executor()); + let test_db = PostgresTestDb::new(build_background_executor()); let db = test_db.db(); + let NewUserResult { user_id: user1, .. } = db .create_user( "user1@example.com", @@ -580,7 +641,7 @@ async fn test_invite_codes() { #[gpui::test] async fn test_signups() { - let test_db = TestDb::new(build_background_executor()); + let test_db = PostgresTestDb::new(build_background_executor()); let db = test_db.db(); // people sign up on the waitlist @@ -724,51 +785,6 @@ async fn test_signups() { .unwrap_err(); } -#[gpui::test] -async fn test_metrics_id() { - let test_db = TestDb::new(build_background_executor()); - let db = test_db.db(); - - let NewUserResult { - user_id: user1, - metrics_id: metrics_id1, - .. - } = db - .create_user( - "person1@example.com", - false, - NewUserParams { - github_login: "person1".into(), - github_user_id: 101, - invite_count: 5, - }, - ) - .await - .unwrap(); - let NewUserResult { - user_id: user2, - metrics_id: metrics_id2, - .. - } = db - .create_user( - "person2@example.com", - false, - NewUserParams { - github_login: "person2".into(), - github_user_id: 102, - invite_count: 5, - }, - ) - .await - .unwrap(); - - assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); - assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); - assert_eq!(metrics_id1.len(), 36); - assert_eq!(metrics_id2.len(), 36); - assert_ne!(metrics_id1, metrics_id2); -} - fn build_background_executor() -> Arc { Deterministic::new(0).build_background() } diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 906424f9c9..a77345270b 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -1,5 +1,5 @@ use crate::{ - db::{NewUserParams, ProjectId, TestDb, UserId}, + db::{NewUserParams, ProjectId, SqliteTestDb as TestDb, UserId}, rpc::{Executor, Server}, AppState, };