Only allow one release channel in a call

This commit is contained in:
Conrad Irwin 2023-10-09 12:59:18 -06:00
parent abfb4490d5
commit 162cb19cff
8 changed files with 152 additions and 19 deletions

View file

@ -37,6 +37,7 @@ CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b");
CREATE TABLE "rooms" ( CREATE TABLE "rooms" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT, "id" INTEGER PRIMARY KEY AUTOINCREMENT,
"live_kit_room" VARCHAR NOT NULL, "live_kit_room" VARCHAR NOT NULL,
"release_channel" VARCHAR,
"channel_id" INTEGER REFERENCES channels (id) ON DELETE CASCADE "channel_id" INTEGER REFERENCES channels (id) ON DELETE CASCADE
); );

View file

@ -0,0 +1 @@
ALTER TABLE rooms ADD COLUMN release_channel TEXT;

View file

@ -107,10 +107,12 @@ impl Database {
user_id: UserId, user_id: UserId,
connection: ConnectionId, connection: ConnectionId,
live_kit_room: &str, live_kit_room: &str,
release_channel: &str,
) -> Result<proto::Room> { ) -> Result<proto::Room> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
let room = room::ActiveModel { let room = room::ActiveModel {
live_kit_room: ActiveValue::set(live_kit_room.into()), live_kit_room: ActiveValue::set(live_kit_room.into()),
release_channel: ActiveValue::set(Some(release_channel.to_string())),
..Default::default() ..Default::default()
} }
.insert(&*tx) .insert(&*tx)
@ -270,20 +272,31 @@ impl Database {
room_id: RoomId, room_id: RoomId,
user_id: UserId, user_id: UserId,
connection: ConnectionId, connection: ConnectionId,
collab_release_channel: &str,
) -> Result<RoomGuard<JoinRoom>> { ) -> Result<RoomGuard<JoinRoom>> {
self.room_transaction(room_id, |tx| async move { self.room_transaction(room_id, |tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryChannelId { enum QueryChannelIdAndReleaseChannel {
ChannelId, ChannelId,
ReleaseChannel,
}
let (channel_id, release_channel): (Option<ChannelId>, Option<String>) =
room::Entity::find()
.select_only()
.column(room::Column::ChannelId)
.column(room::Column::ReleaseChannel)
.filter(room::Column::Id.eq(room_id))
.into_values::<_, QueryChannelIdAndReleaseChannel>()
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such room"))?;
if let Some(release_channel) = release_channel {
if &release_channel != collab_release_channel {
Err(anyhow!("must join using the {} release", release_channel))?;
}
} }
let channel_id: Option<ChannelId> = room::Entity::find()
.select_only()
.column(room::Column::ChannelId)
.filter(room::Column::Id.eq(room_id))
.into_values::<_, QueryChannelId>()
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such room"))?;
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryParticipantIndices { enum QueryParticipantIndices {
@ -300,6 +313,7 @@ impl Database {
.into_values::<_, QueryParticipantIndices>() .into_values::<_, QueryParticipantIndices>()
.all(&*tx) .all(&*tx)
.await?; .await?;
let mut participant_index = 0; let mut participant_index = 0;
while existing_participant_indices.contains(&participant_index) { while existing_participant_indices.contains(&participant_index) {
participant_index += 1; participant_index += 1;

View file

@ -8,6 +8,7 @@ pub struct Model {
pub id: RoomId, pub id: RoomId,
pub live_kit_room: String, pub live_kit_room: String,
pub channel_id: Option<ChannelId>, pub channel_id: Option<ChannelId>,
pub release_channel: Option<String>,
} }
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View file

@ -12,6 +12,8 @@ use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase; use sqlx::migrate::MigrateDatabase;
use std::sync::Arc; use std::sync::Arc;
const TEST_RELEASE_CHANNEL: &'static str = "test";
pub struct TestDb { pub struct TestDb {
pub db: Option<Arc<Database>>, pub db: Option<Arc<Database>>,
pub connection: Option<sqlx::AnyConnection>, pub connection: Option<sqlx::AnyConnection>,

View file

@ -5,7 +5,11 @@ use rpc::{
}; };
use crate::{ use crate::{
db::{queries::channels::ChannelGraph, tests::graph, ChannelId, Database, NewUserParams}, db::{
queries::channels::ChannelGraph,
tests::{graph, TEST_RELEASE_CHANNEL},
ChannelId, Database, NewUserParams,
},
test_both_dbs, test_both_dbs,
}; };
use std::sync::Arc; use std::sync::Arc;
@ -206,7 +210,12 @@ async fn test_joining_channels(db: &Arc<Database>) {
// can join a room with membership to its channel // can join a room with membership to its channel
let joined_room = db let joined_room = db
.join_room(room_1, user_1, ConnectionId { owner_id, id: 1 }) .join_room(
room_1,
user_1,
ConnectionId { owner_id, id: 1 },
TEST_RELEASE_CHANNEL,
)
.await .await
.unwrap(); .unwrap();
assert_eq!(joined_room.room.participants.len(), 1); assert_eq!(joined_room.room.participants.len(), 1);
@ -214,7 +223,12 @@ async fn test_joining_channels(db: &Arc<Database>) {
drop(joined_room); drop(joined_room);
// cannot join a room without membership to its channel // cannot join a room without membership to its channel
assert!(db assert!(db
.join_room(room_1, user_2, ConnectionId { owner_id, id: 1 }) .join_room(
room_1,
user_2,
ConnectionId { owner_id, id: 1 },
TEST_RELEASE_CHANNEL
)
.await .await
.is_err()); .is_err());
} }

View file

@ -479,7 +479,7 @@ async fn test_project_count(db: &Arc<Database>) {
.unwrap(); .unwrap();
let room_id = RoomId::from_proto( let room_id = RoomId::from_proto(
db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "") db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "", "dev")
.await .await
.unwrap() .unwrap()
.id, .id,
@ -493,9 +493,14 @@ async fn test_project_count(db: &Arc<Database>) {
) )
.await .await
.unwrap(); .unwrap();
db.join_room(room_id, user2.user_id, ConnectionId { owner_id, id: 1 }) db.join_room(
.await room_id,
.unwrap(); user2.user_id,
ConnectionId { owner_id, id: 1 },
"dev",
)
.await
.unwrap();
assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0); assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0);
db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[]) db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[])
@ -575,6 +580,85 @@ async fn test_fuzzy_search_users() {
} }
} }
test_both_dbs!(
test_non_matching_release_channels,
test_non_matching_release_channels_postgres,
test_non_matching_release_channels_sqlite
);
async fn test_non_matching_release_channels(db: &Arc<Database>) {
let owner_id = db.create_server("test").await.unwrap().0 as u32;
let user1 = db
.create_user(
&format!("admin@example.com"),
true,
NewUserParams {
github_login: "admin".into(),
github_user_id: 0,
invite_count: 0,
},
)
.await
.unwrap();
let user2 = db
.create_user(
&format!("user@example.com"),
false,
NewUserParams {
github_login: "user".into(),
github_user_id: 1,
invite_count: 0,
},
)
.await
.unwrap();
let room = db
.create_room(
user1.user_id,
ConnectionId { owner_id, id: 0 },
"",
"stable",
)
.await
.unwrap();
db.call(
RoomId::from_proto(room.id),
user1.user_id,
ConnectionId { owner_id, id: 0 },
user2.user_id,
None,
)
.await
.unwrap();
// User attempts to join from preview
let result = db
.join_room(
RoomId::from_proto(room.id),
user2.user_id,
ConnectionId { owner_id, id: 1 },
"preview",
)
.await;
assert!(result.is_err());
// User switches to stable
let result = db
.join_room(
RoomId::from_proto(room.id),
user2.user_id,
ConnectionId { owner_id, id: 1 },
"stable",
)
.await;
assert!(result.is_ok())
}
fn build_background_executor() -> Arc<Background> { fn build_background_executor() -> Arc<Background> {
Deterministic::new(0).build_background() Deterministic::new(0).build_background()
} }

View file

@ -63,6 +63,7 @@ use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore}; use tokio::sync::{watch, Semaphore};
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
use util::channel::RELEASE_CHANNEL_NAME;
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10); pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
@ -957,7 +958,12 @@ async fn create_room(
let room = session let room = session
.db() .db()
.await .await
.create_room(session.user_id, session.connection_id, &live_kit_room) .create_room(
session.user_id,
session.connection_id,
&live_kit_room,
RELEASE_CHANNEL_NAME.as_str(),
)
.await?; .await?;
response.send(proto::CreateRoomResponse { response.send(proto::CreateRoomResponse {
@ -979,7 +985,12 @@ async fn join_room(
let room = session let room = session
.db() .db()
.await .await
.join_room(room_id, session.user_id, session.connection_id) .join_room(
room_id,
session.user_id,
session.connection_id,
RELEASE_CHANNEL_NAME.as_str(),
)
.await?; .await?;
room_updated(&room.room, &session.peer); room_updated(&room.room, &session.peer);
room.into_inner() room.into_inner()
@ -2616,7 +2627,12 @@ async fn join_channel(
let room_id = db.room_id_for_channel(channel_id).await?; let room_id = db.room_id_for_channel(channel_id).await?;
let joined_room = db let joined_room = db
.join_room(room_id, session.user_id, session.connection_id) .join_room(
room_id,
session.user_id,
session.connection_id,
RELEASE_CHANNEL_NAME.as_str(),
)
.await?; .await?;
let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| { let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {