diff --git a/Cargo.lock b/Cargo.lock index 64f13c6540..d89a9022d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4308,6 +4308,7 @@ dependencies = [ "rsa", "serde 1.0.125", "smol", + "smol-timeout", "tempdir", "zstd", ] @@ -4867,6 +4868,16 @@ dependencies = [ "once_cell", ] +[[package]] +name = "smol-timeout" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "847d777e2c6c166bad26264479e80a9820f3d364fcb4a0e23cd57bbfa8e94961" +dependencies = [ + "async-io", + "pin-project-lite 0.1.12", +] + [[package]] name = "socket2" version = "0.3.19" diff --git a/crates/rpc/Cargo.toml b/crates/rpc/Cargo.toml index 8f43a28a3b..f8c1e7d39c 100644 --- a/crates/rpc/Cargo.toml +++ b/crates/rpc/Cargo.toml @@ -20,6 +20,7 @@ prost = "0.8" rand = "0.8" rsa = "0.4" serde = { version = "1", features = ["derive"] } +smol-timeout = "0.6" zstd = "0.9" [build-dependencies] diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 1a407e512f..7752fc231a 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -7,6 +7,7 @@ use postage::{ mpsc, prelude::{Sink as _, Stream as _}, }; +use smol_timeout::TimeoutExt as _; use std::{ collections::HashMap, fmt, @@ -16,6 +17,7 @@ use std::{ atomic::{self, AtomicU32}, Arc, }, + time::Duration, }; #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] @@ -90,6 +92,8 @@ struct ConnectionState { response_channels: Arc>>>>, } +const WRITE_TIMEOUT: Duration = Duration::from_secs(10); + impl Peer { pub fn new() -> Arc { Arc::new(Self { @@ -155,8 +159,10 @@ impl Peer { }, outgoing = outgoing_rx.recv().fuse() => match outgoing { Some(outgoing) => { - if let Err(result) = writer.write_message(&outgoing).await { - break 'outer Err(result).context("failed to write RPC message") + match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await { + None => break 'outer Err(anyhow!("timed out writing RPC message")), + Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"), + _ => {} } } None => break 'outer Ok(()), diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index 0c60ba3cbd..64463f2723 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -6,9 +6,10 @@ use super::{ AppState, }; use anyhow::anyhow; -use async_std::{sync::RwLock, task}; +use async_std::task; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use futures::{future::BoxFuture, FutureExt}; +use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use postage::{mpsc, prelude::Sink as _, prelude::Stream as _}; use rpc::{ proto::{self, AnyTypedEnvelope, EnvelopedMessage}, @@ -23,7 +24,7 @@ use std::{ sync::Arc, time::Instant, }; -use store::{JoinedWorktree, Store, Worktree}; +use store::{Store, Worktree}; use surf::StatusCode; use tide::log; use tide::{ @@ -116,9 +117,7 @@ impl Server { async move { let (connection_id, handle_io, mut incoming_rx) = this.peer.add_connection(connection).await; - this.state_mut() - .await - .add_connection(connection_id, user_id); + this.state_mut().add_connection(connection_id, user_id); if let Err(err) = this.update_collaborators_for_users(&[user_id]).await { log::error!("error updating collaborators for {:?}: {}", user_id, err); } @@ -168,7 +167,7 @@ impl Server { async fn sign_out(self: &mut Arc, connection_id: ConnectionId) -> tide::Result<()> { self.peer.disconnect(connection_id).await; - let removed_connection = self.state_mut().await.remove_connection(connection_id)?; + let removed_connection = self.state_mut().remove_connection(connection_id)?; for (worktree_id, worktree) in removed_connection.hosted_worktrees { if let Some(share) = worktree.share { @@ -213,10 +212,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let receipt = request.receipt(); - let host_user_id = self - .state() - .await - .user_id_for_connection(request.sender_id)?; + let host_user_id = self.state().user_id_for_connection(request.sender_id)?; let mut collaborator_user_ids = HashSet::new(); collaborator_user_ids.insert(host_user_id); @@ -236,7 +232,7 @@ impl Server { } let collaborator_user_ids = collaborator_user_ids.into_iter().collect::>(); - let worktree_id = self.state_mut().await.add_worktree(Worktree { + let worktree_id = self.state_mut().add_worktree(Worktree { host_connection_id: request.sender_id, collaborator_user_ids: collaborator_user_ids.clone(), root_name: request.payload.root_name, @@ -259,7 +255,6 @@ impl Server { let worktree_id = request.payload.worktree_id; let worktree = self .state_mut() - .await .remove_worktree(worktree_id, request.sender_id)?; if let Some(share) = worktree.share { @@ -294,7 +289,6 @@ impl Server { let collaborator_user_ids = self.state_mut() - .await .share_worktree(worktree.id, request.sender_id, entries); if let Some(collaborator_user_ids) = collaborator_user_ids { self.peer @@ -322,7 +316,6 @@ impl Server { let worktree_id = request.payload.worktree_id; let worktree = self .state_mut() - .await .unshare_worktree(worktree_id, request.sender_id)?; broadcast(request.sender_id, worktree.connection_ids, |conn_id| { @@ -341,22 +334,17 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let worktree_id = request.payload.worktree_id; - let user_id = self - .state() - .await - .user_id_for_connection(request.sender_id)?; - let mut state = self.state_mut().await; - match state.join_worktree(request.sender_id, user_id, worktree_id) { - Ok(JoinedWorktree { - replica_id, - worktree, - }) => { - let share = worktree.share()?; + let user_id = self.state().user_id_for_connection(request.sender_id)?; + let response_data = self + .state_mut() + .join_worktree(request.sender_id, user_id, worktree_id) + .and_then(|joined| { + let share = joined.worktree.share()?; let peer_count = share.guest_connection_ids.len(); let mut peers = Vec::with_capacity(peer_count); peers.push(proto::Peer { - peer_id: worktree.host_connection_id.0, + peer_id: joined.worktree.host_connection_id.0, replica_id: 0, }); for (peer_conn_id, peer_replica_id) in &share.guest_connection_ids { @@ -370,16 +358,19 @@ impl Server { let response = proto::JoinWorktreeResponse { worktree: Some(proto::Worktree { id: worktree_id, - root_name: worktree.root_name.clone(), + root_name: joined.worktree.root_name.clone(), entries: share.entries.values().cloned().collect(), }), - replica_id: replica_id as u32, + replica_id: joined.replica_id as u32, peers, }; - let connection_ids = worktree.connection_ids(); - let collaborator_user_ids = worktree.collaborator_user_ids.clone(); - drop(state); + let connection_ids = joined.worktree.connection_ids(); + let collaborator_user_ids = joined.worktree.collaborator_user_ids.clone(); + Ok((response, connection_ids, collaborator_user_ids)) + }); + match response_data { + Ok((response, connection_ids, collaborator_user_ids)) => { broadcast(request.sender_id, connection_ids, |conn_id| { self.peer.send( conn_id, @@ -398,7 +389,6 @@ impl Server { .await?; } Err(error) => { - drop(state); self.peer .respond_with_error( request.receipt(), @@ -419,10 +409,7 @@ impl Server { ) -> tide::Result<()> { let sender_id = request.sender_id; let worktree_id = request.payload.worktree_id; - let worktree = self - .state_mut() - .await - .leave_worktree(sender_id, worktree_id); + let worktree = self.state_mut().leave_worktree(sender_id, worktree_id); if let Some(worktree) = worktree { broadcast(sender_id, worktree.connection_ids, |conn_id| { self.peer.send( @@ -444,7 +431,7 @@ impl Server { mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let connection_ids = self.state_mut().await.update_worktree( + let connection_ids = self.state_mut().update_worktree( request.sender_id, request.payload.worktree_id, &request.payload.removed_entries, @@ -467,7 +454,6 @@ impl Server { let receipt = request.receipt(); let host_connection_id = self .state() - .await .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?; let response = self .peer @@ -483,7 +469,6 @@ impl Server { ) -> tide::Result<()> { let host_connection_id = self .state() - .await .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?; self.peer .forward_send(request.sender_id, host_connection_id, request.payload) @@ -498,7 +483,7 @@ impl Server { let host; let guests; { - let state = self.state().await; + let state = self.state(); host = state .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?; guests = state @@ -532,16 +517,13 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - broadcast( - request.sender_id, - self.state() - .await - .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?, - |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }, - ) + let receiver_ids = self + .state() + .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?; + broadcast(request.sender_id, receiver_ids, |connection_id| { + self.peer + .forward_send(request.sender_id, connection_id, request.payload.clone()) + }) .await?; self.peer.respond(request.receipt(), proto::Ack {}).await?; Ok(()) @@ -551,17 +533,13 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - broadcast( - request.sender_id, - self.store - .read() - .await - .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?, - |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }, - ) + let receiver_ids = self + .state() + .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?; + broadcast(request.sender_id, receiver_ids, |connection_id| { + self.peer + .forward_send(request.sender_id, connection_id, request.payload.clone()) + }) .await?; Ok(()) } @@ -570,10 +548,7 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let user_id = self - .state() - .await - .user_id_for_connection(request.sender_id)?; + let user_id = self.state().user_id_for_connection(request.sender_id)?; let channels = self.app_state.db.get_accessible_channels(user_id).await?; self.peer .respond( @@ -622,20 +597,20 @@ impl Server { ) -> tide::Result<()> { let mut send_futures = Vec::new(); - let state = self.state().await; - for user_id in user_ids { - let collaborators = state.collaborators_for_user(*user_id); - for connection_id in state.connection_ids_for_user(*user_id) { - send_futures.push(self.peer.send( - connection_id, - proto::UpdateCollaborators { - collaborators: collaborators.clone(), - }, - )); + { + let state = self.state(); + for user_id in user_ids { + let collaborators = state.collaborators_for_user(*user_id); + for connection_id in state.connection_ids_for_user(*user_id) { + send_futures.push(self.peer.send( + connection_id, + proto::UpdateCollaborators { + collaborators: collaborators.clone(), + }, + )); + } } } - - drop(state); futures::future::try_join_all(send_futures).await?; Ok(()) @@ -645,10 +620,7 @@ impl Server { mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let user_id = self - .state() - .await - .user_id_for_connection(request.sender_id)?; + let user_id = self.state().user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); if !self .app_state @@ -659,9 +631,7 @@ impl Server { Err(anyhow!("access denied"))?; } - self.state_mut() - .await - .join_channel(request.sender_id, channel_id); + self.state_mut().join_channel(request.sender_id, channel_id); let messages = self .app_state .db @@ -692,10 +662,7 @@ impl Server { mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let user_id = self - .state() - .await - .user_id_for_connection(request.sender_id)?; + let user_id = self.state().user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); if !self .app_state @@ -707,7 +674,6 @@ impl Server { } self.state_mut() - .await .leave_channel(request.sender_id, channel_id); Ok(()) @@ -722,7 +688,7 @@ impl Server { let user_id; let connection_ids; { - let state = self.state().await; + let state = self.state(); user_id = state.user_id_for_connection(request.sender_id)?; if let Some(ids) = state.channel_connection_ids(channel_id) { connection_ids = ids; @@ -809,10 +775,7 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let user_id = self - .state() - .await - .user_id_for_connection(request.sender_id)?; + let user_id = self.state().user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); if !self .app_state @@ -853,15 +816,11 @@ impl Server { Ok(()) } - fn state<'a>( - self: &'a Arc, - ) -> impl Future> { + fn state<'a>(self: &'a Arc) -> RwLockReadGuard<'a, Store> { self.store.read() } - fn state_mut<'a>( - self: &'a mut Arc, - ) -> impl Future> { + fn state_mut<'a>(self: &'a mut Arc) -> RwLockWriteGuard<'a, Store> { self.store.write() } } @@ -961,7 +920,7 @@ mod tests { github, AppState, Config, }; use ::rpc::Peer; - use async_std::{sync::RwLockReadGuard, task}; + use async_std::task; use gpui::{ModelHandle, TestAppContext}; use parking_lot::Mutex; use postage::{mpsc, watch}; @@ -2372,7 +2331,7 @@ mod tests { } async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> { - self.server.store.read().await + self.server.store.read() } async fn condition(&mut self, mut predicate: F) @@ -2380,7 +2339,7 @@ mod tests { F: FnMut(&Store) -> bool, { async_std::future::timeout(Duration::from_millis(500), async { - while !(predicate)(&*self.server.store.read().await) { + while !(predicate)(&*self.server.store.read()) { self.notifications.recv().await; } })