diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index b7aa30645f..c573b7e7c0 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -68,7 +68,7 @@ use std::{ rc::Rc, sync::{ Arc, OnceLock, - atomic::{AtomicBool, Ordering::SeqCst}, + atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}, }, time::{Duration, Instant}, }; @@ -89,10 +89,36 @@ pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15); const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; const NOTIFICATION_COUNT_PER_PAGE: usize = 50; +const MAX_CONCURRENT_CONNECTIONS: usize = 512; + +static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0); type MessageHandler = Box, Session) -> BoxFuture<'static, ()>>; +pub struct ConnectionGuard; + +impl ConnectionGuard { + pub fn try_acquire() -> Result { + let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst); + if current_connections >= MAX_CONCURRENT_CONNECTIONS { + CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst); + tracing::error!( + "too many concurrent connections: {}", + current_connections + 1 + ); + return Err(()); + } + Ok(ConnectionGuard) + } +} + +impl Drop for ConnectionGuard { + fn drop(&mut self) { + CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst); + } +} + struct Response { peer: Arc, receipt: Receipt, @@ -725,6 +751,7 @@ impl Server { system_id: Option, send_connection_id: Option>, executor: Executor, + connection_guard: Option, ) -> impl Future + use<> { let this = self.clone(); let span = info_span!("handle connection", %address, @@ -745,6 +772,7 @@ impl Server { tracing::error!("server is tearing down"); return } + let (connection_id, handle_io, mut incoming_rx) = this .peer .add_connection(connection, { @@ -786,6 +814,7 @@ impl Server { tracing::error!(?error, "failed to send initial client update"); return; } + drop(connection_guard); let handle_io = handle_io.fuse(); futures::pin_mut!(handle_io); @@ -1157,6 +1186,19 @@ pub async fn handle_websocket_request( } let socket_address = socket_address.to_string(); + + // Acquire connection guard before WebSocket upgrade + let connection_guard = match ConnectionGuard::try_acquire() { + Ok(guard) => guard, + Err(()) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + "Too many concurrent connections", + ) + .into_response(); + } + }; + ws.on_upgrade(move |socket| { let socket = socket .map_ok(to_tungstenite_message) @@ -1174,6 +1216,7 @@ pub async fn handle_websocket_request( system_id_header.map(|header| header.to_string()), None, Executor::Production, + Some(connection_guard), ) .await; } diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 83792c631a..c133e46ecd 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -258,6 +258,7 @@ impl TestServer { None, Some(connection_id_tx), Executor::Deterministic(cx.background_executor().clone()), + None, )) .detach(); let connection_id = connection_id_rx.await.map_err(|e| {