diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 6a3644aeea..a07b7d395d 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -135,7 +135,7 @@ impl ChannelChat { &mut self, message: MessageParams, cx: &mut ModelContext, - ) -> Result>> { + ) -> Result>> { if message.text.is_empty() { Err(anyhow!("message body can't be empty"))?; } @@ -176,15 +176,12 @@ impl ChannelChat { }); let response = request.await?; drop(outgoing_message_guard); - let message = ChannelMessage::from_proto( - response.message.ok_or_else(|| anyhow!("invalid message"))?, - &user_store, - &mut cx, - ) - .await?; + let response = response.message.ok_or_else(|| anyhow!("invalid message"))?; + let id = response.id; + let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?; this.update(&mut cx, |this, cx| { this.insert_messages(SumTree::from_item(message, &()), cx); - Ok(()) + Ok(id) }) })) } diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 2eb229046a..5f3d0fc0c7 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -386,6 +386,13 @@ impl Contact { pub type NotificationBatch = Vec<(UserId, proto::Notification)>; +pub struct CreatedChannelMessage { + pub message_id: MessageId, + pub participant_connection_ids: Vec, + pub channel_members: Vec, + pub notifications: NotificationBatch, +} + #[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] pub struct Invite { pub email_address: String, diff --git a/crates/collab/src/db/queries/messages.rs b/crates/collab/src/db/queries/messages.rs index 9515f856c3..ca364c5596 100644 --- a/crates/collab/src/db/queries/messages.rs +++ b/crates/collab/src/db/queries/messages.rs @@ -1,4 +1,5 @@ use super::*; +use futures::Stream; use sea_orm::TryInsertResult; use time::OffsetDateTime; @@ -88,61 +89,46 @@ impl Database { condition = condition.add(channel_message::Column::Id.lt(before_message_id)); } - let mut rows = channel_message::Entity::find() + let rows = channel_message::Entity::find() .filter(condition) .order_by_desc(channel_message::Column::Id) .limit(count as u64) .stream(&*tx) .await?; - let mut messages = Vec::new(); - while let Some(row) = rows.next().await { - let row = row?; - let nonce = row.nonce.as_u64_pair(); - messages.push(proto::ChannelMessage { - id: row.id.to_proto(), - sender_id: row.sender_id.to_proto(), - body: row.body, - timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, - mentions: vec![], - nonce: Some(proto::Nonce { - upper_half: nonce.0, - lower_half: nonce.1, - }), - }); - } - drop(rows); - messages.reverse(); + self.load_channel_messages(rows, &*tx).await + }) + .await + } - let mut mentions = channel_message_mention::Entity::find() - .filter( - channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)), - ) - .order_by_asc(channel_message_mention::Column::MessageId) - .order_by_asc(channel_message_mention::Column::StartOffset) + pub async fn get_channel_messages_by_id( + &self, + user_id: UserId, + message_ids: &[MessageId], + ) -> Result> { + self.transaction(|tx| async move { + let rows = channel_message::Entity::find() + .filter(channel_message::Column::Id.is_in(message_ids.iter().copied())) + .order_by_desc(channel_message::Column::Id) .stream(&*tx) .await?; - let mut message_ix = 0; - while let Some(mention) = mentions.next().await { - let mention = mention?; - let message_id = mention.message_id.to_proto(); - while let Some(message) = messages.get_mut(message_ix) { - if message.id < message_id { - message_ix += 1; - } else { - if message.id == message_id { - message.mentions.push(proto::ChatMention { - range: Some(proto::Range { - start: mention.start_offset as u64, - end: mention.end_offset as u64, - }), - user_id: mention.user_id.to_proto(), - }); - } - break; - } - } + let mut channel_ids = HashSet::::default(); + let messages = self + .load_channel_messages( + rows.map(|row| { + row.map(|row| { + channel_ids.insert(row.channel_id); + row + }) + }), + &*tx, + ) + .await?; + + for channel_id in channel_ids { + self.check_user_is_channel_member(channel_id, user_id, &*tx) + .await?; } Ok(messages) @@ -150,6 +136,62 @@ impl Database { .await } + async fn load_channel_messages( + &self, + mut rows: impl Send + Unpin + Stream>, + tx: &DatabaseTransaction, + ) -> Result> { + let mut messages = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + let nonce = row.nonce.as_u64_pair(); + messages.push(proto::ChannelMessage { + id: row.id.to_proto(), + sender_id: row.sender_id.to_proto(), + body: row.body, + timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, + mentions: vec![], + nonce: Some(proto::Nonce { + upper_half: nonce.0, + lower_half: nonce.1, + }), + }); + } + drop(rows); + messages.reverse(); + + let mut mentions = channel_message_mention::Entity::find() + .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id))) + .order_by_asc(channel_message_mention::Column::MessageId) + .order_by_asc(channel_message_mention::Column::StartOffset) + .stream(&*tx) + .await?; + + let mut message_ix = 0; + while let Some(mention) = mentions.next().await { + let mention = mention?; + let message_id = mention.message_id.to_proto(); + while let Some(message) = messages.get_mut(message_ix) { + if message.id < message_id { + message_ix += 1; + } else { + if message.id == message_id { + message.mentions.push(proto::ChatMention { + range: Some(proto::Range { + start: mention.start_offset as u64, + end: mention.end_offset as u64, + }), + user_id: mention.user_id.to_proto(), + }); + } + break; + } + } + } + + Ok(messages) + } + pub async fn create_channel_message( &self, channel_id: ChannelId, @@ -158,7 +200,7 @@ impl Database { mentions: &[proto::ChatMention], timestamp: OffsetDateTime, nonce: u128, - ) -> Result<(MessageId, Vec, Vec)> { + ) -> Result { self.transaction(|tx| async move { let mut rows = channel_chat_participant::Entity::find() .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) @@ -206,10 +248,13 @@ impl Database { .await?; let message_id; + let mut notifications = Vec::new(); match result { TryInsertResult::Inserted(result) => { message_id = result.last_insert_id; - let models = mentions + let mentioned_user_ids = + mentions.iter().map(|m| m.user_id).collect::>(); + let mentions = mentions .iter() .filter_map(|mention| { let range = mention.range.as_ref()?; @@ -226,12 +271,28 @@ impl Database { }) }) .collect::>(); - if !models.is_empty() { - channel_message_mention::Entity::insert_many(models) + if !mentions.is_empty() { + channel_message_mention::Entity::insert_many(mentions) .exec(&*tx) .await?; } + for mentioned_user in mentioned_user_ids { + notifications.extend( + self.create_notification( + UserId::from_proto(mentioned_user), + rpc::Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: user_id.to_proto(), + channel_id: channel_id.to_proto(), + }, + false, + &*tx, + ) + .await?, + ); + } + self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) .await?; } @@ -250,7 +311,12 @@ impl Database { .await?; channel_members.retain(|member| !participant_user_ids.contains(member)); - Ok((message_id, participant_connection_ids, channel_members)) + Ok(CreatedChannelMessage { + message_id, + participant_connection_ids, + channel_members, + notifications, + }) }) .await } diff --git a/crates/collab/src/db/tests/message_tests.rs b/crates/collab/src/db/tests/message_tests.rs index 49c4ebf4a2..97b3142930 100644 --- a/crates/collab/src/db/tests/message_tests.rs +++ b/crates/collab/src/db/tests/message_tests.rs @@ -35,7 +35,7 @@ async fn test_channel_message_retrieval(db: &Arc) { ) .await .unwrap() - .0 + .message_id .to_proto(), ); } @@ -109,7 +109,7 @@ async fn test_channel_message_nonces(db: &Arc) { ) .await .unwrap() - .0; + .message_id; let id2 = db .create_channel_message( channel, @@ -121,7 +121,7 @@ async fn test_channel_message_nonces(db: &Arc) { ) .await .unwrap() - .0; + .message_id; let id3 = db .create_channel_message( channel, @@ -133,7 +133,7 @@ async fn test_channel_message_nonces(db: &Arc) { ) .await .unwrap() - .0; + .message_id; let id4 = db .create_channel_message( channel, @@ -145,7 +145,7 @@ async fn test_channel_message_nonces(db: &Arc) { ) .await .unwrap() - .0; + .message_id; // As a different user, reuse one of the same nonces. This request succeeds // and returns a different id. @@ -160,7 +160,7 @@ async fn test_channel_message_nonces(db: &Arc) { ) .await .unwrap() - .0; + .message_id; assert_ne!(id1, id2); assert_eq!(id1, id3); @@ -235,24 +235,27 @@ async fn test_unseen_channel_messages(db: &Arc) { .await .unwrap(); - let (second_message, _, _) = db + let second_message = db .create_channel_message(channel_1, user, "1_2", &[], OffsetDateTime::now_utc(), 2) .await - .unwrap(); + .unwrap() + .message_id; - let (third_message, _, _) = db + let third_message = db .create_channel_message(channel_1, user, "1_3", &[], OffsetDateTime::now_utc(), 3) .await - .unwrap(); + .unwrap() + .message_id; db.join_channel_chat(channel_2, user_connection_id, user) .await .unwrap(); - let (fourth_message, _, _) = db + let fourth_message = db .create_channel_message(channel_2, user, "2_1", &[], OffsetDateTime::now_utc(), 4) .await - .unwrap(); + .unwrap() + .message_id; // Check that observer has new messages let unseen_messages = db diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 191ab76274..7ff1dc7717 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -3,8 +3,8 @@ mod connection_pool; use crate::{ auth, db::{ - self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, Database, MessageId, - ProjectId, RoomId, ServerId, User, UserId, + self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, CreatedChannelMessage, + Database, MessageId, ProjectId, RoomId, ServerId, User, UserId, }, executor::Executor, AppState, Result, @@ -271,6 +271,7 @@ impl Server { .add_request_handler(send_channel_message) .add_request_handler(remove_channel_message) .add_request_handler(get_channel_messages) + .add_request_handler(get_channel_messages_by_id) .add_request_handler(get_notifications) .add_request_handler(link_channel) .add_request_handler(unlink_channel) @@ -2969,7 +2970,12 @@ async fn send_channel_message( .ok_or_else(|| anyhow!("nonce can't be blank"))?; let channel_id = ChannelId::from_proto(request.channel_id); - let (message_id, connection_ids, non_participants) = session + let CreatedChannelMessage { + message_id, + participant_connection_ids, + channel_members, + notifications, + } = session .db() .await .create_channel_message( @@ -2989,15 +2995,19 @@ async fn send_channel_message( timestamp: timestamp.unix_timestamp() as u64, nonce: Some(nonce), }; - broadcast(Some(session.connection_id), connection_ids, |connection| { - session.peer.send( - connection, - proto::ChannelMessageSent { - channel_id: channel_id.to_proto(), - message: Some(message.clone()), - }, - ) - }); + broadcast( + Some(session.connection_id), + participant_connection_ids, + |connection| { + session.peer.send( + connection, + proto::ChannelMessageSent { + channel_id: channel_id.to_proto(), + message: Some(message.clone()), + }, + ) + }, + ); response.send(proto::SendChannelMessageResponse { message: Some(message), })?; @@ -3005,7 +3015,7 @@ async fn send_channel_message( let pool = &*session.connection_pool().await; broadcast( None, - non_participants + channel_members .iter() .flat_map(|user_id| pool.user_connection_ids(*user_id)), |peer_id| { @@ -3021,6 +3031,7 @@ async fn send_channel_message( ) }, ); + send_notifications(pool, &session.peer, notifications); Ok(()) } @@ -3129,6 +3140,28 @@ async fn get_channel_messages( Ok(()) } +async fn get_channel_messages_by_id( + request: proto::GetChannelMessagesById, + response: Response, + session: Session, +) -> Result<()> { + let message_ids = request + .message_ids + .iter() + .map(|id| MessageId::from_proto(*id)) + .collect::>(); + let messages = session + .db() + .await + .get_channel_messages_by_id(session.user_id, &message_ids) + .await?; + response.send(proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + async fn get_notifications( request: proto::GetNotifications, response: Response, diff --git a/crates/collab/src/tests/channel_message_tests.rs b/crates/collab/src/tests/channel_message_tests.rs index 7b252bef8f..0e63f96bf9 100644 --- a/crates/collab/src/tests/channel_message_tests.rs +++ b/crates/collab/src/tests/channel_message_tests.rs @@ -2,6 +2,7 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use channel::{ChannelChat, ChannelMessageId, MessageParams}; use collab_ui::chat_panel::ChatPanel; use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext}; +use rpc::Notification; use std::sync::Arc; use workspace::dock::Panel; @@ -38,7 +39,7 @@ async fn test_basic_channel_messages( .await .unwrap(); - channel_chat_a + let message_id = channel_chat_a .update(cx_a, |c, cx| { c.send_message( MessageParams { @@ -91,6 +92,27 @@ async fn test_basic_channel_messages( ); }); } + + client_c.notification_store().update(cx_c, |store, _| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 1); + assert_eq!( + store.notification_at(0).unwrap().notification, + Notification::ChannelMessageMention { + message_id, + sender_id: client_a.id(), + channel_id, + } + ); + assert_eq!( + store.notification_at(1).unwrap().notification, + Notification::ChannelInvitation { + channel_id, + channel_name: "the-channel".to_string(), + inviter_id: client_a.id() + } + ); + }); } #[gpui::test]