use super::*; type ChannelDescendants = HashMap>; impl Database { #[cfg(test)] pub async fn all_channels(&self) -> Result> { self.transaction(move |tx| async move { let mut channels = Vec::new(); let mut rows = channel::Entity::find().stream(&*tx).await?; while let Some(row) = rows.next().await { let row = row?; channels.push((row.id, row.name)); } Ok(channels) }) .await } pub async fn create_root_channel( &self, name: &str, live_kit_room: &str, creator_id: UserId, ) -> Result { self.create_channel(name, None, live_kit_room, creator_id) .await } pub async fn create_channel( &self, name: &str, parent: Option, live_kit_room: &str, creator_id: UserId, ) -> Result { let name = Self::sanitize_channel_name(name)?; self.transaction(move |tx| async move { if let Some(parent) = parent { self.check_user_is_channel_admin(parent, creator_id, &*tx) .await?; } let channel = channel::ActiveModel { name: ActiveValue::Set(name.to_string()), ..Default::default() } .insert(&*tx) .await?; let channel_paths_stmt; if let Some(parent) = parent { let sql = r#" INSERT INTO channel_paths (id_path, channel_id) SELECT id_path || $1 || '/', $2 FROM channel_paths WHERE channel_id = $3 "#; channel_paths_stmt = Statement::from_sql_and_values( self.pool.get_database_backend(), sql, [ channel.id.to_proto().into(), channel.id.to_proto().into(), parent.to_proto().into(), ], ); tx.execute(channel_paths_stmt).await?; } else { channel_path::Entity::insert(channel_path::ActiveModel { channel_id: ActiveValue::Set(channel.id), id_path: ActiveValue::Set(format!("/{}/", channel.id)), }) .exec(&*tx) .await?; } channel_member::ActiveModel { channel_id: ActiveValue::Set(channel.id), user_id: ActiveValue::Set(creator_id), accepted: ActiveValue::Set(true), admin: ActiveValue::Set(true), ..Default::default() } .insert(&*tx) .await?; room::ActiveModel { channel_id: ActiveValue::Set(Some(channel.id)), live_kit_room: ActiveValue::Set(live_kit_room.to_string()), ..Default::default() } .insert(&*tx) .await?; Ok(channel.id) }) .await } pub async fn delete_channel( &self, channel_id: ChannelId, user_id: UserId, ) -> Result<(Vec, Vec)> { self.transaction(move |tx| async move { self.check_user_is_channel_admin(channel_id, user_id, &*tx) .await?; // Don't remove descendant channels that have additional parents. let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?; { let mut channels_to_keep = channel_path::Entity::find() .filter( channel_path::Column::ChannelId .is_in( channels_to_remove .keys() .copied() .filter(|&id| id != channel_id), ) .and( channel_path::Column::IdPath .not_like(&format!("%/{}/%", channel_id)), ), ) .stream(&*tx) .await?; while let Some(row) = channels_to_keep.next().await { let row = row?; channels_to_remove.remove(&row.channel_id); } } let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?; let members_to_notify: Vec = channel_member::Entity::find() .filter(channel_member::Column::ChannelId.is_in(channel_ancestors)) .select_only() .column(channel_member::Column::UserId) .distinct() .into_values::<_, QueryUserIds>() .all(&*tx) .await?; channel::Entity::delete_many() .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied())) .exec(&*tx) .await?; // Delete any other paths that incldue this channel let sql = r#" DELETE FROM channel_paths WHERE id_path LIKE '%' || $1 || '%' "#; let channel_paths_stmt = Statement::from_sql_and_values( self.pool.get_database_backend(), sql, [channel_id.to_proto().into()], ); tx.execute(channel_paths_stmt).await?; Ok((channels_to_remove.into_keys().collect(), members_to_notify)) }) .await } pub async fn invite_channel_member( &self, channel_id: ChannelId, invitee_id: UserId, inviter_id: UserId, is_admin: bool, ) -> Result<()> { self.transaction(move |tx| async move { self.check_user_is_channel_admin(channel_id, inviter_id, &*tx) .await?; channel_member::ActiveModel { channel_id: ActiveValue::Set(channel_id), user_id: ActiveValue::Set(invitee_id), accepted: ActiveValue::Set(false), admin: ActiveValue::Set(is_admin), ..Default::default() } .insert(&*tx) .await?; Ok(()) }) .await } fn sanitize_channel_name(name: &str) -> Result<&str> { let new_name = name.trim().trim_start_matches('#'); if new_name == "" { Err(anyhow!("channel name can't be blank"))?; } Ok(new_name) } pub async fn rename_channel( &self, channel_id: ChannelId, user_id: UserId, new_name: &str, ) -> Result { self.transaction(move |tx| async move { let new_name = Self::sanitize_channel_name(new_name)?.to_string(); self.check_user_is_channel_admin(channel_id, user_id, &*tx) .await?; channel::ActiveModel { id: ActiveValue::Unchanged(channel_id), name: ActiveValue::Set(new_name.clone()), ..Default::default() } .update(&*tx) .await?; Ok(new_name) }) .await } pub async fn respond_to_channel_invite( &self, channel_id: ChannelId, user_id: UserId, accept: bool, ) -> Result<()> { self.transaction(move |tx| async move { let rows_affected = if accept { channel_member::Entity::update_many() .set(channel_member::ActiveModel { accepted: ActiveValue::Set(accept), ..Default::default() }) .filter( channel_member::Column::ChannelId .eq(channel_id) .and(channel_member::Column::UserId.eq(user_id)) .and(channel_member::Column::Accepted.eq(false)), ) .exec(&*tx) .await? .rows_affected } else { channel_member::ActiveModel { channel_id: ActiveValue::Unchanged(channel_id), user_id: ActiveValue::Unchanged(user_id), ..Default::default() } .delete(&*tx) .await? .rows_affected }; if rows_affected == 0 { Err(anyhow!("no such invitation"))?; } Ok(()) }) .await } pub async fn remove_channel_member( &self, channel_id: ChannelId, member_id: UserId, remover_id: UserId, ) -> Result<()> { self.transaction(|tx| async move { self.check_user_is_channel_admin(channel_id, remover_id, &*tx) .await?; let result = channel_member::Entity::delete_many() .filter( channel_member::Column::ChannelId .eq(channel_id) .and(channel_member::Column::UserId.eq(member_id)), ) .exec(&*tx) .await?; if result.rows_affected == 0 { Err(anyhow!("no such member"))?; } Ok(()) }) .await } pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result> { self.transaction(|tx| async move { let channel_invites = channel_member::Entity::find() .filter( channel_member::Column::UserId .eq(user_id) .and(channel_member::Column::Accepted.eq(false)), ) .all(&*tx) .await?; let channels = channel::Entity::find() .filter( channel::Column::Id.is_in( channel_invites .into_iter() .map(|channel_member| channel_member.channel_id), ), ) .all(&*tx) .await?; let channels = channels .into_iter() .map(|channel| Channel { id: channel.id, name: channel.name, parent_id: None, }) .collect(); Ok(channels) }) .await } async fn get_all_channels( &self, parents_by_child_id: ChannelDescendants, tx: &DatabaseTransaction, ) -> Result> { let mut channels = Vec::with_capacity(parents_by_child_id.len()); { let mut rows = channel::Entity::find() .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied())) .stream(&*tx) .await?; while let Some(row) = rows.next().await { let row = row?; // As these rows are pulled from the map's keys, this unwrap is safe. let parents = parents_by_child_id.get(&row.id).unwrap(); if parents.len() > 0 { for parent in parents { channels.push(Channel { id: row.id, name: row.name.clone(), parent_id: Some(*parent), }); } } else { channels.push(Channel { id: row.id, name: row.name, parent_id: None, }); } } } Ok(channels) } pub async fn get_channels_for_user(&self, user_id: UserId) -> Result { self.transaction(|tx| async move { let tx = tx; let channel_memberships = channel_member::Entity::find() .filter( channel_member::Column::UserId .eq(user_id) .and(channel_member::Column::Accepted.eq(true)), ) .all(&*tx) .await?; let parents_by_child_id = self .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx) .await?; let channels_with_admin_privileges = channel_memberships .iter() .filter_map(|membership| membership.admin.then_some(membership.channel_id)) .collect(); let channels = self.get_all_channels(parents_by_child_id, &tx).await?; #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryUserIdsAndChannelIds { ChannelId, UserId, } let mut channel_participants: HashMap> = HashMap::default(); { let mut rows = room_participant::Entity::find() .inner_join(room::Entity) .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id))) .select_only() .column(room::Column::ChannelId) .column(room_participant::Column::UserId) .into_values::<_, QueryUserIdsAndChannelIds>() .stream(&*tx) .await?; while let Some(row) = rows.next().await { let row: (ChannelId, UserId) = row?; channel_participants.entry(row.0).or_default().push(row.1) } } Ok(ChannelsForUser { channels, channel_participants, channels_with_admin_privileges, }) }) .await } pub async fn get_channel_members(&self, id: ChannelId) -> Result> { self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await }) .await } pub async fn set_channel_member_admin( &self, channel_id: ChannelId, from: UserId, for_user: UserId, admin: bool, ) -> Result<()> { self.transaction(|tx| async move { self.check_user_is_channel_admin(channel_id, from, &*tx) .await?; let result = channel_member::Entity::update_many() .filter( channel_member::Column::ChannelId .eq(channel_id) .and(channel_member::Column::UserId.eq(for_user)), ) .set(channel_member::ActiveModel { admin: ActiveValue::set(admin), ..Default::default() }) .exec(&*tx) .await?; if result.rows_affected == 0 { Err(anyhow!("no such member"))?; } Ok(()) }) .await } pub async fn get_channel_member_details( &self, channel_id: ChannelId, user_id: UserId, ) -> Result> { self.transaction(|tx| async move { self.check_user_is_channel_admin(channel_id, user_id, &*tx) .await?; #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryMemberDetails { UserId, Admin, IsDirectMember, Accepted, } let tx = tx; let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?; let mut stream = channel_member::Entity::find() .distinct() .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied())) .select_only() .column(channel_member::Column::UserId) .column(channel_member::Column::Admin) .column_as( channel_member::Column::ChannelId.eq(channel_id), QueryMemberDetails::IsDirectMember, ) .column(channel_member::Column::Accepted) .order_by_asc(channel_member::Column::UserId) .into_values::<_, QueryMemberDetails>() .stream(&*tx) .await?; let mut rows = Vec::::new(); while let Some(row) = stream.next().await { let (user_id, is_admin, is_direct_member, is_invite_accepted): ( UserId, bool, bool, bool, ) = row?; let kind = match (is_direct_member, is_invite_accepted) { (true, true) => proto::channel_member::Kind::Member, (true, false) => proto::channel_member::Kind::Invitee, (false, true) => proto::channel_member::Kind::AncestorMember, (false, false) => continue, }; let user_id = user_id.to_proto(); let kind = kind.into(); if let Some(last_row) = rows.last_mut() { if last_row.user_id == user_id { if is_direct_member { last_row.kind = kind; last_row.admin = is_admin; } continue; } } rows.push(proto::ChannelMember { user_id, kind, admin: is_admin, }); } Ok(rows) }) .await } pub async fn get_channel_members_internal( &self, id: ChannelId, tx: &DatabaseTransaction, ) -> Result> { let ancestor_ids = self.get_channel_ancestors(id, tx).await?; let user_ids = channel_member::Entity::find() .distinct() .filter( channel_member::Column::ChannelId .is_in(ancestor_ids.iter().copied()) .and(channel_member::Column::Accepted.eq(true)), ) .select_only() .column(channel_member::Column::UserId) .into_values::<_, QueryUserIds>() .all(&*tx) .await?; Ok(user_ids) } pub async fn check_user_is_channel_member( &self, channel_id: ChannelId, user_id: UserId, tx: &DatabaseTransaction, ) -> Result<()> { let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; channel_member::Entity::find() .filter( channel_member::Column::ChannelId .is_in(channel_ids) .and(channel_member::Column::UserId.eq(user_id)), ) .one(&*tx) .await? .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?; Ok(()) } pub async fn check_user_is_channel_admin( &self, channel_id: ChannelId, user_id: UserId, tx: &DatabaseTransaction, ) -> Result<()> { let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; channel_member::Entity::find() .filter( channel_member::Column::ChannelId .is_in(channel_ids) .and(channel_member::Column::UserId.eq(user_id)) .and(channel_member::Column::Admin.eq(true)), ) .one(&*tx) .await? .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?; Ok(()) } /// Returns the channel ancestors, deepest first pub async fn get_channel_ancestors( &self, channel_id: ChannelId, tx: &DatabaseTransaction, ) -> Result> { let paths = channel_path::Entity::find() .filter(channel_path::Column::ChannelId.eq(channel_id)) .order_by(channel_path::Column::IdPath, sea_query::Order::Desc) .all(tx) .await?; let mut channel_ids = Vec::new(); for path in paths { for id in path.id_path.trim_matches('/').split('/') { if let Ok(id) = id.parse() { let id = ChannelId::from_proto(id); if let Err(ix) = channel_ids.binary_search(&id) { channel_ids.insert(ix, id); } } } } Ok(channel_ids) } /// Returns the channel descendants, /// Structured as a map from child ids to their parent ids /// For example, the descendants of 'a' in this DAG: /// /// /- b -\ /// a -- c -- d /// /// would be: /// { /// a: [], /// b: [a], /// c: [a], /// d: [a, c], /// } async fn get_channel_descendants( &self, channel_ids: impl IntoIterator, tx: &DatabaseTransaction, ) -> Result { let mut values = String::new(); for id in channel_ids { if !values.is_empty() { values.push_str(", "); } write!(&mut values, "({})", id).unwrap(); } if values.is_empty() { return Ok(HashMap::default()); } let sql = format!( r#" SELECT descendant_paths.* FROM channel_paths parent_paths, channel_paths descendant_paths WHERE parent_paths.channel_id IN ({values}) AND descendant_paths.id_path LIKE (parent_paths.id_path || '%') "# ); let stmt = Statement::from_string(self.pool.get_database_backend(), sql); let mut parents_by_child_id: ChannelDescendants = HashMap::default(); let mut paths = channel_path::Entity::find() .from_raw_sql(stmt) .stream(tx) .await?; while let Some(path) = paths.next().await { let path = path?; let ids = path.id_path.trim_matches('/').split('/'); let mut parent_id = None; for id in ids { if let Ok(id) = id.parse() { let id = ChannelId::from_proto(id); if id == path.channel_id { break; } parent_id = Some(id); } } let entry = parents_by_child_id.entry(path.channel_id).or_default(); if let Some(parent_id) = parent_id { entry.insert(parent_id); } } Ok(parents_by_child_id) } /// Returns the channel with the given ID and: /// - true if the user is a member /// - false if the user hasn't accepted the invitation yet pub async fn get_channel( &self, channel_id: ChannelId, user_id: UserId, ) -> Result> { self.transaction(|tx| async move { let tx = tx; let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?; if let Some(channel) = channel { if self .check_user_is_channel_member(channel_id, user_id, &*tx) .await .is_err() { return Ok(None); } let channel_membership = channel_member::Entity::find() .filter( channel_member::Column::ChannelId .eq(channel_id) .and(channel_member::Column::UserId.eq(user_id)), ) .one(&*tx) .await?; let is_accepted = channel_membership .map(|membership| membership.accepted) .unwrap_or(false); Ok(Some(( Channel { id: channel.id, name: channel.name, parent_id: None, }, is_accepted, ))) } else { Ok(None) } }) .await } pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result { self.transaction(|tx| async move { let tx = tx; let room = channel::Model { id: channel_id, ..Default::default() } .find_related(room::Entity) .one(&*tx) .await? .ok_or_else(|| anyhow!("invalid channel"))?; Ok(room.id) }) .await } async fn link_channel( &self, from: ChannelId, to: ChannelId, tx: &DatabaseTransaction, ) -> Result { let to_ancestors = self.get_channel_ancestors(to, &*tx).await?; let from_descendants = self.get_channel_descendants([from], &*tx).await?; for ancestor in to_ancestors { if from_descendants.contains_key(&ancestor) { return Err(anyhow!("Cannot create a channel cycle").into()); } } let sql = r#" INSERT INTO channel_paths (id_path, channel_id) SELECT id_path || $1 || '/', $2 FROM channel_paths WHERE channel_id = $3 ON CONFLICT (id_path) DO NOTHING; "#; let channel_paths_stmt = Statement::from_sql_and_values( self.pool.get_database_backend(), sql, [ from.to_proto().into(), from.to_proto().into(), to.to_proto().into(), ], ); tx.execute(channel_paths_stmt).await?; for (from_id, to_ids) in from_descendants.iter().filter(|(id, _)| id != &&from) { for to_id in to_ids { let channel_paths_stmt = Statement::from_sql_and_values( self.pool.get_database_backend(), sql, [ from_id.to_proto().into(), from_id.to_proto().into(), to_id.to_proto().into(), ], ); tx.execute(channel_paths_stmt).await?; } } Ok(from_descendants) } async fn remove_channel_from_parent( &self, from: ChannelId, parent: ChannelId, tx: &DatabaseTransaction, ) -> Result<()> { let sql = r#" DELETE FROM channel_paths WHERE id_path LIKE '%' || $1 || '/' || $2 || '%' "#; let channel_paths_stmt = Statement::from_sql_and_values( self.pool.get_database_backend(), sql, [parent.to_proto().into(), from.to_proto().into()], ); tx.execute(channel_paths_stmt).await?; Ok(()) } /// Move a channel from one parent to another. /// Note that this requires a valid parent_id in the 'from_parent' field. /// As channels are a DAG, we need to know which parent to remove the channel from. /// Here's a list of the parameters to this function and their behavior: /// /// - (`None`, `None`) No op /// - (`None`, `Some(id)`) Link the channel without removing it from any of it's parents /// - (`Some(id)`, `None`) Remove a channel from a given parent, and leave other parents /// - (`Some(id)`, `Some(id)`) Move channel from one parent to another, leaving other parents /// /// Returns the channel that was moved + it's sub channels pub async fn move_channel( &self, user: UserId, from: ChannelId, from_parent: Option, to: Option, ) -> Result> { self.transaction(|tx| async move { // Note that even with these maxed permissions, this linking operation // is still insecure because you can't remove someone's permissions to a // channel if they've linked the channel to one where they're an admin. self.check_user_is_channel_admin(from, user, &*tx).await?; let mut channel_descendants = None; if let Some(from_parent) = from_parent { self.check_user_is_channel_admin(from_parent, user, &*tx) .await?; self.remove_channel_from_parent(from, from_parent, &*tx) .await?; } if let Some(to) = to { self.check_user_is_channel_admin(to, user, &*tx).await?; channel_descendants = Some(self.link_channel(from, to, &*tx).await?); } let mut channel_descendants = match channel_descendants { Some(channel_descendants) => channel_descendants, None => self.get_channel_descendants([from], &*tx).await?, }; // Repair the parent ID of the channel in case it was from a cached call if let Some(channel) = channel_descendants.get_mut(&from) { if let Some(from_parent) = from_parent { channel.remove(&from_parent); } if let Some(to) = to { channel.insert(to); } } let channels = self .get_all_channels(channel_descendants, &*tx) .await?; Ok(channels) }) .await } } #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryUserIds { UserId, }