Ensure client reconnects after erroring during the handshake (#31278)

Release Notes:

- Fixed a bug that prevented Zed from reconnecting after erroring during
the initial handshake with the server.
This commit is contained in:
Antonio Scandurra 2025-05-23 15:46:30 +02:00 committed by GitHub
parent 03ac3fb91a
commit 9dba8e5b0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 79 additions and 14 deletions

View file

@ -905,7 +905,15 @@ impl Client {
}
futures::select_biased! {
result = self.set_connection(conn, cx).fuse() => ConnectionResult::Result(result.context("client auth and connect")),
result = self.set_connection(conn, cx).fuse() => {
match result.context("client auth and connect") {
Ok(()) => ConnectionResult::Result(Ok(())),
Err(err) => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(err))
},
}
},
_ = timeout => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Timeout

View file

@ -56,6 +56,12 @@ pub use sea_orm::ConnectOptions;
pub use tables::user::Model as User;
pub use tables::*;
#[cfg(test)]
pub struct DatabaseTestOptions {
pub runtime: tokio::runtime::Runtime,
pub query_failure_probability: parking_lot::Mutex<f64>,
}
/// Database gives you a handle that lets you access the database.
/// It handles pooling internally.
pub struct Database {
@ -68,7 +74,7 @@ pub struct Database {
notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
notification_kinds_by_name: HashMap<String, NotificationKindId>,
#[cfg(test)]
runtime: Option<tokio::runtime::Runtime>,
test_options: Option<DatabaseTestOptions>,
}
// The `Database` type has so many methods that its impl blocks are split into
@ -87,7 +93,7 @@ impl Database {
notification_kinds_by_name: HashMap::default(),
executor,
#[cfg(test)]
runtime: None,
test_options: None,
})
}
@ -355,11 +361,16 @@ impl Database {
{
#[cfg(test)]
{
let test_options = self.test_options.as_ref().unwrap();
if let Executor::Deterministic(executor) = &self.executor {
executor.simulate_random_delay().await;
let fail_probability = *test_options.query_failure_probability.lock();
if executor.rng().gen_bool(fail_probability) {
return Err(anyhow!("simulated query failure"))?;
}
}
self.runtime.as_ref().unwrap().block_on(future)
test_options.runtime.block_on(future)
}
#[cfg(not(test))]

View file

@ -30,7 +30,7 @@ pub struct TestDb {
}
impl TestDb {
pub fn sqlite(background: BackgroundExecutor) -> Self {
pub fn sqlite(executor: BackgroundExecutor) -> Self {
let url = "sqlite::memory:";
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
@ -41,7 +41,7 @@ impl TestDb {
let mut db = runtime.block_on(async {
let mut options = ConnectOptions::new(url);
options.max_connections(5);
let mut db = Database::new(options, Executor::Deterministic(background))
let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
.await
.unwrap();
let sql = include_str!(concat!(
@ -59,7 +59,10 @@ impl TestDb {
db
});
db.runtime = Some(runtime);
db.test_options = Some(DatabaseTestOptions {
runtime,
query_failure_probability: parking_lot::Mutex::new(0.0),
});
Self {
db: Some(Arc::new(db)),
@ -67,7 +70,7 @@ impl TestDb {
}
}
pub fn postgres(background: BackgroundExecutor) -> Self {
pub fn postgres(executor: BackgroundExecutor) -> Self {
static LOCK: Mutex<()> = Mutex::new(());
let _guard = LOCK.lock();
@ -90,7 +93,7 @@ impl TestDb {
options
.max_connections(5)
.idle_timeout(Duration::from_secs(0));
let mut db = Database::new(options, Executor::Deterministic(background))
let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
.await
.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
@ -101,7 +104,10 @@ impl TestDb {
db
});
db.runtime = Some(runtime);
db.test_options = Some(DatabaseTestOptions {
runtime,
query_failure_probability: parking_lot::Mutex::new(0.0),
});
Self {
db: Some(Arc::new(db)),
@ -112,6 +118,12 @@ impl TestDb {
pub fn db(&self) -> &Arc<Database> {
self.db.as_ref().unwrap()
}
pub fn set_query_failure_probability(&self, probability: f64) {
let database = self.db.as_ref().unwrap();
let test_options = database.test_options.as_ref().unwrap();
*test_options.query_failure_probability.lock() = probability;
}
}
#[macro_export]
@ -136,7 +148,7 @@ impl Drop for TestDb {
fn drop(&mut self) {
let db = self.db.take().unwrap();
if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
db.runtime.as_ref().unwrap().block_on(async {
db.test_options.as_ref().unwrap().runtime.block_on(async {
use util::ResultExt;
let query = "
SELECT pg_terminate_backend(pg_stat_activity.pid)

View file

@ -61,6 +61,35 @@ fn init_logger() {
}
}
#[gpui::test(iterations = 10)]
async fn test_database_failure_during_client_reconnection(
executor: BackgroundExecutor,
cx: &mut TestAppContext,
) {
let mut server = TestServer::start(executor.clone()).await;
let client = server.create_client(cx, "user_a").await;
// Keep disconnecting the client until a database failure prevents it from
// reconnecting.
server.test_db.set_query_failure_probability(0.3);
loop {
server.disconnect_client(client.peer_id().unwrap());
executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
if !client.status().borrow().is_connected() {
break;
}
}
// Make the database healthy again and ensure the client can finally connect.
server.test_db.set_query_failure_probability(0.);
executor.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
assert!(
matches!(*client.status().borrow(), client::Status::Connected { .. }),
"status was {:?}",
*client.status().borrow()
);
}
#[gpui::test(iterations = 10)]
async fn test_basic_calls(
executor: BackgroundExecutor,

View file

@ -52,11 +52,11 @@ use livekit_client::test::TestServer as LivekitTestServer;
pub struct TestServer {
pub app_state: Arc<AppState>,
pub test_livekit_server: Arc<LivekitTestServer>,
pub test_db: TestDb,
server: Arc<Server>,
next_github_user_id: i32,
connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
}
pub struct TestClient {
@ -117,7 +117,7 @@ impl TestServer {
connection_killers: Default::default(),
forbid_connections: Default::default(),
next_github_user_id: 0,
_test_db: test_db,
test_db,
test_livekit_server: livekit_server,
}
}
@ -241,7 +241,12 @@ impl TestServer {
let user = db
.get_user_by_id(user_id)
.await
.expect("retrieving user failed")
.map_err(|e| {
EstablishConnectionError::Other(anyhow!(
"retrieving user failed: {}",
e
))
})?
.unwrap();
cx.background_spawn(server.handle_connection(
server_conn,