From da5a77badf9f39ac7e2614b2c33fd65fc1857629 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 7 Sep 2023 12:24:25 -0700 Subject: [PATCH] Start work on restoring server-side code for chat messages --- crates/channel/src/channel.rs | 1 + crates/channel/src/channel_chat.rs | 8 +- .../20221109000000_test_schema.sql | 20 +++ .../20230907114200_add_channel_messages.sql | 19 +++ crates/collab/src/db/ids.rs | 2 + crates/collab/src/db/queries.rs | 1 + crates/collab/src/db/queries/messages.rs | 152 ++++++++++++++++++ crates/collab/src/db/tables.rs | 2 + crates/collab/src/db/tables/channel.rs | 8 + .../src/db/tables/channel_chat_participant.rs | 41 +++++ .../collab/src/db/tables/channel_message.rs | 45 ++++++ crates/collab/src/db/tests.rs | 1 + crates/collab/src/db/tests/message_tests.rs | 53 ++++++ crates/collab/src/rpc.rs | 119 +++++++++++++- crates/collab/src/tests.rs | 1 + .../collab/src/tests/channel_message_tests.rs | 56 +++++++ 16 files changed, 524 insertions(+), 5 deletions(-) create mode 100644 crates/collab/migrations/20230907114200_add_channel_messages.sql create mode 100644 crates/collab/src/db/queries/messages.rs create mode 100644 crates/collab/src/db/tables/channel_chat_participant.rs create mode 100644 crates/collab/src/db/tables/channel_message.rs create mode 100644 crates/collab/src/db/tests/message_tests.rs create mode 100644 crates/collab/src/tests/channel_message_tests.rs diff --git a/crates/channel/src/channel.rs b/crates/channel/src/channel.rs index b4acc0c25f..37f1c0ce44 100644 --- a/crates/channel/src/channel.rs +++ b/crates/channel/src/channel.rs @@ -14,4 +14,5 @@ mod channel_store_tests; pub fn init(client: &Arc) { channel_buffer::init(client); + channel_chat::init(client); } diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index ebe491e5b8..3916189363 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -57,6 +57,10 @@ pub enum ChannelChatEvent { }, } +pub fn init(client: &Arc) { + client.add_model_message_handler(ChannelChat::handle_message_sent); +} + impl Entity for ChannelChat { type Event = ChannelChatEvent; @@ -70,10 +74,6 @@ impl Entity for ChannelChat { } impl ChannelChat { - pub fn init(rpc: &Arc) { - rpc.add_model_message_handler(Self::handle_message_sent); - } - pub async fn new( channel: Arc, user_store: ModelHandle, diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 80477dcb3c..ab12039b10 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -192,6 +192,26 @@ CREATE TABLE "channels" ( "created_at" TIMESTAMP NOT NULL DEFAULT now ); +CREATE TABLE IF NOT EXISTS "channel_chat_participants" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "user_id" INTEGER NOT NULL REFERENCES users (id), + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE +); +CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id"); + +CREATE TABLE IF NOT EXISTS "channel_messages" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "sender_id" INTEGER NOT NULL REFERENCES users (id), + "body" TEXT NOT NULL, + "sent_at" TIMESTAMP, + "nonce" BLOB NOT NULL +); +CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); +CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce"); + CREATE TABLE "channel_paths" ( "id_path" TEXT NOT NULL PRIMARY KEY, "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE diff --git a/crates/collab/migrations/20230907114200_add_channel_messages.sql b/crates/collab/migrations/20230907114200_add_channel_messages.sql new file mode 100644 index 0000000000..abe7753ca6 --- /dev/null +++ b/crates/collab/migrations/20230907114200_add_channel_messages.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS "channel_messages" ( + "id" SERIAL PRIMARY KEY, + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "sender_id" INTEGER NOT NULL REFERENCES users (id), + "body" TEXT NOT NULL, + "sent_at" TIMESTAMP, + "nonce" UUID NOT NULL +); +CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); +CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce"); + +CREATE TABLE IF NOT EXISTS "channel_chat_participants" ( + "id" SERIAL PRIMARY KEY, + "user_id" INTEGER NOT NULL REFERENCES users (id), + "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE +); +CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id"); diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index b33ea57183..865a39fd71 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -112,8 +112,10 @@ fn value_to_integer(v: Value) -> Result { id_type!(BufferId); id_type!(AccessTokenId); +id_type!(ChannelChatParticipantId); id_type!(ChannelId); id_type!(ChannelMemberId); +id_type!(MessageId); id_type!(ContactId); id_type!(FollowerId); id_type!(RoomId); diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 09a8f073b4..54db0663d2 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -4,6 +4,7 @@ pub mod access_tokens; pub mod buffers; pub mod channels; pub mod contacts; +pub mod messages; pub mod projects; pub mod rooms; pub mod servers; diff --git a/crates/collab/src/db/queries/messages.rs b/crates/collab/src/db/queries/messages.rs new file mode 100644 index 0000000000..e8b9ec7416 --- /dev/null +++ b/crates/collab/src/db/queries/messages.rs @@ -0,0 +1,152 @@ +use super::*; +use time::OffsetDateTime; + +impl Database { + pub async fn join_channel_chat( + &self, + channel_id: ChannelId, + connection_id: ConnectionId, + user_id: UserId, + ) -> Result<()> { + self.transaction(|tx| async move { + self.check_user_is_channel_member(channel_id, user_id, &*tx) + .await?; + channel_chat_participant::ActiveModel { + id: ActiveValue::NotSet, + channel_id: ActiveValue::Set(channel_id), + user_id: ActiveValue::Set(user_id), + connection_id: ActiveValue::Set(connection_id.id as i32), + connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)), + } + .insert(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn leave_channel_chat( + &self, + channel_id: ChannelId, + connection_id: ConnectionId, + _user_id: UserId, + ) -> Result<()> { + self.transaction(|tx| async move { + channel_chat_participant::Entity::delete_many() + .filter( + Condition::all() + .add( + channel_chat_participant::Column::ConnectionServerId + .eq(connection_id.owner_id), + ) + .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)) + .add(channel_chat_participant::Column::ChannelId.eq(channel_id)), + ) + .exec(&*tx) + .await?; + + Ok(()) + }) + .await + } + + pub async fn get_channel_messages( + &self, + channel_id: ChannelId, + user_id: UserId, + count: usize, + before_message_id: Option, + ) -> Result> { + self.transaction(|tx| async move { + self.check_user_is_channel_member(channel_id, user_id, &*tx) + .await?; + + let mut condition = + Condition::all().add(channel_message::Column::ChannelId.eq(channel_id)); + + if let Some(before_message_id) = before_message_id { + condition = condition.add(channel_message::Column::Id.lt(before_message_id)); + } + + let mut rows = channel_message::Entity::find() + .filter(condition) + .limit(count as u64) + .stream(&*tx) + .await?; + + let mut messages = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + let nonce = row.nonce.as_u64_pair(); + messages.push(proto::ChannelMessage { + id: row.id.to_proto(), + sender_id: row.sender_id.to_proto(), + body: row.body, + timestamp: row.sent_at.unix_timestamp() as u64, + nonce: Some(proto::Nonce { + upper_half: nonce.0, + lower_half: nonce.1, + }), + }); + } + + Ok(messages) + }) + .await + } + + pub async fn create_channel_message( + &self, + channel_id: ChannelId, + user_id: UserId, + body: &str, + timestamp: OffsetDateTime, + nonce: u128, + ) -> Result<(MessageId, Vec)> { + self.transaction(|tx| async move { + let mut rows = channel_chat_participant::Entity::find() + .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) + .stream(&*tx) + .await?; + + let mut is_participant = false; + let mut participant_connection_ids = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + if row.user_id == user_id { + is_participant = true; + } + participant_connection_ids.push(row.connection()); + } + drop(rows); + + if !is_participant { + Err(anyhow!("not a chat participant"))?; + } + + let message = channel_message::Entity::insert(channel_message::ActiveModel { + channel_id: ActiveValue::Set(channel_id), + sender_id: ActiveValue::Set(user_id), + body: ActiveValue::Set(body.to_string()), + sent_at: ActiveValue::Set(timestamp), + nonce: ActiveValue::Set(Uuid::from_u128(nonce)), + id: ActiveValue::NotSet, + }) + .on_conflict( + OnConflict::column(channel_message::Column::Nonce) + .update_column(channel_message::Column::Nonce) + .to_owned(), + ) + .exec(&*tx) + .await?; + + #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] + enum QueryConnectionId { + ConnectionId, + } + + Ok((message.last_insert_id, participant_connection_ids)) + }) + .await + } +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index 1765cee065..81200df3e7 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -4,7 +4,9 @@ pub mod buffer_operation; pub mod buffer_snapshot; pub mod channel; pub mod channel_buffer_collaborator; +pub mod channel_chat_participant; pub mod channel_member; +pub mod channel_message; pub mod channel_path; pub mod contact; pub mod feature_flag; diff --git a/crates/collab/src/db/tables/channel.rs b/crates/collab/src/db/tables/channel.rs index 05895ede4c..54f12defc1 100644 --- a/crates/collab/src/db/tables/channel.rs +++ b/crates/collab/src/db/tables/channel.rs @@ -21,6 +21,8 @@ pub enum Relation { Member, #[sea_orm(has_many = "super::channel_buffer_collaborator::Entity")] BufferCollaborators, + #[sea_orm(has_many = "super::channel_chat_participant::Entity")] + ChatParticipants, } impl Related for Entity { @@ -46,3 +48,9 @@ impl Related for Entity { Relation::BufferCollaborators.def() } } + +impl Related for Entity { + fn to() -> RelationDef { + Relation::ChatParticipants.def() + } +} diff --git a/crates/collab/src/db/tables/channel_chat_participant.rs b/crates/collab/src/db/tables/channel_chat_participant.rs new file mode 100644 index 0000000000..f3ef36c289 --- /dev/null +++ b/crates/collab/src/db/tables/channel_chat_participant.rs @@ -0,0 +1,41 @@ +use crate::db::{ChannelChatParticipantId, ChannelId, ServerId, UserId}; +use rpc::ConnectionId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_chat_participants")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ChannelChatParticipantId, + pub channel_id: ChannelId, + pub user_id: UserId, + pub connection_id: i32, + pub connection_server_id: ServerId, +} + +impl Model { + pub fn connection(&self) -> ConnectionId { + ConnectionId { + owner_id: self.connection_server_id.0 as u32, + id: self.connection_id as u32, + } + } +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel::Entity", + from = "Column::ChannelId", + to = "super::channel::Column::Id" + )] + Channel, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Channel.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/channel_message.rs b/crates/collab/src/db/tables/channel_message.rs new file mode 100644 index 0000000000..d043d5b668 --- /dev/null +++ b/crates/collab/src/db/tables/channel_message.rs @@ -0,0 +1,45 @@ +use crate::db::{ChannelId, MessageId, UserId}; +use sea_orm::entity::prelude::*; +use time::OffsetDateTime; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_messages")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: MessageId, + pub channel_id: ChannelId, + pub sender_id: UserId, + pub body: String, + pub sent_at: OffsetDateTime, + pub nonce: Uuid, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel::Entity", + from = "Column::ChannelId", + to = "super::channel::Column::Id" + )] + Channel, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::SenderId", + to = "super::user::Column::Id" + )] + Sender, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Channel.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Sender.def() + } +} diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index ee961006cb..1d6c550865 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -1,6 +1,7 @@ mod buffer_tests; mod db_tests; mod feature_flag_tests; +mod message_tests; use super::*; use gpui::executor::Background; diff --git a/crates/collab/src/db/tests/message_tests.rs b/crates/collab/src/db/tests/message_tests.rs new file mode 100644 index 0000000000..1dd686a1d2 --- /dev/null +++ b/crates/collab/src/db/tests/message_tests.rs @@ -0,0 +1,53 @@ +use crate::{ + db::{Database, NewUserParams}, + test_both_dbs, +}; +use std::sync::Arc; +use time::OffsetDateTime; + +test_both_dbs!( + test_channel_message_nonces, + test_channel_message_nonces_postgres, + test_channel_message_nonces_sqlite +); + +async fn test_channel_message_nonces(db: &Arc) { + let user = db + .create_user( + "user@example.com", + false, + NewUserParams { + github_login: "user".into(), + github_user_id: 1, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + let channel = db + .create_channel("channel", None, "room", user) + .await + .unwrap(); + + let msg1_id = db + .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg2_id = db + .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); + let msg3_id = db + .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg4_id = db + .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); + + assert_ne!(msg1_id, msg2_id); + assert_eq!(msg1_id, msg3_id); + assert_eq!(msg2_id, msg4_id); +} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index e454fcbb9e..b508c45dc5 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,7 +2,10 @@ mod connection_pool; use crate::{ auth, - db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId}, + db::{ + self, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, ServerId, User, + UserId, + }, executor::Executor, AppState, Result, }; @@ -56,6 +59,7 @@ use std::{ }, time::{Duration, Instant}, }; +use time::OffsetDateTime; use tokio::sync::{watch, Semaphore}; use tower::ServiceBuilder; use tracing::{info_span, instrument, Instrument}; @@ -63,6 +67,9 @@ use tracing::{info_span, instrument, Instrument}; pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10); +const MESSAGE_COUNT_PER_PAGE: usize = 100; +const MAX_MESSAGE_LEN: usize = 1024; + lazy_static! { static ref METRIC_CONNECTIONS: IntGauge = register_int_gauge!("connections", "number of connections").unwrap(); @@ -255,6 +262,10 @@ impl Server { .add_request_handler(get_channel_members) .add_request_handler(respond_to_channel_invite) .add_request_handler(join_channel) + .add_request_handler(join_channel_chat) + .add_message_handler(leave_channel_chat) + .add_request_handler(send_channel_message) + .add_request_handler(get_channel_messages) .add_request_handler(follow) .add_message_handler(unfollow) .add_message_handler(update_followers) @@ -2641,6 +2652,112 @@ fn channel_buffer_updated( }); } +async fn send_channel_message( + request: proto::SendChannelMessage, + response: Response, + session: Session, +) -> Result<()> { + // Validate the message body. + let body = request.body.trim().to_string(); + if body.len() > MAX_MESSAGE_LEN { + return Err(anyhow!("message is too long"))?; + } + if body.is_empty() { + return Err(anyhow!("message can't be blank"))?; + } + + let timestamp = OffsetDateTime::now_utc(); + let nonce = request + .nonce + .ok_or_else(|| anyhow!("nonce can't be blank"))?; + + let channel_id = ChannelId::from_proto(request.channel_id); + let (message_id, connection_ids) = session + .db() + .await + .create_channel_message( + channel_id, + session.user_id, + &body, + timestamp, + nonce.clone().into(), + ) + .await?; + let message = proto::ChannelMessage { + sender_id: session.user_id.to_proto(), + id: message_id.to_proto(), + body, + timestamp: timestamp.unix_timestamp() as u64, + nonce: Some(nonce), + }; + broadcast(Some(session.connection_id), connection_ids, |connection| { + session.peer.send( + connection, + proto::ChannelMessageSent { + channel_id: channel_id.to_proto(), + message: Some(message.clone()), + }, + ) + }); + response.send(proto::SendChannelMessageResponse { + message: Some(message), + })?; + Ok(()) +} + +async fn join_channel_chat( + request: proto::JoinChannelChat, + response: Response, + session: Session, +) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + + let db = session.db().await; + db.join_channel_chat(channel_id, session.connection_id, session.user_id) + .await?; + let messages = db + .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None) + .await?; + response.send(proto::JoinChannelChatResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + +async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + session + .db() + .await + .leave_channel_chat(channel_id, session.connection_id, session.user_id) + .await?; + Ok(()) +} + +async fn get_channel_messages( + request: proto::GetChannelMessages, + response: Response, + session: Session, +) -> Result<()> { + let channel_id = ChannelId::from_proto(request.channel_id); + let messages = session + .db() + .await + .get_channel_messages( + channel_id, + session.user_id, + MESSAGE_COUNT_PER_PAGE, + Some(MessageId::from_proto(request.before_message_id)), + ) + .await?; + response.send(proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = session diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 3000f0d8c3..b0f5b96fde 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -2,6 +2,7 @@ use call::Room; use gpui::{ModelHandle, TestAppContext}; mod channel_buffer_tests; +mod channel_message_tests; mod channel_tests; mod integration_tests; mod random_channel_buffer_tests; diff --git a/crates/collab/src/tests/channel_message_tests.rs b/crates/collab/src/tests/channel_message_tests.rs new file mode 100644 index 0000000000..e7afd136f3 --- /dev/null +++ b/crates/collab/src/tests/channel_message_tests.rs @@ -0,0 +1,56 @@ +use crate::tests::TestServer; +use gpui::{executor::Deterministic, TestAppContext}; +use std::sync::Arc; + +#[gpui::test] +async fn test_basic_channel_messages( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let channel_id = server + .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)]) + .await; + + let channel_chat_a = client_a + .channel_store() + .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + let channel_chat_b = client_b + .channel_store() + .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + + channel_chat_a + .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) + .await + .unwrap(); + channel_chat_a + .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap()) + .await + .unwrap(); + + deterministic.run_until_parked(); + channel_chat_b + .update(cx_b, |c, cx| c.send_message("three".into(), cx).unwrap()) + .await + .unwrap(); + + deterministic.run_until_parked(); + channel_chat_a.update(cx_a, |c, _| { + assert_eq!( + c.messages() + .iter() + .map(|m| m.body.as_str()) + .collect::>(), + vec!["one", "two", "three"] + ); + }) +}