Ensure worktrees have been sent before responding with definitions

Changing the frequency at which we update worktrees highlighted a
problem in the randomized tests that was causing clients to receive
a definition to a worktree *before* observing the registration of
the worktree itself. This was most likely caused by #1224 because
the scenario that pull request enabled was the following:

- Guest requests a definition pointing to a non-existant worktree
- Server forwards the request to the host
- Host sends an `UpdateProject` message
- Host sends a response to the definition request
- Server observes the `UpdateProject` message and tries to acquire
  the store
- Given that we're waiting, the server goes ahead to process the
  response for the definition request, responding *before*
  `UpdateProject` is forwarded
- Server finally forwards `UpdateProject` to the guest

This commit ensures that, after forwarding a project request and getting a
response, we acquire a lock to the store again to ensure the project still
exists. This has the effect of ordering the forwarded request *after* any
message that was received prior to the response and for which we are still
waiting to acquire a lock to the store.
This commit is contained in:
Antonio Scandurra 2022-07-01 11:45:30 +02:00
parent 8a105bf12f
commit d36a4888db
2 changed files with 49 additions and 76 deletions

View file

@ -50,7 +50,6 @@ use std::{
time::Duration, time::Duration,
}; };
use theme::ThemeRegistry; use theme::ThemeRegistry;
use tokio::sync::RwLockReadGuard;
use workspace::{Item, SplitDirection, ToggleFollow, Workspace}; use workspace::{Item, SplitDirection, ToggleFollow, Workspace};
#[ctor::ctor] #[ctor::ctor]
@ -589,7 +588,7 @@ async fn test_offline_projects(
deterministic.run_until_parked(); deterministic.run_until_parked();
assert!(server assert!(server
.store .store
.read() .lock()
.await .await
.project_metadata_for_user(user_a) .project_metadata_for_user(user_a)
.is_empty()); .is_empty());
@ -620,7 +619,7 @@ async fn test_offline_projects(
cx_a.foreground().advance_clock(rpc::RECEIVE_TIMEOUT); cx_a.foreground().advance_clock(rpc::RECEIVE_TIMEOUT);
assert!(server assert!(server
.store .store
.read() .lock()
.await .await
.project_metadata_for_user(user_a) .project_metadata_for_user(user_a)
.is_empty()); .is_empty());
@ -1446,7 +1445,7 @@ async fn test_collaborating_with_diagnostics(
// Wait for server to see the diagnostics update. // Wait for server to see the diagnostics update.
deterministic.run_until_parked(); deterministic.run_until_parked();
{ {
let store = server.store.read().await; let store = server.store.lock().await;
let project = store.project(ProjectId::from_proto(project_id)).unwrap(); let project = store.project(ProjectId::from_proto(project_id)).unwrap();
let worktree = project.worktrees.get(&worktree_id.to_proto()).unwrap(); let worktree = project.worktrees.get(&worktree_id.to_proto()).unwrap();
assert!(!worktree.diagnostic_summaries.is_empty()); assert!(!worktree.diagnostic_summaries.is_empty());
@ -3172,7 +3171,7 @@ async fn test_basic_chat(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
assert_eq!( assert_eq!(
server server
.state() .store()
.await .await
.channel(channel_id) .channel(channel_id)
.unwrap() .unwrap()
@ -4660,7 +4659,7 @@ async fn test_random_collaboration(
.unwrap(); .unwrap();
let contacts = server let contacts = server
.store .store
.read() .lock()
.await .await
.build_initial_contacts_update(contacts) .build_initial_contacts_update(contacts)
.contacts; .contacts;
@ -4745,7 +4744,7 @@ async fn test_random_collaboration(
let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap(); let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap();
let contacts = server let contacts = server
.store .store
.read() .lock()
.await .await
.build_initial_contacts_update(contacts) .build_initial_contacts_update(contacts)
.contacts; .contacts;
@ -5077,10 +5076,6 @@ impl TestServer {
}) })
} }
async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
self.server.store.read().await
}
async fn condition<F>(&mut self, mut predicate: F) async fn condition<F>(&mut self, mut predicate: F)
where where
F: FnMut(&Store) -> bool, F: FnMut(&Store) -> bool,
@ -5089,7 +5084,7 @@ impl TestServer {
self.foreground.parking_forbidden(), self.foreground.parking_forbidden(),
"you must call forbid_parking to use server conditions so we don't block indefinitely" "you must call forbid_parking to use server conditions so we don't block indefinitely"
); );
while !(predicate)(&*self.server.store.read().await) { while !(predicate)(&*self.server.store.lock().await) {
self.foreground.start_waiting(); self.foreground.start_waiting();
self.notifications.next().await; self.notifications.next().await;
self.foreground.finish_waiting(); self.foreground.finish_waiting();

View file

@ -51,7 +51,7 @@ use std::{
}; };
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::{ use tokio::{
sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}, sync::{Mutex, MutexGuard},
time::Sleep, time::Sleep,
}; };
use tower::ServiceBuilder; use tower::ServiceBuilder;
@ -97,7 +97,7 @@ impl<R: RequestMessage> Response<R> {
pub struct Server { pub struct Server {
peer: Arc<Peer>, peer: Arc<Peer>,
pub(crate) store: RwLock<Store>, pub(crate) store: Mutex<Store>,
app_state: Arc<AppState>, app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>, handlers: HashMap<TypeId, MessageHandler>,
notifications: Option<mpsc::UnboundedSender<()>>, notifications: Option<mpsc::UnboundedSender<()>>,
@ -115,13 +115,8 @@ pub struct RealExecutor;
const MESSAGE_COUNT_PER_PAGE: usize = 100; const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024; const MAX_MESSAGE_LEN: usize = 1024;
struct StoreReadGuard<'a> { pub(crate) struct StoreGuard<'a> {
guard: RwLockReadGuard<'a, Store>, guard: MutexGuard<'a, Store>,
_not_send: PhantomData<Rc<()>>,
}
struct StoreWriteGuard<'a> {
guard: RwLockWriteGuard<'a, Store>,
_not_send: PhantomData<Rc<()>>, _not_send: PhantomData<Rc<()>>,
} }
@ -129,7 +124,7 @@ struct StoreWriteGuard<'a> {
pub struct ServerSnapshot<'a> { pub struct ServerSnapshot<'a> {
peer: &'a Peer, peer: &'a Peer,
#[serde(serialize_with = "serialize_deref")] #[serde(serialize_with = "serialize_deref")]
store: RwLockReadGuard<'a, Store>, store: StoreGuard<'a>,
} }
pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error> pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
@ -384,7 +379,7 @@ impl Server {
).await?; ).await?;
{ {
let mut store = this.store_mut().await; let mut store = this.store().await;
store.add_connection(connection_id, user_id, user.admin); store.add_connection(connection_id, user_id, user.admin);
this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?; this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
@ -471,7 +466,7 @@ impl Server {
let mut projects_to_unregister = Vec::new(); let mut projects_to_unregister = Vec::new();
let removed_user_id; let removed_user_id;
{ {
let mut store = self.store_mut().await; let mut store = self.store().await;
let removed_connection = store.remove_connection(connection_id)?; let removed_connection = store.remove_connection(connection_id)?;
for (project_id, project) in removed_connection.hosted_projects { for (project_id, project) in removed_connection.hosted_projects {
@ -605,7 +600,7 @@ impl Server {
.await .await
.user_id_for_connection(request.sender_id)?; .user_id_for_connection(request.sender_id)?;
let project_id = self.app_state.db.register_project(user_id).await?; let project_id = self.app_state.db.register_project(user_id).await?;
self.store_mut() self.store()
.await .await
.register_project(request.sender_id, project_id)?; .register_project(request.sender_id, project_id)?;
@ -623,7 +618,7 @@ impl Server {
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let (user_id, project) = { let (user_id, project) = {
let mut state = self.store_mut().await; let mut state = self.store().await;
let project = state.unregister_project(project_id, request.sender_id)?; let project = state.unregister_project(project_id, request.sender_id)?;
(state.user_id_for_connection(request.sender_id)?, project) (state.user_id_for_connection(request.sender_id)?, project)
}; };
@ -725,7 +720,7 @@ impl Server {
return Err(anyhow!("no such project"))?; return Err(anyhow!("no such project"))?;
} }
self.store_mut().await.request_join_project( self.store().await.request_join_project(
guest_user_id, guest_user_id,
project_id, project_id,
response.into_receipt(), response.into_receipt(),
@ -747,7 +742,7 @@ impl Server {
let host_user_id; let host_user_id;
{ {
let mut state = self.store_mut().await; let mut state = self.store().await;
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let project = state.project(project_id)?; let project = state.project(project_id)?;
if project.host_connection_id != request.sender_id { if project.host_connection_id != request.sender_id {
@ -897,7 +892,7 @@ impl Server {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let project; let project;
{ {
let mut store = self.store_mut().await; let mut store = self.store().await;
project = store.leave_project(sender_id, project_id)?; project = store.leave_project(sender_id, project_id)?;
tracing::info!( tracing::info!(
%project_id, %project_id,
@ -948,7 +943,7 @@ impl Server {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let user_id; let user_id;
{ {
let mut state = self.store_mut().await; let mut state = self.store().await;
user_id = state.user_id_for_connection(request.sender_id)?; user_id = state.user_id_for_connection(request.sender_id)?;
let guest_connection_ids = state let guest_connection_ids = state
.read_project(project_id, request.sender_id)? .read_project(project_id, request.sender_id)?
@ -967,7 +962,7 @@ impl Server {
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::RegisterProjectActivity>, request: TypedEnvelope<proto::RegisterProjectActivity>,
) -> Result<()> { ) -> Result<()> {
self.store_mut().await.register_project_activity( self.store().await.register_project_activity(
ProjectId::from_proto(request.payload.project_id), ProjectId::from_proto(request.payload.project_id),
request.sender_id, request.sender_id,
)?; )?;
@ -982,7 +977,7 @@ impl Server {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let worktree_id = request.payload.worktree_id; let worktree_id = request.payload.worktree_id;
let (connection_ids, metadata_changed, extension_counts) = { let (connection_ids, metadata_changed, extension_counts) = {
let mut store = self.store_mut().await; let mut store = self.store().await;
let (connection_ids, metadata_changed, extension_counts) = store.update_worktree( let (connection_ids, metadata_changed, extension_counts) = store.update_worktree(
request.sender_id, request.sender_id,
project_id, project_id,
@ -1024,7 +1019,7 @@ impl Server {
.summary .summary
.clone() .clone()
.ok_or_else(|| anyhow!("invalid summary"))?; .ok_or_else(|| anyhow!("invalid summary"))?;
let receiver_ids = self.store_mut().await.update_diagnostic_summary( let receiver_ids = self.store().await.update_diagnostic_summary(
ProjectId::from_proto(request.payload.project_id), ProjectId::from_proto(request.payload.project_id),
request.payload.worktree_id, request.payload.worktree_id,
request.sender_id, request.sender_id,
@ -1042,7 +1037,7 @@ impl Server {
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::StartLanguageServer>, request: TypedEnvelope<proto::StartLanguageServer>,
) -> Result<()> { ) -> Result<()> {
let receiver_ids = self.store_mut().await.start_language_server( let receiver_ids = self.store().await.start_language_server(
ProjectId::from_proto(request.payload.project_id), ProjectId::from_proto(request.payload.project_id),
request.sender_id, request.sender_id,
request request
@ -1081,20 +1076,23 @@ impl Server {
where where
T: EntityMessage + RequestMessage, T: EntityMessage + RequestMessage,
{ {
let project_id = ProjectId::from_proto(request.payload.remote_entity_id());
let host_connection_id = self let host_connection_id = self
.store() .store()
.await .await
.read_project( .read_project(project_id, request.sender_id)?
ProjectId::from_proto(request.payload.remote_entity_id()),
request.sender_id,
)?
.host_connection_id; .host_connection_id;
let payload = self
response.send( .peer
self.peer
.forward_request(request.sender_id, host_connection_id, request.payload) .forward_request(request.sender_id, host_connection_id, request.payload)
.await?, .await?;
)?;
// Ensure project still exists by the time we get the response from the host.
self.store()
.await
.read_project(project_id, request.sender_id)?;
response.send(payload)?;
Ok(()) Ok(())
} }
@ -1135,7 +1133,7 @@ impl Server {
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let receiver_ids = { let receiver_ids = {
let mut store = self.store_mut().await; let mut store = self.store().await;
store.register_project_activity(project_id, request.sender_id)?; store.register_project_activity(project_id, request.sender_id)?;
store.project_connection_ids(project_id, request.sender_id)? store.project_connection_ids(project_id, request.sender_id)?
}; };
@ -1202,7 +1200,7 @@ impl Server {
let leader_id = ConnectionId(request.payload.leader_id); let leader_id = ConnectionId(request.payload.leader_id);
let follower_id = request.sender_id; let follower_id = request.sender_id;
{ {
let mut store = self.store_mut().await; let mut store = self.store().await;
if !store if !store
.project_connection_ids(project_id, follower_id)? .project_connection_ids(project_id, follower_id)?
.contains(&leader_id) .contains(&leader_id)
@ -1227,7 +1225,7 @@ impl Server {
async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> { async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let leader_id = ConnectionId(request.payload.leader_id); let leader_id = ConnectionId(request.payload.leader_id);
let mut store = self.store_mut().await; let mut store = self.store().await;
if !store if !store
.project_connection_ids(project_id, request.sender_id)? .project_connection_ids(project_id, request.sender_id)?
.contains(&leader_id) .contains(&leader_id)
@ -1245,7 +1243,7 @@ impl Server {
request: TypedEnvelope<proto::UpdateFollowers>, request: TypedEnvelope<proto::UpdateFollowers>,
) -> Result<()> { ) -> Result<()> {
let project_id = ProjectId::from_proto(request.payload.project_id); let project_id = ProjectId::from_proto(request.payload.project_id);
let mut store = self.store_mut().await; let mut store = self.store().await;
store.register_project_activity(project_id, request.sender_id)?; store.register_project_activity(project_id, request.sender_id)?;
let connection_ids = store.project_connection_ids(project_id, request.sender_id)?; let connection_ids = store.project_connection_ids(project_id, request.sender_id)?;
let leader_id = request let leader_id = request
@ -1503,7 +1501,7 @@ impl Server {
Err(anyhow!("access denied"))?; Err(anyhow!("access denied"))?;
} }
self.store_mut() self.store()
.await .await
.join_channel(request.sender_id, channel_id); .join_channel(request.sender_id, channel_id);
let messages = self let messages = self
@ -1545,7 +1543,7 @@ impl Server {
Err(anyhow!("access denied"))?; Err(anyhow!("access denied"))?;
} }
self.store_mut() self.store()
.await .await
.leave_channel(request.sender_id, channel_id); .leave_channel(request.sender_id, channel_id);
@ -1653,25 +1651,13 @@ impl Server {
Ok(()) Ok(())
} }
async fn store<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> { pub(crate) async fn store<'a>(&'a self) -> StoreGuard<'a> {
#[cfg(test)] #[cfg(test)]
tokio::task::yield_now().await; tokio::task::yield_now().await;
let guard = self.store.read().await; let guard = self.store.lock().await;
#[cfg(test)] #[cfg(test)]
tokio::task::yield_now().await; tokio::task::yield_now().await;
StoreReadGuard { StoreGuard {
guard,
_not_send: PhantomData,
}
}
async fn store_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
#[cfg(test)]
tokio::task::yield_now().await;
let guard = self.store.write().await;
#[cfg(test)]
tokio::task::yield_now().await;
StoreWriteGuard {
guard, guard,
_not_send: PhantomData, _not_send: PhantomData,
} }
@ -1679,13 +1665,13 @@ impl Server {
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> { pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
ServerSnapshot { ServerSnapshot {
store: self.store.read().await, store: self.store().await,
peer: &self.peer, peer: &self.peer,
} }
} }
} }
impl<'a> Deref for StoreReadGuard<'a> { impl<'a> Deref for StoreGuard<'a> {
type Target = Store; type Target = Store;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@ -1693,21 +1679,13 @@ impl<'a> Deref for StoreReadGuard<'a> {
} }
} }
impl<'a> Deref for StoreWriteGuard<'a> { impl<'a> DerefMut for StoreGuard<'a> {
type Target = Store;
fn deref(&self) -> &Self::Target {
&*self.guard
}
}
impl<'a> DerefMut for StoreWriteGuard<'a> {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.guard &mut *self.guard
} }
} }
impl<'a> Drop for StoreWriteGuard<'a> { impl<'a> Drop for StoreGuard<'a> {
fn drop(&mut self) { fn drop(&mut self) {
#[cfg(test)] #[cfg(test)]
self.check_invariants(); self.check_invariants();