diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 6d204a32bd..6c8c702c0f 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -905,7 +905,15 @@ impl Client { } futures::select_biased! { - result = self.set_connection(conn, cx).fuse() => ConnectionResult::Result(result.context("client auth and connect")), + result = self.set_connection(conn, cx).fuse() => { + match result.context("client auth and connect") { + Ok(()) => ConnectionResult::Result(Ok(())), + Err(err) => { + self.set_status(Status::ConnectionError, cx); + ConnectionResult::Result(Err(err)) + }, + } + }, _ = timeout => { self.set_status(Status::ConnectionError, cx); ConnectionResult::Timeout diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e4c3e55a3d..93ccc1ba03 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -56,6 +56,12 @@ pub use sea_orm::ConnectOptions; pub use tables::user::Model as User; pub use tables::*; +#[cfg(test)] +pub struct DatabaseTestOptions { + pub runtime: tokio::runtime::Runtime, + pub query_failure_probability: parking_lot::Mutex, +} + /// Database gives you a handle that lets you access the database. /// It handles pooling internally. pub struct Database { @@ -68,7 +74,7 @@ pub struct Database { notification_kinds_by_id: HashMap, notification_kinds_by_name: HashMap, #[cfg(test)] - runtime: Option, + test_options: Option, } // The `Database` type has so many methods that its impl blocks are split into @@ -87,7 +93,7 @@ impl Database { notification_kinds_by_name: HashMap::default(), executor, #[cfg(test)] - runtime: None, + test_options: None, }) } @@ -355,11 +361,16 @@ impl Database { { #[cfg(test)] { + let test_options = self.test_options.as_ref().unwrap(); if let Executor::Deterministic(executor) = &self.executor { executor.simulate_random_delay().await; + let fail_probability = *test_options.query_failure_probability.lock(); + if executor.rng().gen_bool(fail_probability) { + return Err(anyhow!("simulated query failure"))?; + } } - self.runtime.as_ref().unwrap().block_on(future) + test_options.runtime.block_on(future) } #[cfg(not(test))] diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index cb27e15d6b..d7967fac98 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -30,7 +30,7 @@ pub struct TestDb { } impl TestDb { - pub fn sqlite(background: BackgroundExecutor) -> Self { + pub fn sqlite(executor: BackgroundExecutor) -> Self { let url = "sqlite::memory:"; let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() @@ -41,7 +41,7 @@ impl TestDb { let mut db = runtime.block_on(async { let mut options = ConnectOptions::new(url); options.max_connections(5); - let mut db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(executor.clone())) .await .unwrap(); let sql = include_str!(concat!( @@ -59,7 +59,10 @@ impl TestDb { db }); - db.runtime = Some(runtime); + db.test_options = Some(DatabaseTestOptions { + runtime, + query_failure_probability: parking_lot::Mutex::new(0.0), + }); Self { db: Some(Arc::new(db)), @@ -67,7 +70,7 @@ impl TestDb { } } - pub fn postgres(background: BackgroundExecutor) -> Self { + pub fn postgres(executor: BackgroundExecutor) -> Self { static LOCK: Mutex<()> = Mutex::new(()); let _guard = LOCK.lock(); @@ -90,7 +93,7 @@ impl TestDb { options .max_connections(5) .idle_timeout(Duration::from_secs(0)); - let mut db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(executor.clone())) .await .unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); @@ -101,7 +104,10 @@ impl TestDb { db }); - db.runtime = Some(runtime); + db.test_options = Some(DatabaseTestOptions { + runtime, + query_failure_probability: parking_lot::Mutex::new(0.0), + }); Self { db: Some(Arc::new(db)), @@ -112,6 +118,12 @@ impl TestDb { pub fn db(&self) -> &Arc { self.db.as_ref().unwrap() } + + pub fn set_query_failure_probability(&self, probability: f64) { + let database = self.db.as_ref().unwrap(); + let test_options = database.test_options.as_ref().unwrap(); + *test_options.query_failure_probability.lock() = probability; + } } #[macro_export] @@ -136,7 +148,7 @@ impl Drop for TestDb { fn drop(&mut self) { let db = self.db.take().unwrap(); if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { - db.runtime.as_ref().unwrap().block_on(async { + db.test_options.as_ref().unwrap().runtime.block_on(async { use util::ResultExt; let query = " SELECT pg_terminate_backend(pg_stat_activity.pid) diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index af8ea38265..20429c7038 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -61,6 +61,35 @@ fn init_logger() { } } +#[gpui::test(iterations = 10)] +async fn test_database_failure_during_client_reconnection( + executor: BackgroundExecutor, + cx: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client = server.create_client(cx, "user_a").await; + + // Keep disconnecting the client until a database failure prevents it from + // reconnecting. + server.test_db.set_query_failure_probability(0.3); + loop { + server.disconnect_client(client.peer_id().unwrap()); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + if !client.status().borrow().is_connected() { + break; + } + } + + // Make the database healthy again and ensure the client can finally connect. + server.test_db.set_query_failure_probability(0.); + executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); + assert!( + matches!(*client.status().borrow(), client::Status::Connected { .. }), + "status was {:?}", + *client.status().borrow() + ); +} + #[gpui::test(iterations = 10)] async fn test_basic_calls( executor: BackgroundExecutor, diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 683bd97e0f..2397ab1c00 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -52,11 +52,11 @@ use livekit_client::test::TestServer as LivekitTestServer; pub struct TestServer { pub app_state: Arc, pub test_livekit_server: Arc, + pub test_db: TestDb, server: Arc, next_github_user_id: i32, connection_killers: Arc>>>, forbid_connections: Arc, - _test_db: TestDb, } pub struct TestClient { @@ -117,7 +117,7 @@ impl TestServer { connection_killers: Default::default(), forbid_connections: Default::default(), next_github_user_id: 0, - _test_db: test_db, + test_db, test_livekit_server: livekit_server, } } @@ -241,7 +241,12 @@ impl TestServer { let user = db .get_user_by_id(user_id) .await - .expect("retrieving user failed") + .map_err(|e| { + EstablishConnectionError::Other(anyhow!( + "retrieving user failed: {}", + e + )) + })? .unwrap(); cx.background_spawn(server.handle_connection( server_conn,