Return project collaborators and connection IDs in a RoomGuard

This commit is contained in:
Antonio Scandurra 2022-12-05 18:37:01 +01:00
parent be3fb1e985
commit 5443d9cffe
2 changed files with 57 additions and 44 deletions

View file

@ -1981,8 +1981,12 @@ impl Database {
&self, &self,
project_id: ProjectId, project_id: ProjectId,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<Vec<project_collaborator::Model>> { ) -> Result<RoomGuard<Vec<project_collaborator::Model>>> {
self.transaction(|tx| async move { self.room_transaction(|tx| async move {
let project = project::Entity::find_by_id(project_id)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such project"))?;
let collaborators = project_collaborator::Entity::find() let collaborators = project_collaborator::Entity::find()
.filter(project_collaborator::Column::ProjectId.eq(project_id)) .filter(project_collaborator::Column::ProjectId.eq(project_id))
.all(&*tx) .all(&*tx)
@ -1992,7 +1996,7 @@ impl Database {
.iter() .iter()
.any(|collaborator| collaborator.connection_id == connection_id.0 as i32) .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
{ {
Ok(collaborators) Ok((project.room_id, collaborators))
} else { } else {
Err(anyhow!("no such project"))? Err(anyhow!("no such project"))?
} }
@ -2004,13 +2008,17 @@ impl Database {
&self, &self,
project_id: ProjectId, project_id: ProjectId,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<HashSet<ConnectionId>> { ) -> Result<RoomGuard<HashSet<ConnectionId>>> {
self.transaction(|tx| async move { self.room_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs { enum QueryAs {
ConnectionId, ConnectionId,
} }
let project = project::Entity::find_by_id(project_id)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such project"))?;
let mut db_connection_ids = project_collaborator::Entity::find() let mut db_connection_ids = project_collaborator::Entity::find()
.select_only() .select_only()
.column_as( .column_as(
@ -2028,7 +2036,7 @@ impl Database {
} }
if connection_ids.contains(&connection_id) { if connection_ids.contains(&connection_id) {
Ok(connection_ids) Ok((project.room_id, connection_ids))
} else { } else {
Err(anyhow!("no such project"))? Err(anyhow!("no such project"))?
} }

View file

@ -1245,7 +1245,7 @@ async fn update_language_server(
.await?; .await?;
broadcast( broadcast(
session.connection_id, session.connection_id,
project_connection_ids, project_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer
@ -1264,23 +1264,24 @@ where
T: EntityMessage + RequestMessage, T: EntityMessage + RequestMessage,
{ {
let project_id = ProjectId::from_proto(request.remote_entity_id()); let project_id = ProjectId::from_proto(request.remote_entity_id());
let collaborators = session let host_connection_id = {
.db() let collaborators = session
.await .db()
.project_collaborators(project_id, session.connection_id) .await
.await?; .project_collaborators(project_id, session.connection_id)
let host = collaborators .await?;
.iter() ConnectionId(
.find(|collaborator| collaborator.is_host) collaborators
.ok_or_else(|| anyhow!("host not found"))?; .iter()
.find(|collaborator| collaborator.is_host)
.ok_or_else(|| anyhow!("host not found"))?
.connection_id as u32,
)
};
let payload = session let payload = session
.peer .peer
.forward_request( .forward_request(session.connection_id, host_connection_id, request)
session.connection_id,
ConnectionId(host.connection_id as u32),
request,
)
.await?; .await?;
response.send(payload)?; response.send(payload)?;
@ -1293,16 +1294,18 @@ async fn save_buffer(
session: Session, session: Session,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let collaborators = session let host_connection_id = {
.db() let collaborators = session
.await .db()
.project_collaborators(project_id, session.connection_id) .await
.await?; .project_collaborators(project_id, session.connection_id)
let host = collaborators .await?;
.into_iter() let host = collaborators
.find(|collaborator| collaborator.is_host) .iter()
.ok_or_else(|| anyhow!("host not found"))?; .find(|collaborator| collaborator.is_host)
let host_connection_id = ConnectionId(host.connection_id as u32); .ok_or_else(|| anyhow!("host not found"))?;
ConnectionId(host.connection_id as u32)
};
let response_payload = session let response_payload = session
.peer .peer
.forward_request(session.connection_id, host_connection_id, request.clone()) .forward_request(session.connection_id, host_connection_id, request.clone())
@ -1316,7 +1319,7 @@ async fn save_buffer(
collaborators collaborators
.retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
let project_connection_ids = collaborators let project_connection_ids = collaborators
.into_iter() .iter()
.map(|collaborator| ConnectionId(collaborator.connection_id as u32)); .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
broadcast(host_connection_id, project_connection_ids, |conn_id| { broadcast(host_connection_id, project_connection_ids, |conn_id| {
session session
@ -1353,7 +1356,7 @@ async fn update_buffer(
broadcast( broadcast(
session.connection_id, session.connection_id,
project_connection_ids, project_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer
@ -1374,7 +1377,7 @@ async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session)
broadcast( broadcast(
session.connection_id, session.connection_id,
project_connection_ids, project_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer
@ -1393,7 +1396,7 @@ async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Re
.await?; .await?;
broadcast( broadcast(
session.connection_id, session.connection_id,
project_connection_ids, project_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer
@ -1412,7 +1415,7 @@ async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<(
.await?; .await?;
broadcast( broadcast(
session.connection_id, session.connection_id,
project_connection_ids, project_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer
@ -1430,14 +1433,16 @@ async fn follow(
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let leader_id = ConnectionId(request.leader_id); let leader_id = ConnectionId(request.leader_id);
let follower_id = session.connection_id; let follower_id = session.connection_id;
let project_connection_ids = session {
.db() let project_connection_ids = session
.await .db()
.project_connection_ids(project_id, session.connection_id) .await
.await?; .project_connection_ids(project_id, session.connection_id)
.await?;
if !project_connection_ids.contains(&leader_id) { if !project_connection_ids.contains(&leader_id) {
Err(anyhow!("no such peer"))?; Err(anyhow!("no such peer"))?;
}
} }
let mut response_payload = session let mut response_payload = session
@ -1691,7 +1696,7 @@ async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> R
.await?; .await?;
broadcast( broadcast(
session.connection_id, session.connection_id,
project_connection_ids, project_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer