Refactor and write a simple unit test to verify reconnection logic

This commit is contained in:
Antonio Scandurra 2021-09-09 11:00:43 +02:00
parent 6baa9fe37b
commit ad7631de9f
3 changed files with 121 additions and 77 deletions

View file

@ -1693,20 +1693,46 @@ mod tests {
cx: &mut TestAppContext,
name: &str,
) -> (UserId, Arc<Client>) {
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
let client = Client::new();
let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
let client_name = name.to_string();
let mut client = Client::new();
let server = self.server.clone();
Arc::get_mut(&mut client)
.unwrap()
.set_login_and_connect_callbacks(
move |cx| {
cx.spawn(|_| async move {
let access_token = "the-token".to_string();
Ok((client_user_id.0 as u64, access_token))
})
},
{
move |user_id, access_token, cx| {
assert_eq!(user_id, client_user_id.0 as u64);
assert_eq!(access_token, "the-token");
let server = server.clone();
let client_name = client_name.clone();
cx.spawn(move |cx| async move {
let (client_conn, server_conn) = Conn::in_memory();
cx.background()
.spawn(
self.server
.handle_connection(server_conn, name.to_string(), user_id),
)
.spawn(server.handle_connection(
server_conn,
client_name,
client_user_id,
))
.detach();
Ok(client_conn)
})
}
},
);
client
.set_connection(user_id.to_proto(), client_conn, &cx.to_async())
.authenticate_and_connect(&cx.to_async())
.await
.unwrap();
(user_id, client)
(client_user_id, client)
}
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {

View file

@ -49,7 +49,10 @@ pub enum Status {
user_id: u64,
},
ConnectionLost,
Reconnecting,
Reauthenticating,
Reconnecting {
user_id: u64,
},
ReconnectionError {
next_reconnection: Instant,
},
@ -164,9 +167,10 @@ impl Client {
}
}));
}
_ => {
Status::Disconnected => {
state._maintain_connection.take();
}
_ => {}
}
}
@ -227,14 +231,20 @@ impl Client {
self: &Arc<Self>,
cx: &AsyncAppContext,
) -> anyhow::Result<()> {
if matches!(
*self.status().borrow(),
Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. }
) {
return Ok(());
}
let was_disconnected = match *self.status().borrow() {
Status::Disconnected => true,
Status::Connected { .. }
| Status::Connecting { .. }
| Status::Reconnecting { .. }
| Status::Reauthenticating => return Ok(()),
_ => false,
};
if was_disconnected {
self.set_status(Status::Authenticating, cx);
} else {
self.set_status(Status::Reauthenticating, cx)
}
let (user_id, access_token) = match self.authenticate(&cx).await {
Ok(result) => result,
@ -244,27 +254,25 @@ impl Client {
}
};
if was_disconnected {
self.set_status(Status::Connecting { user_id }, cx);
let conn = match self.connect(user_id, &access_token, cx).await {
Ok(conn) => conn,
Err(err) => {
self.set_status(Status::ConnectionError, cx);
return Err(err);
} else {
self.set_status(Status::Reconnecting { user_id }, cx);
}
};
self.set_connection(user_id, conn, cx).await?;
match self.connect(user_id, &access_token, cx).await {
Ok(conn) => {
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
self.set_connection(user_id, conn, cx).await;
Ok(())
}
Err(err) => {
self.set_status(Status::ConnectionError, cx);
Err(err)
}
}
}
pub async fn set_connection(
self: &Arc<Self>,
user_id: u64,
conn: Conn,
cx: &AsyncAppContext,
) -> Result<()> {
async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
cx.foreground()
.spawn({
@ -321,7 +329,6 @@ impl Client {
}
})
.detach();
Ok(())
}
fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
@ -489,35 +496,6 @@ impl Client {
}
}
pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
type Output: 'a + Future<Output = anyhow::Result<()>>;
fn handle(
&self,
message: TypedEnvelope<M>,
rpc: &'a Client,
cx: &'a mut gpui::AsyncAppContext,
) -> Self::Output;
}
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
where
M: proto::EnvelopedMessage,
F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
Fut: 'a + Future<Output = anyhow::Result<()>>,
{
type Output = Fut;
fn handle(
&self,
message: TypedEnvelope<M>,
rpc: &'a Client,
cx: &'a mut gpui::AsyncAppContext,
) -> Self::Output {
(self)(message, rpc, cx)
}
}
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
@ -550,6 +528,8 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_heartbeat(cx: TestAppContext) {
cx.foreground().forbid_parking();
let user_id = 5;
let mut client = Client::new();
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
@ -568,6 +548,28 @@ mod tests {
assert!(server.receive::<proto::Ping>().await.is_err());
}
#[gpui::test(iterations = 10)]
async fn test_reconnection(cx: TestAppContext) {
cx.foreground().forbid_parking();
let user_id = 5;
let mut client = Client::new();
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
let mut status = client.status();
assert!(matches!(
status.recv().await,
Some(Status::Connected { .. })
));
server.forbid_connections();
server.disconnect().await;
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
server.allow_connections();
cx.foreground().advance_clock(Duration::from_secs(10));
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
}
#[test]
fn test_encode_and_decode_worktree_url() {
let url = encode_worktree_url(5, "deadbeef");

View file

@ -17,7 +17,10 @@ use smol::channel;
use std::{
marker::PhantomData,
path::{Path, PathBuf},
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use tempdir::TempDir;
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
@ -200,6 +203,7 @@ pub struct FakeServer {
peer: Arc<Peer>,
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
connection_id: Mutex<Option<ConnectionId>>,
forbid_connections: AtomicBool,
}
impl FakeServer {
@ -212,6 +216,7 @@ impl FakeServer {
peer: Peer::new(),
incoming: Default::default(),
connection_id: Default::default(),
forbid_connections: Default::default(),
});
Arc::get_mut(client)
@ -230,15 +235,14 @@ impl FakeServer {
assert_eq!(access_token, "the-token");
cx.spawn({
let server = server.clone();
move |cx| async move { Ok(server.connect(&cx).await) }
move |cx| async move { server.connect(&cx).await }
})
}
},
);
let conn = result.connect(&cx.to_async()).await;
client
.set_connection(client_user_id, conn, &cx.to_async())
.authenticate_and_connect(&cx.to_async())
.await
.unwrap();
result
@ -250,13 +254,25 @@ impl FakeServer {
self.incoming.lock().take();
}
async fn connect(&self, cx: &AsyncAppContext) -> Conn {
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
if self.forbid_connections.load(SeqCst) {
Err(anyhow!("server is forbidding connections"))
} else {
let (client_conn, server_conn) = Conn::in_memory();
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
client_conn
Ok(client_conn)
}
}
pub fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}
pub fn allow_connections(&self) {
self.forbid_connections.store(false, SeqCst);
}
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {