Introduce a concrete Conn type for peer's websocket connection

This is mostly to set us up to test the rpc::Client's reconnect
logic.

There are multiple ways that the `rpc::Client` may establish
its websocket connection: (SSL in production, plain TCP during
local development, and using an in-memory connection for tests).
Now we can represent all of those connections using a common type.

Also, several long methods no longer need to be generic, which
is good for compile time.
This commit is contained in:
Max Brunsfeld 2021-09-08 17:49:07 -07:00
parent c3e29e0a2d
commit b6eac57f63
9 changed files with 196 additions and 228 deletions

View file

@ -5,10 +5,7 @@ use super::{
};
use anyhow::anyhow;
use async_std::{sync::RwLock, task};
use async_tungstenite::{
tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
WebSocketStream,
};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use futures::{future::BoxFuture, FutureExt};
use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
use sha1::{Digest as _, Sha1};
@ -30,7 +27,7 @@ use time::OffsetDateTime;
use zrpc::{
auth::random_token,
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
ConnectionId, Peer, TypedEnvelope,
Conn, ConnectionId, Peer, TypedEnvelope,
};
type ReplicaId = u16;
@ -133,19 +130,12 @@ impl Server {
self
}
pub fn handle_connection<Conn>(
pub fn handle_connection(
self: &Arc<Self>,
connection: Conn,
addr: String,
user_id: UserId,
) -> impl Future<Output = ()>
where
Conn: 'static
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Send
+ Unpin,
{
) -> impl Future<Output = ()> {
let this = self.clone();
async move {
let (connection_id, handle_io, mut incoming_rx) =
@ -974,8 +964,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
server.handle_connection(stream, addr, user_id).await;
server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
}
});
@ -1019,7 +1008,7 @@ mod tests {
fs::{FakeFs, Fs as _},
language::LanguageRegistry,
rpc::Client,
settings, test,
settings,
user::UserStore,
worktree::Worktree,
};
@ -1706,7 +1695,7 @@ mod tests {
) -> (UserId, Arc<Client>) {
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
let client = Client::new();
let (client_conn, server_conn) = test::Channel::bidirectional();
let (client_conn, server_conn) = Conn::in_memory();
cx.background()
.spawn(
self.server

View file

@ -445,12 +445,13 @@ mod tests {
use super::*;
use crate::test::FakeServer;
use gpui::TestAppContext;
use std::time::Duration;
#[gpui::test]
async fn test_channel_messages(mut cx: TestAppContext) {
let user_id = 5;
let client = Client::new();
let mut server = FakeServer::for_client(user_id, &client, &cx).await;
let server = FakeServer::for_client(user_id, &client, &cx).await;
let user_store = Arc::new(UserStore::new(client.clone()));
let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));

View file

@ -1,8 +1,6 @@
use crate::util::ResultExt;
use anyhow::{anyhow, Context, Result};
use async_tungstenite::tungstenite::{
http::Request, Error as WebSocketError, Message as WebSocketMessage,
};
use async_tungstenite::tungstenite::http::Request;
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
use lazy_static::lazy_static;
use parking_lot::RwLock;
@ -19,7 +17,7 @@ use surf::Url;
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
use zrpc::{
proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
Peer, Receipt,
Conn, Peer, Receipt,
};
lazy_static! {
@ -106,6 +104,7 @@ impl Client {
fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
let mut state = self.state.write();
*state.status.0.borrow_mut() = status;
match status {
Status::Connected { .. } => {
let heartbeat_interval = state.heartbeat_interval;
@ -193,75 +192,46 @@ impl Client {
) -> anyhow::Result<()> {
if matches!(
*self.status().borrow(),
Status::Connecting | Status::Connected { .. }
Status::Connecting { .. } | Status::Connected { .. }
) {
return Ok(());
}
let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
let user_id = user_id.parse::<u64>()?;
let (user_id, access_token) = match self.authenticate(&cx).await {
Ok(result) => result,
Err(err) => {
self.set_status(Status::ConnectionError, cx);
return Err(err);
}
};
self.set_status(Status::Connecting, cx);
match self.connect(user_id, &access_token, cx).await {
Ok(()) => {
let conn = match self.connect(user_id, &access_token, cx).await {
Ok(conn) => conn,
Err(err) => {
self.set_status(Status::ConnectionError, cx);
return Err(err);
}
};
self.set_connection(user_id, conn, cx).await?;
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
Ok(())
}
Err(err) => {
self.set_status(Status::ConnectionError, cx);
Err(err)
}
}
}
async fn connect(
self: &Arc<Self>,
user_id: u64,
access_token: &str,
cx: &AsyncAppContext,
) -> Result<()> {
let request =
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
.await
.context("websocket handshake")?;
self.set_connection(user_id, stream, cx).await?;
Ok(())
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::client_async(request, stream)
.await
.context("websocket handshake")?;
self.set_connection(user_id, stream, cx).await?;
Ok(())
} else {
return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL));
}
}
pub async fn set_connection<Conn>(
pub async fn set_connection(
self: &Arc<Self>,
user_id: u64,
conn: Conn,
cx: &AsyncAppContext,
) -> Result<()>
where
Conn: 'static
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Unpin
+ Send,
{
) -> Result<()> {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
{
cx.foreground()
.spawn({
let mut cx = cx.clone();
let this = self.clone();
cx.foreground()
.spawn(async move {
async move {
while let Some(message) = incoming.recv().await {
let mut state = this.state.write();
if let Some(extract_entity_id) =
@ -286,9 +256,9 @@ impl Client {
log::info!("unhandled message {}", message.payload_type_name());
}
}
}
})
.detach();
}
self.set_status(
Status::Connected {
@ -315,11 +285,38 @@ impl Client {
Ok(())
}
pub fn login(
platform: Arc<dyn gpui::Platform>,
executor: &Arc<gpui::executor::Background>,
) -> Task<Result<(String, String)>> {
let executor = executor.clone();
fn connect(
self: &Arc<Self>,
user_id: u64,
access_token: &str,
cx: &AsyncAppContext,
) -> Task<Result<Conn>> {
let request =
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
cx.background().spawn(async move {
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
.await
.context("websocket handshake")?;
Ok(Conn::new(stream))
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::client_async(request, stream)
.await
.context("websocket handshake")?;
Ok(Conn::new(stream))
} else {
Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
}
})
}
pub fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
let platform = cx.platform();
let executor = cx.background();
executor.clone().spawn(async move {
if let Some((user_id, access_token)) = platform
.read_credentials(&ZED_SERVER_URL)
@ -327,7 +324,7 @@ impl Client {
.flatten()
{
log::info!("already signed in. user_id: {}", user_id);
return Ok((user_id, String::from_utf8(access_token).unwrap()));
return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
}
// Generate a pair of asymmetric encryption keys. The public key will be used by the
@ -393,7 +390,7 @@ impl Client {
platform
.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
.log_err();
Ok((user_id.to_string(), access_token))
Ok((user_id.parse()?, access_token))
})
}
@ -492,7 +489,7 @@ mod tests {
async fn test_heartbeat(cx: TestAppContext) {
let user_id = 5;
let client = Client::new();
let mut server = FakeServer::for_client(user_id, &client, &cx).await;
let server = FakeServer::for_client(user_id, &client, &cx).await;
cx.foreground().advance_clock(Duration::from_secs(10));
let ping = server.receive::<proto::Ping>().await.unwrap();

View file

@ -10,7 +10,7 @@ use crate::{
AppState,
};
use anyhow::{anyhow, Result};
use gpui::{Entity, ModelHandle, MutableAppContext, TestAppContext};
use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
use parking_lot::Mutex;
use postage::{mpsc, prelude::Stream as _};
use smol::channel;
@ -20,10 +20,7 @@ use std::{
sync::Arc,
};
use tempdir::TempDir;
use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
#[cfg(feature = "test-support")]
pub use zrpc::test::Channel;
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
#[cfg(test)]
#[ctor::ctor]
@ -201,40 +198,64 @@ impl<T: Entity> Observer<T> {
pub struct FakeServer {
peer: Arc<Peer>,
incoming: mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>,
connection_id: ConnectionId,
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
connection_id: Mutex<Option<ConnectionId>>,
}
impl FakeServer {
pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
let (client_conn, server_conn) = zrpc::test::Channel::bidirectional();
let peer = Peer::new();
let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Arc<Self> {
let result = Arc::new(Self {
peer: Peer::new(),
incoming: Default::default(),
connection_id: Default::default(),
});
let conn = result.connect(&cx.to_async()).await;
client
.set_connection(user_id, client_conn, &cx.to_async())
.set_connection(user_id, conn, &cx.to_async())
.await
.unwrap();
Self {
peer,
incoming,
connection_id,
result
}
pub async fn disconnect(&self) {
self.peer.disconnect(self.connection_id()).await;
self.connection_id.lock().take();
self.incoming.lock().take();
}
async fn connect(&self, cx: &AsyncAppContext) -> Conn {
let (client_conn, server_conn) = Conn::in_memory();
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
client_conn
}
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
self.peer.send(self.connection_id, message).await.unwrap();
self.peer.send(self.connection_id(), message).await.unwrap();
}
pub async fn receive<M: proto::EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
let message = self
.incoming
.lock()
.as_mut()
.expect("not connected")
.recv()
.await
.ok_or_else(|| anyhow!("other half hung up"))?;
Ok(*message.into_any().downcast::<TypedEnvelope<M>>().unwrap())
let type_name = message.payload_type_name();
Ok(*message
.into_any()
.downcast::<TypedEnvelope<M>>()
.unwrap_or_else(|_| {
panic!(
"fake server received unexpected message type: {:?}",
type_name
);
}))
}
pub async fn respond<T: proto::RequestMessage>(
@ -244,4 +265,8 @@ impl FakeServer {
) {
self.peer.respond(receipt, response).await.unwrap()
}
fn connection_id(&self) -> ConnectionId {
self.connection_id.lock().expect("not connected")
}
}

54
zrpc/src/conn.rs Normal file
View file

@ -0,0 +1,54 @@
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::{SinkExt as _, StreamExt as _};
pub struct Conn {
pub(crate) tx:
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
pub(crate) rx: Box<
dyn 'static
+ Send
+ Unpin
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
>,
}
impl Conn {
pub fn new<S>(stream: S) -> Self
where
S: 'static
+ Send
+ Unpin
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
{
let (tx, rx) = stream.split();
Self {
tx: Box::new(tx),
rx: Box::new(rx),
}
}
pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
self.tx.send(message).await
}
#[cfg(any(test, feature = "test-support"))]
pub fn in_memory() -> (Self, Self) {
use futures::SinkExt as _;
use futures::StreamExt as _;
use std::io::{Error, ErrorKind};
let (a_tx, a_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
let (b_tx, b_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
(
Self {
tx: Box::new(a_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
rx: Box::new(b_rx.map(Ok)),
},
Self {
tx: Box::new(b_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
rx: Box::new(a_rx.map(Ok)),
},
)
}
}

View file

@ -1,7 +1,6 @@
pub mod auth;
mod conn;
mod peer;
pub mod proto;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
pub use conn::Conn;
pub use peer::*;

View file

@ -1,8 +1,8 @@
use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
use super::Conn;
use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::{FutureExt, StreamExt};
use futures::FutureExt as _;
use postage::{
mpsc,
prelude::{Sink as _, Stream as _},
@ -98,21 +98,14 @@ impl Peer {
})
}
pub async fn add_connection<Conn>(
pub async fn add_connection(
self: &Arc<Self>,
conn: Conn,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
)
where
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Send
+ Unpin,
{
let (tx, rx) = conn.split();
) {
let connection_id = ConnectionId(
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
@ -124,8 +117,8 @@ impl Peer {
next_message_id: Default::default(),
response_channels: Default::default(),
};
let mut writer = MessageStream::new(tx);
let mut reader = MessageStream::new(rx);
let mut writer = MessageStream::new(conn.tx);
let mut reader = MessageStream::new(conn.rx);
let this = self.clone();
let response_channels = connection.response_channels.clone();
@ -347,7 +340,9 @@ impl Peer {
#[cfg(test)]
mod tests {
use super::*;
use crate::{test, TypedEnvelope};
use crate::TypedEnvelope;
use async_tungstenite::tungstenite::Message as WebSocketMessage;
use futures::StreamExt as _;
#[test]
fn test_request_response() {
@ -357,12 +352,12 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
let (client1_to_server_conn, server_to_client_1_conn) = Conn::in_memory();
let (client1_conn_id, io_task1, _) =
client1.add_connection(client1_to_server_conn).await;
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
let (client2_to_server_conn, server_to_client_2_conn) = Conn::in_memory();
let (client2_conn_id, io_task3, _) =
client2.add_connection(client2_to_server_conn).await;
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@ -497,7 +492,7 @@ mod tests {
#[test]
fn test_disconnect() {
smol::block_on(async move {
let (client_conn, mut server_conn) = test::Channel::bidirectional();
let (client_conn, mut server_conn) = Conn::in_memory();
let client = Peer::new();
let (connection_id, io_handler, mut incoming) =
@ -521,18 +516,17 @@ mod tests {
io_ended_rx.recv().await;
messages_ended_rx.recv().await;
assert!(
futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
assert!(server_conn
.send(WebSocketMessage::Binary(vec![]))
.await
.is_err()
);
.is_err());
});
}
#[test]
fn test_io_error() {
smol::block_on(async move {
let (client_conn, server_conn) = test::Channel::bidirectional();
let (client_conn, server_conn) = Conn::in_memory();
drop(server_conn);
let client = Peer::new();

View file

@ -247,30 +247,3 @@ impl From<SystemTime> for Timestamp {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test;
#[test]
fn test_round_trip_message() {
smol::block_on(async {
let stream = test::Channel::new();
let message1 = Ping { id: 5 }.into_envelope(3, None, None);
let message2 = OpenBuffer {
worktree_id: 0,
path: "some/path".to_string(),
}
.into_envelope(5, None, None);
let mut message_stream = MessageStream::new(stream);
message_stream.write_message(&message1).await.unwrap();
message_stream.write_message(&message2).await.unwrap();
let decoded_message1 = message_stream.read_message().await.unwrap();
let decoded_message2 = message_stream.read_message().await.unwrap();
assert_eq!(decoded_message1, message1);
assert_eq!(decoded_message2, message2);
});
}
}

View file

@ -1,64 +0,0 @@
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
pub struct Channel {
tx: futures::channel::mpsc::UnboundedSender<WebSocketMessage>,
rx: futures::channel::mpsc::UnboundedReceiver<WebSocketMessage>,
}
impl Channel {
pub fn new() -> Self {
let (tx, rx) = futures::channel::mpsc::unbounded();
Self { tx, rx }
}
pub fn bidirectional() -> (Self, Self) {
let (a_tx, a_rx) = futures::channel::mpsc::unbounded();
let (b_tx, b_rx) = futures::channel::mpsc::unbounded();
let a = Self { tx: a_tx, rx: b_rx };
let b = Self { tx: b_tx, rx: a_rx };
(a, b)
}
}
impl futures::Sink<WebSocketMessage> for Channel {
type Error = WebSocketError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.tx)
.poll_ready(cx)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
}
fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> {
Pin::new(&mut self.tx)
.start_send(item)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.tx)
.poll_flush(cx)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.tx)
.poll_close(cx)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
}
}
impl futures::Stream for Channel {
type Item = Result<WebSocketMessage, WebSocketError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.rx)
.poll_next(cx)
.map(|i| i.map(|i| Ok(i)))
}
}