diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index dd6e80150b..dcb793aa51 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -192,7 +192,8 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id"); CREATE TABLE "channels" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "name" VARCHAR NOT NULL, - "created_at" TIMESTAMP NOT NULL DEFAULT now + "created_at" TIMESTAMP NOT NULL DEFAULT now, + "visibility" VARCHAR NOT NULL ); CREATE TABLE IF NOT EXISTS "channel_chat_participants" ( diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index d2e990a640..5ba724dd12 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -91,6 +91,8 @@ pub enum ChannelRole { Member, #[sea_orm(string_value = "guest")] Guest, + #[sea_orm(string_value = "banned")] + Banned, } impl From for ChannelRole { @@ -99,6 +101,7 @@ impl From for ChannelRole { proto::ChannelRole::Admin => ChannelRole::Admin, proto::ChannelRole::Member => ChannelRole::Member, proto::ChannelRole::Guest => ChannelRole::Guest, + proto::ChannelRole::Banned => ChannelRole::Banned, } } } @@ -109,6 +112,7 @@ impl Into for ChannelRole { ChannelRole::Admin => proto::ChannelRole::Admin, ChannelRole::Member => proto::ChannelRole::Member, ChannelRole::Guest => proto::ChannelRole::Guest, + ChannelRole::Banned => proto::ChannelRole::Banned, } } } diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index 5c96955eba..7ce20e1a20 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -37,8 +37,9 @@ impl Database { } let channel = channel::ActiveModel { + id: ActiveValue::NotSet, name: ActiveValue::Set(name.to_string()), - ..Default::default() + visibility: ActiveValue::Set(ChannelVisibility::ChannelMembers), } .insert(&*tx) .await?; @@ -89,6 +90,29 @@ impl Database { .await } + pub async fn set_channel_visibility( + &self, + channel_id: ChannelId, + visibility: ChannelVisibility, + user_id: UserId, + ) -> Result<()> { + self.transaction(move |tx| async move { + self.check_user_is_channel_admin(channel_id, user_id, &*tx) + .await?; + + channel::ActiveModel { + id: ActiveValue::Unchanged(channel_id), + visibility: ActiveValue::Set(visibility), + ..Default::default() + } + .update(&*tx) + .await?; + + Ok(()) + }) + .await + } + pub async fn delete_channel( &self, channel_id: ChannelId, @@ -160,11 +184,11 @@ impl Database { &self, channel_id: ChannelId, invitee_id: UserId, - inviter_id: UserId, + admin_id: UserId, role: ChannelRole, ) -> Result<()> { self.transaction(move |tx| async move { - self.check_user_is_channel_admin(channel_id, inviter_id, &*tx) + self.check_user_is_channel_admin(channel_id, admin_id, &*tx) .await?; channel_member::ActiveModel { @@ -262,10 +286,10 @@ impl Database { &self, channel_id: ChannelId, member_id: UserId, - remover_id: UserId, + admin_id: UserId, ) -> Result<()> { self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel_id, remover_id, &*tx) + self.check_user_is_channel_admin(channel_id, admin_id, &*tx) .await?; let result = channel_member::Entity::delete_many() @@ -481,12 +505,12 @@ impl Database { pub async fn set_channel_member_role( &self, channel_id: ChannelId, - from: UserId, + admin_id: UserId, for_user: UserId, role: ChannelRole, ) -> Result<()> { self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel_id, from, &*tx) + self.check_user_is_channel_admin(channel_id, admin_id, &*tx) .await?; let result = channel_member::Entity::update_many() @@ -613,43 +637,147 @@ impl Database { Ok(user_ids) } - pub async fn check_user_is_channel_member( - &self, - channel_id: ChannelId, - user_id: UserId, - tx: &DatabaseTransaction, - ) -> Result<()> { - let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; - channel_member::Entity::find() - .filter( - channel_member::Column::ChannelId - .is_in(channel_ids) - .and(channel_member::Column::UserId.eq(user_id)), - ) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?; - Ok(()) - } - pub async fn check_user_is_channel_admin( &self, channel_id: ChannelId, user_id: UserId, tx: &DatabaseTransaction, ) -> Result<()> { + match self.channel_role_for_user(channel_id, user_id, tx).await? { + Some(ChannelRole::Admin) => Ok(()), + Some(ChannelRole::Member) + | Some(ChannelRole::Banned) + | Some(ChannelRole::Guest) + | None => Err(anyhow!( + "user is not a channel admin or channel does not exist" + ))?, + } + } + + pub async fn check_user_is_channel_member( + &self, + channel_id: ChannelId, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result<()> { + match self.channel_role_for_user(channel_id, user_id, tx).await? { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()), + Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!( + "user is not a channel member or channel does not exist" + ))?, + } + } + + pub async fn check_user_is_channel_participant( + &self, + channel_id: ChannelId, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result<()> { + match self.channel_role_for_user(channel_id, user_id, tx).await? { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) | Some(ChannelRole::Guest) => { + Ok(()) + } + Some(ChannelRole::Banned) | None => Err(anyhow!( + "user is not a channel participant or channel does not exist" + ))?, + } + } + + pub async fn channel_role_for_user( + &self, + channel_id: ChannelId, + user_id: UserId, + tx: &DatabaseTransaction, + ) -> Result> { let channel_ids = self.get_channel_ancestors(channel_id, tx).await?; - channel_member::Entity::find() + + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryChannelMembership { + ChannelId, + Role, + Admin, + Visibility, + } + + let mut rows = channel_member::Entity::find() + .left_join(channel::Entity) .filter( channel_member::Column::ChannelId .is_in(channel_ids) - .and(channel_member::Column::UserId.eq(user_id)) - .and(channel_member::Column::Admin.eq(true)), + .and(channel_member::Column::UserId.eq(user_id)), ) - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?; - Ok(()) + .select_only() + .column(channel_member::Column::ChannelId) + .column(channel_member::Column::Role) + .column(channel_member::Column::Admin) + .column(channel::Column::Visibility) + .into_values::<_, QueryChannelMembership>() + .stream(&*tx) + .await?; + + let mut is_admin = false; + let mut is_member = false; + let mut is_participant = false; + let mut is_banned = false; + let mut current_channel_visibility = None; + + // note these channels are not iterated in any particular order, + // our current logic takes the highest permission available. + while let Some(row) = rows.next().await { + let (ch_id, role, admin, visibility): ( + ChannelId, + Option, + bool, + ChannelVisibility, + ) = row?; + match role { + Some(ChannelRole::Admin) => is_admin = true, + Some(ChannelRole::Member) => is_member = true, + Some(ChannelRole::Guest) => { + if visibility == ChannelVisibility::Public { + is_participant = true + } + } + Some(ChannelRole::Banned) => is_banned = true, + None => { + // rows created from pre-role collab server. + if admin { + is_admin = true + } else { + is_member = true + } + } + } + if channel_id == ch_id { + current_channel_visibility = Some(visibility); + } + } + // free up database connection + drop(rows); + + Ok(if is_admin { + Some(ChannelRole::Admin) + } else if is_member { + Some(ChannelRole::Member) + } else if is_banned { + Some(ChannelRole::Banned) + } else if is_participant { + if current_channel_visibility.is_none() { + current_channel_visibility = channel::Entity::find() + .filter(channel::Column::Id.eq(channel_id)) + .one(&*tx) + .await? + .map(|channel| channel.visibility); + } + if current_channel_visibility == Some(ChannelVisibility::Public) { + Some(ChannelRole::Guest) + } else { + None + } + } else { + None + }) } /// Returns the channel ancestors, deepest first diff --git a/crates/collab/src/db/tables/channel.rs b/crates/collab/src/db/tables/channel.rs index efda02ec43..0975a8cc30 100644 --- a/crates/collab/src/db/tables/channel.rs +++ b/crates/collab/src/db/tables/channel.rs @@ -7,7 +7,7 @@ pub struct Model { #[sea_orm(primary_key)] pub id: ChannelId, pub name: String, - pub visbility: ChannelVisibility, + pub visibility: ChannelVisibility, } impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 90b3a0cd2e..2263920955 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -8,11 +8,14 @@ use crate::{ db::{ queries::channels::ChannelGraph, tests::{graph, TEST_RELEASE_CHANNEL}, - ChannelId, ChannelRole, Database, NewUserParams, + ChannelId, ChannelRole, Database, NewUserParams, UserId, }, test_both_dbs, }; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicI32, Ordering}, + Arc, +}; test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite); @@ -850,6 +853,101 @@ async fn test_db_channel_moving_bugs(db: &Arc) { ); } +test_both_dbs!( + test_user_is_channel_participant, + test_user_is_channel_participant_postgres, + test_user_is_channel_participant_sqlite +); + +async fn test_user_is_channel_participant(db: &Arc) { + let admin_id = new_test_user(db, "admin@example.com").await; + let member_id = new_test_user(db, "member@example.com").await; + let guest_id = new_test_user(db, "guest@example.com").await; + + let zed_id = db.create_root_channel("zed", admin_id).await.unwrap(); + let intermediate_id = db + .create_channel("active", Some(zed_id), admin_id) + .await + .unwrap(); + let public_id = db + .create_channel("active", Some(intermediate_id), admin_id) + .await + .unwrap(); + + db.set_channel_visibility(public_id, crate::db::ChannelVisibility::Public, admin_id) + .await + .unwrap(); + db.invite_channel_member(intermediate_id, member_id, admin_id, ChannelRole::Member) + .await + .unwrap(); + db.invite_channel_member(public_id, guest_id, admin_id, ChannelRole::Guest) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant(public_id, admin_id, &*tx) + .await + }) + .await + .unwrap(); + db.transaction(|tx| async move { + db.check_user_is_channel_participant(public_id, member_id, &*tx) + .await + }) + .await + .unwrap(); + db.transaction(|tx| async move { + db.check_user_is_channel_participant(public_id, guest_id, &*tx) + .await + }) + .await + .unwrap(); + + db.set_channel_member_role(public_id, admin_id, guest_id, ChannelRole::Banned) + .await + .unwrap(); + assert!(db + .transaction(|tx| async move { + db.check_user_is_channel_participant(public_id, guest_id, &*tx) + .await + }) + .await + .is_err()); + + db.remove_channel_member(public_id, guest_id, admin_id) + .await + .unwrap(); + + db.set_channel_visibility(zed_id, crate::db::ChannelVisibility::Public, admin_id) + .await + .unwrap(); + + db.invite_channel_member(zed_id, guest_id, admin_id, ChannelRole::Guest) + .await + .unwrap(); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant(zed_id, guest_id, &*tx) + .await + }) + .await + .unwrap(); + assert!(db + .transaction(|tx| async move { + db.check_user_is_channel_participant(intermediate_id, guest_id, &*tx) + .await + }) + .await + .is_err(),); + + db.transaction(|tx| async move { + db.check_user_is_channel_participant(public_id, guest_id, &*tx) + .await + }) + .await + .unwrap(); +} + #[track_caller] fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option)]) { let mut actual_map: HashMap> = HashMap::default(); @@ -874,3 +972,22 @@ fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option)]) pretty_assertions::assert_eq!(actual_map, expected_map) } + +static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5); + +async fn new_test_user(db: &Arc, email: &str) -> UserId { + let gid = GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst); + + db.create_user( + email, + false, + NewUserParams { + github_login: email[0..email.find("@").unwrap()].to_string(), + github_user_id: GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst), + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id +} diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index bc814d06a2..95a672e76c 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -6,7 +6,10 @@ use call::ActiveCall; use channel::{ChannelId, ChannelMembership, ChannelStore}; use client::User; use gpui::{executor::Deterministic, ModelHandle, TestAppContext}; -use rpc::{proto, RECEIVE_TIMEOUT}; +use rpc::{ + proto::{self}, + RECEIVE_TIMEOUT, +}; use std::sync::Arc; #[gpui::test] diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index fec56ad9dc..90e425a39f 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -1040,6 +1040,7 @@ enum ChannelRole { Admin = 0; Member = 1; Guest = 2; + Banned = 3; } message SetChannelMemberRole {