Send websocket pings from both the client and the server
Remove the client-only logic for sending protobuf pings. Co-Authored-By: Nathan Sobo <nathan@zed.dev> Co-Authored-By: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
c61a1bd659
commit
9017a1363b
9 changed files with 174 additions and 92 deletions
|
@ -94,6 +94,7 @@ pub struct ConnectionState {
|
|||
Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
|
||||
}
|
||||
|
||||
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
|
||||
const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
impl Peer {
|
||||
|
@ -104,14 +105,20 @@ impl Peer {
|
|||
})
|
||||
}
|
||||
|
||||
pub async fn add_connection(
|
||||
pub async fn add_connection<F, Fut, Out>(
|
||||
self: &Arc<Self>,
|
||||
connection: Connection,
|
||||
create_timer: F,
|
||||
) -> (
|
||||
ConnectionId,
|
||||
impl Future<Output = anyhow::Result<()>> + Send,
|
||||
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
|
||||
) {
|
||||
)
|
||||
where
|
||||
F: Send + Fn(Duration) -> Fut,
|
||||
Fut: Send + Future<Output = Out>,
|
||||
Out: Send,
|
||||
{
|
||||
// For outgoing messages, use an unbounded channel so that application code
|
||||
// can always send messages without yielding. For incoming messages, use a
|
||||
// bounded channel so that other peers will receive backpressure if they send
|
||||
|
@ -121,7 +128,7 @@ impl Peer {
|
|||
|
||||
let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
|
||||
let connection_state = ConnectionState {
|
||||
outgoing_tx,
|
||||
outgoing_tx: outgoing_tx.clone(),
|
||||
next_message_id: Default::default(),
|
||||
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
|
||||
};
|
||||
|
@ -131,39 +138,43 @@ impl Peer {
|
|||
let this = self.clone();
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
let handle_io = async move {
|
||||
let result = 'outer: loop {
|
||||
let _end_connection = util::defer(|| {
|
||||
response_channels.lock().take();
|
||||
this.connections.write().remove(&connection_id);
|
||||
});
|
||||
|
||||
loop {
|
||||
let read_message = reader.read_message().fuse();
|
||||
futures::pin_mut!(read_message);
|
||||
loop {
|
||||
futures::select_biased! {
|
||||
outgoing = outgoing_rx.next().fuse() => match outgoing {
|
||||
Some(outgoing) => {
|
||||
match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
|
||||
None => break 'outer Err(anyhow!("timed out writing RPC message")),
|
||||
Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
|
||||
_ => {}
|
||||
if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
|
||||
result.context("failed to write RPC message")?;
|
||||
} else {
|
||||
Err(anyhow!("timed out writing message"))?;
|
||||
}
|
||||
}
|
||||
None => break 'outer Ok(()),
|
||||
None => return Ok(()),
|
||||
},
|
||||
incoming = read_message => match incoming {
|
||||
Ok(incoming) => {
|
||||
if incoming_tx.send(incoming).await.is_err() {
|
||||
break 'outer Ok(());
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(error) => {
|
||||
break 'outer Err(error).context("received invalid RPC message")
|
||||
incoming = read_message => {
|
||||
let incoming = incoming.context("received invalid rpc message")?;
|
||||
if incoming_tx.send(incoming).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
break;
|
||||
},
|
||||
_ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
|
||||
if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await {
|
||||
result.context("failed to send websocket ping")?;
|
||||
} else {
|
||||
Err(anyhow!("timed out sending websocket ping"))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
response_channels.lock().take();
|
||||
this.connections.write().remove(&connection_id);
|
||||
result
|
||||
}
|
||||
};
|
||||
|
||||
let response_channels = connection_state.response_channels.clone();
|
||||
|
@ -191,18 +202,31 @@ impl Peer {
|
|||
|
||||
None
|
||||
} else {
|
||||
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
|
||||
Some(envelope)
|
||||
} else {
|
||||
proto::build_typed_envelope(connection_id, incoming).or_else(|| {
|
||||
log::error!("unable to construct a typed envelope");
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
});
|
||||
(connection_id, handle_io, incoming_rx.boxed())
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub async fn add_test_connection(
|
||||
self: &Arc<Self>,
|
||||
connection: Connection,
|
||||
executor: Arc<gpui::executor::Background>,
|
||||
) -> (
|
||||
ConnectionId,
|
||||
impl Future<Output = anyhow::Result<()>> + Send,
|
||||
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
|
||||
) {
|
||||
let executor = executor.clone();
|
||||
self.add_connection(connection, move |duration| executor.timer(duration))
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn disconnect(&self, connection_id: ConnectionId) {
|
||||
self.connections.write().remove(&connection_id);
|
||||
}
|
||||
|
@ -349,15 +373,21 @@ mod tests {
|
|||
|
||||
let (client1_to_server_conn, server_to_client_1_conn, _) =
|
||||
Connection::in_memory(cx.background());
|
||||
let (client1_conn_id, io_task1, client1_incoming) =
|
||||
client1.add_connection(client1_to_server_conn).await;
|
||||
let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
|
||||
let (client1_conn_id, io_task1, client1_incoming) = client1
|
||||
.add_test_connection(client1_to_server_conn, cx.background())
|
||||
.await;
|
||||
let (_, io_task2, server_incoming1) = server
|
||||
.add_test_connection(server_to_client_1_conn, cx.background())
|
||||
.await;
|
||||
|
||||
let (client2_to_server_conn, server_to_client_2_conn, _) =
|
||||
Connection::in_memory(cx.background());
|
||||
let (client2_conn_id, io_task3, client2_incoming) =
|
||||
client2.add_connection(client2_to_server_conn).await;
|
||||
let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
|
||||
let (client2_conn_id, io_task3, client2_incoming) = client2
|
||||
.add_test_connection(client2_to_server_conn, cx.background())
|
||||
.await;
|
||||
let (_, io_task4, server_incoming2) = server
|
||||
.add_test_connection(server_to_client_2_conn, cx.background())
|
||||
.await;
|
||||
|
||||
executor.spawn(io_task1).detach();
|
||||
executor.spawn(io_task2).detach();
|
||||
|
@ -440,10 +470,12 @@ mod tests {
|
|||
|
||||
let (client_to_server_conn, server_to_client_conn, _) =
|
||||
Connection::in_memory(cx.background());
|
||||
let (client_to_server_conn_id, io_task1, mut client_incoming) =
|
||||
client.add_connection(client_to_server_conn).await;
|
||||
let (server_to_client_conn_id, io_task2, mut server_incoming) =
|
||||
server.add_connection(server_to_client_conn).await;
|
||||
let (client_to_server_conn_id, io_task1, mut client_incoming) = client
|
||||
.add_test_connection(client_to_server_conn, cx.background())
|
||||
.await;
|
||||
let (server_to_client_conn_id, io_task2, mut server_incoming) = server
|
||||
.add_test_connection(server_to_client_conn, cx.background())
|
||||
.await;
|
||||
|
||||
executor.spawn(io_task1).detach();
|
||||
executor.spawn(io_task2).detach();
|
||||
|
@ -538,10 +570,12 @@ mod tests {
|
|||
|
||||
let (client_to_server_conn, server_to_client_conn, _) =
|
||||
Connection::in_memory(cx.background());
|
||||
let (client_to_server_conn_id, io_task1, mut client_incoming) =
|
||||
client.add_connection(client_to_server_conn).await;
|
||||
let (server_to_client_conn_id, io_task2, mut server_incoming) =
|
||||
server.add_connection(server_to_client_conn).await;
|
||||
let (client_to_server_conn_id, io_task1, mut client_incoming) = client
|
||||
.add_test_connection(client_to_server_conn, cx.background())
|
||||
.await;
|
||||
let (server_to_client_conn_id, io_task2, mut server_incoming) = server
|
||||
.add_test_connection(server_to_client_conn, cx.background())
|
||||
.await;
|
||||
|
||||
executor.spawn(io_task1).detach();
|
||||
executor.spawn(io_task2).detach();
|
||||
|
@ -649,7 +683,9 @@ mod tests {
|
|||
let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
|
||||
|
||||
let client = Peer::new();
|
||||
let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
|
||||
let (connection_id, io_handler, mut incoming) = client
|
||||
.add_test_connection(client_conn, cx.background())
|
||||
.await;
|
||||
|
||||
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
|
||||
executor
|
||||
|
@ -683,7 +719,9 @@ mod tests {
|
|||
let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
|
||||
|
||||
let client = Peer::new();
|
||||
let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
|
||||
let (connection_id, io_handler, mut incoming) = client
|
||||
.add_test_connection(client_conn, cx.background())
|
||||
.await;
|
||||
executor.spawn(io_handler).detach();
|
||||
executor
|
||||
.spawn(async move { incoming.next().await })
|
||||
|
|
|
@ -318,6 +318,13 @@ where
|
|||
self.stream.send(WebSocketMessage::Binary(buffer)).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn ping(&mut self) -> Result<(), WebSocketError> {
|
||||
self.stream
|
||||
.send(WebSocketMessage::Ping(Default::default()))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> MessageStream<S>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue