Refactor to avoid some (mostly hypothetical) races

Tidy up added code to reduce duplicity of X and X_internals.
This commit is contained in:
Conrad Irwin 2023-10-18 19:27:00 -06:00
parent 2b11463567
commit 3853009d92
13 changed files with 715 additions and 765 deletions

View file

@ -16,20 +16,39 @@ impl Database {
.await
}
#[cfg(test)]
pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
self.create_channel(name, None, creator_id).await
Ok(self
.create_channel(name, None, creator_id)
.await?
.channel
.id)
}
#[cfg(test)]
pub async fn create_sub_channel(
&self,
name: &str,
parent: ChannelId,
creator_id: UserId,
) -> Result<ChannelId> {
Ok(self
.create_channel(name, Some(parent), creator_id)
.await?
.channel
.id)
}
pub async fn create_channel(
&self,
name: &str,
parent: Option<ChannelId>,
creator_id: UserId,
) -> Result<ChannelId> {
admin_id: UserId,
) -> Result<CreateChannelResult> {
let name = Self::sanitize_channel_name(name)?;
self.transaction(move |tx| async move {
if let Some(parent) = parent {
self.check_user_is_channel_admin(parent, creator_id, &*tx)
self.check_user_is_channel_admin(parent, admin_id, &*tx)
.await?;
}
@ -71,17 +90,34 @@ impl Database {
.await?;
}
channel_member::ActiveModel {
id: ActiveValue::NotSet,
channel_id: ActiveValue::Set(channel.id),
user_id: ActiveValue::Set(creator_id),
accepted: ActiveValue::Set(true),
role: ActiveValue::Set(ChannelRole::Admin),
if parent.is_none() {
channel_member::ActiveModel {
id: ActiveValue::NotSet,
channel_id: ActiveValue::Set(channel.id),
user_id: ActiveValue::Set(admin_id),
accepted: ActiveValue::Set(true),
role: ActiveValue::Set(ChannelRole::Admin),
}
.insert(&*tx)
.await?;
}
.insert(&*tx)
.await?;
Ok(channel.id)
let participants_to_update = if let Some(parent) = parent {
self.participants_to_notify_for_channel_change(parent, &*tx)
.await?
} else {
vec![]
};
Ok(CreateChannelResult {
channel: Channel {
id: channel.id,
visibility: channel.visibility,
name: channel.name,
role: ChannelRole::Admin,
},
participants_to_update,
})
})
.await
}
@ -132,7 +168,7 @@ impl Database {
&& channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public)
{
let channel_id_to_join = self
.public_path_to_channel_internal(channel_id, &*tx)
.public_path_to_channel(channel_id, &*tx)
.await?
.first()
.cloned()
@ -178,13 +214,17 @@ impl Database {
&self,
channel_id: ChannelId,
visibility: ChannelVisibility,
user_id: UserId,
) -> Result<channel::Model> {
admin_id: UserId,
) -> Result<SetChannelVisibilityResult> {
self.transaction(move |tx| async move {
self.check_user_is_channel_admin(channel_id, user_id, &*tx)
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
let channel = channel::ActiveModel {
let previous_members = self
.get_channel_participant_details_internal(channel_id, &*tx)
.await?;
channel::ActiveModel {
id: ActiveValue::Unchanged(channel_id),
visibility: ActiveValue::Set(visibility),
..Default::default()
@ -192,7 +232,40 @@ impl Database {
.update(&*tx)
.await?;
Ok(channel)
let mut participants_to_update: HashMap<UserId, ChannelsForUser> = self
.participants_to_notify_for_channel_change(channel_id, &*tx)
.await?
.into_iter()
.collect();
let mut participants_to_remove: HashSet<UserId> = HashSet::default();
match visibility {
ChannelVisibility::Members => {
for member in previous_members {
if member.role.can_only_see_public_descendants() {
participants_to_remove.insert(member.user_id);
}
}
}
ChannelVisibility::Public => {
if let Some(public_parent_id) =
self.public_parent_channel_id(channel_id, &*tx).await?
{
let parent_updates = self
.participants_to_notify_for_channel_change(public_parent_id, &*tx)
.await?;
for (user_id, channels) in parent_updates {
participants_to_update.insert(user_id, channels);
}
}
}
}
Ok(SetChannelVisibilityResult {
participants_to_update,
participants_to_remove,
})
})
.await
}
@ -303,14 +376,14 @@ impl Database {
pub async fn rename_channel(
&self,
channel_id: ChannelId,
user_id: UserId,
admin_id: UserId,
new_name: &str,
) -> Result<Channel> {
) -> Result<RenameChannelResult> {
self.transaction(move |tx| async move {
let new_name = Self::sanitize_channel_name(new_name)?.to_string();
let role = self
.check_user_is_channel_admin(channel_id, user_id, &*tx)
.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
let channel = channel::ActiveModel {
@ -321,11 +394,31 @@ impl Database {
.update(&*tx)
.await?;
Ok(Channel {
id: channel.id,
name: channel.name,
visibility: channel.visibility,
role,
let participants = self
.get_channel_participant_details_internal(channel_id, &*tx)
.await?;
Ok(RenameChannelResult {
channel: Channel {
id: channel.id,
name: channel.name,
visibility: channel.visibility,
role,
},
participants_to_update: participants
.iter()
.map(|participant| {
(
participant.user_id,
Channel {
id: channel.id,
name: new_name.clone(),
visibility: channel.visibility,
role: participant.role,
},
)
})
.collect(),
})
})
.await
@ -628,91 +721,83 @@ impl Database {
})
}
pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
self.transaction(|tx| async move { self.get_channel_participants_internal(id, &*tx).await })
.await
}
pub async fn participants_to_notify_for_channel_change(
async fn participants_to_notify_for_channel_change(
&self,
new_parent: ChannelId,
admin_id: UserId,
tx: &DatabaseTransaction,
) -> Result<Vec<(UserId, ChannelsForUser)>> {
self.transaction(|tx| async move {
let mut results: Vec<(UserId, ChannelsForUser)> = Vec::new();
let mut results: Vec<(UserId, ChannelsForUser)> = Vec::new();
let members = self
.get_channel_participant_details_internal(new_parent, admin_id, &*tx)
.await?;
let members = self
.get_channel_participant_details_internal(new_parent, &*tx)
.await?;
dbg!(&members);
dbg!(&members);
for member in members.iter() {
if !member.role.can_see_all_descendants() {
continue;
}
results.push((
member.user_id,
self.get_user_channels(
member.user_id,
vec![channel_member::Model {
id: Default::default(),
channel_id: new_parent,
user_id: member.user_id,
role: member.role,
accepted: true,
}],
&*tx,
)
.await?,
))
for member in members.iter() {
if !member.role.can_see_all_descendants() {
continue;
}
results.push((
member.user_id,
self.get_user_channels(
member.user_id,
vec![channel_member::Model {
id: Default::default(),
channel_id: new_parent,
user_id: member.user_id,
role: member.role,
accepted: true,
}],
&*tx,
)
.await?,
))
}
let public_parent = self
.public_path_to_channel_internal(new_parent, &*tx)
let public_parent = self
.public_path_to_channel(new_parent, &*tx)
.await?
.last()
.copied();
let Some(public_parent) = public_parent else {
return Ok(results);
};
// could save some time in the common case by skipping this if the
// new channel is not public and has no public descendants.
let public_members = if public_parent == new_parent {
members
} else {
self.get_channel_participant_details_internal(public_parent, &*tx)
.await?
.last()
.copied();
};
let Some(public_parent) = public_parent else {
return Ok(results);
dbg!(&public_members);
for member in public_members {
if !member.role.can_only_see_public_descendants() {
continue;
};
// could save some time in the common case by skipping this if the
// new channel is not public and has no public descendants.
let public_members = if public_parent == new_parent {
members
} else {
self.get_channel_participant_details_internal(public_parent, admin_id, &*tx)
.await?
};
dbg!(&public_members);
for member in public_members {
if !member.role.can_only_see_public_descendants() {
continue;
};
results.push((
results.push((
member.user_id,
self.get_user_channels(
member.user_id,
self.get_user_channels(
member.user_id,
vec![channel_member::Model {
id: Default::default(),
channel_id: public_parent,
user_id: member.user_id,
role: member.role,
accepted: true,
}],
&*tx,
)
.await?,
))
}
vec![channel_member::Model {
id: Default::default(),
channel_id: public_parent,
user_id: member.user_id,
role: member.role,
accepted: true,
}],
&*tx,
)
.await?,
))
}
Ok(results)
})
.await
Ok(results)
}
pub async fn set_channel_member_role(
@ -748,15 +833,11 @@ impl Database {
.await
}
pub async fn get_channel_participant_details_internal(
async fn get_channel_participant_details_internal(
&self,
channel_id: ChannelId,
admin_id: UserId,
tx: &DatabaseTransaction,
) -> Result<Vec<ChannelMember>> {
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
let channel_visibility = channel::Entity::find()
.filter(channel::Column::Id.eq(channel_id))
.one(&*tx)
@ -851,8 +932,11 @@ impl Database {
) -> Result<Vec<proto::ChannelMember>> {
let members = self
.transaction(move |tx| async move {
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
Ok(self
.get_channel_participant_details_internal(channel_id, admin_id, &*tx)
.get_channel_participant_details_internal(channel_id, &*tx)
.await?)
})
.await?;
@ -863,25 +947,18 @@ impl Database {
.collect())
}
pub async fn get_channel_participants_internal(
pub async fn get_channel_participants(
&self,
id: ChannelId,
channel_id: ChannelId,
tx: &DatabaseTransaction,
) -> Result<Vec<UserId>> {
let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
let user_ids = channel_member::Entity::find()
.distinct()
.filter(
channel_member::Column::ChannelId
.is_in(ancestor_ids.iter().copied())
.and(channel_member::Column::Accepted.eq(true)),
)
.select_only()
.column(channel_member::Column::UserId)
.into_values::<_, QueryUserIds>()
.all(&*tx)
let participants = self
.get_channel_participant_details_internal(channel_id, &*tx)
.await?;
Ok(user_ids)
Ok(participants
.into_iter()
.map(|member| member.user_id)
.collect())
}
pub async fn check_user_is_channel_admin(
@ -951,18 +1028,12 @@ impl Database {
Ok(row)
}
// ordered from higher in tree to lower
// only considers one path to a channel
// includes the channel itself
pub async fn path_to_channel(&self, channel_id: ChannelId) -> Result<Vec<ChannelId>> {
self.transaction(move |tx| async move {
Ok(self.path_to_channel_internal(channel_id, &*tx).await?)
})
.await
}
pub async fn parent_channel_id(&self, channel_id: ChannelId) -> Result<Option<ChannelId>> {
let path = self.path_to_channel(channel_id).await?;
pub async fn parent_channel_id(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
) -> Result<Option<ChannelId>> {
let path = self.path_to_channel(channel_id, &*tx).await?;
if path.len() >= 2 {
Ok(Some(path[path.len() - 2]))
} else {
@ -973,8 +1044,9 @@ impl Database {
pub async fn public_parent_channel_id(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
) -> Result<Option<ChannelId>> {
let path = self.path_to_channel(channel_id).await?;
let path = self.public_path_to_channel(channel_id, &*tx).await?;
if path.len() >= 2 && path.last().copied() == Some(channel_id) {
Ok(Some(path[path.len() - 2]))
} else {
@ -982,7 +1054,7 @@ impl Database {
}
}
pub async fn path_to_channel_internal(
pub async fn path_to_channel(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
@ -1005,27 +1077,12 @@ impl Database {
.collect())
}
// ordered from higher in tree to lower
// only considers one path to a channel
// includes the channel itself
pub async fn public_path_to_channel(&self, channel_id: ChannelId) -> Result<Vec<ChannelId>> {
self.transaction(move |tx| async move {
Ok(self
.public_path_to_channel_internal(channel_id, &*tx)
.await?)
})
.await
}
// ordered from higher in tree to lower
// only considers one path to a channel
// includes the channel itself
pub async fn public_path_to_channel_internal(
pub async fn public_path_to_channel(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
) -> Result<Vec<ChannelId>> {
let ancestor_ids = self.path_to_channel_internal(channel_id, &*tx).await?;
let ancestor_ids = self.path_to_channel(channel_id, &*tx).await?;
let rows = channel::Entity::find()
.filter(channel::Column::Id.is_in(ancestor_ids.iter().copied()))
@ -1151,27 +1208,6 @@ impl Database {
Ok(channel_ids)
}
// returns all ids of channels in the tree under this channel_id.
pub async fn get_channel_descendant_ids(
&self,
channel_id: ChannelId,
) -> Result<HashSet<ChannelId>> {
self.transaction(|tx| async move {
let pairs = self.get_channel_descendants([channel_id], &*tx).await?;
let mut results: HashSet<ChannelId> = HashSet::default();
for ChannelEdge {
parent_id: _,
channel_id,
} in pairs
{
results.insert(ChannelId::from_proto(channel_id));
}
Ok(results)
})
.await
}
// Returns the channel desendants as a sorted list of edges for further processing.
// The edges are sorted such that you will see unknown channel ids as children
// before you see them as parents.
@ -1388,9 +1424,6 @@ impl Database {
from: ChannelId,
) -> Result<()> {
self.transaction(|tx| async move {
// Note that even with these maxed permissions, this linking operation
// is still insecure because you can't remove someone's permissions to a
// channel if they've linked the channel to one where they're an admin.
self.check_user_is_channel_admin(channel, user, &*tx)
.await?;
@ -1433,6 +1466,8 @@ impl Database {
.await?
== 0;
dbg!(is_stranded, &paths);
// Make sure that there is always at least one path to the channel
if is_stranded {
let root_paths: Vec<_> = paths
@ -1445,6 +1480,8 @@ impl Database {
}
})
.collect();
dbg!(is_stranded, &root_paths);
channel_path::Entity::insert_many(root_paths)
.exec(&*tx)
.await?;
@ -1453,49 +1490,64 @@ impl Database {
Ok(())
}
/// Move a channel from one parent to another, returns the
/// Channels that were moved for notifying clients
/// Move a channel from one parent to another
pub async fn move_channel(
&self,
user: UserId,
channel: ChannelId,
from: ChannelId,
to: ChannelId,
) -> Result<ChannelGraph> {
if from == to {
return Ok(ChannelGraph {
channels: vec![],
edges: vec![],
});
}
channel_id: ChannelId,
old_parent_id: Option<ChannelId>,
new_parent_id: ChannelId,
admin_id: UserId,
) -> Result<Option<MoveChannelResult>> {
self.transaction(|tx| async move {
self.check_user_is_channel_admin(channel, user, &*tx)
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
debug_assert_eq!(
self.parent_channel_id(channel_id, &*tx).await?,
old_parent_id
);
self.unlink_channel_internal(user, channel, from, &*tx)
if old_parent_id == Some(new_parent_id) {
return Ok(None);
}
let previous_participants = self
.get_channel_participant_details_internal(channel_id, &*tx)
.await?;
Ok(moved_channels)
})
.await
}
self.link_channel_internal(admin_id, channel_id, new_parent_id, &*tx)
.await?;
pub async fn assert_root_channel(&self, channel: ChannelId) -> Result<()> {
self.transaction(|tx| async move {
let path = channel_path::Entity::find()
.filter(channel_path::Column::ChannelId.eq(channel))
.one(&*tx)
if let Some(from) = old_parent_id {
self.unlink_channel_internal(admin_id, channel_id, from, &*tx)
.await?;
}
let participants_to_update: HashMap<UserId, ChannelsForUser> = self
.participants_to_notify_for_channel_change(new_parent_id, &*tx)
.await?
.ok_or_else(|| anyhow!("no such channel found"))?;
.into_iter()
.collect();
let mut id_parts = path.id_path.trim_matches('/').split('/');
let mut moved_channels: HashSet<ChannelId> = HashSet::default();
moved_channels.insert(channel_id);
for edge in self.get_channel_descendants([channel_id], &*tx).await? {
moved_channels.insert(ChannelId::from_proto(edge.channel_id));
}
(id_parts.next().is_some() && id_parts.next().is_none())
.then_some(())
.ok_or_else(|| anyhow!("channel is not a root channel").into())
let mut participants_to_remove: HashSet<UserId> = HashSet::default();
for participant in previous_participants {
if participant.kind == proto::channel_member::Kind::AncestorMember {
if !participants_to_update.contains_key(&participant.user_id) {
participants_to_remove.insert(participant.user_id);
}
}
}
Ok(Some(MoveChannelResult {
participants_to_remove,
participants_to_update,
moved_channels,
}))
})
.await
}