Introduce per-room lock acquired before committing a transaction

This commit is contained in:
Antonio Scandurra 2022-11-28 17:00:47 +01:00
parent 2a0ddd99d2
commit cd0b663f62
4 changed files with 327 additions and 223 deletions

14
Cargo.lock generated
View file

@ -1041,6 +1041,7 @@ dependencies = [
"client", "client",
"collections", "collections",
"ctor", "ctor",
"dashmap",
"editor", "editor",
"env_logger", "env_logger",
"envy", "envy",
@ -1536,6 +1537,19 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "dashmap"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
dependencies = [
"cfg-if 1.0.0",
"hashbrown 0.12.3",
"lock_api",
"once_cell",
"parking_lot_core 0.9.4",
]
[[package]] [[package]]
name = "data-url" name = "data-url"
version = "0.1.1" version = "0.1.1"

View file

@ -24,6 +24,7 @@ axum = { version = "0.5", features = ["json", "headers", "ws"] }
axum-extra = { version = "0.3", features = ["erased-json"] } axum-extra = { version = "0.3", features = ["erased-json"] }
base64 = "0.13" base64 = "0.13"
clap = { version = "3.1", features = ["derive"], optional = true } clap = { version = "3.1", features = ["derive"], optional = true }
dashmap = "5.4"
envy = "0.4.2" envy = "0.4.2"
futures = "0.3" futures = "0.3"
hyper = "0.14" hyper = "0.14"

View file

@ -2,6 +2,7 @@ use crate::{Error, Result};
use anyhow::anyhow; use anyhow::anyhow;
use axum::http::StatusCode; use axum::http::StatusCode;
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use dashmap::DashMap;
use futures::{future::BoxFuture, FutureExt, StreamExt}; use futures::{future::BoxFuture, FutureExt, StreamExt};
use rpc::{proto, ConnectionId}; use rpc::{proto, ConnectionId};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -10,8 +11,17 @@ use sqlx::{
types::Uuid, types::Uuid,
FromRow, FromRow,
}; };
use std::{future::Future, path::Path, time::Duration}; use std::{
future::Future,
marker::PhantomData,
ops::{Deref, DerefMut},
path::Path,
rc::Rc,
sync::Arc,
time::Duration,
};
use time::{OffsetDateTime, PrimitiveDateTime}; use time::{OffsetDateTime, PrimitiveDateTime};
use tokio::sync::{Mutex, OwnedMutexGuard};
#[cfg(test)] #[cfg(test)]
pub type DefaultDb = Db<sqlx::Sqlite>; pub type DefaultDb = Db<sqlx::Sqlite>;
@ -21,12 +31,33 @@ pub type DefaultDb = Db<sqlx::Postgres>;
pub struct Db<D: sqlx::Database> { pub struct Db<D: sqlx::Database> {
pool: sqlx::Pool<D>, pool: sqlx::Pool<D>,
rooms: DashMap<RoomId, Arc<Mutex<()>>>,
#[cfg(test)] #[cfg(test)]
background: Option<std::sync::Arc<gpui::executor::Background>>, background: Option<std::sync::Arc<gpui::executor::Background>>,
#[cfg(test)] #[cfg(test)]
runtime: Option<tokio::runtime::Runtime>, runtime: Option<tokio::runtime::Runtime>,
} }
pub struct RoomGuard<T> {
data: T,
_guard: OwnedMutexGuard<()>,
_not_send: PhantomData<Rc<()>>,
}
impl<T> Deref for RoomGuard<T> {
type Target = T;
fn deref(&self) -> &T {
&self.data
}
}
impl<T> DerefMut for RoomGuard<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.data
}
}
pub trait BeginTransaction: Send + Sync { pub trait BeginTransaction: Send + Sync {
type Database: sqlx::Database; type Database: sqlx::Database;
@ -90,6 +121,7 @@ impl Db<sqlx::Sqlite> {
.await?; .await?;
Ok(Self { Ok(Self {
pool, pool,
rooms: Default::default(),
background: None, background: None,
runtime: None, runtime: None,
}) })
@ -197,6 +229,7 @@ impl Db<sqlx::Postgres> {
.await?; .await?;
Ok(Self { Ok(Self {
pool, pool,
rooms: DashMap::with_capacity(16384),
#[cfg(test)] #[cfg(test)]
background: None, background: None,
#[cfg(test)] #[cfg(test)]
@ -922,13 +955,29 @@ where
.await .await
} }
async fn commit_room_transaction<'a, T>(
&'a self,
room_id: RoomId,
tx: sqlx::Transaction<'static, D>,
data: T,
) -> Result<RoomGuard<T>> {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
tx.commit().await?;
Ok(RoomGuard {
data,
_guard,
_not_send: PhantomData,
})
}
pub async fn create_room( pub async fn create_room(
&self, &self,
user_id: UserId, user_id: UserId,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<proto::Room> { live_kit_room: &str,
) -> Result<RoomGuard<proto::Room>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let live_kit_room = nanoid::nanoid!(30);
let room_id = sqlx::query_scalar( let room_id = sqlx::query_scalar(
" "
INSERT INTO rooms (live_kit_room) INSERT INTO rooms (live_kit_room)
@ -956,8 +1005,7 @@ where
.await?; .await?;
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, room).await
Ok(room)
}).await }).await
} }
@ -968,11 +1016,17 @@ where
calling_connection_id: ConnectionId, calling_connection_id: ConnectionId,
called_user_id: UserId, called_user_id: UserId,
initial_project_id: Option<ProjectId>, initial_project_id: Option<ProjectId>,
) -> Result<(proto::Room, proto::IncomingCall)> { ) -> Result<RoomGuard<(proto::Room, proto::IncomingCall)>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
sqlx::query( sqlx::query(
" "
INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id) INSERT INTO room_participants (
room_id,
user_id,
calling_user_id,
calling_connection_id,
initial_project_id
)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
", ",
) )
@ -985,12 +1039,12 @@ where
.await?; .await?;
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?;
let incoming_call = Self::build_incoming_call(&room, called_user_id) let incoming_call = Self::build_incoming_call(&room, called_user_id)
.ok_or_else(|| anyhow!("failed to build incoming call"))?; .ok_or_else(|| anyhow!("failed to build incoming call"))?;
Ok((room, incoming_call)) self.commit_room_transaction(room_id, tx, (room, incoming_call))
}).await .await
})
.await
} }
pub async fn incoming_call_for_user( pub async fn incoming_call_for_user(
@ -1051,7 +1105,7 @@ where
&self, &self,
room_id: RoomId, room_id: RoomId,
called_user_id: UserId, called_user_id: UserId,
) -> Result<proto::Room> { ) -> Result<RoomGuard<proto::Room>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
sqlx::query( sqlx::query(
" "
@ -1065,8 +1119,7 @@ where
.await?; .await?;
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, room).await
Ok(room)
}) })
.await .await
} }
@ -1075,7 +1128,7 @@ where
&self, &self,
expected_room_id: Option<RoomId>, expected_room_id: Option<RoomId>,
user_id: UserId, user_id: UserId,
) -> Result<proto::Room> { ) -> Result<RoomGuard<proto::Room>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let room_id = sqlx::query_scalar( let room_id = sqlx::query_scalar(
" "
@ -1092,8 +1145,7 @@ where
} }
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, room).await
Ok(room)
}) })
.await .await
} }
@ -1103,7 +1155,7 @@ where
expected_room_id: Option<RoomId>, expected_room_id: Option<RoomId>,
calling_connection_id: ConnectionId, calling_connection_id: ConnectionId,
called_user_id: UserId, called_user_id: UserId,
) -> Result<proto::Room> { ) -> Result<RoomGuard<proto::Room>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let room_id = sqlx::query_scalar( let room_id = sqlx::query_scalar(
" "
@ -1121,8 +1173,7 @@ where
} }
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, room).await
Ok(room)
}).await }).await
} }
@ -1131,7 +1182,7 @@ where
room_id: RoomId, room_id: RoomId,
user_id: UserId, user_id: UserId,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<proto::Room> { ) -> Result<RoomGuard<proto::Room>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
sqlx::query( sqlx::query(
" "
@ -1148,13 +1199,15 @@ where
.await?; .await?;
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, room).await
Ok(room)
}) })
.await .await
} }
pub async fn leave_room(&self, connection_id: ConnectionId) -> Result<Option<LeftRoom>> { pub async fn leave_room(
&self,
connection_id: ConnectionId,
) -> Result<Option<RoomGuard<LeftRoom>>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
// Leave room. // Leave room.
let room_id = sqlx::query_scalar::<_, RoomId>( let room_id = sqlx::query_scalar::<_, RoomId>(
@ -1258,13 +1311,18 @@ where
.await?; .await?;
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; Ok(Some(
self.commit_room_transaction(
Ok(Some(LeftRoom { room_id,
tx,
LeftRoom {
room, room,
left_projects, left_projects,
canceled_calls_to_user_ids, canceled_calls_to_user_ids,
})) },
)
.await?,
))
} else { } else {
Ok(None) Ok(None)
} }
@ -1277,7 +1335,7 @@ where
room_id: RoomId, room_id: RoomId,
connection_id: ConnectionId, connection_id: ConnectionId,
location: proto::ParticipantLocation, location: proto::ParticipantLocation,
) -> Result<proto::Room> { ) -> Result<RoomGuard<proto::Room>> {
self.transact(|tx| async { self.transact(|tx| async {
let mut tx = tx; let mut tx = tx;
let location_kind; let location_kind;
@ -1317,8 +1375,7 @@ where
.await?; .await?;
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, room).await
Ok(room)
}) })
.await .await
} }
@ -1478,7 +1535,7 @@ where
expected_room_id: RoomId, expected_room_id: RoomId,
connection_id: ConnectionId, connection_id: ConnectionId,
worktrees: &[proto::WorktreeMetadata], worktrees: &[proto::WorktreeMetadata],
) -> Result<(ProjectId, proto::Room)> { ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
" "
@ -1560,9 +1617,8 @@ where
.await?; .await?;
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, (project_id, room))
.await
Ok((project_id, room))
}) })
.await .await
} }
@ -1571,7 +1627,7 @@ where
&self, &self,
project_id: ProjectId, project_id: ProjectId,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<(proto::Room, Vec<ConnectionId>)> { ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
let room_id: RoomId = sqlx::query_scalar( let room_id: RoomId = sqlx::query_scalar(
@ -1586,9 +1642,8 @@ where
.fetch_one(&mut tx) .fetch_one(&mut tx)
.await?; .await?;
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
.await
Ok((room, guest_connection_ids))
}) })
.await .await
} }
@ -1598,7 +1653,7 @@ where
project_id: ProjectId, project_id: ProjectId,
connection_id: ConnectionId, connection_id: ConnectionId,
worktrees: &[proto::WorktreeMetadata], worktrees: &[proto::WorktreeMetadata],
) -> Result<(proto::Room, Vec<ConnectionId>)> { ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let room_id: RoomId = sqlx::query_scalar( let room_id: RoomId = sqlx::query_scalar(
" "
@ -1664,9 +1719,8 @@ where
let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
let room = self.get_room(room_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
.await
Ok((room, guest_connection_ids))
}) })
.await .await
} }
@ -1675,15 +1729,15 @@ where
&self, &self,
update: &proto::UpdateWorktree, update: &proto::UpdateWorktree,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<Vec<ConnectionId>> { ) -> Result<RoomGuard<Vec<ConnectionId>>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let project_id = ProjectId::from_proto(update.project_id); let project_id = ProjectId::from_proto(update.project_id);
let worktree_id = WorktreeId::from_proto(update.worktree_id); let worktree_id = WorktreeId::from_proto(update.worktree_id);
// Ensure the update comes from the host. // Ensure the update comes from the host.
sqlx::query( let room_id: RoomId = sqlx::query_scalar(
" "
SELECT 1 SELECT room_id
FROM projects FROM projects
WHERE id = $1 AND host_connection_id = $2 WHERE id = $1 AND host_connection_id = $2
", ",
@ -1781,8 +1835,8 @@ where
} }
let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, connection_ids)
Ok(connection_ids) .await
}) })
.await .await
} }
@ -1791,7 +1845,7 @@ where
&self, &self,
update: &proto::UpdateDiagnosticSummary, update: &proto::UpdateDiagnosticSummary,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<Vec<ConnectionId>> { ) -> Result<RoomGuard<Vec<ConnectionId>>> {
self.transact(|mut tx| async { self.transact(|mut tx| async {
let project_id = ProjectId::from_proto(update.project_id); let project_id = ProjectId::from_proto(update.project_id);
let worktree_id = WorktreeId::from_proto(update.worktree_id); let worktree_id = WorktreeId::from_proto(update.worktree_id);
@ -1801,9 +1855,9 @@ where
.ok_or_else(|| anyhow!("invalid summary"))?; .ok_or_else(|| anyhow!("invalid summary"))?;
// Ensure the update comes from the host. // Ensure the update comes from the host.
sqlx::query( let room_id: RoomId = sqlx::query_scalar(
" "
SELECT 1 SELECT room_id
FROM projects FROM projects
WHERE id = $1 AND host_connection_id = $2 WHERE id = $1 AND host_connection_id = $2
", ",
@ -1841,8 +1895,8 @@ where
.await?; .await?;
let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, connection_ids)
Ok(connection_ids) .await
}) })
.await .await
} }
@ -1851,7 +1905,7 @@ where
&self, &self,
update: &proto::StartLanguageServer, update: &proto::StartLanguageServer,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<Vec<ConnectionId>> { ) -> Result<RoomGuard<Vec<ConnectionId>>> {
self.transact(|mut tx| async { self.transact(|mut tx| async {
let project_id = ProjectId::from_proto(update.project_id); let project_id = ProjectId::from_proto(update.project_id);
let server = update let server = update
@ -1860,9 +1914,9 @@ where
.ok_or_else(|| anyhow!("invalid language server"))?; .ok_or_else(|| anyhow!("invalid language server"))?;
// Ensure the update comes from the host. // Ensure the update comes from the host.
sqlx::query( let room_id: RoomId = sqlx::query_scalar(
" "
SELECT 1 SELECT room_id
FROM projects FROM projects
WHERE id = $1 AND host_connection_id = $2 WHERE id = $1 AND host_connection_id = $2
", ",
@ -1888,8 +1942,8 @@ where
.await?; .await?;
let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
tx.commit().await?; self.commit_room_transaction(room_id, tx, connection_ids)
Ok(connection_ids) .await
}) })
.await .await
} }
@ -1898,7 +1952,7 @@ where
&self, &self,
project_id: ProjectId, project_id: ProjectId,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<(Project, ReplicaId)> { ) -> Result<RoomGuard<(Project, ReplicaId)>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
" "
@ -2068,8 +2122,10 @@ where
.fetch_all(&mut tx) .fetch_all(&mut tx)
.await?; .await?;
tx.commit().await?; self.commit_room_transaction(
Ok(( room_id,
tx,
(
Project { Project {
collaborators, collaborators,
worktrees, worktrees,
@ -2082,7 +2138,9 @@ where
.collect(), .collect(),
}, },
replica_id as ReplicaId, replica_id as ReplicaId,
)) ),
)
.await
}) })
.await .await
} }
@ -2091,7 +2149,7 @@ where
&self, &self,
project_id: ProjectId, project_id: ProjectId,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> Result<LeftProject> { ) -> Result<RoomGuard<LeftProject>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let result = sqlx::query( let result = sqlx::query(
" "
@ -2122,9 +2180,10 @@ where
.map(|id| ConnectionId(id as u32)) .map(|id| ConnectionId(id as u32))
.collect(); .collect();
let (host_user_id, host_connection_id) = sqlx::query_as::<_, (i32, i32)>( let (room_id, host_user_id, host_connection_id) =
sqlx::query_as::<_, (RoomId, i32, i32)>(
" "
SELECT host_user_id, host_connection_id SELECT room_id, host_user_id, host_connection_id
FROM projects FROM projects
WHERE id = $1 WHERE id = $1
", ",
@ -2133,14 +2192,17 @@ where
.fetch_one(&mut tx) .fetch_one(&mut tx)
.await?; .await?;
tx.commit().await?; self.commit_room_transaction(
room_id,
Ok(LeftProject { tx,
LeftProject {
id: project_id, id: project_id,
host_user_id: UserId(host_user_id), host_user_id: UserId(host_user_id),
host_connection_id: ConnectionId(host_connection_id as u32), host_connection_id: ConnectionId(host_connection_id as u32),
connection_ids, connection_ids,
}) },
)
.await
}) })
.await .await
} }
@ -2538,9 +2600,9 @@ where
let result = self.runtime.as_ref().unwrap().block_on(body); let result = self.runtime.as_ref().unwrap().block_on(body);
if let Some(background) = self.background.as_ref() { // if let Some(background) = self.background.as_ref() {
background.simulate_random_delay().await; // background.simulate_random_delay().await;
} // }
result result
} }

