diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 44cc382ee0..78b6547ef2 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1171,44 +1171,68 @@ where .fetch_all(&mut tx) .await?; - let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>( + let project_ids = sqlx::query_scalar::<_, ProjectId>( " - SELECT project_collaborators.* - FROM projects, project_collaborators - WHERE - projects.room_id = $1 AND - projects.id = project_collaborators.project_id AND - project_collaborators.connection_id = $2 + SELECT project_id + FROM project_collaborators + WHERE connection_id = $1 ", ) - .bind(room_id) .bind(connection_id.0 as i32) - .fetch(&mut tx); + .fetch_all(&mut tx) + .await?; + // Leave projects. let mut left_projects = HashMap::default(); - while let Some(collaborator) = project_collaborators.next().await { - let collaborator = collaborator?; - let left_project = - left_projects - .entry(collaborator.project_id) - .or_insert(LeftProject { - id: collaborator.project_id, - host_user_id: Default::default(), - connection_ids: Default::default(), - }); - - let collaborator_connection_id = - ConnectionId(collaborator.connection_id as u32); - if collaborator_connection_id != connection_id || collaborator.is_host { - left_project.connection_ids.push(collaborator_connection_id); + if !project_ids.is_empty() { + let mut params = "?,".repeat(project_ids.len()); + params.pop(); + let query = format!( + " + SELECT * + FROM project_collaborators + WHERE project_id IN ({params}) + " + ); + let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query); + for project_id in project_ids { + query = query.bind(project_id); } - if collaborator.is_host { - left_project.host_user_id = collaborator.user_id; + let mut project_collaborators = query.fetch(&mut tx); + while let Some(collaborator) = project_collaborators.next().await { + let collaborator = collaborator?; + let left_project = + left_projects + .entry(collaborator.project_id) + .or_insert(LeftProject { + id: collaborator.project_id, + host_user_id: Default::default(), + connection_ids: Default::default(), + }); + + let collaborator_connection_id = + ConnectionId(collaborator.connection_id as u32); + if collaborator_connection_id != connection_id { + left_project.connection_ids.push(collaborator_connection_id); + } + + if collaborator.is_host { + left_project.host_user_id = collaborator.user_id; + } } } - drop(project_collaborators); + sqlx::query( + " + DELETE FROM project_collaborators + WHERE connection_id = $1 + ", + ) + .bind(connection_id.0 as i32) + .execute(&mut tx) + .await?; + // Unshare projects. sqlx::query( " DELETE FROM projects @@ -1265,15 +1289,16 @@ where sqlx::query( " UPDATE room_participants - SET location_kind = $1 AND location_project_id = $2 + SET location_kind = $1, location_project_id = $2 WHERE room_id = $3 AND answering_connection_id = $4 + RETURNING 1 ", ) .bind(location_kind) .bind(location_project_id) .bind(room_id) .bind(connection_id.0 as i32) - .execute(&mut tx) + .fetch_one(&mut tx) .await?; self.commit_room_transaction(room_id, tx).await @@ -1335,21 +1360,32 @@ where let ( user_id, answering_connection_id, - _location_kind, - _location_project_id, + location_kind, + location_project_id, calling_user_id, initial_project_id, ) = participant?; if let Some(answering_connection_id) = answering_connection_id { + let location = match (location_kind, location_project_id) { + (Some(0), Some(project_id)) => { + Some(proto::participant_location::Variant::SharedProject( + proto::participant_location::SharedProject { + id: project_id.to_proto(), + }, + )) + } + (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject( + Default::default(), + )), + _ => Some(proto::participant_location::Variant::External( + Default::default(), + )), + }; participants.push(proto::Participant { user_id: user_id.to_proto(), peer_id: answering_connection_id as u32, projects: Default::default(), - location: Some(proto::ParticipantLocation { - variant: Some(proto::participant_location::Variant::External( - Default::default(), - )), - }), + location: Some(proto::ParticipantLocation { variant: location }), }); } else { pending_participants.push(proto::PendingParticipant { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index f0116f04f9..9f7d21a1a9 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -624,19 +624,19 @@ impl Server { async fn leave_room_for_connection( self: &Arc, - connection_id: ConnectionId, - user_id: UserId, + leaving_connection_id: ConnectionId, + leaving_user_id: UserId, ) -> Result<()> { let mut contacts_to_update = HashSet::default(); - let Some(left_room) = self.app_state.db.leave_room_for_connection(connection_id).await? else { + let Some(left_room) = self.app_state.db.leave_room_for_connection(leaving_connection_id).await? else { return Err(anyhow!("no room to leave"))?; }; - contacts_to_update.insert(user_id); + contacts_to_update.insert(leaving_user_id); for project in left_room.left_projects.into_values() { - if project.host_user_id == user_id { - for connection_id in project.connection_ids { + for connection_id in project.connection_ids { + if project.host_user_id == leaving_user_id { self.peer .send( connection_id, @@ -645,29 +645,27 @@ impl Server { }, ) .trace_err(); - } - } else { - for connection_id in project.connection_ids { + } else { self.peer .send( connection_id, proto::RemoveProjectCollaborator { project_id: project.id.to_proto(), - peer_id: connection_id.0, + peer_id: leaving_connection_id.0, }, ) .trace_err(); } - - self.peer - .send( - connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); } + + self.peer + .send( + leaving_connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); } self.room_updated(&left_room.room); @@ -691,7 +689,7 @@ impl Server { live_kit .remove_participant( left_room.room.live_kit_room.clone(), - connection_id.to_string(), + leaving_connection_id.to_string(), ) .await .trace_err(); @@ -941,6 +939,9 @@ impl Server { let collaborators = project .collaborators .iter() + .filter(|collaborator| { + collaborator.connection_id != request.sender_connection_id.0 as i32 + }) .map(|collaborator| proto::Collaborator { peer_id: collaborator.connection_id as u32, replica_id: collaborator.replica_id.0 as u32, @@ -958,23 +959,20 @@ impl Server { }) .collect::>(); - for collaborator in &project.collaborators { - let connection_id = ConnectionId(collaborator.connection_id as u32); - if connection_id != request.sender_connection_id { - self.peer - .send( - connection_id, - proto::AddProjectCollaborator { - project_id: project_id.to_proto(), - collaborator: Some(proto::Collaborator { - peer_id: request.sender_connection_id.0, - replica_id: replica_id.0 as u32, - user_id: guest_user_id.to_proto(), - }), - }, - ) - .trace_err(); - } + for collaborator in &collaborators { + self.peer + .send( + ConnectionId(collaborator.peer_id), + proto::AddProjectCollaborator { + project_id: project_id.to_proto(), + collaborator: Some(proto::Collaborator { + peer_id: request.sender_connection_id.0, + replica_id: replica_id.0 as u32, + user_id: guest_user_id.to_proto(), + }), + }, + ) + .trace_err(); } // First, we send the metadata associated with each worktree.