diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 792c65b075..c580e911bc 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -72,7 +72,6 @@ fs = { path = "../fs", features = ["test-support"] } git = { path = "../git", features = ["test-support"] } live_kit_client = { path = "../live_kit_client", features = ["test-support"] } lsp = { path = "../lsp", features = ["test-support"] } -pretty_assertions.workspace = true project = { path = "../project", features = ["test-support"] } rpc = { path = "../rpc", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] } @@ -81,6 +80,7 @@ workspace = { path = "../workspace", features = ["test-support"] } collab_ui = { path = "../collab_ui", features = ["test-support"] } async-trait.workspace = true +pretty_assertions.workspace = true ctor.workspace = true env_logger.workspace = true indoc.workspace = true diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index 4705cc9415..f31a1cde5d 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -68,6 +68,8 @@ impl Database { ], ); tx.execute(channel_paths_stmt).await?; + + dbg!(channel_path::Entity::find().all(&*tx).await?); } else { channel_path::Entity::insert(channel_path::ActiveModel { channel_id: ActiveValue::Set(channel.id), @@ -336,6 +338,8 @@ impl Database { .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx) .await?; + dbg!(&parents_by_child_id); + let channels_with_admin_privileges = channel_memberships .iter() .filter_map(|membership| membership.admin.then_some(membership.channel_id)) @@ -349,11 +353,24 @@ impl Database { .await?; while let Some(row) = rows.next().await { let row = row?; - channels.push(Channel { - id: row.id, - name: row.name, - parent_id: parents_by_child_id.get(&row.id).copied().flatten(), - }); + + // As these rows are pulled from the map's keys, this unwrap is safe. + let parents = parents_by_child_id.get(&row.id).unwrap(); + if parents.len() > 0 { + for parent in parents { + channels.push(Channel { + id: row.id, + name: row.name.clone(), + parent_id: Some(*parent), + }); + } + } else { + channels.push(Channel { + id: row.id, + name: row.name, + parent_id: None, + }); + } } } @@ -559,6 +576,7 @@ impl Database { Ok(()) } + /// Returns the channel ancestors, deepest first pub async fn get_channel_ancestors( &self, channel_id: ChannelId, @@ -566,6 +584,7 @@ impl Database { ) -> Result> { let paths = channel_path::Entity::find() .filter(channel_path::Column::ChannelId.eq(channel_id)) + .order_by(channel_path::Column::IdPath, sea_query::Order::Desc) .all(tx) .await?; let mut channel_ids = Vec::new(); @@ -586,7 +605,7 @@ impl Database { &self, channel_ids: impl IntoIterator, tx: &DatabaseTransaction, - ) -> Result>> { + ) -> Result>> { let mut values = String::new(); for id in channel_ids { if !values.is_empty() { @@ -613,7 +632,7 @@ impl Database { let stmt = Statement::from_string(self.pool.get_database_backend(), sql); - let mut parents_by_child_id = HashMap::default(); + let mut parents_by_child_id: HashMap> = HashMap::default(); let mut paths = channel_path::Entity::find() .from_raw_sql(stmt) .stream(tx) @@ -632,7 +651,10 @@ impl Database { parent_id = Some(id); } } - parents_by_child_id.insert(path.channel_id, parent_id); + let entry = parents_by_child_id.entry(path.channel_id).or_default(); + if let Some(parent_id) = parent_id { + entry.insert(parent_id); + } } Ok(parents_by_child_id) @@ -704,12 +726,74 @@ impl Database { .await } + pub async fn link_channel(&self, user: UserId, from: ChannelId, to: ChannelId) -> Result<()> { + self.transaction(|tx| async move { + self.check_user_is_channel_admin(to, user, &*tx).await?; + + // TODO: Downgrade this check once our permissions system isn't busted + // You should be able to safely link a member channel for your own uses. See: + // https://zed.dev/blog/this-week-at-zed-15 > Mikayla's section + // + // Note that even with these higher permissions, this linking operation + // is still insecure because you can't remove someone's permissions to a + // channel if they've linked the channel to one where they're an admin. + self.check_user_is_channel_admin(from, user, &*tx).await?; + + let to_ancestors = self.get_channel_ancestors(to, &*tx).await?; + let from_descendants = self.get_channel_descendants([from], &*tx).await?; + for ancestor in to_ancestors { + if from_descendants.contains_key(&ancestor) { + return Err(anyhow!("Cannot create a channel cycle").into()); + } + } + + let sql = r#" + INSERT INTO channel_paths + (id_path, channel_id) + SELECT + id_path || $1 || '/', $2 + FROM + channel_paths + WHERE + channel_id = $3 + ON CONFLICT (id_path) DO NOTHING; + "#; + let channel_paths_stmt = Statement::from_sql_and_values( + self.pool.get_database_backend(), + sql, + [ + from.to_proto().into(), + from.to_proto().into(), + to.to_proto().into(), + ], + ); + tx.execute(channel_paths_stmt).await?; + + for (from_id, to_ids) in from_descendants.iter().filter(|(id, _)| id == &&from) { + for to_id in to_ids { + let channel_paths_stmt = Statement::from_sql_and_values( + self.pool.get_database_backend(), + sql, + [ + from_id.to_proto().into(), + from_id.to_proto().into(), + to_id.to_proto().into(), + ], + ); + tx.execute(channel_paths_stmt).await?; + } + } + + Ok(()) + }) + .await + } + pub async fn move_channel( &self, user: UserId, from: ChannelId, to: Option, - link: bool, ) -> Result<()> { self.transaction(|tx| async move { todo!() }).await } diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 04304ec848..e077950a3a 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -486,14 +486,24 @@ async fn test_channels_moving(db: &Arc) { .await .unwrap(); + let gpui2_id = db + .create_channel("gpui2", Some(zed_id), "3", a_id) + .await + .unwrap(); + let livestreaming_id = db - .create_channel("livestreaming", Some(crdb_id), "3", a_id) + .create_channel("livestreaming", Some(crdb_id), "4", a_id) + .await + .unwrap(); + + let livestreaming_dag_id = db + .create_channel("livestreaming_dag", Some(livestreaming_id), "5", a_id) .await .unwrap(); // sanity check let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_eq!( + pretty_assertions::assert_eq!( result.channels, vec![ Channel { @@ -506,33 +516,46 @@ async fn test_channels_moving(db: &Arc) { name: "crdb".to_string(), parent_id: Some(zed_id), }, + Channel { + id: gpui2_id, + name: "gpui2".to_string(), + parent_id: Some(zed_id), + }, Channel { id: livestreaming_id, name: "livestreaming".to_string(), parent_id: Some(crdb_id), }, + Channel { + id: livestreaming_dag_id, + name: "livestreaming_dag".to_string(), + parent_id: Some(livestreaming_id), + }, ] ); + // Initial DAG: + // /- gpui2 + // zed -- crdb - livestreaming - livestreaming_dag - // Move channel up - db.move_channel(a_id, livestreaming_id, Some(zed_id), false) - .await - .unwrap(); - - // Attempt to make a cycle + // Attemp to make a cycle assert!(db - .move_channel(a_id, zed_id, Some(livestreaming_id), false) + .link_channel(a_id, zed_id, livestreaming_id) .await .is_err()); // Make a link - db.move_channel(a_id, crdb_id, Some(livestreaming_id), true) + db.link_channel(a_id, livestreaming_id, zed_id) .await .unwrap(); + // DAG is now: + // /- gpui2 + // zed -- crdb - livestreaming - livestreaming_dag + // \---------/ + let result = db.get_channels_for_user(a_id).await.unwrap(); - assert_eq!( - result.channels, + pretty_assertions::assert_eq!( + dbg!(result.channels), vec![ Channel { id: zed_id, @@ -545,15 +568,234 @@ async fn test_channels_moving(db: &Arc) { parent_id: Some(zed_id), }, Channel { - id: crdb_id, - name: "crdb".to_string(), - parent_id: Some(livestreaming_id), + id: gpui2_id, + name: "gpui2".to_string(), + parent_id: Some(zed_id), }, Channel { id: livestreaming_id, name: "livestreaming".to_string(), parent_id: Some(zed_id), }, + Channel { + id: livestreaming_id, + name: "livestreaming".to_string(), + parent_id: Some(crdb_id), + }, + Channel { + id: livestreaming_dag_id, + name: "livestreaming_dag".to_string(), + parent_id: Some(livestreaming_id), + }, ] ); + + let livestreaming_dag_sub_id = db + .create_channel("livestreaming_dag_sub", Some(livestreaming_dag_id), "6", a_id) + .await + .unwrap(); + + // DAG is now: + // /- gpui2 + // zed -- crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id + // \---------/ + + let result = db.get_channels_for_user(a_id).await.unwrap(); + pretty_assertions::assert_eq!( + dbg!(result.channels), + vec![ + Channel { + id: zed_id, + name: "zed".to_string(), + parent_id: None, + }, + Channel { + id: crdb_id, + name: "crdb".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: gpui2_id, + name: "gpui2".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: livestreaming_id, + name: "livestreaming".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: livestreaming_id, + name: "livestreaming".to_string(), + parent_id: Some(crdb_id), + }, + Channel { + id: livestreaming_dag_id, + name: "livestreaming_dag".to_string(), + parent_id: Some(livestreaming_id), + }, + Channel { + id: livestreaming_dag_sub_id, + name: "livestreaming_dag_sub".to_string(), + parent_id: Some(livestreaming_dag_id), + }, + ] + ); + + // Make a link + db.link_channel(a_id, livestreaming_dag_sub_id, livestreaming_id) + .await + .unwrap(); + + // DAG is now: + // /- gpui2 /---------------------\ + // zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id + // \--------/ + + let result = db.get_channels_for_user(a_id).await.unwrap(); + pretty_assertions::assert_eq!( + dbg!(result.channels), + vec![ + Channel { + id: zed_id, + name: "zed".to_string(), + parent_id: None, + }, + Channel { + id: crdb_id, + name: "crdb".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: gpui2_id, + name: "gpui2".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: livestreaming_id, + name: "livestreaming".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: livestreaming_id, + name: "livestreaming".to_string(), + parent_id: Some(crdb_id), + }, + Channel { + id: livestreaming_dag_id, + name: "livestreaming_dag".to_string(), + parent_id: Some(livestreaming_id), + }, + Channel { + id: livestreaming_dag_sub_id, + name: "livestreaming_dag_sub".to_string(), + parent_id: Some(livestreaming_id), + }, + Channel { + id: livestreaming_dag_sub_id, + name: "livestreaming_dag_sub".to_string(), + parent_id: Some(livestreaming_dag_id), + }, + ] + ); + + // Make another link + db.link_channel(a_id, livestreaming_id, gpui2_id) + .await + .unwrap(); + + // DAG is now: + // /- gpui2 -\ /---------------------\ + // zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub_id + // \---------/ + + let result = db.get_channels_for_user(a_id).await.unwrap(); + pretty_assertions::assert_eq!( + dbg!(result.channels), + vec![ + Channel { + id: zed_id, + name: "zed".to_string(), + parent_id: None, + }, + Channel { + id: crdb_id, + name: "crdb".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: gpui2_id, + name: "gpui2".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: livestreaming_id, + name: "livestreaming".to_string(), + parent_id: Some(gpui2_id), + }, + Channel { + id: livestreaming_id, + name: "livestreaming".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: livestreaming_id, + name: "livestreaming".to_string(), + parent_id: Some(crdb_id), + }, + Channel { + id: livestreaming_dag_id, + name: "livestreaming_dag".to_string(), + parent_id: Some(livestreaming_id), + }, + Channel { + id: livestreaming_dag_sub_id, + name: "livestreaming_dag_sub".to_string(), + parent_id: Some(livestreaming_id), + }, + Channel { + id: livestreaming_dag_sub_id, + name: "livestreaming_dag_sub".to_string(), + parent_id: Some(livestreaming_dag_id), + }, + ] + ); + + // // Attempt to make a cycle + // assert!(db + // .move_channel(a_id, zed_id, Some(livestreaming_id)) + // .await + // .is_err()); + + // // Move channel up + // db.move_channel(a_id, livestreaming_id, Some(zed_id)) + // .await + // .unwrap(); + + // let result = db.get_channels_for_user(a_id).await.unwrap(); + // pretty_assertions::assert_eq!( + // result.channels, + // vec![ + // Channel { + // id: zed_id, + // name: "zed".to_string(), + // parent_id: None, + // }, + // Channel { + // id: crdb_id, + // name: "crdb".to_string(), + // parent_id: Some(zed_id), + // }, + // Channel { + // id: crdb_id, + // name: "crdb".to_string(), + // parent_id: Some(livestreaming_id), + // }, + // Channel { + // id: livestreaming_id, + // name: "livestreaming".to_string(), + // parent_id: Some(zed_id), + // }, + // ] + // ); }