Refactor and write a simple unit test to verify reconnection logic
This commit is contained in:
parent
6baa9fe37b
commit
ad7631de9f
3 changed files with 121 additions and 77 deletions
|
@ -17,7 +17,10 @@ use smol::channel;
|
|||
use std::{
|
||||
marker::PhantomData,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tempdir::TempDir;
|
||||
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
|
||||
|
@ -200,6 +203,7 @@ pub struct FakeServer {
|
|||
peer: Arc<Peer>,
|
||||
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
|
||||
connection_id: Mutex<Option<ConnectionId>>,
|
||||
forbid_connections: AtomicBool,
|
||||
}
|
||||
|
||||
impl FakeServer {
|
||||
|
@ -212,6 +216,7 @@ impl FakeServer {
|
|||
peer: Peer::new(),
|
||||
incoming: Default::default(),
|
||||
connection_id: Default::default(),
|
||||
forbid_connections: Default::default(),
|
||||
});
|
||||
|
||||
Arc::get_mut(client)
|
||||
|
@ -230,15 +235,14 @@ impl FakeServer {
|
|||
assert_eq!(access_token, "the-token");
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
move |cx| async move { Ok(server.connect(&cx).await) }
|
||||
move |cx| async move { server.connect(&cx).await }
|
||||
})
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let conn = result.connect(&cx.to_async()).await;
|
||||
client
|
||||
.set_connection(client_user_id, conn, &cx.to_async())
|
||||
.authenticate_and_connect(&cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
result
|
||||
|
@ -250,13 +254,25 @@ impl FakeServer {
|
|||
self.incoming.lock().take();
|
||||
}
|
||||
|
||||
async fn connect(&self, cx: &AsyncAppContext) -> Conn {
|
||||
let (client_conn, server_conn) = Conn::in_memory();
|
||||
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
||||
cx.background().spawn(io).detach();
|
||||
*self.incoming.lock() = Some(incoming);
|
||||
*self.connection_id.lock() = Some(connection_id);
|
||||
client_conn
|
||||
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
|
||||
if self.forbid_connections.load(SeqCst) {
|
||||
Err(anyhow!("server is forbidding connections"))
|
||||
} else {
|
||||
let (client_conn, server_conn) = Conn::in_memory();
|
||||
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
||||
cx.background().spawn(io).detach();
|
||||
*self.incoming.lock() = Some(incoming);
|
||||
*self.connection_id.lock() = Some(connection_id);
|
||||
Ok(client_conn)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forbid_connections(&self) {
|
||||
self.forbid_connections.store(true, SeqCst);
|
||||
}
|
||||
|
||||
pub fn allow_connections(&self) {
|
||||
self.forbid_connections.store(false, SeqCst);
|
||||
}
|
||||
|
||||
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue