Introduce a concrete Conn
type for peer's websocket connection
This is mostly to set us up to test the rpc::Client's reconnect logic. There are multiple ways that the `rpc::Client` may establish its websocket connection: (SSL in production, plain TCP during local development, and using an in-memory connection for tests). Now we can represent all of those connections using a common type. Also, several long methods no longer need to be generic, which is good for compile time.
This commit is contained in:
parent
c3e29e0a2d
commit
b6eac57f63
9 changed files with 196 additions and 228 deletions
|
@ -5,10 +5,7 @@ use super::{
|
|||
};
|
||||
use anyhow::anyhow;
|
||||
use async_std::{sync::RwLock, task};
|
||||
use async_tungstenite::{
|
||||
tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
|
||||
WebSocketStream,
|
||||
};
|
||||
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
|
||||
use futures::{future::BoxFuture, FutureExt};
|
||||
use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
|
||||
use sha1::{Digest as _, Sha1};
|
||||
|
@ -30,7 +27,7 @@ use time::OffsetDateTime;
|
|||
use zrpc::{
|
||||
auth::random_token,
|
||||
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
|
||||
ConnectionId, Peer, TypedEnvelope,
|
||||
Conn, ConnectionId, Peer, TypedEnvelope,
|
||||
};
|
||||
|
||||
type ReplicaId = u16;
|
||||
|
@ -133,19 +130,12 @@ impl Server {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn handle_connection<Conn>(
|
||||
pub fn handle_connection(
|
||||
self: &Arc<Self>,
|
||||
connection: Conn,
|
||||
addr: String,
|
||||
user_id: UserId,
|
||||
) -> impl Future<Output = ()>
|
||||
where
|
||||
Conn: 'static
|
||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
||||
+ Send
|
||||
+ Unpin,
|
||||
{
|
||||
) -> impl Future<Output = ()> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let (connection_id, handle_io, mut incoming_rx) =
|
||||
|
@ -974,8 +964,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
|||
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
|
||||
task::spawn(async move {
|
||||
if let Some(stream) = upgrade_receiver.await {
|
||||
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
|
||||
server.handle_connection(stream, addr, user_id).await;
|
||||
server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -1019,7 +1008,7 @@ mod tests {
|
|||
fs::{FakeFs, Fs as _},
|
||||
language::LanguageRegistry,
|
||||
rpc::Client,
|
||||
settings, test,
|
||||
settings,
|
||||
user::UserStore,
|
||||
worktree::Worktree,
|
||||
};
|
||||
|
@ -1706,7 +1695,7 @@ mod tests {
|
|||
) -> (UserId, Arc<Client>) {
|
||||
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
||||
let client = Client::new();
|
||||
let (client_conn, server_conn) = test::Channel::bidirectional();
|
||||
let (client_conn, server_conn) = Conn::in_memory();
|
||||
cx.background()
|
||||
.spawn(
|
||||
self.server
|
||||
|
|
|
@ -445,12 +445,13 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::test::FakeServer;
|
||||
use gpui::TestAppContext;
|
||||
use std::time::Duration;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_channel_messages(mut cx: TestAppContext) {
|
||||
let user_id = 5;
|
||||
let client = Client::new();
|
||||
let mut server = FakeServer::for_client(user_id, &client, &cx).await;
|
||||
let server = FakeServer::for_client(user_id, &client, &cx).await;
|
||||
let user_store = Arc::new(UserStore::new(client.clone()));
|
||||
|
||||
let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
|
||||
|
|
137
zed/src/rpc.rs
137
zed/src/rpc.rs
|
@ -1,8 +1,6 @@
|
|||
use crate::util::ResultExt;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_tungstenite::tungstenite::{
|
||||
http::Request, Error as WebSocketError, Message as WebSocketMessage,
|
||||
};
|
||||
use async_tungstenite::tungstenite::http::Request;
|
||||
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::RwLock;
|
||||
|
@ -19,7 +17,7 @@ use surf::Url;
|
|||
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
|
||||
use zrpc::{
|
||||
proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
|
||||
Peer, Receipt,
|
||||
Conn, Peer, Receipt,
|
||||
};
|
||||
|
||||
lazy_static! {
|
||||
|
@ -106,6 +104,7 @@ impl Client {
|
|||
fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
|
||||
let mut state = self.state.write();
|
||||
*state.status.0.borrow_mut() = status;
|
||||
|
||||
match status {
|
||||
Status::Connected { .. } => {
|
||||
let heartbeat_interval = state.heartbeat_interval;
|
||||
|
@ -193,75 +192,46 @@ impl Client {
|
|||
) -> anyhow::Result<()> {
|
||||
if matches!(
|
||||
*self.status().borrow(),
|
||||
Status::Connecting | Status::Connected { .. }
|
||||
Status::Connecting { .. } | Status::Connected { .. }
|
||||
) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
|
||||
let user_id = user_id.parse::<u64>()?;
|
||||
|
||||
self.set_status(Status::Connecting, cx);
|
||||
match self.connect(user_id, &access_token, cx).await {
|
||||
Ok(()) => {
|
||||
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
||||
Ok(())
|
||||
}
|
||||
let (user_id, access_token) = match self.authenticate(&cx).await {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
Err(err)
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.set_status(Status::Connecting, cx);
|
||||
|
||||
let conn = match self.connect(user_id, &access_token, cx).await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
self.set_connection(user_id, conn, cx).await?;
|
||||
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
self: &Arc<Self>,
|
||||
user_id: u64,
|
||||
access_token: &str,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
let request =
|
||||
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
|
||||
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
|
||||
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
self.set_connection(user_id, stream, cx).await?;
|
||||
Ok(())
|
||||
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
|
||||
let (stream, _) = async_tungstenite::client_async(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
self.set_connection(user_id, stream, cx).await?;
|
||||
Ok(())
|
||||
} else {
|
||||
return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL));
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_connection<Conn>(
|
||||
pub async fn set_connection(
|
||||
self: &Arc<Self>,
|
||||
user_id: u64,
|
||||
conn: Conn,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Result<()>
|
||||
where
|
||||
Conn: 'static
|
||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
||||
+ Unpin
|
||||
+ Send,
|
||||
{
|
||||
) -> Result<()> {
|
||||
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
|
||||
{
|
||||
let mut cx = cx.clone();
|
||||
let this = self.clone();
|
||||
cx.foreground()
|
||||
.spawn(async move {
|
||||
cx.foreground()
|
||||
.spawn({
|
||||
let mut cx = cx.clone();
|
||||
let this = self.clone();
|
||||
async move {
|
||||
while let Some(message) = incoming.recv().await {
|
||||
let mut state = this.state.write();
|
||||
if let Some(extract_entity_id) =
|
||||
|
@ -286,9 +256,9 @@ impl Client {
|
|||
log::info!("unhandled message {}", message.payload_type_name());
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
self.set_status(
|
||||
Status::Connected {
|
||||
|
@ -315,11 +285,38 @@ impl Client {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn login(
|
||||
platform: Arc<dyn gpui::Platform>,
|
||||
executor: &Arc<gpui::executor::Background>,
|
||||
) -> Task<Result<(String, String)>> {
|
||||
let executor = executor.clone();
|
||||
fn connect(
|
||||
self: &Arc<Self>,
|
||||
user_id: u64,
|
||||
access_token: &str,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Task<Result<Conn>> {
|
||||
let request =
|
||||
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
|
||||
cx.background().spawn(async move {
|
||||
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
|
||||
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
Ok(Conn::new(stream))
|
||||
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
|
||||
let (stream, _) = async_tungstenite::client_async(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
Ok(Conn::new(stream))
|
||||
} else {
|
||||
Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
|
||||
let platform = cx.platform();
|
||||
let executor = cx.background();
|
||||
executor.clone().spawn(async move {
|
||||
if let Some((user_id, access_token)) = platform
|
||||
.read_credentials(&ZED_SERVER_URL)
|
||||
|
@ -327,7 +324,7 @@ impl Client {
|
|||
.flatten()
|
||||
{
|
||||
log::info!("already signed in. user_id: {}", user_id);
|
||||
return Ok((user_id, String::from_utf8(access_token).unwrap()));
|
||||
return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
|
||||
}
|
||||
|
||||
// Generate a pair of asymmetric encryption keys. The public key will be used by the
|
||||
|
@ -393,7 +390,7 @@ impl Client {
|
|||
platform
|
||||
.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
|
||||
.log_err();
|
||||
Ok((user_id.to_string(), access_token))
|
||||
Ok((user_id.parse()?, access_token))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -492,7 +489,7 @@ mod tests {
|
|||
async fn test_heartbeat(cx: TestAppContext) {
|
||||
let user_id = 5;
|
||||
let client = Client::new();
|
||||
let mut server = FakeServer::for_client(user_id, &client, &cx).await;
|
||||
let server = FakeServer::for_client(user_id, &client, &cx).await;
|
||||
|
||||
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||
let ping = server.receive::<proto::Ping>().await.unwrap();
|
||||
|
|
|
@ -10,7 +10,7 @@ use crate::{
|
|||
AppState,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{Entity, ModelHandle, MutableAppContext, TestAppContext};
|
||||
use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
|
||||
use parking_lot::Mutex;
|
||||
use postage::{mpsc, prelude::Stream as _};
|
||||
use smol::channel;
|
||||
|
@ -20,10 +20,7 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
use tempdir::TempDir;
|
||||
use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
pub use zrpc::test::Channel;
|
||||
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
|
@ -201,40 +198,64 @@ impl<T: Entity> Observer<T> {
|
|||
|
||||
pub struct FakeServer {
|
||||
peer: Arc<Peer>,
|
||||
incoming: mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>,
|
||||
connection_id: ConnectionId,
|
||||
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
|
||||
connection_id: Mutex<Option<ConnectionId>>,
|
||||
}
|
||||
|
||||
impl FakeServer {
|
||||
pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
|
||||
let (client_conn, server_conn) = zrpc::test::Channel::bidirectional();
|
||||
let peer = Peer::new();
|
||||
let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
|
||||
cx.background().spawn(io).detach();
|
||||
pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Arc<Self> {
|
||||
let result = Arc::new(Self {
|
||||
peer: Peer::new(),
|
||||
incoming: Default::default(),
|
||||
connection_id: Default::default(),
|
||||
});
|
||||
|
||||
let conn = result.connect(&cx.to_async()).await;
|
||||
client
|
||||
.set_connection(user_id, client_conn, &cx.to_async())
|
||||
.set_connection(user_id, conn, &cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
Self {
|
||||
peer,
|
||||
incoming,
|
||||
connection_id,
|
||||
}
|
||||
pub async fn disconnect(&self) {
|
||||
self.peer.disconnect(self.connection_id()).await;
|
||||
self.connection_id.lock().take();
|
||||
self.incoming.lock().take();
|
||||
}
|
||||
|
||||
async fn connect(&self, cx: &AsyncAppContext) -> Conn {
|
||||
let (client_conn, server_conn) = Conn::in_memory();
|
||||
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
||||
cx.background().spawn(io).detach();
|
||||
*self.incoming.lock() = Some(incoming);
|
||||
*self.connection_id.lock() = Some(connection_id);
|
||||
client_conn
|
||||
}
|
||||
|
||||
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
|
||||
self.peer.send(self.connection_id, message).await.unwrap();
|
||||
self.peer.send(self.connection_id(), message).await.unwrap();
|
||||
}
|
||||
|
||||
pub async fn receive<M: proto::EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
|
||||
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
|
||||
let message = self
|
||||
.incoming
|
||||
.lock()
|
||||
.as_mut()
|
||||
.expect("not connected")
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("other half hung up"))?;
|
||||
Ok(*message.into_any().downcast::<TypedEnvelope<M>>().unwrap())
|
||||
let type_name = message.payload_type_name();
|
||||
Ok(*message
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<M>>()
|
||||
.unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"fake server received unexpected message type: {:?}",
|
||||
type_name
|
||||
);
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn respond<T: proto::RequestMessage>(
|
||||
|
@ -244,4 +265,8 @@ impl FakeServer {
|
|||
) {
|
||||
self.peer.respond(receipt, response).await.unwrap()
|
||||
}
|
||||
|
||||
fn connection_id(&self) -> ConnectionId {
|
||||
self.connection_id.lock().expect("not connected")
|
||||
}
|
||||
}
|
||||
|
|
54
zrpc/src/conn.rs
Normal file
54
zrpc/src/conn.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||
use futures::{SinkExt as _, StreamExt as _};
|
||||
|
||||
pub struct Conn {
|
||||
pub(crate) tx:
|
||||
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
|
||||
pub(crate) rx: Box<
|
||||
dyn 'static
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
|
||||
>,
|
||||
}
|
||||
|
||||
impl Conn {
|
||||
pub fn new<S>(stream: S) -> Self
|
||||
where
|
||||
S: 'static
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
|
||||
{
|
||||
let (tx, rx) = stream.split();
|
||||
Self {
|
||||
tx: Box::new(tx),
|
||||
rx: Box::new(rx),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
|
||||
self.tx.send(message).await
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn in_memory() -> (Self, Self) {
|
||||
use futures::SinkExt as _;
|
||||
use futures::StreamExt as _;
|
||||
use std::io::{Error, ErrorKind};
|
||||
|
||||
let (a_tx, a_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
|
||||
let (b_tx, b_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
|
||||
(
|
||||
Self {
|
||||
tx: Box::new(a_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
|
||||
rx: Box::new(b_rx.map(Ok)),
|
||||
},
|
||||
Self {
|
||||
tx: Box::new(b_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
|
||||
rx: Box::new(a_rx.map(Ok)),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
pub mod auth;
|
||||
mod conn;
|
||||
mod peer;
|
||||
pub mod proto;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
||||
pub use conn::Conn;
|
||||
pub use peer::*;
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
||||
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
||||
use super::Conn;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_lock::{Mutex, RwLock};
|
||||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use futures::FutureExt as _;
|
||||
use postage::{
|
||||
mpsc,
|
||||
prelude::{Sink as _, Stream as _},
|
||||
|
@ -98,21 +98,14 @@ impl Peer {
|
|||
})
|
||||
}
|
||||
|
||||
pub async fn add_connection<Conn>(
|
||||
pub async fn add_connection(
|
||||
self: &Arc<Self>,
|
||||
conn: Conn,
|
||||
) -> (
|
||||
ConnectionId,
|
||||
impl Future<Output = anyhow::Result<()>> + Send,
|
||||
mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
|
||||
)
|
||||
where
|
||||
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
||||
+ Send
|
||||
+ Unpin,
|
||||
{
|
||||
let (tx, rx) = conn.split();
|
||||
) {
|
||||
let connection_id = ConnectionId(
|
||||
self.next_connection_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst),
|
||||
|
@ -124,8 +117,8 @@ impl Peer {
|
|||
next_message_id: Default::default(),
|
||||
response_channels: Default::default(),
|
||||
};
|
||||
let mut writer = MessageStream::new(tx);
|
||||
let mut reader = MessageStream::new(rx);
|
||||
let mut writer = MessageStream::new(conn.tx);
|
||||
let mut reader = MessageStream::new(conn.rx);
|
||||
|
||||
let this = self.clone();
|
||||
let response_channels = connection.response_channels.clone();
|
||||
|
@ -347,7 +340,9 @@ impl Peer {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{test, TypedEnvelope};
|
||||
use crate::TypedEnvelope;
|
||||
use async_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
use futures::StreamExt as _;
|
||||
|
||||
#[test]
|
||||
fn test_request_response() {
|
||||
|
@ -357,12 +352,12 @@ mod tests {
|
|||
let client1 = Peer::new();
|
||||
let client2 = Peer::new();
|
||||
|
||||
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
|
||||
let (client1_to_server_conn, server_to_client_1_conn) = Conn::in_memory();
|
||||
let (client1_conn_id, io_task1, _) =
|
||||
client1.add_connection(client1_to_server_conn).await;
|
||||
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
|
||||
|
||||
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
|
||||
let (client2_to_server_conn, server_to_client_2_conn) = Conn::in_memory();
|
||||
let (client2_conn_id, io_task3, _) =
|
||||
client2.add_connection(client2_to_server_conn).await;
|
||||
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
|
||||
|
@ -497,7 +492,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_disconnect() {
|
||||
smol::block_on(async move {
|
||||
let (client_conn, mut server_conn) = test::Channel::bidirectional();
|
||||
let (client_conn, mut server_conn) = Conn::in_memory();
|
||||
|
||||
let client = Peer::new();
|
||||
let (connection_id, io_handler, mut incoming) =
|
||||
|
@ -521,18 +516,17 @@ mod tests {
|
|||
|
||||
io_ended_rx.recv().await;
|
||||
messages_ended_rx.recv().await;
|
||||
assert!(
|
||||
futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
assert!(server_conn
|
||||
.send(WebSocketMessage::Binary(vec![]))
|
||||
.await
|
||||
.is_err());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_io_error() {
|
||||
smol::block_on(async move {
|
||||
let (client_conn, server_conn) = test::Channel::bidirectional();
|
||||
let (client_conn, server_conn) = Conn::in_memory();
|
||||
drop(server_conn);
|
||||
|
||||
let client = Peer::new();
|
||||
|
|
|
@ -247,30 +247,3 @@ impl From<SystemTime> for Timestamp {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::test;
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_message() {
|
||||
smol::block_on(async {
|
||||
let stream = test::Channel::new();
|
||||
let message1 = Ping { id: 5 }.into_envelope(3, None, None);
|
||||
let message2 = OpenBuffer {
|
||||
worktree_id: 0,
|
||||
path: "some/path".to_string(),
|
||||
}
|
||||
.into_envelope(5, None, None);
|
||||
|
||||
let mut message_stream = MessageStream::new(stream);
|
||||
message_stream.write_message(&message1).await.unwrap();
|
||||
message_stream.write_message(&message2).await.unwrap();
|
||||
let decoded_message1 = message_stream.read_message().await.unwrap();
|
||||
let decoded_message2 = message_stream.read_message().await.unwrap();
|
||||
assert_eq!(decoded_message1, message1);
|
||||
assert_eq!(decoded_message2, message2);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||
use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
pub struct Channel {
|
||||
tx: futures::channel::mpsc::UnboundedSender<WebSocketMessage>,
|
||||
rx: futures::channel::mpsc::UnboundedReceiver<WebSocketMessage>,
|
||||
}
|
||||
|
||||
impl Channel {
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
||||
Self { tx, rx }
|
||||
}
|
||||
|
||||
pub fn bidirectional() -> (Self, Self) {
|
||||
let (a_tx, a_rx) = futures::channel::mpsc::unbounded();
|
||||
let (b_tx, b_rx) = futures::channel::mpsc::unbounded();
|
||||
let a = Self { tx: a_tx, rx: b_rx };
|
||||
let b = Self { tx: b_tx, rx: a_rx };
|
||||
(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl futures::Sink<WebSocketMessage> for Channel {
|
||||
type Error = WebSocketError;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.tx)
|
||||
.poll_ready(cx)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> {
|
||||
Pin::new(&mut self.tx)
|
||||
.start_send(item)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.tx)
|
||||
.poll_flush(cx)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.tx)
|
||||
.poll_close(cx)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl futures::Stream for Channel {
|
||||
type Item = Result<WebSocketMessage, WebSocketError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.rx)
|
||||
.poll_next(cx)
|
||||
.map(|i| i.map(|i| Ok(i)))
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue