diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 2d963ff15f..a9b8d9709d 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -37,6 +37,7 @@ CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b"); CREATE TABLE "rooms" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "live_kit_room" VARCHAR NOT NULL, + "release_channel" VARCHAR, "channel_id" INTEGER REFERENCES channels (id) ON DELETE CASCADE ); diff --git a/crates/collab/migrations/20231009181554_add_release_channel_to_rooms.sql b/crates/collab/migrations/20231009181554_add_release_channel_to_rooms.sql new file mode 100644 index 0000000000..95d3c400fc --- /dev/null +++ b/crates/collab/migrations/20231009181554_add_release_channel_to_rooms.sql @@ -0,0 +1 @@ +ALTER TABLE rooms ADD COLUMN release_channel TEXT; diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index b103ae1c73..6589f23791 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -107,10 +107,12 @@ impl Database { user_id: UserId, connection: ConnectionId, live_kit_room: &str, + release_channel: &str, ) -> Result { self.transaction(|tx| async move { let room = room::ActiveModel { live_kit_room: ActiveValue::set(live_kit_room.into()), + release_channel: ActiveValue::set(Some(release_channel.to_string())), ..Default::default() } .insert(&*tx) @@ -270,20 +272,31 @@ impl Database { room_id: RoomId, user_id: UserId, connection: ConnectionId, + collab_release_channel: &str, ) -> Result> { self.room_transaction(room_id, |tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryChannelId { + enum QueryChannelIdAndReleaseChannel { ChannelId, + ReleaseChannel, + } + + let (channel_id, release_channel): (Option, Option) = + room::Entity::find() + .select_only() + .column(room::Column::ChannelId) + .column(room::Column::ReleaseChannel) + .filter(room::Column::Id.eq(room_id)) + .into_values::<_, QueryChannelIdAndReleaseChannel>() + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such room"))?; + + if let Some(release_channel) = release_channel { + if &release_channel != collab_release_channel { + Err(anyhow!("must join using the {} release", release_channel))?; + } } - let channel_id: Option = room::Entity::find() - .select_only() - .column(room::Column::ChannelId) - .filter(room::Column::Id.eq(room_id)) - .into_values::<_, QueryChannelId>() - .one(&*tx) - .await? - .ok_or_else(|| anyhow!("no such room"))?; #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryParticipantIndices { @@ -300,6 +313,7 @@ impl Database { .into_values::<_, QueryParticipantIndices>() .all(&*tx) .await?; + let mut participant_index = 0; while existing_participant_indices.contains(&participant_index) { participant_index += 1; diff --git a/crates/collab/src/db/tables/room.rs b/crates/collab/src/db/tables/room.rs index f72f7000a7..7f31edcdbd 100644 --- a/crates/collab/src/db/tables/room.rs +++ b/crates/collab/src/db/tables/room.rs @@ -8,6 +8,7 @@ pub struct Model { pub id: RoomId, pub live_kit_room: String, pub channel_id: Option, + pub release_channel: Option, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 75584ff90b..6a91fd6ffe 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -12,6 +12,8 @@ use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; use std::sync::Arc; +const TEST_RELEASE_CHANNEL: &'static str = "test"; + pub struct TestDb { pub db: Option>, pub connection: Option, diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 429852d128..2631e0d191 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -5,7 +5,11 @@ use rpc::{ }; use crate::{ - db::{queries::channels::ChannelGraph, tests::graph, ChannelId, Database, NewUserParams}, + db::{ + queries::channels::ChannelGraph, + tests::{graph, TEST_RELEASE_CHANNEL}, + ChannelId, Database, NewUserParams, + }, test_both_dbs, }; use std::sync::Arc; @@ -206,7 +210,12 @@ async fn test_joining_channels(db: &Arc) { // can join a room with membership to its channel let joined_room = db - .join_room(room_1, user_1, ConnectionId { owner_id, id: 1 }) + .join_room( + room_1, + user_1, + ConnectionId { owner_id, id: 1 }, + TEST_RELEASE_CHANNEL, + ) .await .unwrap(); assert_eq!(joined_room.room.participants.len(), 1); @@ -214,7 +223,12 @@ async fn test_joining_channels(db: &Arc) { drop(joined_room); // cannot join a room without membership to its channel assert!(db - .join_room(room_1, user_2, ConnectionId { owner_id, id: 1 }) + .join_room( + room_1, + user_2, + ConnectionId { owner_id, id: 1 }, + TEST_RELEASE_CHANNEL + ) .await .is_err()); } diff --git a/crates/collab/src/db/tests/db_tests.rs b/crates/collab/src/db/tests/db_tests.rs index 9a617166fe..1520e081c0 100644 --- a/crates/collab/src/db/tests/db_tests.rs +++ b/crates/collab/src/db/tests/db_tests.rs @@ -479,7 +479,7 @@ async fn test_project_count(db: &Arc) { .unwrap(); let room_id = RoomId::from_proto( - db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "") + db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "", "dev") .await .unwrap() .id, @@ -493,9 +493,14 @@ async fn test_project_count(db: &Arc) { ) .await .unwrap(); - db.join_room(room_id, user2.user_id, ConnectionId { owner_id, id: 1 }) - .await - .unwrap(); + db.join_room( + room_id, + user2.user_id, + ConnectionId { owner_id, id: 1 }, + "dev", + ) + .await + .unwrap(); assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0); db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[]) @@ -575,6 +580,85 @@ async fn test_fuzzy_search_users() { } } +test_both_dbs!( + test_non_matching_release_channels, + test_non_matching_release_channels_postgres, + test_non_matching_release_channels_sqlite +); + +async fn test_non_matching_release_channels(db: &Arc) { + let owner_id = db.create_server("test").await.unwrap().0 as u32; + + let user1 = db + .create_user( + &format!("admin@example.com"), + true, + NewUserParams { + github_login: "admin".into(), + github_user_id: 0, + invite_count: 0, + }, + ) + .await + .unwrap(); + let user2 = db + .create_user( + &format!("user@example.com"), + false, + NewUserParams { + github_login: "user".into(), + github_user_id: 1, + invite_count: 0, + }, + ) + .await + .unwrap(); + + let room = db + .create_room( + user1.user_id, + ConnectionId { owner_id, id: 0 }, + "", + "stable", + ) + .await + .unwrap(); + + db.call( + RoomId::from_proto(room.id), + user1.user_id, + ConnectionId { owner_id, id: 0 }, + user2.user_id, + None, + ) + .await + .unwrap(); + + // User attempts to join from preview + let result = db + .join_room( + RoomId::from_proto(room.id), + user2.user_id, + ConnectionId { owner_id, id: 1 }, + "preview", + ) + .await; + + assert!(result.is_err()); + + // User switches to stable + let result = db + .join_room( + RoomId::from_proto(room.id), + user2.user_id, + ConnectionId { owner_id, id: 1 }, + "stable", + ) + .await; + + assert!(result.is_ok()) +} + fn build_background_executor() -> Arc { Deterministic::new(0).build_background() } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 6171803341..70052468d4 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -63,6 +63,7 @@ use time::OffsetDateTime; use tokio::sync::{watch, Semaphore}; use tower::ServiceBuilder; use tracing::{info_span, instrument, Instrument}; +use util::channel::RELEASE_CHANNEL_NAME; pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10); @@ -957,7 +958,12 @@ async fn create_room( let room = session .db() .await - .create_room(session.user_id, session.connection_id, &live_kit_room) + .create_room( + session.user_id, + session.connection_id, + &live_kit_room, + RELEASE_CHANNEL_NAME.as_str(), + ) .await?; response.send(proto::CreateRoomResponse { @@ -979,7 +985,12 @@ async fn join_room( let room = session .db() .await - .join_room(room_id, session.user_id, session.connection_id) + .join_room( + room_id, + session.user_id, + session.connection_id, + RELEASE_CHANNEL_NAME.as_str(), + ) .await?; room_updated(&room.room, &session.peer); room.into_inner() @@ -2616,7 +2627,12 @@ async fn join_channel( let room_id = db.room_id_for_channel(channel_id).await?; let joined_room = db - .join_room(room_id, session.user_id, session.connection_id) + .join_room( + room_id, + session.user_id, + session.connection_id, + RELEASE_CHANNEL_NAME.as_str(), + ) .await?; let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {