WIP: Add channel DAG related RPC messages, change update message

This commit is contained in:
Mikayla 2023-09-08 11:38:00 -07:00
parent 49fbb27ce9
commit 9e68d4a8ea
No known key found for this signature in database
8 changed files with 402 additions and 98 deletions

View file

@ -1,5 +1,7 @@
use super::*;
type ChannelDescendants = HashMap<ChannelId, HashSet<ChannelId>>;
impl Database {
#[cfg(test)]
pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
@ -68,7 +70,6 @@ impl Database {
],
);
tx.execute(channel_paths_stmt).await?;
} else {
channel_path::Entity::insert(channel_path::ActiveModel {
channel_id: ActiveValue::Set(channel.id),
@ -101,7 +102,7 @@ impl Database {
.await
}
pub async fn remove_channel(
pub async fn delete_channel(
&self,
channel_id: ChannelId,
user_id: UserId,
@ -159,9 +160,7 @@ impl Database {
let channel_paths_stmt = Statement::from_sql_and_values(
self.pool.get_database_backend(),
sql,
[
channel_id.to_proto().into(),
],
[channel_id.to_proto().into()],
);
tx.execute(channel_paths_stmt).await?;
@ -335,6 +334,43 @@ impl Database {
.await
}
async fn get_all_channels(
&self,
parents_by_child_id: ChannelDescendants,
tx: &DatabaseTransaction,
) -> Result<Vec<Channel>> {
let mut channels = Vec::with_capacity(parents_by_child_id.len());
{
let mut rows = channel::Entity::find()
.filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
// As these rows are pulled from the map's keys, this unwrap is safe.
let parents = parents_by_child_id.get(&row.id).unwrap();
if parents.len() > 0 {
for parent in parents {
channels.push(Channel {
id: row.id,
name: row.name.clone(),
parent_id: Some(*parent),
});
}
} else {
channels.push(Channel {
id: row.id,
name: row.name,
parent_id: None,
});
}
}
}
Ok(channels)
}
pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
self.transaction(|tx| async move {
let tx = tx;
@ -352,40 +388,12 @@ impl Database {
.get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
.await?;
let channels_with_admin_privileges = channel_memberships
.iter()
.filter_map(|membership| membership.admin.then_some(membership.channel_id))
.collect();
let mut channels = Vec::with_capacity(parents_by_child_id.len());
{
let mut rows = channel::Entity::find()
.filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
// As these rows are pulled from the map's keys, this unwrap is safe.
let parents = parents_by_child_id.get(&row.id).unwrap();
if parents.len() > 0 {
for parent in parents {
channels.push(Channel {
id: row.id,
name: row.name.clone(),
parent_id: Some(*parent),
});
}
} else {
channels.push(Channel {
id: row.id,
name: row.name,
parent_id: None,
});
}
}
}
let channels = self.get_all_channels(parents_by_child_id, &tx).await?;
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryUserIdsAndChannelIds {
@ -632,7 +640,7 @@ impl Database {
&self,
channel_ids: impl IntoIterator<Item = ChannelId>,
tx: &DatabaseTransaction,
) -> Result<HashMap<ChannelId, HashSet<ChannelId>>> {
) -> Result<ChannelDescendants> {
let mut values = String::new();
for id in channel_ids {
if !values.is_empty() {
@ -659,7 +667,7 @@ impl Database {
let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
let mut parents_by_child_id: HashMap<ChannelId, HashSet<ChannelId>> = HashMap::default();
let mut parents_by_child_id: ChannelDescendants = HashMap::default();
let mut paths = channel_path::Entity::find()
.from_raw_sql(stmt)
.stream(tx)
@ -758,7 +766,7 @@ impl Database {
from: ChannelId,
to: ChannelId,
tx: &DatabaseTransaction,
) -> Result<()> {
) -> Result<ChannelDescendants> {
let to_ancestors = self.get_channel_ancestors(to, &*tx).await?;
let from_descendants = self.get_channel_descendants([from], &*tx).await?;
for ancestor in to_ancestors {
@ -767,8 +775,6 @@ impl Database {
}
}
let sql = r#"
INSERT INTO channel_paths
(id_path, channel_id)
@ -806,8 +812,7 @@ impl Database {
}
}
Ok(())
Ok(from_descendants)
}
async fn remove_channel_from_parent(
@ -816,8 +821,6 @@ impl Database {
parent: ChannelId,
tx: &DatabaseTransaction,
) -> Result<()> {
let sql = r#"
DELETE FROM channel_paths
WHERE
@ -826,14 +829,10 @@ impl Database {
let channel_paths_stmt = Statement::from_sql_and_values(
self.pool.get_database_backend(),
sql,
[
parent.to_proto().into(),
from.to_proto().into(),
],
[parent.to_proto().into(), from.to_proto().into()],
);
tx.execute(channel_paths_stmt).await?;
Ok(())
}
@ -846,19 +845,22 @@ impl Database {
/// - (`None`, `Some(id)`) Link the channel without removing it from any of it's parents
/// - (`Some(id)`, `None`) Remove a channel from a given parent, and leave other parents
/// - (`Some(id)`, `Some(id)`) Move channel from one parent to another, leaving other parents
///
/// Returns the channel that was moved + it's sub channels
pub async fn move_channel(
&self,
user: UserId,
from: ChannelId,
from_parent: Option<ChannelId>,
to: Option<ChannelId>,
) -> Result<()> {
) -> Result<Vec<Channel>> {
self.transaction(|tx| async move {
// Note that even with these maxed permissions, this linking operation
// is still insecure because you can't remove someone's permissions to a
// channel if they've linked the channel to one where they're an admin.
self.check_user_is_channel_admin(from, user, &*tx).await?;
let mut channel_descendants = None;
if let Some(from_parent) = from_parent {
self.check_user_is_channel_admin(from_parent, user, &*tx)
.await?;
@ -870,10 +872,30 @@ impl Database {
if let Some(to) = to {
self.check_user_is_channel_admin(to, user, &*tx).await?;
self.link_channel(from, to, &*tx).await?;
channel_descendants = Some(self.link_channel(from, to, &*tx).await?);
}
Ok(())
let mut channel_descendants = match channel_descendants {
Some(channel_descendants) => channel_descendants,
None => self.get_channel_descendants([from], &*tx).await?,
};
// Repair the parent ID of the channel in case it was from a cached call
if let Some(channel) = channel_descendants.get_mut(&from) {
if let Some(from_parent) = from_parent {
channel.remove(&from_parent);
}
if let Some(to) = to {
channel.insert(to);
}
}
let channels = self
.get_all_channels(channel_descendants, &*tx)
.await?;
Ok(channels)
})
.await
}