View file

@ -42,6 +42,7 @@ use std::{
fmt, fmt,
future::Future, future::Future,
marker::PhantomData, marker::PhantomData,
mem,
net::SocketAddr, net::SocketAddr,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
rc::Rc, rc::Rc,
@ -702,20 +703,15 @@ async fn create_room(
response: Response<proto::CreateRoom>, response: Response<proto::CreateRoom>,
session: Session, session: Session,
) -> Result<()> { ) -> Result<()> {
let room = session let live_kit_room = nanoid::nanoid!(30);
.db()
.await
.create_room(session.user_id, session.connection_id)
.await?;
let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
if let Some(_) = live_kit if let Some(_) = live_kit
.create_room(room.live_kit_room.clone()) .create_room(live_kit_room.clone())
.await .await
.trace_err() .trace_err()
{ {
if let Some(token) = live_kit if let Some(token) = live_kit
.room_token(&room.live_kit_room, &session.connection_id.to_string()) .room_token(&live_kit_room, &session.connection_id.to_string())
.trace_err() .trace_err()
{ {
Some(proto::LiveKitConnectionInfo { Some(proto::LiveKitConnectionInfo {
@ -732,10 +728,19 @@ async fn create_room(
None None
}; };
{
let room = session
.db()
.await
.create_room(session.user_id, session.connection_id, &live_kit_room)
.await?;
response.send(proto::CreateRoomResponse { response.send(proto::CreateRoomResponse {
room: Some(room), room: Some(room.clone()),
live_kit_connection_info, live_kit_connection_info,
})?; })?;
}
update_user_contacts(session.user_id, &session).await?; update_user_contacts(session.user_id, &session).await?;
Ok(()) Ok(())
} }
@ -745,6 +750,7 @@ async fn join_room(
response: Response<proto::JoinRoom>, response: Response<proto::JoinRoom>,
session: Session, session: Session,
) -> Result<()> { ) -> Result<()> {
let room = {
let room = session let room = session
.db() .db()
.await .await
@ -754,6 +760,10 @@ async fn join_room(
session.connection_id, session.connection_id,
) )
.await?; .await?;
room_updated(&room, &session);
room.clone()
};
for connection_id in session for connection_id in session
.connection_pool() .connection_pool()
.await .await
@ -781,7 +791,6 @@ async fn join_room(
None None
}; };
room_updated(&room, &session);
response.send(proto::JoinRoomResponse { response.send(proto::JoinRoomResponse {
room: Some(room), room: Some(room),
live_kit_connection_info, live_kit_connection_info,
@ -814,7 +823,8 @@ async fn call(
return Err(anyhow!("cannot call a user who isn't a contact"))?; return Err(anyhow!("cannot call a user who isn't a contact"))?;
} }
let (room, incoming_call) = session let incoming_call = {
let (room, incoming_call) = &mut *session
.db() .db()
.await .await
.call( .call(
@ -826,6 +836,8 @@ async fn call(
) )
.await?; .await?;
room_updated(&room, &session); room_updated(&room, &session);
mem::take(incoming_call)
};
update_user_contacts(called_user_id, &session).await?; update_user_contacts(called_user_id, &session).await?;
let mut calls = session let mut calls = session
@ -847,12 +859,14 @@ async fn call(
} }
} }
{
let room = session let room = session
.db() .db()
.await .await
.call_failed(room_id, called_user_id) .call_failed(room_id, called_user_id)
.await?; .await?;
room_updated(&room, &session); room_updated(&room, &session);
}
update_user_contacts(called_user_id, &session).await?; update_user_contacts(called_user_id, &session).await?;
Err(anyhow!("failed to ring user"))? Err(anyhow!("failed to ring user"))?
@ -865,11 +879,15 @@ async fn cancel_call(
) -> Result<()> { ) -> Result<()> {
let called_user_id = UserId::from_proto(request.called_user_id); let called_user_id = UserId::from_proto(request.called_user_id);
let room_id = RoomId::from_proto(request.room_id); let room_id = RoomId::from_proto(request.room_id);
{
let room = session let room = session
.db() .db()
.await .await
.cancel_call(Some(room_id), session.connection_id, called_user_id) .cancel_call(Some(room_id), session.connection_id, called_user_id)
.await?; .await?;
room_updated(&room, &session);
}
for connection_id in session for connection_id in session
.connection_pool() .connection_pool()
.await .await
@ -880,7 +898,6 @@ async fn cancel_call(
.send(connection_id, proto::CallCanceled {}) .send(connection_id, proto::CallCanceled {})
.trace_err(); .trace_err();
} }
room_updated(&room, &session);
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
update_user_contacts(called_user_id, &session).await?; update_user_contacts(called_user_id, &session).await?;
@ -889,11 +906,15 @@ async fn cancel_call(
async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
let room_id = RoomId::from_proto(message.room_id); let room_id = RoomId::from_proto(message.room_id);
{
let room = session let room = session
.db() .db()
.await .await
.decline_call(Some(room_id), session.user_id) .decline_call(Some(room_id), session.user_id)
.await?; .await?;
room_updated(&room, &session);
}
for connection_id in session for connection_id in session
.connection_pool() .connection_pool()
.await .await
@ -904,7 +925,6 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<(
.send(connection_id, proto::CallCanceled {}) .send(connection_id, proto::CallCanceled {})
.trace_err(); .trace_err();
} }
room_updated(&room, &session);
update_user_contacts(session.user_id, &session).await?; update_user_contacts(session.user_id, &session).await?;
Ok(()) Ok(())
} }
@ -933,7 +953,7 @@ async fn share_project(
response: Response<proto::ShareProject>, response: Response<proto::ShareProject>,
session: Session, session: Session,
) -> Result<()> { ) -> Result<()> {
let (project_id, room) = session let (project_id, room) = &*session
.db() .db()
.await .await
.share_project( .share_project(
@ -953,15 +973,17 @@ async fn share_project(
async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
let project_id = ProjectId::from_proto(message.project_id); let project_id = ProjectId::from_proto(message.project_id);
let (room, guest_connection_ids) = session let (room, guest_connection_ids) = &*session
.db() .db()
.await .await
.unshare_project(project_id, session.connection_id) .unshare_project(project_id, session.connection_id)
.await?; .await?;
broadcast(session.connection_id, guest_connection_ids, |conn_id| { broadcast(
session.peer.send(conn_id, message.clone()) session.connection_id,
}); guest_connection_ids.iter().copied(),
|conn_id| session.peer.send(conn_id, message.clone()),
);
room_updated(&room, &session); room_updated(&room, &session);
Ok(()) Ok(())
@ -977,7 +999,7 @@ async fn join_project(
tracing::info!(%project_id, "join project"); tracing::info!(%project_id, "join project");
let (project, replica_id) = session let (project, replica_id) = &mut *session
.db() .db()
.await .await
.join_project(project_id, session.connection_id) .join_project(project_id, session.connection_id)
@ -1029,7 +1051,7 @@ async fn join_project(
language_servers: project.language_servers.clone(), language_servers: project.language_servers.clone(),
})?; })?;
for (worktree_id, worktree) in project.worktrees { for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
const MAX_CHUNK_SIZE: usize = 2; const MAX_CHUNK_SIZE: usize = 2;
#[cfg(not(any(test, feature = "test-support")))] #[cfg(not(any(test, feature = "test-support")))]
@ -1084,9 +1106,8 @@ async fn join_project(
async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
let sender_id = session.connection_id; let sender_id = session.connection_id;
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let project;
{ let project = session
project = session
.db() .db()
.await .await
.leave_project(project_id, sender_id) .leave_project(project_id, sender_id)
@ -1098,7 +1119,10 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result
"leave project" "leave project"
); );
broadcast(sender_id, project.connection_ids, |conn_id| { broadcast(
sender_id,
project.connection_ids.iter().copied(),
|conn_id| {
session.peer.send( session.peer.send(
conn_id, conn_id,
proto::RemoveProjectCollaborator { proto::RemoveProjectCollaborator {
@ -1106,8 +1130,8 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result
peer_id: sender_id.0, peer_id: sender_id.0,
}, },
) )
}); },
} );
Ok(()) Ok(())
} }
@ -1118,14 +1142,14 @@ async fn update_project(
session: Session, session: Session,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let (room, guest_connection_ids) = session let (room, guest_connection_ids) = &*session
.db() .db()
.await .await
.update_project(project_id, session.connection_id, &request.worktrees) .update_project(project_id, session.connection_id, &request.worktrees)
.await?; .await?;
broadcast( broadcast(
session.connection_id, session.connection_id,
guest_connection_ids, guest_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer
@ -1151,7 +1175,7 @@ async fn update_worktree(
broadcast( broadcast(
session.connection_id, session.connection_id,
guest_connection_ids, guest_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer
@ -1175,7 +1199,7 @@ async fn update_diagnostic_summary(
broadcast( broadcast(
session.connection_id, session.connection_id,
guest_connection_ids, guest_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer
@ -1199,7 +1223,7 @@ async fn start_language_server(
broadcast( broadcast(
session.connection_id, session.connection_id,
guest_connection_ids, guest_connection_ids.iter().copied(),
|connection_id| { |connection_id| {
session session
.peer .peer
@ -1826,18 +1850,22 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()>
async fn leave_room_for_session(session: &Session) -> Result<()> { async fn leave_room_for_session(session: &Session) -> Result<()> {
let mut contacts_to_update = HashSet::default(); let mut contacts_to_update = HashSet::default();
let Some(left_room) = session.db().await.leave_room(session.connection_id).await? else { let canceled_calls_to_user_ids;
let live_kit_room;
let delete_live_kit_room;
{
let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? else {
return Err(anyhow!("no room to leave"))?; return Err(anyhow!("no room to leave"))?;
}; };
contacts_to_update.insert(session.user_id); contacts_to_update.insert(session.user_id);
for project in left_room.left_projects.into_values() { for project in left_room.left_projects.values() {
for connection_id in project.connection_ids { for connection_id in &project.connection_ids {
if project.host_user_id == session.user_id { if project.host_user_id == session.user_id {
session session
.peer .peer
.send( .send(
connection_id, *connection_id,
proto::UnshareProject { proto::UnshareProject {
project_id: project.id.to_proto(), project_id: project.id.to_proto(),
}, },
@ -1847,7 +1875,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
session session
.peer .peer
.send( .send(
connection_id, *connection_id,
proto::RemoveProjectCollaborator { proto::RemoveProjectCollaborator {
project_id: project.id.to_proto(), project_id: project.id.to_proto(),
peer_id: session.connection_id.0, peer_id: session.connection_id.0,
@ -1869,9 +1897,14 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
} }
room_updated(&left_room.room, &session); room_updated(&left_room.room, &session);
canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
live_kit_room = mem::take(&mut left_room.room.live_kit_room);
delete_live_kit_room = left_room.room.participants.is_empty();
}
{ {
let pool = session.connection_pool().await; let pool = session.connection_pool().await;
for canceled_user_id in left_room.canceled_calls_to_user_ids { for canceled_user_id in canceled_calls_to_user_ids {
for connection_id in pool.user_connection_ids(canceled_user_id) { for connection_id in pool.user_connection_ids(canceled_user_id) {
session session
.peer .peer
@ -1888,18 +1921,12 @@ async fn leave_room_for_session(session: &Session) -> Result<()> {
if let Some(live_kit) = session.live_kit_client.as_ref() { if let Some(live_kit) = session.live_kit_client.as_ref() {
live_kit live_kit
.remove_participant( .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
left_room.room.live_kit_room.clone(),
session.connection_id.to_string(),
)
.await .await
.trace_err(); .trace_err();
if left_room.room.participants.is_empty() { if delete_live_kit_room {
live_kit live_kit.delete_room(live_kit_room).await.trace_err();
.delete_room(left_room.room.live_kit_room)
.await
.trace_err();
} }
} }