diff --git a/Cargo.lock b/Cargo.lock index 01153ca0f8..a426a6a1ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6403,6 +6403,7 @@ dependencies = [ "serde_derive", "smol", "smol-timeout", + "strum", "tempdir", "tracing", "util", @@ -6623,6 +6624,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + [[package]] name = "rustybuzz" version = "0.3.0" @@ -7698,6 +7705,22 @@ name = "strum" version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.37", +] [[package]] name = "subtle" diff --git a/Cargo.toml b/Cargo.toml index 532610efd6..adb7fedb26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,6 +112,7 @@ serde_derive = { version = "1.0", features = ["deserialize_in_place"] } serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] } smallvec = { version = "1.6", features = ["union"] } smol = { version = "1.2" } +strum = { version = "0.25.0", features = ["derive"] } sysinfo = "0.29.10" tempdir = { version = "0.3.7" } thiserror = { version = "1.0.29" } diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 5a84bfd796..0e811d8455 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -312,3 +312,22 @@ CREATE TABLE IF NOT EXISTS "observed_channel_messages" ( ); CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id"); + +CREATE TABLE "notification_kinds" ( + "id" INTEGER PRIMARY KEY NOT NULL, + "name" VARCHAR NOT NULL, +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE "notifications" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "created_at" TIMESTAMP NOT NULL default now, + "recipent_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "entity_id_1" INTEGER, + "entity_id_2" INTEGER +); + +CREATE INDEX "index_notifications_on_recipient_id" ON "notifications" ("recipient_id"); diff --git a/crates/collab/migrations/20231004130100_create_notifications.sql b/crates/collab/migrations/20231004130100_create_notifications.sql new file mode 100644 index 0000000000..e0c7b290b4 --- /dev/null +++ b/crates/collab/migrations/20231004130100_create_notifications.sql @@ -0,0 +1,18 @@ +CREATE TABLE "notification_kinds" ( + "id" INTEGER PRIMARY KEY NOT NULL, + "name" VARCHAR NOT NULL, +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE notifications ( + "id" SERIAL PRIMARY KEY, + "created_at" TIMESTAMP NOT NULL DEFAULT now(), + "recipent_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "is_read" BOOLEAN NOT NULL DEFAULT FALSE + "entity_id_1" INTEGER, + "entity_id_2" INTEGER +); + +CREATE INDEX "index_notifications_on_recipient_id" ON "notifications" ("recipient_id"); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e60b7cc33d..56e7c0d942 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -20,7 +20,7 @@ use rpc::{ }; use sea_orm::{ entity::prelude::*, - sea_query::{Alias, Expr, OnConflict, Query}, + sea_query::{Alias, Expr, OnConflict}, ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, TransactionTrait, diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 23bb9e53bf..b5873a152f 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -80,3 +80,4 @@ id_type!(SignupId); id_type!(UserId); id_type!(ChannelBufferCollaboratorId); id_type!(FlagId); +id_type!(NotificationId); diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 80bd8704b2..629e26f1a9 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -5,6 +5,7 @@ pub mod buffers; pub mod channels; pub mod contacts; pub mod messages; +pub mod notifications; pub mod projects; pub mod rooms; pub mod servers; diff --git a/crates/collab/src/db/queries/access_tokens.rs b/crates/collab/src/db/queries/access_tokens.rs index def9428a2b..589b6483df 100644 --- a/crates/collab/src/db/queries/access_tokens.rs +++ b/crates/collab/src/db/queries/access_tokens.rs @@ -1,4 +1,5 @@ use super::*; +use sea_orm::sea_query::Query; impl Database { pub async fn create_access_token( diff --git a/crates/collab/src/db/queries/notifications.rs b/crates/collab/src/db/queries/notifications.rs new file mode 100644 index 0000000000..2907ad85b7 --- /dev/null +++ b/crates/collab/src/db/queries/notifications.rs @@ -0,0 +1,140 @@ +use super::*; +use rpc::{Notification, NotificationEntityKind, NotificationKind}; + +impl Database { + pub async fn ensure_notification_kinds(&self) -> Result<()> { + self.transaction(|tx| async move { + notification_kind::Entity::insert_many(NotificationKind::all().map(|kind| { + notification_kind::ActiveModel { + id: ActiveValue::Set(kind as i32), + name: ActiveValue::Set(kind.to_string()), + } + })) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn get_notifications( + &self, + recipient_id: UserId, + limit: usize, + ) -> Result { + self.transaction(|tx| async move { + let mut result = proto::AddNotifications::default(); + + let mut rows = notification::Entity::find() + .filter(notification::Column::RecipientId.eq(recipient_id)) + .order_by_desc(notification::Column::Id) + .limit(limit as u64) + .stream(&*tx) + .await?; + + let mut user_ids = Vec::new(); + let mut channel_ids = Vec::new(); + let mut message_ids = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + + let Some(kind) = NotificationKind::from_i32(row.kind) else { + continue; + }; + let Some(notification) = Notification::from_fields( + kind, + [ + row.entity_id_1.map(|id| id as u64), + row.entity_id_2.map(|id| id as u64), + row.entity_id_3.map(|id| id as u64), + ], + ) else { + continue; + }; + + // Gather the ids of all associated entities. + let (_, associated_entities) = notification.to_fields(); + for entity in associated_entities { + let Some((id, kind)) = entity else { + break; + }; + match kind { + NotificationEntityKind::User => &mut user_ids, + NotificationEntityKind::Channel => &mut channel_ids, + NotificationEntityKind::ChannelMessage => &mut message_ids, + } + .push(id); + } + + result.notifications.push(proto::Notification { + kind: row.kind as u32, + timestamp: row.created_at.assume_utc().unix_timestamp() as u64, + is_read: row.is_read, + entity_id_1: row.entity_id_1.map(|id| id as u64), + entity_id_2: row.entity_id_2.map(|id| id as u64), + entity_id_3: row.entity_id_3.map(|id| id as u64), + }); + } + + let users = user::Entity::find() + .filter(user::Column::Id.is_in(user_ids)) + .all(&*tx) + .await?; + let channels = channel::Entity::find() + .filter(user::Column::Id.is_in(channel_ids)) + .all(&*tx) + .await?; + let messages = channel_message::Entity::find() + .filter(user::Column::Id.is_in(message_ids)) + .all(&*tx) + .await?; + + for user in users { + result.users.push(proto::User { + id: user.id.to_proto(), + github_login: user.github_login, + avatar_url: String::new(), + }); + } + for channel in channels { + result.channels.push(proto::Channel { + id: channel.id.to_proto(), + name: channel.name, + }); + } + for message in messages { + result.messages.push(proto::ChannelMessage { + id: message.id.to_proto(), + body: message.body, + timestamp: message.sent_at.assume_utc().unix_timestamp() as u64, + sender_id: message.sender_id.to_proto(), + nonce: None, + }); + } + + Ok(result) + }) + .await + } + + pub async fn create_notification( + &self, + recipient_id: UserId, + notification: Notification, + tx: &DatabaseTransaction, + ) -> Result<()> { + let (kind, associated_entities) = notification.to_fields(); + notification::ActiveModel { + recipient_id: ActiveValue::Set(recipient_id), + kind: ActiveValue::Set(kind as i32), + entity_id_1: ActiveValue::Set(associated_entities[0].map(|(id, _)| id as i32)), + entity_id_2: ActiveValue::Set(associated_entities[1].map(|(id, _)| id as i32)), + entity_id_3: ActiveValue::Set(associated_entities[2].map(|(id, _)| id as i32)), + ..Default::default() + } + .save(&*tx) + .await?; + Ok(()) + } +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index e19391da7d..4336217b23 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -12,6 +12,8 @@ pub mod contact; pub mod feature_flag; pub mod follower; pub mod language_server; +pub mod notification; +pub mod notification_kind; pub mod observed_buffer_edits; pub mod observed_channel_messages; pub mod project; diff --git a/crates/collab/src/db/tables/notification.rs b/crates/collab/src/db/tables/notification.rs new file mode 100644 index 0000000000..6a0abe9dc6 --- /dev/null +++ b/crates/collab/src/db/tables/notification.rs @@ -0,0 +1,29 @@ +use crate::db::{NotificationId, UserId}; +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notifications")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: NotificationId, + pub recipient_id: UserId, + pub kind: i32, + pub is_read: bool, + pub created_at: PrimitiveDateTime, + pub entity_id_1: Option, + pub entity_id_2: Option, + pub entity_id_3: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::RecipientId", + to = "super::user::Column::Id" + )] + Recipient, +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/notification_kind.rs b/crates/collab/src/db/tables/notification_kind.rs new file mode 100644 index 0000000000..32dfb2065a --- /dev/null +++ b/crates/collab/src/db/tables/notification_kind.rs @@ -0,0 +1,14 @@ +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notification_kinds")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/rpc/Cargo.toml b/crates/rpc/Cargo.toml index 3c307be4fb..bc750374dd 100644 --- a/crates/rpc/Cargo.toml +++ b/crates/rpc/Cargo.toml @@ -29,6 +29,7 @@ rsa = "0.4" serde.workspace = true serde_derive.workspace = true smol-timeout = "0.6" +strum.workspace = true tracing = { version = "0.1.34", features = ["log"] } zstd = "0.11" diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 3501e70e6a..f51d11d3db 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -170,7 +170,9 @@ message Envelope { LinkChannel link_channel = 140; UnlinkChannel unlink_channel = 141; - MoveChannel move_channel = 142; // current max: 144 + MoveChannel move_channel = 142; + + AddNotifications add_notification = 145; // Current max } } @@ -1557,3 +1559,40 @@ message UpdateDiffBase { uint64 buffer_id = 2; optional string diff_base = 3; } + +message AddNotifications { + repeated Notification notifications = 1; + repeated User users = 2; + repeated Channel channels = 3; + repeated ChannelMessage messages = 4; +} + +message Notification { + uint32 kind = 1; + uint64 timestamp = 2; + bool is_read = 3; + optional uint64 entity_id_1 = 4; + optional uint64 entity_id_2 = 5; + optional uint64 entity_id_3 = 6; + + // oneof variant { + // ContactRequest contact_request = 3; + // ChannelInvitation channel_invitation = 4; + // ChatMessageMention chat_message_mention = 5; + // }; + + // message ContactRequest { + // uint64 requester_id = 1; + // } + + // message ChannelInvitation { + // uint64 inviter_id = 1; + // uint64 channel_id = 2; + // } + + // message ChatMessageMention { + // uint64 sender_id = 1; + // uint64 channel_id = 2; + // uint64 message_id = 3; + // } +} diff --git a/crates/rpc/src/notification.rs b/crates/rpc/src/notification.rs new file mode 100644 index 0000000000..40794a11c3 --- /dev/null +++ b/crates/rpc/src/notification.rs @@ -0,0 +1,105 @@ +use strum::{Display, EnumIter, EnumString, IntoEnumIterator}; + +// An integer indicating a type of notification. The variants' numerical +// values are stored in the database, so they should never be removed +// or changed. +#[repr(i32)] +#[derive(Copy, Clone, Debug, EnumIter, EnumString, Display)] +pub enum NotificationKind { + ContactRequest = 0, + ChannelInvitation = 1, + ChannelMessageMention = 2, +} + +pub enum Notification { + ContactRequest { + requester_id: u64, + }, + ChannelInvitation { + inviter_id: u64, + channel_id: u64, + }, + ChannelMessageMention { + sender_id: u64, + channel_id: u64, + message_id: u64, + }, +} + +#[derive(Copy, Clone)] +pub enum NotificationEntityKind { + User, + Channel, + ChannelMessage, +} + +impl Notification { + pub fn from_fields(kind: NotificationKind, entity_ids: [Option; 3]) -> Option { + use NotificationKind::*; + + Some(match kind { + ContactRequest => Self::ContactRequest { + requester_id: entity_ids[0]?, + }, + ChannelInvitation => Self::ChannelInvitation { + inviter_id: entity_ids[0]?, + channel_id: entity_ids[1]?, + }, + ChannelMessageMention => Self::ChannelMessageMention { + sender_id: entity_ids[0]?, + channel_id: entity_ids[1]?, + message_id: entity_ids[2]?, + }, + }) + } + + pub fn to_fields(&self) -> (NotificationKind, [Option<(u64, NotificationEntityKind)>; 3]) { + use NotificationKind::*; + + match self { + Self::ContactRequest { requester_id } => ( + ContactRequest, + [ + Some((*requester_id, NotificationEntityKind::User)), + None, + None, + ], + ), + + Self::ChannelInvitation { + inviter_id, + channel_id, + } => ( + ChannelInvitation, + [ + Some((*inviter_id, NotificationEntityKind::User)), + Some((*channel_id, NotificationEntityKind::User)), + None, + ], + ), + + Self::ChannelMessageMention { + sender_id, + channel_id, + message_id, + } => ( + ChannelMessageMention, + [ + Some((*sender_id, NotificationEntityKind::User)), + Some((*channel_id, NotificationEntityKind::ChannelMessage)), + Some((*message_id, NotificationEntityKind::Channel)), + ], + ), + } + } +} + +impl NotificationKind { + pub fn all() -> impl Iterator { + Self::iter() + } + + pub fn from_i32(i: i32) -> Option { + Self::iter().find(|kind| *kind as i32 == i) + } +} diff --git a/crates/rpc/src/rpc.rs b/crates/rpc/src/rpc.rs index 942672b94b..539ef014bb 100644 --- a/crates/rpc/src/rpc.rs +++ b/crates/rpc/src/rpc.rs @@ -1,9 +1,12 @@ pub mod auth; mod conn; +mod notification; mod peer; pub mod proto; + pub use conn::Connection; pub use peer::*; +pub use notification::*; mod macros; pub const PROTOCOL_VERSION: u32 = 64;