Refactor and write a simple unit test to verify reconnection logic

This commit is contained in:
Antonio Scandurra 2021-09-09 11:00:43 +02:00
parent 6baa9fe37b
commit ad7631de9f
3 changed files with 121 additions and 77 deletions

View file

@ -17,7 +17,10 @@ use smol::channel;
use std::{
marker::PhantomData,
path::{Path, PathBuf},
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use tempdir::TempDir;
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
@ -200,6 +203,7 @@ pub struct FakeServer {
peer: Arc<Peer>,
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
connection_id: Mutex<Option<ConnectionId>>,
forbid_connections: AtomicBool,
}
impl FakeServer {
@ -212,6 +216,7 @@ impl FakeServer {
peer: Peer::new(),
incoming: Default::default(),
connection_id: Default::default(),
forbid_connections: Default::default(),
});
Arc::get_mut(client)
@ -230,15 +235,14 @@ impl FakeServer {
assert_eq!(access_token, "the-token");
cx.spawn({
let server = server.clone();
move |cx| async move { Ok(server.connect(&cx).await) }
move |cx| async move { server.connect(&cx).await }
})
}
},
);
let conn = result.connect(&cx.to_async()).await;
client
.set_connection(client_user_id, conn, &cx.to_async())
.authenticate_and_connect(&cx.to_async())
.await
.unwrap();
result
@ -250,13 +254,25 @@ impl FakeServer {
self.incoming.lock().take();
}
async fn connect(&self, cx: &AsyncAppContext) -> Conn {
let (client_conn, server_conn) = Conn::in_memory();
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
client_conn
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
if self.forbid_connections.load(SeqCst) {
Err(anyhow!("server is forbidding connections"))
} else {
let (client_conn, server_conn) = Conn::in_memory();
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
Ok(client_conn)
}
}
pub fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}
pub fn allow_connections(&self) {
self.forbid_connections.store(false, SeqCst);
}
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {