Delete old Peer
implementation and adapt previous code paths
This still doesn't compile but should be close.
This commit is contained in:
parent
5dee7ecf5b
commit
d6412fdbde
10 changed files with 171 additions and 997 deletions
|
@ -137,34 +137,6 @@ impl PeerExt for Peer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl PeerExt for zrpc::peer2::Peer {
|
|
||||||
async fn sign_out(
|
|
||||||
self: &Arc<Self>,
|
|
||||||
connection_id: zrpc::ConnectionId,
|
|
||||||
state: &AppState,
|
|
||||||
) -> tide::Result<()> {
|
|
||||||
self.disconnect(connection_id).await;
|
|
||||||
let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
|
|
||||||
for worktree_id in worktree_ids {
|
|
||||||
let state = state.rpc.read().await;
|
|
||||||
if let Some(worktree) = state.worktrees.get(&worktree_id) {
|
|
||||||
rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
|
|
||||||
self.send(
|
|
||||||
conn_id,
|
|
||||||
proto::RemovePeer {
|
|
||||||
worktree_id,
|
|
||||||
peer_id: connection_id.0,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn build_client(client_id: &str, client_secret: &str) -> Client {
|
pub fn build_client(client_id: &str, client_secret: &str) -> Client {
|
||||||
Client::new(
|
Client::new(
|
||||||
ClientId::new(client_id.to_string()),
|
ClientId::new(client_id.to_string()),
|
||||||
|
|
|
@ -30,13 +30,15 @@ use time::OffsetDateTime;
|
||||||
use zrpc::{
|
use zrpc::{
|
||||||
auth::random_token,
|
auth::random_token,
|
||||||
proto::{self, EnvelopedMessage},
|
proto::{self, EnvelopedMessage},
|
||||||
ConnectionId, Peer, Router, TypedEnvelope,
|
ConnectionId, Peer, TypedEnvelope,
|
||||||
};
|
};
|
||||||
|
|
||||||
type ReplicaId = u16;
|
type ReplicaId = u16;
|
||||||
|
|
||||||
type Handler = Box<
|
type Handler = Box<
|
||||||
dyn Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
|
dyn Send
|
||||||
|
+ Sync
|
||||||
|
+ Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
@ -48,7 +50,7 @@ struct ServerBuilder {
|
||||||
impl ServerBuilder {
|
impl ServerBuilder {
|
||||||
pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||||
where
|
where
|
||||||
F: 'static + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
|
F: 'static + Send + Sync + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
|
||||||
Fut: 'static + Send + Future<Output = ()>,
|
Fut: 'static + Send + Future<Output = ()>,
|
||||||
M: EnvelopedMessage,
|
M: EnvelopedMessage,
|
||||||
{
|
{
|
||||||
|
@ -73,23 +75,23 @@ impl ServerBuilder {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build(self, rpc: Arc<zrpc::peer2::Peer>, state: Arc<AppState>) -> Arc<Server> {
|
pub fn build(self, rpc: &Arc<Peer>, state: &Arc<AppState>) -> Arc<Server> {
|
||||||
Arc::new(Server {
|
Arc::new(Server {
|
||||||
rpc,
|
rpc: rpc.clone(),
|
||||||
state,
|
state: state.clone(),
|
||||||
handlers: self.handlers,
|
handlers: self.handlers,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Server {
|
pub struct Server {
|
||||||
rpc: Arc<zrpc::peer2::Peer>,
|
rpc: Arc<Peer>,
|
||||||
state: Arc<AppState>,
|
state: Arc<AppState>,
|
||||||
handlers: Vec<Handler>,
|
handlers: Vec<Handler>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Server {
|
impl Server {
|
||||||
pub async fn add_connection<Conn>(
|
pub async fn handle_connection<Conn>(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
connection: Conn,
|
connection: Conn,
|
||||||
addr: String,
|
addr: String,
|
||||||
|
@ -332,99 +334,31 @@ impl State {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait MessageHandler<'a, M: proto::EnvelopedMessage> {
|
pub fn build_server(state: &Arc<AppState>, rpc: &Arc<Peer>) -> Arc<Server> {
|
||||||
type Output: 'a + Send + Future<Output = tide::Result<()>>;
|
ServerBuilder::default()
|
||||||
|
// .on_message(share_worktree)
|
||||||
fn handle(
|
// .on_message(join_worktree)
|
||||||
&self,
|
// .on_message(update_worktree)
|
||||||
message: TypedEnvelope<M>,
|
// .on_message(close_worktree)
|
||||||
rpc: &'a Arc<Peer>,
|
// .on_message(open_buffer)
|
||||||
app_state: &'a Arc<AppState>,
|
// .on_message(close_buffer)
|
||||||
) -> Self::Output;
|
// .on_message(update_buffer)
|
||||||
}
|
// .on_message(buffer_saved)
|
||||||
|
// .on_message(save_buffer)
|
||||||
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
|
// .on_message(get_channels)
|
||||||
where
|
// .on_message(get_users)
|
||||||
M: proto::EnvelopedMessage,
|
// .on_message(join_channel)
|
||||||
F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
|
// .on_message(send_channel_message)
|
||||||
Fut: 'a + Send + Future<Output = tide::Result<()>>,
|
.build(rpc, state)
|
||||||
{
|
|
||||||
type Output = Fut;
|
|
||||||
|
|
||||||
fn handle(
|
|
||||||
&self,
|
|
||||||
message: TypedEnvelope<M>,
|
|
||||||
rpc: &'a Arc<Peer>,
|
|
||||||
app_state: &'a Arc<AppState>,
|
|
||||||
) -> Self::Output {
|
|
||||||
(self)(message, rpc, app_state)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
|
|
||||||
where
|
|
||||||
M: EnvelopedMessage,
|
|
||||||
H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
|
|
||||||
{
|
|
||||||
let rpc = rpc.clone();
|
|
||||||
let handler = handler.clone();
|
|
||||||
let app_state = app_state.clone();
|
|
||||||
router.add_message_handler(move |message| {
|
|
||||||
let rpc = rpc.clone();
|
|
||||||
let handler = handler.clone();
|
|
||||||
let app_state = app_state.clone();
|
|
||||||
async move {
|
|
||||||
let sender_id = message.sender_id;
|
|
||||||
let message_id = message.message_id;
|
|
||||||
let start_time = Instant::now();
|
|
||||||
log::info!(
|
|
||||||
"RPC message received. id: {}.{}, type:{}",
|
|
||||||
sender_id,
|
|
||||||
message_id,
|
|
||||||
M::NAME
|
|
||||||
);
|
|
||||||
if let Err(err) = handler.handle(message, &rpc, &app_state).await {
|
|
||||||
log::error!("error handling message: {:?}", err);
|
|
||||||
} else {
|
|
||||||
log::info!(
|
|
||||||
"RPC message handled. id:{}.{}, duration:{:?}",
|
|
||||||
sender_id,
|
|
||||||
message_id,
|
|
||||||
start_time.elapsed()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
|
|
||||||
on_message(router, rpc, state, share_worktree);
|
|
||||||
on_message(router, rpc, state, join_worktree);
|
|
||||||
on_message(router, rpc, state, update_worktree);
|
|
||||||
on_message(router, rpc, state, close_worktree);
|
|
||||||
on_message(router, rpc, state, open_buffer);
|
|
||||||
on_message(router, rpc, state, close_buffer);
|
|
||||||
on_message(router, rpc, state, update_buffer);
|
|
||||||
on_message(router, rpc, state, buffer_saved);
|
|
||||||
on_message(router, rpc, state, save_buffer);
|
|
||||||
on_message(router, rpc, state, get_channels);
|
|
||||||
on_message(router, rpc, state, get_users);
|
|
||||||
on_message(router, rpc, state, join_channel);
|
|
||||||
on_message(router, rpc, state, send_channel_message);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
||||||
let mut router = Router::new();
|
let server = build_server(app.state(), rpc);
|
||||||
add_rpc_routes(&mut router, app.state(), rpc);
|
|
||||||
let router = Arc::new(router);
|
|
||||||
|
|
||||||
let rpc = rpc.clone();
|
let rpc = rpc.clone();
|
||||||
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
|
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
|
||||||
let user_id = request.ext::<UserId>().copied();
|
let user_id = request.ext::<UserId>().copied();
|
||||||
let rpc = rpc.clone();
|
let server = server.clone();
|
||||||
let router = router.clone();
|
|
||||||
async move {
|
async move {
|
||||||
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||||
|
|
||||||
|
@ -451,12 +385,11 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
||||||
let http_res: &mut tide::http::Response = response.as_mut();
|
let http_res: &mut tide::http::Response = response.as_mut();
|
||||||
let upgrade_receiver = http_res.recv_upgrade().await;
|
let upgrade_receiver = http_res.recv_upgrade().await;
|
||||||
let addr = request.remote().unwrap_or("unknown").to_string();
|
let addr = request.remote().unwrap_or("unknown").to_string();
|
||||||
let state = request.state().clone();
|
|
||||||
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
|
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
|
||||||
task::spawn(async move {
|
task::spawn(async move {
|
||||||
if let Some(stream) = upgrade_receiver.await {
|
if let Some(stream) = upgrade_receiver.await {
|
||||||
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
|
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
|
||||||
handle_connection(rpc, router, state, addr, stream, user_id).await;
|
server.handle_connection(stream, addr, user_id).await;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -465,43 +398,6 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_connection<Conn>(
|
|
||||||
rpc: Arc<Peer>,
|
|
||||||
router: Arc<Router>,
|
|
||||||
state: Arc<AppState>,
|
|
||||||
addr: String,
|
|
||||||
stream: Conn,
|
|
||||||
user_id: UserId,
|
|
||||||
) where
|
|
||||||
Conn: 'static
|
|
||||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
|
||||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
|
||||||
+ Send
|
|
||||||
+ Unpin,
|
|
||||||
{
|
|
||||||
log::info!("accepted rpc connection: {:?}", addr);
|
|
||||||
let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
|
|
||||||
state
|
|
||||||
.rpc
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.add_connection(connection_id, user_id);
|
|
||||||
|
|
||||||
let handle_messages = async move {
|
|
||||||
handle_messages.await;
|
|
||||||
Ok(())
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(e) = futures::try_join!(handle_messages, handle_io) {
|
|
||||||
log::error!("error handling rpc connection {:?} - {:?}", addr, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
log::info!("closing connection to {:?}", addr);
|
|
||||||
if let Err(e) = rpc.sign_out(connection_id, &state).await {
|
|
||||||
log::error!("error signing out connection {:?} - {:?}", addr, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn share_worktree(
|
async fn share_worktree(
|
||||||
mut request: TypedEnvelope<proto::ShareWorktree>,
|
mut request: TypedEnvelope<proto::ShareWorktree>,
|
||||||
rpc: &Arc<Peer>,
|
rpc: &Arc<Peer>,
|
||||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
||||||
auth,
|
auth,
|
||||||
db::{self, UserId},
|
db::{self, UserId},
|
||||||
github,
|
github,
|
||||||
rpc::{self, add_rpc_routes},
|
rpc::{self, build_server},
|
||||||
AppState, Config,
|
AppState, Config,
|
||||||
};
|
};
|
||||||
use async_std::task;
|
use async_std::task;
|
||||||
|
@ -24,7 +24,7 @@ use zed::{
|
||||||
test::Channel,
|
test::Channel,
|
||||||
worktree::Worktree,
|
worktree::Worktree,
|
||||||
};
|
};
|
||||||
use zrpc::{ForegroundRouter, Peer, Router};
|
use zrpc::Peer;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
|
async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
|
||||||
|
@ -541,7 +541,7 @@ impl TestServer {
|
||||||
let app_state = Self::build_app_state(&db_name).await;
|
let app_state = Self::build_app_state(&db_name).await;
|
||||||
let peer = Peer::new();
|
let peer = Peer::new();
|
||||||
let mut router = Router::new();
|
let mut router = Router::new();
|
||||||
add_rpc_routes(&mut router, &app_state, &peer);
|
build_server(&mut router, &app_state, &peer);
|
||||||
Self {
|
Self {
|
||||||
peer,
|
peer,
|
||||||
router: Arc::new(router),
|
router: Arc::new(router),
|
||||||
|
|
|
@ -24,14 +24,12 @@ pub use settings::Settings;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use postage::watch;
|
use postage::watch;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use zrpc::ForegroundRouter;
|
|
||||||
|
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub settings_tx: Arc<Mutex<watch::Sender<Settings>>>,
|
pub settings_tx: Arc<Mutex<watch::Sender<Settings>>>,
|
||||||
pub settings: watch::Receiver<Settings>,
|
pub settings: watch::Receiver<Settings>,
|
||||||
pub languages: Arc<language::LanguageRegistry>,
|
pub languages: Arc<language::LanguageRegistry>,
|
||||||
pub themes: Arc<settings::ThemeRegistry>,
|
pub themes: Arc<settings::ThemeRegistry>,
|
||||||
pub rpc_router: Arc<ForegroundRouter>,
|
|
||||||
pub rpc: rpc::Client,
|
pub rpc: rpc::Client,
|
||||||
pub fs: Arc<dyn fs::Fs>,
|
pub fs: Arc<dyn fs::Fs>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ use zrpc::proto::EntityMessage;
|
||||||
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
|
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
|
||||||
use zrpc::{
|
use zrpc::{
|
||||||
proto::{EnvelopedMessage, RequestMessage},
|
proto::{EnvelopedMessage, RequestMessage},
|
||||||
ForegroundRouter, Peer, Receipt,
|
Peer, Receipt,
|
||||||
};
|
};
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
|
@ -43,25 +43,6 @@ impl Client {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn on_message<H, M>(
|
|
||||||
&self,
|
|
||||||
router: &mut ForegroundRouter,
|
|
||||||
handler: H,
|
|
||||||
cx: &mut gpui::MutableAppContext,
|
|
||||||
) where
|
|
||||||
H: 'static + Clone + for<'a> MessageHandler<'a, M>,
|
|
||||||
M: proto::EnvelopedMessage,
|
|
||||||
{
|
|
||||||
let this = self.clone();
|
|
||||||
let cx = cx.to_async();
|
|
||||||
router.add_message_handler(move |message| {
|
|
||||||
let this = this.clone();
|
|
||||||
let mut cx = cx.clone();
|
|
||||||
let handler = handler.clone();
|
|
||||||
async move { handler.handle(message, &this, &mut cx).await }
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn subscribe_from_model<T, M, F>(
|
pub fn subscribe_from_model<T, M, F>(
|
||||||
&self,
|
&self,
|
||||||
remote_id: u64,
|
remote_id: u64,
|
||||||
|
@ -90,11 +71,7 @@ impl Client {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn log_in_and_connect(
|
pub async fn log_in_and_connect(&self, cx: AsyncAppContext) -> surf::Result<()> {
|
||||||
&self,
|
|
||||||
router: Arc<ForegroundRouter>,
|
|
||||||
cx: AsyncAppContext,
|
|
||||||
) -> surf::Result<()> {
|
|
||||||
if self.state.read().await.connection_id.is_some() {
|
if self.state.read().await.connection_id.is_some() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
@ -111,13 +88,13 @@ impl Client {
|
||||||
.await
|
.await
|
||||||
.context("websocket handshake")?;
|
.context("websocket handshake")?;
|
||||||
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
||||||
self.add_connection(stream, router, cx).await?;
|
self.add_connection(stream, cx).await?;
|
||||||
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
|
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
|
||||||
let stream = smol::net::TcpStream::connect(host).await?;
|
let stream = smol::net::TcpStream::connect(host).await?;
|
||||||
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
|
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
|
||||||
let (stream, _) = async_tungstenite::client_async(request, stream).await?;
|
let (stream, _) = async_tungstenite::client_async(request, stream).await?;
|
||||||
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
||||||
self.add_connection(stream, router, cx).await?;
|
self.add_connection(stream, cx).await?;
|
||||||
} else {
|
} else {
|
||||||
return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
|
return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
|
||||||
};
|
};
|
||||||
|
@ -125,12 +102,7 @@ impl Client {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add_connection<Conn>(
|
pub async fn add_connection<Conn>(&self, conn: Conn, cx: AsyncAppContext) -> surf::Result<()>
|
||||||
&self,
|
|
||||||
conn: Conn,
|
|
||||||
router: Arc<ForegroundRouter>,
|
|
||||||
cx: AsyncAppContext,
|
|
||||||
) -> surf::Result<()>
|
|
||||||
where
|
where
|
||||||
Conn: 'static
|
Conn: 'static
|
||||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||||
|
@ -138,8 +110,7 @@ impl Client {
|
||||||
+ Unpin
|
+ Unpin
|
||||||
+ Send,
|
+ Send,
|
||||||
{
|
{
|
||||||
let (connection_id, handle_io, handle_messages) =
|
let (connection_id, handle_io, handle_messages) = self.peer.add_connection(conn).await;
|
||||||
self.peer.add_connection(conn, router).await;
|
|
||||||
cx.foreground().spawn(handle_messages).detach();
|
cx.foreground().spawn(handle_messages).detach();
|
||||||
cx.background()
|
cx.background()
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
|
|
|
@ -15,7 +15,6 @@ use std::{
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
use tempdir::TempDir;
|
use tempdir::TempDir;
|
||||||
use zrpc::ForegroundRouter;
|
|
||||||
|
|
||||||
#[cfg(feature = "test-support")]
|
#[cfg(feature = "test-support")]
|
||||||
pub use zrpc::test::Channel;
|
pub use zrpc::test::Channel;
|
||||||
|
@ -163,7 +162,6 @@ pub fn build_app_state(cx: &AppContext) -> Arc<AppState> {
|
||||||
settings,
|
settings,
|
||||||
themes,
|
themes,
|
||||||
languages: languages.clone(),
|
languages: languages.clone(),
|
||||||
rpc_router: Arc::new(ForegroundRouter::new()),
|
|
||||||
rpc: rpc::Client::new(languages),
|
rpc: rpc::Client::new(languages),
|
||||||
fs: Arc::new(RealFs),
|
fs: Arc::new(RealFs),
|
||||||
})
|
})
|
||||||
|
|
|
@ -728,10 +728,9 @@ impl Workspace {
|
||||||
fn share_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
|
fn share_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
|
||||||
let rpc = self.rpc.clone();
|
let rpc = self.rpc.clone();
|
||||||
let platform = cx.platform();
|
let platform = cx.platform();
|
||||||
let router = app_state.rpc_router.clone();
|
|
||||||
|
|
||||||
let task = cx.spawn(|this, mut cx| async move {
|
let task = cx.spawn(|this, mut cx| async move {
|
||||||
rpc.log_in_and_connect(router, cx.clone()).await?;
|
rpc.log_in_and_connect(cx.clone()).await?;
|
||||||
|
|
||||||
let share_task = this.update(&mut cx, |this, cx| {
|
let share_task = this.update(&mut cx, |this, cx| {
|
||||||
let worktree = this.worktrees.iter().next()?;
|
let worktree = this.worktrees.iter().next()?;
|
||||||
|
@ -761,10 +760,9 @@ impl Workspace {
|
||||||
fn join_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
|
fn join_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
|
||||||
let rpc = self.rpc.clone();
|
let rpc = self.rpc.clone();
|
||||||
let languages = self.languages.clone();
|
let languages = self.languages.clone();
|
||||||
let router = app_state.rpc_router.clone();
|
|
||||||
|
|
||||||
let task = cx.spawn(|this, mut cx| async move {
|
let task = cx.spawn(|this, mut cx| async move {
|
||||||
rpc.log_in_and_connect(router, cx.clone()).await?;
|
rpc.log_in_and_connect(cx.clone()).await?;
|
||||||
|
|
||||||
let worktree_url = cx
|
let worktree_url = cx
|
||||||
.platform()
|
.platform()
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
mod peer;
|
mod peer;
|
||||||
pub mod peer2;
|
|
||||||
pub mod proto;
|
pub mod proto;
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
pub mod test;
|
pub mod test;
|
||||||
|
|
454
zrpc/src/peer.rs
454
zrpc/src/peer.rs
|
@ -2,17 +2,14 @@ use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use async_lock::{Mutex, RwLock};
|
use async_lock::{Mutex, RwLock};
|
||||||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||||
use futures::{
|
use futures::{FutureExt, StreamExt};
|
||||||
future::{self, BoxFuture, LocalBoxFuture},
|
|
||||||
FutureExt, Stream, StreamExt,
|
|
||||||
};
|
|
||||||
use postage::{
|
use postage::{
|
||||||
broadcast, mpsc,
|
mpsc,
|
||||||
prelude::{Sink as _, Stream as _},
|
prelude::{Sink as _, Stream as _},
|
||||||
};
|
};
|
||||||
use std::{
|
use std::{
|
||||||
any::{Any, TypeId},
|
any::Any,
|
||||||
collections::{HashMap, HashSet},
|
collections::HashMap,
|
||||||
fmt,
|
fmt,
|
||||||
future::Future,
|
future::Future,
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
|
@ -25,17 +22,20 @@ use std::{
|
||||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||||
pub struct ConnectionId(pub u32);
|
pub struct ConnectionId(pub u32);
|
||||||
|
|
||||||
|
impl fmt::Display for ConnectionId {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
self.0.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||||
pub struct PeerId(pub u32);
|
pub struct PeerId(pub u32);
|
||||||
|
|
||||||
type MessageHandler = Box<
|
impl fmt::Display for PeerId {
|
||||||
dyn Send
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
+ Sync
|
self.0.fmt(f)
|
||||||
+ Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
|
}
|
||||||
>;
|
}
|
||||||
|
|
||||||
type ForegroundMessageHandler =
|
|
||||||
Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
|
|
||||||
|
|
||||||
pub struct Receipt<T> {
|
pub struct Receipt<T> {
|
||||||
pub sender_id: ConnectionId,
|
pub sender_id: ConnectionId,
|
||||||
|
@ -43,6 +43,18 @@ pub struct Receipt<T> {
|
||||||
payload_type: PhantomData<T>,
|
payload_type: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> Clone for Receipt<T> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
sender_id: self.sender_id,
|
||||||
|
message_id: self.message_id,
|
||||||
|
payload_type: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Copy for Receipt<T> {}
|
||||||
|
|
||||||
pub struct TypedEnvelope<T> {
|
pub struct TypedEnvelope<T> {
|
||||||
pub sender_id: ConnectionId,
|
pub sender_id: ConnectionId,
|
||||||
pub original_sender_id: Option<PeerId>,
|
pub original_sender_id: Option<PeerId>,
|
||||||
|
@ -67,17 +79,9 @@ impl<T: RequestMessage> TypedEnvelope<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Router = RouterInternal<MessageHandler>;
|
|
||||||
pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
|
|
||||||
pub struct RouterInternal<H> {
|
|
||||||
message_handlers: Vec<H>,
|
|
||||||
handler_types: HashSet<TypeId>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Peer {
|
pub struct Peer {
|
||||||
connections: RwLock<HashMap<ConnectionId, Connection>>,
|
connections: RwLock<HashMap<ConnectionId, Connection>>,
|
||||||
next_connection_id: AtomicU32,
|
next_connection_id: AtomicU32,
|
||||||
incoming_messages: broadcast::Sender<Arc<dyn Any + Send + Sync>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -92,22 +96,18 @@ impl Peer {
|
||||||
Arc::new(Self {
|
Arc::new(Self {
|
||||||
connections: Default::default(),
|
connections: Default::default(),
|
||||||
next_connection_id: Default::default(),
|
next_connection_id: Default::default(),
|
||||||
incoming_messages: broadcast::channel(256).0,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add_connection<Conn, H, Fut>(
|
pub async fn add_connection<Conn>(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
conn: Conn,
|
conn: Conn,
|
||||||
router: Arc<RouterInternal<H>>,
|
|
||||||
) -> (
|
) -> (
|
||||||
ConnectionId,
|
ConnectionId,
|
||||||
impl Future<Output = anyhow::Result<()>> + Send,
|
impl Future<Output = anyhow::Result<()>> + Send,
|
||||||
impl Future<Output = ()>,
|
mpsc::Receiver<Box<dyn Any + Sync + Send>>,
|
||||||
)
|
)
|
||||||
where
|
where
|
||||||
H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
|
|
||||||
Fut: Future<Output = ()>,
|
|
||||||
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
|
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
||||||
+ Send
|
+ Send
|
||||||
|
@ -118,7 +118,7 @@ impl Peer {
|
||||||
self.next_connection_id
|
self.next_connection_id
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst),
|
.fetch_add(1, atomic::Ordering::SeqCst),
|
||||||
);
|
);
|
||||||
let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
|
let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
|
||||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
|
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
|
||||||
let connection = Connection {
|
let connection = Connection {
|
||||||
outgoing_tx,
|
outgoing_tx,
|
||||||
|
@ -128,6 +128,7 @@ impl Peer {
|
||||||
let mut writer = MessageStream::new(tx);
|
let mut writer = MessageStream::new(tx);
|
||||||
let mut reader = MessageStream::new(rx);
|
let mut reader = MessageStream::new(rx);
|
||||||
|
|
||||||
|
let response_channels = connection.response_channels.clone();
|
||||||
let handle_io = async move {
|
let handle_io = async move {
|
||||||
loop {
|
loop {
|
||||||
let read_message = reader.read_message().fuse();
|
let read_message = reader.read_message().fuse();
|
||||||
|
@ -136,49 +137,46 @@ impl Peer {
|
||||||
futures::select_biased! {
|
futures::select_biased! {
|
||||||
incoming = read_message => match incoming {
|
incoming = read_message => match incoming {
|
||||||
Ok(incoming) => {
|
Ok(incoming) => {
|
||||||
if incoming_tx.send(incoming).await.is_err() {
|
if let Some(responding_to) = incoming.responding_to {
|
||||||
return Ok(());
|
let channel = response_channels.lock().await.remove(&responding_to);
|
||||||
|
if let Some(mut tx) = channel {
|
||||||
|
tx.send(incoming).await.ok();
|
||||||
|
} else {
|
||||||
|
log::warn!("received RPC response to unknown request {}", responding_to);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
|
||||||
|
if incoming_tx.send(envelope).await.is_err() {
|
||||||
|
response_channels.lock().await.clear();
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log::error!("unable to construct a typed envelope");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
|
response_channels.lock().await.clear();
|
||||||
Err(error).context("received invalid RPC message")?;
|
Err(error).context("received invalid RPC message")?;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
outgoing = outgoing_rx.recv().fuse() => match outgoing {
|
outgoing = outgoing_rx.recv().fuse() => match outgoing {
|
||||||
Some(outgoing) => {
|
Some(outgoing) => {
|
||||||
if let Err(result) = writer.write_message(&outgoing).await {
|
if let Err(result) = writer.write_message(&outgoing).await {
|
||||||
|
response_channels.lock().await.clear();
|
||||||
Err(result).context("failed to write RPC message")?;
|
Err(result).context("failed to write RPC message")?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => return Ok(()),
|
None => {
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut broadcast_incoming_messages = self.incoming_messages.clone();
|
|
||||||
let response_channels = connection.response_channels.clone();
|
|
||||||
let handle_messages = async move {
|
|
||||||
while let Some(envelope) = incoming_rx.recv().await {
|
|
||||||
if let Some(responding_to) = envelope.responding_to {
|
|
||||||
let channel = response_channels.lock().await.remove(&responding_to);
|
|
||||||
if let Some(mut tx) = channel {
|
|
||||||
tx.send(envelope).await.ok();
|
|
||||||
} else {
|
|
||||||
log::warn!("received RPC response to unknown request {}", responding_to);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
router.handle(connection_id, envelope.clone()).await;
|
|
||||||
if let Some(envelope) = proto::build_typed_envelope(connection_id, envelope) {
|
|
||||||
broadcast_incoming_messages.send(Arc::from(envelope)).await.ok();
|
|
||||||
} else {
|
|
||||||
log::error!("unable to construct a typed envelope");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
response_channels.lock().await.clear();
|
response_channels.lock().await.clear();
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.connections
|
self.connections
|
||||||
|
@ -186,7 +184,7 @@ impl Peer {
|
||||||
.await
|
.await
|
||||||
.insert(connection_id, connection);
|
.insert(connection_id, connection);
|
||||||
|
|
||||||
(connection_id, handle_io, handle_messages)
|
(connection_id, handle_io, incoming_rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
||||||
|
@ -197,12 +195,6 @@ impl Peer {
|
||||||
self.connections.write().await.clear();
|
self.connections.write().await.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn subscribe<T: EnvelopedMessage>(&self) -> impl Stream<Item = Arc<TypedEnvelope<T>>> {
|
|
||||||
self.incoming_messages
|
|
||||||
.subscribe()
|
|
||||||
.filter_map(|envelope| future::ready(Arc::downcast(envelope).ok()))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn request<T: RequestMessage>(
|
pub fn request<T: RequestMessage>(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
receiver_id: ConnectionId,
|
receiver_id: ConnectionId,
|
||||||
|
@ -325,142 +317,10 @@ impl Peer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<H, Fut> RouterInternal<H>
|
|
||||||
where
|
|
||||||
H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
|
|
||||||
Fut: Future<Output = ()>,
|
|
||||||
{
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
message_handlers: Default::default(),
|
|
||||||
handler_types: Default::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
|
|
||||||
let mut envelope = Some(message);
|
|
||||||
for handler in self.message_handlers.iter() {
|
|
||||||
if let Some(future) = handler(&mut envelope, connection_id) {
|
|
||||||
future.await;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Router {
|
|
||||||
pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
|
|
||||||
where
|
|
||||||
T: EnvelopedMessage,
|
|
||||||
Fut: 'static + Send + Future<Output = Result<()>>,
|
|
||||||
F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
|
|
||||||
{
|
|
||||||
if !self.handler_types.insert(TypeId::of::<T>()) {
|
|
||||||
panic!("duplicate handler type");
|
|
||||||
}
|
|
||||||
|
|
||||||
self.message_handlers
|
|
||||||
.push(Box::new(move |envelope, connection_id| {
|
|
||||||
if envelope.as_ref().map_or(false, T::matches_envelope) {
|
|
||||||
let envelope = Option::take(envelope).unwrap();
|
|
||||||
let message_id = envelope.id;
|
|
||||||
let future = handler(TypedEnvelope {
|
|
||||||
sender_id: connection_id,
|
|
||||||
original_sender_id: envelope.original_sender_id.map(PeerId),
|
|
||||||
message_id,
|
|
||||||
payload: T::from_envelope(envelope).unwrap(),
|
|
||||||
});
|
|
||||||
Some(
|
|
||||||
async move {
|
|
||||||
if let Err(error) = future.await {
|
|
||||||
log::error!(
|
|
||||||
"error handling message {} {}: {:?}",
|
|
||||||
T::NAME,
|
|
||||||
message_id,
|
|
||||||
error
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.boxed(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ForegroundRouter {
|
|
||||||
pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
|
|
||||||
where
|
|
||||||
T: EnvelopedMessage,
|
|
||||||
Fut: 'static + Future<Output = Result<()>>,
|
|
||||||
F: 'static + Fn(TypedEnvelope<T>) -> Fut,
|
|
||||||
{
|
|
||||||
if !self.handler_types.insert(TypeId::of::<T>()) {
|
|
||||||
panic!("duplicate handler type");
|
|
||||||
}
|
|
||||||
|
|
||||||
self.message_handlers
|
|
||||||
.push(Box::new(move |envelope, connection_id| {
|
|
||||||
if envelope.as_ref().map_or(false, T::matches_envelope) {
|
|
||||||
let envelope = Option::take(envelope).unwrap();
|
|
||||||
let message_id = envelope.id;
|
|
||||||
let future = handler(TypedEnvelope {
|
|
||||||
sender_id: connection_id,
|
|
||||||
original_sender_id: envelope.original_sender_id.map(PeerId),
|
|
||||||
message_id,
|
|
||||||
payload: T::from_envelope(envelope).unwrap(),
|
|
||||||
});
|
|
||||||
Some(
|
|
||||||
async move {
|
|
||||||
if let Err(error) = future.await {
|
|
||||||
log::error!(
|
|
||||||
"error handling message {} {}: {:?}",
|
|
||||||
T::NAME,
|
|
||||||
message_id,
|
|
||||||
error
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.boxed_local(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> Clone for Receipt<T> {
|
|
||||||
fn clone(&self) -> Self {
|
|
||||||
Self {
|
|
||||||
sender_id: self.sender_id,
|
|
||||||
message_id: self.message_id,
|
|
||||||
payload_type: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> Copy for Receipt<T> {}
|
|
||||||
|
|
||||||
impl fmt::Display for ConnectionId {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
self.0.fmt(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for PeerId {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
self.0.fmt(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::test;
|
use crate::{test, TypedEnvelope};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_request_response() {
|
fn test_request_response() {
|
||||||
|
@ -470,139 +330,37 @@ mod tests {
|
||||||
let client1 = Peer::new();
|
let client1 = Peer::new();
|
||||||
let client2 = Peer::new();
|
let client2 = Peer::new();
|
||||||
|
|
||||||
let mut router = Router::new();
|
|
||||||
router.add_message_handler({
|
|
||||||
let server = server.clone();
|
|
||||||
move |envelope: TypedEnvelope<proto::Auth>| {
|
|
||||||
let server = server.clone();
|
|
||||||
async move {
|
|
||||||
let receipt = envelope.receipt();
|
|
||||||
let message = envelope.payload;
|
|
||||||
server
|
|
||||||
.respond(
|
|
||||||
receipt,
|
|
||||||
match message.user_id {
|
|
||||||
1 => {
|
|
||||||
assert_eq!(message.access_token, "access-token-1");
|
|
||||||
proto::AuthResponse {
|
|
||||||
credentials_valid: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
2 => {
|
|
||||||
assert_eq!(message.access_token, "access-token-2");
|
|
||||||
proto::AuthResponse {
|
|
||||||
credentials_valid: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
panic!("unexpected user id {}", message.user_id);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
router.add_message_handler({
|
|
||||||
let server = server.clone();
|
|
||||||
move |envelope: TypedEnvelope<proto::OpenBuffer>| {
|
|
||||||
let server = server.clone();
|
|
||||||
async move {
|
|
||||||
let receipt = envelope.receipt();
|
|
||||||
let message = envelope.payload;
|
|
||||||
server
|
|
||||||
.respond(
|
|
||||||
receipt,
|
|
||||||
match message.path.as_str() {
|
|
||||||
"path/one" => {
|
|
||||||
assert_eq!(message.worktree_id, 1);
|
|
||||||
proto::OpenBufferResponse {
|
|
||||||
buffer: Some(proto::Buffer {
|
|
||||||
id: 101,
|
|
||||||
content: "path/one content".to_string(),
|
|
||||||
history: vec![],
|
|
||||||
selections: vec![],
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"path/two" => {
|
|
||||||
assert_eq!(message.worktree_id, 2);
|
|
||||||
proto::OpenBufferResponse {
|
|
||||||
buffer: Some(proto::Buffer {
|
|
||||||
id: 102,
|
|
||||||
content: "path/two content".to_string(),
|
|
||||||
history: vec![],
|
|
||||||
selections: vec![],
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
panic!("unexpected path {}", message.path);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
let router = Arc::new(router);
|
|
||||||
|
|
||||||
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
|
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
|
||||||
let (client1_conn_id, io_task1, msg_task1) = client1
|
let (client1_conn_id, io_task1, _) =
|
||||||
.add_connection(client1_to_server_conn, router.clone())
|
client1.add_connection(client1_to_server_conn).await;
|
||||||
.await;
|
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
|
||||||
let (_, io_task2, msg_task2) = server
|
|
||||||
.add_connection(server_to_client_1_conn, router.clone())
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
|
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
|
||||||
let (client2_conn_id, io_task3, msg_task3) = client2
|
let (client2_conn_id, io_task3, _) =
|
||||||
.add_connection(client2_to_server_conn, router.clone())
|
client2.add_connection(client2_to_server_conn).await;
|
||||||
.await;
|
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
|
||||||
let (_, io_task4, msg_task4) = server
|
|
||||||
.add_connection(server_to_client_2_conn, router.clone())
|
|
||||||
.await;
|
|
||||||
|
|
||||||
smol::spawn(io_task1).detach();
|
smol::spawn(io_task1).detach();
|
||||||
smol::spawn(io_task2).detach();
|
smol::spawn(io_task2).detach();
|
||||||
smol::spawn(io_task3).detach();
|
smol::spawn(io_task3).detach();
|
||||||
smol::spawn(io_task4).detach();
|
smol::spawn(io_task4).detach();
|
||||||
smol::spawn(msg_task1).detach();
|
smol::spawn(handle_messages(incoming1, server.clone())).detach();
|
||||||
smol::spawn(msg_task2).detach();
|
smol::spawn(handle_messages(incoming2, server.clone())).detach();
|
||||||
smol::spawn(msg_task3).detach();
|
|
||||||
smol::spawn(msg_task4).detach();
|
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
client1
|
client1
|
||||||
.request(
|
.request(client1_conn_id, proto::Ping { id: 1 },)
|
||||||
client1_conn_id,
|
|
||||||
proto::Auth {
|
|
||||||
user_id: 1,
|
|
||||||
access_token: "access-token-1".to_string(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
proto::AuthResponse {
|
proto::Pong { id: 1 }
|
||||||
credentials_valid: true,
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
client2
|
client2
|
||||||
.request(
|
.request(client2_conn_id, proto::Ping { id: 2 },)
|
||||||
client2_conn_id,
|
|
||||||
proto::Auth {
|
|
||||||
user_id: 2,
|
|
||||||
access_token: "access-token-2".to_string(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
proto::AuthResponse {
|
proto::Pong { id: 2 }
|
||||||
credentials_valid: false,
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -649,6 +407,62 @@ mod tests {
|
||||||
|
|
||||||
client1.disconnect(client1_conn_id).await;
|
client1.disconnect(client1_conn_id).await;
|
||||||
client2.disconnect(client1_conn_id).await;
|
client2.disconnect(client1_conn_id).await;
|
||||||
|
|
||||||
|
async fn handle_messages(
|
||||||
|
mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
|
||||||
|
peer: Arc<Peer>,
|
||||||
|
) -> Result<()> {
|
||||||
|
while let Some(envelope) = messages.next().await {
|
||||||
|
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
|
||||||
|
let receipt = envelope.receipt();
|
||||||
|
peer.respond(
|
||||||
|
receipt,
|
||||||
|
proto::Pong {
|
||||||
|
id: envelope.payload.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
} else if let Some(envelope) =
|
||||||
|
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
|
||||||
|
{
|
||||||
|
let message = &envelope.payload;
|
||||||
|
let receipt = envelope.receipt();
|
||||||
|
let response = match message.path.as_str() {
|
||||||
|
"path/one" => {
|
||||||
|
assert_eq!(message.worktree_id, 1);
|
||||||
|
proto::OpenBufferResponse {
|
||||||
|
buffer: Some(proto::Buffer {
|
||||||
|
id: 101,
|
||||||
|
content: "path/one content".to_string(),
|
||||||
|
history: vec![],
|
||||||
|
selections: vec![],
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"path/two" => {
|
||||||
|
assert_eq!(message.worktree_id, 2);
|
||||||
|
proto::OpenBufferResponse {
|
||||||
|
buffer: Some(proto::Buffer {
|
||||||
|
id: 102,
|
||||||
|
content: "path/two content".to_string(),
|
||||||
|
history: vec![],
|
||||||
|
selections: vec![],
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("unexpected path {}", message.path);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
peer.respond(receipt, response).await?
|
||||||
|
} else {
|
||||||
|
panic!("unknown message type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -658,9 +472,8 @@ mod tests {
|
||||||
let (client_conn, mut server_conn) = test::Channel::bidirectional();
|
let (client_conn, mut server_conn) = test::Channel::bidirectional();
|
||||||
|
|
||||||
let client = Peer::new();
|
let client = Peer::new();
|
||||||
let router = Arc::new(Router::new());
|
let (connection_id, io_handler, mut incoming) =
|
||||||
let (connection_id, io_handler, message_handler) =
|
client.add_connection(client_conn).await;
|
||||||
client.add_connection(client_conn, router).await;
|
|
||||||
|
|
||||||
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
|
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
|
||||||
smol::spawn(async move {
|
smol::spawn(async move {
|
||||||
|
@ -671,7 +484,7 @@ mod tests {
|
||||||
|
|
||||||
let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
|
let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
|
||||||
smol::spawn(async move {
|
smol::spawn(async move {
|
||||||
message_handler.await;
|
incoming.next().await;
|
||||||
messages_ended_tx.send(()).await.unwrap();
|
messages_ended_tx.send(()).await.unwrap();
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
|
@ -695,11 +508,10 @@ mod tests {
|
||||||
drop(server_conn);
|
drop(server_conn);
|
||||||
|
|
||||||
let client = Peer::new();
|
let client = Peer::new();
|
||||||
let router = Arc::new(Router::new());
|
let (connection_id, io_handler, mut incoming) =
|
||||||
let (connection_id, io_handler, message_handler) =
|
client.add_connection(client_conn).await;
|
||||||
client.add_connection(client_conn, router).await;
|
|
||||||
smol::spawn(io_handler).detach();
|
smol::spawn(io_handler).detach();
|
||||||
smol::spawn(message_handler).detach();
|
smol::spawn(async move { incoming.next().await }).detach();
|
||||||
|
|
||||||
let err = client
|
let err = client
|
||||||
.request(
|
.request(
|
||||||
|
|
|
@ -1,470 +0,0 @@
|
||||||
use crate::{
|
|
||||||
proto::{self, EnvelopedMessage, MessageStream, RequestMessage},
|
|
||||||
ConnectionId, PeerId, Receipt,
|
|
||||||
};
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
|
||||||
use async_lock::{Mutex, RwLock};
|
|
||||||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
|
||||||
use futures::{FutureExt, StreamExt};
|
|
||||||
use postage::{
|
|
||||||
mpsc,
|
|
||||||
prelude::{Sink as _, Stream as _},
|
|
||||||
};
|
|
||||||
use std::{
|
|
||||||
any::Any,
|
|
||||||
collections::HashMap,
|
|
||||||
future::Future,
|
|
||||||
sync::{
|
|
||||||
atomic::{self, AtomicU32},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub struct Peer {
|
|
||||||
connections: RwLock<HashMap<ConnectionId, Connection>>,
|
|
||||||
next_connection_id: AtomicU32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct Connection {
|
|
||||||
outgoing_tx: mpsc::Sender<proto::Envelope>,
|
|
||||||
next_message_id: Arc<AtomicU32>,
|
|
||||||
response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Peer {
|
|
||||||
pub fn new() -> Arc<Self> {
|
|
||||||
Arc::new(Self {
|
|
||||||
connections: Default::default(),
|
|
||||||
next_connection_id: Default::default(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn add_connection<Conn>(
|
|
||||||
self: &Arc<Self>,
|
|
||||||
conn: Conn,
|
|
||||||
) -> (
|
|
||||||
ConnectionId,
|
|
||||||
impl Future<Output = anyhow::Result<()>> + Send,
|
|
||||||
mpsc::Receiver<Box<dyn Any + Sync + Send>>,
|
|
||||||
)
|
|
||||||
where
|
|
||||||
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
|
|
||||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
|
||||||
+ Send
|
|
||||||
+ Unpin,
|
|
||||||
{
|
|
||||||
let (tx, rx) = conn.split();
|
|
||||||
let connection_id = ConnectionId(
|
|
||||||
self.next_connection_id
|
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst),
|
|
||||||
);
|
|
||||||
let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
|
|
||||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
|
|
||||||
let connection = Connection {
|
|
||||||
outgoing_tx,
|
|
||||||
next_message_id: Default::default(),
|
|
||||||
response_channels: Default::default(),
|
|
||||||
};
|
|
||||||
let mut writer = MessageStream::new(tx);
|
|
||||||
let mut reader = MessageStream::new(rx);
|
|
||||||
|
|
||||||
let response_channels = connection.response_channels.clone();
|
|
||||||
let handle_io = async move {
|
|
||||||
loop {
|
|
||||||
let read_message = reader.read_message().fuse();
|
|
||||||
futures::pin_mut!(read_message);
|
|
||||||
loop {
|
|
||||||
futures::select_biased! {
|
|
||||||
incoming = read_message => match incoming {
|
|
||||||
Ok(incoming) => {
|
|
||||||
if let Some(responding_to) = incoming.responding_to {
|
|
||||||
let channel = response_channels.lock().await.remove(&responding_to);
|
|
||||||
if let Some(mut tx) = channel {
|
|
||||||
tx.send(incoming).await.ok();
|
|
||||||
} else {
|
|
||||||
log::warn!("received RPC response to unknown request {}", responding_to);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
|
|
||||||
if incoming_tx.send(envelope).await.is_err() {
|
|
||||||
response_channels.lock().await.clear();
|
|
||||||
return Ok(())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log::error!("unable to construct a typed envelope");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Err(error) => {
|
|
||||||
response_channels.lock().await.clear();
|
|
||||||
Err(error).context("received invalid RPC message")?;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
outgoing = outgoing_rx.recv().fuse() => match outgoing {
|
|
||||||
Some(outgoing) => {
|
|
||||||
if let Err(result) = writer.write_message(&outgoing).await {
|
|
||||||
response_channels.lock().await.clear();
|
|
||||||
Err(result).context("failed to write RPC message")?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
response_channels.lock().await.clear();
|
|
||||||
return Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
self.connections
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.insert(connection_id, connection);
|
|
||||||
|
|
||||||
(connection_id, handle_io, incoming_rx)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn disconnect(&self, connection_id: ConnectionId) {
|
|
||||||
self.connections.write().await.remove(&connection_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn reset(&self) {
|
|
||||||
self.connections.write().await.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn request<T: RequestMessage>(
|
|
||||||
self: &Arc<Self>,
|
|
||||||
receiver_id: ConnectionId,
|
|
||||||
request: T,
|
|
||||||
) -> impl Future<Output = Result<T::Response>> {
|
|
||||||
self.request_internal(None, receiver_id, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward_request<T: RequestMessage>(
|
|
||||||
self: &Arc<Self>,
|
|
||||||
sender_id: ConnectionId,
|
|
||||||
receiver_id: ConnectionId,
|
|
||||||
request: T,
|
|
||||||
) -> impl Future<Output = Result<T::Response>> {
|
|
||||||
self.request_internal(Some(sender_id), receiver_id, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn request_internal<T: RequestMessage>(
|
|
||||||
self: &Arc<Self>,
|
|
||||||
original_sender_id: Option<ConnectionId>,
|
|
||||||
receiver_id: ConnectionId,
|
|
||||||
request: T,
|
|
||||||
) -> impl Future<Output = Result<T::Response>> {
|
|
||||||
let this = self.clone();
|
|
||||||
let (tx, mut rx) = mpsc::channel(1);
|
|
||||||
async move {
|
|
||||||
let mut connection = this.connection(receiver_id).await?;
|
|
||||||
let message_id = connection
|
|
||||||
.next_message_id
|
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
|
||||||
connection
|
|
||||||
.response_channels
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.insert(message_id, tx);
|
|
||||||
connection
|
|
||||||
.outgoing_tx
|
|
||||||
.send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
|
|
||||||
.await
|
|
||||||
.map_err(|_| anyhow!("connection was closed"))?;
|
|
||||||
let response = rx
|
|
||||||
.recv()
|
|
||||||
.await
|
|
||||||
.ok_or_else(|| anyhow!("connection was closed"))?;
|
|
||||||
T::Response::from_envelope(response)
|
|
||||||
.ok_or_else(|| anyhow!("received response of the wrong type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send<T: EnvelopedMessage>(
|
|
||||||
self: &Arc<Self>,
|
|
||||||
receiver_id: ConnectionId,
|
|
||||||
message: T,
|
|
||||||
) -> impl Future<Output = Result<()>> {
|
|
||||||
let this = self.clone();
|
|
||||||
async move {
|
|
||||||
let mut connection = this.connection(receiver_id).await?;
|
|
||||||
let message_id = connection
|
|
||||||
.next_message_id
|
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
|
||||||
connection
|
|
||||||
.outgoing_tx
|
|
||||||
.send(message.into_envelope(message_id, None, None))
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward_send<T: EnvelopedMessage>(
|
|
||||||
self: &Arc<Self>,
|
|
||||||
sender_id: ConnectionId,
|
|
||||||
receiver_id: ConnectionId,
|
|
||||||
message: T,
|
|
||||||
) -> impl Future<Output = Result<()>> {
|
|
||||||
let this = self.clone();
|
|
||||||
async move {
|
|
||||||
let mut connection = this.connection(receiver_id).await?;
|
|
||||||
let message_id = connection
|
|
||||||
.next_message_id
|
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
|
||||||
connection
|
|
||||||
.outgoing_tx
|
|
||||||
.send(message.into_envelope(message_id, None, Some(sender_id.0)))
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn respond<T: RequestMessage>(
|
|
||||||
self: &Arc<Self>,
|
|
||||||
receipt: Receipt<T>,
|
|
||||||
response: T::Response,
|
|
||||||
) -> impl Future<Output = Result<()>> {
|
|
||||||
let this = self.clone();
|
|
||||||
async move {
|
|
||||||
let mut connection = this.connection(receipt.sender_id).await?;
|
|
||||||
let message_id = connection
|
|
||||||
.next_message_id
|
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst);
|
|
||||||
connection
|
|
||||||
.outgoing_tx
|
|
||||||
.send(response.into_envelope(message_id, Some(receipt.message_id), None))
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn connection(
|
|
||||||
self: &Arc<Self>,
|
|
||||||
connection_id: ConnectionId,
|
|
||||||
) -> impl Future<Output = Result<Connection>> {
|
|
||||||
let this = self.clone();
|
|
||||||
async move {
|
|
||||||
let connections = this.connections.read().await;
|
|
||||||
let connection = connections
|
|
||||||
.get(&connection_id)
|
|
||||||
.ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
|
|
||||||
Ok(connection.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::{test, TypedEnvelope};
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_request_response() {
|
|
||||||
smol::block_on(async move {
|
|
||||||
// create 2 clients connected to 1 server
|
|
||||||
let server = Peer::new();
|
|
||||||
let client1 = Peer::new();
|
|
||||||
let client2 = Peer::new();
|
|
||||||
|
|
||||||
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
|
|
||||||
let (client1_conn_id, io_task1, _) =
|
|
||||||
client1.add_connection(client1_to_server_conn).await;
|
|
||||||
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
|
|
||||||
|
|
||||||
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
|
|
||||||
let (client2_conn_id, io_task3, _) =
|
|
||||||
client2.add_connection(client2_to_server_conn).await;
|
|
||||||
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
|
|
||||||
|
|
||||||
smol::spawn(io_task1).detach();
|
|
||||||
smol::spawn(io_task2).detach();
|
|
||||||
smol::spawn(io_task3).detach();
|
|
||||||
smol::spawn(io_task4).detach();
|
|
||||||
smol::spawn(handle_messages(incoming1, server.clone())).detach();
|
|
||||||
smol::spawn(handle_messages(incoming2, server.clone())).detach();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
client1
|
|
||||||
.request(client1_conn_id, proto::Ping { id: 1 },)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
proto::Pong { id: 1 }
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
client2
|
|
||||||
.request(client2_conn_id, proto::Ping { id: 2 },)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
proto::Pong { id: 2 }
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
client1
|
|
||||||
.request(
|
|
||||||
client1_conn_id,
|
|
||||||
proto::OpenBuffer {
|
|
||||||
worktree_id: 1,
|
|
||||||
path: "path/one".to_string(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
proto::OpenBufferResponse {
|
|
||||||
buffer: Some(proto::Buffer {
|
|
||||||
id: 101,
|
|
||||||
content: "path/one content".to_string(),
|
|
||||||
history: vec![],
|
|
||||||
selections: vec![],
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
client2
|
|
||||||
.request(
|
|
||||||
client2_conn_id,
|
|
||||||
proto::OpenBuffer {
|
|
||||||
worktree_id: 2,
|
|
||||||
path: "path/two".to_string(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap(),
|
|
||||||
proto::OpenBufferResponse {
|
|
||||||
buffer: Some(proto::Buffer {
|
|
||||||
id: 102,
|
|
||||||
content: "path/two content".to_string(),
|
|
||||||
history: vec![],
|
|
||||||
selections: vec![],
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
client1.disconnect(client1_conn_id).await;
|
|
||||||
client2.disconnect(client1_conn_id).await;
|
|
||||||
|
|
||||||
async fn handle_messages(
|
|
||||||
mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
|
|
||||||
peer: Arc<Peer>,
|
|
||||||
) -> Result<()> {
|
|
||||||
while let Some(envelope) = messages.next().await {
|
|
||||||
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
|
|
||||||
let receipt = envelope.receipt();
|
|
||||||
peer.respond(
|
|
||||||
receipt,
|
|
||||||
proto::Pong {
|
|
||||||
id: envelope.payload.id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
} else if let Some(envelope) =
|
|
||||||
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
|
|
||||||
{
|
|
||||||
let message = &envelope.payload;
|
|
||||||
let receipt = envelope.receipt();
|
|
||||||
let response = match message.path.as_str() {
|
|
||||||
"path/one" => {
|
|
||||||
assert_eq!(message.worktree_id, 1);
|
|
||||||
proto::OpenBufferResponse {
|
|
||||||
buffer: Some(proto::Buffer {
|
|
||||||
id: 101,
|
|
||||||
content: "path/one content".to_string(),
|
|
||||||
history: vec![],
|
|
||||||
selections: vec![],
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"path/two" => {
|
|
||||||
assert_eq!(message.worktree_id, 2);
|
|
||||||
proto::OpenBufferResponse {
|
|
||||||
buffer: Some(proto::Buffer {
|
|
||||||
id: 102,
|
|
||||||
content: "path/two content".to_string(),
|
|
||||||
history: vec![],
|
|
||||||
selections: vec![],
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
panic!("unexpected path {}", message.path);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
peer.respond(receipt, response).await?
|
|
||||||
} else {
|
|
||||||
panic!("unknown message type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_disconnect() {
|
|
||||||
smol::block_on(async move {
|
|
||||||
let (client_conn, mut server_conn) = test::Channel::bidirectional();
|
|
||||||
|
|
||||||
let client = Peer::new();
|
|
||||||
let (connection_id, io_handler, mut incoming) =
|
|
||||||
client.add_connection(client_conn).await;
|
|
||||||
|
|
||||||
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
|
|
||||||
smol::spawn(async move {
|
|
||||||
io_handler.await.ok();
|
|
||||||
io_ended_tx.send(()).await.unwrap();
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
|
|
||||||
let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
|
|
||||||
smol::spawn(async move {
|
|
||||||
incoming.next().await;
|
|
||||||
messages_ended_tx.send(()).await.unwrap();
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
|
|
||||||
client.disconnect(connection_id).await;
|
|
||||||
|
|
||||||
io_ended_rx.recv().await;
|
|
||||||
messages_ended_rx.recv().await;
|
|
||||||
assert!(
|
|
||||||
futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
|
|
||||||
.await
|
|
||||||
.is_err()
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_io_error() {
|
|
||||||
smol::block_on(async move {
|
|
||||||
let (client_conn, server_conn) = test::Channel::bidirectional();
|
|
||||||
drop(server_conn);
|
|
||||||
|
|
||||||
let client = Peer::new();
|
|
||||||
let (connection_id, io_handler, mut incoming) =
|
|
||||||
client.add_connection(client_conn).await;
|
|
||||||
smol::spawn(io_handler).detach();
|
|
||||||
smol::spawn(async move { incoming.next().await }).detach();
|
|
||||||
|
|
||||||
let err = client
|
|
||||||
.request(
|
|
||||||
connection_id,
|
|
||||||
proto::Auth {
|
|
||||||
user_id: 42,
|
|
||||||
access_token: "token".to_string(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap_err();
|
|
||||||
assert_eq!(err.to_string(), "connection was closed");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue