Rename Store to ConnectionPool

This commit is contained in:
Antonio Scandurra 2022-11-17 19:03:50 +01:00
parent 6c83be3f89
commit 44bb2ce024
3 changed files with 133 additions and 112 deletions

View file

@ -1,5 +1,5 @@
use crate::{ use crate::{
db::{NewUserParams, SqliteTestDb as TestDb, UserId}, db::{self, NewUserParams, SqliteTestDb as TestDb, UserId},
rpc::{Executor, Server}, rpc::{Executor, Server},
AppState, AppState,
}; };
@ -5469,18 +5469,15 @@ async fn test_random_collaboration(
} }
for user_id in &user_ids { for user_id in &user_ids {
let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap(); let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap();
let contacts = server let pool = server.connection_pool.lock().await;
.store
.lock()
.await
.build_initial_contacts_update(contacts)
.contacts;
for contact in contacts { for contact in contacts {
if contact.online { if let db::Contact::Accepted { user_id, .. } = contact {
assert_ne!( if pool.is_user_online(user_id) {
contact.user_id, removed_guest_id.0 as u64, assert_ne!(
"removed guest is still a contact of another peer" user_id, removed_guest_id,
); "removed guest is still a contact of another peer"
);
}
} }
} }
} }

View file

@ -1,4 +1,4 @@
mod store; mod connection_pool;
use crate::{ use crate::{
auth, auth,
@ -23,6 +23,7 @@ use axum::{
Extension, Router, TypedHeader, Extension, Router, TypedHeader,
}; };
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
pub use connection_pool::ConnectionPool;
use futures::{ use futures::{
channel::oneshot, channel::oneshot,
future::{self, BoxFuture}, future::{self, BoxFuture},
@ -49,7 +50,6 @@ use std::{
}, },
time::Duration, time::Duration,
}; };
pub use store::Store;
use tokio::{ use tokio::{
sync::{Mutex, MutexGuard}, sync::{Mutex, MutexGuard},
time::Sleep, time::Sleep,
@ -103,7 +103,7 @@ impl<R: RequestMessage> Response<R> {
pub struct Server { pub struct Server {
peer: Arc<Peer>, peer: Arc<Peer>,
pub(crate) store: Mutex<Store>, pub(crate) connection_pool: Mutex<ConnectionPool>,
app_state: Arc<AppState>, app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>, handlers: HashMap<TypeId, MessageHandler>,
} }
@ -117,8 +117,8 @@ pub trait Executor: Send + Clone {
#[derive(Clone)] #[derive(Clone)]
pub struct RealExecutor; pub struct RealExecutor;
pub(crate) struct StoreGuard<'a> { pub(crate) struct ConnectionPoolGuard<'a> {
guard: MutexGuard<'a, Store>, guard: MutexGuard<'a, ConnectionPool>,
_not_send: PhantomData<Rc<()>>, _not_send: PhantomData<Rc<()>>,
} }
@ -126,7 +126,7 @@ pub(crate) struct StoreGuard<'a> {
pub struct ServerSnapshot<'a> { pub struct ServerSnapshot<'a> {
peer: &'a Peer, peer: &'a Peer,
#[serde(serialize_with = "serialize_deref")] #[serde(serialize_with = "serialize_deref")]
store: StoreGuard<'a>, connection_pool: ConnectionPoolGuard<'a>,
} }
pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error> pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
@ -143,7 +143,7 @@ impl Server {
let mut server = Self { let mut server = Self {
peer: Peer::new(), peer: Peer::new(),
app_state, app_state,
store: Default::default(), connection_pool: Default::default(),
handlers: Default::default(), handlers: Default::default(),
}; };
@ -257,8 +257,6 @@ impl Server {
self self
} }
/// Handle a request while holding a lock to the store. This is useful when we're registering
/// a connection but we want to respond on the connection before anybody else can send on it.
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where where
F: 'static + Send + Sync + Fn(Arc<Self>, M, Response<M>, Session) -> Fut, F: 'static + Send + Sync + Fn(Arc<Self>, M, Response<M>, Session) -> Fut,
@ -342,9 +340,9 @@ impl Server {
).await?; ).await?;
{ {
let mut store = this.store().await; let mut pool = this.connection_pool().await;
store.add_connection(connection_id, user_id, user.admin); pool.add_connection(connection_id, user_id, user.admin);
this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?; this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
if let Some((code, count)) = invite_code { if let Some((code, count)) = invite_code {
this.peer.send(connection_id, proto::UpdateInviteInfo { this.peer.send(connection_id, proto::UpdateInviteInfo {
@ -435,9 +433,9 @@ impl Server {
) -> Result<()> { ) -> Result<()> {
self.peer.disconnect(connection_id); self.peer.disconnect(connection_id);
let decline_calls = { let decline_calls = {
let mut store = self.store().await; let mut pool = self.connection_pool().await;
store.remove_connection(connection_id)?; pool.remove_connection(connection_id)?;
let mut connections = store.user_connection_ids(user_id); let mut connections = pool.user_connection_ids(user_id);
connections.next().is_none() connections.next().is_none()
}; };
@ -468,9 +466,9 @@ impl Server {
) -> Result<()> { ) -> Result<()> {
if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
if let Some(code) = &user.invite_code { if let Some(code) = &user.invite_code {
let store = self.store().await; let pool = self.connection_pool().await;
let invitee_contact = store.contact_for_user(invitee_id, true, false); let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
for connection_id in store.user_connection_ids(inviter_id) { for connection_id in pool.user_connection_ids(inviter_id) {
self.peer.send( self.peer.send(
connection_id, connection_id,
proto::UpdateContacts { proto::UpdateContacts {
@ -494,8 +492,8 @@ impl Server {
pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> { pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
if let Some(invite_code) = &user.invite_code { if let Some(invite_code) = &user.invite_code {
let store = self.store().await; let pool = self.connection_pool().await;
for connection_id in store.user_connection_ids(user_id) { for connection_id in pool.user_connection_ids(user_id) {
self.peer.send( self.peer.send(
connection_id, connection_id,
proto::UpdateInviteInfo { proto::UpdateInviteInfo {
@ -582,7 +580,11 @@ impl Server {
session.connection_id, session.connection_id,
) )
.await?; .await?;
for connection_id in self.store().await.user_connection_ids(session.user_id) { for connection_id in self
.connection_pool()
.await
.user_connection_ids(session.user_id)
{
self.peer self.peer
.send(connection_id, proto::CallCanceled {}) .send(connection_id, proto::CallCanceled {})
.trace_err(); .trace_err();
@ -672,9 +674,9 @@ impl Server {
self.room_updated(&left_room.room); self.room_updated(&left_room.room);
{ {
let store = self.store().await; let pool = self.connection_pool().await;
for canceled_user_id in left_room.canceled_calls_to_user_ids { for canceled_user_id in left_room.canceled_calls_to_user_ids {
for connection_id in store.user_connection_ids(canceled_user_id) { for connection_id in pool.user_connection_ids(canceled_user_id) {
self.peer self.peer
.send(connection_id, proto::CallCanceled {}) .send(connection_id, proto::CallCanceled {})
.trace_err(); .trace_err();
@ -742,7 +744,7 @@ impl Server {
self.update_user_contacts(called_user_id).await?; self.update_user_contacts(called_user_id).await?;
let mut calls = self let mut calls = self
.store() .connection_pool()
.await .await
.user_connection_ids(called_user_id) .user_connection_ids(called_user_id)
.map(|connection_id| self.peer.request(connection_id, incoming_call.clone())) .map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
@ -784,7 +786,11 @@ impl Server {
.db .db
.cancel_call(Some(room_id), session.connection_id, called_user_id) .cancel_call(Some(room_id), session.connection_id, called_user_id)
.await?; .await?;
for connection_id in self.store().await.user_connection_ids(called_user_id) { for connection_id in self
.connection_pool()
.await
.user_connection_ids(called_user_id)
{
self.peer self.peer
.send(connection_id, proto::CallCanceled {}) .send(connection_id, proto::CallCanceled {})
.trace_err(); .trace_err();
@ -807,7 +813,11 @@ impl Server {
.db .db
.decline_call(Some(room_id), session.user_id) .decline_call(Some(room_id), session.user_id)
.await?; .await?;
for connection_id in self.store().await.user_connection_ids(session.user_id) { for connection_id in self
.connection_pool()
.await
.user_connection_ids(session.user_id)
{
self.peer self.peer
.send(connection_id, proto::CallCanceled {}) .send(connection_id, proto::CallCanceled {})
.trace_err(); .trace_err();
@ -897,15 +907,15 @@ impl Server {
async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> { async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
let contacts = self.app_state.db.get_contacts(user_id).await?; let contacts = self.app_state.db.get_contacts(user_id).await?;
let busy = self.app_state.db.is_user_busy(user_id).await?; let busy = self.app_state.db.is_user_busy(user_id).await?;
let store = self.store().await; let pool = self.connection_pool().await;
let updated_contact = store.contact_for_user(user_id, false, busy); let updated_contact = contact_for_user(user_id, false, busy, &pool);
for contact in contacts { for contact in contacts {
if let db::Contact::Accepted { if let db::Contact::Accepted {
user_id: contact_user_id, user_id: contact_user_id,
.. ..
} = contact } = contact
{ {
for contact_conn_id in store.user_connection_ids(contact_user_id) { for contact_conn_id in pool.user_connection_ids(contact_user_id) {
self.peer self.peer
.send( .send(
contact_conn_id, contact_conn_id,
@ -1522,7 +1532,11 @@ impl Server {
// Update outgoing contact requests of requester // Update outgoing contact requests of requester
let mut update = proto::UpdateContacts::default(); let mut update = proto::UpdateContacts::default();
update.outgoing_requests.push(responder_id.to_proto()); update.outgoing_requests.push(responder_id.to_proto());
for connection_id in self.store().await.user_connection_ids(requester_id) { for connection_id in self
.connection_pool()
.await
.user_connection_ids(requester_id)
{
self.peer.send(connection_id, update.clone())?; self.peer.send(connection_id, update.clone())?;
} }
@ -1534,7 +1548,11 @@ impl Server {
requester_id: requester_id.to_proto(), requester_id: requester_id.to_proto(),
should_notify: true, should_notify: true,
}); });
for connection_id in self.store().await.user_connection_ids(responder_id) { for connection_id in self
.connection_pool()
.await
.user_connection_ids(responder_id)
{
self.peer.send(connection_id, update.clone())?; self.peer.send(connection_id, update.clone())?;
} }
@ -1563,18 +1581,18 @@ impl Server {
.await?; .await?;
let busy = self.app_state.db.is_user_busy(requester_id).await?; let busy = self.app_state.db.is_user_busy(requester_id).await?;
let store = self.store().await; let pool = self.connection_pool().await;
// Update responder with new contact // Update responder with new contact
let mut update = proto::UpdateContacts::default(); let mut update = proto::UpdateContacts::default();
if accept { if accept {
update update
.contacts .contacts
.push(store.contact_for_user(requester_id, false, busy)); .push(contact_for_user(requester_id, false, busy, &pool));
} }
update update
.remove_incoming_requests .remove_incoming_requests
.push(requester_id.to_proto()); .push(requester_id.to_proto());
for connection_id in store.user_connection_ids(responder_id) { for connection_id in pool.user_connection_ids(responder_id) {
self.peer.send(connection_id, update.clone())?; self.peer.send(connection_id, update.clone())?;
} }
@ -1583,12 +1601,12 @@ impl Server {
if accept { if accept {
update update
.contacts .contacts
.push(store.contact_for_user(responder_id, true, busy)); .push(contact_for_user(responder_id, true, busy, &pool));
} }
update update
.remove_outgoing_requests .remove_outgoing_requests
.push(responder_id.to_proto()); .push(responder_id.to_proto());
for connection_id in store.user_connection_ids(requester_id) { for connection_id in pool.user_connection_ids(requester_id) {
self.peer.send(connection_id, update.clone())?; self.peer.send(connection_id, update.clone())?;
} }
} }
@ -1615,7 +1633,11 @@ impl Server {
update update
.remove_outgoing_requests .remove_outgoing_requests
.push(responder_id.to_proto()); .push(responder_id.to_proto());
for connection_id in self.store().await.user_connection_ids(requester_id) { for connection_id in self
.connection_pool()
.await
.user_connection_ids(requester_id)
{
self.peer.send(connection_id, update.clone())?; self.peer.send(connection_id, update.clone())?;
} }
@ -1624,7 +1646,11 @@ impl Server {
update update
.remove_incoming_requests .remove_incoming_requests
.push(requester_id.to_proto()); .push(requester_id.to_proto());
for connection_id in self.store().await.user_connection_ids(responder_id) { for connection_id in self
.connection_pool()
.await
.user_connection_ids(responder_id)
{
self.peer.send(connection_id, update.clone())?; self.peer.send(connection_id, update.clone())?;
} }
@ -1678,13 +1704,13 @@ impl Server {
Ok(()) Ok(())
} }
pub(crate) async fn store(&self) -> StoreGuard<'_> { pub(crate) async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
#[cfg(test)] #[cfg(test)]
tokio::task::yield_now().await; tokio::task::yield_now().await;
let guard = self.store.lock().await; let guard = self.connection_pool.lock().await;
#[cfg(test)] #[cfg(test)]
tokio::task::yield_now().await; tokio::task::yield_now().await;
StoreGuard { ConnectionPoolGuard {
guard, guard,
_not_send: PhantomData, _not_send: PhantomData,
} }
@ -1692,27 +1718,27 @@ impl Server {
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> { pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
ServerSnapshot { ServerSnapshot {
store: self.store().await, connection_pool: self.connection_pool().await,
peer: &self.peer, peer: &self.peer,
} }
} }
} }
impl<'a> Deref for StoreGuard<'a> { impl<'a> Deref for ConnectionPoolGuard<'a> {
type Target = Store; type Target = ConnectionPool;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&*self.guard &*self.guard
} }
} }
impl<'a> DerefMut for StoreGuard<'a> { impl<'a> DerefMut for ConnectionPoolGuard<'a> {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.guard &mut *self.guard
} }
} }
impl<'a> Drop for StoreGuard<'a> { impl<'a> Drop for ConnectionPoolGuard<'a> {
fn drop(&mut self) { fn drop(&mut self) {
#[cfg(test)] #[cfg(test)]
self.check_invariants(); self.check_invariants();
@ -1821,7 +1847,7 @@ pub async fn handle_websocket_request(
pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> { pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
let connections = server let connections = server
.store() .connection_pool()
.await .await
.connections() .connections()
.filter(|connection| !connection.admin) .filter(|connection| !connection.admin)
@ -1868,6 +1894,53 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
} }
} }
fn build_initial_contacts_update(
contacts: Vec<db::Contact>,
pool: &ConnectionPool,
) -> proto::UpdateContacts {
let mut update = proto::UpdateContacts::default();
for contact in contacts {
match contact {
db::Contact::Accepted {
user_id,
should_notify,
busy,
} => {
update
.contacts
.push(contact_for_user(user_id, should_notify, busy, &pool));
}
db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
db::Contact::Incoming {
user_id,
should_notify,
} => update
.incoming_requests
.push(proto::IncomingContactRequest {
requester_id: user_id.to_proto(),
should_notify,
}),
}
}
update
}
fn contact_for_user(
user_id: UserId,
should_notify: bool,
busy: bool,
pool: &ConnectionPool,
) -> proto::Contact {
proto::Contact {
user_id: user_id.to_proto(),
online: pool.is_user_online(user_id),
busy,
should_notify,
}
}
pub trait ResultExt { pub trait ResultExt {
type Ok; type Ok;

View file

@ -1,12 +1,12 @@
use crate::db::{self, UserId}; use crate::db::UserId;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashSet}; use collections::{BTreeMap, HashSet};
use rpc::{proto, ConnectionId}; use rpc::ConnectionId;
use serde::Serialize; use serde::Serialize;
use tracing::instrument; use tracing::instrument;
#[derive(Default, Serialize)] #[derive(Default, Serialize)]
pub struct Store { pub struct ConnectionPool {
connections: BTreeMap<ConnectionId, Connection>, connections: BTreeMap<ConnectionId, Connection>,
connected_users: BTreeMap<UserId, ConnectedUser>, connected_users: BTreeMap<UserId, ConnectedUser>,
} }
@ -22,7 +22,7 @@ pub struct Connection {
pub admin: bool, pub admin: bool,
} }
impl Store { impl ConnectionPool {
#[instrument(skip(self))] #[instrument(skip(self))]
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) { pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
self.connections self.connections
@ -70,55 +70,6 @@ impl Store {
.is_empty() .is_empty()
} }
pub fn build_initial_contacts_update(
&self,
contacts: Vec<db::Contact>,
) -> proto::UpdateContacts {
let mut update = proto::UpdateContacts::default();
for contact in contacts {
match contact {
db::Contact::Accepted {
user_id,
should_notify,
busy,
} => {
update
.contacts
.push(self.contact_for_user(user_id, should_notify, busy));
}
db::Contact::Outgoing { user_id } => {
update.outgoing_requests.push(user_id.to_proto())
}
db::Contact::Incoming {
user_id,
should_notify,
} => update
.incoming_requests
.push(proto::IncomingContactRequest {
requester_id: user_id.to_proto(),
should_notify,
}),
}
}
update
}
pub fn contact_for_user(
&self,
user_id: UserId,
should_notify: bool,
busy: bool,
) -> proto::Contact {
proto::Contact {
user_id: user_id.to_proto(),
online: self.is_user_online(user_id),
busy,
should_notify,
}
}
#[cfg(test)] #[cfg(test)]
pub fn check_invariants(&self) { pub fn check_invariants(&self) {
for (connection_id, connection) in &self.connections { for (connection_id, connection) in &self.connections {