From af6e931da7ab88df95fc6d010f1ad87da6e5ce6b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 8 Sep 2021 15:58:16 +0200 Subject: [PATCH 01/22] Start on a `Client::status` method that can be observed --- server/src/rpc.rs | 2 +- zed/src/channel.rs | 21 ++++----- zed/src/chat_panel.rs | 30 +++++++----- zed/src/rpc.rs | 105 +++++++++++++++++++++++++++++++----------- zrpc/src/peer.rs | 5 ++ 5 files changed, 111 insertions(+), 52 deletions(-) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 34f5f378d9..7562bdf74b 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1714,7 +1714,7 @@ mod tests { ) .detach(); client - .add_connection(user_id.to_proto(), client_conn, &cx.to_async()) + .set_connection(user_id.to_proto(), client_conn, &cx.to_async()) .await .unwrap(); (user_id, client) diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 24997d4964..38329c70d4 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -90,18 +90,15 @@ impl ChannelList { let _task = cx.spawn(|this, mut cx| { let rpc = rpc.clone(); async move { - let mut user_id = rpc.user_id(); + let mut status = rpc.status(); loop { - let available_channels = if user_id.recv().await.unwrap().is_some() { - Some( - rpc.request(proto::GetChannels {}) - .await - .context("failed to fetch available channels")? - .channels - .into_iter() - .map(Into::into) - .collect(), - ) + let status = status.recv().await.unwrap(); + let available_channels = if matches!(status, rpc::Status::Connected { .. }) { + let response = rpc + .request(proto::GetChannels {}) + .await + .context("failed to fetch available channels")?; + Some(response.channels.into_iter().map(Into::into).collect()) } else { None }; @@ -671,7 +668,7 @@ mod tests { cx.background().spawn(io).detach(); client - .add_connection(user_id, client_conn, &cx.to_async()) + .set_connection(user_id, client_conn, &cx.to_async()) .await .unwrap(); diff --git a/zed/src/chat_panel.rs b/zed/src/chat_panel.rs index 18b737b2d8..ab38f64169 100644 --- a/zed/src/chat_panel.rs +++ b/zed/src/chat_panel.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use crate::{ channel::{Channel, ChannelEvent, ChannelList, ChannelMessage}, editor::Editor, - rpc::Client, + rpc::{self, Client}, theme, util::{ResultExt, TryFutureExt}, Settings, @@ -14,10 +14,10 @@ use gpui::{ keymap::Binding, platform::CursorStyle, views::{ItemType, Select, SelectStyle}, - AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, View, + AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, Task, View, ViewContext, ViewHandle, }; -use postage::watch; +use postage::{prelude::Stream, watch}; use time::{OffsetDateTime, UtcOffset}; const MESSAGE_LOADING_THRESHOLD: usize = 50; @@ -31,6 +31,7 @@ pub struct ChatPanel { channel_select: ViewHandle, settings: watch::Receiver, local_timezone: UtcOffset, - _status_observer: Task<()>, + _observe_status: Task<()>, } pub enum Event {} @@ -99,7 +99,7 @@ impl ChatPanel { cx.dispatch_action(LoadMoreMessages); } }); - let _status_observer = cx.spawn(|this, mut cx| { + let _observe_status = cx.spawn(|this, mut cx| { let mut status = rpc.status(); async move { while let Some(_) = status.recv().await { @@ -117,7 +117,7 @@ impl ChatPanel { channel_select, settings, local_timezone: cx.platform().local_timezone(), - _status_observer, + _observe_status, }; this.init_active_channel(cx); diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 4daff984e8..0427e3ad92 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -6,7 +6,6 @@ use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use lazy_static::lazy_static; use parking_lot::RwLock; use postage::prelude::Stream; -use postage::sink::Sink; use postage::watch; use std::any::TypeId; use std::collections::HashMap; @@ -94,10 +93,9 @@ impl Client { self.state.read().status.1.clone() } - async fn set_status(&self, status: Status) -> Result<()> { + fn set_status(&self, status: Status) { let mut state = self.state.write(); - state.status.0.send(status).await?; - Ok(()) + *state.status.0.borrow_mut() = status; } pub fn subscribe_from_model( @@ -167,14 +165,14 @@ impl Client { let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?; let user_id = user_id.parse::()?; - self.set_status(Status::Connecting).await?; + self.set_status(Status::Connecting); match self.connect(user_id, &access_token, cx).await { Ok(()) => { log::info!("connected to rpc address {}", *ZED_SERVER_URL); Ok(()) } Err(err) => { - self.set_status(Status::ConnectionError).await?; + self.set_status(Status::ConnectionError); Err(err) } } @@ -259,20 +257,17 @@ impl Client { self.set_status(Status::Connected { connection_id, user_id, - }) - .await?; + }); let handle_io = cx.background().spawn(handle_io); let this = self.clone(); cx.foreground() .spawn(async move { match handle_io.await { - Ok(()) => { - let _ = this.set_status(Status::Disconnected).await; - } + Ok(()) => this.set_status(Status::Disconnected), Err(err) => { log::error!("connection error: {:?}", err); - let _ = this.set_status(Status::ConnectionLost).await; + this.set_status(Status::ConnectionLost); } } }) @@ -365,7 +360,7 @@ impl Client { pub async fn disconnect(&self) -> Result<()> { let conn_id = self.connection_id()?; self.peer.disconnect(conn_id).await; - self.set_status(Status::Disconnected).await?; + self.set_status(Status::Disconnected); Ok(()) } From a3bbf71390fd0b01ab1bdfa88c52c6c6cd7ec186 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 8 Sep 2021 17:10:48 +0200 Subject: [PATCH 03/22] :art: --- zed/src/rpc.rs | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 0427e3ad92..f3df1e7b88 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -1,22 +1,24 @@ use crate::util::ResultExt; use anyhow::{anyhow, Context, Result}; -use async_tungstenite::tungstenite::http::Request; -use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; +use async_tungstenite::tungstenite::{ + http::Request, Error as WebSocketError, Message as WebSocketMessage, +}; use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use lazy_static::lazy_static; use parking_lot::RwLock; -use postage::prelude::Stream; -use postage::watch; -use std::any::TypeId; -use std::collections::HashMap; -use std::sync::Weak; -use std::time::{Duration, Instant}; -use std::{convert::TryFrom, future::Future, sync::Arc}; +use postage::{prelude::Stream, watch}; +use std::{ + any::TypeId, + collections::HashMap, + convert::TryFrom, + future::Future, + sync::{Arc, Weak}, + time::{Duration, Instant}, +}; use surf::Url; -use zrpc::proto::{AnyTypedEnvelope, EntityMessage}; pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope}; use zrpc::{ - proto::{EnvelopedMessage, RequestMessage}, + proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}, Peer, Receipt, }; From 900010160f1c9ffc909bcd8c8788cf93c495d3e6 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 8 Sep 2021 18:58:59 +0200 Subject: [PATCH 04/22] WIP Co-Authored-By: Max Brunsfeld --- gpui/src/executor.rs | 38 ++++++++++++++- server/src/rpc.rs | 2 +- zed/src/channel.rs | 62 +++---------------------- zed/src/rpc.rs | 107 +++++++++++++++++++++++++++++++++++-------- zed/src/test.rs | 54 +++++++++++++++++++++- 5 files changed, 182 insertions(+), 81 deletions(-) diff --git a/gpui/src/executor.rs b/gpui/src/executor.rs index b135f5034d..7a223a96d2 100644 --- a/gpui/src/executor.rs +++ b/gpui/src/executor.rs @@ -3,8 +3,9 @@ use async_task::Runnable; pub use async_task::Task; use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString}; use parking_lot::Mutex; +use postage::{barrier, prelude::Stream as _}; use rand::prelude::*; -use smol::{channel, prelude::*, Executor}; +use smol::{channel, prelude::*, Executor, Timer}; use std::{ fmt::{self, Debug}, marker::PhantomData, @@ -18,7 +19,7 @@ use std::{ }, task::{Context, Poll}, thread, - time::Duration, + time::{Duration, Instant}, }; use waker_fn::waker_fn; @@ -49,6 +50,8 @@ struct DeterministicState { spawned_from_foreground: Vec<(Runnable, Backtrace)>, forbid_parking: bool, block_on_ticks: RangeInclusive, + now: Instant, + pending_sleeps: Vec<(Instant, barrier::Sender)>, } pub struct Deterministic { @@ -67,6 +70,8 @@ impl Deterministic { spawned_from_foreground: Default::default(), forbid_parking: false, block_on_ticks: 0..=1000, + now: Instant::now(), + pending_sleeps: Default::default(), })), parker: Default::default(), } @@ -407,6 +412,35 @@ impl Foreground { } } + pub async fn sleep(&self, duration: Duration) { + match self { + Self::Deterministic(executor) => { + let (tx, mut rx) = barrier::channel(); + { + let mut state = executor.state.lock(); + let wakeup_at = state.now + duration; + state.pending_sleeps.push((wakeup_at, tx)); + } + rx.recv().await; + } + _ => { + Timer::after(duration).await; + } + } + } + + pub fn advance_clock(&self, duration: Duration) { + match self { + Self::Deterministic(executor) => { + let mut state = executor.state.lock(); + state.now += duration; + let now = state.now; + state.pending_sleeps.retain(|(wakeup, _)| *wakeup > now); + } + _ => panic!("this method can only be called on a deterministic executor"), + } + } + pub fn set_block_on_ticks(&self, range: RangeInclusive) { match self { Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range, diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 7562bdf74b..e1b1bce058 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1469,7 +1469,7 @@ mod tests { .await; // Drop client B's connection and ensure client A observes client B leaving the worktree. - client_b.disconnect().await.unwrap(); + client_b.disconnect(&cx_b.to_async()).await.unwrap(); worktree_a .condition(&cx_a, |tree, _| tree.peers().len() == 0) .await; diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 38329c70d4..234e3e1e5f 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -443,9 +443,8 @@ impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count { #[cfg(test)] mod tests { use super::*; + use crate::test::FakeServer; use gpui::TestAppContext; - use postage::mpsc::Receiver; - use zrpc::{test::Channel, ConnectionId, Peer, Receipt}; #[gpui::test] async fn test_channel_messages(mut cx: TestAppContext) { @@ -458,7 +457,7 @@ mod tests { channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None)); // Get the available channels. - let get_channels = server.receive::().await; + let get_channels = server.receive::().await.unwrap(); server .respond( get_channels.receipt(), @@ -489,7 +488,7 @@ mod tests { }) .unwrap(); channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty())); - let join_channel = server.receive::().await; + let join_channel = server.receive::().await.unwrap(); server .respond( join_channel.receipt(), @@ -514,7 +513,7 @@ mod tests { .await; // Client requests all users for the received messages - let mut get_users = server.receive::().await; + let mut get_users = server.receive::().await.unwrap(); get_users.payload.user_ids.sort(); assert_eq!(get_users.payload.user_ids, vec![5, 6]); server @@ -571,7 +570,7 @@ mod tests { .await; // Client requests user for message since they haven't seen them yet - let get_users = server.receive::().await; + let get_users = server.receive::().await.unwrap(); assert_eq!(get_users.payload.user_ids, vec![7]); server .respond( @@ -607,7 +606,7 @@ mod tests { channel.update(&mut cx, |channel, cx| { assert!(channel.load_more_messages(cx)); }); - let get_messages = server.receive::().await; + let get_messages = server.receive::().await.unwrap(); assert_eq!(get_messages.payload.channel_id, 5); assert_eq!(get_messages.payload.before_message_id, 10); server @@ -653,53 +652,4 @@ mod tests { ); }); } - - struct FakeServer { - peer: Arc, - incoming: Receiver>, - connection_id: ConnectionId, - } - - impl FakeServer { - async fn for_client(user_id: u64, client: &Arc, cx: &TestAppContext) -> Self { - let (client_conn, server_conn) = Channel::bidirectional(); - let peer = Peer::new(); - let (connection_id, io, incoming) = peer.add_connection(server_conn).await; - cx.background().spawn(io).detach(); - - client - .set_connection(user_id, client_conn, &cx.to_async()) - .await - .unwrap(); - - Self { - peer, - incoming, - connection_id, - } - } - - async fn send(&self, message: T) { - self.peer.send(self.connection_id, message).await.unwrap(); - } - - async fn receive(&mut self) -> TypedEnvelope { - *self - .incoming - .recv() - .await - .unwrap() - .into_any() - .downcast::>() - .unwrap() - } - - async fn respond( - &self, - receipt: Receipt, - response: T::Response, - ) { - self.peer.respond(receipt, response).await.unwrap() - } - } } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index f3df1e7b88..b36ec4d376 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -3,10 +3,12 @@ use anyhow::{anyhow, Context, Result}; use async_tungstenite::tungstenite::{ http::Request, Error as WebSocketError, Message as WebSocketMessage, }; +use futures::StreamExt as _; use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use lazy_static::lazy_static; use parking_lot::RwLock; use postage::{prelude::Stream, watch}; +use smol::Timer; use std::{ any::TypeId, collections::HashMap, @@ -42,6 +44,10 @@ pub enum Status { user_id: u64, }, ConnectionLost, + Reconnecting, + ReconnectionError { + next_reconnection: Instant, + }, } struct ClientState { @@ -51,6 +57,8 @@ struct ClientState { (TypeId, u64), Box, &mut AsyncAppContext)>, >, + _maintain_connection: Option>, + heartbeat_interval: Duration, } impl Default for ClientState { @@ -59,6 +67,8 @@ impl Default for ClientState { status: watch::channel_with(Status::Disconnected), entity_id_extractors: Default::default(), model_handlers: Default::default(), + _maintain_connection: None, + heartbeat_interval: Duration::from_secs(5), } } } @@ -95,9 +105,35 @@ impl Client { self.state.read().status.1.clone() } - fn set_status(&self, status: Status) { + fn set_status(self: &Arc, status: Status, cx: &AsyncAppContext) { let mut state = self.state.write(); *state.status.0.borrow_mut() = status; + match status { + Status::Connected { .. } => { + let heartbeat_interval = state.heartbeat_interval; + let this = self.clone(); + let foreground = cx.foreground(); + state._maintain_connection = Some(cx.foreground().spawn(async move { + let mut next_ping_id = 0; + loop { + foreground.sleep(heartbeat_interval).await; + this.request(proto::Ping { id: next_ping_id }) + .await + .unwrap(); + next_ping_id += 1; + } + })); + } + Status::ConnectionLost => { + state._maintain_connection = Some(cx.foreground().spawn(async move { + // TODO: try to reconnect + })); + } + Status::Disconnected => { + state._maintain_connection.take(); + } + _ => {} + } } pub fn subscribe_from_model( @@ -167,14 +203,14 @@ impl Client { let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?; let user_id = user_id.parse::()?; - self.set_status(Status::Connecting); + self.set_status(Status::Connecting, cx); match self.connect(user_id, &access_token, cx).await { Ok(()) => { log::info!("connected to rpc address {}", *ZED_SERVER_URL); Ok(()) } Err(err) => { - self.set_status(Status::ConnectionError); + self.set_status(Status::ConnectionError, cx); Err(err) } } @@ -256,20 +292,24 @@ impl Client { .detach(); } - self.set_status(Status::Connected { - connection_id, - user_id, - }); + self.set_status( + Status::Connected { + connection_id, + user_id, + }, + cx, + ); let handle_io = cx.background().spawn(handle_io); let this = self.clone(); + let cx = cx.clone(); cx.foreground() .spawn(async move { match handle_io.await { - Ok(()) => this.set_status(Status::Disconnected), + Ok(()) => this.set_status(Status::Disconnected, &cx), Err(err) => { log::error!("connection error: {:?}", err); - this.set_status(Status::ConnectionLost); + this.set_status(Status::ConnectionLost, &cx); } } }) @@ -359,10 +399,10 @@ impl Client { }) } - pub async fn disconnect(&self) -> Result<()> { + pub async fn disconnect(self: &Arc, cx: &AsyncAppContext) -> Result<()> { let conn_id = self.connection_id()?; self.peer.disconnect(conn_id).await; - self.set_status(Status::Disconnected); + self.set_status(Status::Disconnected, cx); Ok(()) } @@ -444,13 +484,40 @@ const LOGIN_RESPONSE: &'static str = " "; -#[test] -fn test_encode_and_decode_worktree_url() { - let url = encode_worktree_url(5, "deadbeef"); - assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string()))); - assert_eq!( - decode_worktree_url(&format!("\n {}\t", url)), - Some((5, "deadbeef".to_string())) - ); - assert_eq!(decode_worktree_url("not://the-right-format"), None); +#[cfg(test)] +mod tests { + use super::*; + use crate::test::FakeServer; + use gpui::TestAppContext; + + #[gpui::test(iterations = 1000)] + async fn test_heartbeat(cx: TestAppContext) { + let user_id = 5; + let client = Client::new(); + + client.state.write().heartbeat_interval = Duration::from_millis(1); + let mut server = FakeServer::for_client(user_id, &client, &cx).await; + + let ping = server.receive::().await.unwrap(); + assert_eq!(ping.payload.id, 0); + server.respond(ping.receipt(), proto::Pong { id: 0 }).await; + + let ping = server.receive::().await.unwrap(); + assert_eq!(ping.payload.id, 1); + server.respond(ping.receipt(), proto::Pong { id: 1 }).await; + + client.disconnect(&cx.to_async()).await.unwrap(); + assert!(server.receive::().await.is_err()); + } + + #[test] + fn test_encode_and_decode_worktree_url() { + let url = encode_worktree_url(5, "deadbeef"); + assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string()))); + assert_eq!( + decode_worktree_url(&format!("\n {}\t", url)), + Some((5, "deadbeef".to_string())) + ); + assert_eq!(decode_worktree_url("not://the-right-format"), None); + } } diff --git a/zed/src/test.rs b/zed/src/test.rs index b917e428f6..f34ff55014 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -3,14 +3,16 @@ use crate::{ channel::ChannelList, fs::RealFs, language::LanguageRegistry, - rpc, + rpc::{self, Client}, settings::{self, ThemeRegistry}, time::ReplicaId, user::UserStore, AppState, }; -use gpui::{Entity, ModelHandle, MutableAppContext}; +use anyhow::{anyhow, Result}; +use gpui::{Entity, ModelHandle, MutableAppContext, TestAppContext}; use parking_lot::Mutex; +use postage::{mpsc, prelude::Stream as _}; use smol::channel; use std::{ marker::PhantomData, @@ -18,6 +20,7 @@ use std::{ sync::Arc, }; use tempdir::TempDir; +use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope}; #[cfg(feature = "test-support")] pub use zrpc::test::Channel; @@ -195,3 +198,50 @@ impl Observer { (observer, notify_rx) } } + +pub struct FakeServer { + peer: Arc, + incoming: mpsc::Receiver>, + connection_id: ConnectionId, +} + +impl FakeServer { + pub async fn for_client(user_id: u64, client: &Arc, cx: &TestAppContext) -> Self { + let (client_conn, server_conn) = zrpc::test::Channel::bidirectional(); + let peer = Peer::new(); + let (connection_id, io, incoming) = peer.add_connection(server_conn).await; + cx.background().spawn(io).detach(); + + client + .set_connection(user_id, client_conn, &cx.to_async()) + .await + .unwrap(); + + Self { + peer, + incoming, + connection_id, + } + } + + pub async fn send(&self, message: T) { + self.peer.send(self.connection_id, message).await.unwrap(); + } + + pub async fn receive(&mut self) -> Result> { + let message = self + .incoming + .recv() + .await + .ok_or_else(|| anyhow!("other half hung up"))?; + Ok(*message.into_any().downcast::>().unwrap()) + } + + pub async fn respond( + &self, + receipt: Receipt, + response: T::Response, + ) { + self.peer.respond(receipt, response).await.unwrap() + } +} From c3e29e0a2dbad232834d575d080923f4bd9f2763 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 8 Sep 2021 11:24:27 -0700 Subject: [PATCH 05/22] Finish implementing DeterministicExecutor::advance_clock * Start by running all non-timer futures to completion, to ensure that timers have a chance to be registered. * Release executor's state lock before waking any timers --- gpui/src/executor.rs | 64 +++++++++++++++++++++++++++++--------------- zed/src/rpc.rs | 8 +++--- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/gpui/src/executor.rs b/gpui/src/executor.rs index 7a223a96d2..5849a86902 100644 --- a/gpui/src/executor.rs +++ b/gpui/src/executor.rs @@ -124,17 +124,39 @@ impl Deterministic { T: 'static, F: Future + 'static, { - smol::pin!(future); - - let unparker = self.parker.lock().unparker(); let woken = Arc::new(AtomicBool::new(false)); - let waker = { - let woken = woken.clone(); - waker_fn(move || { - woken.store(true, SeqCst); - unparker.unpark(); - }) - }; + let mut future = Box::pin(future); + loop { + if let Some(result) = self.run_internal(woken.clone(), &mut future) { + return result; + } + + if !woken.load(SeqCst) && self.state.lock().forbid_parking { + panic!("deterministic executor parked after a call to forbid_parking"); + } + + woken.store(false, SeqCst); + self.parker.lock().park(); + } + } + + fn run_until_parked(&self) { + let woken = Arc::new(AtomicBool::new(false)); + let future = std::future::pending::<()>(); + smol::pin!(future); + self.run_internal(woken, future); + } + + pub fn run_internal(&self, woken: Arc, mut future: F) -> Option + where + T: 'static, + F: Future + Unpin, + { + let unparker = self.parker.lock().unparker(); + let waker = waker_fn(move || { + woken.store(true, SeqCst); + unparker.unpark(); + }); let mut cx = Context::from_waker(&waker); let mut trace = Trace::default(); @@ -168,23 +190,17 @@ impl Deterministic { runnable.run(); } else { drop(state); - if let Poll::Ready(result) = future.as_mut().poll(&mut cx) { - return result; + if let Poll::Ready(result) = future.poll(&mut cx) { + return Some(result); } + let state = self.state.lock(); if state.scheduled_from_foreground.is_empty() && state.scheduled_from_background.is_empty() && state.spawned_from_foreground.is_empty() { - if state.forbid_parking && !woken.load(SeqCst) { - panic!("deterministic executor parked after a call to forbid_parking"); - } - drop(state); - woken.store(false, SeqCst); - self.parker.lock().park(); + return None; } - - continue; } } } @@ -432,10 +448,16 @@ impl Foreground { pub fn advance_clock(&self, duration: Duration) { match self { Self::Deterministic(executor) => { + executor.run_until_parked(); + let mut state = executor.state.lock(); state.now += duration; let now = state.now; - state.pending_sleeps.retain(|(wakeup, _)| *wakeup > now); + let mut pending_sleeps = mem::take(&mut state.pending_sleeps); + drop(state); + + pending_sleeps.retain(|(wakeup, _)| *wakeup > now); + executor.state.lock().pending_sleeps.extend(pending_sleeps); } _ => panic!("this method can only be called on a deterministic executor"), } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index b36ec4d376..d0c04f5872 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -3,12 +3,10 @@ use anyhow::{anyhow, Context, Result}; use async_tungstenite::tungstenite::{ http::Request, Error as WebSocketError, Message as WebSocketMessage, }; -use futures::StreamExt as _; use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use lazy_static::lazy_static; use parking_lot::RwLock; use postage::{prelude::Stream, watch}; -use smol::Timer; use std::{ any::TypeId, collections::HashMap, @@ -490,18 +488,18 @@ mod tests { use crate::test::FakeServer; use gpui::TestAppContext; - #[gpui::test(iterations = 1000)] + #[gpui::test(iterations = 10)] async fn test_heartbeat(cx: TestAppContext) { let user_id = 5; let client = Client::new(); - - client.state.write().heartbeat_interval = Duration::from_millis(1); let mut server = FakeServer::for_client(user_id, &client, &cx).await; + cx.foreground().advance_clock(Duration::from_secs(10)); let ping = server.receive::().await.unwrap(); assert_eq!(ping.payload.id, 0); server.respond(ping.receipt(), proto::Pong { id: 0 }).await; + cx.foreground().advance_clock(Duration::from_secs(10)); let ping = server.receive::().await.unwrap(); assert_eq!(ping.payload.id, 1); server.respond(ping.receipt(), proto::Pong { id: 1 }).await; From b6eac57f6385296248ca811d6fa56804b0c98c97 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 8 Sep 2021 17:49:07 -0700 Subject: [PATCH 06/22] Introduce a concrete `Conn` type for peer's websocket connection This is mostly to set us up to test the rpc::Client's reconnect logic. There are multiple ways that the `rpc::Client` may establish its websocket connection: (SSL in production, plain TCP during local development, and using an in-memory connection for tests). Now we can represent all of those connections using a common type. Also, several long methods no longer need to be generic, which is good for compile time. --- server/src/rpc.rs | 25 +++------ zed/src/channel.rs | 3 +- zed/src/rpc.rs | 137 ++++++++++++++++++++++----------------------- zed/src/test.rs | 67 +++++++++++++++------- zrpc/src/conn.rs | 54 ++++++++++++++++++ zrpc/src/lib.rs | 5 +- zrpc/src/peer.rs | 42 ++++++-------- zrpc/src/proto.rs | 27 --------- zrpc/src/test.rs | 64 --------------------- 9 files changed, 196 insertions(+), 228 deletions(-) create mode 100644 zrpc/src/conn.rs delete mode 100644 zrpc/src/test.rs diff --git a/server/src/rpc.rs b/server/src/rpc.rs index e1b1bce058..c234967444 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -5,10 +5,7 @@ use super::{ }; use anyhow::anyhow; use async_std::{sync::RwLock, task}; -use async_tungstenite::{ - tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage}, - WebSocketStream, -}; +use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use futures::{future::BoxFuture, FutureExt}; use postage::{mpsc, prelude::Sink as _, prelude::Stream as _}; use sha1::{Digest as _, Sha1}; @@ -30,7 +27,7 @@ use time::OffsetDateTime; use zrpc::{ auth::random_token, proto::{self, AnyTypedEnvelope, EnvelopedMessage}, - ConnectionId, Peer, TypedEnvelope, + Conn, ConnectionId, Peer, TypedEnvelope, }; type ReplicaId = u16; @@ -133,19 +130,12 @@ impl Server { self } - pub fn handle_connection( + pub fn handle_connection( self: &Arc, connection: Conn, addr: String, user_id: UserId, - ) -> impl Future - where - Conn: 'static - + futures::Sink - + futures::Stream> - + Send - + Unpin, - { + ) -> impl Future { let this = self.clone(); async move { let (connection_id, handle_io, mut incoming_rx) = @@ -974,8 +964,7 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?; task::spawn(async move { if let Some(stream) = upgrade_receiver.await { - let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await; - server.handle_connection(stream, addr, user_id).await; + server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await; } }); @@ -1019,7 +1008,7 @@ mod tests { fs::{FakeFs, Fs as _}, language::LanguageRegistry, rpc::Client, - settings, test, + settings, user::UserStore, worktree::Worktree, }; @@ -1706,7 +1695,7 @@ mod tests { ) -> (UserId, Arc) { let user_id = self.app_state.db.create_user(name, false).await.unwrap(); let client = Client::new(); - let (client_conn, server_conn) = test::Channel::bidirectional(); + let (client_conn, server_conn) = Conn::in_memory(); cx.background() .spawn( self.server diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 234e3e1e5f..68f5299a8b 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -445,12 +445,13 @@ mod tests { use super::*; use crate::test::FakeServer; use gpui::TestAppContext; + use std::time::Duration; #[gpui::test] async fn test_channel_messages(mut cx: TestAppContext) { let user_id = 5; let client = Client::new(); - let mut server = FakeServer::for_client(user_id, &client, &cx).await; + let server = FakeServer::for_client(user_id, &client, &cx).await; let user_store = Arc::new(UserStore::new(client.clone())); let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx)); diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index d0c04f5872..64fc8a56ea 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -1,8 +1,6 @@ use crate::util::ResultExt; use anyhow::{anyhow, Context, Result}; -use async_tungstenite::tungstenite::{ - http::Request, Error as WebSocketError, Message as WebSocketMessage, -}; +use async_tungstenite::tungstenite::http::Request; use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use lazy_static::lazy_static; use parking_lot::RwLock; @@ -19,7 +17,7 @@ use surf::Url; pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope}; use zrpc::{ proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}, - Peer, Receipt, + Conn, Peer, Receipt, }; lazy_static! { @@ -106,6 +104,7 @@ impl Client { fn set_status(self: &Arc, status: Status, cx: &AsyncAppContext) { let mut state = self.state.write(); *state.status.0.borrow_mut() = status; + match status { Status::Connected { .. } => { let heartbeat_interval = state.heartbeat_interval; @@ -193,75 +192,46 @@ impl Client { ) -> anyhow::Result<()> { if matches!( *self.status().borrow(), - Status::Connecting | Status::Connected { .. } + Status::Connecting { .. } | Status::Connected { .. } ) { return Ok(()); } - let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?; - let user_id = user_id.parse::()?; - - self.set_status(Status::Connecting, cx); - match self.connect(user_id, &access_token, cx).await { - Ok(()) => { - log::info!("connected to rpc address {}", *ZED_SERVER_URL); - Ok(()) - } + let (user_id, access_token) = match self.authenticate(&cx).await { + Ok(result) => result, Err(err) => { self.set_status(Status::ConnectionError, cx); - Err(err) + return Err(err); } - } + }; + + self.set_status(Status::Connecting, cx); + + let conn = match self.connect(user_id, &access_token, cx).await { + Ok(conn) => conn, + Err(err) => { + self.set_status(Status::ConnectionError, cx); + return Err(err); + } + }; + + self.set_connection(user_id, conn, cx).await?; + log::info!("connected to rpc address {}", *ZED_SERVER_URL); + Ok(()) } - async fn connect( - self: &Arc, - user_id: u64, - access_token: &str, - cx: &AsyncAppContext, - ) -> Result<()> { - let request = - Request::builder().header("Authorization", format!("{} {}", user_id, access_token)); - if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") { - let stream = smol::net::TcpStream::connect(host).await?; - let request = request.uri(format!("wss://{}/rpc", host)).body(())?; - let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream) - .await - .context("websocket handshake")?; - self.set_connection(user_id, stream, cx).await?; - Ok(()) - } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { - let stream = smol::net::TcpStream::connect(host).await?; - let request = request.uri(format!("ws://{}/rpc", host)).body(())?; - let (stream, _) = async_tungstenite::client_async(request, stream) - .await - .context("websocket handshake")?; - self.set_connection(user_id, stream, cx).await?; - Ok(()) - } else { - return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL)); - } - } - - pub async fn set_connection( + pub async fn set_connection( self: &Arc, user_id: u64, conn: Conn, cx: &AsyncAppContext, - ) -> Result<()> - where - Conn: 'static - + futures::Sink - + futures::Stream> - + Unpin - + Send, - { + ) -> Result<()> { let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await; - { - let mut cx = cx.clone(); - let this = self.clone(); - cx.foreground() - .spawn(async move { + cx.foreground() + .spawn({ + let mut cx = cx.clone(); + let this = self.clone(); + async move { while let Some(message) = incoming.recv().await { let mut state = this.state.write(); if let Some(extract_entity_id) = @@ -286,9 +256,9 @@ impl Client { log::info!("unhandled message {}", message.payload_type_name()); } } - }) - .detach(); - } + } + }) + .detach(); self.set_status( Status::Connected { @@ -315,11 +285,38 @@ impl Client { Ok(()) } - pub fn login( - platform: Arc, - executor: &Arc, - ) -> Task> { - let executor = executor.clone(); + fn connect( + self: &Arc, + user_id: u64, + access_token: &str, + cx: &AsyncAppContext, + ) -> Task> { + let request = + Request::builder().header("Authorization", format!("{} {}", user_id, access_token)); + cx.background().spawn(async move { + if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") { + let stream = smol::net::TcpStream::connect(host).await?; + let request = request.uri(format!("wss://{}/rpc", host)).body(())?; + let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream) + .await + .context("websocket handshake")?; + Ok(Conn::new(stream)) + } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { + let stream = smol::net::TcpStream::connect(host).await?; + let request = request.uri(format!("ws://{}/rpc", host)).body(())?; + let (stream, _) = async_tungstenite::client_async(request, stream) + .await + .context("websocket handshake")?; + Ok(Conn::new(stream)) + } else { + Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL)) + } + }) + } + + pub fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { + let platform = cx.platform(); + let executor = cx.background(); executor.clone().spawn(async move { if let Some((user_id, access_token)) = platform .read_credentials(&ZED_SERVER_URL) @@ -327,7 +324,7 @@ impl Client { .flatten() { log::info!("already signed in. user_id: {}", user_id); - return Ok((user_id, String::from_utf8(access_token).unwrap())); + return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap())); } // Generate a pair of asymmetric encryption keys. The public key will be used by the @@ -393,7 +390,7 @@ impl Client { platform .write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes()) .log_err(); - Ok((user_id.to_string(), access_token)) + Ok((user_id.parse()?, access_token)) }) } @@ -492,7 +489,7 @@ mod tests { async fn test_heartbeat(cx: TestAppContext) { let user_id = 5; let client = Client::new(); - let mut server = FakeServer::for_client(user_id, &client, &cx).await; + let server = FakeServer::for_client(user_id, &client, &cx).await; cx.foreground().advance_clock(Duration::from_secs(10)); let ping = server.receive::().await.unwrap(); diff --git a/zed/src/test.rs b/zed/src/test.rs index f34ff55014..e5169ecb69 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -10,7 +10,7 @@ use crate::{ AppState, }; use anyhow::{anyhow, Result}; -use gpui::{Entity, ModelHandle, MutableAppContext, TestAppContext}; +use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext}; use parking_lot::Mutex; use postage::{mpsc, prelude::Stream as _}; use smol::channel; @@ -20,10 +20,7 @@ use std::{ sync::Arc, }; use tempdir::TempDir; -use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope}; - -#[cfg(feature = "test-support")] -pub use zrpc::test::Channel; +use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope}; #[cfg(test)] #[ctor::ctor] @@ -201,40 +198,64 @@ impl Observer { pub struct FakeServer { peer: Arc, - incoming: mpsc::Receiver>, - connection_id: ConnectionId, + incoming: Mutex>>>, + connection_id: Mutex>, } impl FakeServer { - pub async fn for_client(user_id: u64, client: &Arc, cx: &TestAppContext) -> Self { - let (client_conn, server_conn) = zrpc::test::Channel::bidirectional(); - let peer = Peer::new(); - let (connection_id, io, incoming) = peer.add_connection(server_conn).await; - cx.background().spawn(io).detach(); + pub async fn for_client(user_id: u64, client: &Arc, cx: &TestAppContext) -> Arc { + let result = Arc::new(Self { + peer: Peer::new(), + incoming: Default::default(), + connection_id: Default::default(), + }); + let conn = result.connect(&cx.to_async()).await; client - .set_connection(user_id, client_conn, &cx.to_async()) + .set_connection(user_id, conn, &cx.to_async()) .await .unwrap(); + result + } - Self { - peer, - incoming, - connection_id, - } + pub async fn disconnect(&self) { + self.peer.disconnect(self.connection_id()).await; + self.connection_id.lock().take(); + self.incoming.lock().take(); + } + + async fn connect(&self, cx: &AsyncAppContext) -> Conn { + let (client_conn, server_conn) = Conn::in_memory(); + let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; + cx.background().spawn(io).detach(); + *self.incoming.lock() = Some(incoming); + *self.connection_id.lock() = Some(connection_id); + client_conn } pub async fn send(&self, message: T) { - self.peer.send(self.connection_id, message).await.unwrap(); + self.peer.send(self.connection_id(), message).await.unwrap(); } - pub async fn receive(&mut self) -> Result> { + pub async fn receive(&self) -> Result> { let message = self .incoming + .lock() + .as_mut() + .expect("not connected") .recv() .await .ok_or_else(|| anyhow!("other half hung up"))?; - Ok(*message.into_any().downcast::>().unwrap()) + let type_name = message.payload_type_name(); + Ok(*message + .into_any() + .downcast::>() + .unwrap_or_else(|_| { + panic!( + "fake server received unexpected message type: {:?}", + type_name + ); + })) } pub async fn respond( @@ -244,4 +265,8 @@ impl FakeServer { ) { self.peer.respond(receipt, response).await.unwrap() } + + fn connection_id(&self) -> ConnectionId { + self.connection_id.lock().expect("not connected") + } } diff --git a/zrpc/src/conn.rs b/zrpc/src/conn.rs new file mode 100644 index 0000000000..06dbcee077 --- /dev/null +++ b/zrpc/src/conn.rs @@ -0,0 +1,54 @@ +use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; +use futures::{SinkExt as _, StreamExt as _}; + +pub struct Conn { + pub(crate) tx: + Box>, + pub(crate) rx: Box< + dyn 'static + + Send + + Unpin + + futures::Stream>, + >, +} + +impl Conn { + pub fn new(stream: S) -> Self + where + S: 'static + + Send + + Unpin + + futures::Sink + + futures::Stream>, + { + let (tx, rx) = stream.split(); + Self { + tx: Box::new(tx), + rx: Box::new(rx), + } + } + + pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> { + self.tx.send(message).await + } + + #[cfg(any(test, feature = "test-support"))] + pub fn in_memory() -> (Self, Self) { + use futures::SinkExt as _; + use futures::StreamExt as _; + use std::io::{Error, ErrorKind}; + + let (a_tx, a_rx) = futures::channel::mpsc::unbounded::(); + let (b_tx, b_rx) = futures::channel::mpsc::unbounded::(); + ( + Self { + tx: Box::new(a_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())), + rx: Box::new(b_rx.map(Ok)), + }, + Self { + tx: Box::new(b_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())), + rx: Box::new(a_rx.map(Ok)), + }, + ) + } +} diff --git a/zrpc/src/lib.rs b/zrpc/src/lib.rs index 8cafad9f1f..b3973cae19 100644 --- a/zrpc/src/lib.rs +++ b/zrpc/src/lib.rs @@ -1,7 +1,6 @@ pub mod auth; +mod conn; mod peer; pub mod proto; -#[cfg(any(test, feature = "test-support"))] -pub mod test; - +pub use conn::Conn; pub use peer::*; diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index 5b6ae8655a..d50ee50ec3 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -1,8 +1,8 @@ -use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage}; +use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage}; +use super::Conn; use anyhow::{anyhow, Context, Result}; use async_lock::{Mutex, RwLock}; -use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; -use futures::{FutureExt, StreamExt}; +use futures::FutureExt as _; use postage::{ mpsc, prelude::{Sink as _, Stream as _}, @@ -98,21 +98,14 @@ impl Peer { }) } - pub async fn add_connection( + pub async fn add_connection( self: &Arc, conn: Conn, ) -> ( ConnectionId, impl Future> + Send, mpsc::Receiver>, - ) - where - Conn: futures::Sink - + futures::Stream> - + Send - + Unpin, - { - let (tx, rx) = conn.split(); + ) { let connection_id = ConnectionId( self.next_connection_id .fetch_add(1, atomic::Ordering::SeqCst), @@ -124,8 +117,8 @@ impl Peer { next_message_id: Default::default(), response_channels: Default::default(), }; - let mut writer = MessageStream::new(tx); - let mut reader = MessageStream::new(rx); + let mut writer = MessageStream::new(conn.tx); + let mut reader = MessageStream::new(conn.rx); let this = self.clone(); let response_channels = connection.response_channels.clone(); @@ -347,7 +340,9 @@ impl Peer { #[cfg(test)] mod tests { use super::*; - use crate::{test, TypedEnvelope}; + use crate::TypedEnvelope; + use async_tungstenite::tungstenite::Message as WebSocketMessage; + use futures::StreamExt as _; #[test] fn test_request_response() { @@ -357,12 +352,12 @@ mod tests { let client1 = Peer::new(); let client2 = Peer::new(); - let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional(); + let (client1_to_server_conn, server_to_client_1_conn) = Conn::in_memory(); let (client1_conn_id, io_task1, _) = client1.add_connection(client1_to_server_conn).await; let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await; - let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional(); + let (client2_to_server_conn, server_to_client_2_conn) = Conn::in_memory(); let (client2_conn_id, io_task3, _) = client2.add_connection(client2_to_server_conn).await; let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await; @@ -497,7 +492,7 @@ mod tests { #[test] fn test_disconnect() { smol::block_on(async move { - let (client_conn, mut server_conn) = test::Channel::bidirectional(); + let (client_conn, mut server_conn) = Conn::in_memory(); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = @@ -521,18 +516,17 @@ mod tests { io_ended_rx.recv().await; messages_ended_rx.recv().await; - assert!( - futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![])) - .await - .is_err() - ); + assert!(server_conn + .send(WebSocketMessage::Binary(vec![])) + .await + .is_err()); }); } #[test] fn test_io_error() { smol::block_on(async move { - let (client_conn, server_conn) = test::Channel::bidirectional(); + let (client_conn, server_conn) = Conn::in_memory(); drop(server_conn); let client = Peer::new(); diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index 002c5bc840..ded7fdc1cd 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -247,30 +247,3 @@ impl From for Timestamp { } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::test; - - #[test] - fn test_round_trip_message() { - smol::block_on(async { - let stream = test::Channel::new(); - let message1 = Ping { id: 5 }.into_envelope(3, None, None); - let message2 = OpenBuffer { - worktree_id: 0, - path: "some/path".to_string(), - } - .into_envelope(5, None, None); - - let mut message_stream = MessageStream::new(stream); - message_stream.write_message(&message1).await.unwrap(); - message_stream.write_message(&message2).await.unwrap(); - let decoded_message1 = message_stream.read_message().await.unwrap(); - let decoded_message2 = message_stream.read_message().await.unwrap(); - assert_eq!(decoded_message1, message1); - assert_eq!(decoded_message2, message2); - }); - } -} diff --git a/zrpc/src/test.rs b/zrpc/src/test.rs deleted file mode 100644 index ad698a4094..0000000000 --- a/zrpc/src/test.rs +++ /dev/null @@ -1,64 +0,0 @@ -use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; -use std::{ - io, - pin::Pin, - task::{Context, Poll}, -}; - -pub struct Channel { - tx: futures::channel::mpsc::UnboundedSender, - rx: futures::channel::mpsc::UnboundedReceiver, -} - -impl Channel { - pub fn new() -> Self { - let (tx, rx) = futures::channel::mpsc::unbounded(); - Self { tx, rx } - } - - pub fn bidirectional() -> (Self, Self) { - let (a_tx, a_rx) = futures::channel::mpsc::unbounded(); - let (b_tx, b_rx) = futures::channel::mpsc::unbounded(); - let a = Self { tx: a_tx, rx: b_rx }; - let b = Self { tx: b_tx, rx: a_rx }; - (a, b) - } -} - -impl futures::Sink for Channel { - type Error = WebSocketError; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.tx) - .poll_ready(cx) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into()) - } - - fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> { - Pin::new(&mut self.tx) - .start_send(item) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into()) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.tx) - .poll_flush(cx) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into()) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.tx) - .poll_close(cx) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into()) - } -} - -impl futures::Stream for Channel { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.rx) - .poll_next(cx) - .map(|i| i.map(|i| Ok(i))) - } -} From edbd424b75e4d6ce5ebfd0a67736b8885498945c Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 8 Sep 2021 18:19:59 -0700 Subject: [PATCH 07/22] Introduce test-only APIs for configuring how Client reconnects --- zed/src/channel.rs | 4 +-- zed/src/rpc.rs | 62 +++++++++++++++++++++++++++++++++++++++++----- zed/src/test.rs | 30 ++++++++++++++++++++-- 3 files changed, 86 insertions(+), 10 deletions(-) diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 68f5299a8b..e0d65e8969 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -450,8 +450,8 @@ mod tests { #[gpui::test] async fn test_channel_messages(mut cx: TestAppContext) { let user_id = 5; - let client = Client::new(); - let server = FakeServer::for_client(user_id, &client, &cx).await; + let mut client = Client::new(); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; let user_store = Arc::new(UserStore::new(client.clone())); let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx)); diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 64fc8a56ea..81cc10395f 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -28,12 +28,21 @@ lazy_static! { pub struct Client { peer: Arc, state: RwLock, + auth_callback: Option< + Box Task>>, + >, + connect_callback: Option< + Box Task>>, + >, } #[derive(Copy, Clone, Debug)] pub enum Status { Disconnected, - Connecting, + Authenticating, + Connecting { + user_id: u64, + }, ConnectionError, Connected { connection_id: ConnectionId, @@ -94,9 +103,24 @@ impl Client { Arc::new(Self { peer: Peer::new(), state: Default::default(), + auth_callback: None, + connect_callback: None, }) } + #[cfg(any(test, feature = "test-support"))] + pub fn set_login_and_connect_callbacks( + &mut self, + login: Login, + connect: Connect, + ) where + Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task>, + Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task>, + { + self.auth_callback = Some(Box::new(login)); + self.connect_callback = Some(Box::new(connect)); + } + pub fn status(&self) -> watch::Receiver { self.state.read().status.1.clone() } @@ -192,11 +216,13 @@ impl Client { ) -> anyhow::Result<()> { if matches!( *self.status().borrow(), - Status::Connecting { .. } | Status::Connected { .. } + Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. } ) { return Ok(()); } + self.set_status(Status::Authenticating, cx); + let (user_id, access_token) = match self.authenticate(&cx).await { Ok(result) => result, Err(err) => { @@ -205,7 +231,7 @@ impl Client { } }; - self.set_status(Status::Connecting, cx); + self.set_status(Status::Connecting { user_id }, cx); let conn = match self.connect(user_id, &access_token, cx).await { Ok(conn) => conn, @@ -285,11 +311,32 @@ impl Client { Ok(()) } + fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { + if let Some(callback) = self.auth_callback.as_ref() { + callback(cx) + } else { + self.authenticate_with_browser(cx) + } + } + fn connect( self: &Arc, user_id: u64, access_token: &str, cx: &AsyncAppContext, + ) -> Task> { + if let Some(callback) = self.connect_callback.as_ref() { + callback(user_id, access_token, cx) + } else { + self.connect_with_websocket(user_id, access_token, cx) + } + } + + fn connect_with_websocket( + self: &Arc, + user_id: u64, + access_token: &str, + cx: &AsyncAppContext, ) -> Task> { let request = Request::builder().header("Authorization", format!("{} {}", user_id, access_token)); @@ -314,7 +361,10 @@ impl Client { }) } - pub fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { + pub fn authenticate_with_browser( + self: &Arc, + cx: &AsyncAppContext, + ) -> Task> { let platform = cx.platform(); let executor = cx.background(); executor.clone().spawn(async move { @@ -488,8 +538,8 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_heartbeat(cx: TestAppContext) { let user_id = 5; - let client = Client::new(); - let server = FakeServer::for_client(user_id, &client, &cx).await; + let mut client = Client::new(); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; cx.foreground().advance_clock(Duration::from_secs(10)); let ping = server.receive::().await.unwrap(); diff --git a/zed/src/test.rs b/zed/src/test.rs index e5169ecb69..bee1537b9d 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -203,16 +203,42 @@ pub struct FakeServer { } impl FakeServer { - pub async fn for_client(user_id: u64, client: &Arc, cx: &TestAppContext) -> Arc { + pub async fn for_client( + client_user_id: u64, + client: &mut Arc, + cx: &TestAppContext, + ) -> Arc { let result = Arc::new(Self { peer: Peer::new(), incoming: Default::default(), connection_id: Default::default(), }); + Arc::get_mut(client) + .unwrap() + .set_login_and_connect_callbacks( + move |cx| { + cx.spawn(|_| async move { + let access_token = "the-token".to_string(); + Ok((client_user_id, access_token)) + }) + }, + { + let server = result.clone(); + move |user_id, access_token, cx| { + assert_eq!(user_id, client_user_id); + assert_eq!(access_token, "the-token"); + cx.spawn({ + let server = server.clone(); + move |cx| async move { Ok(server.connect(&cx).await) } + }) + } + }, + ); + let conn = result.connect(&cx.to_async()).await; client - .set_connection(user_id, conn, &cx.to_async()) + .set_connection(client_user_id, conn, &cx.to_async()) .await .unwrap(); result From 6baa9fe37b60c46bffe0a6d89e6a0dd632273d32 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 8 Sep 2021 18:20:32 -0700 Subject: [PATCH 08/22] WIP - Start work on reconnect logic --- zed/src/channel.rs | 65 ++++++++++++++++++++++++++++++++++++++++++++++ zed/src/rpc.rs | 21 ++++++++++++--- 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/zed/src/channel.rs b/zed/src/channel.rs index e0d65e8969..601e47355e 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -653,4 +653,69 @@ mod tests { ); }); } + + #[gpui::test] + async fn test_channel_reconnect(mut cx: TestAppContext) { + let user_id = 5; + let mut client = Client::new(); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; + + let user_store = Arc::new(UserStore::new(client.clone())); + + let channel = cx.add_model(|cx| { + Channel::new( + ChannelDetails { + id: 1, + name: "general".into(), + }, + user_store, + client.clone(), + cx, + ) + }); + + let join_channel = server.receive::().await.unwrap(); + server + .respond( + join_channel.receipt(), + proto::JoinChannelResponse { + messages: vec![ + proto::ChannelMessage { + id: 10, + body: "a".into(), + timestamp: 1000, + sender_id: 5, + }, + proto::ChannelMessage { + id: 11, + body: "b".into(), + timestamp: 1001, + sender_id: 5, + }, + ], + done: false, + }, + ) + .await; + + let get_users = server.receive::().await.unwrap(); + assert_eq!(get_users.payload.user_ids, vec![5]); + server + .respond( + get_users.receipt(), + proto::GetUsersResponse { + users: vec![proto::User { + id: 5, + github_login: "nathansobo".into(), + avatar_url: "http://avatar.com/nathansobo".into(), + }], + }, + ) + .await; + + // Disconnect, wait for the client to reconnect. + server.disconnect().await; + cx.foreground().advance_clock(Duration::from_secs(10)); + let get_messages = server.receive::().await.unwrap(); + } } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 81cc10395f..518e291704 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -146,14 +146,27 @@ impl Client { })); } Status::ConnectionLost => { - state._maintain_connection = Some(cx.foreground().spawn(async move { - // TODO: try to reconnect + let this = self.clone(); + let foreground = cx.foreground(); + state._maintain_connection = Some(cx.spawn(|cx| async move { + let mut delay_seconds = 5; + while let Err(error) = this.authenticate_and_connect(&cx).await { + log::error!("failed to connect {}", error); + let delay = Duration::from_secs(delay_seconds); + this.set_status( + Status::ReconnectionError { + next_reconnection: Instant::now() + delay, + }, + &cx, + ); + foreground.sleep(delay).await; + delay_seconds = (delay_seconds * 2).min(300); + } })); } - Status::Disconnected => { + _ => { state._maintain_connection.take(); } - _ => {} } } From ad7631de9f7687463020d5f401008ed950965e08 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 9 Sep 2021 11:00:43 +0200 Subject: [PATCH 09/22] Refactor and write a simple unit test to verify reconnection logic --- server/src/rpc.rs | 48 +++++++++++++++----- zed/src/rpc.rs | 112 +++++++++++++++++++++++----------------------- zed/src/test.rs | 38 +++++++++++----- 3 files changed, 121 insertions(+), 77 deletions(-) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index c234967444..f623f29649 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1693,20 +1693,46 @@ mod tests { cx: &mut TestAppContext, name: &str, ) -> (UserId, Arc) { - let user_id = self.app_state.db.create_user(name, false).await.unwrap(); - let client = Client::new(); - let (client_conn, server_conn) = Conn::in_memory(); - cx.background() - .spawn( - self.server - .handle_connection(server_conn, name.to_string(), user_id), - ) - .detach(); + let client_user_id = self.app_state.db.create_user(name, false).await.unwrap(); + let client_name = name.to_string(); + let mut client = Client::new(); + let server = self.server.clone(); + Arc::get_mut(&mut client) + .unwrap() + .set_login_and_connect_callbacks( + move |cx| { + cx.spawn(|_| async move { + let access_token = "the-token".to_string(); + Ok((client_user_id.0 as u64, access_token)) + }) + }, + { + move |user_id, access_token, cx| { + assert_eq!(user_id, client_user_id.0 as u64); + assert_eq!(access_token, "the-token"); + + let server = server.clone(); + let client_name = client_name.clone(); + cx.spawn(move |cx| async move { + let (client_conn, server_conn) = Conn::in_memory(); + cx.background() + .spawn(server.handle_connection( + server_conn, + client_name, + client_user_id, + )) + .detach(); + Ok(client_conn) + }) + } + }, + ); + client - .set_connection(user_id.to_proto(), client_conn, &cx.to_async()) + .authenticate_and_connect(&cx.to_async()) .await .unwrap(); - (user_id, client) + (client_user_id, client) } async fn build_app_state(test_db: &TestDb) -> Arc { diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 518e291704..3399c1cf41 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -49,7 +49,10 @@ pub enum Status { user_id: u64, }, ConnectionLost, - Reconnecting, + Reauthenticating, + Reconnecting { + user_id: u64, + }, ReconnectionError { next_reconnection: Instant, }, @@ -164,9 +167,10 @@ impl Client { } })); } - _ => { + Status::Disconnected => { state._maintain_connection.take(); } + _ => {} } } @@ -227,14 +231,20 @@ impl Client { self: &Arc, cx: &AsyncAppContext, ) -> anyhow::Result<()> { - if matches!( - *self.status().borrow(), - Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. } - ) { - return Ok(()); - } + let was_disconnected = match *self.status().borrow() { + Status::Disconnected => true, + Status::Connected { .. } + | Status::Connecting { .. } + | Status::Reconnecting { .. } + | Status::Reauthenticating => return Ok(()), + _ => false, + }; - self.set_status(Status::Authenticating, cx); + if was_disconnected { + self.set_status(Status::Authenticating, cx); + } else { + self.set_status(Status::Reauthenticating, cx) + } let (user_id, access_token) = match self.authenticate(&cx).await { Ok(result) => result, @@ -244,27 +254,25 @@ impl Client { } }; - self.set_status(Status::Connecting { user_id }, cx); - - let conn = match self.connect(user_id, &access_token, cx).await { - Ok(conn) => conn, + if was_disconnected { + self.set_status(Status::Connecting { user_id }, cx); + } else { + self.set_status(Status::Reconnecting { user_id }, cx); + } + match self.connect(user_id, &access_token, cx).await { + Ok(conn) => { + log::info!("connected to rpc address {}", *ZED_SERVER_URL); + self.set_connection(user_id, conn, cx).await; + Ok(()) + } Err(err) => { self.set_status(Status::ConnectionError, cx); - return Err(err); + Err(err) } - }; - - self.set_connection(user_id, conn, cx).await?; - log::info!("connected to rpc address {}", *ZED_SERVER_URL); - Ok(()) + } } - pub async fn set_connection( - self: &Arc, - user_id: u64, - conn: Conn, - cx: &AsyncAppContext, - ) -> Result<()> { + async fn set_connection(self: &Arc, user_id: u64, conn: Conn, cx: &AsyncAppContext) { let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await; cx.foreground() .spawn({ @@ -321,7 +329,6 @@ impl Client { } }) .detach(); - Ok(()) } fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { @@ -489,35 +496,6 @@ impl Client { } } -pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone { - type Output: 'a + Future>; - - fn handle( - &self, - message: TypedEnvelope, - rpc: &'a Client, - cx: &'a mut gpui::AsyncAppContext, - ) -> Self::Output; -} - -impl<'a, M, F, Fut> MessageHandler<'a, M> for F -where - M: proto::EnvelopedMessage, - F: Clone + Fn(TypedEnvelope, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut, - Fut: 'a + Future>, -{ - type Output = Fut; - - fn handle( - &self, - message: TypedEnvelope, - rpc: &'a Client, - cx: &'a mut gpui::AsyncAppContext, - ) -> Self::Output { - (self)(message, rpc, cx) - } -} - const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/"; pub fn encode_worktree_url(id: u64, access_token: &str) -> String { @@ -550,6 +528,8 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_heartbeat(cx: TestAppContext) { + cx.foreground().forbid_parking(); + let user_id = 5; let mut client = Client::new(); let server = FakeServer::for_client(user_id, &mut client, &cx).await; @@ -568,6 +548,28 @@ mod tests { assert!(server.receive::().await.is_err()); } + #[gpui::test(iterations = 10)] + async fn test_reconnection(cx: TestAppContext) { + cx.foreground().forbid_parking(); + + let user_id = 5; + let mut client = Client::new(); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; + let mut status = client.status(); + assert!(matches!( + status.recv().await, + Some(Status::Connected { .. }) + )); + + server.forbid_connections(); + server.disconnect().await; + while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {} + + server.allow_connections(); + cx.foreground().advance_clock(Duration::from_secs(10)); + while !matches!(status.recv().await, Some(Status::Connected { .. })) {} + } + #[test] fn test_encode_and_decode_worktree_url() { let url = encode_worktree_url(5, "deadbeef"); diff --git a/zed/src/test.rs b/zed/src/test.rs index bee1537b9d..cf1fbfd9e8 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -17,7 +17,10 @@ use smol::channel; use std::{ marker::PhantomData, path::{Path, PathBuf}, - sync::Arc, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, }; use tempdir::TempDir; use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope}; @@ -200,6 +203,7 @@ pub struct FakeServer { peer: Arc, incoming: Mutex>>>, connection_id: Mutex>, + forbid_connections: AtomicBool, } impl FakeServer { @@ -212,6 +216,7 @@ impl FakeServer { peer: Peer::new(), incoming: Default::default(), connection_id: Default::default(), + forbid_connections: Default::default(), }); Arc::get_mut(client) @@ -230,15 +235,14 @@ impl FakeServer { assert_eq!(access_token, "the-token"); cx.spawn({ let server = server.clone(); - move |cx| async move { Ok(server.connect(&cx).await) } + move |cx| async move { server.connect(&cx).await } }) } }, ); - let conn = result.connect(&cx.to_async()).await; client - .set_connection(client_user_id, conn, &cx.to_async()) + .authenticate_and_connect(&cx.to_async()) .await .unwrap(); result @@ -250,13 +254,25 @@ impl FakeServer { self.incoming.lock().take(); } - async fn connect(&self, cx: &AsyncAppContext) -> Conn { - let (client_conn, server_conn) = Conn::in_memory(); - let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; - cx.background().spawn(io).detach(); - *self.incoming.lock() = Some(incoming); - *self.connection_id.lock() = Some(connection_id); - client_conn + async fn connect(&self, cx: &AsyncAppContext) -> Result { + if self.forbid_connections.load(SeqCst) { + Err(anyhow!("server is forbidding connections")) + } else { + let (client_conn, server_conn) = Conn::in_memory(); + let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; + cx.background().spawn(io).detach(); + *self.incoming.lock() = Some(incoming); + *self.connection_id.lock() = Some(connection_id); + Ok(client_conn) + } + } + + pub fn forbid_connections(&self) { + self.forbid_connections.store(true, SeqCst); + } + + pub fn allow_connections(&self) { + self.forbid_connections.store(false, SeqCst); } pub async fn send(&self, message: T) { From 34d8f997149f1ada7ce2d430a837ef5831afdefa Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 9 Sep 2021 11:08:41 +0200 Subject: [PATCH 10/22] Respond to RPC pings in the server --- server/src/rpc.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index f623f29649..f4d16af996 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -92,6 +92,7 @@ impl Server { }; server + .add_handler(Server::ping) .add_handler(Server::share_worktree) .add_handler(Server::join_worktree) .add_handler(Server::update_worktree) @@ -244,6 +245,18 @@ impl Server { worktree_ids } + async fn ping(self: Arc, request: TypedEnvelope) -> tide::Result<()> { + self.peer + .respond( + request.receipt(), + proto::Pong { + id: request.payload.id, + }, + ) + .await?; + Ok(()) + } + async fn share_worktree( self: Arc, mut request: TypedEnvelope, From 156fd4ba576bb201065e34de195f5a736d698332 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 9 Sep 2021 13:27:44 +0200 Subject: [PATCH 11/22] Add integration test simulating killing a connection while chatting --- server/src/rpc.rs | 214 +++++++++++++++++++++++++++++++++++++++++++--- zed/src/test.rs | 2 +- zrpc/src/conn.rs | 61 ++++++++++--- zrpc/src/peer.rs | 8 +- 4 files changed, 253 insertions(+), 32 deletions(-) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index f4d16af996..cc0d35d097 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1011,16 +1011,24 @@ mod tests { }; use async_std::{sync::RwLockReadGuard, task}; use gpui::TestAppContext; - use postage::mpsc; + use parking_lot::Mutex; + use postage::{mpsc, watch}; use serde_json::json; use sqlx::types::time::OffsetDateTime; - use std::{path::Path, sync::Arc, time::Duration}; + use std::{ + path::Path, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, + time::Duration, + }; use zed::{ channel::{Channel, ChannelDetails, ChannelList}, editor::{Editor, Insert}, fs::{FakeFs, Fs as _}, language::LanguageRegistry, - rpc::Client, + rpc::{self, Client}, settings, user::UserStore, worktree::Worktree, @@ -1677,11 +1685,168 @@ mod tests { ); } + #[gpui::test] + async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { + cx_a.foreground().forbid_parking(); + + // Connect to a server as 2 clients. + let mut server = TestServer::start().await; + let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; + let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await; + let mut status_b = client_b.status(); + + // Create an org that includes these 2 users. + let db = &server.app_state.db; + let org_id = db.create_org("Test Org", "test-org").await.unwrap(); + db.add_org_member(org_id, user_id_a, false).await.unwrap(); + db.add_org_member(org_id, user_id_b, false).await.unwrap(); + + // Create a channel that includes all the users. + let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); + db.add_channel_member(channel_id, user_id_a, false) + .await + .unwrap(); + db.add_channel_member(channel_id, user_id_b, false) + .await + .unwrap(); + db.create_channel_message( + channel_id, + user_id_b, + "hello A, it's B.", + OffsetDateTime::now_utc(), + ) + .await + .unwrap(); + + let user_store_a = Arc::new(UserStore::new(client_a.clone())); + let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx)); + channels_a + .condition(&mut cx_a, |list, _| list.available_channels().is_some()) + .await; + + channels_a.read_with(&cx_a, |list, _| { + assert_eq!( + list.available_channels().unwrap(), + &[ChannelDetails { + id: channel_id.to_proto(), + name: "test-channel".to_string() + }] + ) + }); + let channel_a = channels_a.update(&mut cx_a, |this, cx| { + this.get_channel(channel_id.to_proto(), cx).unwrap() + }); + channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty())); + channel_a + .condition(&cx_a, |channel, _| { + channel_messages(channel) + == [("user_b".to_string(), "hello A, it's B.".to_string())] + }) + .await; + + let user_store_b = Arc::new(UserStore::new(client_b.clone())); + let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx)); + channels_b + .condition(&mut cx_b, |list, _| list.available_channels().is_some()) + .await; + channels_b.read_with(&cx_b, |list, _| { + assert_eq!( + list.available_channels().unwrap(), + &[ChannelDetails { + id: channel_id.to_proto(), + name: "test-channel".to_string() + }] + ) + }); + + let channel_b = channels_b.update(&mut cx_b, |this, cx| { + this.get_channel(channel_id.to_proto(), cx).unwrap() + }); + channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty())); + channel_b + .condition(&cx_b, |channel, _| { + channel_messages(channel) + == [("user_b".to_string(), "hello A, it's B.".to_string())] + }) + .await; + + // Disconnect client B, ensuring we can still access its cached channel data. + server.forbid_connections(); + server.disconnect_client(user_id_b); + while !matches!( + status_b.recv().await, + Some(rpc::Status::ReconnectionError { .. }) + ) {} + + channels_b.read_with(&cx_b, |channels, _| { + assert_eq!( + channels.available_channels().unwrap(), + [ChannelDetails { + id: channel_id.to_proto(), + name: "test-channel".to_string() + }] + ) + }); + channel_b.read_with(&cx_b, |channel, _| { + assert_eq!( + channel_messages(channel), + [("user_b".to_string(), "hello A, it's B.".to_string())] + ) + }); + + // Send a message from client A while B is disconnected. + channel_a + .update(&mut cx_a, |channel, cx| { + channel + .send_message("oh, hi B.".to_string(), cx) + .unwrap() + .detach(); + let task = channel.send_message("sup".to_string(), cx).unwrap(); + assert_eq!( + channel + .pending_messages() + .iter() + .map(|m| &m.body) + .collect::>(), + &["oh, hi B.", "sup"] + ); + task + }) + .await + .unwrap(); + + // Give client B a chance to reconnect. + server.allow_connections(); + cx_b.foreground().advance_clock(Duration::from_secs(10)); + + // Verify that B sees the new messages upon reconnection. + channel_b + .condition(&cx_b, |channel, _| { + channel_messages(channel) + == [ + ("user_b".to_string(), "hello A, it's B.".to_string()), + ("user_a".to_string(), "oh, hi B.".to_string()), + ("user_a".to_string(), "sup".to_string()), + ] + }) + .await; + + fn channel_messages(channel: &Channel) -> Vec<(String, String)> { + channel + .messages() + .cursor::<(), ()>() + .map(|m| (m.sender.github_login.clone(), m.body.clone())) + .collect() + } + } + struct TestServer { peer: Arc, app_state: Arc, server: Arc, notifications: mpsc::Receiver<()>, + connection_killers: Arc>>>>, + forbid_connections: Arc, _test_db: TestDb, } @@ -1697,6 +1862,8 @@ mod tests { app_state, server, notifications: notifications.1, + connection_killers: Default::default(), + forbid_connections: Default::default(), _test_db: test_db, } } @@ -1710,6 +1877,8 @@ mod tests { let client_name = name.to_string(); let mut client = Client::new(); let server = self.server.clone(); + let connection_killers = self.connection_killers.clone(); + let forbid_connections = self.forbid_connections.clone(); Arc::get_mut(&mut client) .unwrap() .set_login_and_connect_callbacks( @@ -1719,15 +1888,20 @@ mod tests { Ok((client_user_id.0 as u64, access_token)) }) }, - { - move |user_id, access_token, cx| { - assert_eq!(user_id, client_user_id.0 as u64); - assert_eq!(access_token, "the-token"); + move |user_id, access_token, cx| { + assert_eq!(user_id, client_user_id.0 as u64); + assert_eq!(access_token, "the-token"); - let server = server.clone(); - let client_name = client_name.clone(); - cx.spawn(move |cx| async move { - let (client_conn, server_conn) = Conn::in_memory(); + let server = server.clone(); + let connection_killers = connection_killers.clone(); + let forbid_connections = forbid_connections.clone(); + let client_name = client_name.clone(); + cx.spawn(move |cx| async move { + if forbid_connections.load(SeqCst) { + Err(anyhow!("server is forbidding connections")) + } else { + let (client_conn, server_conn, kill_conn) = Conn::in_memory(); + connection_killers.lock().insert(client_user_id, kill_conn); cx.background() .spawn(server.handle_connection( server_conn, @@ -1736,8 +1910,8 @@ mod tests { )) .detach(); Ok(client_conn) - }) - } + } + }) }, ); @@ -1748,6 +1922,20 @@ mod tests { (client_user_id, client) } + fn disconnect_client(&self, user_id: UserId) { + if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) { + let _ = kill_conn.try_send(Some(())); + } + } + + fn forbid_connections(&self) { + self.forbid_connections.store(true, SeqCst); + } + + fn allow_connections(&self) { + self.forbid_connections.store(false, SeqCst); + } + async fn build_app_state(test_db: &TestDb) -> Arc { let mut config = Config::default(); config.session_secret = "a".repeat(32); diff --git a/zed/src/test.rs b/zed/src/test.rs index cf1fbfd9e8..ce865bbfe5 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -258,7 +258,7 @@ impl FakeServer { if self.forbid_connections.load(SeqCst) { Err(anyhow!("server is forbidding connections")) } else { - let (client_conn, server_conn) = Conn::in_memory(); + let (client_conn, server_conn, _) = Conn::in_memory(); let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; cx.background().spawn(io).detach(); *self.incoming.lock() = Some(incoming); diff --git a/zrpc/src/conn.rs b/zrpc/src/conn.rs index 06dbcee077..f25fb12f01 100644 --- a/zrpc/src/conn.rs +++ b/zrpc/src/conn.rs @@ -33,22 +33,55 @@ impl Conn { } #[cfg(any(test, feature = "test-support"))] - pub fn in_memory() -> (Self, Self) { - use futures::SinkExt as _; - use futures::StreamExt as _; - use std::io::{Error, ErrorKind}; + pub fn in_memory() -> (Self, Self, postage::watch::Sender>) { + let (kill_tx, mut kill_rx) = postage::watch::channel_with(None); + postage::stream::Stream::try_recv(&mut kill_rx).unwrap(); - let (a_tx, a_rx) = futures::channel::mpsc::unbounded::(); - let (b_tx, b_rx) = futures::channel::mpsc::unbounded::(); + let (a_tx, a_rx) = Self::channel(kill_rx.clone()); + let (b_tx, b_rx) = Self::channel(kill_rx); ( - Self { - tx: Box::new(a_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())), - rx: Box::new(b_rx.map(Ok)), - }, - Self { - tx: Box::new(b_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())), - rx: Box::new(a_rx.map(Ok)), - }, + Self { tx: a_tx, rx: b_rx }, + Self { tx: b_tx, rx: a_rx }, + kill_tx, ) } + + #[cfg(any(test, feature = "test-support"))] + fn channel( + kill_rx: postage::watch::Receiver>, + ) -> ( + Box>, + Box>>, + ) { + use futures::{future, stream, SinkExt as _, StreamExt as _}; + use std::io::{Error, ErrorKind}; + + let (tx, rx) = futures::channel::mpsc::unbounded::(); + let tx = tx + .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e))) + .with({ + let kill_rx = kill_rx.clone(); + move |msg| { + if kill_rx.borrow().is_none() { + future::ready(Ok(msg)) + } else { + future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into())) + } + } + }); + let rx = stream::select( + rx.map(Ok), + kill_rx.filter_map(|kill| { + if let Some(_) = kill { + future::ready(Some(Err( + Error::new(ErrorKind::Other, "connection killed").into() + ))) + } else { + future::ready(None) + } + }), + ); + + (Box::new(tx), Box::new(rx)) + } } diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index d50ee50ec3..93c413acbc 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -352,12 +352,12 @@ mod tests { let client1 = Peer::new(); let client2 = Peer::new(); - let (client1_to_server_conn, server_to_client_1_conn) = Conn::in_memory(); + let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory(); let (client1_conn_id, io_task1, _) = client1.add_connection(client1_to_server_conn).await; let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await; - let (client2_to_server_conn, server_to_client_2_conn) = Conn::in_memory(); + let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory(); let (client2_conn_id, io_task3, _) = client2.add_connection(client2_to_server_conn).await; let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await; @@ -492,7 +492,7 @@ mod tests { #[test] fn test_disconnect() { smol::block_on(async move { - let (client_conn, mut server_conn) = Conn::in_memory(); + let (client_conn, mut server_conn, _) = Conn::in_memory(); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = @@ -526,7 +526,7 @@ mod tests { #[test] fn test_io_error() { smol::block_on(async move { - let (client_conn, server_conn) = Conn::in_memory(); + let (client_conn, server_conn, _) = Conn::in_memory(); drop(server_conn); let client = Peer::new(); From 8de18b5a84cdd1ccbf7de8f7823c176bb89b3c13 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 9 Sep 2021 15:34:46 +0200 Subject: [PATCH 12/22] Re-join channel when reconnecting Co-Authored-By: Nathan Sobo --- gpui/src/executor.rs | 14 +++---- zed/src/channel.rs | 98 +++++++++++++++++++++++++++++++++---------- zed/src/chat_panel.rs | 2 +- zed/src/rpc.rs | 9 ++-- zrpc/src/conn.rs | 6 +-- 5 files changed, 92 insertions(+), 37 deletions(-) diff --git a/gpui/src/executor.rs b/gpui/src/executor.rs index 5849a86902..9c7681e19e 100644 --- a/gpui/src/executor.rs +++ b/gpui/src/executor.rs @@ -51,7 +51,7 @@ struct DeterministicState { forbid_parking: bool, block_on_ticks: RangeInclusive, now: Instant, - pending_sleeps: Vec<(Instant, barrier::Sender)>, + pending_timers: Vec<(Instant, barrier::Sender)>, } pub struct Deterministic { @@ -71,7 +71,7 @@ impl Deterministic { forbid_parking: false, block_on_ticks: 0..=1000, now: Instant::now(), - pending_sleeps: Default::default(), + pending_timers: Default::default(), })), parker: Default::default(), } @@ -428,14 +428,14 @@ impl Foreground { } } - pub async fn sleep(&self, duration: Duration) { + pub async fn timer(&self, duration: Duration) { match self { Self::Deterministic(executor) => { let (tx, mut rx) = barrier::channel(); { let mut state = executor.state.lock(); let wakeup_at = state.now + duration; - state.pending_sleeps.push((wakeup_at, tx)); + state.pending_timers.push((wakeup_at, tx)); } rx.recv().await; } @@ -453,11 +453,11 @@ impl Foreground { let mut state = executor.state.lock(); state.now += duration; let now = state.now; - let mut pending_sleeps = mem::take(&mut state.pending_sleeps); + let mut pending_timers = mem::take(&mut state.pending_timers); drop(state); - pending_sleeps.retain(|(wakeup, _)| *wakeup > now); - executor.state.lock().pending_sleeps.extend(pending_sleeps); + pending_timers.retain(|(wakeup, _)| *wakeup > now); + executor.state.lock().pending_timers.extend(pending_timers); } _ => panic!("this method can only be called on a deterministic executor"), } diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 601e47355e..4425174b26 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -11,6 +11,7 @@ use gpui::{ use postage::prelude::Stream; use std::{ collections::{HashMap, HashSet}, + mem, ops::Range, sync::Arc, }; @@ -71,7 +72,7 @@ pub enum ChannelListEvent {} #[derive(Clone, Debug, PartialEq)] pub enum ChannelEvent { - MessagesAdded { + MessagesUpdated { old_range: Range, new_count: usize, }, @@ -93,26 +94,40 @@ impl ChannelList { let mut status = rpc.status(); loop { let status = status.recv().await.unwrap(); - let available_channels = if matches!(status, rpc::Status::Connected { .. }) { - let response = rpc - .request(proto::GetChannels {}) - .await - .context("failed to fetch available channels")?; - Some(response.channels.into_iter().map(Into::into).collect()) - } else { - None - }; + match status { + rpc::Status::Connected { .. } => { + let response = rpc + .request(proto::GetChannels {}) + .await + .context("failed to fetch available channels")?; + this.update(&mut cx, |this, cx| { + this.available_channels = + Some(response.channels.into_iter().map(Into::into).collect()); - this.update(&mut cx, |this, cx| { - if available_channels.is_none() { - if this.available_channels.is_none() { - return; - } - this.channels.clear(); + let mut to_remove = Vec::new(); + for (channel_id, channel) in &this.channels { + if let Some(channel) = channel.upgrade(cx) { + channel.update(cx, |channel, cx| channel.rejoin(cx)) + } else { + to_remove.push(*channel_id); + } + } + + for channel_id in to_remove { + this.channels.remove(&channel_id); + } + cx.notify(); + }); } - this.available_channels = available_channels; - cx.notify(); - }); + rpc::Status::Disconnected { .. } => { + this.update(&mut cx, |this, cx| { + this.available_channels = None; + this.channels.clear(); + cx.notify(); + }); + } + _ => {} + } } } .log_err() @@ -282,6 +297,43 @@ impl Channel { false } + pub fn rejoin(&mut self, cx: &mut ModelContext) { + let user_store = self.user_store.clone(); + let rpc = self.rpc.clone(); + let channel_id = self.details.id; + cx.spawn(|channel, mut cx| { + async move { + let response = rpc.request(proto::JoinChannel { channel_id }).await?; + let messages = messages_from_proto(response.messages, &user_store).await?; + let loaded_all_messages = response.done; + + channel.update(&mut cx, |channel, cx| { + if let Some((first_new_message, last_old_message)) = + messages.first().zip(channel.messages.last()) + { + if first_new_message.id > last_old_message.id { + let old_messages = mem::take(&mut channel.messages); + cx.emit(ChannelEvent::MessagesUpdated { + old_range: 0..old_messages.summary().count, + new_count: 0, + }); + channel.loaded_all_messages = loaded_all_messages; + } + } + + channel.insert_messages(messages, cx); + if loaded_all_messages { + channel.loaded_all_messages = loaded_all_messages; + } + }); + + Ok(()) + } + .log_err() + }) + .detach(); + } + pub fn message_count(&self) -> usize { self.messages.summary().count } @@ -347,7 +399,7 @@ impl Channel { drop(old_cursor); self.messages = new_messages; - cx.emit(ChannelEvent::MessagesAdded { + cx.emit(ChannelEvent::MessagesUpdated { old_range: start_ix..end_ix, new_count, }); @@ -539,7 +591,7 @@ mod tests { assert_eq!( channel.next_event(&cx).await, - ChannelEvent::MessagesAdded { + ChannelEvent::MessagesUpdated { old_range: 0..0, new_count: 2, } @@ -588,7 +640,7 @@ mod tests { assert_eq!( channel.next_event(&cx).await, - ChannelEvent::MessagesAdded { + ChannelEvent::MessagesUpdated { old_range: 2..2, new_count: 1, } @@ -635,7 +687,7 @@ mod tests { assert_eq!( channel.next_event(&cx).await, - ChannelEvent::MessagesAdded { + ChannelEvent::MessagesUpdated { old_range: 0..0, new_count: 2, } diff --git a/zed/src/chat_panel.rs b/zed/src/chat_panel.rs index a4b7eb5938..200a35fcca 100644 --- a/zed/src/chat_panel.rs +++ b/zed/src/chat_panel.rs @@ -194,7 +194,7 @@ impl ChatPanel { cx: &mut ViewContext, ) { match event { - ChannelEvent::MessagesAdded { + ChannelEvent::MessagesUpdated { old_range, new_count, } => { diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 3399c1cf41..07eb22f137 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -140,7 +140,7 @@ impl Client { state._maintain_connection = Some(cx.foreground().spawn(async move { let mut next_ping_id = 0; loop { - foreground.sleep(heartbeat_interval).await; + foreground.timer(heartbeat_interval).await; this.request(proto::Ping { id: next_ping_id }) .await .unwrap(); @@ -162,7 +162,7 @@ impl Client { }, &cx, ); - foreground.sleep(delay).await; + foreground.timer(delay).await; delay_seconds = (delay_seconds * 2).min(300); } })); @@ -233,11 +233,14 @@ impl Client { ) -> anyhow::Result<()> { let was_disconnected = match *self.status().borrow() { Status::Disconnected => true, + Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => { + false + } Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } + | Status::Authenticating | Status::Reauthenticating => return Ok(()), - _ => false, }; if was_disconnected { diff --git a/zrpc/src/conn.rs b/zrpc/src/conn.rs index f25fb12f01..4b2356c872 100644 --- a/zrpc/src/conn.rs +++ b/zrpc/src/conn.rs @@ -72,12 +72,12 @@ impl Conn { let rx = stream::select( rx.map(Ok), kill_rx.filter_map(|kill| { - if let Some(_) = kill { + if kill.is_none() { + future::ready(None) + } else { future::ready(Some(Err( Error::new(ErrorKind::Other, "connection killed").into() ))) - } else { - future::ready(None) } }), ); From 3c61a3e82650250bc8483872221b8d834d1d80b1 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 9 Sep 2021 15:40:35 +0200 Subject: [PATCH 13/22] Ensure client A and B can communicate after reconnection Co-Authored-By: Nathan Sobo --- server/src/rpc.rs | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index cc0d35d097..cd2409250e 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1831,6 +1831,44 @@ mod tests { }) .await; + // Ensure client A and B can communicate normally after reconnection. + channel_a + .update(&mut cx_a, |channel, cx| { + channel.send_message("you online?".to_string(), cx).unwrap() + }) + .await + .unwrap(); + channel_b + .condition(&cx_b, |channel, _| { + channel_messages(channel) + == [ + ("user_b".to_string(), "hello A, it's B.".to_string()), + ("user_a".to_string(), "oh, hi B.".to_string()), + ("user_a".to_string(), "sup".to_string()), + ("user_a".to_string(), "you online?".to_string()), + ] + }) + .await; + + channel_b + .update(&mut cx_b, |channel, cx| { + channel.send_message("yep".to_string(), cx).unwrap() + }) + .await + .unwrap(); + channel_a + .condition(&cx_a, |channel, _| { + channel_messages(channel) + == [ + ("user_b".to_string(), "hello A, it's B.".to_string()), + ("user_a".to_string(), "oh, hi B.".to_string()), + ("user_a".to_string(), "sup".to_string()), + ("user_a".to_string(), "you online?".to_string()), + ("user_b".to_string(), "yep".to_string()), + ] + }) + .await; + fn channel_messages(channel: &Channel) -> Vec<(String, String)> { channel .messages() From e2b56e8764fbf408e2cbd1ba39a263c58e3f67c6 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Thu, 9 Sep 2021 08:25:58 -0600 Subject: [PATCH 14/22] If a test connection has been killed, never return a message --- zrpc/src/conn.rs | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/zrpc/src/conn.rs b/zrpc/src/conn.rs index 4b2356c872..e67b4fa587 100644 --- a/zrpc/src/conn.rs +++ b/zrpc/src/conn.rs @@ -1,5 +1,6 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; -use futures::{SinkExt as _, StreamExt as _}; +use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _}; +use std::{io, task::Poll}; pub struct Conn { pub(crate) tx: @@ -53,10 +54,10 @@ impl Conn { Box>, Box>>, ) { - use futures::{future, stream, SinkExt as _, StreamExt as _}; - use std::io::{Error, ErrorKind}; + use futures::{future, SinkExt as _}; + use io::{Error, ErrorKind}; - let (tx, rx) = futures::channel::mpsc::unbounded::(); + let (tx, rx) = mpsc::unbounded::(); let tx = tx .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e))) .with({ @@ -69,19 +70,32 @@ impl Conn { } } }); - let rx = stream::select( - rx.map(Ok), - kill_rx.filter_map(|kill| { - if kill.is_none() { - future::ready(None) - } else { - future::ready(Some(Err( - Error::new(ErrorKind::Other, "connection killed").into() - ))) - } - }), - ); + let rx = KillableReceiver { kill_rx, rx }; (Box::new(tx), Box::new(rx)) } } + +struct KillableReceiver { + rx: mpsc::UnboundedReceiver, + kill_rx: postage::watch::Receiver>, +} + +impl Stream for KillableReceiver { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) { + Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::Other, + "connection killed", + ) + .into()))) + } else { + self.rx.poll_next_unpin(cx).map(|value| value.map(Ok)) + } + } +} From d08ec8bd53c79196849029789441489bd105c95e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 9 Sep 2021 17:00:18 +0200 Subject: [PATCH 15/22] Reduce backoff and add some jitter to avoid thundering herd issues Co-Authored-By: Nathan Sobo --- zed/src/rpc.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 07eb22f137..088f30e1d1 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -5,6 +5,7 @@ use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use lazy_static::lazy_static; use parking_lot::RwLock; use postage::{prelude::Stream, watch}; +use rand::prelude::*; use std::{ any::TypeId, collections::HashMap, @@ -151,11 +152,12 @@ impl Client { Status::ConnectionLost => { let this = self.clone(); let foreground = cx.foreground(); + let heartbeat_interval = state.heartbeat_interval; state._maintain_connection = Some(cx.spawn(|cx| async move { - let mut delay_seconds = 5; + let mut rng = StdRng::from_entropy(); + let mut delay = Duration::from_millis(100); while let Err(error) = this.authenticate_and_connect(&cx).await { log::error!("failed to connect {}", error); - let delay = Duration::from_secs(delay_seconds); this.set_status( Status::ReconnectionError { next_reconnection: Instant::now() + delay, @@ -163,7 +165,9 @@ impl Client { &cx, ); foreground.timer(delay).await; - delay_seconds = (delay_seconds * 2).min(300); + delay = delay + .mul_f32(rng.gen_range(1.0..=2.0)) + .min(heartbeat_interval); } })); } From 8fb58e09d8e9b76ab39716ab3debe8edd2837302 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 9 Sep 2021 17:00:55 +0200 Subject: [PATCH 16/22] Remove channel disconnection unit test ...as that's already covered by the integration test. --- zed/src/channel.rs | 65 ---------------------------------------------- 1 file changed, 65 deletions(-) diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 4425174b26..8a10655e0e 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -705,69 +705,4 @@ mod tests { ); }); } - - #[gpui::test] - async fn test_channel_reconnect(mut cx: TestAppContext) { - let user_id = 5; - let mut client = Client::new(); - let server = FakeServer::for_client(user_id, &mut client, &cx).await; - - let user_store = Arc::new(UserStore::new(client.clone())); - - let channel = cx.add_model(|cx| { - Channel::new( - ChannelDetails { - id: 1, - name: "general".into(), - }, - user_store, - client.clone(), - cx, - ) - }); - - let join_channel = server.receive::().await.unwrap(); - server - .respond( - join_channel.receipt(), - proto::JoinChannelResponse { - messages: vec![ - proto::ChannelMessage { - id: 10, - body: "a".into(), - timestamp: 1000, - sender_id: 5, - }, - proto::ChannelMessage { - id: 11, - body: "b".into(), - timestamp: 1001, - sender_id: 5, - }, - ], - done: false, - }, - ) - .await; - - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![5]); - server - .respond( - get_users.receipt(), - proto::GetUsersResponse { - users: vec![proto::User { - id: 5, - github_login: "nathansobo".into(), - avatar_url: "http://avatar.com/nathansobo".into(), - }], - }, - ) - .await; - - // Disconnect, wait for the client to reconnect. - server.disconnect().await; - cx.foreground().advance_clock(Duration::from_secs(10)); - let get_messages = server.receive::().await.unwrap(); - } } From b3aad5d3335863cd0502045cd26aeac6f25de245 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 9 Sep 2021 17:45:33 +0200 Subject: [PATCH 17/22] :lipstick: --- zed/src/channel.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 8a10655e0e..48ede386f8 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -497,7 +497,6 @@ mod tests { use super::*; use crate::test::FakeServer; use gpui::TestAppContext; - use std::time::Duration; #[gpui::test] async fn test_channel_messages(mut cx: TestAppContext) { From 000305472aa1b298c6b75f44cf72fd27cf1f9cbd Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 9 Sep 2021 19:51:26 +0200 Subject: [PATCH 18/22] Minor stylistic changes --- zed/src/channel.rs | 3 +-- zed/src/worktree.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 48ede386f8..5c7a2aaa1a 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -93,8 +93,7 @@ impl ChannelList { async move { let mut status = rpc.status(); loop { - let status = status.recv().await.unwrap(); - match status { + match status.recv().await.unwrap() { rpc::Status::Connected { .. } => { let response = rpc .request(proto::GetChannels {}) diff --git a/zed/src/worktree.rs b/zed/src/worktree.rs index 7b0e212d8d..d290454f46 100644 --- a/zed/src/worktree.rs +++ b/zed/src/worktree.rs @@ -1582,7 +1582,7 @@ impl File { .await { log::error!("error closing remote buffer: {}", error); - }; + } }) .detach(); } From 38bfaba135cc826d88f40439d2398f7190c2fc17 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 9 Sep 2021 11:24:16 -0700 Subject: [PATCH 19/22] Add a generic `Ack` message, use it instead of `Pong` Remove the `id` field from `Ping`, because it isn't used. There is already an id on the message envelope. --- server/src/rpc.rs | 9 +-------- zed/src/rpc.rs | 12 +++--------- zrpc/proto/zed.proto | 14 +++++--------- zrpc/src/peer.rs | 18 ++++++------------ zrpc/src/proto.rs | 4 ++-- 5 files changed, 17 insertions(+), 40 deletions(-) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index cd2409250e..bbc3f23822 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -246,14 +246,7 @@ impl Server { } async fn ping(self: Arc, request: TypedEnvelope) -> tide::Result<()> { - self.peer - .respond( - request.receipt(), - proto::Pong { - id: request.payload.id, - }, - ) - .await?; + self.peer.respond(request.receipt(), proto::Ack {}).await?; Ok(()) } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 088f30e1d1..1ac9c40161 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -139,13 +139,9 @@ impl Client { let this = self.clone(); let foreground = cx.foreground(); state._maintain_connection = Some(cx.foreground().spawn(async move { - let mut next_ping_id = 0; loop { foreground.timer(heartbeat_interval).await; - this.request(proto::Ping { id: next_ping_id }) - .await - .unwrap(); - next_ping_id += 1; + this.request(proto::Ping {}).await.unwrap(); } })); } @@ -543,13 +539,11 @@ mod tests { cx.foreground().advance_clock(Duration::from_secs(10)); let ping = server.receive::().await.unwrap(); - assert_eq!(ping.payload.id, 0); - server.respond(ping.receipt(), proto::Pong { id: 0 }).await; + server.respond(ping.receipt(), proto::Ack {}).await; cx.foreground().advance_clock(Duration::from_secs(10)); let ping = server.receive::().await.unwrap(); - assert_eq!(ping.payload.id, 1); - server.respond(ping.receipt(), proto::Pong { id: 1 }).await; + server.respond(ping.receipt(), proto::Ack {}).await; client.disconnect(&cx.to_async()).await.unwrap(); assert!(server.receive::().await.is_err()); diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index 66b3c8d226..c9f1dc0f80 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -6,9 +6,9 @@ message Envelope { optional uint32 responding_to = 2; optional uint32 original_sender_id = 3; oneof payload { - Error error = 4; - Ping ping = 5; - Pong pong = 6; + Ack ack = 4; + Error error = 5; + Ping ping = 6; ShareWorktree share_worktree = 7; ShareWorktreeResponse share_worktree_response = 8; OpenWorktree open_worktree = 9; @@ -40,13 +40,9 @@ message Envelope { // Messages -message Ping { - int32 id = 1; -} +message Ping {} -message Pong { - int32 id = 2; -} +message Ack {} message Error { string message = 1; diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index 93c413acbc..75db257f55 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -371,18 +371,18 @@ mod tests { assert_eq!( client1 - .request(client1_conn_id, proto::Ping { id: 1 },) + .request(client1_conn_id, proto::Ping {},) .await .unwrap(), - proto::Pong { id: 1 } + proto::Ack {} ); assert_eq!( client2 - .request(client2_conn_id, proto::Ping { id: 2 },) + .request(client2_conn_id, proto::Ping {},) .await .unwrap(), - proto::Pong { id: 2 } + proto::Ack {} ); assert_eq!( @@ -438,13 +438,7 @@ mod tests { let envelope = envelope.into_any(); if let Some(envelope) = envelope.downcast_ref::>() { let receipt = envelope.receipt(); - peer.respond( - receipt, - proto::Pong { - id: envelope.payload.id, - }, - ) - .await? + peer.respond(receipt, proto::Ack {}).await? } else if let Some(envelope) = envelope.downcast_ref::>() { @@ -536,7 +530,7 @@ mod tests { smol::spawn(async move { incoming.next().await }).detach(); let err = client - .request(connection_id, proto::Ping { id: 42 }) + .request(connection_id, proto::Ping {}) .await .unwrap_err(); assert_eq!(err.to_string(), "connection was closed"); diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index ded7fdc1cd..6274c187e1 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -120,6 +120,7 @@ macro_rules! entity_messages { } messages!( + Ack, AddPeer, BufferSaved, ChannelMessageSent, @@ -140,7 +141,6 @@ messages!( OpenWorktree, OpenWorktreeResponse, Ping, - Pong, RemovePeer, SaveBuffer, SendChannelMessage, @@ -157,7 +157,7 @@ request_messages!( (JoinChannel, JoinChannelResponse), (OpenBuffer, OpenBufferResponse), (OpenWorktree, OpenWorktreeResponse), - (Ping, Pong), + (Ping, Ack), (SaveBuffer, BufferSaved), (ShareWorktree, ShareWorktreeResponse), (SendChannelMessage, SendChannelMessageResponse), From c58e335b8742cfc6bc3726b89eeef42ee33b70d0 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 9 Sep 2021 11:26:06 -0700 Subject: [PATCH 20/22] Make `UpdateBuffer` a request, store unsent operations on worktree --- server/src/rpc.rs | 4 +++- zed/src/worktree.rs | 19 +++++++++++++++---- zrpc/src/proto.rs | 1 + 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/server/src/rpc.rs b/server/src/rpc.rs index bbc3f23822..2bd0eac625 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -499,7 +499,9 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { self.broadcast_in_worktree(request.payload.worktree_id, &request) - .await + .await?; + self.peer.respond(request.receipt(), proto::Ack {}).await?; + Ok(()) } async fn buffer_saved( diff --git a/zed/src/worktree.rs b/zed/src/worktree.rs index d290454f46..7b2fea9175 100644 --- a/zed/src/worktree.rs +++ b/zed/src/worktree.rs @@ -234,6 +234,7 @@ impl Worktree { .into_iter() .map(|p| (PeerId(p.peer_id), p.replica_id as ReplicaId)) .collect(), + queued_operations: Default::default(), languages, _subscriptions, }) @@ -656,6 +657,7 @@ pub struct LocalWorktree { shared_buffers: HashMap>>, peers: HashMap, languages: Arc, + queued_operations: Vec<(u64, Operation)>, fs: Arc, } @@ -711,6 +713,7 @@ impl LocalWorktree { poll_task: None, open_buffers: Default::default(), shared_buffers: Default::default(), + queued_operations: Default::default(), peers: Default::default(), languages, fs, @@ -1091,6 +1094,7 @@ pub struct RemoteWorktree { open_buffers: HashMap, peers: HashMap, languages: Arc, + queued_operations: Vec<(u64, Operation)>, _subscriptions: Vec, } @@ -1550,16 +1554,23 @@ impl File { .map(|share| (share.rpc.clone(), share.remote_id)), Worktree::Remote(worktree) => Some((worktree.rpc.clone(), worktree.remote_id)), } { - cx.spawn(|_, _| async move { + cx.spawn(|worktree, mut cx| async move { if let Err(error) = rpc - .send(proto::UpdateBuffer { + .request(proto::UpdateBuffer { worktree_id: remote_id, buffer_id, - operations: Some(operation).iter().map(Into::into).collect(), + operations: vec![(&operation).into()], }) .await { - log::error!("error sending buffer operation: {}", error); + worktree.update(&mut cx, |worktree, _| { + log::error!("error sending buffer operation: {}", error); + match worktree { + Worktree::Local(t) => &mut t.queued_operations, + Worktree::Remote(t) => &mut t.queued_operations, + } + .push((buffer_id, operation)); + }); } }) .detach(); diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index 6274c187e1..af9dbf3abc 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -159,6 +159,7 @@ request_messages!( (OpenWorktree, OpenWorktreeResponse), (Ping, Ack), (SaveBuffer, BufferSaved), + (UpdateBuffer, Ack), (ShareWorktree, ShareWorktreeResponse), (SendChannelMessage, SendChannelMessageResponse), (GetChannelMessages, GetChannelMessagesResponse), From 5a4ba7f551e5e755113d53a2ea290c69cd7bf9ef Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 9 Sep 2021 17:56:32 -0700 Subject: [PATCH 21/22] :lipstick: Use time::Global::into in Anchor::into --- zed/src/editor/buffer.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/zed/src/editor/buffer.rs b/zed/src/editor/buffer.rs index 123b3f2db1..97e0202cec 100644 --- a/zed/src/editor/buffer.rs +++ b/zed/src/editor/buffer.rs @@ -2695,14 +2695,7 @@ impl<'a> Into for &'a EditOperation { impl<'a> Into for &'a Anchor { fn into(self) -> proto::Anchor { proto::Anchor { - version: self - .version - .iter() - .map(|entry| proto::VectorClockEntry { - replica_id: entry.replica_id as u32, - timestamp: entry.value, - }) - .collect(), + version: (&self.version).into(), offset: self.offset as u64, bias: match self.bias { Bias::Left => proto::anchor::Bias::Left as i32, From 680b86b17c63b67f768bc5da5f34e5ccf056a0ce Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 9 Sep 2021 17:57:06 -0700 Subject: [PATCH 22/22] Avoid holding strong handle to Channel in long-lived task --- zed/src/channel.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 5c7a2aaa1a..aa182c0540 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -88,12 +88,12 @@ impl ChannelList { rpc: Arc, cx: &mut ModelContext, ) -> Self { - let _task = cx.spawn(|this, mut cx| { + let _task = cx.spawn_weak(|this, mut cx| { let rpc = rpc.clone(); async move { let mut status = rpc.status(); - loop { - match status.recv().await.unwrap() { + while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) { + match status { rpc::Status::Connected { .. } => { let response = rpc .request(proto::GetChannels {}) @@ -128,6 +128,7 @@ impl ChannelList { _ => {} } } + Ok(()) } .log_err() });