WIP: Clear cached credentials if authentication fails

Still need to actually handle an HTTP response from the server indicating there was an invalid token.

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Nathan Sobo 2021-09-14 19:19:11 -06:00
parent 77a4a36eb3
commit 4a9918979e
8 changed files with 149 additions and 67 deletions

9
Cargo.lock generated
View file

@ -5108,18 +5108,18 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.24" version = "1.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0f4a65597094d4483ddaed134f409b2cb7c1beccf25201a9f73c719254fa98e" checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.24" version = "1.0.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7765189610d8241a44529806d6fd1f2e0a08734313a35d5b3a556f92b381f3c0" checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -5914,6 +5914,7 @@ dependencies = [
"smol", "smol",
"surf", "surf",
"tempdir", "tempdir",
"thiserror",
"time 0.3.2", "time 0.3.2",
"tiny_http", "tiny_http",
"toml 0.5.8", "toml 0.5.8",

View file

@ -27,7 +27,7 @@ use time::OffsetDateTime;
use zrpc::{ use zrpc::{
auth::random_token, auth::random_token,
proto::{self, AnyTypedEnvelope, EnvelopedMessage}, proto::{self, AnyTypedEnvelope, EnvelopedMessage},
Conn, ConnectionId, Peer, TypedEnvelope, Connection, ConnectionId, Peer, TypedEnvelope,
}; };
type ReplicaId = u16; type ReplicaId = u16;
@ -48,13 +48,13 @@ pub struct Server {
#[derive(Default)] #[derive(Default)]
struct ServerState { struct ServerState {
connections: HashMap<ConnectionId, Connection>, connections: HashMap<ConnectionId, ConnectionState>,
pub worktrees: HashMap<u64, Worktree>, pub worktrees: HashMap<u64, Worktree>,
channels: HashMap<ChannelId, Channel>, channels: HashMap<ChannelId, Channel>,
next_worktree_id: u64, next_worktree_id: u64,
} }
struct Connection { struct ConnectionState {
user_id: UserId, user_id: UserId,
worktrees: HashSet<u64>, worktrees: HashSet<u64>,
channels: HashSet<ChannelId>, channels: HashSet<ChannelId>,
@ -133,7 +133,7 @@ impl Server {
pub fn handle_connection( pub fn handle_connection(
self: &Arc<Self>, self: &Arc<Self>,
connection: Conn, connection: Connection,
addr: String, addr: String,
user_id: UserId, user_id: UserId,
) -> impl Future<Output = ()> { ) -> impl Future<Output = ()> {
@ -211,7 +211,7 @@ impl Server {
async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) { async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
self.state.write().await.connections.insert( self.state.write().await.connections.insert(
connection_id, connection_id,
Connection { ConnectionState {
user_id, user_id,
worktrees: Default::default(), worktrees: Default::default(),
channels: Default::default(), channels: Default::default(),
@ -972,7 +972,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?; let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move { task::spawn(async move {
if let Some(stream) = upgrade_receiver.await { if let Some(stream) = upgrade_receiver.await {
server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await; server.handle_connection(Connection::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
} }
}); });
@ -1023,7 +1023,7 @@ mod tests {
editor::{Editor, Insert}, editor::{Editor, Insert},
fs::{FakeFs, Fs as _}, fs::{FakeFs, Fs as _},
language::LanguageRegistry, language::LanguageRegistry,
rpc::{self, Client, Credentials}, rpc::{self, Client, Credentials, EstablishConnectionError},
settings, settings,
test::FakeHttpClient, test::FakeHttpClient,
user::UserStore, user::UserStore,
@ -1941,9 +1941,11 @@ mod tests {
let client_name = client_name.clone(); let client_name = client_name.clone();
cx.spawn(move |cx| async move { cx.spawn(move |cx| async move {
if forbid_connections.load(SeqCst) { if forbid_connections.load(SeqCst) {
Err(anyhow!("server is forbidding connections")) Err(EstablishConnectionError::other(anyhow!(
"server is forbidding connections"
)))
} else { } else {
let (client_conn, server_conn, kill_conn) = Conn::in_memory(); let (client_conn, server_conn, kill_conn) = Connection::in_memory();
connection_killers.lock().insert(client_user_id, kill_conn); connection_killers.lock().insert(client_user_id, kill_conn);
cx.background() cx.background()
.spawn(server.handle_connection( .spawn(server.handle_connection(

View file

@ -50,6 +50,7 @@ smallvec = { version = "1.6", features = ["union"] }
smol = "1.2.5" smol = "1.2.5"
surf = "2.2" surf = "2.2"
tempdir = { version = "0.3.7", optional = true } tempdir = { version = "0.3.7", optional = true }
thiserror = "1.0.29"
time = { version = "0.3" } time = { version = "0.3" }
tiny_http = "0.8" tiny_http = "0.8"
toml = "0.5" toml = "0.5"

View file

@ -15,10 +15,11 @@ use std::{
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use surf::Url; use surf::Url;
use thiserror::Error;
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope}; pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
use zrpc::{ use zrpc::{
proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}, proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
Conn, Peer, Receipt, Connection, Peer, Receipt,
}; };
lazy_static! { lazy_static! {
@ -32,10 +33,32 @@ pub struct Client {
authenticate: authenticate:
Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>, Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
establish_connection: Option< establish_connection: Option<
Box<dyn 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>>, Box<
dyn 'static
+ Send
+ Sync
+ Fn(
&Credentials,
&AsyncAppContext,
) -> Task<Result<Connection, EstablishConnectionError>>,
>,
>, >,
} }
#[derive(Error, Debug)]
pub enum EstablishConnectionError {
#[error("invalid access token")]
InvalidAccessToken,
#[error("{0}")]
Other(anyhow::Error),
}
impl EstablishConnectionError {
pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
Self::Other(error.into())
}
}
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub enum Status { pub enum Status {
SignedOut, SignedOut,
@ -122,7 +145,10 @@ impl Client {
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
where where
F: 'static + Send + Sync + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Conn>>, F: 'static
+ Send
+ Sync
+ Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
{ {
self.establish_connection = Some(Box::new(connect)); self.establish_connection = Some(Box::new(connect));
self self
@ -288,13 +314,18 @@ impl Client {
Ok(()) Ok(())
} }
Err(err) => { Err(err) => {
eprintln!("error in authenticate and connect {}", err);
if matches!(err, EstablishConnectionError::InvalidAccessToken) {
eprintln!("nuking credentials");
self.state.write().credentials.take();
}
self.set_status(Status::ConnectionError, cx); self.set_status(Status::ConnectionError, cx);
Err(err) Err(err)?
} }
} }
} }
async fn set_connection(self: &Arc<Self>, conn: Conn, cx: &AsyncAppContext) { async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await; let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
cx.foreground() cx.foreground()
.spawn({ .spawn({
@ -359,7 +390,7 @@ impl Client {
self: &Arc<Self>, self: &Arc<Self>,
credentials: &Credentials, credentials: &Credentials,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> Task<Result<Conn>> { ) -> Task<Result<Connection, EstablishConnectionError>> {
if let Some(callback) = self.establish_connection.as_ref() { if let Some(callback) = self.establish_connection.as_ref() {
callback(credentials, cx) callback(credentials, cx)
} else { } else {
@ -371,28 +402,43 @@ impl Client {
self: &Arc<Self>, self: &Arc<Self>,
credentials: &Credentials, credentials: &Credentials,
cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> Task<Result<Conn>> { ) -> Task<Result<Connection, EstablishConnectionError>> {
let request = Request::builder().header( let request = Request::builder().header(
"Authorization", "Authorization",
format!("{} {}", credentials.user_id, credentials.access_token), format!("{} {}", credentials.user_id, credentials.access_token),
); );
cx.background().spawn(async move { cx.background().spawn(async move {
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") { if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?; let stream = smol::net::TcpStream::connect(host)
let request = request.uri(format!("wss://{}/rpc", host)).body(())?; .await
.map_err(EstablishConnectionError::other)?;
let request = request
.uri(format!("wss://{}/rpc", host))
.body(())
.map_err(EstablishConnectionError::other)?;
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream) let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
.await .await
.context("websocket handshake")?; .context("websocket handshake")
Ok(Conn::new(stream)) .map_err(EstablishConnectionError::other)?;
Ok(Connection::new(stream))
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host).await?; let stream = smol::net::TcpStream::connect(host)
let request = request.uri(format!("ws://{}/rpc", host)).body(())?; .await
.map_err(EstablishConnectionError::other)?;
let request = request
.uri(format!("ws://{}/rpc", host))
.body(())
.map_err(EstablishConnectionError::other)?;
let (stream, _) = async_tungstenite::client_async(request, stream) let (stream, _) = async_tungstenite::client_async(request, stream)
.await .await
.context("websocket handshake")?; .context("websocket handshake")
Ok(Conn::new(stream)) .map_err(EstablishConnectionError::other)?;
Ok(Connection::new(stream))
} else { } else {
Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL)) Err(EstablishConnectionError::other(anyhow!(
"invalid server url: {}",
*ZED_SERVER_URL
)))
} }
}) })
} }
@ -591,6 +637,19 @@ mod tests {
cx.foreground().advance_clock(Duration::from_secs(10)); cx.foreground().advance_clock(Duration::from_secs(10));
while !matches!(status.recv().await, Some(Status::Connected { .. })) {} while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
server.forbid_connections();
server.disconnect().await;
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
// Clear cached credentials after authentication fails
server.roll_access_token();
server.allow_connections();
cx.foreground().advance_clock(Duration::from_secs(10));
assert_eq!(server.auth_count(), 1);
cx.foreground().advance_clock(Duration::from_secs(10));
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
} }
#[test] #[test]

View file

@ -4,7 +4,7 @@ use crate::{
fs::RealFs, fs::RealFs,
http::{HttpClient, Request, Response, ServerResponse}, http::{HttpClient, Request, Response, ServerResponse},
language::LanguageRegistry, language::LanguageRegistry,
rpc::{self, Client, Credentials}, rpc::{self, Client, Credentials, EstablishConnectionError},
settings::{self, ThemeRegistry}, settings::{self, ThemeRegistry},
time::ReplicaId, time::ReplicaId,
user::UserStore, user::UserStore,
@ -26,7 +26,7 @@ use std::{
}, },
}; };
use tempdir::TempDir; use tempdir::TempDir;
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope}; use zrpc::{proto, Connection, ConnectionId, Peer, Receipt, TypedEnvelope};
#[cfg(test)] #[cfg(test)]
#[ctor::ctor] #[ctor::ctor]
@ -210,6 +210,8 @@ pub struct FakeServer {
connection_id: Mutex<Option<ConnectionId>>, connection_id: Mutex<Option<ConnectionId>>,
forbid_connections: AtomicBool, forbid_connections: AtomicBool,
auth_count: AtomicUsize, auth_count: AtomicUsize,
access_token: AtomicUsize,
user_id: u64,
} }
impl FakeServer { impl FakeServer {
@ -224,6 +226,8 @@ impl FakeServer {
connection_id: Default::default(), connection_id: Default::default(),
forbid_connections: Default::default(), forbid_connections: Default::default(),
auth_count: Default::default(), auth_count: Default::default(),
access_token: Default::default(),
user_id: client_user_id,
}); });
Arc::get_mut(client) Arc::get_mut(client)
@ -232,8 +236,8 @@ impl FakeServer {
let server = server.clone(); let server = server.clone();
move |cx| { move |cx| {
server.auth_count.fetch_add(1, SeqCst); server.auth_count.fetch_add(1, SeqCst);
let access_token = server.access_token.load(SeqCst).to_string();
cx.spawn(move |_| async move { cx.spawn(move |_| async move {
let access_token = "the-token".to_string();
Ok(Credentials { Ok(Credentials {
user_id: client_user_id, user_id: client_user_id,
access_token, access_token,
@ -244,11 +248,10 @@ impl FakeServer {
.override_establish_connection({ .override_establish_connection({
let server = server.clone(); let server = server.clone();
move |credentials, cx| { move |credentials, cx| {
assert_eq!(credentials.user_id, client_user_id); let credentials = credentials.clone();
assert_eq!(credentials.access_token, "the-token");
cx.spawn({ cx.spawn({
let server = server.clone(); let server = server.clone();
move |cx| async move { server.connect(&cx).await } move |cx| async move { server.establish_connection(&credentials, &cx).await }
}) })
} }
}); });
@ -266,23 +269,39 @@ impl FakeServer {
self.incoming.lock().take(); self.incoming.lock().take();
} }
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> { async fn establish_connection(
&self,
credentials: &Credentials,
cx: &AsyncAppContext,
) -> Result<Connection, EstablishConnectionError> {
assert_eq!(credentials.user_id, self.user_id);
if self.forbid_connections.load(SeqCst) { if self.forbid_connections.load(SeqCst) {
Err(anyhow!("server is forbidding connections")) Err(EstablishConnectionError::Other(anyhow!(
} else { "server is forbidding connections"
let (client_conn, server_conn, _) = Conn::in_memory(); )))?
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
Ok(client_conn)
} }
if credentials.access_token != self.access_token.load(SeqCst).to_string() {
Err(EstablishConnectionError::InvalidAccessToken)?
}
let (client_conn, server_conn, _) = Connection::in_memory();
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
Ok(client_conn)
} }
pub fn auth_count(&self) -> usize { pub fn auth_count(&self) -> usize {
self.auth_count.load(SeqCst) self.auth_count.load(SeqCst)
} }
pub fn roll_access_token(&self) {
self.access_token.fetch_add(1, SeqCst);
}
pub fn forbid_connections(&self) { pub fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst); self.forbid_connections.store(true, SeqCst);
} }

View file

@ -2,7 +2,7 @@ use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSock
use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _}; use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
use std::{io, task::Poll}; use std::{io, task::Poll};
pub struct Conn { pub struct Connection {
pub(crate) tx: pub(crate) tx:
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>, Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
pub(crate) rx: Box< pub(crate) rx: Box<
@ -13,7 +13,7 @@ pub struct Conn {
>, >,
} }
impl Conn { impl Connection {
pub fn new<S>(stream: S) -> Self pub fn new<S>(stream: S) -> Self
where where
S: 'static S: 'static

View file

@ -2,5 +2,5 @@ pub mod auth;
mod conn; mod conn;
mod peer; mod peer;
pub mod proto; pub mod proto;
pub use conn::Conn; pub use conn::Connection;
pub use peer::*; pub use peer::*;

View file

@ -1,5 +1,5 @@
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage}; use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
use super::Conn; use super::Connection;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock}; use async_lock::{Mutex, RwLock};
use futures::FutureExt as _; use futures::FutureExt as _;
@ -79,12 +79,12 @@ impl<T: RequestMessage> TypedEnvelope<T> {
} }
pub struct Peer { pub struct Peer {
connections: RwLock<HashMap<ConnectionId, Connection>>, connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
next_connection_id: AtomicU32, next_connection_id: AtomicU32,
} }
#[derive(Clone)] #[derive(Clone)]
struct Connection { struct ConnectionState {
outgoing_tx: mpsc::Sender<proto::Envelope>, outgoing_tx: mpsc::Sender<proto::Envelope>,
next_message_id: Arc<AtomicU32>, next_message_id: Arc<AtomicU32>,
response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>, response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
@ -100,7 +100,7 @@ impl Peer {
pub async fn add_connection( pub async fn add_connection(
self: &Arc<Self>, self: &Arc<Self>,
conn: Conn, connection: Connection,
) -> ( ) -> (
ConnectionId, ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send, impl Future<Output = anyhow::Result<()>> + Send,
@ -112,16 +112,16 @@ impl Peer {
); );
let (mut incoming_tx, incoming_rx) = mpsc::channel(64); let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64); let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
let connection = Connection { let connection_state = ConnectionState {
outgoing_tx, outgoing_tx,
next_message_id: Default::default(), next_message_id: Default::default(),
response_channels: Default::default(), response_channels: Default::default(),
}; };
let mut writer = MessageStream::new(conn.tx); let mut writer = MessageStream::new(connection.tx);
let mut reader = MessageStream::new(conn.rx); let mut reader = MessageStream::new(connection.rx);
let this = self.clone(); let this = self.clone();
let response_channels = connection.response_channels.clone(); let response_channels = connection_state.response_channels.clone();
let handle_io = async move { let handle_io = async move {
loop { loop {
let read_message = reader.read_message().fuse(); let read_message = reader.read_message().fuse();
@ -179,7 +179,7 @@ impl Peer {
self.connections self.connections
.write() .write()
.await .await
.insert(connection_id, connection); .insert(connection_id, connection_state);
(connection_id, handle_io, incoming_rx) (connection_id, handle_io, incoming_rx)
} }
@ -218,7 +218,7 @@ impl Peer {
let this = self.clone(); let this = self.clone();
let (tx, mut rx) = mpsc::channel(1); let (tx, mut rx) = mpsc::channel(1);
async move { async move {
let mut connection = this.connection(receiver_id).await?; let mut connection = this.connection_state(receiver_id).await?;
let message_id = connection let message_id = connection
.next_message_id .next_message_id
.fetch_add(1, atomic::Ordering::SeqCst); .fetch_add(1, atomic::Ordering::SeqCst);
@ -252,7 +252,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> { ) -> impl Future<Output = Result<()>> {
let this = self.clone(); let this = self.clone();
async move { async move {
let mut connection = this.connection(receiver_id).await?; let mut connection = this.connection_state(receiver_id).await?;
let message_id = connection let message_id = connection
.next_message_id .next_message_id
.fetch_add(1, atomic::Ordering::SeqCst); .fetch_add(1, atomic::Ordering::SeqCst);
@ -272,7 +272,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> { ) -> impl Future<Output = Result<()>> {
let this = self.clone(); let this = self.clone();
async move { async move {
let mut connection = this.connection(receiver_id).await?; let mut connection = this.connection_state(receiver_id).await?;
let message_id = connection let message_id = connection
.next_message_id .next_message_id
.fetch_add(1, atomic::Ordering::SeqCst); .fetch_add(1, atomic::Ordering::SeqCst);
@ -291,7 +291,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> { ) -> impl Future<Output = Result<()>> {
let this = self.clone(); let this = self.clone();
async move { async move {
let mut connection = this.connection(receipt.sender_id).await?; let mut connection = this.connection_state(receipt.sender_id).await?;
let message_id = connection let message_id = connection
.next_message_id .next_message_id
.fetch_add(1, atomic::Ordering::SeqCst); .fetch_add(1, atomic::Ordering::SeqCst);
@ -310,7 +310,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> { ) -> impl Future<Output = Result<()>> {
let this = self.clone(); let this = self.clone();
async move { async move {
let mut connection = this.connection(receipt.sender_id).await?; let mut connection = this.connection_state(receipt.sender_id).await?;
let message_id = connection let message_id = connection
.next_message_id .next_message_id
.fetch_add(1, atomic::Ordering::SeqCst); .fetch_add(1, atomic::Ordering::SeqCst);
@ -322,10 +322,10 @@ impl Peer {
} }
} }
fn connection( fn connection_state(
self: &Arc<Self>, self: &Arc<Self>,
connection_id: ConnectionId, connection_id: ConnectionId,
) -> impl Future<Output = Result<Connection>> { ) -> impl Future<Output = Result<ConnectionState>> {
let this = self.clone(); let this = self.clone();
async move { async move {
let connections = this.connections.read().await; let connections = this.connections.read().await;
@ -352,12 +352,12 @@ mod tests {
let client1 = Peer::new(); let client1 = Peer::new();
let client2 = Peer::new(); let client2 = Peer::new();
let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory(); let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
let (client1_conn_id, io_task1, _) = let (client1_conn_id, io_task1, _) =
client1.add_connection(client1_to_server_conn).await; client1.add_connection(client1_to_server_conn).await;
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await; let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory(); let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
let (client2_conn_id, io_task3, _) = let (client2_conn_id, io_task3, _) =
client2.add_connection(client2_to_server_conn).await; client2.add_connection(client2_to_server_conn).await;
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await; let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@ -486,7 +486,7 @@ mod tests {
#[test] #[test]
fn test_disconnect() { fn test_disconnect() {
smol::block_on(async move { smol::block_on(async move {
let (client_conn, mut server_conn, _) = Conn::in_memory(); let (client_conn, mut server_conn, _) = Connection::in_memory();
let client = Peer::new(); let client = Peer::new();
let (connection_id, io_handler, mut incoming) = let (connection_id, io_handler, mut incoming) =
@ -520,7 +520,7 @@ mod tests {
#[test] #[test]
fn test_io_error() { fn test_io_error() {
smol::block_on(async move { smol::block_on(async move {
let (client_conn, server_conn, _) = Conn::in_memory(); let (client_conn, server_conn, _) = Connection::in_memory();
drop(server_conn); drop(server_conn);
let client = Peer::new(); let client = Peer::new();