diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index d901bd060c..4d23c00d42 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -50,7 +50,6 @@ use std::{ time::Duration, }; use theme::ThemeRegistry; -use tokio::sync::RwLockReadGuard; use workspace::{Item, SplitDirection, ToggleFollow, Workspace}; #[ctor::ctor] @@ -589,7 +588,7 @@ async fn test_offline_projects( deterministic.run_until_parked(); assert!(server .store - .read() + .lock() .await .project_metadata_for_user(user_a) .is_empty()); @@ -620,7 +619,7 @@ async fn test_offline_projects( cx_a.foreground().advance_clock(rpc::RECEIVE_TIMEOUT); assert!(server .store - .read() + .lock() .await .project_metadata_for_user(user_a) .is_empty()); @@ -1446,7 +1445,7 @@ async fn test_collaborating_with_diagnostics( // Wait for server to see the diagnostics update. deterministic.run_until_parked(); { - let store = server.store.read().await; + let store = server.store.lock().await; let project = store.project(ProjectId::from_proto(project_id)).unwrap(); let worktree = project.worktrees.get(&worktree_id.to_proto()).unwrap(); assert!(!worktree.diagnostic_summaries.is_empty()); @@ -3172,7 +3171,7 @@ async fn test_basic_chat(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { assert_eq!( server - .state() + .store() .await .channel(channel_id) .unwrap() @@ -4660,7 +4659,7 @@ async fn test_random_collaboration( .unwrap(); let contacts = server .store - .read() + .lock() .await .build_initial_contacts_update(contacts) .contacts; @@ -4745,7 +4744,7 @@ async fn test_random_collaboration( let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap(); let contacts = server .store - .read() + .lock() .await .build_initial_contacts_update(contacts) .contacts; @@ -5077,10 +5076,6 @@ impl TestServer { }) } - async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> { - self.server.store.read().await - } - async fn condition(&mut self, mut predicate: F) where F: FnMut(&Store) -> bool, @@ -5089,7 +5084,7 @@ impl TestServer { self.foreground.parking_forbidden(), "you must call forbid_parking to use server conditions so we don't block indefinitely" ); - while !(predicate)(&*self.server.store.read().await) { + while !(predicate)(&*self.server.store.lock().await) { self.foreground.start_waiting(); self.notifications.next().await; self.foreground.finish_waiting(); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index b3dc965ff3..b7b0e00f2d 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -51,7 +51,7 @@ use std::{ }; use time::OffsetDateTime; use tokio::{ - sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}, + sync::{Mutex, MutexGuard}, time::Sleep, }; use tower::ServiceBuilder; @@ -97,7 +97,7 @@ impl Response { pub struct Server { peer: Arc, - pub(crate) store: RwLock, + pub(crate) store: Mutex, app_state: Arc, handlers: HashMap, notifications: Option>, @@ -115,13 +115,8 @@ pub struct RealExecutor; const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; -struct StoreReadGuard<'a> { - guard: RwLockReadGuard<'a, Store>, - _not_send: PhantomData>, -} - -struct StoreWriteGuard<'a> { - guard: RwLockWriteGuard<'a, Store>, +pub(crate) struct StoreGuard<'a> { + guard: MutexGuard<'a, Store>, _not_send: PhantomData>, } @@ -129,7 +124,7 @@ struct StoreWriteGuard<'a> { pub struct ServerSnapshot<'a> { peer: &'a Peer, #[serde(serialize_with = "serialize_deref")] - store: RwLockReadGuard<'a, Store>, + store: StoreGuard<'a>, } pub fn serialize_deref(value: &T, serializer: S) -> Result @@ -384,7 +379,7 @@ impl Server { ).await?; { - let mut store = this.store_mut().await; + let mut store = this.store().await; store.add_connection(connection_id, user_id, user.admin); this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?; @@ -471,7 +466,7 @@ impl Server { let mut projects_to_unregister = Vec::new(); let removed_user_id; { - let mut store = self.store_mut().await; + let mut store = self.store().await; let removed_connection = store.remove_connection(connection_id)?; for (project_id, project) in removed_connection.hosted_projects { @@ -605,7 +600,7 @@ impl Server { .await .user_id_for_connection(request.sender_id)?; let project_id = self.app_state.db.register_project(user_id).await?; - self.store_mut() + self.store() .await .register_project(request.sender_id, project_id)?; @@ -623,7 +618,7 @@ impl Server { ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let (user_id, project) = { - let mut state = self.store_mut().await; + let mut state = self.store().await; let project = state.unregister_project(project_id, request.sender_id)?; (state.user_id_for_connection(request.sender_id)?, project) }; @@ -725,7 +720,7 @@ impl Server { return Err(anyhow!("no such project"))?; } - self.store_mut().await.request_join_project( + self.store().await.request_join_project( guest_user_id, project_id, response.into_receipt(), @@ -747,7 +742,7 @@ impl Server { let host_user_id; { - let mut state = self.store_mut().await; + let mut state = self.store().await; let project_id = ProjectId::from_proto(request.payload.project_id); let project = state.project(project_id)?; if project.host_connection_id != request.sender_id { @@ -897,7 +892,7 @@ impl Server { let project_id = ProjectId::from_proto(request.payload.project_id); let project; { - let mut store = self.store_mut().await; + let mut store = self.store().await; project = store.leave_project(sender_id, project_id)?; tracing::info!( %project_id, @@ -948,7 +943,7 @@ impl Server { let project_id = ProjectId::from_proto(request.payload.project_id); let user_id; { - let mut state = self.store_mut().await; + let mut state = self.store().await; user_id = state.user_id_for_connection(request.sender_id)?; let guest_connection_ids = state .read_project(project_id, request.sender_id)? @@ -967,7 +962,7 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> Result<()> { - self.store_mut().await.register_project_activity( + self.store().await.register_project_activity( ProjectId::from_proto(request.payload.project_id), request.sender_id, )?; @@ -982,7 +977,7 @@ impl Server { let project_id = ProjectId::from_proto(request.payload.project_id); let worktree_id = request.payload.worktree_id; let (connection_ids, metadata_changed, extension_counts) = { - let mut store = self.store_mut().await; + let mut store = self.store().await; let (connection_ids, metadata_changed, extension_counts) = store.update_worktree( request.sender_id, project_id, @@ -1024,7 +1019,7 @@ impl Server { .summary .clone() .ok_or_else(|| anyhow!("invalid summary"))?; - let receiver_ids = self.store_mut().await.update_diagnostic_summary( + let receiver_ids = self.store().await.update_diagnostic_summary( ProjectId::from_proto(request.payload.project_id), request.payload.worktree_id, request.sender_id, @@ -1042,7 +1037,7 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> Result<()> { - let receiver_ids = self.store_mut().await.start_language_server( + let receiver_ids = self.store().await.start_language_server( ProjectId::from_proto(request.payload.project_id), request.sender_id, request @@ -1081,20 +1076,23 @@ impl Server { where T: EntityMessage + RequestMessage, { + let project_id = ProjectId::from_proto(request.payload.remote_entity_id()); let host_connection_id = self .store() .await - .read_project( - ProjectId::from_proto(request.payload.remote_entity_id()), - request.sender_id, - )? + .read_project(project_id, request.sender_id)? .host_connection_id; + let payload = self + .peer + .forward_request(request.sender_id, host_connection_id, request.payload) + .await?; - response.send( - self.peer - .forward_request(request.sender_id, host_connection_id, request.payload) - .await?, - )?; + // Ensure project still exists by the time we get the response from the host. + self.store() + .await + .read_project(project_id, request.sender_id)?; + + response.send(payload)?; Ok(()) } @@ -1135,7 +1133,7 @@ impl Server { ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let receiver_ids = { - let mut store = self.store_mut().await; + let mut store = self.store().await; store.register_project_activity(project_id, request.sender_id)?; store.project_connection_ids(project_id, request.sender_id)? }; @@ -1202,7 +1200,7 @@ impl Server { let leader_id = ConnectionId(request.payload.leader_id); let follower_id = request.sender_id; { - let mut store = self.store_mut().await; + let mut store = self.store().await; if !store .project_connection_ids(project_id, follower_id)? .contains(&leader_id) @@ -1227,7 +1225,7 @@ impl Server { async fn unfollow(self: Arc, request: TypedEnvelope) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); - let mut store = self.store_mut().await; + let mut store = self.store().await; if !store .project_connection_ids(project_id, request.sender_id)? .contains(&leader_id) @@ -1245,7 +1243,7 @@ impl Server { request: TypedEnvelope, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - let mut store = self.store_mut().await; + let mut store = self.store().await; store.register_project_activity(project_id, request.sender_id)?; let connection_ids = store.project_connection_ids(project_id, request.sender_id)?; let leader_id = request @@ -1503,7 +1501,7 @@ impl Server { Err(anyhow!("access denied"))?; } - self.store_mut() + self.store() .await .join_channel(request.sender_id, channel_id); let messages = self @@ -1545,7 +1543,7 @@ impl Server { Err(anyhow!("access denied"))?; } - self.store_mut() + self.store() .await .leave_channel(request.sender_id, channel_id); @@ -1653,25 +1651,13 @@ impl Server { Ok(()) } - async fn store<'a>(self: &'a Arc) -> StoreReadGuard<'a> { + pub(crate) async fn store<'a>(&'a self) -> StoreGuard<'a> { #[cfg(test)] tokio::task::yield_now().await; - let guard = self.store.read().await; + let guard = self.store.lock().await; #[cfg(test)] tokio::task::yield_now().await; - StoreReadGuard { - guard, - _not_send: PhantomData, - } - } - - async fn store_mut<'a>(self: &'a Arc) -> StoreWriteGuard<'a> { - #[cfg(test)] - tokio::task::yield_now().await; - let guard = self.store.write().await; - #[cfg(test)] - tokio::task::yield_now().await; - StoreWriteGuard { + StoreGuard { guard, _not_send: PhantomData, } @@ -1679,13 +1665,13 @@ impl Server { pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { ServerSnapshot { - store: self.store.read().await, + store: self.store().await, peer: &self.peer, } } } -impl<'a> Deref for StoreReadGuard<'a> { +impl<'a> Deref for StoreGuard<'a> { type Target = Store; fn deref(&self) -> &Self::Target { @@ -1693,21 +1679,13 @@ impl<'a> Deref for StoreReadGuard<'a> { } } -impl<'a> Deref for StoreWriteGuard<'a> { - type Target = Store; - - fn deref(&self) -> &Self::Target { - &*self.guard - } -} - -impl<'a> DerefMut for StoreWriteGuard<'a> { +impl<'a> DerefMut for StoreGuard<'a> { fn deref_mut(&mut self) -> &mut Self::Target { &mut *self.guard } } -impl<'a> Drop for StoreWriteGuard<'a> { +impl<'a> Drop for StoreGuard<'a> { fn drop(&mut self) { #[cfg(test)] self.check_invariants();