diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 2a3f6ff62c..13f484910b 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -6062,7 +6062,6 @@ async fn test_random_collaboration( let user_connection_ids = server .connection_pool .lock() - .await .user_connection_ids(removed_guest_id) .collect::>(); assert_eq!(user_connection_ids.len(), 1); @@ -6083,7 +6082,7 @@ async fn test_random_collaboration( } for user_id in &user_ids { let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap(); - let pool = server.connection_pool.lock().await; + let pool = server.connection_pool.lock(); for contact in contacts { if let db::Contact::Accepted { user_id, .. } = contact { if pool.is_user_online(user_id) { @@ -6112,7 +6111,6 @@ async fn test_random_collaboration( let user_connection_ids = server .connection_pool .lock() - .await .user_connection_ids(user_id) .collect::>(); assert_eq!(user_connection_ids.len(), 1); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 2c1733aeb3..861a02fd31 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -53,7 +53,7 @@ use std::{ }, time::Duration, }; -use tokio::sync::{watch, Mutex, MutexGuard}; +use tokio::sync::watch; use tower::ServiceBuilder; use tracing::{info_span, instrument, Instrument}; @@ -90,14 +90,14 @@ impl Response { struct Session { user_id: UserId, connection_id: ConnectionId, - db: Arc>, + db: Arc>, peer: Arc, - connection_pool: Arc>, + connection_pool: Arc>, live_kit_client: Option>, } impl Session { - async fn db(&self) -> MutexGuard { + async fn db(&self) -> tokio::sync::MutexGuard { #[cfg(test)] tokio::task::yield_now().await; let guard = self.db.lock().await; @@ -109,9 +109,7 @@ impl Session { async fn connection_pool(&self) -> ConnectionPoolGuard<'_> { #[cfg(test)] tokio::task::yield_now().await; - let guard = self.connection_pool.lock().await; - #[cfg(test)] - tokio::task::yield_now().await; + let guard = self.connection_pool.lock(); ConnectionPoolGuard { guard, _not_send: PhantomData, @@ -140,7 +138,7 @@ impl Deref for DbHandle { pub struct Server { peer: Arc, - pub(crate) connection_pool: Arc>, + pub(crate) connection_pool: Arc>, app_state: Arc, executor: Executor, handlers: HashMap, @@ -148,7 +146,7 @@ pub struct Server { } pub(crate) struct ConnectionPoolGuard<'a> { - guard: MutexGuard<'a, ConnectionPool>, + guard: parking_lot::MutexGuard<'a, ConnectionPool>, _not_send: PhantomData>, } @@ -268,7 +266,7 @@ impl Server { } { - let pool = pool.lock().await; + let pool = pool.lock(); for canceled_user_id in canceled_calls_to_user_ids { for connection_id in pool.user_connection_ids(canceled_user_id) { peer.send( @@ -286,7 +284,7 @@ impl Server { let busy = db.is_user_busy(user_id).await.trace_err(); let contacts = db.get_contacts(user_id).await.trace_err(); if let Some((busy, contacts)) = busy.zip(contacts) { - let pool = pool.lock().await; + let pool = pool.lock(); let updated_contact = contact_for_user(user_id, false, busy, &pool); for contact in contacts { if let db::Contact::Accepted { @@ -456,7 +454,7 @@ impl Server { ).await?; { - let mut pool = this.connection_pool.lock().await; + let mut pool = this.connection_pool.lock(); pool.add_connection(connection_id, user_id, user.admin); this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?; @@ -475,7 +473,7 @@ impl Server { let session = Session { user_id, connection_id, - db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))), + db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))), peer: this.peer.clone(), connection_pool: this.connection_pool.clone(), live_kit_client: this.app_state.live_kit_client.clone() @@ -550,7 +548,7 @@ impl Server { ) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { - let pool = self.connection_pool.lock().await; + let pool = self.connection_pool.lock(); let invitee_contact = contact_for_user(invitee_id, true, false, &pool); for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( @@ -576,7 +574,7 @@ impl Server { pub async fn invite_count_updated(self: &Arc, user_id: UserId) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { if let Some(invite_code) = &user.invite_code { - let pool = self.connection_pool.lock().await; + let pool = self.connection_pool.lock(); for connection_id in pool.user_connection_ids(user_id) { self.peer.send( connection_id, @@ -597,7 +595,7 @@ impl Server { pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { ServerSnapshot { connection_pool: ConnectionPoolGuard { - guard: self.connection_pool.lock().await, + guard: self.connection_pool.lock(), _not_send: PhantomData, }, peer: &self.peer, @@ -718,7 +716,6 @@ pub async fn handle_metrics(Extension(server): Extension>) -> Result let connections = server .connection_pool .lock() - .await .connections() .filter(|connection| !connection.admin) .count(); diff --git a/crates/collab/src/rpc/connection_pool.rs b/crates/collab/src/rpc/connection_pool.rs index ac7632f7da..30c4e144ed 100644 --- a/crates/collab/src/rpc/connection_pool.rs +++ b/crates/collab/src/rpc/connection_pool.rs @@ -23,6 +23,11 @@ pub struct Connection { } impl ConnectionPool { + pub fn reset(&mut self) { + self.connections.clear(); + self.connected_users.clear(); + } + #[instrument(skip(self))] pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) { self.connections