diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 8250a8354f..915acb00eb 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1981,8 +1981,12 @@ impl Database { &self, project_id: ProjectId, connection_id: ConnectionId, - ) -> Result> { - self.transaction(|tx| async move { + ) -> Result>> { + self.room_transaction(|tx| async move { + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; let collaborators = project_collaborator::Entity::find() .filter(project_collaborator::Column::ProjectId.eq(project_id)) .all(&*tx) @@ -1992,7 +1996,7 @@ impl Database { .iter() .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) { - Ok(collaborators) + Ok((project.room_id, collaborators)) } else { Err(anyhow!("no such project"))? } @@ -2004,13 +2008,17 @@ impl Database { &self, project_id: ProjectId, connection_id: ConnectionId, - ) -> Result> { - self.transaction(|tx| async move { + ) -> Result>> { + self.room_transaction(|tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryAs { ConnectionId, } + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; let mut db_connection_ids = project_collaborator::Entity::find() .select_only() .column_as( @@ -2028,7 +2036,7 @@ impl Database { } if connection_ids.contains(&connection_id) { - Ok(connection_ids) + Ok((project.room_id, connection_ids)) } else { Err(anyhow!("no such project"))? } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 7f404feffe..79544de6fb 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1245,7 +1245,7 @@ async fn update_language_server( .await?; broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1264,23 +1264,24 @@ where T: EntityMessage + RequestMessage, { let project_id = ProjectId::from_proto(request.remote_entity_id()); - let collaborators = session - .db() - .await - .project_collaborators(project_id, session.connection_id) - .await?; - let host = collaborators - .iter() - .find(|collaborator| collaborator.is_host) - .ok_or_else(|| anyhow!("host not found"))?; + let host_connection_id = { + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + ConnectionId( + collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))? + .connection_id as u32, + ) + }; let payload = session .peer - .forward_request( - session.connection_id, - ConnectionId(host.connection_id as u32), - request, - ) + .forward_request(session.connection_id, host_connection_id, request) .await?; response.send(payload)?; @@ -1293,16 +1294,18 @@ async fn save_buffer( session: Session, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); - let collaborators = session - .db() - .await - .project_collaborators(project_id, session.connection_id) - .await?; - let host = collaborators - .into_iter() - .find(|collaborator| collaborator.is_host) - .ok_or_else(|| anyhow!("host not found"))?; - let host_connection_id = ConnectionId(host.connection_id as u32); + let host_connection_id = { + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + let host = collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + ConnectionId(host.connection_id as u32) + }; let response_payload = session .peer .forward_request(session.connection_id, host_connection_id, request.clone()) @@ -1316,7 +1319,7 @@ async fn save_buffer( collaborators .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); let project_connection_ids = collaborators - .into_iter() + .iter() .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); broadcast(host_connection_id, project_connection_ids, |conn_id| { session @@ -1353,7 +1356,7 @@ async fn update_buffer( broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1374,7 +1377,7 @@ async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1393,7 +1396,7 @@ async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Re .await?; broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1412,7 +1415,7 @@ async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<( .await?; broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1430,14 +1433,16 @@ async fn follow( let project_id = ProjectId::from_proto(request.project_id); let leader_id = ConnectionId(request.leader_id); let follower_id = session.connection_id; - let project_connection_ids = session - .db() - .await - .project_connection_ids(project_id, session.connection_id) - .await?; + { + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; - if !project_connection_ids.contains(&leader_id) { - Err(anyhow!("no such peer"))?; + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; + } } let mut response_payload = session @@ -1691,7 +1696,7 @@ async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> R .await?; broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer