Refactor and write a simple unit test to verify reconnection logic
This commit is contained in:
parent
6baa9fe37b
commit
ad7631de9f
3 changed files with 121 additions and 77 deletions
|
@ -1693,20 +1693,46 @@ mod tests {
|
|||
cx: &mut TestAppContext,
|
||||
name: &str,
|
||||
) -> (UserId, Arc<Client>) {
|
||||
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
||||
let client = Client::new();
|
||||
let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
||||
let client_name = name.to_string();
|
||||
let mut client = Client::new();
|
||||
let server = self.server.clone();
|
||||
Arc::get_mut(&mut client)
|
||||
.unwrap()
|
||||
.set_login_and_connect_callbacks(
|
||||
move |cx| {
|
||||
cx.spawn(|_| async move {
|
||||
let access_token = "the-token".to_string();
|
||||
Ok((client_user_id.0 as u64, access_token))
|
||||
})
|
||||
},
|
||||
{
|
||||
move |user_id, access_token, cx| {
|
||||
assert_eq!(user_id, client_user_id.0 as u64);
|
||||
assert_eq!(access_token, "the-token");
|
||||
|
||||
let server = server.clone();
|
||||
let client_name = client_name.clone();
|
||||
cx.spawn(move |cx| async move {
|
||||
let (client_conn, server_conn) = Conn::in_memory();
|
||||
cx.background()
|
||||
.spawn(
|
||||
self.server
|
||||
.handle_connection(server_conn, name.to_string(), user_id),
|
||||
)
|
||||
.spawn(server.handle_connection(
|
||||
server_conn,
|
||||
client_name,
|
||||
client_user_id,
|
||||
))
|
||||
.detach();
|
||||
Ok(client_conn)
|
||||
})
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
client
|
||||
.set_connection(user_id.to_proto(), client_conn, &cx.to_async())
|
||||
.authenticate_and_connect(&cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
(user_id, client)
|
||||
(client_user_id, client)
|
||||
}
|
||||
|
||||
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
|
||||
|
|
108
zed/src/rpc.rs
108
zed/src/rpc.rs
|
@ -49,7 +49,10 @@ pub enum Status {
|
|||
user_id: u64,
|
||||
},
|
||||
ConnectionLost,
|
||||
Reconnecting,
|
||||
Reauthenticating,
|
||||
Reconnecting {
|
||||
user_id: u64,
|
||||
},
|
||||
ReconnectionError {
|
||||
next_reconnection: Instant,
|
||||
},
|
||||
|
@ -164,9 +167,10 @@ impl Client {
|
|||
}
|
||||
}));
|
||||
}
|
||||
_ => {
|
||||
Status::Disconnected => {
|
||||
state._maintain_connection.take();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -227,14 +231,20 @@ impl Client {
|
|||
self: &Arc<Self>,
|
||||
cx: &AsyncAppContext,
|
||||
) -> anyhow::Result<()> {
|
||||
if matches!(
|
||||
*self.status().borrow(),
|
||||
Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. }
|
||||
) {
|
||||
return Ok(());
|
||||
}
|
||||
let was_disconnected = match *self.status().borrow() {
|
||||
Status::Disconnected => true,
|
||||
Status::Connected { .. }
|
||||
| Status::Connecting { .. }
|
||||
| Status::Reconnecting { .. }
|
||||
| Status::Reauthenticating => return Ok(()),
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx)
|
||||
}
|
||||
|
||||
let (user_id, access_token) = match self.authenticate(&cx).await {
|
||||
Ok(result) => result,
|
||||
|
@ -244,27 +254,25 @@ impl Client {
|
|||
}
|
||||
};
|
||||
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Connecting { user_id }, cx);
|
||||
|
||||
let conn = match self.connect(user_id, &access_token, cx).await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
return Err(err);
|
||||
} else {
|
||||
self.set_status(Status::Reconnecting { user_id }, cx);
|
||||
}
|
||||
};
|
||||
|
||||
self.set_connection(user_id, conn, cx).await?;
|
||||
match self.connect(user_id, &access_token, cx).await {
|
||||
Ok(conn) => {
|
||||
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
||||
self.set_connection(user_id, conn, cx).await;
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_connection(
|
||||
self: &Arc<Self>,
|
||||
user_id: u64,
|
||||
conn: Conn,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
|
||||
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
|
||||
cx.foreground()
|
||||
.spawn({
|
||||
|
@ -321,7 +329,6 @@ impl Client {
|
|||
}
|
||||
})
|
||||
.detach();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
|
||||
|
@ -489,35 +496,6 @@ impl Client {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
|
||||
type Output: 'a + Future<Output = anyhow::Result<()>>;
|
||||
|
||||
fn handle(
|
||||
&self,
|
||||
message: TypedEnvelope<M>,
|
||||
rpc: &'a Client,
|
||||
cx: &'a mut gpui::AsyncAppContext,
|
||||
) -> Self::Output;
|
||||
}
|
||||
|
||||
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
|
||||
where
|
||||
M: proto::EnvelopedMessage,
|
||||
F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
|
||||
Fut: 'a + Future<Output = anyhow::Result<()>>,
|
||||
{
|
||||
type Output = Fut;
|
||||
|
||||
fn handle(
|
||||
&self,
|
||||
message: TypedEnvelope<M>,
|
||||
rpc: &'a Client,
|
||||
cx: &'a mut gpui::AsyncAppContext,
|
||||
) -> Self::Output {
|
||||
(self)(message, rpc, cx)
|
||||
}
|
||||
}
|
||||
|
||||
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
|
||||
|
||||
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
|
||||
|
@ -550,6 +528,8 @@ mod tests {
|
|||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_heartbeat(cx: TestAppContext) {
|
||||
cx.foreground().forbid_parking();
|
||||
|
||||
let user_id = 5;
|
||||
let mut client = Client::new();
|
||||
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||
|
@ -568,6 +548,28 @@ mod tests {
|
|||
assert!(server.receive::<proto::Ping>().await.is_err());
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_reconnection(cx: TestAppContext) {
|
||||
cx.foreground().forbid_parking();
|
||||
|
||||
let user_id = 5;
|
||||
let mut client = Client::new();
|
||||
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||
let mut status = client.status();
|
||||
assert!(matches!(
|
||||
status.recv().await,
|
||||
Some(Status::Connected { .. })
|
||||
));
|
||||
|
||||
server.forbid_connections();
|
||||
server.disconnect().await;
|
||||
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
|
||||
|
||||
server.allow_connections();
|
||||
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_and_decode_worktree_url() {
|
||||
let url = encode_worktree_url(5, "deadbeef");
|
||||
|
|
|
@ -17,7 +17,10 @@ use smol::channel;
|
|||
use std::{
|
||||
marker::PhantomData,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tempdir::TempDir;
|
||||
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
|
||||
|
@ -200,6 +203,7 @@ pub struct FakeServer {
|
|||
peer: Arc<Peer>,
|
||||
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
|
||||
connection_id: Mutex<Option<ConnectionId>>,
|
||||
forbid_connections: AtomicBool,
|
||||
}
|
||||
|
||||
impl FakeServer {
|
||||
|
@ -212,6 +216,7 @@ impl FakeServer {
|
|||
peer: Peer::new(),
|
||||
incoming: Default::default(),
|
||||
connection_id: Default::default(),
|
||||
forbid_connections: Default::default(),
|
||||
});
|
||||
|
||||
Arc::get_mut(client)
|
||||
|
@ -230,15 +235,14 @@ impl FakeServer {
|
|||
assert_eq!(access_token, "the-token");
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
move |cx| async move { Ok(server.connect(&cx).await) }
|
||||
move |cx| async move { server.connect(&cx).await }
|
||||
})
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let conn = result.connect(&cx.to_async()).await;
|
||||
client
|
||||
.set_connection(client_user_id, conn, &cx.to_async())
|
||||
.authenticate_and_connect(&cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
result
|
||||
|
@ -250,13 +254,25 @@ impl FakeServer {
|
|||
self.incoming.lock().take();
|
||||
}
|
||||
|
||||
async fn connect(&self, cx: &AsyncAppContext) -> Conn {
|
||||
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
|
||||
if self.forbid_connections.load(SeqCst) {
|
||||
Err(anyhow!("server is forbidding connections"))
|
||||
} else {
|
||||
let (client_conn, server_conn) = Conn::in_memory();
|
||||
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
||||
cx.background().spawn(io).detach();
|
||||
*self.incoming.lock() = Some(incoming);
|
||||
*self.connection_id.lock() = Some(connection_id);
|
||||
client_conn
|
||||
Ok(client_conn)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forbid_connections(&self) {
|
||||
self.forbid_connections.store(true, SeqCst);
|
||||
}
|
||||
|
||||
pub fn allow_connections(&self) {
|
||||
self.forbid_connections.store(false, SeqCst);
|
||||
}
|
||||
|
||||
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue