Rename Store to ConnectionPool

This commit is contained in:
Antonio Scandurra 2022-11-17 19:03:50 +01:00
parent 6c83be3f89
commit 44bb2ce024
3 changed files with 133 additions and 112 deletions

View file

@ -1,5 +1,5 @@
use crate::{
db::{NewUserParams, SqliteTestDb as TestDb, UserId},
db::{self, NewUserParams, SqliteTestDb as TestDb, UserId},
rpc::{Executor, Server},
AppState,
};
@ -5469,18 +5469,15 @@ async fn test_random_collaboration(
}
for user_id in &user_ids {
let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap();
let contacts = server
.store
.lock()
.await
.build_initial_contacts_update(contacts)
.contacts;
let pool = server.connection_pool.lock().await;
for contact in contacts {
if contact.online {
assert_ne!(
contact.user_id, removed_guest_id.0 as u64,
"removed guest is still a contact of another peer"
);
if let db::Contact::Accepted { user_id, .. } = contact {
if pool.is_user_online(user_id) {
assert_ne!(
user_id, removed_guest_id,
"removed guest is still a contact of another peer"
);
}
}
}
}

View file

@ -1,4 +1,4 @@
mod store;
mod connection_pool;
use crate::{
auth,
@ -23,6 +23,7 @@ use axum::{
Extension, Router, TypedHeader,
};
use collections::{HashMap, HashSet};
pub use connection_pool::ConnectionPool;
use futures::{
channel::oneshot,
future::{self, BoxFuture},
@ -49,7 +50,6 @@ use std::{
},
time::Duration,
};
pub use store::Store;
use tokio::{
sync::{Mutex, MutexGuard},
time::Sleep,
@ -103,7 +103,7 @@ impl<R: RequestMessage> Response<R> {
pub struct Server {
peer: Arc<Peer>,
pub(crate) store: Mutex<Store>,
pub(crate) connection_pool: Mutex<ConnectionPool>,
app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>,
}
@ -117,8 +117,8 @@ pub trait Executor: Send + Clone {
#[derive(Clone)]
pub struct RealExecutor;
pub(crate) struct StoreGuard<'a> {
guard: MutexGuard<'a, Store>,
pub(crate) struct ConnectionPoolGuard<'a> {
guard: MutexGuard<'a, ConnectionPool>,
_not_send: PhantomData<Rc<()>>,
}
@ -126,7 +126,7 @@ pub(crate) struct StoreGuard<'a> {
pub struct ServerSnapshot<'a> {
peer: &'a Peer,
#[serde(serialize_with = "serialize_deref")]
store: StoreGuard<'a>,
connection_pool: ConnectionPoolGuard<'a>,
}
pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
@ -143,7 +143,7 @@ impl Server {
let mut server = Self {
peer: Peer::new(),
app_state,
store: Default::default(),
connection_pool: Default::default(),
handlers: Default::default(),
};
@ -257,8 +257,6 @@ impl Server {
self
}
/// Handle a request while holding a lock to the store. This is useful when we're registering
/// a connection but we want to respond on the connection before anybody else can send on it.
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(Arc<Self>, M, Response<M>, Session) -> Fut,
@ -342,9 +340,9 @@ impl Server {
).await?;
{
let mut store = this.store().await;
store.add_connection(connection_id, user_id, user.admin);
this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?;
let mut pool = this.connection_pool().await;
pool.add_connection(connection_id, user_id, user.admin);
this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
if let Some((code, count)) = invite_code {
this.peer.send(connection_id, proto::UpdateInviteInfo {
@ -435,9 +433,9 @@ impl Server {
) -> Result<()> {
self.peer.disconnect(connection_id);
let decline_calls = {
let mut store = self.store().await;
store.remove_connection(connection_id)?;
let mut connections = store.user_connection_ids(user_id);
let mut pool = self.connection_pool().await;
pool.remove_connection(connection_id)?;
let mut connections = pool.user_connection_ids(user_id);
connections.next().is_none()
};
@ -468,9 +466,9 @@ impl Server {
) -> Result<()> {
if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
if let Some(code) = &user.invite_code {
let store = self.store().await;
let invitee_contact = store.contact_for_user(invitee_id, true, false);
for connection_id in store.user_connection_ids(inviter_id) {
let pool = self.connection_pool().await;
let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
for connection_id in pool.user_connection_ids(inviter_id) {
self.peer.send(
connection_id,
proto::UpdateContacts {
@ -494,8 +492,8 @@ impl Server {
pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
if let Some(invite_code) = &user.invite_code {
let store = self.store().await;
for connection_id in store.user_connection_ids(user_id) {
let pool = self.connection_pool().await;
for connection_id in pool.user_connection_ids(user_id) {
self.peer.send(
connection_id,
proto::UpdateInviteInfo {
@ -582,7 +580,11 @@ impl Server {
session.connection_id,
)
.await?;
for connection_id in self.store().await.user_connection_ids(session.user_id) {
for connection_id in self
.connection_pool()
.await
.user_connection_ids(session.user_id)
{
self.peer
.send(connection_id, proto::CallCanceled {})
.trace_err();
@ -672,9 +674,9 @@ impl Server {
self.room_updated(&left_room.room);
{
let store = self.store().await;
let pool = self.connection_pool().await;
for canceled_user_id in left_room.canceled_calls_to_user_ids {
for connection_id in store.user_connection_ids(canceled_user_id) {
for connection_id in pool.user_connection_ids(canceled_user_id) {
self.peer
.send(connection_id, proto::CallCanceled {})
.trace_err();
@ -742,7 +744,7 @@ impl Server {
self.update_user_contacts(called_user_id).await?;
let mut calls = self
.store()
.connection_pool()
.await
.user_connection_ids(called_user_id)
.map(|connection_id| self.peer.request(connection_id, incoming_call.clone()))
@ -784,7 +786,11 @@ impl Server {
.db
.cancel_call(Some(room_id), session.connection_id, called_user_id)
.await?;
for connection_id in self.store().await.user_connection_ids(called_user_id) {
for connection_id in self
.connection_pool()
.await
.user_connection_ids(called_user_id)
{
self.peer
.send(connection_id, proto::CallCanceled {})
.trace_err();
@ -807,7 +813,11 @@ impl Server {
.db
.decline_call(Some(room_id), session.user_id)
.await?;
for connection_id in self.store().await.user_connection_ids(session.user_id) {
for connection_id in self
.connection_pool()
.await
.user_connection_ids(session.user_id)
{
self.peer
.send(connection_id, proto::CallCanceled {})
.trace_err();
@ -897,15 +907,15 @@ impl Server {
async fn update_user_contacts(self: &Arc<Server>, user_id: UserId) -> Result<()> {
let contacts = self.app_state.db.get_contacts(user_id).await?;
let busy = self.app_state.db.is_user_busy(user_id).await?;
let store = self.store().await;
let updated_contact = store.contact_for_user(user_id, false, busy);
let pool = self.connection_pool().await;
let updated_contact = contact_for_user(user_id, false, busy, &pool);
for contact in contacts {
if let db::Contact::Accepted {
user_id: contact_user_id,
..
} = contact
{
for contact_conn_id in store.user_connection_ids(contact_user_id) {
for contact_conn_id in pool.user_connection_ids(contact_user_id) {
self.peer
.send(
contact_conn_id,
@ -1522,7 +1532,11 @@ impl Server {
// Update outgoing contact requests of requester
let mut update = proto::UpdateContacts::default();
update.outgoing_requests.push(responder_id.to_proto());
for connection_id in self.store().await.user_connection_ids(requester_id) {
for connection_id in self
.connection_pool()
.await
.user_connection_ids(requester_id)
{
self.peer.send(connection_id, update.clone())?;
}
@ -1534,7 +1548,11 @@ impl Server {
requester_id: requester_id.to_proto(),
should_notify: true,
});
for connection_id in self.store().await.user_connection_ids(responder_id) {
for connection_id in self
.connection_pool()
.await
.user_connection_ids(responder_id)
{
self.peer.send(connection_id, update.clone())?;
}
@ -1563,18 +1581,18 @@ impl Server {
.await?;
let busy = self.app_state.db.is_user_busy(requester_id).await?;
let store = self.store().await;
let pool = self.connection_pool().await;
// Update responder with new contact
let mut update = proto::UpdateContacts::default();
if accept {
update
.contacts
.push(store.contact_for_user(requester_id, false, busy));
.push(contact_for_user(requester_id, false, busy, &pool));
}
update
.remove_incoming_requests
.push(requester_id.to_proto());
for connection_id in store.user_connection_ids(responder_id) {
for connection_id in pool.user_connection_ids(responder_id) {
self.peer.send(connection_id, update.clone())?;
}
@ -1583,12 +1601,12 @@ impl Server {
if accept {
update
.contacts
.push(store.contact_for_user(responder_id, true, busy));
.push(contact_for_user(responder_id, true, busy, &pool));
}
update
.remove_outgoing_requests
.push(responder_id.to_proto());
for connection_id in store.user_connection_ids(requester_id) {
for connection_id in pool.user_connection_ids(requester_id) {
self.peer.send(connection_id, update.clone())?;
}
}
@ -1615,7 +1633,11 @@ impl Server {
update
.remove_outgoing_requests
.push(responder_id.to_proto());
for connection_id in self.store().await.user_connection_ids(requester_id) {
for connection_id in self
.connection_pool()
.await
.user_connection_ids(requester_id)
{
self.peer.send(connection_id, update.clone())?;
}
@ -1624,7 +1646,11 @@ impl Server {
update
.remove_incoming_requests
.push(requester_id.to_proto());
for connection_id in self.store().await.user_connection_ids(responder_id) {
for connection_id in self
.connection_pool()
.await
.user_connection_ids(responder_id)
{
self.peer.send(connection_id, update.clone())?;
}
@ -1678,13 +1704,13 @@ impl Server {
Ok(())
}
pub(crate) async fn store(&self) -> StoreGuard<'_> {
pub(crate) async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
#[cfg(test)]
tokio::task::yield_now().await;
let guard = self.store.lock().await;
let guard = self.connection_pool.lock().await;
#[cfg(test)]
tokio::task::yield_now().await;
StoreGuard {
ConnectionPoolGuard {
guard,
_not_send: PhantomData,
}
@ -1692,27 +1718,27 @@ impl Server {
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
ServerSnapshot {
store: self.store().await,
connection_pool: self.connection_pool().await,
peer: &self.peer,
}
}
}
impl<'a> Deref for StoreGuard<'a> {
type Target = Store;
impl<'a> Deref for ConnectionPoolGuard<'a> {
type Target = ConnectionPool;
fn deref(&self) -> &Self::Target {
&*self.guard
}
}
impl<'a> DerefMut for StoreGuard<'a> {
impl<'a> DerefMut for ConnectionPoolGuard<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.guard
}
}
impl<'a> Drop for StoreGuard<'a> {
impl<'a> Drop for ConnectionPoolGuard<'a> {
fn drop(&mut self) {
#[cfg(test)]
self.check_invariants();
@ -1821,7 +1847,7 @@ pub async fn handle_websocket_request(
pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
let connections = server
.store()
.connection_pool()
.await
.connections()
.filter(|connection| !connection.admin)
@ -1868,6 +1894,53 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
}
}
fn build_initial_contacts_update(
contacts: Vec<db::Contact>,
pool: &ConnectionPool,
) -> proto::UpdateContacts {
let mut update = proto::UpdateContacts::default();
for contact in contacts {
match contact {
db::Contact::Accepted {
user_id,
should_notify,
busy,
} => {
update
.contacts
.push(contact_for_user(user_id, should_notify, busy, &pool));
}
db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
db::Contact::Incoming {
user_id,
should_notify,
} => update
.incoming_requests
.push(proto::IncomingContactRequest {
requester_id: user_id.to_proto(),
should_notify,
}),
}
}
update
}
fn contact_for_user(
user_id: UserId,
should_notify: bool,
busy: bool,
pool: &ConnectionPool,
) -> proto::Contact {
proto::Contact {
user_id: user_id.to_proto(),
online: pool.is_user_online(user_id),
busy,
should_notify,
}
}
pub trait ResultExt {
type Ok;

View file

@ -1,12 +1,12 @@
use crate::db::{self, UserId};
use crate::db::UserId;
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashSet};
use rpc::{proto, ConnectionId};
use rpc::ConnectionId;
use serde::Serialize;
use tracing::instrument;
#[derive(Default, Serialize)]
pub struct Store {
pub struct ConnectionPool {
connections: BTreeMap<ConnectionId, Connection>,
connected_users: BTreeMap<UserId, ConnectedUser>,
}
@ -22,7 +22,7 @@ pub struct Connection {
pub admin: bool,
}
impl Store {
impl ConnectionPool {
#[instrument(skip(self))]
pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) {
self.connections
@ -70,55 +70,6 @@ impl Store {
.is_empty()
}
pub fn build_initial_contacts_update(
&self,
contacts: Vec<db::Contact>,
) -> proto::UpdateContacts {
let mut update = proto::UpdateContacts::default();
for contact in contacts {
match contact {
db::Contact::Accepted {
user_id,
should_notify,
busy,
} => {
update
.contacts
.push(self.contact_for_user(user_id, should_notify, busy));
}
db::Contact::Outgoing { user_id } => {
update.outgoing_requests.push(user_id.to_proto())
}
db::Contact::Incoming {
user_id,
should_notify,
} => update
.incoming_requests
.push(proto::IncomingContactRequest {
requester_id: user_id.to_proto(),
should_notify,
}),
}
}
update
}
pub fn contact_for_user(
&self,
user_id: UserId,
should_notify: bool,
busy: bool,
) -> proto::Contact {
proto::Contact {
user_id: user_id.to_proto(),
online: self.is_user_online(user_id),
busy,
should_notify,
}
}
#[cfg(test)]
pub fn check_invariants(&self) {
for (connection_id, connection) in &self.connections {