Send websocket pings from both the client and the server

Remove the client-only logic for sending protobuf pings.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Max Brunsfeld 2022-03-04 13:53:40 -08:00
parent c61a1bd659
commit 9017a1363b
9 changed files with 174 additions and 92 deletions

View file

@ -94,6 +94,7 @@ pub struct ConnectionState {
Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
}
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
impl Peer {
@ -104,14 +105,20 @@ impl Peer {
})
}
pub async fn add_connection(
pub async fn add_connection<F, Fut, Out>(
self: &Arc<Self>,
connection: Connection,
create_timer: F,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
) {
)
where
F: Send + Fn(Duration) -> Fut,
Fut: Send + Future<Output = Out>,
Out: Send,
{
// For outgoing messages, use an unbounded channel so that application code
// can always send messages without yielding. For incoming messages, use a
// bounded channel so that other peers will receive backpressure if they send
@ -121,7 +128,7 @@ impl Peer {
let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
let connection_state = ConnectionState {
outgoing_tx,
outgoing_tx: outgoing_tx.clone(),
next_message_id: Default::default(),
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
};
@ -131,39 +138,43 @@ impl Peer {
let this = self.clone();
let response_channels = connection_state.response_channels.clone();
let handle_io = async move {
let result = 'outer: loop {
let _end_connection = util::defer(|| {
response_channels.lock().take();
this.connections.write().remove(&connection_id);
});
loop {
let read_message = reader.read_message().fuse();
futures::pin_mut!(read_message);
loop {
futures::select_biased! {
outgoing = outgoing_rx.next().fuse() => match outgoing {
Some(outgoing) => {
match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
None => break 'outer Err(anyhow!("timed out writing RPC message")),
Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
_ => {}
if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
result.context("failed to write RPC message")?;
} else {
Err(anyhow!("timed out writing message"))?;
}
}
None => break 'outer Ok(()),
None => return Ok(()),
},
incoming = read_message => match incoming {
Ok(incoming) => {
if incoming_tx.send(incoming).await.is_err() {
break 'outer Ok(());
}
break;
}
Err(error) => {
break 'outer Err(error).context("received invalid RPC message")
incoming = read_message => {
let incoming = incoming.context("received invalid rpc message")?;
if incoming_tx.send(incoming).await.is_err() {
return Ok(());
}
break;
},
_ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await {
result.context("failed to send websocket ping")?;
} else {
Err(anyhow!("timed out sending websocket ping"))?;
}
}
}
}
};
response_channels.lock().take();
this.connections.write().remove(&connection_id);
result
}
};
let response_channels = connection_state.response_channels.clone();
@ -191,18 +202,31 @@ impl Peer {
None
} else {
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
Some(envelope)
} else {
proto::build_typed_envelope(connection_id, incoming).or_else(|| {
log::error!("unable to construct a typed envelope");
None
}
})
}
}
});
(connection_id, handle_io, incoming_rx.boxed())
}
#[cfg(any(test, feature = "test-support"))]
pub async fn add_test_connection(
self: &Arc<Self>,
connection: Connection,
executor: Arc<gpui::executor::Background>,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
) {
let executor = executor.clone();
self.add_connection(connection, move |duration| executor.timer(duration))
.await
}
pub fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().remove(&connection_id);
}
@ -349,15 +373,21 @@ mod tests {
let (client1_to_server_conn, server_to_client_1_conn, _) =
Connection::in_memory(cx.background());
let (client1_conn_id, io_task1, client1_incoming) =
client1.add_connection(client1_to_server_conn).await;
let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
let (client1_conn_id, io_task1, client1_incoming) = client1
.add_test_connection(client1_to_server_conn, cx.background())
.await;
let (_, io_task2, server_incoming1) = server
.add_test_connection(server_to_client_1_conn, cx.background())
.await;
let (client2_to_server_conn, server_to_client_2_conn, _) =
Connection::in_memory(cx.background());
let (client2_conn_id, io_task3, client2_incoming) =
client2.add_connection(client2_to_server_conn).await;
let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
let (client2_conn_id, io_task3, client2_incoming) = client2
.add_test_connection(client2_to_server_conn, cx.background())
.await;
let (_, io_task4, server_incoming2) = server
.add_test_connection(server_to_client_2_conn, cx.background())
.await;
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
@ -440,10 +470,12 @@ mod tests {
let (client_to_server_conn, server_to_client_conn, _) =
Connection::in_memory(cx.background());
let (client_to_server_conn_id, io_task1, mut client_incoming) =
client.add_connection(client_to_server_conn).await;
let (server_to_client_conn_id, io_task2, mut server_incoming) =
server.add_connection(server_to_client_conn).await;
let (client_to_server_conn_id, io_task1, mut client_incoming) = client
.add_test_connection(client_to_server_conn, cx.background())
.await;
let (server_to_client_conn_id, io_task2, mut server_incoming) = server
.add_test_connection(server_to_client_conn, cx.background())
.await;
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
@ -538,10 +570,12 @@ mod tests {
let (client_to_server_conn, server_to_client_conn, _) =
Connection::in_memory(cx.background());
let (client_to_server_conn_id, io_task1, mut client_incoming) =
client.add_connection(client_to_server_conn).await;
let (server_to_client_conn_id, io_task2, mut server_incoming) =
server.add_connection(server_to_client_conn).await;
let (client_to_server_conn_id, io_task1, mut client_incoming) = client
.add_test_connection(client_to_server_conn, cx.background())
.await;
let (server_to_client_conn_id, io_task2, mut server_incoming) = server
.add_test_connection(server_to_client_conn, cx.background())
.await;
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
@ -649,7 +683,9 @@ mod tests {
let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
let client = Peer::new();
let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
let (connection_id, io_handler, mut incoming) = client
.add_test_connection(client_conn, cx.background())
.await;
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
executor
@ -683,7 +719,9 @@ mod tests {
let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
let client = Peer::new();
let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
let (connection_id, io_handler, mut incoming) = client
.add_test_connection(client_conn, cx.background())
.await;
executor.spawn(io_handler).detach();
executor
.spawn(async move { incoming.next().await })

View file

@ -318,6 +318,13 @@ where
self.stream.send(WebSocketMessage::Binary(buffer)).await?;
Ok(())
}
pub async fn ping(&mut self) -> Result<(), WebSocketError> {
self.stream
.send(WebSocketMessage::Ping(Default::default()))
.await?;
Ok(())
}
}
impl<S> MessageStream<S>