diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index 62289cdeaa..c53f60872d 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -349,6 +349,17 @@ impl Database { .await } + pub async fn stale_room_connection(&self, user_id: UserId) -> Result> { + self.transaction(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::UserId.eq(user_id)) + .one(&*tx) + .await?; + Ok(participant.and_then(|p| p.answering_connection())) + }) + .await + } + async fn get_next_participant_index_internal( &self, room_id: RoomId, @@ -403,39 +414,50 @@ impl Database { .get_next_participant_index_internal(room_id, tx) .await?; - room_participant::Entity::insert_many([room_participant::ActiveModel { - room_id: ActiveValue::set(room_id), - user_id: ActiveValue::set(user_id), - answering_connection_id: ActiveValue::set(Some(connection.id as i32)), - answering_connection_server_id: ActiveValue::set(Some(ServerId( - connection.owner_id as i32, - ))), - answering_connection_lost: ActiveValue::set(false), - calling_user_id: ActiveValue::set(user_id), - calling_connection_id: ActiveValue::set(connection.id as i32), - calling_connection_server_id: ActiveValue::set(Some(ServerId( - connection.owner_id as i32, - ))), - participant_index: ActiveValue::Set(Some(participant_index)), - role: ActiveValue::set(Some(role)), - id: ActiveValue::NotSet, - location_kind: ActiveValue::NotSet, - location_project_id: ActiveValue::NotSet, - initial_project_id: ActiveValue::NotSet, - }]) - .on_conflict( - OnConflict::columns([room_participant::Column::UserId]) - .update_columns([ - room_participant::Column::AnsweringConnectionId, - room_participant::Column::AnsweringConnectionServerId, - room_participant::Column::AnsweringConnectionLost, - room_participant::Column::ParticipantIndex, - room_participant::Column::Role, - ]) - .to_owned(), - ) - .exec(tx) - .await?; + // If someone has been invited into the room, accept the invite instead of inserting + let result = room_participant::Entity::update_many() + .filter( + Condition::all() + .add(room_participant::Column::RoomId.eq(room_id)) + .add(room_participant::Column::UserId.eq(user_id)) + .add(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .set(room_participant::ActiveModel { + participant_index: ActiveValue::Set(Some(participant_index)), + answering_connection_id: ActiveValue::set(Some(connection.id as i32)), + answering_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + answering_connection_lost: ActiveValue::set(false), + ..Default::default() + }) + .exec(tx) + .await?; + + if result.rows_affected == 0 { + room_participant::Entity::insert(room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(user_id), + answering_connection_id: ActiveValue::set(Some(connection.id as i32)), + answering_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + answering_connection_lost: ActiveValue::set(false), + calling_user_id: ActiveValue::set(user_id), + calling_connection_id: ActiveValue::set(connection.id as i32), + calling_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + participant_index: ActiveValue::Set(Some(participant_index)), + role: ActiveValue::set(Some(role)), + id: ActiveValue::NotSet, + location_kind: ActiveValue::NotSet, + location_project_id: ActiveValue::NotSet, + initial_project_id: ActiveValue::NotSet, + }) + .exec(tx) + .await?; + } let (channel, room) = self.get_channel_room(room_id, &tx).await?; let channel = channel.ok_or_else(|| anyhow!("no channel for room"))?; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index c14a8fe8b8..4055b7ca8d 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1203,7 +1203,7 @@ async fn connection_lost( _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { if let Some(session) = session.for_user() { log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id); - leave_room_for_session(&session).await.trace_err(); + leave_room_for_session(&session, session.connection_id).await.trace_err(); leave_channel_buffers_for_session(&session) .await .trace_err(); @@ -1539,7 +1539,7 @@ async fn leave_room( response: Response, session: UserSession, ) -> Result<()> { - leave_room_for_session(&session).await?; + leave_room_for_session(&session, session.connection_id).await?; response.send(proto::Ack {})?; Ok(()) } @@ -3023,8 +3023,19 @@ async fn join_channel_internal( session: UserSession, ) -> Result<()> { let joined_room = { - leave_room_for_session(&session).await?; - let db = session.db().await; + let mut db = session.db().await; + // If zed quits without leaving the room, and the user re-opens zed before the + // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous + // room they were in. + if let Some(connection) = db.stale_room_connection(session.user_id()).await? { + tracing::info!( + stale_connection_id = %connection, + "cleaning up stale connection", + ); + drop(db); + leave_room_for_session(&session, connection).await?; + db = session.db().await; + } let (joined_room, membership_updated, role) = db .join_channel(channel_id, session.user_id(), session.connection_id) @@ -4199,7 +4210,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> Ok(()) } -async fn leave_room_for_session(session: &UserSession) -> Result<()> { +async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> { let mut contacts_to_update = HashSet::default(); let room_id; @@ -4209,7 +4220,7 @@ async fn leave_room_for_session(session: &UserSession) -> Result<()> { let room; let channel; - if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? { + if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? { contacts_to_update.insert(session.user_id()); for project in left_room.left_projects.values() { diff --git a/crates/collab/src/tests/following_tests.rs b/crates/collab/src/tests/following_tests.rs index 7c1179c69b..1756b450ca 100644 --- a/crates/collab/src/tests/following_tests.rs +++ b/crates/collab/src/tests/following_tests.rs @@ -2007,7 +2007,7 @@ async fn test_following_to_channel_notes_without_a_shared_project( }); } -async fn join_channel( +pub(crate) async fn join_channel( channel_id: ChannelId, client: &TestClient, cx: &mut TestAppContext, diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index fe130e68e4..598d755e92 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -1,6 +1,9 @@ use crate::{ rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, - tests::{channel_id, room_participants, rust_lang, RoomParticipants, TestClient, TestServer}, + tests::{ + channel_id, following_tests::join_channel, room_participants, rust_lang, RoomParticipants, + TestClient, TestServer, + }, }; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{User, RECEIVE_TIMEOUT}; @@ -5914,7 +5917,7 @@ async fn test_right_click_menu_behind_collab_panel(cx: &mut TestAppContext) { #[gpui::test] async fn test_cmd_k_left(cx: &mut TestAppContext) { - let client = TestServer::start1(cx).await; + let (_, client) = TestServer::start1(cx).await; let (workspace, cx) = client.build_test_workspace(cx).await; cx.simulate_keystrokes("cmd-n"); @@ -5934,3 +5937,16 @@ async fn test_cmd_k_left(cx: &mut TestAppContext) { assert!(workspace.items(cx).collect::>().len() == 2); }); } + +#[gpui::test] +async fn test_join_after_restart(cx1: &mut TestAppContext, cx2: &mut TestAppContext) { + let (mut server, client) = TestServer::start1(cx1).await; + let channel1 = server.make_public_channel("channel1", &client, cx1).await; + let channel2 = server.make_public_channel("channel2", &client, cx1).await; + + join_channel(channel1, &client, cx1).await.unwrap(); + drop(client); + + let client2 = server.create_client(cx2, "user_a").await; + join_channel(channel2, &client2, cx2).await.unwrap(); +} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index f1364cdc66..78323bde76 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -135,9 +135,10 @@ impl TestServer { (server, client_a, client_b, channel_id) } - pub async fn start1(cx: &mut TestAppContext) -> TestClient { + pub async fn start1(cx: &mut TestAppContext) -> (TestServer, TestClient) { let mut server = Self::start(cx.executor().clone()).await; - server.create_client(cx, "user_a").await + let client = server.create_client(cx, "user_a").await; + (server, client) } pub async fn reset(&self) { diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 6bc64b0be7..841fc5b19e 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -219,11 +219,17 @@ impl BackgroundExecutor { if let Some(test) = self.dispatcher.as_test() { if !test.parking_allowed() { let mut backtrace_message = String::new(); + let mut waiting_message = String::new(); if let Some(backtrace) = test.waiting_backtrace() { backtrace_message = format!("\nbacktrace of waiting future:\n{:?}", backtrace); } - panic!("parked with nothing left to run\n{:?}", backtrace_message) + if let Some(waiting_hint) = test.waiting_hint() { + waiting_message = format!("\n waiting on: {}\n", waiting_hint); + } + panic!( + "parked with nothing left to run{waiting_message}{backtrace_message}", + ) } } @@ -354,6 +360,12 @@ impl BackgroundExecutor { self.dispatcher.as_test().unwrap().forbid_parking(); } + /// adds detail to the "parked with nothing let to run" message. + #[cfg(any(test, feature = "test-support"))] + pub fn set_waiting_hint(&self, msg: Option) { + self.dispatcher.as_test().unwrap().set_waiting_hint(msg); + } + /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable #[cfg(any(test, feature = "test-support"))] pub fn rng(&self) -> StdRng { diff --git a/crates/gpui/src/platform/test/dispatcher.rs b/crates/gpui/src/platform/test/dispatcher.rs index 51fc65bc97..24850187fc 100644 --- a/crates/gpui/src/platform/test/dispatcher.rs +++ b/crates/gpui/src/platform/test/dispatcher.rs @@ -36,6 +36,7 @@ struct TestDispatcherState { is_main_thread: bool, next_id: TestDispatcherId, allow_parking: bool, + waiting_hint: Option, waiting_backtrace: Option, deprioritized_task_labels: HashSet, block_on_ticks: RangeInclusive, @@ -54,6 +55,7 @@ impl TestDispatcher { is_main_thread: true, next_id: TestDispatcherId(1), allow_parking: false, + waiting_hint: None, waiting_backtrace: None, deprioritized_task_labels: Default::default(), block_on_ticks: 0..=1000, @@ -132,6 +134,14 @@ impl TestDispatcher { self.state.lock().allow_parking = false } + pub fn set_waiting_hint(&self, msg: Option) { + self.state.lock().waiting_hint = msg + } + + pub fn waiting_hint(&self) -> Option { + self.state.lock().waiting_hint.clone() + } + pub fn start_waiting(&self) { self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved()); } diff --git a/crates/gpui/src/platform/test/platform.rs b/crates/gpui/src/platform/test/platform.rs index 0d673b36aa..dde8cf1db6 100644 --- a/crates/gpui/src/platform/test/platform.rs +++ b/crates/gpui/src/platform/test/platform.rs @@ -69,6 +69,7 @@ impl TestPlatform { .multiple_choice .pop_front() .expect("no pending multiple choice prompt"); + self.background_executor().set_waiting_hint(None); tx.send(response_ix).ok(); } @@ -76,8 +77,10 @@ impl TestPlatform { !self.prompts.borrow().multiple_choice.is_empty() } - pub(crate) fn prompt(&self) -> oneshot::Receiver { + pub(crate) fn prompt(&self, msg: &str, detail: Option<&str>) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); + self.background_executor() + .set_waiting_hint(Some(format!("PROMPT: {:?} {:?}", msg, detail))); self.prompts.borrow_mut().multiple_choice.push_back(tx); rx } diff --git a/crates/gpui/src/platform/test/window.rs b/crates/gpui/src/platform/test/window.rs index 75c42dafcd..14c72251a3 100644 --- a/crates/gpui/src/platform/test/window.rs +++ b/crates/gpui/src/platform/test/window.rs @@ -159,8 +159,8 @@ impl PlatformWindow for TestWindow { fn prompt( &self, _level: crate::PromptLevel, - _msg: &str, - _detail: Option<&str>, + msg: &str, + detail: Option<&str>, _answers: &[&str], ) -> Option> { Some( @@ -169,7 +169,7 @@ impl PlatformWindow for TestWindow { .platform .upgrade() .expect("platform dropped") - .prompt(), + .prompt(msg, detail), ) }