diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 3e8fbafb6a..57b183f7de 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -972,6 +972,7 @@ impl ChannelStore { let mut all_user_ids = Vec::new(); let channel_participants = payload.channel_participants; + dbg!(&channel_participants); for entry in &channel_participants { for user_id in entry.participant_user_ids.iter() { if let Err(ix) = all_user_ids.binary_search(user_id) { diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index 0b7e9eb2d8..d4276603f9 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -88,6 +88,84 @@ impl Database { .await } + pub async fn join_channel_internal( + &self, + channel_id: ChannelId, + user_id: UserId, + connection: ConnectionId, + environment: &str, + tx: &DatabaseTransaction, + ) -> Result<(JoinRoom, bool)> { + let mut joined = false; + + let channel = channel::Entity::find() + .filter(channel::Column::Id.eq(channel_id)) + .one(&*tx) + .await?; + + let mut role = self + .channel_role_for_user(channel_id, user_id, &*tx) + .await?; + + if role.is_none() { + if channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public) { + channel_member::Entity::insert(channel_member::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_id), + user_id: ActiveValue::Set(user_id), + accepted: ActiveValue::Set(true), + role: ActiveValue::Set(ChannelRole::Guest), + }) + .on_conflict( + OnConflict::columns([ + channel_member::Column::UserId, + channel_member::Column::ChannelId, + ]) + .update_columns([channel_member::Column::Accepted]) + .to_owned(), + ) + .exec(&*tx) + .await?; + + debug_assert!( + self.channel_role_for_user(channel_id, user_id, &*tx) + .await? + == Some(ChannelRole::Guest) + ); + + role = Some(ChannelRole::Guest); + joined = true; + } + } + + if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) { + Err(anyhow!("no such channel, or not allowed"))? + } + + let live_kit_room = format!("channel-{}", nanoid::nanoid!(30)); + let room_id = self + .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx) + .await?; + + self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx) + .await + .map(|jr| (jr, joined)) + } + + pub async fn join_channel( + &self, + channel_id: ChannelId, + user_id: UserId, + connection: ConnectionId, + environment: &str, + ) -> Result<(JoinRoom, bool)> { + self.transaction(move |tx| async move { + self.join_channel_internal(channel_id, user_id, connection, environment, &*tx) + .await + }) + .await + } + pub async fn set_channel_visibility( &self, channel_id: ChannelId, @@ -981,38 +1059,39 @@ impl Database { .await } - pub async fn get_or_create_channel_room( + pub(crate) async fn get_or_create_channel_room( &self, channel_id: ChannelId, live_kit_room: &str, - enviroment: &str, + environment: &str, + tx: &DatabaseTransaction, ) -> Result { - self.transaction(|tx| async move { - let tx = tx; + let room = room::Entity::find() + .filter(room::Column::ChannelId.eq(channel_id)) + .one(&*tx) + .await?; - let room = room::Entity::find() - .filter(room::Column::ChannelId.eq(channel_id)) - .one(&*tx) - .await?; + let room_id = if let Some(room) = room { + if let Some(env) = room.enviroment { + if &env != environment { + Err(anyhow!("must join using the {} release", env))?; + } + } + room.id + } else { + let result = room::Entity::insert(room::ActiveModel { + channel_id: ActiveValue::Set(Some(channel_id)), + live_kit_room: ActiveValue::Set(live_kit_room.to_string()), + enviroment: ActiveValue::Set(Some(environment.to_string())), + ..Default::default() + }) + .exec(&*tx) + .await?; - let room_id = if let Some(room) = room { - room.id - } else { - let result = room::Entity::insert(room::ActiveModel { - channel_id: ActiveValue::Set(Some(channel_id)), - live_kit_room: ActiveValue::Set(live_kit_room.to_string()), - enviroment: ActiveValue::Set(Some(enviroment.to_string())), - ..Default::default() - }) - .exec(&*tx) - .await?; + result.last_insert_id + }; - result.last_insert_id - }; - - Ok(room_id) - }) - .await + Ok(room_id) } // Insert an edge from the given channel to the given other channel. diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index 625615db5f..d2120495b0 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -300,99 +300,139 @@ impl Database { } } - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryParticipantIndices { - ParticipantIndex, + if channel_id.is_some() { + Err(anyhow!("tried to join channel call directly"))? } - let existing_participant_indices: Vec = room_participant::Entity::find() - .filter( - room_participant::Column::RoomId - .eq(room_id) - .and(room_participant::Column::ParticipantIndex.is_not_null()), - ) - .select_only() - .column(room_participant::Column::ParticipantIndex) - .into_values::<_, QueryParticipantIndices>() - .all(&*tx) + + let participant_index = self + .get_next_participant_index_internal(room_id, &*tx) .await?; - let mut participant_index = 0; - while existing_participant_indices.contains(&participant_index) { - participant_index += 1; - } - - if let Some(channel_id) = channel_id { - self.check_user_is_channel_member(channel_id, user_id, &*tx) - .await?; - - room_participant::Entity::insert_many([room_participant::ActiveModel { - room_id: ActiveValue::set(room_id), - user_id: ActiveValue::set(user_id), + let result = room_participant::Entity::update_many() + .filter( + Condition::all() + .add(room_participant::Column::RoomId.eq(room_id)) + .add(room_participant::Column::UserId.eq(user_id)) + .add(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .set(room_participant::ActiveModel { + participant_index: ActiveValue::Set(Some(participant_index)), answering_connection_id: ActiveValue::set(Some(connection.id as i32)), answering_connection_server_id: ActiveValue::set(Some(ServerId( connection.owner_id as i32, ))), answering_connection_lost: ActiveValue::set(false), - calling_user_id: ActiveValue::set(user_id), - calling_connection_id: ActiveValue::set(connection.id as i32), - calling_connection_server_id: ActiveValue::set(Some(ServerId( - connection.owner_id as i32, - ))), - participant_index: ActiveValue::Set(Some(participant_index)), ..Default::default() - }]) - .on_conflict( - OnConflict::columns([room_participant::Column::UserId]) - .update_columns([ - room_participant::Column::AnsweringConnectionId, - room_participant::Column::AnsweringConnectionServerId, - room_participant::Column::AnsweringConnectionLost, - room_participant::Column::ParticipantIndex, - ]) - .to_owned(), - ) + }) .exec(&*tx) .await?; - } else { - let result = room_participant::Entity::update_many() - .filter( - Condition::all() - .add(room_participant::Column::RoomId.eq(room_id)) - .add(room_participant::Column::UserId.eq(user_id)) - .add(room_participant::Column::AnsweringConnectionId.is_null()), - ) - .set(room_participant::ActiveModel { - participant_index: ActiveValue::Set(Some(participant_index)), - answering_connection_id: ActiveValue::set(Some(connection.id as i32)), - answering_connection_server_id: ActiveValue::set(Some(ServerId( - connection.owner_id as i32, - ))), - answering_connection_lost: ActiveValue::set(false), - ..Default::default() - }) - .exec(&*tx) - .await?; - if result.rows_affected == 0 { - Err(anyhow!("room does not exist or was already joined"))?; - } + if result.rows_affected == 0 { + Err(anyhow!("room does not exist or was already joined"))?; } let room = self.get_room(room_id, &tx).await?; - let channel_members = if let Some(channel_id) = channel_id { - self.get_channel_participants_internal(channel_id, &tx) - .await? - } else { - Vec::new() - }; Ok(JoinRoom { room, - channel_id, - channel_members, + channel_id: None, + channel_members: vec![], }) }) .await } + async fn get_next_participant_index_internal( + &self, + room_id: RoomId, + tx: &DatabaseTransaction, + ) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryParticipantIndices { + ParticipantIndex, + } + let existing_participant_indices: Vec = room_participant::Entity::find() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::ParticipantIndex.is_not_null()), + ) + .select_only() + .column(room_participant::Column::ParticipantIndex) + .into_values::<_, QueryParticipantIndices>() + .all(&*tx) + .await?; + + let mut participant_index = 0; + while existing_participant_indices.contains(&participant_index) { + participant_index += 1; + } + + Ok(participant_index) + } + + pub async fn channel_id_for_room(&self, room_id: RoomId) -> Result> { + self.transaction(|tx| async move { + let room: Option = room::Entity::find() + .filter(room::Column::Id.eq(room_id)) + .one(&*tx) + .await?; + + Ok(room.and_then(|room| room.channel_id)) + }) + .await + } + + pub(crate) async fn join_channel_room_internal( + &self, + channel_id: ChannelId, + room_id: RoomId, + user_id: UserId, + connection: ConnectionId, + tx: &DatabaseTransaction, + ) -> Result { + let participant_index = self + .get_next_participant_index_internal(room_id, &*tx) + .await?; + + room_participant::Entity::insert_many([room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(user_id), + answering_connection_id: ActiveValue::set(Some(connection.id as i32)), + answering_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + answering_connection_lost: ActiveValue::set(false), + calling_user_id: ActiveValue::set(user_id), + calling_connection_id: ActiveValue::set(connection.id as i32), + calling_connection_server_id: ActiveValue::set(Some(ServerId( + connection.owner_id as i32, + ))), + participant_index: ActiveValue::Set(Some(participant_index)), + ..Default::default() + }]) + .on_conflict( + OnConflict::columns([room_participant::Column::UserId]) + .update_columns([ + room_participant::Column::AnsweringConnectionId, + room_participant::Column::AnsweringConnectionServerId, + room_participant::Column::AnsweringConnectionLost, + room_participant::Column::ParticipantIndex, + ]) + .to_owned(), + ) + .exec(&*tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + let channel_members = self + .get_channel_participants_internal(channel_id, &tx) + .await?; + Ok(JoinRoom { + room, + channel_id: Some(channel_id), + channel_members, + }) + } + pub async fn rejoin_room( &self, rejoin_room: proto::RejoinRoom, diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index b969711232..9b6d8d1525 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -8,7 +8,7 @@ use crate::{ db::{ queries::channels::ChannelGraph, tests::{graph, TEST_RELEASE_CHANNEL}, - ChannelId, ChannelRole, Database, NewUserParams, UserId, + ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId, }, test_both_dbs, }; @@ -207,15 +207,11 @@ async fn test_joining_channels(db: &Arc) { .user_id; let channel_1 = db.create_root_channel("channel_1", user_1).await.unwrap(); - let room_1 = db - .get_or_create_channel_room(channel_1, "1", TEST_RELEASE_CHANNEL) - .await - .unwrap(); // can join a room with membership to its channel - let joined_room = db - .join_room( - room_1, + let (joined_room, _) = db + .join_channel( + channel_1, user_1, ConnectionId { owner_id, id: 1 }, TEST_RELEASE_CHANNEL, @@ -224,11 +220,12 @@ async fn test_joining_channels(db: &Arc) { .unwrap(); assert_eq!(joined_room.room.participants.len(), 1); + let room_id = RoomId::from_proto(joined_room.room.id); drop(joined_room); // cannot join a room without membership to its channel assert!(db .join_room( - room_1, + room_id, user_2, ConnectionId { owner_id, id: 1 }, TEST_RELEASE_CHANNEL diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index c3d8a25ab7..26ad2f281a 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -38,7 +38,7 @@ use lazy_static::lazy_static; use prometheus::{register_int_gauge, IntGauge}; use rpc::{ proto::{ - self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, + self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, JoinRoom, LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators, }, Connection, ConnectionId, Peer, Receipt, TypedEnvelope, @@ -977,6 +977,13 @@ async fn join_room( session: Session, ) -> Result<()> { let room_id = RoomId::from_proto(request.id); + + let channel_id = session.db().await.channel_id_for_room(room_id).await?; + + if let Some(channel_id) = channel_id { + return join_channel_internal(channel_id, Box::new(response), session).await; + } + let joined_room = { let room = session .db() @@ -992,16 +999,6 @@ async fn join_room( room.into_inner() }; - if let Some(channel_id) = joined_room.channel_id { - channel_updated( - channel_id, - &joined_room.room, - &joined_room.channel_members, - &session.peer, - &*session.connection_pool().await, - ) - } - for connection_id in session .connection_pool() .await @@ -1039,7 +1036,7 @@ async fn join_room( response.send(proto::JoinRoomResponse { room: Some(joined_room.room), - channel_id: joined_room.channel_id.map(|id| id.to_proto()), + channel_id: None, live_kit_connection_info, })?; @@ -2602,54 +2599,68 @@ async fn respond_to_channel_invite( db.respond_to_channel_invite(channel_id, session.user_id, request.accept) .await?; + if request.accept { + channel_membership_updated(db, channel_id, &session).await?; + } else { + let mut update = proto::UpdateChannels::default(); + update + .remove_channel_invitations + .push(channel_id.to_proto()); + session.peer.send(session.connection_id, update)?; + } + response.send(proto::Ack {})?; + + Ok(()) +} + +async fn channel_membership_updated( + db: tokio::sync::MutexGuard<'_, DbHandle>, + channel_id: ChannelId, + session: &Session, +) -> Result<(), crate::Error> { let mut update = proto::UpdateChannels::default(); update .remove_channel_invitations .push(channel_id.to_proto()); - if request.accept { - let result = db.get_channel_for_user(channel_id, session.user_id).await?; - update - .channels - .extend( - result - .channels - .channels - .into_iter() - .map(|channel| proto::Channel { - id: channel.id.to_proto(), - visibility: channel.visibility.into(), - name: channel.name, - }), - ); - update.unseen_channel_messages = result.channel_messages; - update.unseen_channel_buffer_changes = result.unseen_buffer_changes; - update.insert_edge = result.channels.edges; - update - .channel_participants - .extend( - result - .channel_participants - .into_iter() - .map(|(channel_id, user_ids)| proto::ChannelParticipants { - channel_id: channel_id.to_proto(), - participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(), - }), - ); - update - .channel_permissions - .extend( - result - .channels_with_admin_privileges - .into_iter() - .map(|channel_id| proto::ChannelPermission { - channel_id: channel_id.to_proto(), - role: proto::ChannelRole::Admin.into(), - }), - ); - } - session.peer.send(session.connection_id, update)?; - response.send(proto::Ack {})?; + let result = db.get_channel_for_user(channel_id, session.user_id).await?; + update.channels.extend( + result + .channels + .channels + .into_iter() + .map(|channel| proto::Channel { + id: channel.id.to_proto(), + visibility: channel.visibility.into(), + name: channel.name, + }), + ); + update.unseen_channel_messages = result.channel_messages; + update.unseen_channel_buffer_changes = result.unseen_buffer_changes; + update.insert_edge = result.channels.edges; + update + .channel_participants + .extend( + result + .channel_participants + .into_iter() + .map(|(channel_id, user_ids)| proto::ChannelParticipants { + channel_id: channel_id.to_proto(), + participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(), + }), + ); + update + .channel_permissions + .extend( + result + .channels_with_admin_privileges + .into_iter() + .map(|channel_id| proto::ChannelPermission { + channel_id: channel_id.to_proto(), + role: proto::ChannelRole::Admin.into(), + }), + ); + session.peer.send(session.connection_id, update)?; Ok(()) } @@ -2659,19 +2670,35 @@ async fn join_channel( session: Session, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); - let live_kit_room = format!("channel-{}", nanoid::nanoid!(30)); + join_channel_internal(channel_id, Box::new(response), session).await +} +trait JoinChannelInternalResponse { + fn send(self, result: proto::JoinRoomResponse) -> Result<()>; +} +impl JoinChannelInternalResponse for Response { + fn send(self, result: proto::JoinRoomResponse) -> Result<()> { + Response::::send(self, result) + } +} +impl JoinChannelInternalResponse for Response { + fn send(self, result: proto::JoinRoomResponse) -> Result<()> { + Response::::send(self, result) + } +} + +async fn join_channel_internal( + channel_id: ChannelId, + response: Box, + session: Session, +) -> Result<()> { let joined_room = { leave_room_for_session(&session).await?; let db = session.db().await; - let room_id = db - .get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME) - .await?; - - let joined_room = db - .join_room( - room_id, + let (joined_room, joined_channel) = db + .join_channel( + channel_id, session.user_id, session.connection_id, RELEASE_CHANNEL_NAME.as_str(), @@ -2698,9 +2725,13 @@ async fn join_channel( live_kit_connection_info, })?; + if joined_channel { + channel_membership_updated(db, channel_id, &session).await? + } + room_updated(&joined_room.room, &session.peer); - joined_room.into_inner() + joined_room }; channel_updated( @@ -2712,7 +2743,6 @@ async fn join_channel( ); update_user_contacts(session.user_id, &session).await?; - Ok(()) } diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index 95a672e76c..1700dfc5d3 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -912,6 +912,58 @@ async fn test_lost_channel_creation( ], ); } +#[gpui::test] +async fn test_guest_access( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channels = server + .make_channel_tree(&[("channel-a", None)], (&client_a, cx_a)) + .await; + let channel_a_id = channels[0]; + + let active_call_b = cx_b.read(ActiveCall::global); + + // should not be allowed to join + assert!(active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_a_id, cx)) + .await + .is_err()); + + client_a + .channel_store() + .update(cx_a, |channel_store, cx| { + channel_store.set_channel_visibility(channel_a_id, proto::ChannelVisibility::Public, cx) + }) + .await + .unwrap(); + + active_call_b + .update(cx_b, |call, cx| call.join_channel(channel_a_id, cx)) + .await + .unwrap(); + + deterministic.run_until_parked(); + + assert!(client_b + .channel_store() + .update(cx_b, |channel_store, _| channel_store + .channel_for_id(channel_a_id) + .is_some())); + + client_a.channel_store().update(cx_a, |channel_store, _| { + let participants = channel_store.channel_participants(channel_a_id); + assert_eq!(participants.len(), 1); + assert_eq!(participants[0].id, client_b.user_id().unwrap()); + }) +} #[gpui::test] async fn test_channel_moving( diff --git a/crates/collab_ui/src/collab_panel/channel_modal.rs b/crates/collab_ui/src/collab_panel/channel_modal.rs index bf04e4f7e6..da6edbde69 100644 --- a/crates/collab_ui/src/collab_panel/channel_modal.rs +++ b/crates/collab_ui/src/collab_panel/channel_modal.rs @@ -1,4 +1,4 @@ -use channel::{Channel, ChannelId, ChannelMembership, ChannelStore}; +use channel::{ChannelId, ChannelMembership, ChannelStore}; use client::{ proto::{self, ChannelRole, ChannelVisibility}, User, UserId, UserStore,