Consolidate logic for protobuf message handling between ssh and web socket clients (#17185)
This is a refactor to prepare for adding LSP support in SSH remote projects. Release Notes: - N/A --------- Co-authored-by: Mikayla <mikayla@zed.dev> Co-authored-by: Conrad <conrad@zed.dev>
This commit is contained in:
parent
144793bf16
commit
b8e6098f60
20 changed files with 1002 additions and 963 deletions
|
@ -14,22 +14,18 @@ use async_tungstenite::tungstenite::{
|
|||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use clock::SystemClock;
|
||||
use collections::HashMap;
|
||||
use futures::{
|
||||
channel::oneshot,
|
||||
future::{BoxFuture, LocalBoxFuture},
|
||||
AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
|
||||
};
|
||||
use gpui::{
|
||||
actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
|
||||
channel::oneshot, future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
|
||||
TryFutureExt as _, TryStreamExt,
|
||||
};
|
||||
use gpui::{actions, AppContext, AsyncAppContext, Global, Model, Task, WeakModel};
|
||||
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
|
||||
use parking_lot::RwLock;
|
||||
use postage::watch;
|
||||
use proto::ProtoClient;
|
||||
use proto::{AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet};
|
||||
use rand::prelude::*;
|
||||
use release_channel::{AppVersion, ReleaseChannel};
|
||||
use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
|
||||
use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
|
@ -208,6 +204,7 @@ pub struct Client {
|
|||
telemetry: Arc<Telemetry>,
|
||||
credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
|
||||
state: RwLock<ClientState>,
|
||||
handler_set: parking_lot::Mutex<ProtoMessageHandlerSet>,
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
|
@ -304,30 +301,7 @@ impl Status {
|
|||
struct ClientState {
|
||||
credentials: Option<Credentials>,
|
||||
status: (watch::Sender<Status>, watch::Receiver<Status>),
|
||||
entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
|
||||
_reconnect_task: Option<Task<()>>,
|
||||
entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
|
||||
models_by_message_type: HashMap<TypeId, AnyWeakModel>,
|
||||
entity_types_by_message_type: HashMap<TypeId, TypeId>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
message_handlers: HashMap<
|
||||
TypeId,
|
||||
Arc<
|
||||
dyn Send
|
||||
+ Sync
|
||||
+ Fn(
|
||||
AnyModel,
|
||||
Box<dyn AnyTypedEnvelope>,
|
||||
&Arc<Client>,
|
||||
AsyncAppContext,
|
||||
) -> LocalBoxFuture<'static, Result<()>>,
|
||||
>,
|
||||
>,
|
||||
}
|
||||
|
||||
enum WeakSubscriber {
|
||||
Entity { handle: AnyWeakModel },
|
||||
Pending(Vec<Box<dyn AnyTypedEnvelope>>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
|
@ -379,12 +353,7 @@ impl Default for ClientState {
|
|||
Self {
|
||||
credentials: None,
|
||||
status: watch::channel_with(Status::SignedOut),
|
||||
entity_id_extractors: Default::default(),
|
||||
_reconnect_task: None,
|
||||
models_by_message_type: Default::default(),
|
||||
entities_by_type_and_remote_id: Default::default(),
|
||||
entity_types_by_message_type: Default::default(),
|
||||
message_handlers: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -405,13 +374,13 @@ impl Drop for Subscription {
|
|||
match self {
|
||||
Subscription::Entity { client, id } => {
|
||||
if let Some(client) = client.upgrade() {
|
||||
let mut state = client.state.write();
|
||||
let mut state = client.handler_set.lock();
|
||||
let _ = state.entities_by_type_and_remote_id.remove(id);
|
||||
}
|
||||
}
|
||||
Subscription::Message { client, id } => {
|
||||
if let Some(client) = client.upgrade() {
|
||||
let mut state = client.state.write();
|
||||
let mut state = client.handler_set.lock();
|
||||
let _ = state.entity_types_by_message_type.remove(id);
|
||||
let _ = state.message_handlers.remove(id);
|
||||
}
|
||||
|
@ -430,21 +399,21 @@ pub struct PendingEntitySubscription<T: 'static> {
|
|||
impl<T: 'static> PendingEntitySubscription<T> {
|
||||
pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
|
||||
self.consumed = true;
|
||||
let mut state = self.client.state.write();
|
||||
let mut handlers = self.client.handler_set.lock();
|
||||
let id = (TypeId::of::<T>(), self.remote_id);
|
||||
let Some(WeakSubscriber::Pending(messages)) =
|
||||
state.entities_by_type_and_remote_id.remove(&id)
|
||||
let Some(EntityMessageSubscriber::Pending(messages)) =
|
||||
handlers.entities_by_type_and_remote_id.remove(&id)
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
state.entities_by_type_and_remote_id.insert(
|
||||
handlers.entities_by_type_and_remote_id.insert(
|
||||
id,
|
||||
WeakSubscriber::Entity {
|
||||
EntityMessageSubscriber::Entity {
|
||||
handle: model.downgrade().into(),
|
||||
},
|
||||
);
|
||||
drop(state);
|
||||
drop(handlers);
|
||||
for message in messages {
|
||||
let client_id = self.client.id();
|
||||
let type_name = message.payload_type_name();
|
||||
|
@ -467,8 +436,8 @@ impl<T: 'static> PendingEntitySubscription<T> {
|
|||
impl<T: 'static> Drop for PendingEntitySubscription<T> {
|
||||
fn drop(&mut self) {
|
||||
if !self.consumed {
|
||||
let mut state = self.client.state.write();
|
||||
if let Some(WeakSubscriber::Pending(messages)) = state
|
||||
let mut state = self.client.handler_set.lock();
|
||||
if let Some(EntityMessageSubscriber::Pending(messages)) = state
|
||||
.entities_by_type_and_remote_id
|
||||
.remove(&(TypeId::of::<T>(), self.remote_id))
|
||||
{
|
||||
|
@ -549,6 +518,7 @@ impl Client {
|
|||
http,
|
||||
credentials_provider,
|
||||
state: Default::default(),
|
||||
handler_set: Default::default(),
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
authenticate: Default::default(),
|
||||
|
@ -592,10 +562,7 @@ impl Client {
|
|||
pub fn teardown(&self) {
|
||||
let mut state = self.state.write();
|
||||
state._reconnect_task.take();
|
||||
state.message_handlers.clear();
|
||||
state.models_by_message_type.clear();
|
||||
state.entities_by_type_and_remote_id.clear();
|
||||
state.entity_id_extractors.clear();
|
||||
self.handler_set.lock().clear();
|
||||
self.peer.teardown();
|
||||
}
|
||||
|
||||
|
@ -708,14 +675,14 @@ impl Client {
|
|||
{
|
||||
let id = (TypeId::of::<T>(), remote_id);
|
||||
|
||||
let mut state = self.state.write();
|
||||
let mut state = self.handler_set.lock();
|
||||
if state.entities_by_type_and_remote_id.contains_key(&id) {
|
||||
return Err(anyhow!("already subscribed to entity"));
|
||||
}
|
||||
|
||||
state
|
||||
.entities_by_type_and_remote_id
|
||||
.insert(id, WeakSubscriber::Pending(Default::default()));
|
||||
.insert(id, EntityMessageSubscriber::Pending(Default::default()));
|
||||
|
||||
Ok(PendingEntitySubscription {
|
||||
client: self.clone(),
|
||||
|
@ -752,13 +719,13 @@ impl Client {
|
|||
E: 'static,
|
||||
H: 'static
|
||||
+ Sync
|
||||
+ Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F
|
||||
+ Fn(Model<E>, TypedEnvelope<M>, AnyProtoClient, AsyncAppContext) -> F
|
||||
+ Send
|
||||
+ Sync,
|
||||
F: 'static + Future<Output = Result<()>>,
|
||||
{
|
||||
let message_type_id = TypeId::of::<M>();
|
||||
let mut state = self.state.write();
|
||||
let mut state = self.handler_set.lock();
|
||||
state
|
||||
.models_by_message_type
|
||||
.insert(message_type_id, entity.into());
|
||||
|
@ -803,85 +770,18 @@ impl Client {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
|
||||
where
|
||||
M: EntityMessage,
|
||||
E: 'static,
|
||||
H: 'static + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
|
||||
F: 'static + Future<Output = Result<()>>,
|
||||
{
|
||||
self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, _, cx| {
|
||||
handler(subscriber.downcast::<E>().unwrap(), message, cx)
|
||||
})
|
||||
}
|
||||
|
||||
fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
|
||||
where
|
||||
M: EntityMessage,
|
||||
E: 'static,
|
||||
H: 'static + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F + Send + Sync,
|
||||
F: 'static + Future<Output = Result<()>>,
|
||||
{
|
||||
let model_type_id = TypeId::of::<E>();
|
||||
let message_type_id = TypeId::of::<M>();
|
||||
|
||||
let mut state = self.state.write();
|
||||
state
|
||||
.entity_types_by_message_type
|
||||
.insert(message_type_id, model_type_id);
|
||||
state
|
||||
.entity_id_extractors
|
||||
.entry(message_type_id)
|
||||
.or_insert_with(|| {
|
||||
|envelope| {
|
||||
envelope
|
||||
.as_any()
|
||||
.downcast_ref::<TypedEnvelope<M>>()
|
||||
.unwrap()
|
||||
.payload
|
||||
.remote_entity_id()
|
||||
}
|
||||
});
|
||||
let prev_handler = state.message_handlers.insert(
|
||||
message_type_id,
|
||||
Arc::new(move |handle, envelope, client, cx| {
|
||||
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
|
||||
handler(handle, *envelope, client.clone(), cx).boxed_local()
|
||||
}),
|
||||
);
|
||||
if prev_handler.is_some() {
|
||||
panic!("registered handler for the same message twice");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_model_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
|
||||
where
|
||||
M: EntityMessage + RequestMessage,
|
||||
E: 'static,
|
||||
H: 'static + Fn(Model<E>, TypedEnvelope<M>, AsyncAppContext) -> F + Send + Sync,
|
||||
F: 'static + Future<Output = Result<M::Response>>,
|
||||
{
|
||||
self.add_entity_message_handler::<M, E, _, _>(move |entity, envelope, client, cx| {
|
||||
Self::respond_to_request::<M, _>(
|
||||
envelope.receipt(),
|
||||
handler(entity.downcast::<E>().unwrap(), envelope, cx),
|
||||
client,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
|
||||
receipt: Receipt<T>,
|
||||
response: F,
|
||||
client: Arc<Self>,
|
||||
client: AnyProtoClient,
|
||||
) -> Result<()> {
|
||||
match response.await {
|
||||
Ok(response) => {
|
||||
client.respond(receipt, response)?;
|
||||
client.send_response(receipt.message_id, response)?;
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
client.respond_with_error(receipt, error.to_proto())?;
|
||||
client.send_response(receipt.message_id, error.to_proto())?;
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
|
@ -1541,16 +1441,6 @@ impl Client {
|
|||
self.peer.send(self.connection_id()?, message)
|
||||
}
|
||||
|
||||
pub fn send_dynamic(
|
||||
&self,
|
||||
envelope: proto::Envelope,
|
||||
message_type: &'static str,
|
||||
) -> Result<()> {
|
||||
log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
|
||||
let connection_id = self.connection_id()?;
|
||||
self.peer.send_dynamic(connection_id, envelope)
|
||||
}
|
||||
|
||||
pub fn request<T: RequestMessage>(
|
||||
&self,
|
||||
request: T,
|
||||
|
@ -1632,115 +1522,56 @@ impl Client {
|
|||
}
|
||||
}
|
||||
|
||||
fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
|
||||
log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
|
||||
self.peer.respond(receipt, response)
|
||||
}
|
||||
|
||||
fn respond_with_error<T: RequestMessage>(
|
||||
&self,
|
||||
receipt: Receipt<T>,
|
||||
error: proto::Error,
|
||||
) -> Result<()> {
|
||||
log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
|
||||
self.peer.respond_with_error(receipt, error)
|
||||
}
|
||||
|
||||
fn handle_message(
|
||||
self: &Arc<Client>,
|
||||
message: Box<dyn AnyTypedEnvelope>,
|
||||
cx: &AsyncAppContext,
|
||||
) {
|
||||
let mut state = self.state.write();
|
||||
let sender_id = message.sender_id();
|
||||
let request_id = message.message_id();
|
||||
let type_name = message.payload_type_name();
|
||||
let payload_type_id = message.payload_type_id();
|
||||
let sender_id = message.original_sender_id();
|
||||
let original_sender_id = message.original_sender_id();
|
||||
|
||||
let mut subscriber = None;
|
||||
|
||||
if let Some(handle) = state
|
||||
.models_by_message_type
|
||||
.get(&payload_type_id)
|
||||
.and_then(|handle| handle.upgrade())
|
||||
{
|
||||
subscriber = Some(handle);
|
||||
} else if let Some((extract_entity_id, entity_type_id)) =
|
||||
state.entity_id_extractors.get(&payload_type_id).zip(
|
||||
state
|
||||
.entity_types_by_message_type
|
||||
.get(&payload_type_id)
|
||||
.copied(),
|
||||
)
|
||||
{
|
||||
let entity_id = (extract_entity_id)(message.as_ref());
|
||||
|
||||
match state
|
||||
.entities_by_type_and_remote_id
|
||||
.get_mut(&(entity_type_id, entity_id))
|
||||
{
|
||||
Some(WeakSubscriber::Pending(pending)) => {
|
||||
pending.push(message);
|
||||
return;
|
||||
}
|
||||
Some(weak_subscriber) => match weak_subscriber {
|
||||
WeakSubscriber::Entity { handle } => {
|
||||
subscriber = handle.upgrade();
|
||||
}
|
||||
|
||||
WeakSubscriber::Pending(_) => {}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let subscriber = if let Some(subscriber) = subscriber {
|
||||
subscriber
|
||||
} else {
|
||||
log::info!("unhandled message {}", type_name);
|
||||
self.peer.respond_with_unhandled_message(message).log_err();
|
||||
return;
|
||||
};
|
||||
|
||||
let handler = state.message_handlers.get(&payload_type_id).cloned();
|
||||
// Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
|
||||
// It also ensures we don't hold the lock while yielding back to the executor, as
|
||||
// that might cause the executor thread driving this future to block indefinitely.
|
||||
drop(state);
|
||||
|
||||
if let Some(handler) = handler {
|
||||
let future = handler(subscriber, message, self, cx.clone());
|
||||
if let Some(future) = ProtoMessageHandlerSet::handle_message(
|
||||
&self.handler_set,
|
||||
message,
|
||||
self.clone().into(),
|
||||
cx.clone(),
|
||||
) {
|
||||
let client_id = self.id();
|
||||
log::debug!(
|
||||
"rpc message received. client_id:{}, sender_id:{:?}, type:{}",
|
||||
client_id,
|
||||
sender_id,
|
||||
original_sender_id,
|
||||
type_name
|
||||
);
|
||||
cx.spawn(move |_| async move {
|
||||
match future.await {
|
||||
Ok(()) => {
|
||||
log::debug!(
|
||||
"rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
|
||||
client_id,
|
||||
sender_id,
|
||||
type_name
|
||||
);
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!(
|
||||
"error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
|
||||
client_id,
|
||||
sender_id,
|
||||
type_name,
|
||||
error
|
||||
);
|
||||
}
|
||||
match future.await {
|
||||
Ok(()) => {
|
||||
log::debug!(
|
||||
"rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
|
||||
client_id,
|
||||
original_sender_id,
|
||||
type_name
|
||||
);
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
Err(error) => {
|
||||
log::error!(
|
||||
"error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
|
||||
client_id,
|
||||
original_sender_id,
|
||||
type_name,
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
} else {
|
||||
log::info!("unhandled message {}", type_name);
|
||||
self.peer.respond_with_unhandled_message(message).log_err();
|
||||
self.peer
|
||||
.respond_with_unhandled_message(sender_id.into(), request_id, type_name)
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1759,7 +1590,23 @@ impl ProtoClient for Client {
|
|||
}
|
||||
|
||||
fn send(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
|
||||
self.send_dynamic(envelope, message_type)
|
||||
log::debug!("rpc send. client_id:{}, name:{}", self.id(), message_type);
|
||||
let connection_id = self.connection_id()?;
|
||||
self.peer.send_dynamic(connection_id, envelope)
|
||||
}
|
||||
|
||||
fn send_response(&self, envelope: proto::Envelope, message_type: &'static str) -> Result<()> {
|
||||
log::debug!(
|
||||
"rpc respond. client_id:{}, name:{}",
|
||||
self.id(),
|
||||
message_type
|
||||
);
|
||||
let connection_id = self.connection_id()?;
|
||||
self.peer.send_dynamic(connection_id, envelope)
|
||||
}
|
||||
|
||||
fn message_handler_set(&self) -> &parking_lot::Mutex<ProtoMessageHandlerSet> {
|
||||
&self.handler_set
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2103,7 +1950,7 @@ mod tests {
|
|||
|
||||
let (done_tx1, mut done_rx1) = smol::channel::unbounded();
|
||||
let (done_tx2, mut done_rx2) = smol::channel::unbounded();
|
||||
client.add_model_message_handler(
|
||||
AnyProtoClient::from(client.clone()).add_model_message_handler(
|
||||
move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, mut cx| {
|
||||
match model.update(&mut cx, |model, _| model.id).unwrap() {
|
||||
1 => done_tx1.try_send(()).unwrap(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue