Refactor and write a simple unit test to verify reconnection logic

This commit is contained in:
Antonio Scandurra 2021-09-09 11:00:43 +02:00
parent 6baa9fe37b
commit ad7631de9f
3 changed files with 121 additions and 77 deletions

View file

@ -1693,20 +1693,46 @@ mod tests {
cx: &mut TestAppContext, cx: &mut TestAppContext,
name: &str, name: &str,
) -> (UserId, Arc<Client>) { ) -> (UserId, Arc<Client>) {
let user_id = self.app_state.db.create_user(name, false).await.unwrap(); let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
let client = Client::new(); let client_name = name.to_string();
let (client_conn, server_conn) = Conn::in_memory(); let mut client = Client::new();
cx.background() let server = self.server.clone();
.spawn( Arc::get_mut(&mut client)
self.server .unwrap()
.handle_connection(server_conn, name.to_string(), user_id), .set_login_and_connect_callbacks(
) move |cx| {
.detach(); cx.spawn(|_| async move {
let access_token = "the-token".to_string();
Ok((client_user_id.0 as u64, access_token))
})
},
{
move |user_id, access_token, cx| {
assert_eq!(user_id, client_user_id.0 as u64);
assert_eq!(access_token, "the-token");
let server = server.clone();
let client_name = client_name.clone();
cx.spawn(move |cx| async move {
let (client_conn, server_conn) = Conn::in_memory();
cx.background()
.spawn(server.handle_connection(
server_conn,
client_name,
client_user_id,
))
.detach();
Ok(client_conn)
})
}
},
);
client client
.set_connection(user_id.to_proto(), client_conn, &cx.to_async()) .authenticate_and_connect(&cx.to_async())
.await .await
.unwrap(); .unwrap();
(user_id, client) (client_user_id, client)
} }
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> { async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {

View file

@ -49,7 +49,10 @@ pub enum Status {
user_id: u64, user_id: u64,
}, },
ConnectionLost, ConnectionLost,
Reconnecting, Reauthenticating,
Reconnecting {
user_id: u64,
},
ReconnectionError { ReconnectionError {
next_reconnection: Instant, next_reconnection: Instant,
}, },
@ -164,9 +167,10 @@ impl Client {
} }
})); }));
} }
_ => { Status::Disconnected => {
state._maintain_connection.take(); state._maintain_connection.take();
} }
_ => {}
} }
} }
@ -227,14 +231,20 @@ impl Client {
self: &Arc<Self>, self: &Arc<Self>,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
if matches!( let was_disconnected = match *self.status().borrow() {
*self.status().borrow(), Status::Disconnected => true,
Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. } Status::Connected { .. }
) { | Status::Connecting { .. }
return Ok(()); | Status::Reconnecting { .. }
} | Status::Reauthenticating => return Ok(()),
_ => false,
};
self.set_status(Status::Authenticating, cx); if was_disconnected {
self.set_status(Status::Authenticating, cx);
} else {
self.set_status(Status::Reauthenticating, cx)
}
let (user_id, access_token) = match self.authenticate(&cx).await { let (user_id, access_token) = match self.authenticate(&cx).await {
Ok(result) => result, Ok(result) => result,
@ -244,27 +254,25 @@ impl Client {
} }
}; };
self.set_status(Status::Connecting { user_id }, cx); if was_disconnected {
self.set_status(Status::Connecting { user_id }, cx);
let conn = match self.connect(user_id, &access_token, cx).await { } else {
Ok(conn) => conn, self.set_status(Status::Reconnecting { user_id }, cx);
}
match self.connect(user_id, &access_token, cx).await {
Ok(conn) => {
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
self.set_connection(user_id, conn, cx).await;
Ok(())
}
Err(err) => { Err(err) => {
self.set_status(Status::ConnectionError, cx); self.set_status(Status::ConnectionError, cx);
return Err(err); Err(err)
} }
}; }
self.set_connection(user_id, conn, cx).await?;
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
Ok(())
} }
pub async fn set_connection( async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
self: &Arc<Self>,
user_id: u64,
conn: Conn,
cx: &AsyncAppContext,
) -> Result<()> {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await; let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
cx.foreground() cx.foreground()
.spawn({ .spawn({
@ -321,7 +329,6 @@ impl Client {
} }
}) })
.detach(); .detach();
Ok(())
} }
fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> { fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
@ -489,35 +496,6 @@ impl Client {
} }
} }
pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
type Output: 'a + Future<Output = anyhow::Result<()>>;
fn handle(
&self,
message: TypedEnvelope<M>,
rpc: &'a Client,
cx: &'a mut gpui::AsyncAppContext,
) -> Self::Output;
}
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
where
M: proto::EnvelopedMessage,
F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
Fut: 'a + Future<Output = anyhow::Result<()>>,
{
type Output = Fut;
fn handle(
&self,
message: TypedEnvelope<M>,
rpc: &'a Client,
cx: &'a mut gpui::AsyncAppContext,
) -> Self::Output {
(self)(message, rpc, cx)
}
}
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/"; const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
pub fn encode_worktree_url(id: u64, access_token: &str) -> String { pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
@ -550,6 +528,8 @@ mod tests {
#[gpui::test(iterations = 10)] #[gpui::test(iterations = 10)]
async fn test_heartbeat(cx: TestAppContext) { async fn test_heartbeat(cx: TestAppContext) {
cx.foreground().forbid_parking();
let user_id = 5; let user_id = 5;
let mut client = Client::new(); let mut client = Client::new();
let server = FakeServer::for_client(user_id, &mut client, &cx).await; let server = FakeServer::for_client(user_id, &mut client, &cx).await;
@ -568,6 +548,28 @@ mod tests {
assert!(server.receive::<proto::Ping>().await.is_err()); assert!(server.receive::<proto::Ping>().await.is_err());
} }
#[gpui::test(iterations = 10)]
async fn test_reconnection(cx: TestAppContext) {
cx.foreground().forbid_parking();
let user_id = 5;
let mut client = Client::new();
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
let mut status = client.status();
assert!(matches!(
status.recv().await,
Some(Status::Connected { .. })
));
server.forbid_connections();
server.disconnect().await;
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
server.allow_connections();
cx.foreground().advance_clock(Duration::from_secs(10));
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
}
#[test] #[test]
fn test_encode_and_decode_worktree_url() { fn test_encode_and_decode_worktree_url() {
let url = encode_worktree_url(5, "deadbeef"); let url = encode_worktree_url(5, "deadbeef");

View file

@ -17,7 +17,10 @@ use smol::channel;
use std::{ use std::{
marker::PhantomData, marker::PhantomData,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
}; };
use tempdir::TempDir; use tempdir::TempDir;
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope}; use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
@ -200,6 +203,7 @@ pub struct FakeServer {
peer: Arc<Peer>, peer: Arc<Peer>,
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>, incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
connection_id: Mutex<Option<ConnectionId>>, connection_id: Mutex<Option<ConnectionId>>,
forbid_connections: AtomicBool,
} }
impl FakeServer { impl FakeServer {
@ -212,6 +216,7 @@ impl FakeServer {
peer: Peer::new(), peer: Peer::new(),
incoming: Default::default(), incoming: Default::default(),
connection_id: Default::default(), connection_id: Default::default(),
forbid_connections: Default::default(),
}); });
Arc::get_mut(client) Arc::get_mut(client)
@ -230,15 +235,14 @@ impl FakeServer {
assert_eq!(access_token, "the-token"); assert_eq!(access_token, "the-token");
cx.spawn({ cx.spawn({
let server = server.clone(); let server = server.clone();
move |cx| async move { Ok(server.connect(&cx).await) } move |cx| async move { server.connect(&cx).await }
}) })
} }
}, },
); );
let conn = result.connect(&cx.to_async()).await;
client client
.set_connection(client_user_id, conn, &cx.to_async()) .authenticate_and_connect(&cx.to_async())
.await .await
.unwrap(); .unwrap();
result result
@ -250,13 +254,25 @@ impl FakeServer {
self.incoming.lock().take(); self.incoming.lock().take();
} }
async fn connect(&self, cx: &AsyncAppContext) -> Conn { async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
let (client_conn, server_conn) = Conn::in_memory(); if self.forbid_connections.load(SeqCst) {
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; Err(anyhow!("server is forbidding connections"))
cx.background().spawn(io).detach(); } else {
*self.incoming.lock() = Some(incoming); let (client_conn, server_conn) = Conn::in_memory();
*self.connection_id.lock() = Some(connection_id); let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
client_conn cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
Ok(client_conn)
}
}
pub fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}
pub fn allow_connections(&self) {
self.forbid_connections.store(false, SeqCst);
} }
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) { pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {