Introduce per-room lock acquired before committing a transaction
This commit is contained in:
parent
2a0ddd99d2
commit
cd0b663f62
4 changed files with 327 additions and 223 deletions
|
@ -2,6 +2,7 @@ use crate::{Error, Result};
|
|||
use anyhow::anyhow;
|
||||
use axum::http::StatusCode;
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use dashmap::DashMap;
|
||||
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||
use rpc::{proto, ConnectionId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -10,8 +11,17 @@ use sqlx::{
|
|||
types::Uuid,
|
||||
FromRow,
|
||||
};
|
||||
use std::{future::Future, path::Path, time::Duration};
|
||||
use std::{
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
ops::{Deref, DerefMut},
|
||||
path::Path,
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use time::{OffsetDateTime, PrimitiveDateTime};
|
||||
use tokio::sync::{Mutex, OwnedMutexGuard};
|
||||
|
||||
#[cfg(test)]
|
||||
pub type DefaultDb = Db<sqlx::Sqlite>;
|
||||
|
@ -21,12 +31,33 @@ pub type DefaultDb = Db<sqlx::Postgres>;
|
|||
|
||||
pub struct Db<D: sqlx::Database> {
|
||||
pool: sqlx::Pool<D>,
|
||||
rooms: DashMap<RoomId, Arc<Mutex<()>>>,
|
||||
#[cfg(test)]
|
||||
background: Option<std::sync::Arc<gpui::executor::Background>>,
|
||||
#[cfg(test)]
|
||||
runtime: Option<tokio::runtime::Runtime>,
|
||||
}
|
||||
|
||||
pub struct RoomGuard<T> {
|
||||
data: T,
|
||||
_guard: OwnedMutexGuard<()>,
|
||||
_not_send: PhantomData<Rc<()>>,
|
||||
}
|
||||
|
||||
impl<T> Deref for RoomGuard<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &T {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for RoomGuard<T> {
|
||||
fn deref_mut(&mut self) -> &mut T {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
pub trait BeginTransaction: Send + Sync {
|
||||
type Database: sqlx::Database;
|
||||
|
||||
|
@ -90,6 +121,7 @@ impl Db<sqlx::Sqlite> {
|
|||
.await?;
|
||||
Ok(Self {
|
||||
pool,
|
||||
rooms: Default::default(),
|
||||
background: None,
|
||||
runtime: None,
|
||||
})
|
||||
|
@ -197,6 +229,7 @@ impl Db<sqlx::Postgres> {
|
|||
.await?;
|
||||
Ok(Self {
|
||||
pool,
|
||||
rooms: DashMap::with_capacity(16384),
|
||||
#[cfg(test)]
|
||||
background: None,
|
||||
#[cfg(test)]
|
||||
|
@ -922,13 +955,29 @@ where
|
|||
.await
|
||||
}
|
||||
|
||||
async fn commit_room_transaction<'a, T>(
|
||||
&'a self,
|
||||
room_id: RoomId,
|
||||
tx: sqlx::Transaction<'static, D>,
|
||||
data: T,
|
||||
) -> Result<RoomGuard<T>> {
|
||||
let lock = self.rooms.entry(room_id).or_default().clone();
|
||||
let _guard = lock.lock_owned().await;
|
||||
tx.commit().await?;
|
||||
Ok(RoomGuard {
|
||||
data,
|
||||
_guard,
|
||||
_not_send: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn create_room(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<proto::Room> {
|
||||
live_kit_room: &str,
|
||||
) -> Result<RoomGuard<proto::Room>> {
|
||||
self.transact(|mut tx| async move {
|
||||
let live_kit_room = nanoid::nanoid!(30);
|
||||
let room_id = sqlx::query_scalar(
|
||||
"
|
||||
INSERT INTO rooms (live_kit_room)
|
||||
|
@ -956,8 +1005,7 @@ where
|
|||
.await?;
|
||||
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(room)
|
||||
self.commit_room_transaction(room_id, tx, room).await
|
||||
}).await
|
||||
}
|
||||
|
||||
|
@ -968,11 +1016,17 @@ where
|
|||
calling_connection_id: ConnectionId,
|
||||
called_user_id: UserId,
|
||||
initial_project_id: Option<ProjectId>,
|
||||
) -> Result<(proto::Room, proto::IncomingCall)> {
|
||||
) -> Result<RoomGuard<(proto::Room, proto::IncomingCall)>> {
|
||||
self.transact(|mut tx| async move {
|
||||
sqlx::query(
|
||||
"
|
||||
INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id)
|
||||
INSERT INTO room_participants (
|
||||
room_id,
|
||||
user_id,
|
||||
calling_user_id,
|
||||
calling_connection_id,
|
||||
initial_project_id
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
",
|
||||
)
|
||||
|
@ -985,12 +1039,12 @@ where
|
|||
.await?;
|
||||
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
|
||||
let incoming_call = Self::build_incoming_call(&room, called_user_id)
|
||||
.ok_or_else(|| anyhow!("failed to build incoming call"))?;
|
||||
Ok((room, incoming_call))
|
||||
}).await
|
||||
self.commit_room_transaction(room_id, tx, (room, incoming_call))
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn incoming_call_for_user(
|
||||
|
@ -1051,7 +1105,7 @@ where
|
|||
&self,
|
||||
room_id: RoomId,
|
||||
called_user_id: UserId,
|
||||
) -> Result<proto::Room> {
|
||||
) -> Result<RoomGuard<proto::Room>> {
|
||||
self.transact(|mut tx| async move {
|
||||
sqlx::query(
|
||||
"
|
||||
|
@ -1065,8 +1119,7 @@ where
|
|||
.await?;
|
||||
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(room)
|
||||
self.commit_room_transaction(room_id, tx, room).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -1075,7 +1128,7 @@ where
|
|||
&self,
|
||||
expected_room_id: Option<RoomId>,
|
||||
user_id: UserId,
|
||||
) -> Result<proto::Room> {
|
||||
) -> Result<RoomGuard<proto::Room>> {
|
||||
self.transact(|mut tx| async move {
|
||||
let room_id = sqlx::query_scalar(
|
||||
"
|
||||
|
@ -1092,8 +1145,7 @@ where
|
|||
}
|
||||
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(room)
|
||||
self.commit_room_transaction(room_id, tx, room).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -1103,7 +1155,7 @@ where
|
|||
expected_room_id: Option<RoomId>,
|
||||
calling_connection_id: ConnectionId,
|
||||
called_user_id: UserId,
|
||||
) -> Result<proto::Room> {
|
||||
) -> Result<RoomGuard<proto::Room>> {
|
||||
self.transact(|mut tx| async move {
|
||||
let room_id = sqlx::query_scalar(
|
||||
"
|
||||
|
@ -1121,8 +1173,7 @@ where
|
|||
}
|
||||
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(room)
|
||||
self.commit_room_transaction(room_id, tx, room).await
|
||||
}).await
|
||||
}
|
||||
|
||||
|
@ -1131,7 +1182,7 @@ where
|
|||
room_id: RoomId,
|
||||
user_id: UserId,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<proto::Room> {
|
||||
) -> Result<RoomGuard<proto::Room>> {
|
||||
self.transact(|mut tx| async move {
|
||||
sqlx::query(
|
||||
"
|
||||
|
@ -1148,13 +1199,15 @@ where
|
|||
.await?;
|
||||
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(room)
|
||||
self.commit_room_transaction(room_id, tx, room).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn leave_room(&self, connection_id: ConnectionId) -> Result<Option<LeftRoom>> {
|
||||
pub async fn leave_room(
|
||||
&self,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<Option<RoomGuard<LeftRoom>>> {
|
||||
self.transact(|mut tx| async move {
|
||||
// Leave room.
|
||||
let room_id = sqlx::query_scalar::<_, RoomId>(
|
||||
|
@ -1258,13 +1311,18 @@ where
|
|||
.await?;
|
||||
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(Some(LeftRoom {
|
||||
room,
|
||||
left_projects,
|
||||
canceled_calls_to_user_ids,
|
||||
}))
|
||||
Ok(Some(
|
||||
self.commit_room_transaction(
|
||||
room_id,
|
||||
tx,
|
||||
LeftRoom {
|
||||
room,
|
||||
left_projects,
|
||||
canceled_calls_to_user_ids,
|
||||
},
|
||||
)
|
||||
.await?,
|
||||
))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
|
@ -1277,7 +1335,7 @@ where
|
|||
room_id: RoomId,
|
||||
connection_id: ConnectionId,
|
||||
location: proto::ParticipantLocation,
|
||||
) -> Result<proto::Room> {
|
||||
) -> Result<RoomGuard<proto::Room>> {
|
||||
self.transact(|tx| async {
|
||||
let mut tx = tx;
|
||||
let location_kind;
|
||||
|
@ -1317,8 +1375,7 @@ where
|
|||
.await?;
|
||||
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(room)
|
||||
self.commit_room_transaction(room_id, tx, room).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -1478,7 +1535,7 @@ where
|
|||
expected_room_id: RoomId,
|
||||
connection_id: ConnectionId,
|
||||
worktrees: &[proto::WorktreeMetadata],
|
||||
) -> Result<(ProjectId, proto::Room)> {
|
||||
) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
|
||||
self.transact(|mut tx| async move {
|
||||
let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
|
||||
"
|
||||
|
@ -1560,9 +1617,8 @@ where
|
|||
.await?;
|
||||
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
|
||||
Ok((project_id, room))
|
||||
self.commit_room_transaction(room_id, tx, (project_id, room))
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -1571,7 +1627,7 @@ where
|
|||
&self,
|
||||
project_id: ProjectId,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<(proto::Room, Vec<ConnectionId>)> {
|
||||
) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
|
||||
self.transact(|mut tx| async move {
|
||||
let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
|
||||
let room_id: RoomId = sqlx::query_scalar(
|
||||
|
@ -1586,9 +1642,8 @@ where
|
|||
.fetch_one(&mut tx)
|
||||
.await?;
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
|
||||
Ok((room, guest_connection_ids))
|
||||
self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -1598,7 +1653,7 @@ where
|
|||
project_id: ProjectId,
|
||||
connection_id: ConnectionId,
|
||||
worktrees: &[proto::WorktreeMetadata],
|
||||
) -> Result<(proto::Room, Vec<ConnectionId>)> {
|
||||
) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
|
||||
self.transact(|mut tx| async move {
|
||||
let room_id: RoomId = sqlx::query_scalar(
|
||||
"
|
||||
|
@ -1664,9 +1719,8 @@ where
|
|||
|
||||
let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
|
||||
let room = self.get_room(room_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
|
||||
Ok((room, guest_connection_ids))
|
||||
self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -1675,15 +1729,15 @@ where
|
|||
&self,
|
||||
update: &proto::UpdateWorktree,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<Vec<ConnectionId>> {
|
||||
) -> Result<RoomGuard<Vec<ConnectionId>>> {
|
||||
self.transact(|mut tx| async move {
|
||||
let project_id = ProjectId::from_proto(update.project_id);
|
||||
let worktree_id = WorktreeId::from_proto(update.worktree_id);
|
||||
|
||||
// Ensure the update comes from the host.
|
||||
sqlx::query(
|
||||
let room_id: RoomId = sqlx::query_scalar(
|
||||
"
|
||||
SELECT 1
|
||||
SELECT room_id
|
||||
FROM projects
|
||||
WHERE id = $1 AND host_connection_id = $2
|
||||
",
|
||||
|
@ -1781,8 +1835,8 @@ where
|
|||
}
|
||||
|
||||
let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(connection_ids)
|
||||
self.commit_room_transaction(room_id, tx, connection_ids)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -1791,7 +1845,7 @@ where
|
|||
&self,
|
||||
update: &proto::UpdateDiagnosticSummary,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<Vec<ConnectionId>> {
|
||||
) -> Result<RoomGuard<Vec<ConnectionId>>> {
|
||||
self.transact(|mut tx| async {
|
||||
let project_id = ProjectId::from_proto(update.project_id);
|
||||
let worktree_id = WorktreeId::from_proto(update.worktree_id);
|
||||
|
@ -1801,9 +1855,9 @@ where
|
|||
.ok_or_else(|| anyhow!("invalid summary"))?;
|
||||
|
||||
// Ensure the update comes from the host.
|
||||
sqlx::query(
|
||||
let room_id: RoomId = sqlx::query_scalar(
|
||||
"
|
||||
SELECT 1
|
||||
SELECT room_id
|
||||
FROM projects
|
||||
WHERE id = $1 AND host_connection_id = $2
|
||||
",
|
||||
|
@ -1841,8 +1895,8 @@ where
|
|||
.await?;
|
||||
|
||||
let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(connection_ids)
|
||||
self.commit_room_transaction(room_id, tx, connection_ids)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -1851,7 +1905,7 @@ where
|
|||
&self,
|
||||
update: &proto::StartLanguageServer,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<Vec<ConnectionId>> {
|
||||
) -> Result<RoomGuard<Vec<ConnectionId>>> {
|
||||
self.transact(|mut tx| async {
|
||||
let project_id = ProjectId::from_proto(update.project_id);
|
||||
let server = update
|
||||
|
@ -1860,9 +1914,9 @@ where
|
|||
.ok_or_else(|| anyhow!("invalid language server"))?;
|
||||
|
||||
// Ensure the update comes from the host.
|
||||
sqlx::query(
|
||||
let room_id: RoomId = sqlx::query_scalar(
|
||||
"
|
||||
SELECT 1
|
||||
SELECT room_id
|
||||
FROM projects
|
||||
WHERE id = $1 AND host_connection_id = $2
|
||||
",
|
||||
|
@ -1888,8 +1942,8 @@ where
|
|||
.await?;
|
||||
|
||||
let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(connection_ids)
|
||||
self.commit_room_transaction(room_id, tx, connection_ids)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -1898,7 +1952,7 @@ where
|
|||
&self,
|
||||
project_id: ProjectId,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<(Project, ReplicaId)> {
|
||||
) -> Result<RoomGuard<(Project, ReplicaId)>> {
|
||||
self.transact(|mut tx| async move {
|
||||
let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
|
||||
"
|
||||
|
@ -2068,21 +2122,25 @@ where
|
|||
.fetch_all(&mut tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
Ok((
|
||||
Project {
|
||||
collaborators,
|
||||
worktrees,
|
||||
language_servers: language_servers
|
||||
.into_iter()
|
||||
.map(|language_server| proto::LanguageServer {
|
||||
id: language_server.id.to_proto(),
|
||||
name: language_server.name,
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
replica_id as ReplicaId,
|
||||
))
|
||||
self.commit_room_transaction(
|
||||
room_id,
|
||||
tx,
|
||||
(
|
||||
Project {
|
||||
collaborators,
|
||||
worktrees,
|
||||
language_servers: language_servers
|
||||
.into_iter()
|
||||
.map(|language_server| proto::LanguageServer {
|
||||
id: language_server.id.to_proto(),
|
||||
name: language_server.name,
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
replica_id as ReplicaId,
|
||||
),
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -2091,7 +2149,7 @@ where
|
|||
&self,
|
||||
project_id: ProjectId,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<LeftProject> {
|
||||
) -> Result<RoomGuard<LeftProject>> {
|
||||
self.transact(|mut tx| async move {
|
||||
let result = sqlx::query(
|
||||
"
|
||||
|
@ -2122,25 +2180,29 @@ where
|
|||
.map(|id| ConnectionId(id as u32))
|
||||
.collect();
|
||||
|
||||
let (host_user_id, host_connection_id) = sqlx::query_as::<_, (i32, i32)>(
|
||||
"
|
||||
SELECT host_user_id, host_connection_id
|
||||
let (room_id, host_user_id, host_connection_id) =
|
||||
sqlx::query_as::<_, (RoomId, i32, i32)>(
|
||||
"
|
||||
SELECT room_id, host_user_id, host_connection_id
|
||||
FROM projects
|
||||
WHERE id = $1
|
||||
",
|
||||
)
|
||||
.bind(project_id)
|
||||
.fetch_one(&mut tx)
|
||||
.await?;
|
||||
|
||||
self.commit_room_transaction(
|
||||
room_id,
|
||||
tx,
|
||||
LeftProject {
|
||||
id: project_id,
|
||||
host_user_id: UserId(host_user_id),
|
||||
host_connection_id: ConnectionId(host_connection_id as u32),
|
||||
connection_ids,
|
||||
},
|
||||
)
|
||||
.bind(project_id)
|
||||
.fetch_one(&mut tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(LeftProject {
|
||||
id: project_id,
|
||||
host_user_id: UserId(host_user_id),
|
||||
host_connection_id: ConnectionId(host_connection_id as u32),
|
||||
connection_ids,
|
||||
})
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
@ -2538,9 +2600,9 @@ where
|
|||
|
||||
let result = self.runtime.as_ref().unwrap().block_on(body);
|
||||
|
||||
if let Some(background) = self.background.as_ref() {
|
||||
background.simulate_random_delay().await;
|
||||
}
|
||||
// if let Some(background) = self.background.as_ref() {
|
||||
// background.simulate_random_delay().await;
|
||||
// }
|
||||
|
||||
result
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue