Use a synchronous mutex for ConnectionPool
This commit is contained in:
parent
a594ba8f8a
commit
d4c8fa3090
3 changed files with 20 additions and 20 deletions
|
@ -6062,7 +6062,6 @@ async fn test_random_collaboration(
|
||||||
let user_connection_ids = server
|
let user_connection_ids = server
|
||||||
.connection_pool
|
.connection_pool
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
|
||||||
.user_connection_ids(removed_guest_id)
|
.user_connection_ids(removed_guest_id)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
assert_eq!(user_connection_ids.len(), 1);
|
assert_eq!(user_connection_ids.len(), 1);
|
||||||
|
@ -6083,7 +6082,7 @@ async fn test_random_collaboration(
|
||||||
}
|
}
|
||||||
for user_id in &user_ids {
|
for user_id in &user_ids {
|
||||||
let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap();
|
let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap();
|
||||||
let pool = server.connection_pool.lock().await;
|
let pool = server.connection_pool.lock();
|
||||||
for contact in contacts {
|
for contact in contacts {
|
||||||
if let db::Contact::Accepted { user_id, .. } = contact {
|
if let db::Contact::Accepted { user_id, .. } = contact {
|
||||||
if pool.is_user_online(user_id) {
|
if pool.is_user_online(user_id) {
|
||||||
|
@ -6112,7 +6111,6 @@ async fn test_random_collaboration(
|
||||||
let user_connection_ids = server
|
let user_connection_ids = server
|
||||||
.connection_pool
|
.connection_pool
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
|
||||||
.user_connection_ids(user_id)
|
.user_connection_ids(user_id)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
assert_eq!(user_connection_ids.len(), 1);
|
assert_eq!(user_connection_ids.len(), 1);
|
||||||
|
|
|
@ -53,7 +53,7 @@ use std::{
|
||||||
},
|
},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use tokio::sync::{watch, Mutex, MutexGuard};
|
use tokio::sync::watch;
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
|
|
||||||
|
@ -90,14 +90,14 @@ impl<R: RequestMessage> Response<R> {
|
||||||
struct Session {
|
struct Session {
|
||||||
user_id: UserId,
|
user_id: UserId,
|
||||||
connection_id: ConnectionId,
|
connection_id: ConnectionId,
|
||||||
db: Arc<Mutex<DbHandle>>,
|
db: Arc<tokio::sync::Mutex<DbHandle>>,
|
||||||
peer: Arc<Peer>,
|
peer: Arc<Peer>,
|
||||||
connection_pool: Arc<Mutex<ConnectionPool>>,
|
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
|
||||||
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
async fn db(&self) -> MutexGuard<DbHandle> {
|
async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
tokio::task::yield_now().await;
|
tokio::task::yield_now().await;
|
||||||
let guard = self.db.lock().await;
|
let guard = self.db.lock().await;
|
||||||
|
@ -109,9 +109,7 @@ impl Session {
|
||||||
async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
|
async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
tokio::task::yield_now().await;
|
tokio::task::yield_now().await;
|
||||||
let guard = self.connection_pool.lock().await;
|
let guard = self.connection_pool.lock();
|
||||||
#[cfg(test)]
|
|
||||||
tokio::task::yield_now().await;
|
|
||||||
ConnectionPoolGuard {
|
ConnectionPoolGuard {
|
||||||
guard,
|
guard,
|
||||||
_not_send: PhantomData,
|
_not_send: PhantomData,
|
||||||
|
@ -140,7 +138,7 @@ impl Deref for DbHandle {
|
||||||
|
|
||||||
pub struct Server {
|
pub struct Server {
|
||||||
peer: Arc<Peer>,
|
peer: Arc<Peer>,
|
||||||
pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
|
pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
|
||||||
app_state: Arc<AppState>,
|
app_state: Arc<AppState>,
|
||||||
executor: Executor,
|
executor: Executor,
|
||||||
handlers: HashMap<TypeId, MessageHandler>,
|
handlers: HashMap<TypeId, MessageHandler>,
|
||||||
|
@ -148,7 +146,7 @@ pub struct Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct ConnectionPoolGuard<'a> {
|
pub(crate) struct ConnectionPoolGuard<'a> {
|
||||||
guard: MutexGuard<'a, ConnectionPool>,
|
guard: parking_lot::MutexGuard<'a, ConnectionPool>,
|
||||||
_not_send: PhantomData<Rc<()>>,
|
_not_send: PhantomData<Rc<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -268,7 +266,7 @@ impl Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
let pool = pool.lock().await;
|
let pool = pool.lock();
|
||||||
for canceled_user_id in canceled_calls_to_user_ids {
|
for canceled_user_id in canceled_calls_to_user_ids {
|
||||||
for connection_id in pool.user_connection_ids(canceled_user_id) {
|
for connection_id in pool.user_connection_ids(canceled_user_id) {
|
||||||
peer.send(
|
peer.send(
|
||||||
|
@ -286,7 +284,7 @@ impl Server {
|
||||||
let busy = db.is_user_busy(user_id).await.trace_err();
|
let busy = db.is_user_busy(user_id).await.trace_err();
|
||||||
let contacts = db.get_contacts(user_id).await.trace_err();
|
let contacts = db.get_contacts(user_id).await.trace_err();
|
||||||
if let Some((busy, contacts)) = busy.zip(contacts) {
|
if let Some((busy, contacts)) = busy.zip(contacts) {
|
||||||
let pool = pool.lock().await;
|
let pool = pool.lock();
|
||||||
let updated_contact = contact_for_user(user_id, false, busy, &pool);
|
let updated_contact = contact_for_user(user_id, false, busy, &pool);
|
||||||
for contact in contacts {
|
for contact in contacts {
|
||||||
if let db::Contact::Accepted {
|
if let db::Contact::Accepted {
|
||||||
|
@ -456,7 +454,7 @@ impl Server {
|
||||||
).await?;
|
).await?;
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut pool = this.connection_pool.lock().await;
|
let mut pool = this.connection_pool.lock();
|
||||||
pool.add_connection(connection_id, user_id, user.admin);
|
pool.add_connection(connection_id, user_id, user.admin);
|
||||||
this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
|
this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
|
||||||
|
|
||||||
|
@ -475,7 +473,7 @@ impl Server {
|
||||||
let session = Session {
|
let session = Session {
|
||||||
user_id,
|
user_id,
|
||||||
connection_id,
|
connection_id,
|
||||||
db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
|
db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
|
||||||
peer: this.peer.clone(),
|
peer: this.peer.clone(),
|
||||||
connection_pool: this.connection_pool.clone(),
|
connection_pool: this.connection_pool.clone(),
|
||||||
live_kit_client: this.app_state.live_kit_client.clone()
|
live_kit_client: this.app_state.live_kit_client.clone()
|
||||||
|
@ -550,7 +548,7 @@ impl Server {
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
|
if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
|
||||||
if let Some(code) = &user.invite_code {
|
if let Some(code) = &user.invite_code {
|
||||||
let pool = self.connection_pool.lock().await;
|
let pool = self.connection_pool.lock();
|
||||||
let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
|
let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
|
||||||
for connection_id in pool.user_connection_ids(inviter_id) {
|
for connection_id in pool.user_connection_ids(inviter_id) {
|
||||||
self.peer.send(
|
self.peer.send(
|
||||||
|
@ -576,7 +574,7 @@ impl Server {
|
||||||
pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
|
pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
|
||||||
if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
|
if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
|
||||||
if let Some(invite_code) = &user.invite_code {
|
if let Some(invite_code) = &user.invite_code {
|
||||||
let pool = self.connection_pool.lock().await;
|
let pool = self.connection_pool.lock();
|
||||||
for connection_id in pool.user_connection_ids(user_id) {
|
for connection_id in pool.user_connection_ids(user_id) {
|
||||||
self.peer.send(
|
self.peer.send(
|
||||||
connection_id,
|
connection_id,
|
||||||
|
@ -597,7 +595,7 @@ impl Server {
|
||||||
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
|
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
|
||||||
ServerSnapshot {
|
ServerSnapshot {
|
||||||
connection_pool: ConnectionPoolGuard {
|
connection_pool: ConnectionPoolGuard {
|
||||||
guard: self.connection_pool.lock().await,
|
guard: self.connection_pool.lock(),
|
||||||
_not_send: PhantomData,
|
_not_send: PhantomData,
|
||||||
},
|
},
|
||||||
peer: &self.peer,
|
peer: &self.peer,
|
||||||
|
@ -718,7 +716,6 @@ pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result
|
||||||
let connections = server
|
let connections = server
|
||||||
.connection_pool
|
.connection_pool
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
|
||||||
.connections()
|
.connections()
|
||||||
.filter(|connection| !connection.admin)
|
.filter(|connection| !connection.admin)
|
||||||
.count();
|
.count();
|
||||||
|
|
|
@ -23,6 +23,11 @@ pub struct Connection {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConnectionPool {
|
impl ConnectionPool {
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.connections.clear();
|
||||||
|
self.connected_users.clear();
|
||||||
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
|
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
|
||||||
self.connections
|
self.connections
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue