diff --git a/.zed.toml b/.zed.toml index 6e8c8fe428..0cbe5c59a5 100644 --- a/.zed.toml +++ b/.zed.toml @@ -1 +1 @@ -collaborators = ["nathansobo", "as-cii", "maxbrunsfeld", "iamnbutler"] +collaborators = ["nathansobo", "as-cii", "maxbrunsfeld", "iamnbutler", "Kethku"] diff --git a/Cargo.lock b/Cargo.lock index 92867463c4..5e22f78551 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -998,7 +998,6 @@ dependencies = [ name = "clock" version = "0.1.0" dependencies = [ - "rpc", "smallvec", ] @@ -2236,6 +2235,7 @@ dependencies = [ "tiny-skia", "tree-sitter", "usvg", + "util", "waker-fn", ] @@ -3959,6 +3959,7 @@ dependencies = [ "async-lock", "async-tungstenite", "base64 0.13.0", + "clock", "futures", "gpui", "log", @@ -3972,6 +3973,7 @@ dependencies = [ "smol", "smol-timeout", "tempdir", + "util", "zstd", ] @@ -5574,7 +5576,6 @@ name = "util" version = "0.1.0" dependencies = [ "anyhow", - "clock", "futures", "log", "rand 0.8.3", @@ -5959,6 +5960,7 @@ name = "zed-server" version = "0.1.0" dependencies = [ "anyhow", + "async-io", "async-sqlx-session", "async-std", "async-trait", diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index c40b78987c..62d2c6fb31 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -137,8 +137,8 @@ struct ClientState { credentials: Option, status: (watch::Sender, watch::Receiver), entity_id_extractors: HashMap u64>>, - _maintain_connection: Option>, - heartbeat_interval: Duration, + _reconnect_task: Option>, + reconnect_interval: Duration, models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>, models_by_message_type: HashMap, model_types_by_message_type: HashMap, @@ -168,8 +168,8 @@ impl Default for ClientState { credentials: None, status: watch::channel_with(Status::SignedOut), entity_id_extractors: Default::default(), - _maintain_connection: None, - heartbeat_interval: Duration::from_secs(5), + _reconnect_task: None, + reconnect_interval: Duration::from_secs(5), models_by_message_type: Default::default(), models_by_entity_type_and_remote_id: Default::default(), model_types_by_message_type: Default::default(), @@ -236,7 +236,7 @@ impl Client { #[cfg(any(test, feature = "test-support"))] pub fn tear_down(&self) { let mut state = self.state.write(); - state._maintain_connection.take(); + state._reconnect_task.take(); state.message_handlers.clear(); state.models_by_message_type.clear(); state.models_by_entity_type_and_remote_id.clear(); @@ -283,21 +283,12 @@ impl Client { match status { Status::Connected { .. } => { - let heartbeat_interval = state.heartbeat_interval; - let this = self.clone(); - let foreground = cx.foreground(); - state._maintain_connection = Some(cx.foreground().spawn(async move { - loop { - foreground.timer(heartbeat_interval).await; - let _ = this.request(proto::Ping {}).await; - } - })); + state._reconnect_task = None; } Status::ConnectionLost => { let this = self.clone(); - let foreground = cx.foreground(); - let heartbeat_interval = state.heartbeat_interval; - state._maintain_connection = Some(cx.spawn(|cx| async move { + let reconnect_interval = state.reconnect_interval; + state._reconnect_task = Some(cx.spawn(|cx| async move { let mut rng = StdRng::from_entropy(); let mut delay = Duration::from_millis(100); while let Err(error) = this.authenticate_and_connect(&cx).await { @@ -308,15 +299,15 @@ impl Client { }, &cx, ); - foreground.timer(delay).await; + cx.background().timer(delay).await; delay = delay .mul_f32(rng.gen_range(1.0..=2.0)) - .min(heartbeat_interval); + .min(reconnect_interval); } })); } Status::SignedOut | Status::UpgradeRequired => { - state._maintain_connection.take(); + state._reconnect_task.take(); } _ => {} } @@ -548,7 +539,11 @@ impl Client { } async fn set_connection(self: &Arc, conn: Connection, cx: &AsyncAppContext) { - let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await; + let executor = cx.background(); + let (connection_id, handle_io, mut incoming) = self + .peer + .add_connection(conn, move |duration| executor.timer(duration)) + .await; cx.foreground() .spawn({ let cx = cx.clone(); @@ -940,26 +935,6 @@ mod tests { use crate::test::{FakeHttpClient, FakeServer}; use gpui::TestAppContext; - #[gpui::test(iterations = 10)] - async fn test_heartbeat(cx: &mut TestAppContext) { - cx.foreground().forbid_parking(); - - let user_id = 5; - let mut client = Client::new(FakeHttpClient::with_404_response()); - let server = FakeServer::for_client(user_id, &mut client, &cx).await; - - cx.foreground().advance_clock(Duration::from_secs(10)); - let ping = server.receive::().await.unwrap(); - server.respond(ping.receipt(), proto::Ack {}).await; - - cx.foreground().advance_clock(Duration::from_secs(10)); - let ping = server.receive::().await.unwrap(); - server.respond(ping.receipt(), proto::Ack {}).await; - - client.disconnect(&cx.to_async()).unwrap(); - assert!(server.receive::().await.is_err()); - } - #[gpui::test(iterations = 10)] async fn test_reconnection(cx: &mut TestAppContext) { cx.foreground().forbid_parking(); @@ -991,8 +966,6 @@ mod tests { server.roll_access_token(); server.allow_connections(); cx.foreground().advance_clock(Duration::from_secs(10)); - assert_eq!(server.auth_count(), 1); - cx.foreground().advance_clock(Duration::from_secs(10)); while !matches!(status.next().await, Some(Status::Connected { .. })) {} assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token } diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 697bf3860c..f630d9c0ee 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -6,6 +6,7 @@ use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt}; use gpui::{executor, ModelHandle, TestAppContext}; use parking_lot::Mutex; +use postage::barrier; use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope}; use std::{fmt, rc::Rc, sync::Arc}; @@ -22,6 +23,7 @@ struct FakeServerState { connection_id: Option, forbid_connections: bool, auth_count: usize, + connection_killer: Option, access_token: usize, } @@ -74,12 +76,15 @@ impl FakeServer { Err(EstablishConnectionError::Unauthorized)? } - let (client_conn, server_conn, _) = Connection::in_memory(cx.background()); - let (connection_id, io, incoming) = peer.add_connection(server_conn).await; + let (client_conn, server_conn, kill) = + Connection::in_memory(cx.background()); + let (connection_id, io, incoming) = + peer.add_test_connection(server_conn, cx.background()).await; cx.background().spawn(io).detach(); let mut state = state.lock(); state.connection_id = Some(connection_id); state.incoming = Some(incoming); + state.connection_killer = Some(kill); Ok(client_conn) }) } diff --git a/crates/clock/Cargo.toml b/crates/clock/Cargo.toml index 0b2aa2fabf..8e17e15e5e 100644 --- a/crates/clock/Cargo.toml +++ b/crates/clock/Cargo.toml @@ -9,4 +9,3 @@ doctest = false [dependencies] smallvec = { version = "1.6", features = ["union"] } -rpc = { path = "../rpc" } diff --git a/crates/clock/src/clock.rs b/crates/clock/src/clock.rs index 0fdeda0b99..e122a8b96a 100644 --- a/crates/clock/src/clock.rs +++ b/crates/clock/src/clock.rs @@ -69,37 +69,6 @@ impl<'a> AddAssign<&'a Local> for Local { #[derive(Clone, Default, Hash, Eq, PartialEq)] pub struct Global(SmallVec<[u32; 8]>); -impl From> for Global { - fn from(message: Vec) -> Self { - let mut version = Self::new(); - for entry in message { - version.observe(Local { - replica_id: entry.replica_id as ReplicaId, - value: entry.timestamp, - }); - } - version - } -} - -impl<'a> From<&'a Global> for Vec { - fn from(version: &'a Global) -> Self { - version - .iter() - .map(|entry| rpc::proto::VectorClockEntry { - replica_id: entry.replica_id as u32, - timestamp: entry.value, - }) - .collect() - } -} - -impl From for Vec { - fn from(version: Global) -> Self { - (&version).into() - } -} - impl Global { pub fn new() -> Self { Self::default() diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 8f884259b7..9973ac6549 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -14,6 +14,7 @@ test-support = ["backtrace", "dhat", "env_logger", "collections/test-support"] [dependencies] collections = { path = "../collections" } gpui_macros = { path = "../gpui_macros" } +util = { path = "../util" } sum_tree = { path = "../sum_tree" } async-task = "4.0.3" backtrace = { version = "0.3", optional = true } diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 9c4b9e90e0..2089b954fb 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Result}; use async_task::Runnable; -use smol::{channel, prelude::*, Executor, Timer}; +use smol::{channel, prelude::*, Executor}; use std::{ any::Any, fmt::{self, Display}, @@ -86,6 +86,19 @@ pub struct Deterministic { parker: parking_lot::Mutex, } +pub enum Timer { + Production(smol::Timer), + #[cfg(any(test, feature = "test-support"))] + Deterministic(DeterministicTimer), +} + +#[cfg(any(test, feature = "test-support"))] +pub struct DeterministicTimer { + rx: postage::barrier::Receiver, + id: usize, + state: Arc>, +} + #[cfg(any(test, feature = "test-support"))] impl Deterministic { pub fn new(seed: u64) -> Arc { @@ -306,15 +319,82 @@ impl Deterministic { None } - pub fn advance_clock(&self, duration: Duration) { + pub fn timer(&self, duration: Duration) -> Timer { + let (tx, rx) = postage::barrier::channel(); let mut state = self.state.lock(); - state.now += duration; - let now = state.now; - let mut pending_timers = mem::take(&mut state.pending_timers); - drop(state); + let wakeup_at = state.now + duration; + let id = util::post_inc(&mut state.next_timer_id); + state.pending_timers.push((id, wakeup_at, tx)); + let state = self.state.clone(); + Timer::Deterministic(DeterministicTimer { rx, id, state }) + } - pending_timers.retain(|(_, wakeup, _)| *wakeup > now); - self.state.lock().pending_timers.extend(pending_timers); + pub fn advance_clock(&self, duration: Duration) { + let new_now = self.state.lock().now + duration; + loop { + self.run_until_parked(); + let mut state = self.state.lock(); + + if let Some((_, wakeup_time, _)) = state.pending_timers.first() { + let wakeup_time = *wakeup_time; + if wakeup_time < new_now { + let timer_count = state + .pending_timers + .iter() + .take_while(|(_, t, _)| *t == wakeup_time) + .count(); + state.now = wakeup_time; + let timers_to_wake = state + .pending_timers + .drain(0..timer_count) + .collect::>(); + drop(state); + drop(timers_to_wake); + continue; + } + } + + break; + } + + self.state.lock().now = new_now; + } +} + +impl Drop for Timer { + fn drop(&mut self) { + #[cfg(any(test, feature = "test-support"))] + if let Timer::Deterministic(DeterministicTimer { state, id, .. }) = self { + state + .lock() + .pending_timers + .retain(|(timer_id, _, _)| timer_id != id) + } + } +} + +impl Future for Timer { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match &mut *self { + #[cfg(any(test, feature = "test-support"))] + Self::Deterministic(DeterministicTimer { rx, .. }) => { + use postage::stream::{PollRecv, Stream as _}; + smol::pin!(rx); + match rx.poll_recv(&mut postage::Context::from_waker(cx.waker())) { + PollRecv::Ready(()) | PollRecv::Closed => Poll::Ready(()), + PollRecv::Pending => Poll::Pending, + } + } + Self::Production(timer) => { + smol::pin!(timer); + match timer.poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } + } } } @@ -438,46 +518,6 @@ impl Foreground { } } - pub async fn timer(&self, duration: Duration) { - match self { - #[cfg(any(test, feature = "test-support"))] - Self::Deterministic { executor, .. } => { - use postage::prelude::Stream as _; - - let (tx, mut rx) = postage::barrier::channel(); - let timer_id; - { - let mut state = executor.state.lock(); - let wakeup_at = state.now + duration; - timer_id = util::post_inc(&mut state.next_timer_id); - state.pending_timers.push((timer_id, wakeup_at, tx)); - } - - struct DropTimer<'a>(usize, &'a Foreground); - impl<'a> Drop for DropTimer<'a> { - fn drop(&mut self) { - match self.1 { - Foreground::Deterministic { executor, .. } => { - executor - .state - .lock() - .pending_timers - .retain(|(timer_id, _, _)| *timer_id != self.0); - } - _ => unreachable!(), - } - } - } - - let _guard = DropTimer(timer_id, self); - rx.recv().await; - } - _ => { - Timer::after(duration).await; - } - } - } - #[cfg(any(test, feature = "test-support"))] pub fn advance_clock(&self, duration: Duration) { match self { @@ -600,6 +640,14 @@ impl Background { } } + pub fn timer(&self, duration: Duration) -> Timer { + match self { + Background::Production { .. } => Timer::Production(smol::Timer::after(duration)), + #[cfg(any(test, feature = "test-support"))] + Background::Deterministic { executor } => executor.timer(duration), + } + } + #[cfg(any(test, feature = "test-support"))] pub async fn simulate_random_delay(&self) { use rand::prelude::*; @@ -612,9 +660,6 @@ impl Background { for _ in 0..yields { yield_now().await; } - - let delay = Duration::from_millis(executor.state.lock().rng.gen_range(0..100)); - executor.advance_clock(delay); } } _ => panic!("this method can only be called on a deterministic executor"), diff --git a/crates/gpui/src/util.rs b/crates/gpui/src/util.rs index 9e59c387e8..dc857b4c66 100644 --- a/crates/gpui/src/util.rs +++ b/crates/gpui/src/util.rs @@ -1,5 +1,6 @@ use smol::future::FutureExt; use std::{future::Future, time::Duration}; +pub use util::*; pub fn post_inc(value: &mut usize) -> usize { let prev = *value; diff --git a/crates/language/src/proto.rs b/crates/language/src/proto.rs index 09e1a4e350..f5e797bca7 100644 --- a/crates/language/src/proto.rs +++ b/crates/language/src/proto.rs @@ -25,13 +25,13 @@ pub fn serialize_operation(operation: &Operation) -> proto::Operation { replica_id: undo.id.replica_id as u32, local_timestamp: undo.id.value, lamport_timestamp: lamport_timestamp.value, - version: From::from(&undo.version), + version: serialize_version(&undo.version), transaction_ranges: undo .transaction_ranges .iter() .map(serialize_range) .collect(), - transaction_version: From::from(&undo.transaction_version), + transaction_version: serialize_version(&undo.transaction_version), counts: undo .counts .iter() @@ -77,7 +77,7 @@ pub fn serialize_edit_operation(operation: &EditOperation) -> proto::operation:: replica_id: operation.timestamp.replica_id as u32, local_timestamp: operation.timestamp.local, lamport_timestamp: operation.timestamp.lamport, - version: From::from(&operation.version), + version: serialize_version(&operation.version), ranges: operation.ranges.iter().map(serialize_range).collect(), new_text: operation.new_text.clone(), } @@ -116,7 +116,7 @@ pub fn serialize_buffer_fragment(fragment: &text::Fragment) -> proto::BufferFrag timestamp: clock.value, }) .collect(), - max_undos: From::from(&fragment.max_undos), + max_undos: serialize_version(&fragment.max_undos), } } @@ -188,7 +188,7 @@ pub fn deserialize_operation(message: proto::Operation) -> Result { replica_id: undo.replica_id as ReplicaId, value: undo.local_timestamp, }, - version: undo.version.into(), + version: deserialize_version(undo.version), counts: undo .counts .into_iter() @@ -207,7 +207,7 @@ pub fn deserialize_operation(message: proto::Operation) -> Result { .into_iter() .map(deserialize_range) .collect(), - transaction_version: undo.transaction_version.into(), + transaction_version: deserialize_version(undo.transaction_version), }, }), proto::operation::Variant::UpdateSelections(message) => { @@ -260,7 +260,7 @@ pub fn deserialize_edit_operation(edit: proto::operation::Edit) -> EditOperation local: edit.local_timestamp, lamport: edit.lamport_timestamp, }, - version: edit.version.into(), + version: deserialize_version(edit.version), ranges: edit.ranges.into_iter().map(deserialize_range).collect(), new_text: edit.new_text, } @@ -309,7 +309,7 @@ pub fn deserialize_buffer_fragment( replica_id: entry.replica_id as ReplicaId, value: entry.timestamp, })), - max_undos: From::from(message.max_undos), + max_undos: deserialize_version(message.max_undos), } } @@ -472,8 +472,8 @@ pub fn serialize_transaction(transaction: &Transaction) -> proto::Transaction { .copied() .map(serialize_local_timestamp) .collect(), - start: (&transaction.start).into(), - end: (&transaction.end).into(), + start: serialize_version(&transaction.start), + end: serialize_version(&transaction.end), ranges: transaction.ranges.iter().map(serialize_range).collect(), } } @@ -490,8 +490,8 @@ pub fn deserialize_transaction(transaction: proto::Transaction) -> Result) -> proto::Range { pub fn deserialize_range(range: proto::Range) -> Range { FullOffset(range.start as usize)..FullOffset(range.end as usize) } + +pub fn deserialize_version(message: Vec) -> clock::Global { + let mut version = clock::Global::new(); + for entry in message { + version.observe(clock::Local { + replica_id: entry.replica_id as ReplicaId, + value: entry.timestamp, + }); + } + version +} + +pub fn serialize_version(version: &clock::Global) -> Vec { + version + .iter() + .map(|entry| proto::VectorClockEntry { + replica_id: entry.replica_id as u32, + timestamp: entry.value, + }) + .collect() +} diff --git a/crates/language/src/tests.rs b/crates/language/src/tests.rs index 34062ee601..3783f1e66d 100644 --- a/crates/language/src/tests.rs +++ b/crates/language/src/tests.rs @@ -11,8 +11,9 @@ use std::{ rc::Rc, time::{Duration, Instant}, }; +use text::network::Network; use unindent::Unindent as _; -use util::{post_inc, test::Network}; +use util::post_inc; #[cfg(test)] #[ctor::ctor] diff --git a/crates/project/src/lsp_command.rs b/crates/project/src/lsp_command.rs index 55af622a2f..4b2a7d89c1 100644 --- a/crates/project/src/lsp_command.rs +++ b/crates/project/src/lsp_command.rs @@ -5,7 +5,7 @@ use client::{proto, PeerId}; use gpui::{AppContext, AsyncAppContext, ModelHandle}; use language::{ point_from_lsp, - proto::{deserialize_anchor, serialize_anchor}, + proto::{deserialize_anchor, deserialize_version, serialize_anchor, serialize_version}, range_from_lsp, Anchor, Bias, Buffer, PointUtf16, ToLspPosition, ToPointUtf16, }; use lsp::{DocumentHighlightKind, ServerCapabilities}; @@ -126,7 +126,7 @@ impl LspCommand for PrepareRename { position: Some(language::proto::serialize_anchor( &buffer.anchor_before(self.position), )), - version: (&buffer.version()).into(), + version: serialize_version(&buffer.version()), } } @@ -142,7 +142,7 @@ impl LspCommand for PrepareRename { .ok_or_else(|| anyhow!("invalid position"))?; buffer .update(&mut cx, |buffer, _| { - buffer.wait_for_version(message.version.into()) + buffer.wait_for_version(deserialize_version(message.version)) }) .await; @@ -166,7 +166,7 @@ impl LspCommand for PrepareRename { end: range .as_ref() .map(|range| language::proto::serialize_anchor(&range.end)), - version: buffer_version.into(), + version: serialize_version(buffer_version), } } @@ -180,7 +180,7 @@ impl LspCommand for PrepareRename { if message.can_rename { buffer .update(&mut cx, |buffer, _| { - buffer.wait_for_version(message.version.into()) + buffer.wait_for_version(deserialize_version(message.version)) }) .await; let start = message.start.and_then(deserialize_anchor); @@ -255,7 +255,7 @@ impl LspCommand for PerformRename { &buffer.anchor_before(self.position), )), new_name: self.new_name.clone(), - version: (&buffer.version()).into(), + version: serialize_version(&buffer.version()), } } @@ -271,7 +271,7 @@ impl LspCommand for PerformRename { .ok_or_else(|| anyhow!("invalid position"))?; buffer .update(&mut cx, |buffer, _| { - buffer.wait_for_version(message.version.into()) + buffer.wait_for_version(deserialize_version(message.version)) }) .await; Ok(Self { @@ -407,7 +407,7 @@ impl LspCommand for GetDefinition { position: Some(language::proto::serialize_anchor( &buffer.anchor_before(self.position), )), - version: (&buffer.version()).into(), + version: serialize_version(&buffer.version()), } } @@ -423,7 +423,7 @@ impl LspCommand for GetDefinition { .ok_or_else(|| anyhow!("invalid position"))?; buffer .update(&mut cx, |buffer, _| { - buffer.wait_for_version(message.version.into()) + buffer.wait_for_version(deserialize_version(message.version)) }) .await; Ok(Self { @@ -566,7 +566,7 @@ impl LspCommand for GetReferences { position: Some(language::proto::serialize_anchor( &buffer.anchor_before(self.position), )), - version: (&buffer.version()).into(), + version: serialize_version(&buffer.version()), } } @@ -582,7 +582,7 @@ impl LspCommand for GetReferences { .ok_or_else(|| anyhow!("invalid position"))?; buffer .update(&mut cx, |buffer, _| { - buffer.wait_for_version(message.version.into()) + buffer.wait_for_version(deserialize_version(message.version)) }) .await; Ok(Self { @@ -706,7 +706,7 @@ impl LspCommand for GetDocumentHighlights { position: Some(language::proto::serialize_anchor( &buffer.anchor_before(self.position), )), - version: (&buffer.version()).into(), + version: serialize_version(&buffer.version()), } } @@ -722,7 +722,7 @@ impl LspCommand for GetDocumentHighlights { .ok_or_else(|| anyhow!("invalid position"))?; buffer .update(&mut cx, |buffer, _| { - buffer.wait_for_version(message.version.into()) + buffer.wait_for_version(deserialize_version(message.version)) }) .await; Ok(Self { diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 6dda03c216..945061a3b9 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -15,7 +15,7 @@ use gpui::{ UpgradeModelHandle, WeakModelHandle, }; use language::{ - proto::{deserialize_anchor, serialize_anchor}, + proto::{deserialize_anchor, deserialize_version, serialize_anchor, serialize_version}, range_from_lsp, Anchor, AnchorRangeExt, Bias, Buffer, CodeAction, CodeLabel, Completion, Diagnostic, DiagnosticEntry, File as _, Language, LanguageRegistry, Operation, PointUtf16, ToLspPosition, ToOffset, ToPointUtf16, Transaction, @@ -1713,14 +1713,14 @@ impl Project { project_id, buffer_id, position: Some(language::proto::serialize_anchor(&anchor)), - version: (&source_buffer.version()).into(), + version: serialize_version(&source_buffer.version()), }; cx.spawn_weak(|_, mut cx| async move { let response = rpc.request(message).await?; source_buffer_handle .update(&mut cx, |buffer, _| { - buffer.wait_for_version(response.version.into()) + buffer.wait_for_version(deserialize_version(response.version)) }) .await; @@ -1910,13 +1910,13 @@ impl Project { buffer_id, start: Some(language::proto::serialize_anchor(&range.start)), end: Some(language::proto::serialize_anchor(&range.end)), - version: (&version).into(), + version: serialize_version(&version), }) .await?; buffer_handle .update(&mut cx, |buffer, _| { - buffer.wait_for_version(response.version.into()) + buffer.wait_for_version(deserialize_version(response.version)) }) .await; @@ -2915,7 +2915,7 @@ impl Project { mut cx: AsyncAppContext, ) -> Result { let buffer_id = envelope.payload.buffer_id; - let requested_version = envelope.payload.version.try_into()?; + let requested_version = deserialize_version(envelope.payload.version); let (project_id, buffer) = this.update(&mut cx, |this, cx| { let project_id = this.remote_id().ok_or_else(|| anyhow!("not connected"))?; @@ -2936,7 +2936,7 @@ impl Project { Ok(proto::BufferSaved { project_id, buffer_id, - version: (&saved_version).into(), + version: serialize_version(&saved_version), mtime: Some(mtime.into()), }) } @@ -2981,7 +2981,7 @@ impl Project { .position .and_then(language::proto::deserialize_anchor) .ok_or_else(|| anyhow!("invalid position"))?; - let version = clock::Global::from(envelope.payload.version); + let version = deserialize_version(envelope.payload.version); let buffer = this.read_with(&cx, |this, cx| { this.opened_buffers .get(&envelope.payload.buffer_id) @@ -3001,7 +3001,7 @@ impl Project { .iter() .map(language::proto::serialize_completion) .collect(), - version: (&version).into(), + version: serialize_version(&version), }) } @@ -3062,7 +3062,7 @@ impl Project { })?; buffer .update(&mut cx, |buffer, _| { - buffer.wait_for_version(envelope.payload.version.into()) + buffer.wait_for_version(deserialize_version(envelope.payload.version)) }) .await; @@ -3077,7 +3077,7 @@ impl Project { .iter() .map(language::proto::serialize_code_action) .collect(), - version: (&version).into(), + version: serialize_version(&version), }) } @@ -3445,7 +3445,7 @@ impl Project { _: Arc, mut cx: AsyncAppContext, ) -> Result<()> { - let version = envelope.payload.version.try_into()?; + let version = deserialize_version(envelope.payload.version); let mtime = envelope .payload .mtime @@ -3473,7 +3473,7 @@ impl Project { mut cx: AsyncAppContext, ) -> Result<()> { let payload = envelope.payload.clone(); - let version = payload.version.try_into()?; + let version = deserialize_version(payload.version); let mtime = payload .mtime .ok_or_else(|| anyhow!("missing mtime"))? diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 99b0a0b298..78dac23681 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -17,7 +17,10 @@ use gpui::{ executor, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, }; -use language::{Buffer, DiagnosticEntry, Operation, PointUtf16, Rope}; +use language::{ + proto::{deserialize_version, serialize_version}, + Buffer, DiagnosticEntry, Operation, PointUtf16, Rope, +}; use lazy_static::lazy_static; use parking_lot::Mutex; use postage::{ @@ -30,7 +33,7 @@ use smol::channel::{self, Sender}; use std::{ any::Any, cmp::{self, Ordering}, - convert::{TryFrom, TryInto}, + convert::TryFrom, ffi::{OsStr, OsString}, fmt, future::Future, @@ -1423,7 +1426,7 @@ impl language::File for File { rpc.send(proto::BufferSaved { project_id, buffer_id, - version: (&version).into(), + version: serialize_version(&version), mtime: Some(entry.mtime.into()), })?; } @@ -1438,10 +1441,10 @@ impl language::File for File { .request(proto::SaveBuffer { project_id, buffer_id, - version: (&version).into(), + version: serialize_version(&version), }) .await?; - let version = response.version.try_into()?; + let version = deserialize_version(response.version); let mtime = response .mtime .ok_or_else(|| anyhow!("missing mtime"))? @@ -1518,7 +1521,7 @@ impl language::LocalFile for File { .send(proto::BufferReloaded { project_id, buffer_id, - version: version.into(), + version: serialize_version(&version), mtime: Some(mtime.into()), }) .log_err(); diff --git a/crates/rpc/Cargo.toml b/crates/rpc/Cargo.toml index 255906ab85..e773b3f0ba 100644 --- a/crates/rpc/Cargo.toml +++ b/crates/rpc/Cargo.toml @@ -26,7 +26,9 @@ rsa = "0.4" serde = { version = "1", features = ["derive"] } smol-timeout = "0.6" zstd = "0.9" +clock = { path = "../clock" } gpui = { path = "../gpui", optional = true } +util = { path = "../util" } [build-dependencies] prost-build = "0.8" diff --git a/crates/rpc/src/conn.rs b/crates/rpc/src/conn.rs index fb91b72d9f..a97797fc9d 100644 --- a/crates/rpc/src/conn.rs +++ b/crates/rpc/src/conn.rs @@ -1,6 +1,5 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; -use futures::{SinkExt as _, Stream, StreamExt as _}; -use std::{io, task::Poll}; +use futures::{SinkExt as _, StreamExt as _}; pub struct Connection { pub(crate) tx: @@ -36,87 +35,82 @@ impl Connection { #[cfg(any(test, feature = "test-support"))] pub fn in_memory( executor: std::sync::Arc, - ) -> (Self, Self, postage::watch::Sender>) { - let (kill_tx, mut kill_rx) = postage::watch::channel_with(None); - postage::stream::Stream::try_recv(&mut kill_rx).unwrap(); + ) -> (Self, Self, postage::barrier::Sender) { + use postage::prelude::Stream; - let (a_tx, a_rx) = Self::channel(kill_rx.clone(), executor.clone()); - let (b_tx, b_rx) = Self::channel(kill_rx, executor); - ( + let (kill_tx, kill_rx) = postage::barrier::channel(); + let (a_tx, a_rx) = channel(kill_rx.clone(), executor.clone()); + let (b_tx, b_rx) = channel(kill_rx, executor); + return ( Self { tx: a_tx, rx: b_rx }, Self { tx: b_tx, rx: a_rx }, kill_tx, - ) - } + ); - #[cfg(any(test, feature = "test-support"))] - fn channel( - kill_rx: postage::watch::Receiver>, - executor: std::sync::Arc, - ) -> ( - Box>, - Box>>, - ) { - use futures::channel::mpsc; - use io::{Error, ErrorKind}; - use std::sync::Arc; + fn channel( + kill_rx: postage::barrier::Receiver, + executor: std::sync::Arc, + ) -> ( + Box>, + Box< + dyn Send + Unpin + futures::Stream>, + >, + ) { + use futures::channel::mpsc; + use std::{ + io::{Error, ErrorKind}, + sync::Arc, + }; - let (tx, rx) = mpsc::unbounded::(); - let tx = tx - .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e))) - .with({ - let executor = Arc::downgrade(&executor); - let kill_rx = kill_rx.clone(); - move |msg| { + let (tx, rx) = mpsc::unbounded::(); + + let tx = tx + .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e))) + .with({ let kill_rx = kill_rx.clone(); + let executor = Arc::downgrade(&executor); + move |msg| { + let mut kill_rx = kill_rx.clone(); + let executor = executor.clone(); + Box::pin(async move { + if let Some(executor) = executor.upgrade() { + executor.simulate_random_delay().await; + } + + // Writes to a half-open TCP connection will error. + if kill_rx.try_recv().is_ok() { + std::io::Result::Err( + Error::new(ErrorKind::Other, "connection lost").into(), + )?; + } + + Ok(msg) + }) + } + }); + + let rx = rx.then({ + let kill_rx = kill_rx.clone(); + let executor = Arc::downgrade(&executor); + move |msg| { + let mut kill_rx = kill_rx.clone(); let executor = executor.clone(); Box::pin(async move { if let Some(executor) = executor.upgrade() { executor.simulate_random_delay().await; } - if kill_rx.borrow().is_none() { - Ok(msg) - } else { - Err(Error::new(ErrorKind::Other, "connection killed").into()) + + // Reads from a half-open TCP connection will hang. + if kill_rx.try_recv().is_ok() { + futures::future::pending::<()>().await; } + + Ok(msg) }) } }); - let rx = rx.then(move |msg| { - let executor = Arc::downgrade(&executor); - Box::pin(async move { - if let Some(executor) = executor.upgrade() { - executor.simulate_random_delay().await; - } - msg - }) - }); - let rx = KillableReceiver { kill_rx, rx }; - (Box::new(tx), Box::new(rx)) - } -} - -struct KillableReceiver { - rx: S, - kill_rx: postage::watch::Receiver>, -} - -impl> Stream for KillableReceiver { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) { - Poll::Ready(Some(Err(io::Error::new( - io::ErrorKind::Other, - "connection killed", - ) - .into()))) - } else { - self.rx.poll_next_unpin(cx).map(|value| value.map(Ok)) + (Box::new(tx), Box::new(rx)) } } } diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 0a00f6d801..8f1d66e47a 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -88,13 +88,14 @@ pub struct Peer { #[derive(Clone)] pub struct ConnectionState { - outgoing_tx: futures::channel::mpsc::UnboundedSender, + outgoing_tx: futures::channel::mpsc::UnboundedSender, next_message_id: Arc, response_channels: Arc>>>>, } -const WRITE_TIMEOUT: Duration = Duration::from_secs(10); +const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1); +const WRITE_TIMEOUT: Duration = Duration::from_secs(2); impl Peer { pub fn new() -> Arc { @@ -104,14 +105,20 @@ impl Peer { }) } - pub async fn add_connection( + pub async fn add_connection( self: &Arc, connection: Connection, + create_timer: F, ) -> ( ConnectionId, impl Future> + Send, BoxStream<'static, Box>, - ) { + ) + where + F: Send + Fn(Duration) -> Fut, + Fut: Send + Future, + Out: Send, + { // For outgoing messages, use an unbounded channel so that application code // can always send messages without yielding. For incoming messages, use a // bounded channel so that other peers will receive backpressure if they send @@ -121,7 +128,7 @@ impl Peer { let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst)); let connection_state = ConnectionState { - outgoing_tx, + outgoing_tx: outgoing_tx.clone(), next_message_id: Default::default(), response_channels: Arc::new(Mutex::new(Some(Default::default()))), }; @@ -131,39 +138,59 @@ impl Peer { let this = self.clone(); let response_channels = connection_state.response_channels.clone(); let handle_io = async move { - let result = 'outer: loop { - let read_message = reader.read_message().fuse(); + let _end_connection = util::defer(|| { + response_channels.lock().take(); + this.connections.write().remove(&connection_id); + }); + + // Send messages on this frequency so the connection isn't closed. + let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse(); + futures::pin_mut!(keepalive_timer); + + loop { + let read_message = reader.read().fuse(); futures::pin_mut!(read_message); + + // Disconnect if we don't receive messages at least this frequently. + let receive_timeout = create_timer(3 * KEEPALIVE_INTERVAL).fuse(); + futures::pin_mut!(receive_timeout); + loop { futures::select_biased! { outgoing = outgoing_rx.next().fuse() => match outgoing { Some(outgoing) => { - match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await { - None => break 'outer Err(anyhow!("timed out writing RPC message")), - Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"), - _ => {} + if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await { + result.context("failed to write RPC message")?; + keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse()); + } else { + Err(anyhow!("timed out writing message"))?; } } - None => break 'outer Ok(()), + None => return Ok(()), }, - incoming = read_message => match incoming { - Ok(incoming) => { + incoming = read_message => { + let incoming = incoming.context("received invalid RPC message")?; + if let proto::Message::Envelope(incoming) = incoming { if incoming_tx.send(incoming).await.is_err() { - break 'outer Ok(()); + return Ok(()); } - break; - } - Err(error) => { - break 'outer Err(error).context("received invalid RPC message") } + break; }, + _ = keepalive_timer => { + if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await { + result.context("failed to send keepalive")?; + keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse()); + } else { + Err(anyhow!("timed out sending keepalive"))?; + } + } + _ = receive_timeout => { + Err(anyhow!("delay between messages too long"))? + } } } - }; - - response_channels.lock().take(); - this.connections.write().remove(&connection_id); - result + } }; let response_channels = connection_state.response_channels.clone(); @@ -191,18 +218,31 @@ impl Peer { None } else { - if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) { - Some(envelope) - } else { + proto::build_typed_envelope(connection_id, incoming).or_else(|| { log::error!("unable to construct a typed envelope"); None - } + }) } } }); (connection_id, handle_io, incoming_rx.boxed()) } + #[cfg(any(test, feature = "test-support"))] + pub async fn add_test_connection( + self: &Arc, + connection: Connection, + executor: Arc, + ) -> ( + ConnectionId, + impl Future> + Send, + BoxStream<'static, Box>, + ) { + let executor = executor.clone(); + self.add_connection(connection, move |duration| executor.timer(duration)) + .await + } + pub fn disconnect(&self, connection_id: ConnectionId) { self.connections.write().remove(&connection_id); } @@ -245,11 +285,11 @@ impl Peer { .insert(message_id, tx); connection .outgoing_tx - .unbounded_send(request.into_envelope( + .unbounded_send(proto::Message::Envelope(request.into_envelope( message_id, None, original_sender_id.map(|id| id.0), - )) + ))) .map_err(|_| anyhow!("connection was closed"))?; Ok(()) }); @@ -272,7 +312,9 @@ impl Peer { .fetch_add(1, atomic::Ordering::SeqCst); connection .outgoing_tx - .unbounded_send(message.into_envelope(message_id, None, None))?; + .unbounded_send(proto::Message::Envelope( + message.into_envelope(message_id, None, None), + ))?; Ok(()) } @@ -288,7 +330,11 @@ impl Peer { .fetch_add(1, atomic::Ordering::SeqCst); connection .outgoing_tx - .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?; + .unbounded_send(proto::Message::Envelope(message.into_envelope( + message_id, + None, + Some(sender_id.0), + )))?; Ok(()) } @@ -303,7 +349,11 @@ impl Peer { .fetch_add(1, atomic::Ordering::SeqCst); connection .outgoing_tx - .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?; + .unbounded_send(proto::Message::Envelope(response.into_envelope( + message_id, + Some(receipt.message_id), + None, + )))?; Ok(()) } @@ -318,7 +368,11 @@ impl Peer { .fetch_add(1, atomic::Ordering::SeqCst); connection .outgoing_tx - .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?; + .unbounded_send(proto::Message::Envelope(response.into_envelope( + message_id, + Some(receipt.message_id), + None, + )))?; Ok(()) } @@ -347,17 +401,23 @@ mod tests { let client1 = Peer::new(); let client2 = Peer::new(); - let (client1_to_server_conn, server_to_client_1_conn, _) = + let (client1_to_server_conn, server_to_client_1_conn, _kill) = Connection::in_memory(cx.background()); - let (client1_conn_id, io_task1, client1_incoming) = - client1.add_connection(client1_to_server_conn).await; - let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await; + let (client1_conn_id, io_task1, client1_incoming) = client1 + .add_test_connection(client1_to_server_conn, cx.background()) + .await; + let (_, io_task2, server_incoming1) = server + .add_test_connection(server_to_client_1_conn, cx.background()) + .await; - let (client2_to_server_conn, server_to_client_2_conn, _) = + let (client2_to_server_conn, server_to_client_2_conn, _kill) = Connection::in_memory(cx.background()); - let (client2_conn_id, io_task3, client2_incoming) = - client2.add_connection(client2_to_server_conn).await; - let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await; + let (client2_conn_id, io_task3, client2_incoming) = client2 + .add_test_connection(client2_to_server_conn, cx.background()) + .await; + let (_, io_task4, server_incoming2) = server + .add_test_connection(server_to_client_2_conn, cx.background()) + .await; executor.spawn(io_task1).detach(); executor.spawn(io_task2).detach(); @@ -438,12 +498,14 @@ mod tests { let server = Peer::new(); let client = Peer::new(); - let (client_to_server_conn, server_to_client_conn, _) = + let (client_to_server_conn, server_to_client_conn, _kill) = Connection::in_memory(cx.background()); - let (client_to_server_conn_id, io_task1, mut client_incoming) = - client.add_connection(client_to_server_conn).await; - let (server_to_client_conn_id, io_task2, mut server_incoming) = - server.add_connection(server_to_client_conn).await; + let (client_to_server_conn_id, io_task1, mut client_incoming) = client + .add_test_connection(client_to_server_conn, cx.background()) + .await; + let (server_to_client_conn_id, io_task2, mut server_incoming) = server + .add_test_connection(server_to_client_conn, cx.background()) + .await; executor.spawn(io_task1).detach(); executor.spawn(io_task2).detach(); @@ -536,12 +598,14 @@ mod tests { let server = Peer::new(); let client = Peer::new(); - let (client_to_server_conn, server_to_client_conn, _) = + let (client_to_server_conn, server_to_client_conn, _kill) = Connection::in_memory(cx.background()); - let (client_to_server_conn_id, io_task1, mut client_incoming) = - client.add_connection(client_to_server_conn).await; - let (server_to_client_conn_id, io_task2, mut server_incoming) = - server.add_connection(server_to_client_conn).await; + let (client_to_server_conn_id, io_task1, mut client_incoming) = client + .add_test_connection(client_to_server_conn, cx.background()) + .await; + let (server_to_client_conn_id, io_task2, mut server_incoming) = server + .add_test_connection(server_to_client_conn, cx.background()) + .await; executor.spawn(io_task1).detach(); executor.spawn(io_task2).detach(); @@ -646,10 +710,12 @@ mod tests { async fn test_disconnect(cx: &mut TestAppContext) { let executor = cx.foreground(); - let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background()); + let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background()); let client = Peer::new(); - let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await; + let (connection_id, io_handler, mut incoming) = client + .add_test_connection(client_conn, cx.background()) + .await; let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel(); executor @@ -680,10 +746,12 @@ mod tests { #[gpui::test(iterations = 50)] async fn test_io_error(cx: &mut TestAppContext) { let executor = cx.foreground(); - let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background()); + let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background()); let client = Peer::new(); - let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await; + let (connection_id, io_handler, mut incoming) = client + .add_test_connection(client_conn, cx.background()) + .await; executor.spawn(io_handler).detach(); executor .spawn(async move { incoming.next().await }) diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 0729dbc76a..ffb74f4939 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -2,7 +2,7 @@ use super::{ConnectionId, PeerId, TypedEnvelope}; use anyhow::Result; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; use futures::{SinkExt as _, StreamExt as _}; -use prost::Message; +use prost::Message as _; use std::any::{Any, TypeId}; use std::{ io, @@ -283,6 +283,13 @@ pub struct MessageStream { encoding_buffer: Vec, } +#[derive(Debug)] +pub enum Message { + Envelope(Envelope), + Ping, + Pong, +} + impl MessageStream { pub fn new(stream: S) -> Self { Self { @@ -300,22 +307,37 @@ impl MessageStream where S: futures::Sink + Unpin, { - /// Write a given protobuf message to the stream. - pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> { + pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> { #[cfg(any(test, feature = "test-support"))] const COMPRESSION_LEVEL: i32 = -7; #[cfg(not(any(test, feature = "test-support")))] const COMPRESSION_LEVEL: i32 = 4; - self.encoding_buffer.resize(message.encoded_len(), 0); - self.encoding_buffer.clear(); - message - .encode(&mut self.encoding_buffer) - .map_err(|err| io::Error::from(err))?; - let buffer = - zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap(); - self.stream.send(WebSocketMessage::Binary(buffer)).await?; + match message { + Message::Envelope(message) => { + self.encoding_buffer.resize(message.encoded_len(), 0); + self.encoding_buffer.clear(); + message + .encode(&mut self.encoding_buffer) + .map_err(|err| io::Error::from(err))?; + let buffer = + zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL) + .unwrap(); + self.stream.send(WebSocketMessage::Binary(buffer)).await?; + } + Message::Ping => { + self.stream + .send(WebSocketMessage::Ping(Default::default())) + .await?; + } + Message::Pong => { + self.stream + .send(WebSocketMessage::Pong(Default::default())) + .await?; + } + } + Ok(()) } } @@ -324,8 +346,7 @@ impl MessageStream where S: futures::Stream> + Unpin, { - /// Read a protobuf message of the given type from the stream. - pub async fn read_message(&mut self) -> Result { + pub async fn read(&mut self) -> Result { while let Some(bytes) = self.stream.next().await { match bytes? { WebSocketMessage::Binary(bytes) => { @@ -333,8 +354,10 @@ where zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap(); let envelope = Envelope::decode(self.encoding_buffer.as_slice()) .map_err(io::Error::from)?; - return Ok(envelope); + return Ok(Message::Envelope(envelope)); } + WebSocketMessage::Ping(_) => return Ok(Message::Ping), + WebSocketMessage::Pong(_) => return Ok(Message::Pong), WebSocketMessage::Close(_) => break, _ => {} } diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index d7984dad04..c39fb2f10b 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -16,6 +16,7 @@ required-features = ["seed-support"] collections = { path = "../collections" } rpc = { path = "../rpc" } anyhow = "1.0.40" +async-io = "1.3" async-std = { version = "1.8.0", features = ["attributes"] } async-trait = "0.1.50" async-tungstenite = "0.16" diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index 9a8b4ee161..241217fe63 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -6,6 +6,7 @@ use super::{ AppState, }; use anyhow::anyhow; +use async_io::Timer; use async_std::task; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use collections::{HashMap, HashSet}; @@ -16,7 +17,12 @@ use rpc::{ Connection, ConnectionId, Peer, TypedEnvelope, }; use sha1::{Digest as _, Sha1}; -use std::{any::TypeId, future::Future, sync::Arc, time::Instant}; +use std::{ + any::TypeId, + future::Future, + sync::Arc, + time::{Duration, Instant}, +}; use store::{Store, Worktree}; use surf::StatusCode; use tide::log; @@ -40,10 +46,13 @@ pub struct Server { notifications: Option>, } -pub trait Executor { +pub trait Executor: Send + Clone { + type Timer: Send + Future; fn spawn_detached>(&self, future: F); + fn timer(&self, duration: Duration) -> Self::Timer; } +#[derive(Clone)] pub struct RealExecutor; const MESSAGE_COUNT_PER_PAGE: usize = 100; @@ -167,8 +176,18 @@ impl Server { ) -> impl Future { let mut this = self.clone(); async move { - let (connection_id, handle_io, mut incoming_rx) = - this.peer.add_connection(connection).await; + let (connection_id, handle_io, mut incoming_rx) = this + .peer + .add_connection(connection, { + let executor = executor.clone(); + move |duration| { + let timer = executor.timer(duration); + async move { + timer.await; + } + } + }) + .await; if let Some(send_connection_id) = send_connection_id.as_mut() { let _ = send_connection_id.send(connection_id).await; @@ -883,9 +902,15 @@ impl Server { } impl Executor for RealExecutor { + type Timer = Timer; + fn spawn_detached>(&self, future: F) { task::spawn(future); } + + fn timer(&self, duration: Duration) -> Self::Timer { + Timer::after(duration) + } } fn broadcast( @@ -1005,7 +1030,7 @@ mod tests { }; use lsp; use parking_lot::Mutex; - use postage::{sink::Sink, watch}; + use postage::{barrier, watch}; use project::{ fs::{FakeFs, Fs as _}, search::SearchQuery, @@ -1759,7 +1784,7 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_peer_disconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { + async fn test_leaving_project(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); let lang_registry = Arc::new(LanguageRegistry::new()); let fs = FakeFs::new(cx_a.background()); @@ -1817,16 +1842,40 @@ mod tests { .await .unwrap(); - // See that a guest has joined as client A. + // Client A sees that a guest has joined. project_a .condition(&cx_a, |p, _| p.collaborators().len() == 1) .await; - // Drop client B's connection and ensure client A observes client B leaving the worktree. + // Drop client B's connection and ensure client A observes client B leaving the project. client_b.disconnect(&cx_b.to_async()).unwrap(); project_a .condition(&cx_a, |p, _| p.collaborators().len() == 0) .await; + + // Rejoin the project as client B + let _project_b = Project::remote( + project_id, + client_b.clone(), + client_b.user_store.clone(), + lang_registry.clone(), + fs.clone(), + &mut cx_b.to_async(), + ) + .await + .unwrap(); + + // Client A sees that a guest has re-joined. + project_a + .condition(&cx_a, |p, _| p.collaborators().len() == 1) + .await; + + // Simulate connection loss for client B and ensure client A observes client B leaving the project. + server.disconnect_client(client_b.current_user_id(cx_b)); + cx_a.foreground().advance_clock(Duration::from_secs(3)); + project_a + .condition(&cx_a, |p, _| p.collaborators().len() == 0) + .await; } #[gpui::test(iterations = 10)] @@ -2683,8 +2732,6 @@ mod tests { .read_with(cx_a, |tree, _| tree.as_local().unwrap().scan_complete()) .await; - eprintln!("sharing"); - project_a.update(cx_a, |p, cx| p.share(cx)).await.unwrap(); // Join the worktree as client B. @@ -3850,6 +3897,7 @@ mod tests { // Disconnect client B, ensuring we can still access its cached channel data. server.forbid_connections(); server.disconnect_client(client_b.current_user_id(&cx_b)); + cx_b.foreground().advance_clock(Duration::from_secs(3)); while !matches!( status_b.next().await, Some(client::Status::ReconnectionError { .. }) @@ -4340,7 +4388,7 @@ mod tests { server: Arc, foreground: Rc, notifications: mpsc::UnboundedReceiver<()>, - connection_killers: Arc>>>>, + connection_killers: Arc>>, forbid_connections: Arc, _test_db: TestDb, } @@ -4444,9 +4492,7 @@ mod tests { } fn disconnect_client(&self, user_id: UserId) { - if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) { - let _ = kill_conn.try_send(Some(())); - } + self.connection_killers.lock().remove(&user_id); } fn forbid_connections(&self) { @@ -5031,9 +5077,15 @@ mod tests { } impl Executor for Arc { + type Timer = gpui::executor::Timer; + fn spawn_detached>(&self, future: F) { self.spawn(future).detach(); } + + fn timer(&self, duration: Duration) -> Self::Timer { + self.as_ref().timer(duration) + } } fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> { diff --git a/crates/text/src/network.rs b/crates/text/src/network.rs new file mode 100644 index 0000000000..2f49756ca3 --- /dev/null +++ b/crates/text/src/network.rs @@ -0,0 +1,69 @@ +use clock::ReplicaId; + +pub struct Network { + inboxes: std::collections::BTreeMap>>, + all_messages: Vec, + rng: R, +} + +#[derive(Clone)] +struct Envelope { + message: T, +} + +impl Network { + pub fn new(rng: R) -> Self { + Network { + inboxes: Default::default(), + all_messages: Vec::new(), + rng, + } + } + + pub fn add_peer(&mut self, id: ReplicaId) { + self.inboxes.insert(id, Vec::new()); + } + + pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) { + self.inboxes + .insert(new_replica_id, self.inboxes[&old_replica_id].clone()); + } + + pub fn is_idle(&self) -> bool { + self.inboxes.values().all(|i| i.is_empty()) + } + + pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec) { + for (replica, inbox) in self.inboxes.iter_mut() { + if *replica != sender { + for message in &messages { + // Insert one or more duplicates of this message, potentially *before* the previous + // message sent by this peer to simulate out-of-order delivery. + for _ in 0..self.rng.gen_range(1..4) { + let insertion_index = self.rng.gen_range(0..inbox.len() + 1); + inbox.insert( + insertion_index, + Envelope { + message: message.clone(), + }, + ); + } + } + } + } + self.all_messages.extend(messages); + } + + pub fn has_unreceived(&self, receiver: ReplicaId) -> bool { + !self.inboxes[&receiver].is_empty() + } + + pub fn receive(&mut self, receiver: ReplicaId) -> Vec { + let inbox = self.inboxes.get_mut(&receiver).unwrap(); + let count = self.rng.gen_range(0..inbox.len() + 1); + inbox + .drain(0..count) + .map(|envelope| envelope.message) + .collect() + } +} diff --git a/crates/text/src/tests.rs b/crates/text/src/tests.rs index 4f5e6effb6..05cf0af6ec 100644 --- a/crates/text/src/tests.rs +++ b/crates/text/src/tests.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{network::Network, *}; use clock::ReplicaId; use rand::prelude::*; use std::{ @@ -7,7 +7,6 @@ use std::{ iter::Iterator, time::{Duration, Instant}, }; -use util::test::Network; #[cfg(test)] #[ctor::ctor] diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index 89943777e0..849d6326f2 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -1,5 +1,7 @@ mod anchor; pub mod locator; +#[cfg(any(test, feature = "test-support"))] +pub mod network; pub mod operation_queue; mod patch; mod point; diff --git a/crates/util/Cargo.toml b/crates/util/Cargo.toml index 634e031aee..9d39fb04e2 100644 --- a/crates/util/Cargo.toml +++ b/crates/util/Cargo.toml @@ -7,10 +7,9 @@ edition = "2021" doctest = false [features] -test-support = ["clock", "rand", "serde_json", "tempdir"] +test-support = ["rand", "serde_json", "tempdir"] [dependencies] -clock = { path = "../clock", optional = true } anyhow = "1.0.38" futures = "0.3" log = "0.4" diff --git a/crates/util/src/lib.rs b/crates/util/src/lib.rs index 919fecf8f9..adefa732e7 100644 --- a/crates/util/src/lib.rs +++ b/crates/util/src/lib.rs @@ -123,6 +123,18 @@ where } } +struct Defer(Option); + +impl Drop for Defer { + fn drop(&mut self) { + self.0.take().map(|f| f()); + } +} + +pub fn defer(f: F) -> impl Drop { + Defer(Some(f)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/util/src/test.rs b/crates/util/src/test.rs index 73b5461261..71b847df69 100644 --- a/crates/util/src/test.rs +++ b/crates/util/src/test.rs @@ -1,75 +1,6 @@ -use clock::ReplicaId; use std::path::{Path, PathBuf}; use tempdir::TempDir; -#[derive(Clone)] -struct Envelope { - message: T, -} - -pub struct Network { - inboxes: std::collections::BTreeMap>>, - all_messages: Vec, - rng: R, -} - -impl Network { - pub fn new(rng: R) -> Self { - Network { - inboxes: Default::default(), - all_messages: Vec::new(), - rng, - } - } - - pub fn add_peer(&mut self, id: ReplicaId) { - self.inboxes.insert(id, Vec::new()); - } - - pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) { - self.inboxes - .insert(new_replica_id, self.inboxes[&old_replica_id].clone()); - } - - pub fn is_idle(&self) -> bool { - self.inboxes.values().all(|i| i.is_empty()) - } - - pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec) { - for (replica, inbox) in self.inboxes.iter_mut() { - if *replica != sender { - for message in &messages { - // Insert one or more duplicates of this message, potentially *before* the previous - // message sent by this peer to simulate out-of-order delivery. - for _ in 0..self.rng.gen_range(1..4) { - let insertion_index = self.rng.gen_range(0..inbox.len() + 1); - inbox.insert( - insertion_index, - Envelope { - message: message.clone(), - }, - ); - } - } - } - } - self.all_messages.extend(messages); - } - - pub fn has_unreceived(&self, receiver: ReplicaId) -> bool { - !self.inboxes[&receiver].is_empty() - } - - pub fn receive(&mut self, receiver: ReplicaId) -> Vec { - let inbox = self.inboxes.get_mut(&receiver).unwrap(); - let count = self.rng.gen_range(0..inbox.len() + 1); - inbox - .drain(0..count) - .map(|envelope| envelope.message) - .collect() - } -} - pub fn temp_tree(tree: serde_json::Value) -> TempDir { let dir = TempDir::new("").unwrap(); write_tree(dir.path(), tree);