diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 9f00c02918..cd92287b39 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -22,6 +22,8 @@ pub struct Db { pool: sqlx::Pool, #[cfg(test)] background: Option>, + #[cfg(test)] + runtime: Option, } macro_rules! test_support { @@ -35,7 +37,8 @@ macro_rules! test_support { if let Some(background) = $self.background.as_ref() { background.simulate_random_delay().await; } - tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build().unwrap().block_on(body) + #[cfg(test)] + $self.runtime.as_ref().unwrap().block_on(body) } else { body.await } @@ -60,17 +63,29 @@ impl RowsAffected for sqlx::postgres::PgQueryResult { #[cfg(test)] impl Db { + const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); + pub async fn new(url: &str, max_connections: u32) -> Result { + use std::str::FromStr as _; + let options = sqlx::sqlite::SqliteConnectOptions::from_str(url) + .unwrap() + .create_if_missing(true) + .shared_cache(true); let pool = sqlx::sqlite::SqlitePoolOptions::new() + .min_connections(2) .max_connections(max_connections) - .connect(url) + .connect_with(options) .await?; Ok(Self { pool, background: None, + runtime: None, }) } + #[cfg(test)] + pub fn teardown(&self, _url: &str) {} + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { test_support!(self, { let query = " @@ -143,6 +158,8 @@ impl Db { } impl Db { + const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); + pub async fn new(url: &str, max_connections: u32) -> Result { let pool = sqlx::postgres::PgPoolOptions::new() .max_connections(max_connections) @@ -152,6 +169,25 @@ impl Db { pool, #[cfg(test)] background: None, + #[cfg(test)] + runtime: None, + }) + } + + #[cfg(test)] + pub fn teardown(&self, url: &str) { + self.runtime.as_ref().unwrap().block_on(async { + use util::ResultExt; + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid(); + "; + sqlx::query(query).execute(&self.pool).await.log_err(); + self.pool.close().await; + ::drop_database(url) + .await + .log_err(); }) } @@ -1295,33 +1331,39 @@ mod test { use super::*; use gpui::executor::Background; use rand::prelude::*; - use sqlx::migrate::MigrateDatabase; use std::sync::Arc; pub struct TestDb { pub db: Option>, + pub conn: sqlx::sqlite::SqliteConnection, pub url: String, } impl TestDb { pub fn new(background: Arc) -> Self { let mut rng = StdRng::from_entropy(); - let url = format!("/tmp/zed-test-{}", rng.gen::()); - let db = tokio::runtime::Builder::new_current_thread() + let url = format!("file:zed-test-{}?mode=memory", rng.gen::()); + let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() - .unwrap() - .block_on(async { - sqlx::Sqlite::create_database(&url).await.unwrap(); - let mut db = DefaultDb::new(&url, 5).await.unwrap(); - db.background = Some(background); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); - db.migrate(Path::new(migrations_path), false).await.unwrap(); - db - }); + .unwrap(); + + let (mut db, conn) = runtime.block_on(async { + let db = DefaultDb::new(&url, 5).await.unwrap(); + db.migrate(Path::new(DefaultDb::MIGRATIONS_PATH), false) + .await + .unwrap(); + let conn = db.pool.acquire().await.unwrap().detach(); + (db, conn) + }); + + db.background = Some(background); + db.runtime = Some(runtime); + Self { db: Some(Arc::new(db)), + conn, url, } } @@ -1333,7 +1375,8 @@ mod test { impl Drop for TestDb { fn drop(&mut self) { - std::fs::remove_file(&self.url).ok(); + let db = self.db.take().unwrap(); + db.teardown(&self.url); } } }