Refactor and write a simple unit test to verify reconnection logic
This commit is contained in:
parent
6baa9fe37b
commit
ad7631de9f
3 changed files with 121 additions and 77 deletions
|
@ -1693,20 +1693,46 @@ mod tests {
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
name: &str,
|
name: &str,
|
||||||
) -> (UserId, Arc<Client>) {
|
) -> (UserId, Arc<Client>) {
|
||||||
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
||||||
let client = Client::new();
|
let client_name = name.to_string();
|
||||||
let (client_conn, server_conn) = Conn::in_memory();
|
let mut client = Client::new();
|
||||||
cx.background()
|
let server = self.server.clone();
|
||||||
.spawn(
|
Arc::get_mut(&mut client)
|
||||||
self.server
|
.unwrap()
|
||||||
.handle_connection(server_conn, name.to_string(), user_id),
|
.set_login_and_connect_callbacks(
|
||||||
)
|
move |cx| {
|
||||||
.detach();
|
cx.spawn(|_| async move {
|
||||||
|
let access_token = "the-token".to_string();
|
||||||
|
Ok((client_user_id.0 as u64, access_token))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
move |user_id, access_token, cx| {
|
||||||
|
assert_eq!(user_id, client_user_id.0 as u64);
|
||||||
|
assert_eq!(access_token, "the-token");
|
||||||
|
|
||||||
|
let server = server.clone();
|
||||||
|
let client_name = client_name.clone();
|
||||||
|
cx.spawn(move |cx| async move {
|
||||||
|
let (client_conn, server_conn) = Conn::in_memory();
|
||||||
|
cx.background()
|
||||||
|
.spawn(server.handle_connection(
|
||||||
|
server_conn,
|
||||||
|
client_name,
|
||||||
|
client_user_id,
|
||||||
|
))
|
||||||
|
.detach();
|
||||||
|
Ok(client_conn)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
client
|
client
|
||||||
.set_connection(user_id.to_proto(), client_conn, &cx.to_async())
|
.authenticate_and_connect(&cx.to_async())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
(user_id, client)
|
(client_user_id, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
|
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
|
||||||
|
|
112
zed/src/rpc.rs
112
zed/src/rpc.rs
|
@ -49,7 +49,10 @@ pub enum Status {
|
||||||
user_id: u64,
|
user_id: u64,
|
||||||
},
|
},
|
||||||
ConnectionLost,
|
ConnectionLost,
|
||||||
Reconnecting,
|
Reauthenticating,
|
||||||
|
Reconnecting {
|
||||||
|
user_id: u64,
|
||||||
|
},
|
||||||
ReconnectionError {
|
ReconnectionError {
|
||||||
next_reconnection: Instant,
|
next_reconnection: Instant,
|
||||||
},
|
},
|
||||||
|
@ -164,9 +167,10 @@ impl Client {
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
_ => {
|
Status::Disconnected => {
|
||||||
state._maintain_connection.take();
|
state._maintain_connection.take();
|
||||||
}
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -227,14 +231,20 @@ impl Client {
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
if matches!(
|
let was_disconnected = match *self.status().borrow() {
|
||||||
*self.status().borrow(),
|
Status::Disconnected => true,
|
||||||
Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. }
|
Status::Connected { .. }
|
||||||
) {
|
| Status::Connecting { .. }
|
||||||
return Ok(());
|
| Status::Reconnecting { .. }
|
||||||
}
|
| Status::Reauthenticating => return Ok(()),
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
|
||||||
self.set_status(Status::Authenticating, cx);
|
if was_disconnected {
|
||||||
|
self.set_status(Status::Authenticating, cx);
|
||||||
|
} else {
|
||||||
|
self.set_status(Status::Reauthenticating, cx)
|
||||||
|
}
|
||||||
|
|
||||||
let (user_id, access_token) = match self.authenticate(&cx).await {
|
let (user_id, access_token) = match self.authenticate(&cx).await {
|
||||||
Ok(result) => result,
|
Ok(result) => result,
|
||||||
|
@ -244,27 +254,25 @@ impl Client {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.set_status(Status::Connecting { user_id }, cx);
|
if was_disconnected {
|
||||||
|
self.set_status(Status::Connecting { user_id }, cx);
|
||||||
let conn = match self.connect(user_id, &access_token, cx).await {
|
} else {
|
||||||
Ok(conn) => conn,
|
self.set_status(Status::Reconnecting { user_id }, cx);
|
||||||
|
}
|
||||||
|
match self.connect(user_id, &access_token, cx).await {
|
||||||
|
Ok(conn) => {
|
||||||
|
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
||||||
|
self.set_connection(user_id, conn, cx).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
self.set_status(Status::ConnectionError, cx);
|
self.set_status(Status::ConnectionError, cx);
|
||||||
return Err(err);
|
Err(err)
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
self.set_connection(user_id, conn, cx).await?;
|
|
||||||
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn set_connection(
|
async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
|
||||||
self: &Arc<Self>,
|
|
||||||
user_id: u64,
|
|
||||||
conn: Conn,
|
|
||||||
cx: &AsyncAppContext,
|
|
||||||
) -> Result<()> {
|
|
||||||
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
|
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
|
||||||
cx.foreground()
|
cx.foreground()
|
||||||
.spawn({
|
.spawn({
|
||||||
|
@ -321,7 +329,6 @@ impl Client {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
|
fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
|
||||||
|
@ -489,35 +496,6 @@ impl Client {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
|
|
||||||
type Output: 'a + Future<Output = anyhow::Result<()>>;
|
|
||||||
|
|
||||||
fn handle(
|
|
||||||
&self,
|
|
||||||
message: TypedEnvelope<M>,
|
|
||||||
rpc: &'a Client,
|
|
||||||
cx: &'a mut gpui::AsyncAppContext,
|
|
||||||
) -> Self::Output;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
|
|
||||||
where
|
|
||||||
M: proto::EnvelopedMessage,
|
|
||||||
F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
|
|
||||||
Fut: 'a + Future<Output = anyhow::Result<()>>,
|
|
||||||
{
|
|
||||||
type Output = Fut;
|
|
||||||
|
|
||||||
fn handle(
|
|
||||||
&self,
|
|
||||||
message: TypedEnvelope<M>,
|
|
||||||
rpc: &'a Client,
|
|
||||||
cx: &'a mut gpui::AsyncAppContext,
|
|
||||||
) -> Self::Output {
|
|
||||||
(self)(message, rpc, cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
|
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
|
||||||
|
|
||||||
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
|
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
|
||||||
|
@ -550,6 +528,8 @@ mod tests {
|
||||||
|
|
||||||
#[gpui::test(iterations = 10)]
|
#[gpui::test(iterations = 10)]
|
||||||
async fn test_heartbeat(cx: TestAppContext) {
|
async fn test_heartbeat(cx: TestAppContext) {
|
||||||
|
cx.foreground().forbid_parking();
|
||||||
|
|
||||||
let user_id = 5;
|
let user_id = 5;
|
||||||
let mut client = Client::new();
|
let mut client = Client::new();
|
||||||
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||||
|
@ -568,6 +548,28 @@ mod tests {
|
||||||
assert!(server.receive::<proto::Ping>().await.is_err());
|
assert!(server.receive::<proto::Ping>().await.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test(iterations = 10)]
|
||||||
|
async fn test_reconnection(cx: TestAppContext) {
|
||||||
|
cx.foreground().forbid_parking();
|
||||||
|
|
||||||
|
let user_id = 5;
|
||||||
|
let mut client = Client::new();
|
||||||
|
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||||
|
let mut status = client.status();
|
||||||
|
assert!(matches!(
|
||||||
|
status.recv().await,
|
||||||
|
Some(Status::Connected { .. })
|
||||||
|
));
|
||||||
|
|
||||||
|
server.forbid_connections();
|
||||||
|
server.disconnect().await;
|
||||||
|
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
|
||||||
|
|
||||||
|
server.allow_connections();
|
||||||
|
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||||
|
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_encode_and_decode_worktree_url() {
|
fn test_encode_and_decode_worktree_url() {
|
||||||
let url = encode_worktree_url(5, "deadbeef");
|
let url = encode_worktree_url(5, "deadbeef");
|
||||||
|
|
|
@ -17,7 +17,10 @@ use smol::channel;
|
||||||
use std::{
|
use std::{
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
sync::Arc,
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering::SeqCst},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use tempdir::TempDir;
|
use tempdir::TempDir;
|
||||||
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
|
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
|
||||||
|
@ -200,6 +203,7 @@ pub struct FakeServer {
|
||||||
peer: Arc<Peer>,
|
peer: Arc<Peer>,
|
||||||
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
|
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
|
||||||
connection_id: Mutex<Option<ConnectionId>>,
|
connection_id: Mutex<Option<ConnectionId>>,
|
||||||
|
forbid_connections: AtomicBool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeServer {
|
impl FakeServer {
|
||||||
|
@ -212,6 +216,7 @@ impl FakeServer {
|
||||||
peer: Peer::new(),
|
peer: Peer::new(),
|
||||||
incoming: Default::default(),
|
incoming: Default::default(),
|
||||||
connection_id: Default::default(),
|
connection_id: Default::default(),
|
||||||
|
forbid_connections: Default::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
Arc::get_mut(client)
|
Arc::get_mut(client)
|
||||||
|
@ -230,15 +235,14 @@ impl FakeServer {
|
||||||
assert_eq!(access_token, "the-token");
|
assert_eq!(access_token, "the-token");
|
||||||
cx.spawn({
|
cx.spawn({
|
||||||
let server = server.clone();
|
let server = server.clone();
|
||||||
move |cx| async move { Ok(server.connect(&cx).await) }
|
move |cx| async move { server.connect(&cx).await }
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
let conn = result.connect(&cx.to_async()).await;
|
|
||||||
client
|
client
|
||||||
.set_connection(client_user_id, conn, &cx.to_async())
|
.authenticate_and_connect(&cx.to_async())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
result
|
result
|
||||||
|
@ -250,13 +254,25 @@ impl FakeServer {
|
||||||
self.incoming.lock().take();
|
self.incoming.lock().take();
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn connect(&self, cx: &AsyncAppContext) -> Conn {
|
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
|
||||||
let (client_conn, server_conn) = Conn::in_memory();
|
if self.forbid_connections.load(SeqCst) {
|
||||||
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
Err(anyhow!("server is forbidding connections"))
|
||||||
cx.background().spawn(io).detach();
|
} else {
|
||||||
*self.incoming.lock() = Some(incoming);
|
let (client_conn, server_conn) = Conn::in_memory();
|
||||||
*self.connection_id.lock() = Some(connection_id);
|
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
||||||
client_conn
|
cx.background().spawn(io).detach();
|
||||||
|
*self.incoming.lock() = Some(incoming);
|
||||||
|
*self.connection_id.lock() = Some(connection_id);
|
||||||
|
Ok(client_conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forbid_connections(&self) {
|
||||||
|
self.forbid_connections.store(true, SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allow_connections(&self) {
|
||||||
|
self.forbid_connections.store(false, SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
|
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue