fix rejoin after quit (#10100)

Release Notes:

- collab: Fixed rejoining channels quickly after a restart
This commit is contained in:
Conrad Irwin 2024-04-02 20:35:14 -06:00 committed by GitHub
parent 8958c9e10f
commit fe7b12c444
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 124 additions and 49 deletions

View file

@ -349,6 +349,17 @@ impl Database {
.await .await
} }
pub async fn stale_room_connection(&self, user_id: UserId) -> Result<Option<ConnectionId>> {
self.transaction(|tx| async move {
let participant = room_participant::Entity::find()
.filter(room_participant::Column::UserId.eq(user_id))
.one(&*tx)
.await?;
Ok(participant.and_then(|p| p.answering_connection()))
})
.await
}
async fn get_next_participant_index_internal( async fn get_next_participant_index_internal(
&self, &self,
room_id: RoomId, room_id: RoomId,
@ -403,7 +414,28 @@ impl Database {
.get_next_participant_index_internal(room_id, tx) .get_next_participant_index_internal(room_id, tx)
.await?; .await?;
room_participant::Entity::insert_many([room_participant::ActiveModel { // If someone has been invited into the room, accept the invite instead of inserting
let result = room_participant::Entity::update_many()
.filter(
Condition::all()
.add(room_participant::Column::RoomId.eq(room_id))
.add(room_participant::Column::UserId.eq(user_id))
.add(room_participant::Column::AnsweringConnectionId.is_null()),
)
.set(room_participant::ActiveModel {
participant_index: ActiveValue::Set(Some(participant_index)),
answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
answering_connection_server_id: ActiveValue::set(Some(ServerId(
connection.owner_id as i32,
))),
answering_connection_lost: ActiveValue::set(false),
..Default::default()
})
.exec(tx)
.await?;
if result.rows_affected == 0 {
room_participant::Entity::insert(room_participant::ActiveModel {
room_id: ActiveValue::set(room_id), room_id: ActiveValue::set(room_id),
user_id: ActiveValue::set(user_id), user_id: ActiveValue::set(user_id),
answering_connection_id: ActiveValue::set(Some(connection.id as i32)), answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
@ -422,20 +454,10 @@ impl Database {
location_kind: ActiveValue::NotSet, location_kind: ActiveValue::NotSet,
location_project_id: ActiveValue::NotSet, location_project_id: ActiveValue::NotSet,
initial_project_id: ActiveValue::NotSet, initial_project_id: ActiveValue::NotSet,
}]) })
.on_conflict(
OnConflict::columns([room_participant::Column::UserId])
.update_columns([
room_participant::Column::AnsweringConnectionId,
room_participant::Column::AnsweringConnectionServerId,
room_participant::Column::AnsweringConnectionLost,
room_participant::Column::ParticipantIndex,
room_participant::Column::Role,
])
.to_owned(),
)
.exec(tx) .exec(tx)
.await?; .await?;
}
let (channel, room) = self.get_channel_room(room_id, &tx).await?; let (channel, room) = self.get_channel_room(room_id, &tx).await?;
let channel = channel.ok_or_else(|| anyhow!("no channel for room"))?; let channel = channel.ok_or_else(|| anyhow!("no channel for room"))?;

View file

@ -1203,7 +1203,7 @@ async fn connection_lost(
_ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
if let Some(session) = session.for_user() { if let Some(session) = session.for_user() {
log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id); log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
leave_room_for_session(&session).await.trace_err(); leave_room_for_session(&session, session.connection_id).await.trace_err();
leave_channel_buffers_for_session(&session) leave_channel_buffers_for_session(&session)
.await .await
.trace_err(); .trace_err();
@ -1539,7 +1539,7 @@ async fn leave_room(
response: Response<proto::LeaveRoom>, response: Response<proto::LeaveRoom>,
session: UserSession, session: UserSession,
) -> Result<()> { ) -> Result<()> {
leave_room_for_session(&session).await?; leave_room_for_session(&session, session.connection_id).await?;
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
Ok(()) Ok(())
} }
@ -3023,8 +3023,19 @@ async fn join_channel_internal(
session: UserSession, session: UserSession,
) -> Result<()> { ) -> Result<()> {
let joined_room = { let joined_room = {
leave_room_for_session(&session).await?; let mut db = session.db().await;
let db = session.db().await; // If zed quits without leaving the room, and the user re-opens zed before the
// RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
// room they were in.
if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
tracing::info!(
stale_connection_id = %connection,
"cleaning up stale connection",
);
drop(db);
leave_room_for_session(&session, connection).await?;
db = session.db().await;
}
let (joined_room, membership_updated, role) = db let (joined_room, membership_updated, role) = db
.join_channel(channel_id, session.user_id(), session.connection_id) .join_channel(channel_id, session.user_id(), session.connection_id)
@ -4199,7 +4210,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()>
Ok(()) Ok(())
} }
async fn leave_room_for_session(session: &UserSession) -> Result<()> { async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
let mut contacts_to_update = HashSet::default(); let mut contacts_to_update = HashSet::default();
let room_id; let room_id;
@ -4209,7 +4220,7 @@ async fn leave_room_for_session(session: &UserSession) -> Result<()> {
let room; let room;
let channel; let channel;
if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? { if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
contacts_to_update.insert(session.user_id()); contacts_to_update.insert(session.user_id());
for project in left_room.left_projects.values() { for project in left_room.left_projects.values() {

View file

@ -2007,7 +2007,7 @@ async fn test_following_to_channel_notes_without_a_shared_project(
}); });
} }
async fn join_channel( pub(crate) async fn join_channel(
channel_id: ChannelId, channel_id: ChannelId,
client: &TestClient, client: &TestClient,
cx: &mut TestAppContext, cx: &mut TestAppContext,

View file

@ -1,6 +1,9 @@
use crate::{ use crate::{
rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
tests::{channel_id, room_participants, rust_lang, RoomParticipants, TestClient, TestServer}, tests::{
channel_id, following_tests::join_channel, room_participants, rust_lang, RoomParticipants,
TestClient, TestServer,
},
}; };
use call::{room, ActiveCall, ParticipantLocation, Room}; use call::{room, ActiveCall, ParticipantLocation, Room};
use client::{User, RECEIVE_TIMEOUT}; use client::{User, RECEIVE_TIMEOUT};
@ -5914,7 +5917,7 @@ async fn test_right_click_menu_behind_collab_panel(cx: &mut TestAppContext) {
#[gpui::test] #[gpui::test]
async fn test_cmd_k_left(cx: &mut TestAppContext) { async fn test_cmd_k_left(cx: &mut TestAppContext) {
let client = TestServer::start1(cx).await; let (_, client) = TestServer::start1(cx).await;
let (workspace, cx) = client.build_test_workspace(cx).await; let (workspace, cx) = client.build_test_workspace(cx).await;
cx.simulate_keystrokes("cmd-n"); cx.simulate_keystrokes("cmd-n");
@ -5934,3 +5937,16 @@ async fn test_cmd_k_left(cx: &mut TestAppContext) {
assert!(workspace.items(cx).collect::<Vec<_>>().len() == 2); assert!(workspace.items(cx).collect::<Vec<_>>().len() == 2);
}); });
} }
#[gpui::test]
async fn test_join_after_restart(cx1: &mut TestAppContext, cx2: &mut TestAppContext) {
let (mut server, client) = TestServer::start1(cx1).await;
let channel1 = server.make_public_channel("channel1", &client, cx1).await;
let channel2 = server.make_public_channel("channel2", &client, cx1).await;
join_channel(channel1, &client, cx1).await.unwrap();
drop(client);
let client2 = server.create_client(cx2, "user_a").await;
join_channel(channel2, &client2, cx2).await.unwrap();
}

View file

@ -135,9 +135,10 @@ impl TestServer {
(server, client_a, client_b, channel_id) (server, client_a, client_b, channel_id)
} }
pub async fn start1(cx: &mut TestAppContext) -> TestClient { pub async fn start1(cx: &mut TestAppContext) -> (TestServer, TestClient) {
let mut server = Self::start(cx.executor().clone()).await; let mut server = Self::start(cx.executor().clone()).await;
server.create_client(cx, "user_a").await let client = server.create_client(cx, "user_a").await;
(server, client)
} }
pub async fn reset(&self) { pub async fn reset(&self) {

View file

@ -219,11 +219,17 @@ impl BackgroundExecutor {
if let Some(test) = self.dispatcher.as_test() { if let Some(test) = self.dispatcher.as_test() {
if !test.parking_allowed() { if !test.parking_allowed() {
let mut backtrace_message = String::new(); let mut backtrace_message = String::new();
let mut waiting_message = String::new();
if let Some(backtrace) = test.waiting_backtrace() { if let Some(backtrace) = test.waiting_backtrace() {
backtrace_message = backtrace_message =
format!("\nbacktrace of waiting future:\n{:?}", backtrace); format!("\nbacktrace of waiting future:\n{:?}", backtrace);
} }
panic!("parked with nothing left to run\n{:?}", backtrace_message) if let Some(waiting_hint) = test.waiting_hint() {
waiting_message = format!("\n waiting on: {}\n", waiting_hint);
}
panic!(
"parked with nothing left to run{waiting_message}{backtrace_message}",
)
} }
} }
@ -354,6 +360,12 @@ impl BackgroundExecutor {
self.dispatcher.as_test().unwrap().forbid_parking(); self.dispatcher.as_test().unwrap().forbid_parking();
} }
/// adds detail to the "parked with nothing let to run" message.
#[cfg(any(test, feature = "test-support"))]
pub fn set_waiting_hint(&self, msg: Option<String>) {
self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
}
/// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub fn rng(&self) -> StdRng { pub fn rng(&self) -> StdRng {

View file

@ -36,6 +36,7 @@ struct TestDispatcherState {
is_main_thread: bool, is_main_thread: bool,
next_id: TestDispatcherId, next_id: TestDispatcherId,
allow_parking: bool, allow_parking: bool,
waiting_hint: Option<String>,
waiting_backtrace: Option<Backtrace>, waiting_backtrace: Option<Backtrace>,
deprioritized_task_labels: HashSet<TaskLabel>, deprioritized_task_labels: HashSet<TaskLabel>,
block_on_ticks: RangeInclusive<usize>, block_on_ticks: RangeInclusive<usize>,
@ -54,6 +55,7 @@ impl TestDispatcher {
is_main_thread: true, is_main_thread: true,
next_id: TestDispatcherId(1), next_id: TestDispatcherId(1),
allow_parking: false, allow_parking: false,
waiting_hint: None,
waiting_backtrace: None, waiting_backtrace: None,
deprioritized_task_labels: Default::default(), deprioritized_task_labels: Default::default(),
block_on_ticks: 0..=1000, block_on_ticks: 0..=1000,
@ -132,6 +134,14 @@ impl TestDispatcher {
self.state.lock().allow_parking = false self.state.lock().allow_parking = false
} }
pub fn set_waiting_hint(&self, msg: Option<String>) {
self.state.lock().waiting_hint = msg
}
pub fn waiting_hint(&self) -> Option<String> {
self.state.lock().waiting_hint.clone()
}
pub fn start_waiting(&self) { pub fn start_waiting(&self) {
self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved()); self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
} }

View file

@ -69,6 +69,7 @@ impl TestPlatform {
.multiple_choice .multiple_choice
.pop_front() .pop_front()
.expect("no pending multiple choice prompt"); .expect("no pending multiple choice prompt");
self.background_executor().set_waiting_hint(None);
tx.send(response_ix).ok(); tx.send(response_ix).ok();
} }
@ -76,8 +77,10 @@ impl TestPlatform {
!self.prompts.borrow().multiple_choice.is_empty() !self.prompts.borrow().multiple_choice.is_empty()
} }
pub(crate) fn prompt(&self) -> oneshot::Receiver<usize> { pub(crate) fn prompt(&self, msg: &str, detail: Option<&str>) -> oneshot::Receiver<usize> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
self.background_executor()
.set_waiting_hint(Some(format!("PROMPT: {:?} {:?}", msg, detail)));
self.prompts.borrow_mut().multiple_choice.push_back(tx); self.prompts.borrow_mut().multiple_choice.push_back(tx);
rx rx
} }

View file

@ -159,8 +159,8 @@ impl PlatformWindow for TestWindow {
fn prompt( fn prompt(
&self, &self,
_level: crate::PromptLevel, _level: crate::PromptLevel,
_msg: &str, msg: &str,
_detail: Option<&str>, detail: Option<&str>,
_answers: &[&str], _answers: &[&str],
) -> Option<futures::channel::oneshot::Receiver<usize>> { ) -> Option<futures::channel::oneshot::Receiver<usize>> {
Some( Some(
@ -169,7 +169,7 @@ impl PlatformWindow for TestWindow {
.platform .platform
.upgrade() .upgrade()
.expect("platform dropped") .expect("platform dropped")
.prompt(), .prompt(msg, detail),
) )
} }