From 0998440bddf09a8425a6890dc8bc89052d62feb9 Mon Sep 17 00:00:00 2001 From: Mikayla Maki Date: Fri, 28 Jul 2023 13:14:24 -0700 Subject: [PATCH] implement recursive channel query --- .../20221109000000_test_schema.sql | 3 - crates/collab/src/db.rs | 256 +++++++++--------- crates/collab/src/db/channel.rs | 1 - crates/collab/src/db/channel_parent.rs | 3 + crates/collab/src/tests/channel_tests.rs | 66 ++++- 5 files changed, 191 insertions(+), 138 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index ed7459e4a0..b397438e27 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -187,7 +187,6 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id"); CREATE TABLE "channels" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, - -- "id_path" TEXT NOT NULL, "name" VARCHAR NOT NULL, "room_id" INTEGER REFERENCES rooms (id) ON DELETE SET NULL, "created_at" TIMESTAMP NOT NULL DEFAULT now @@ -199,8 +198,6 @@ CREATE TABLE "channel_parents" ( PRIMARY KEY(child_id, parent_id) ); --- CREATE UNIQUE INDEX "index_channels_on_id_path" ON "channels" ("id_path"); - CREATE TABLE "channel_members" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index c8bec8a3f9..5755ed73e2 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,7 +1,7 @@ mod access_token; mod channel; mod channel_member; -// mod channel_parent; +mod channel_parent; mod contact; mod follower; mod language_server; @@ -39,7 +39,10 @@ use sea_orm::{ DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, TransactionTrait, }; -use sea_query::{Alias, Expr, OnConflict, Query, SelectStatement}; +use sea_query::{ + Alias, ColumnRef, CommonTableExpression, Expr, OnConflict, Order, Query, QueryStatementWriter, + SelectStatement, UnionType, WithClause, +}; use serde::{Deserialize, Serialize}; pub use signup::{Invite, NewSignup, WaitlistSummary}; use sqlx::migrate::{Migrate, Migration, MigrationSource}; @@ -3032,7 +3035,11 @@ impl Database { // channels - pub async fn create_channel(&self, name: &str) -> Result { + pub async fn create_root_channel(&self, name: &str) -> Result { + self.create_channel(name, None).await + } + + pub async fn create_channel(&self, name: &str, parent: Option) -> Result { self.transaction(move |tx| async move { let tx = tx; @@ -3043,10 +3050,21 @@ impl Database { let channel = channel.insert(&*tx).await?; + if let Some(parent) = parent { + channel_parent::ActiveModel { + child_id: ActiveValue::Set(channel.id), + parent_id: ActiveValue::Set(parent), + } + .insert(&*tx) + .await?; + } + Ok(channel.id) - }).await + }) + .await } + // Property: Members are only pub async fn add_channel_member(&self, channel_id: ChannelId, user_id: UserId) -> Result<()> { self.transaction(move |tx| async move { let tx = tx; @@ -3060,139 +3078,108 @@ impl Database { channel_membership.insert(&*tx).await?; Ok(()) - }).await + }) + .await } - pub async fn get_channels(&self, user_id: UserId) -> Vec { + pub async fn get_channels(&self, user_id: UserId) -> Result> { self.transaction(|tx| async move { let tx = tx; + // This is the SQL statement we want to generate: + let sql = r#" + WITH RECURSIVE channel_tree(child_id, parent_id, depth) AS ( + SELECT channel_id as child_id, NULL as parent_id, 0 + FROM channel_members + WHERE user_id = ? + UNION ALL + SELECT channel_parents.child_id, channel_parents.parent_id, channel_tree.depth + 1 + FROM channel_parents, channel_tree + WHERE channel_parents.parent_id = channel_tree.child_id + ) + SELECT channel_tree.child_id as id, channels.name, channel_tree.parent_id + FROM channel_tree + JOIN channels ON channels.id = channel_tree.child_id + ORDER BY channel_tree.depth; + "#; + + // let root_channel_ids_query = SelectStatement::new() + // .column(channel_member::Column::ChannelId) + // .expr(Expr::value("NULL")) + // .from(channel_member::Entity.table_ref()) + // .and_where( + // Expr::col(channel_member::Column::UserId) + // .eq(Expr::cust_with_values("?", vec![user_id])), + // ); + + // let build_tree_query = SelectStatement::new() + // .column(channel_parent::Column::ChildId) + // .column(channel_parent::Column::ParentId) + // .expr(Expr::col(Alias::new("channel_tree.depth")).add(1i32)) + // .from(Alias::new("channel_tree")) + // .and_where( + // Expr::col(channel_parent::Column::ParentId) + // .equals(Alias::new("channel_tree"), Alias::new("child_id")), + // ) + // .to_owned(); + + // let common_table_expression = CommonTableExpression::new() + // .query( + // root_channel_ids_query + // .union(UnionType::Distinct, build_tree_query) + // .to_owned(), + // ) + // .column(Alias::new("child_id")) + // .column(Alias::new("parent_id")) + // .column(Alias::new("depth")) + // .table_name(Alias::new("channel_tree")) + // .to_owned(); + + // let select = SelectStatement::new() + // .expr_as( + // Expr::col(Alias::new("channel_tree.child_id")), + // Alias::new("id"), + // ) + // .column(channel::Column::Name) + // .column(Alias::new("channel_tree.parent_id")) + // .from(Alias::new("channel_tree")) + // .inner_join( + // channel::Entity.table_ref(), + // Expr::eq( + // channel::Column::Id.into_expr(), + // Expr::tbl(Alias::new("channel_tree"), Alias::new("child_id")), + // ), + // ) + // .order_by(Alias::new("channel_tree.child_id"), Order::Asc) + // .to_owned(); + + // let with_clause = WithClause::new() + // .recursive(true) + // .cte(common_table_expression) + // .to_owned(); + + // let query = select.with(with_clause); + + // let query = SelectStatement::new() + // .column(ColumnRef::Asterisk) + // .from_subquery(query, Alias::new("channel_tree") + // .to_owned(); + + // let stmt = self.pool.get_database_backend().build(&query); + + let stmt = Statement::from_sql_and_values( + self.pool.get_database_backend(), + sql, + vec![user_id.into()], + ); + + Ok(channel_parent::Entity::find() + .from_raw_sql(stmt) + .into_model::() + .all(&*tx) + .await?) }) - // let user = user::Model { - // id: user_id, - // ..Default::default() - // }; - // let mut channel_ids = user - // .find_related(channel_member::Entity) - // .select_only() - // .column(channel_member::Column::ChannelId) - // .all(&*tx) - // .await; - - // // let descendants = Alias::new("descendants"); - // // let cte_referencing = SelectStatement::new() - // // .column(channel_parent::Column::ChildId) - // // .from(channel::Entity) - // // .and_where( - // // Expr::col(channel_parent::Column::ParentId) - // // .in_subquery(SelectStatement::new().from(descendants).take()) - // // ); - - // // /* - // // WITH RECURSIVE descendant_ids(id) AS ( - // // $1 - // // UNION ALL - // // SELECT child_id as id FROM channel_parents WHERE parent_id IN descendants - // // ) - // // SELECT * from channels where id in descendant_ids - // // */ - - - // // // WITH RECURSIVE descendants(id) AS ( - // // // // SQL QUERY FOR SELECTING Initial IDs - // // // UNION - // // // SELECT id FROM ancestors WHERE p.parent = id - // // // ) - // // // SELECT * FROM descendants; - - - - // // // let descendant_channel_ids = - - - - // // // let query = sea_query::Query::with().recursive(true); - - - // // for id_path in id_paths { - // // // - // // } - - - // // // zed/public/plugins - // // // zed/public/plugins/js - // // // zed/zed-livekit - // // // livekit/zed-livekit - // // // zed - 101 - // // // livekit - 500 - // // // zed-livekit - 510 - // // // public - 150 - // // // plugins - 200 - // // // js - 300 - // // // - // // // Channel, Parent - edges - // // // 510 - 500 - // // // 510 - 101 - // // // - // // // Given the channel 'Zed' (101) - // // // Select * from EDGES where parent = 101 => 510 - // // // - - - // // "SELECT * from channels where id_path like '$1?'" - - // // // https://www.postgresql.org/docs/current/queries-with.html - // // // https://www.sqlite.org/lang_with.html - - // // "SELECT channel_id from channel_ancestors where ancestor_id IN $()" - - // // // | channel_id | ancestor_ids | - // // // 150 150 - // // // 150 101 - // // // 200 101 - // // // 300 101 - // // // 200 150 - // // // 300 150 - // // // 300 200 - // // // - // // // // | channel_id | ancestor_ids | - // // // 150 101 - // // // 200 101 - // // // 300 101 - // // // 200 150 - // // // 300 [150, 200] - - // // channel::Entity::find() - // // .filter(channel::Column::IdPath.like(id_paths.unwrap())) - - // // dbg!(&id_paths.unwrap()[0].id_path); - - // // // let mut channel_members_by_channel_id = HashMap::new(); - // // // for channel_member in channel_members { - // // // channel_members_by_channel_id - // // // .entry(channel_member.channel_id) - // // // .or_insert_with(Vec::new) - // // // .push(channel_member); - // // // } - - // // // let mut channel_messages = channel_message::Entity::find() - // // // .filter(channel_message::Column::ChannelId.in_selection(channel_ids)) - // // // .all(&*tx) - // // // .await?; - - // // // let mut channel_messages_by_channel_id = HashMap::new(); - // // // for channel_message in channel_messages { - // // // channel_messages_by_channel_id - // // // .entry(channel_message.channel_id) - // // // .or_insert_with(Vec::new) - // // // .push(channel_message); - // // // } - - // // todo!(); - // // // Ok(channels) - // Err(Error("not implemented")) - // }) - // .await + .await } async fn transaction(&self, f: F) -> Result @@ -3440,6 +3427,13 @@ pub struct NewUserResult { pub signup_device_id: Option, } +#[derive(FromQueryResult, Debug, PartialEq)] +pub struct Channel { + pub id: ChannelId, + pub name: String, + pub parent_id: Option, +} + fn random_invite_code() -> String { nanoid::nanoid!(16) } diff --git a/crates/collab/src/db/channel.rs b/crates/collab/src/db/channel.rs index f8e2c3b85b..48e5d50e3e 100644 --- a/crates/collab/src/db/channel.rs +++ b/crates/collab/src/db/channel.rs @@ -8,7 +8,6 @@ pub struct Model { pub id: ChannelId, pub name: String, pub room_id: Option, - // pub id_path: String, } impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/channel_parent.rs b/crates/collab/src/db/channel_parent.rs index bf6cb44711..b0072155a3 100644 --- a/crates/collab/src/db/channel_parent.rs +++ b/crates/collab/src/db/channel_parent.rs @@ -11,3 +11,6 @@ pub struct Model { } impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index 24754adeb3..8ab33adcbf 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -1,6 +1,8 @@ use gpui::{executor::Deterministic, TestAppContext}; use std::sync::Arc; +use crate::db::Channel; + use super::TestServer; #[gpui::test] @@ -11,13 +13,71 @@ async fn test_basic_channels(deterministic: Arc, cx: &mut TestApp let a_id = crate::db::UserId(client_a.user_id().unwrap() as i32); let db = server._test_db.db(); - let zed_id = db.create_channel("zed").await.unwrap(); + let zed_id = db.create_root_channel("zed").await.unwrap(); + let crdb_id = db.create_channel("crdb", Some(zed_id)).await.unwrap(); + let livestreaming_id = db + .create_channel("livestreaming", Some(zed_id)) + .await + .unwrap(); + let replace_id = db.create_channel("replace", Some(zed_id)).await.unwrap(); + let rust_id = db.create_root_channel("rust").await.unwrap(); + let cargo_id = db.create_channel("cargo", Some(rust_id)).await.unwrap(); db.add_channel_member(zed_id, a_id).await.unwrap(); + db.add_channel_member(rust_id, a_id).await.unwrap(); - let channels = db.get_channels(a_id).await; + let channels = db.get_channels(a_id).await.unwrap(); + assert_eq!( + channels, + vec![ + Channel { + id: zed_id, + name: "zed".to_string(), + parent_id: None, + }, + Channel { + id: rust_id, + name: "rust".to_string(), + parent_id: None, + }, + Channel { + id: crdb_id, + name: "crdb".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: livestreaming_id, + name: "livestreaming".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: replace_id, + name: "replace".to_string(), + parent_id: Some(zed_id), + }, + Channel { + id: cargo_id, + name: "cargo".to_string(), + parent_id: Some(rust_id), + } + ] + ); +} - assert_eq!(channels, vec![zed_id]); +#[gpui::test] +async fn test_block_cycle_creation(deterministic: Arc, cx: &mut TestAppContext) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx, "user_a").await; + let a_id = crate::db::UserId(client_a.user_id().unwrap() as i32); + let db = server._test_db.db(); + + let zed_id = db.create_root_channel("zed").await.unwrap(); + let first_id = db.create_channel("first", Some(zed_id)).await.unwrap(); + let second_id = db + .create_channel("second_id", Some(first_id)) + .await + .unwrap(); } /*