Rename zrpc to rpc
Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
fdfed3d7db
commit
d5b60ad124
26 changed files with 62 additions and 66 deletions
136
crates/rpc/src/auth.rs
Normal file
136
crates/rpc/src/auth.rs
Normal file
|
@ -0,0 +1,136 @@
|
|||
use anyhow::{Context, Result};
|
||||
use rand::{thread_rng, Rng as _};
|
||||
use rsa::{PublicKey as _, PublicKeyEncoding, RSAPrivateKey, RSAPublicKey};
|
||||
use std::convert::TryFrom;
|
||||
|
||||
pub struct PublicKey(RSAPublicKey);
|
||||
|
||||
pub struct PrivateKey(RSAPrivateKey);
|
||||
|
||||
/// Generate a public and private key for asymmetric encryption.
|
||||
pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
|
||||
let mut rng = thread_rng();
|
||||
let bits = 1024;
|
||||
let private_key = RSAPrivateKey::new(&mut rng, bits)?;
|
||||
let public_key = RSAPublicKey::from(&private_key);
|
||||
Ok((PublicKey(public_key), PrivateKey(private_key)))
|
||||
}
|
||||
|
||||
/// Generate a random 64-character base64 string.
|
||||
pub fn random_token() -> String {
|
||||
let mut rng = thread_rng();
|
||||
let mut token_bytes = [0; 48];
|
||||
for byte in token_bytes.iter_mut() {
|
||||
*byte = rng.gen();
|
||||
}
|
||||
base64::encode_config(&token_bytes, base64::URL_SAFE)
|
||||
}
|
||||
|
||||
impl PublicKey {
|
||||
/// Convert a string to a base64-encoded string that can only be decoded with the corresponding
|
||||
/// private key.
|
||||
pub fn encrypt_string(&self, string: &str) -> Result<String> {
|
||||
let mut rng = thread_rng();
|
||||
let bytes = string.as_bytes();
|
||||
let encrypted_bytes = self
|
||||
.0
|
||||
.encrypt(&mut rng, PADDING_SCHEME, bytes)
|
||||
.context("failed to encrypt string with public key")?;
|
||||
let encrypted_string = base64::encode_config(&encrypted_bytes, base64::URL_SAFE);
|
||||
Ok(encrypted_string)
|
||||
}
|
||||
}
|
||||
|
||||
impl PrivateKey {
|
||||
/// Decrypt a base64-encoded string that was encrypted by the correspoding public key.
|
||||
pub fn decrypt_string(&self, encrypted_string: &str) -> Result<String> {
|
||||
let encrypted_bytes = base64::decode_config(encrypted_string, base64::URL_SAFE)
|
||||
.context("failed to base64-decode encrypted string")?;
|
||||
let bytes = self
|
||||
.0
|
||||
.decrypt(PADDING_SCHEME, &encrypted_bytes)
|
||||
.context("failed to decrypt string with private key")?;
|
||||
let string = String::from_utf8(bytes).context("decrypted content was not valid utf8")?;
|
||||
Ok(string)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<PublicKey> for String {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(key: PublicKey) -> Result<Self> {
|
||||
let bytes = key.0.to_pkcs1().context("failed to serialize public key")?;
|
||||
let string = base64::encode_config(&bytes, base64::URL_SAFE);
|
||||
Ok(string)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for PublicKey {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
let bytes = base64::decode_config(&value, base64::URL_SAFE)
|
||||
.context("failed to base64-decode public key string")?;
|
||||
let key = Self(RSAPublicKey::from_pkcs1(&bytes).context("failed to parse public key")?);
|
||||
Ok(key)
|
||||
}
|
||||
}
|
||||
|
||||
const PADDING_SCHEME: rsa::PaddingScheme = rsa::PaddingScheme::PKCS1v15Encrypt;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_encrypt_and_decrypt_token() {
|
||||
// CLIENT:
|
||||
// * generate a keypair for asymmetric encryption
|
||||
// * serialize the public key to send it to the server.
|
||||
let (public, private) = keypair().unwrap();
|
||||
let public_string = String::try_from(public).unwrap();
|
||||
assert_printable(&public_string);
|
||||
|
||||
// SERVER:
|
||||
// * parse the public key
|
||||
// * generate a random token.
|
||||
// * encrypt the token using the public key.
|
||||
let public = PublicKey::try_from(public_string).unwrap();
|
||||
let token = random_token();
|
||||
let encrypted_token = public.encrypt_string(&token).unwrap();
|
||||
assert_eq!(token.len(), 64);
|
||||
assert_ne!(encrypted_token, token);
|
||||
assert_printable(&token);
|
||||
assert_printable(&encrypted_token);
|
||||
|
||||
// CLIENT:
|
||||
// * decrypt the token using the private key.
|
||||
let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
|
||||
assert_eq!(decrypted_token, token);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokens_are_always_url_safe() {
|
||||
for _ in 0..5 {
|
||||
let token = random_token();
|
||||
let (public_key, _) = keypair().unwrap();
|
||||
let encrypted_token = public_key.encrypt_string(&token).unwrap();
|
||||
let public_key_str = String::try_from(public_key).unwrap();
|
||||
|
||||
assert_printable(&token);
|
||||
assert_printable(&public_key_str);
|
||||
assert_printable(&encrypted_token);
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_printable(token: &str) {
|
||||
for c in token.chars() {
|
||||
assert!(
|
||||
c.is_ascii_graphic(),
|
||||
"token {:?} has non-printable char {}",
|
||||
token,
|
||||
c
|
||||
);
|
||||
assert_ne!(c, '/', "token {:?} is not URL-safe", token);
|
||||
assert_ne!(c, '&', "token {:?} is not URL-safe", token);
|
||||
}
|
||||
}
|
||||
}
|
101
crates/rpc/src/conn.rs
Normal file
101
crates/rpc/src/conn.rs
Normal file
|
@ -0,0 +1,101 @@
|
|||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||
use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
|
||||
use std::{io, task::Poll};
|
||||
|
||||
pub struct Connection {
|
||||
pub(crate) tx:
|
||||
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
|
||||
pub(crate) rx: Box<
|
||||
dyn 'static
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
|
||||
>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn new<S>(stream: S) -> Self
|
||||
where
|
||||
S: 'static
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
|
||||
{
|
||||
let (tx, rx) = stream.split();
|
||||
Self {
|
||||
tx: Box::new(tx),
|
||||
rx: Box::new(rx),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
|
||||
self.tx.send(message).await
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
|
||||
let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
|
||||
postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
|
||||
|
||||
let (a_tx, a_rx) = Self::channel(kill_rx.clone());
|
||||
let (b_tx, b_rx) = Self::channel(kill_rx);
|
||||
(
|
||||
Self { tx: a_tx, rx: b_rx },
|
||||
Self { tx: b_tx, rx: a_rx },
|
||||
kill_tx,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
fn channel(
|
||||
kill_rx: postage::watch::Receiver<Option<()>>,
|
||||
) -> (
|
||||
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
|
||||
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
|
||||
) {
|
||||
use futures::{future, SinkExt as _};
|
||||
use io::{Error, ErrorKind};
|
||||
|
||||
let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
|
||||
let tx = tx
|
||||
.sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
|
||||
.with({
|
||||
let kill_rx = kill_rx.clone();
|
||||
move |msg| {
|
||||
if kill_rx.borrow().is_none() {
|
||||
future::ready(Ok(msg))
|
||||
} else {
|
||||
future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
|
||||
}
|
||||
}
|
||||
});
|
||||
let rx = KillableReceiver { kill_rx, rx };
|
||||
|
||||
(Box::new(tx), Box::new(rx))
|
||||
}
|
||||
}
|
||||
|
||||
struct KillableReceiver {
|
||||
rx: mpsc::UnboundedReceiver<WebSocketMessage>,
|
||||
kill_rx: postage::watch::Receiver<Option<()>>,
|
||||
}
|
||||
|
||||
impl Stream for KillableReceiver {
|
||||
type Item = Result<WebSocketMessage, WebSocketError>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"connection killed",
|
||||
)
|
||||
.into())))
|
||||
} else {
|
||||
self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
|
||||
}
|
||||
}
|
||||
}
|
8
crates/rpc/src/lib.rs
Normal file
8
crates/rpc/src/lib.rs
Normal file
|
@ -0,0 +1,8 @@
|
|||
pub mod auth;
|
||||
mod conn;
|
||||
mod peer;
|
||||
pub mod proto;
|
||||
pub use conn::Connection;
|
||||
pub use peer::*;
|
||||
|
||||
pub const PROTOCOL_VERSION: u32 = 1;
|
537
crates/rpc/src/peer.rs
Normal file
537
crates/rpc/src/peer.rs
Normal file
|
@ -0,0 +1,537 @@
|
|||
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
||||
use super::Connection;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_lock::{Mutex, RwLock};
|
||||
use futures::FutureExt as _;
|
||||
use postage::{
|
||||
mpsc,
|
||||
prelude::{Sink as _, Stream as _},
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt,
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
sync::{
|
||||
atomic::{self, AtomicU32},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct ConnectionId(pub u32);
|
||||
|
||||
impl fmt::Display for ConnectionId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct PeerId(pub u32);
|
||||
|
||||
impl fmt::Display for PeerId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Receipt<T> {
|
||||
pub sender_id: ConnectionId,
|
||||
pub message_id: u32,
|
||||
payload_type: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Clone for Receipt<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
sender_id: self.sender_id,
|
||||
message_id: self.message_id,
|
||||
payload_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Copy for Receipt<T> {}
|
||||
|
||||
pub struct TypedEnvelope<T> {
|
||||
pub sender_id: ConnectionId,
|
||||
pub original_sender_id: Option<PeerId>,
|
||||
pub message_id: u32,
|
||||
pub payload: T,
|
||||
}
|
||||
|
||||
impl<T> TypedEnvelope<T> {
|
||||
pub fn original_sender_id(&self) -> Result<PeerId> {
|
||||
self.original_sender_id
|
||||
.ok_or_else(|| anyhow!("missing original_sender_id"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RequestMessage> TypedEnvelope<T> {
|
||||
pub fn receipt(&self) -> Receipt<T> {
|
||||
Receipt {
|
||||
sender_id: self.sender_id,
|
||||
message_id: self.message_id,
|
||||
payload_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Peer {
|
||||
connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
|
||||
next_connection_id: AtomicU32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConnectionState {
|
||||
outgoing_tx: mpsc::Sender<proto::Envelope>,
|
||||
next_message_id: Arc<AtomicU32>,
|
||||
response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
|
||||
}
|
||||
|
||||
impl Peer {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
connections: Default::default(),
|
||||
next_connection_id: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn add_connection(
|
||||
self: &Arc<Self>,
|
||||
connection: Connection,
|
||||
) -> (
|
||||
ConnectionId,
|
||||
impl Future<Output = anyhow::Result<()>> + Send,
|
||||
mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
|
||||
) {
|
||||
let connection_id = ConnectionId(
|
||||
self.next_connection_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst),
|
||||
);
|
||||
let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
|
||||
let connection_state = ConnectionState {
|
||||
outgoing_tx,
|
||||
next_message_id: Default::default(),
|
||||
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
|
||||
};
|
||||
let mut writer = MessageStream::new(connection.tx);
|
||||
let mut reader = MessageStream::new(connection.rx);
|
||||
|
||||
let this = self.clone();
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
let handle_io = async move {
|
||||
let result = 'outer: loop {
|
||||
let read_message = reader.read_message().fuse();
|
||||
futures::pin_mut!(read_message);
|
||||
loop {
|
||||
futures::select_biased! {
|
||||
incoming = read_message => match incoming {
|
||||
Ok(incoming) => {
|
||||
if let Some(responding_to) = incoming.responding_to {
|
||||
let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
|
||||
if let Some(mut tx) = channel {
|
||||
tx.send(incoming).await.ok();
|
||||
} else {
|
||||
log::warn!("received RPC response to unknown request {}", responding_to);
|
||||
}
|
||||
} else {
|
||||
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
|
||||
if incoming_tx.send(envelope).await.is_err() {
|
||||
break 'outer Ok(())
|
||||
}
|
||||
} else {
|
||||
log::error!("unable to construct a typed envelope");
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
Err(error) => {
|
||||
break 'outer Err(error).context("received invalid RPC message")
|
||||
}
|
||||
},
|
||||
outgoing = outgoing_rx.recv().fuse() => match outgoing {
|
||||
Some(outgoing) => {
|
||||
if let Err(result) = writer.write_message(&outgoing).await {
|
||||
break 'outer Err(result).context("failed to write RPC message")
|
||||
}
|
||||
}
|
||||
None => break 'outer Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
response_channels.lock().await.take();
|
||||
this.connections.write().await.remove(&connection_id);
|
||||
result
|
||||
};
|
||||
|
||||
self.connections
|
||||
.write()
|
||||
.await
|
||||
.insert(connection_id, connection_state);
|
||||
|
||||
(connection_id, handle_io, incoming_rx)
|
||||
}
|
||||
|
||||
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
||||
self.connections.write().await.remove(&connection_id);
|
||||
}
|
||||
|
||||
pub async fn reset(&self) {
|
||||
self.connections.write().await.clear();
|
||||
}
|
||||
|
||||
pub fn request<T: RequestMessage>(
|
||||
self: &Arc<Self>,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<T::Response>> {
|
||||
self.request_internal(None, receiver_id, request)
|
||||
}
|
||||
|
||||
pub fn forward_request<T: RequestMessage>(
|
||||
self: &Arc<Self>,
|
||||
sender_id: ConnectionId,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<T::Response>> {
|
||||
self.request_internal(Some(sender_id), receiver_id, request)
|
||||
}
|
||||
|
||||
pub fn request_internal<T: RequestMessage>(
|
||||
self: &Arc<Self>,
|
||||
original_sender_id: Option<ConnectionId>,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<T::Response>> {
|
||||
let this = self.clone();
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
async move {
|
||||
let mut connection = this.connection_state(receiver_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.response_channels
|
||||
.lock()
|
||||
.await
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow!("connection was closed"))?
|
||||
.insert(message_id, tx);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
|
||||
.await
|
||||
.map_err(|_| anyhow!("connection was closed"))?;
|
||||
let response = rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("connection was closed"))?;
|
||||
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
|
||||
Err(anyhow!("request failed").context(error.message.clone()))
|
||||
} else {
|
||||
T::Response::from_envelope(response)
|
||||
.ok_or_else(|| anyhow!("received response of the wrong type"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send<T: EnvelopedMessage>(
|
||||
self: &Arc<Self>,
|
||||
receiver_id: ConnectionId,
|
||||
message: T,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let mut connection = this.connection_state(receiver_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.send(message.into_envelope(message_id, None, None))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_send<T: EnvelopedMessage>(
|
||||
self: &Arc<Self>,
|
||||
sender_id: ConnectionId,
|
||||
receiver_id: ConnectionId,
|
||||
message: T,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let mut connection = this.connection_state(receiver_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.send(message.into_envelope(message_id, None, Some(sender_id.0)))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn respond<T: RequestMessage>(
|
||||
self: &Arc<Self>,
|
||||
receipt: Receipt<T>,
|
||||
response: T::Response,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let mut connection = this.connection_state(receipt.sender_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.send(response.into_envelope(message_id, Some(receipt.message_id), None))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn respond_with_error<T: RequestMessage>(
|
||||
self: &Arc<Self>,
|
||||
receipt: Receipt<T>,
|
||||
response: proto::Error,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let mut connection = this.connection_state(receipt.sender_id).await?;
|
||||
let message_id = connection
|
||||
.next_message_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
||||
connection
|
||||
.outgoing_tx
|
||||
.send(response.into_envelope(message_id, Some(receipt.message_id), None))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn connection_state(
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
) -> impl Future<Output = Result<ConnectionState>> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let connections = this.connections.read().await;
|
||||
let connection = connections
|
||||
.get(&connection_id)
|
||||
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
|
||||
Ok(connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TypedEnvelope;
|
||||
use async_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
use futures::StreamExt as _;
|
||||
|
||||
#[test]
|
||||
fn test_request_response() {
|
||||
smol::block_on(async move {
|
||||
// create 2 clients connected to 1 server
|
||||
let server = Peer::new();
|
||||
let client1 = Peer::new();
|
||||
let client2 = Peer::new();
|
||||
|
||||
let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
|
||||
let (client1_conn_id, io_task1, _) =
|
||||
client1.add_connection(client1_to_server_conn).await;
|
||||
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
|
||||
|
||||
let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
|
||||
let (client2_conn_id, io_task3, _) =
|
||||
client2.add_connection(client2_to_server_conn).await;
|
||||
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
|
||||
|
||||
smol::spawn(io_task1).detach();
|
||||
smol::spawn(io_task2).detach();
|
||||
smol::spawn(io_task3).detach();
|
||||
smol::spawn(io_task4).detach();
|
||||
smol::spawn(handle_messages(incoming1, server.clone())).detach();
|
||||
smol::spawn(handle_messages(incoming2, server.clone())).detach();
|
||||
|
||||
assert_eq!(
|
||||
client1
|
||||
.request(client1_conn_id, proto::Ping {},)
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::Ack {}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client2
|
||||
.request(client2_conn_id, proto::Ping {},)
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::Ack {}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client1
|
||||
.request(
|
||||
client1_conn_id,
|
||||
proto::OpenBuffer {
|
||||
worktree_id: 1,
|
||||
path: "path/one".to_string(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::OpenBufferResponse {
|
||||
buffer: Some(proto::Buffer {
|
||||
id: 101,
|
||||
content: "path/one content".to_string(),
|
||||
history: vec![],
|
||||
selections: vec![],
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client2
|
||||
.request(
|
||||
client2_conn_id,
|
||||
proto::OpenBuffer {
|
||||
worktree_id: 2,
|
||||
path: "path/two".to_string(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::OpenBufferResponse {
|
||||
buffer: Some(proto::Buffer {
|
||||
id: 102,
|
||||
content: "path/two content".to_string(),
|
||||
history: vec![],
|
||||
selections: vec![],
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
client1.disconnect(client1_conn_id).await;
|
||||
client2.disconnect(client1_conn_id).await;
|
||||
|
||||
async fn handle_messages(
|
||||
mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
|
||||
peer: Arc<Peer>,
|
||||
) -> Result<()> {
|
||||
while let Some(envelope) = messages.next().await {
|
||||
let envelope = envelope.into_any();
|
||||
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
|
||||
let receipt = envelope.receipt();
|
||||
peer.respond(receipt, proto::Ack {}).await?
|
||||
} else if let Some(envelope) =
|
||||
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
|
||||
{
|
||||
let message = &envelope.payload;
|
||||
let receipt = envelope.receipt();
|
||||
let response = match message.path.as_str() {
|
||||
"path/one" => {
|
||||
assert_eq!(message.worktree_id, 1);
|
||||
proto::OpenBufferResponse {
|
||||
buffer: Some(proto::Buffer {
|
||||
id: 101,
|
||||
content: "path/one content".to_string(),
|
||||
history: vec![],
|
||||
selections: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
||||
"path/two" => {
|
||||
assert_eq!(message.worktree_id, 2);
|
||||
proto::OpenBufferResponse {
|
||||
buffer: Some(proto::Buffer {
|
||||
id: 102,
|
||||
content: "path/two content".to_string(),
|
||||
history: vec![],
|
||||
selections: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("unexpected path {}", message.path);
|
||||
}
|
||||
};
|
||||
|
||||
peer.respond(receipt, response).await?
|
||||
} else {
|
||||
panic!("unknown message type");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disconnect() {
|
||||
smol::block_on(async move {
|
||||
let (client_conn, mut server_conn, _) = Connection::in_memory();
|
||||
|
||||
let client = Peer::new();
|
||||
let (connection_id, io_handler, mut incoming) =
|
||||
client.add_connection(client_conn).await;
|
||||
|
||||
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
|
||||
smol::spawn(async move {
|
||||
io_handler.await.ok();
|
||||
io_ended_tx.send(()).await.unwrap();
|
||||
})
|
||||
.detach();
|
||||
|
||||
let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
|
||||
smol::spawn(async move {
|
||||
incoming.next().await;
|
||||
messages_ended_tx.send(()).await.unwrap();
|
||||
})
|
||||
.detach();
|
||||
|
||||
client.disconnect(connection_id).await;
|
||||
|
||||
io_ended_rx.recv().await;
|
||||
messages_ended_rx.recv().await;
|
||||
assert!(server_conn
|
||||
.send(WebSocketMessage::Binary(vec![]))
|
||||
.await
|
||||
.is_err());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_io_error() {
|
||||
smol::block_on(async move {
|
||||
let (client_conn, mut server_conn, _) = Connection::in_memory();
|
||||
|
||||
let client = Peer::new();
|
||||
let (connection_id, io_handler, mut incoming) =
|
||||
client.add_connection(client_conn).await;
|
||||
smol::spawn(io_handler).detach();
|
||||
smol::spawn(async move { incoming.next().await }).detach();
|
||||
|
||||
let response = smol::spawn(client.request(connection_id, proto::Ping {}));
|
||||
let _request = server_conn.rx.next().await.unwrap().unwrap();
|
||||
|
||||
drop(server_conn);
|
||||
assert_eq!(
|
||||
response.await.unwrap_err().to_string(),
|
||||
"connection was closed"
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
286
crates/rpc/src/proto.rs
Normal file
286
crates/rpc/src/proto.rs
Normal file
|
@ -0,0 +1,286 @@
|
|||
use super::{ConnectionId, PeerId, TypedEnvelope};
|
||||
use anyhow::Result;
|
||||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||
use futures::{SinkExt as _, StreamExt as _};
|
||||
use prost::Message;
|
||||
use std::any::{Any, TypeId};
|
||||
use std::{
|
||||
io,
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
|
||||
|
||||
pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
|
||||
const NAME: &'static str;
|
||||
fn into_envelope(
|
||||
self,
|
||||
id: u32,
|
||||
responding_to: Option<u32>,
|
||||
original_sender_id: Option<u32>,
|
||||
) -> Envelope;
|
||||
fn from_envelope(envelope: Envelope) -> Option<Self>;
|
||||
}
|
||||
|
||||
pub trait EntityMessage: EnvelopedMessage {
|
||||
fn remote_entity_id(&self) -> u64;
|
||||
}
|
||||
|
||||
pub trait RequestMessage: EnvelopedMessage {
|
||||
type Response: EnvelopedMessage;
|
||||
}
|
||||
|
||||
pub trait AnyTypedEnvelope: 'static + Send + Sync {
|
||||
fn payload_type_id(&self) -> TypeId;
|
||||
fn payload_type_name(&self) -> &'static str;
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
|
||||
}
|
||||
|
||||
impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
|
||||
fn payload_type_id(&self) -> TypeId {
|
||||
TypeId::of::<T>()
|
||||
}
|
||||
|
||||
fn payload_type_name(&self) -> &'static str {
|
||||
T::NAME
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! messages {
|
||||
($($name:ident),* $(,)?) => {
|
||||
pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
|
||||
match envelope.payload {
|
||||
$(Some(envelope::Payload::$name(payload)) => {
|
||||
Some(Box::new(TypedEnvelope {
|
||||
sender_id,
|
||||
original_sender_id: envelope.original_sender_id.map(PeerId),
|
||||
message_id: envelope.id,
|
||||
payload,
|
||||
}))
|
||||
}, )*
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
$(
|
||||
impl EnvelopedMessage for $name {
|
||||
const NAME: &'static str = std::stringify!($name);
|
||||
|
||||
fn into_envelope(
|
||||
self,
|
||||
id: u32,
|
||||
responding_to: Option<u32>,
|
||||
original_sender_id: Option<u32>,
|
||||
) -> Envelope {
|
||||
Envelope {
|
||||
id,
|
||||
responding_to,
|
||||
original_sender_id,
|
||||
payload: Some(envelope::Payload::$name(self)),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_envelope(envelope: Envelope) -> Option<Self> {
|
||||
if let Some(envelope::Payload::$name(msg)) = envelope.payload {
|
||||
Some(msg)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! request_messages {
|
||||
($(($request_name:ident, $response_name:ident)),* $(,)?) => {
|
||||
$(impl RequestMessage for $request_name {
|
||||
type Response = $response_name;
|
||||
})*
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! entity_messages {
|
||||
($id_field:ident, $($name:ident),* $(,)?) => {
|
||||
$(impl EntityMessage for $name {
|
||||
fn remote_entity_id(&self) -> u64 {
|
||||
self.$id_field
|
||||
}
|
||||
})*
|
||||
};
|
||||
}
|
||||
|
||||
messages!(
|
||||
Ack,
|
||||
AddPeer,
|
||||
BufferSaved,
|
||||
ChannelMessageSent,
|
||||
CloseBuffer,
|
||||
CloseWorktree,
|
||||
Error,
|
||||
GetChannelMessages,
|
||||
GetChannelMessagesResponse,
|
||||
GetChannels,
|
||||
GetChannelsResponse,
|
||||
UpdateCollaborators,
|
||||
GetUsers,
|
||||
GetUsersResponse,
|
||||
JoinChannel,
|
||||
JoinChannelResponse,
|
||||
JoinWorktree,
|
||||
JoinWorktreeResponse,
|
||||
LeaveChannel,
|
||||
LeaveWorktree,
|
||||
OpenBuffer,
|
||||
OpenBufferResponse,
|
||||
OpenWorktree,
|
||||
OpenWorktreeResponse,
|
||||
Ping,
|
||||
RemovePeer,
|
||||
SaveBuffer,
|
||||
SendChannelMessage,
|
||||
SendChannelMessageResponse,
|
||||
ShareWorktree,
|
||||
ShareWorktreeResponse,
|
||||
UnshareWorktree,
|
||||
UpdateBuffer,
|
||||
UpdateWorktree,
|
||||
);
|
||||
|
||||
request_messages!(
|
||||
(GetChannels, GetChannelsResponse),
|
||||
(GetUsers, GetUsersResponse),
|
||||
(JoinChannel, JoinChannelResponse),
|
||||
(OpenBuffer, OpenBufferResponse),
|
||||
(JoinWorktree, JoinWorktreeResponse),
|
||||
(OpenWorktree, OpenWorktreeResponse),
|
||||
(Ping, Ack),
|
||||
(SaveBuffer, BufferSaved),
|
||||
(UpdateBuffer, Ack),
|
||||
(ShareWorktree, ShareWorktreeResponse),
|
||||
(UnshareWorktree, Ack),
|
||||
(SendChannelMessage, SendChannelMessageResponse),
|
||||
(GetChannelMessages, GetChannelMessagesResponse),
|
||||
);
|
||||
|
||||
entity_messages!(
|
||||
worktree_id,
|
||||
AddPeer,
|
||||
BufferSaved,
|
||||
CloseBuffer,
|
||||
CloseWorktree,
|
||||
OpenBuffer,
|
||||
JoinWorktree,
|
||||
RemovePeer,
|
||||
SaveBuffer,
|
||||
UnshareWorktree,
|
||||
UpdateBuffer,
|
||||
UpdateWorktree,
|
||||
);
|
||||
|
||||
entity_messages!(channel_id, ChannelMessageSent);
|
||||
|
||||
/// A stream of protobuf messages.
|
||||
pub struct MessageStream<S> {
|
||||
stream: S,
|
||||
encoding_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<S> MessageStream<S> {
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
encoding_buffer: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inner_mut(&mut self) -> &mut S {
|
||||
&mut self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MessageStream<S>
|
||||
where
|
||||
S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
|
||||
{
|
||||
/// Write a given protobuf message to the stream.
|
||||
pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
|
||||
self.encoding_buffer.resize(message.encoded_len(), 0);
|
||||
self.encoding_buffer.clear();
|
||||
message
|
||||
.encode(&mut self.encoding_buffer)
|
||||
.map_err(|err| io::Error::from(err))?;
|
||||
let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
|
||||
self.stream.send(WebSocketMessage::Binary(buffer)).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MessageStream<S>
|
||||
where
|
||||
S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
|
||||
{
|
||||
/// Read a protobuf message of the given type from the stream.
|
||||
pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
|
||||
while let Some(bytes) = self.stream.next().await {
|
||||
match bytes? {
|
||||
WebSocketMessage::Binary(bytes) => {
|
||||
self.encoding_buffer.clear();
|
||||
zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
|
||||
let envelope = Envelope::decode(self.encoding_buffer.as_slice())
|
||||
.map_err(io::Error::from)?;
|
||||
return Ok(envelope);
|
||||
}
|
||||
WebSocketMessage::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Err(WebSocketError::ConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<SystemTime> for Timestamp {
|
||||
fn into(self) -> SystemTime {
|
||||
UNIX_EPOCH
|
||||
.checked_add(Duration::new(self.seconds, self.nanos))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SystemTime> for Timestamp {
|
||||
fn from(time: SystemTime) -> Self {
|
||||
let duration = time.duration_since(UNIX_EPOCH).unwrap();
|
||||
Self {
|
||||
seconds: duration.as_secs(),
|
||||
nanos: duration.subsec_nanos(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u128> for Nonce {
|
||||
fn from(nonce: u128) -> Self {
|
||||
let upper_half = (nonce >> 64) as u64;
|
||||
let lower_half = nonce as u64;
|
||||
Self {
|
||||
upper_half,
|
||||
lower_half,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Nonce> for u128 {
|
||||
fn from(nonce: Nonce) -> Self {
|
||||
let upper_half = (nonce.upper_half as u128) << 64;
|
||||
let lower_half = nonce.lower_half as u128;
|
||||
upper_half | lower_half
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue