Remove remaining instances of router

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Max Brunsfeld <max@zed.dev>
This commit is contained in:
Antonio Scandurra 2021-08-19 19:38:17 +02:00
parent d398b96f56
commit 5338b30c00
12 changed files with 241 additions and 171 deletions

View file

@ -29,7 +29,7 @@ use tide::{
use time::OffsetDateTime; use time::OffsetDateTime;
use zrpc::{ use zrpc::{
auth::random_token, auth::random_token,
proto::{self, EnvelopedMessage}, proto::{self, AnyTypedEnvelope, EnvelopedMessage},
ConnectionId, Peer, TypedEnvelope, ConnectionId, Peer, TypedEnvelope,
}; };
@ -38,16 +38,12 @@ type ReplicaId = u16;
type MessageHandler = Box< type MessageHandler = Box<
dyn Send dyn Send
+ Sync + Sync
+ Fn( + Fn(Box<dyn AnyTypedEnvelope>, Arc<Server>) -> BoxFuture<'static, tide::Result<()>>,
&mut Option<Box<dyn Any + Send + Sync>>,
Arc<Server>,
) -> Option<BoxFuture<'static, tide::Result<()>>>,
>; >;
#[derive(Default)] #[derive(Default)]
struct ServerBuilder { struct ServerBuilder {
handlers: Vec<MessageHandler>, handlers: HashMap<TypeId, MessageHandler>,
handler_types: HashSet<TypeId>,
} }
impl ServerBuilder { impl ServerBuilder {
@ -57,24 +53,17 @@ impl ServerBuilder {
Fut: 'static + Send + Future<Output = tide::Result<()>>, Fut: 'static + Send + Future<Output = tide::Result<()>>,
M: EnvelopedMessage, M: EnvelopedMessage,
{ {
if self.handler_types.insert(TypeId::of::<M>()) { let prev_handler = self.handlers.insert(
TypeId::of::<M>(),
Box::new(move |envelope, server| {
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
(handler)(envelope, server).boxed()
}),
);
if prev_handler.is_some() {
panic!("registered a handler for the same message twice"); panic!("registered a handler for the same message twice");
} }
self.handlers
.push(Box::new(move |untyped_envelope, server| {
if let Some(typed_envelope) = untyped_envelope.take() {
match typed_envelope.downcast::<TypedEnvelope<M>>() {
Ok(typed_envelope) => Some((handler)(typed_envelope, server).boxed()),
Err(envelope) => {
*untyped_envelope = Some(envelope);
None
}
}
} else {
None
}
}));
self self
} }
@ -90,16 +79,17 @@ impl ServerBuilder {
pub struct Server { pub struct Server {
rpc: Arc<Peer>, rpc: Arc<Peer>,
state: Arc<AppState>, state: Arc<AppState>,
handlers: Vec<MessageHandler>, handlers: HashMap<TypeId, MessageHandler>,
} }
impl Server { impl Server {
pub async fn handle_connection<Conn>( pub fn handle_connection<Conn>(
self: &Arc<Self>, self: &Arc<Self>,
connection: Conn, connection: Conn,
addr: String, addr: String,
user_id: UserId, user_id: UserId,
) where ) -> impl Future<Output = ()>
where
Conn: 'static Conn: 'static
+ futures::Sink<WebSocketMessage, Error = WebSocketError> + futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
@ -107,7 +97,9 @@ impl Server {
+ Unpin, + Unpin,
{ {
let this = self.clone(); let this = self.clone();
let (connection_id, handle_io, mut incoming_rx) = this.rpc.add_connection(connection).await; async move {
let (connection_id, handle_io, mut incoming_rx) =
this.rpc.add_connection(connection).await;
this.state this.state
.rpc .rpc
.write() .write()
@ -123,21 +115,15 @@ impl Server {
message = next_message => { message = next_message => {
if let Some(message) = message { if let Some(message) = message {
let start_time = Instant::now(); let start_time = Instant::now();
log::info!("RPC message received"); log::info!("RPC message received: {}", message.payload_type_name());
let mut message = Some(message); if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
for handler in &this.handlers { if let Err(err) = (handler)(message, this.clone()).await {
if let Some(future) = (handler)(&mut message, this.clone()) {
if let Err(err) = future.await {
log::error!("error handling message: {:?}", err); log::error!("error handling message: {:?}", err);
} else { } else {
log::info!("RPC message handled. duration:{:?}", start_time.elapsed()); log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
} }
break; } else {
} log::warn!("unhandled message: {}", message.payload_type_name());
}
if let Some(message) = message {
log::warn!("unhandled message: {:?}", message);
} }
} else { } else {
log::info!("rpc connection closed {:?}", addr); log::info!("rpc connection closed {:?}", addr);
@ -158,6 +144,7 @@ impl Server {
} }
} }
} }
}
#[derive(Default)] #[derive(Default)]
pub struct State { pub struct State {

View file

@ -1,9 +1,7 @@
use crate::{ use crate::{
auth, auth,
db::{self, UserId}, db::{self, UserId},
github, github, rpc, AppState, Config,
rpc::{self, build_server},
AppState, Config,
}; };
use async_std::task; use async_std::task;
use gpui::TestAppContext; use gpui::TestAppContext;
@ -28,6 +26,8 @@ use zrpc::Peer;
#[gpui::test] #[gpui::test]
async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
tide::log::start();
let (window_b, _) = cx_b.add_window(|_| EmptyView); let (window_b, _) = cx_b.add_window(|_| EmptyView);
let settings = settings::channel(&cx_b.font_cache()).unwrap().1; let settings = settings::channel(&cx_b.font_cache()).unwrap().1;
let lang_registry = Arc::new(LanguageRegistry::new()); let lang_registry = Arc::new(LanguageRegistry::new());
@ -514,9 +514,9 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
.await .await
.unwrap(); .unwrap();
let channels_a = client_a.get_channels().await; // let channels_a = client_a.get_channels().await;
assert_eq!(channels_a.len(), 1); // assert_eq!(channels_a.len(), 1);
assert_eq!(channels_a[0].read(&cx_a).name(), "test-channel"); // assert_eq!(channels_a[0].read(&cx_a).name(), "test-channel");
// assert_eq!( // assert_eq!(
// db.get_recent_channel_messages(channel_id, 50) // db.get_recent_channel_messages(channel_id, 50)
@ -530,8 +530,8 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
struct TestServer { struct TestServer {
peer: Arc<Peer>, peer: Arc<Peer>,
app_state: Arc<AppState>, app_state: Arc<AppState>,
server: Arc<rpc::Server>,
db_name: String, db_name: String,
router: Arc<Router>,
} }
impl TestServer { impl TestServer {
@ -540,36 +540,27 @@ impl TestServer {
let db_name = format!("zed-test-{}", rng.gen::<u128>()); let db_name = format!("zed-test-{}", rng.gen::<u128>());
let app_state = Self::build_app_state(&db_name).await; let app_state = Self::build_app_state(&db_name).await;
let peer = Peer::new(); let peer = Peer::new();
let mut router = Router::new(); let server = rpc::build_server(&app_state, &peer);
build_server(&mut router, &app_state, &peer);
Self { Self {
peer, peer,
router: Arc::new(router),
app_state, app_state,
server,
db_name, db_name,
} }
} }
async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> (UserId, Client) { async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> (UserId, Client) {
let user_id = self.app_state.db.create_user(name, false).await.unwrap(); let user_id = self.app_state.db.create_user(name, false).await.unwrap();
let lang_registry = Arc::new(LanguageRegistry::new()); let client = Client::new();
let client = Client::new(lang_registry.clone());
let mut client_router = ForegroundRouter::new();
cx.update(|cx| zed::worktree::init(cx, &client, &mut client_router));
let (client_conn, server_conn) = Channel::bidirectional(); let (client_conn, server_conn) = Channel::bidirectional();
cx.background() cx.background()
.spawn(rpc::handle_connection( .spawn(
self.peer.clone(), self.server
self.router.clone(), .handle_connection(server_conn, name.to_string(), user_id),
self.app_state.clone(), )
name.to_string(),
server_conn,
user_id,
))
.detach(); .detach();
client client
.add_connection(client_conn, Arc::new(client_router), cx.to_async()) .add_connection(client_conn, cx.to_async())
.await .await
.unwrap(); .unwrap();

View file

@ -1,6 +1,6 @@
use crate::rpc::{self, Client}; use crate::rpc::{self, Client};
use anyhow::Result; use anyhow::Result;
use gpui::{Entity, ModelContext, Task, WeakModelHandle}; use gpui::{Entity, ModelContext, WeakModelHandle};
use std::{ use std::{
collections::{HashMap, VecDeque}, collections::{HashMap, VecDeque},
sync::Arc, sync::Arc,
@ -22,7 +22,7 @@ pub struct Channel {
first_message_id: Option<u64>, first_message_id: Option<u64>,
messages: Option<VecDeque<ChannelMessage>>, messages: Option<VecDeque<ChannelMessage>>,
rpc: Arc<Client>, rpc: Arc<Client>,
_message_handler: Task<()>, _subscription: rpc::Subscription,
} }
pub struct ChannelMessage { pub struct ChannelMessage {
@ -50,20 +50,20 @@ impl Entity for Channel {
impl Channel { impl Channel {
pub fn new(details: ChannelDetails, rpc: Arc<Client>, cx: &mut ModelContext<Self>) -> Self { pub fn new(details: ChannelDetails, rpc: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
let _message_handler = rpc.subscribe_from_model(details.id, cx, Self::handle_message_sent); let _subscription = rpc.subscribe_from_model(details.id, cx, Self::handle_message_sent);
Self { Self {
details, details,
rpc, rpc,
first_message_id: None, first_message_id: None,
messages: None, messages: None,
_message_handler, _subscription,
} }
} }
fn handle_message_sent( fn handle_message_sent(
&mut self, &mut self,
message: &TypedEnvelope<ChannelMessageSent>, message: TypedEnvelope<ChannelMessageSent>,
rpc: rpc::Client, rpc: rpc::Client,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {

View file

@ -13,7 +13,6 @@ use zed::{
workspace::{self, OpenParams}, workspace::{self, OpenParams},
AppState, AppState,
}; };
use zrpc::ForegroundRouter;
fn main() { fn main() {
init_logger(); init_logger();
@ -31,8 +30,7 @@ fn main() {
settings_tx: Arc::new(Mutex::new(settings_tx)), settings_tx: Arc::new(Mutex::new(settings_tx)),
settings, settings,
themes, themes,
rpc_router: Arc::new(ForegroundRouter::new()), rpc: rpc::Client::new(),
rpc: rpc::Client::new(languages),
fs: Arc::new(RealFs), fs: Arc::new(RealFs),
}; };

View file

@ -19,13 +19,13 @@ pub fn menus(state: &Arc<AppState>) -> Vec<Menu<'static>> {
name: "Share", name: "Share",
keystroke: None, keystroke: None,
action: "workspace:share_worktree", action: "workspace:share_worktree",
arg: Some(Box::new(state.clone())), arg: None,
}, },
MenuItem::Action { MenuItem::Action {
name: "Join", name: "Join",
keystroke: None, keystroke: None,
action: "workspace:join_worktree", action: "workspace:join_worktree",
arg: Some(Box::new(state.clone())), arg: None,
}, },
MenuItem::Action { MenuItem::Action {
name: "Quit", name: "Quit",

View file

@ -1,15 +1,17 @@
use crate::language::LanguageRegistry;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_tungstenite::tungstenite::http::Request; use async_tungstenite::tungstenite::http::Request;
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::StreamExt;
use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use gpui::{AsyncAppContext, Entity, ModelContext, Task};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use smol::lock::RwLock; use parking_lot::RwLock;
use std::time::Duration; use postage::prelude::Stream;
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::Weak;
use std::time::{Duration, Instant};
use std::{convert::TryFrom, future::Future, sync::Arc}; use std::{convert::TryFrom, future::Future, sync::Arc};
use surf::Url; use surf::Url;
use zrpc::proto::EntityMessage; use zrpc::proto::{AnyTypedEnvelope, EntityMessage};
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope}; pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
use zrpc::{ use zrpc::{
proto::{EnvelopedMessage, RequestMessage}, proto::{EnvelopedMessage, RequestMessage},
@ -24,22 +26,37 @@ lazy_static! {
#[derive(Clone)] #[derive(Clone)]
pub struct Client { pub struct Client {
peer: Arc<Peer>, peer: Arc<Peer>,
pub state: Arc<RwLock<ClientState>>, state: Arc<RwLock<ClientState>>,
} }
#[derive(Default)]
pub struct ClientState { pub struct ClientState {
connection_id: Option<ConnectionId>, connection_id: Option<ConnectionId>,
pub languages: Arc<LanguageRegistry>, entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
model_handlers: HashMap<
(TypeId, u64),
Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
>,
}
pub struct Subscription {
state: Weak<RwLock<ClientState>>,
id: (TypeId, u64),
}
impl Drop for Subscription {
fn drop(&mut self) {
if let Some(state) = self.state.upgrade() {
let _ = state.write().model_handlers.remove(&self.id).unwrap();
}
}
} }
impl Client { impl Client {
pub fn new(languages: Arc<LanguageRegistry>) -> Self { pub fn new() -> Self {
Self { Self {
peer: Peer::new(), peer: Peer::new(),
state: Arc::new(RwLock::new(ClientState { state: Default::default(),
connection_id: None,
languages,
})),
} }
} }
@ -48,31 +65,56 @@ impl Client {
remote_id: u64, remote_id: u64,
cx: &mut ModelContext<M>, cx: &mut ModelContext<M>,
mut handler: F, mut handler: F,
) -> Task<()> ) -> Subscription
where where
T: EntityMessage, T: EntityMessage,
M: Entity, M: Entity,
F: 'static + FnMut(&mut M, &TypedEnvelope<T>, Client, &mut ModelContext<M>) -> Result<()>, F: 'static
+ Send
+ Sync
+ FnMut(&mut M, TypedEnvelope<T>, Client, &mut ModelContext<M>) -> Result<()>,
{ {
let rpc = self.clone(); let subscription_id = (TypeId::of::<T>(), remote_id);
let mut incoming = self.peer.subscribe::<T>(); let client = self.clone();
cx.spawn_weak(|model, mut cx| async move { let mut state = self.state.write();
while let Some(envelope) = incoming.next().await { let model = cx.handle().downgrade();
if envelope.payload.remote_entity_id() == remote_id { state
if let Some(model) = model.upgrade(&cx) { .entity_id_extractors
model.update(&mut cx, |model, cx| { .entry(subscription_id.0)
if let Err(error) = handler(model, &envelope, rpc.clone(), cx) { .or_insert_with(|| {
Box::new(|envelope| {
let envelope = envelope
.as_any()
.downcast_ref::<TypedEnvelope<T>>()
.unwrap();
envelope.payload.remote_entity_id()
})
});
let prev_handler = state.model_handlers.insert(
subscription_id,
Box::new(move |envelope, cx| {
if let Some(model) = model.upgrade(cx) {
let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
model.update(cx, |model, cx| {
if let Err(error) = handler(model, *envelope, client.clone(), cx) {
log::error!("error handling message: {}", error) log::error!("error handling message: {}", error)
} }
}); });
} }
}),
);
if prev_handler.is_some() {
panic!("registered a handler for the same entity twice")
} }
Subscription {
state: Arc::downgrade(&self.state),
id: subscription_id,
} }
})
} }
pub async fn log_in_and_connect(&self, cx: AsyncAppContext) -> surf::Result<()> { pub async fn log_in_and_connect(&self, cx: AsyncAppContext) -> surf::Result<()> {
if self.state.read().await.connection_id.is_some() { if self.state.read().connection_id.is_some() {
return Ok(()); return Ok(());
} }
@ -110,8 +152,39 @@ impl Client {
+ Unpin + Unpin
+ Send, + Send,
{ {
let (connection_id, handle_io, handle_messages) = self.peer.add_connection(conn).await; let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
cx.foreground().spawn(handle_messages).detach(); {
let mut cx = cx.clone();
let state = self.state.clone();
cx.foreground()
.spawn(async move {
while let Some(message) = incoming.recv().await {
let mut state = state.write();
if let Some(extract_entity_id) =
state.entity_id_extractors.get(&message.payload_type_id())
{
let entity_id = (extract_entity_id)(message.as_ref());
if let Some(handler) = state
.model_handlers
.get_mut(&(message.payload_type_id(), entity_id))
{
let start_time = Instant::now();
log::info!("RPC client message {}", message.payload_type_name());
(handler)(message, &mut cx);
log::info!(
"RPC message handled. duration:{:?}",
start_time.elapsed()
);
} else {
log::info!("unhandled message {}", message.payload_type_name());
}
} else {
log::info!("unhandled message {}", message.payload_type_name());
}
}
})
.detach();
}
cx.background() cx.background()
.spawn(async move { .spawn(async move {
if let Err(error) = handle_io.await { if let Err(error) = handle_io.await {
@ -119,7 +192,7 @@ impl Client {
} }
}) })
.detach(); .detach();
self.state.write().await.connection_id = Some(connection_id); self.state.write().connection_id = Some(connection_id);
Ok(()) Ok(())
} }
@ -200,27 +273,24 @@ impl Client {
} }
pub async fn disconnect(&self) -> Result<()> { pub async fn disconnect(&self) -> Result<()> {
let conn_id = self.connection_id().await?; let conn_id = self.connection_id()?;
self.peer.disconnect(conn_id).await; self.peer.disconnect(conn_id).await;
Ok(()) Ok(())
} }
async fn connection_id(&self) -> Result<ConnectionId> { fn connection_id(&self) -> Result<ConnectionId> {
self.state self.state
.read() .read()
.await
.connection_id .connection_id
.ok_or_else(|| anyhow!("not connected")) .ok_or_else(|| anyhow!("not connected"))
} }
pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> { pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
self.peer.send(self.connection_id().await?, message).await self.peer.send(self.connection_id()?, message).await
} }
pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> { pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
self.peer self.peer.request(self.connection_id()?, request).await
.request(self.connection_id().await?, request)
.await
} }
pub fn respond<T: RequestMessage>( pub fn respond<T: RequestMessage>(

View file

@ -162,7 +162,7 @@ pub fn build_app_state(cx: &AppContext) -> Arc<AppState> {
settings, settings,
themes, themes,
languages: languages.clone(), languages: languages.clone(),
rpc: rpc::Client::new(languages), rpc: rpc::Client::new(),
fs: Arc::new(RealFs), fs: Arc::new(RealFs),
}) })
} }

View file

@ -82,16 +82,14 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
} }
} }
pub async fn log_async_errors<F>(f: F) -> impl Future<Output = ()> pub async fn log_async_errors<F>(f: F)
where where
F: Future<Output = anyhow::Result<()>>, F: Future<Output = anyhow::Result<()>>,
{ {
async {
if let Err(error) = f.await { if let Err(error) = f.await {
log::error!("{}", error) log::error!("{}", error)
} }
} }
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View file

@ -108,7 +108,7 @@ fn open_new(app_state: &Arc<AppState>, cx: &mut MutableAppContext) {
fn join_worktree(app_state: &Arc<AppState>, cx: &mut MutableAppContext) { fn join_worktree(app_state: &Arc<AppState>, cx: &mut MutableAppContext) {
cx.add_window(|cx| { cx.add_window(|cx| {
let mut view = Workspace::new(app_state.as_ref(), cx); let mut view = Workspace::new(app_state.as_ref(), cx);
view.join_worktree(&app_state, cx); view.join_worktree(&(), cx);
view view
}); });
} }
@ -725,7 +725,7 @@ impl Workspace {
}; };
} }
fn share_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) { fn share_worktree(&mut self, _: &(), cx: &mut ViewContext<Self>) {
let rpc = self.rpc.clone(); let rpc = self.rpc.clone();
let platform = cx.platform(); let platform = cx.platform();
@ -757,7 +757,7 @@ impl Workspace {
.detach(); .detach();
} }
fn join_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) { fn join_worktree(&mut self, _: &(), cx: &mut ViewContext<Self>) {
let rpc = self.rpc.clone(); let rpc = self.rpc.clone();
let languages = self.languages.clone(); let languages = self.languages.clone();

View file

@ -213,7 +213,7 @@ impl Worktree {
.detach(); .detach();
} }
let _message_handlers = vec![ let _subscriptions = vec![
rpc.subscribe_from_model(remote_id, cx, Self::handle_add_peer), rpc.subscribe_from_model(remote_id, cx, Self::handle_add_peer),
rpc.subscribe_from_model(remote_id, cx, Self::handle_remove_peer), rpc.subscribe_from_model(remote_id, cx, Self::handle_remove_peer),
rpc.subscribe_from_model(remote_id, cx, Self::handle_update), rpc.subscribe_from_model(remote_id, cx, Self::handle_update),
@ -234,7 +234,7 @@ impl Worktree {
.map(|p| (PeerId(p.peer_id), p.replica_id as ReplicaId)) .map(|p| (PeerId(p.peer_id), p.replica_id as ReplicaId))
.collect(), .collect(),
languages, languages,
_message_handlers, _subscriptions,
}) })
}) })
}); });
@ -282,7 +282,7 @@ impl Worktree {
pub fn handle_add_peer( pub fn handle_add_peer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::AddPeer>, envelope: TypedEnvelope<proto::AddPeer>,
_: rpc::Client, _: rpc::Client,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {
@ -294,7 +294,7 @@ impl Worktree {
pub fn handle_remove_peer( pub fn handle_remove_peer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::RemovePeer>, envelope: TypedEnvelope<proto::RemovePeer>,
_: rpc::Client, _: rpc::Client,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {
@ -306,7 +306,7 @@ impl Worktree {
pub fn handle_update( pub fn handle_update(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::UpdateWorktree>, envelope: TypedEnvelope<proto::UpdateWorktree>,
_: rpc::Client, _: rpc::Client,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
@ -317,7 +317,7 @@ impl Worktree {
pub fn handle_open_buffer( pub fn handle_open_buffer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::OpenBuffer>, envelope: TypedEnvelope<proto::OpenBuffer>,
rpc: rpc::Client, rpc: rpc::Client,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
@ -340,7 +340,7 @@ impl Worktree {
pub fn handle_close_buffer( pub fn handle_close_buffer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::CloseBuffer>, envelope: TypedEnvelope<proto::CloseBuffer>,
_: rpc::Client, _: rpc::Client,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
@ -396,7 +396,7 @@ impl Worktree {
pub fn handle_update_buffer( pub fn handle_update_buffer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::UpdateBuffer>, envelope: TypedEnvelope<proto::UpdateBuffer>,
_: rpc::Client, _: rpc::Client,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {
@ -443,7 +443,7 @@ impl Worktree {
pub fn handle_save_buffer( pub fn handle_save_buffer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::SaveBuffer>, envelope: TypedEnvelope<proto::SaveBuffer>,
rpc: rpc::Client, rpc: rpc::Client,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {
@ -485,7 +485,7 @@ impl Worktree {
pub fn handle_buffer_saved( pub fn handle_buffer_saved(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::BufferSaved>, envelope: TypedEnvelope<proto::BufferSaved>,
_: rpc::Client, _: rpc::Client,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {
@ -791,7 +791,7 @@ impl LocalWorktree {
pub fn open_remote_buffer( pub fn open_remote_buffer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::OpenBuffer>, envelope: TypedEnvelope<proto::OpenBuffer>,
cx: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
) -> Task<Result<proto::OpenBufferResponse>> { ) -> Task<Result<proto::OpenBufferResponse>> {
let peer_id = envelope.original_sender_id(); let peer_id = envelope.original_sender_id();
@ -818,11 +818,12 @@ impl LocalWorktree {
pub fn close_remote_buffer( pub fn close_remote_buffer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::CloseBuffer>, envelope: TypedEnvelope<proto::CloseBuffer>,
_: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
) -> Result<()> { ) -> Result<()> {
if let Some(shared_buffers) = self.shared_buffers.get_mut(&envelope.original_sender_id()?) { if let Some(shared_buffers) = self.shared_buffers.get_mut(&envelope.original_sender_id()?) {
shared_buffers.remove(&envelope.payload.buffer_id); shared_buffers.remove(&envelope.payload.buffer_id);
cx.notify();
} }
Ok(()) Ok(())
@ -830,7 +831,7 @@ impl LocalWorktree {
pub fn add_peer( pub fn add_peer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::AddPeer>, envelope: TypedEnvelope<proto::AddPeer>,
cx: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
) -> Result<()> { ) -> Result<()> {
let peer = envelope let peer = envelope
@ -847,7 +848,7 @@ impl LocalWorktree {
pub fn remove_peer( pub fn remove_peer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::RemovePeer>, envelope: TypedEnvelope<proto::RemovePeer>,
cx: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
) -> Result<()> { ) -> Result<()> {
let peer_id = PeerId(envelope.payload.peer_id); let peer_id = PeerId(envelope.payload.peer_id);
@ -994,7 +995,7 @@ impl LocalWorktree {
.detach(); .detach();
this.update(&mut cx, |worktree, cx| { this.update(&mut cx, |worktree, cx| {
let _message_handlers = vec![ let _subscriptions = vec![
rpc.subscribe_from_model(remote_id, cx, Worktree::handle_add_peer), rpc.subscribe_from_model(remote_id, cx, Worktree::handle_add_peer),
rpc.subscribe_from_model(remote_id, cx, Worktree::handle_remove_peer), rpc.subscribe_from_model(remote_id, cx, Worktree::handle_remove_peer),
rpc.subscribe_from_model(remote_id, cx, Worktree::handle_open_buffer), rpc.subscribe_from_model(remote_id, cx, Worktree::handle_open_buffer),
@ -1008,7 +1009,7 @@ impl LocalWorktree {
rpc, rpc,
remote_id: share_response.worktree_id, remote_id: share_response.worktree_id,
snapshots_tx: snapshots_to_send_tx, snapshots_tx: snapshots_to_send_tx,
_message_handlers, _subscriptions,
}); });
}); });
@ -1068,7 +1069,7 @@ struct ShareState {
rpc: rpc::Client, rpc: rpc::Client,
remote_id: u64, remote_id: u64,
snapshots_tx: Sender<Snapshot>, snapshots_tx: Sender<Snapshot>,
_message_handlers: Vec<Task<()>>, _subscriptions: Vec<rpc::Subscription>,
} }
pub struct RemoteWorktree { pub struct RemoteWorktree {
@ -1081,7 +1082,7 @@ pub struct RemoteWorktree {
open_buffers: HashMap<usize, RemoteBuffer>, open_buffers: HashMap<usize, RemoteBuffer>,
peers: HashMap<PeerId, ReplicaId>, peers: HashMap<PeerId, ReplicaId>,
languages: Arc<LanguageRegistry>, languages: Arc<LanguageRegistry>,
_message_handlers: Vec<Task<()>>, _subscriptions: Vec<rpc::Subscription>,
} }
impl RemoteWorktree { impl RemoteWorktree {
@ -1151,7 +1152,7 @@ impl RemoteWorktree {
fn update_from_remote( fn update_from_remote(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::UpdateWorktree>, envelope: TypedEnvelope<proto::UpdateWorktree>,
cx: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
) -> Result<()> { ) -> Result<()> {
let mut tx = self.updates_tx.clone(); let mut tx = self.updates_tx.clone();
@ -1167,7 +1168,7 @@ impl RemoteWorktree {
pub fn add_peer( pub fn add_peer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::AddPeer>, envelope: TypedEnvelope<proto::AddPeer>,
cx: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
) -> Result<()> { ) -> Result<()> {
let peer = envelope let peer = envelope
@ -1183,7 +1184,7 @@ impl RemoteWorktree {
pub fn remove_peer( pub fn remove_peer(
&mut self, &mut self,
envelope: &TypedEnvelope<proto::RemovePeer>, envelope: TypedEnvelope<proto::RemovePeer>,
cx: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
) -> Result<()> { ) -> Result<()> {
let peer_id = PeerId(envelope.payload.peer_id); let peer_id = PeerId(envelope.payload.peer_id);
@ -2761,7 +2762,7 @@ mod tests {
replica_id: 1, replica_id: 1,
peers: Vec::new(), peers: Vec::new(),
}, },
rpc::Client::new(Default::default()), rpc::Client::new(),
Default::default(), Default::default(),
&mut cx.to_async(), &mut cx.to_async(),
) )

View file

@ -1,4 +1,4 @@
use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage}; use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock}; use async_lock::{Mutex, RwLock};
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
@ -8,7 +8,6 @@ use postage::{
prelude::{Sink as _, Stream as _}, prelude::{Sink as _, Stream as _},
}; };
use std::{ use std::{
any::Any,
collections::HashMap, collections::HashMap,
fmt, fmt,
future::Future, future::Future,
@ -105,7 +104,7 @@ impl Peer {
) -> ( ) -> (
ConnectionId, ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send, impl Future<Output = anyhow::Result<()>> + Send,
mpsc::Receiver<Box<dyn Any + Sync + Send>>, mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
) )
where where
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError> Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
@ -409,10 +408,11 @@ mod tests {
client2.disconnect(client1_conn_id).await; client2.disconnect(client1_conn_id).await;
async fn handle_messages( async fn handle_messages(
mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>, mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
peer: Arc<Peer>, peer: Arc<Peer>,
) -> Result<()> { ) -> Result<()> {
while let Some(envelope) = messages.next().await { while let Some(envelope) = messages.next().await {
let envelope = envelope.into_any();
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() { if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
let receipt = envelope.receipt(); let receipt = envelope.receipt();
peer.respond( peer.respond(

View file

@ -3,7 +3,7 @@ use anyhow::Result;
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::{SinkExt as _, StreamExt as _}; use futures::{SinkExt as _, StreamExt as _};
use prost::Message; use prost::Message;
use std::any::Any; use std::any::{Any, TypeId};
use std::{ use std::{
io, io,
time::{Duration, SystemTime, UNIX_EPOCH}, time::{Duration, SystemTime, UNIX_EPOCH},
@ -31,9 +31,34 @@ pub trait RequestMessage: EnvelopedMessage {
type Response: EnvelopedMessage; type Response: EnvelopedMessage;
} }
pub trait AnyTypedEnvelope: 'static + Send + Sync {
fn payload_type_id(&self) -> TypeId;
fn payload_type_name(&self) -> &'static str;
fn as_any(&self) -> &dyn Any;
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
}
impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
fn payload_type_id(&self) -> TypeId {
TypeId::of::<T>()
}
fn payload_type_name(&self) -> &'static str {
T::NAME
}
fn as_any(&self) -> &dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
self
}
}
macro_rules! messages { macro_rules! messages {
($($name:ident),* $(,)?) => { ($($name:ident),* $(,)?) => {
pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn Any + Send + Sync>> { pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
match envelope.payload { match envelope.payload {
$(Some(envelope::Payload::$name(payload)) => { $(Some(envelope::Payload::$name(payload)) => {
Some(Box::new(TypedEnvelope { Some(Box::new(TypedEnvelope {