Avoid creating duplicate invite notifications

This commit is contained in:
Max Brunsfeld 2023-10-13 16:57:28 -07:00
parent 5a0afcc835
commit cb7b011d6b
6 changed files with 109 additions and 44 deletions

View file

@ -330,4 +330,4 @@ CREATE TABLE "notifications" (
"content" TEXT "content" TEXT
); );
CREATE INDEX "index_notifications_on_recipient_id" ON "notifications" ("recipient_id"); CREATE INDEX "index_notifications_on_recipient_id_is_read" ON "notifications" ("recipient_id", "is_read");

View file

@ -161,7 +161,7 @@ impl Database {
invitee_id: UserId, invitee_id: UserId,
inviter_id: UserId, inviter_id: UserId,
is_admin: bool, is_admin: bool,
) -> Result<()> { ) -> Result<Option<proto::Notification>> {
self.transaction(move |tx| async move { self.transaction(move |tx| async move {
self.check_user_is_channel_admin(channel_id, inviter_id, &*tx) self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
.await?; .await?;
@ -176,7 +176,16 @@ impl Database {
.insert(&*tx) .insert(&*tx)
.await?; .await?;
Ok(()) self.create_notification(
invitee_id,
rpc::Notification::ChannelInvitation {
actor_id: inviter_id.to_proto(),
channel_id: channel_id.to_proto(),
},
true,
&*tx,
)
.await
}) })
.await .await
} }

View file

@ -123,7 +123,7 @@ impl Database {
&self, &self,
sender_id: UserId, sender_id: UserId,
receiver_id: UserId, receiver_id: UserId,
) -> Result<proto::Notification> { ) -> Result<Option<proto::Notification>> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
let (id_a, id_b, a_to_b) = if sender_id < receiver_id { let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
(sender_id, receiver_id, true) (sender_id, receiver_id, true)
@ -169,6 +169,7 @@ impl Database {
rpc::Notification::ContactRequest { rpc::Notification::ContactRequest {
actor_id: sender_id.to_proto(), actor_id: sender_id.to_proto(),
}, },
true,
&*tx, &*tx,
) )
.await .await
@ -212,7 +213,7 @@ impl Database {
let mut deleted_notification_id = None; let mut deleted_notification_id = None;
if !contact.accepted { if !contact.accepted {
deleted_notification_id = self deleted_notification_id = self
.delete_notification( .remove_notification(
responder_id, responder_id,
rpc::Notification::ContactRequest { rpc::Notification::ContactRequest {
actor_id: requester_id.to_proto(), actor_id: requester_id.to_proto(),
@ -273,7 +274,7 @@ impl Database {
responder_id: UserId, responder_id: UserId,
requester_id: UserId, requester_id: UserId,
accept: bool, accept: bool,
) -> Result<proto::Notification> { ) -> Result<Option<proto::Notification>> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
let (id_a, id_b, a_to_b) = if responder_id < requester_id { let (id_a, id_b, a_to_b) = if responder_id < requester_id {
(responder_id, requester_id, false) (responder_id, requester_id, false)
@ -320,6 +321,7 @@ impl Database {
rpc::Notification::ContactRequestAccepted { rpc::Notification::ContactRequestAccepted {
actor_id: responder_id.to_proto(), actor_id: responder_id.to_proto(),
}, },
true,
&*tx, &*tx,
) )
.await .await

View file

@ -51,18 +51,12 @@ impl Database {
.await?; .await?;
while let Some(row) = rows.next().await { while let Some(row) = rows.next().await {
let row = row?; let row = row?;
let Some(kind) = self.notification_kinds_by_id.get(&row.kind) else { let kind = row.kind;
log::warn!("unknown notification kind {:?}", row.kind); if let Some(proto) = self.model_to_proto(row) {
continue; result.push(proto);
}; } else {
result.push(proto::Notification { log::warn!("unknown notification kind {:?}", kind);
id: row.id.to_proto(), }
kind: kind.to_string(),
timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
is_read: row.is_read,
content: row.content,
actor_id: row.actor_id.map(|id| id.to_proto()),
});
} }
result.reverse(); result.reverse();
Ok(result) Ok(result)
@ -74,19 +68,48 @@ impl Database {
&self, &self,
recipient_id: UserId, recipient_id: UserId,
notification: Notification, notification: Notification,
avoid_duplicates: bool,
tx: &DatabaseTransaction, tx: &DatabaseTransaction,
) -> Result<proto::Notification> { ) -> Result<Option<proto::Notification>> {
let notification = notification.to_proto(); let notification_proto = notification.to_proto();
let kind = *self let kind = *self
.notification_kinds_by_name .notification_kinds_by_name
.get(&notification.kind) .get(&notification_proto.kind)
.ok_or_else(|| anyhow!("invalid notification kind {:?}", notification.kind))?; .ok_or_else(|| anyhow!("invalid notification kind {:?}", notification_proto.kind))?;
let actor_id = notification_proto.actor_id.map(|id| UserId::from_proto(id));
if avoid_duplicates {
let mut existing_notifications = notification::Entity::find()
.filter(
Condition::all()
.add(notification::Column::RecipientId.eq(recipient_id))
.add(notification::Column::IsRead.eq(false))
.add(notification::Column::Kind.eq(kind))
.add(notification::Column::ActorId.eq(actor_id)),
)
.stream(&*tx)
.await?;
// Check if this notification already exists. Don't rely on the
// JSON serialization being identical, in case the notification enum
// is changed in backward-compatible ways over time.
while let Some(row) = existing_notifications.next().await {
let row = row?;
if let Some(proto) = self.model_to_proto(row) {
if let Some(existing) = Notification::from_proto(&proto) {
if existing == notification {
return Ok(None);
}
}
}
}
}
let model = notification::ActiveModel { let model = notification::ActiveModel {
recipient_id: ActiveValue::Set(recipient_id), recipient_id: ActiveValue::Set(recipient_id),
kind: ActiveValue::Set(kind), kind: ActiveValue::Set(kind),
content: ActiveValue::Set(notification.content.clone()), content: ActiveValue::Set(notification_proto.content.clone()),
actor_id: ActiveValue::Set(notification.actor_id.map(|id| UserId::from_proto(id))), actor_id: ActiveValue::Set(actor_id),
is_read: ActiveValue::NotSet, is_read: ActiveValue::NotSet,
created_at: ActiveValue::NotSet, created_at: ActiveValue::NotSet,
id: ActiveValue::NotSet, id: ActiveValue::NotSet,
@ -94,17 +117,17 @@ impl Database {
.save(&*tx) .save(&*tx)
.await?; .await?;
Ok(proto::Notification { Ok(Some(proto::Notification {
id: model.id.as_ref().to_proto(), id: model.id.as_ref().to_proto(),
kind: notification.kind.to_string(), kind: notification_proto.kind.to_string(),
timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64, timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64,
is_read: false, is_read: false,
content: notification.content, content: notification_proto.content,
actor_id: notification.actor_id, actor_id: notification_proto.actor_id,
}) }))
} }
pub async fn delete_notification( pub async fn remove_notification(
&self, &self,
recipient_id: UserId, recipient_id: UserId,
notification: Notification, notification: Notification,
@ -133,4 +156,16 @@ impl Database {
} }
Ok(notification.map(|notification| notification.id)) Ok(notification.map(|notification| notification.id))
} }
fn model_to_proto(&self, row: notification::Model) -> Option<proto::Notification> {
let kind = self.notification_kinds_by_id.get(&row.kind)?;
Some(proto::Notification {
id: row.id.to_proto(),
kind: kind.to_string(),
timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
is_read: row.is_read,
content: row.content,
actor_id: row.actor_id.map(|id| id.to_proto()),
})
}
} }

View file

@ -2097,6 +2097,7 @@ async fn request_contact(
.user_connection_ids(responder_id) .user_connection_ids(responder_id)
{ {
session.peer.send(connection_id, update.clone())?; session.peer.send(connection_id, update.clone())?;
if let Some(notification) = &notification {
session.peer.send( session.peer.send(
connection_id, connection_id,
proto::NewNotification { proto::NewNotification {
@ -2104,6 +2105,7 @@ async fn request_contact(
}, },
)?; )?;
} }
}
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
Ok(()) Ok(())
@ -2156,6 +2158,7 @@ async fn respond_to_contact_request(
.push(responder_id.to_proto()); .push(responder_id.to_proto());
for connection_id in pool.user_connection_ids(requester_id) { for connection_id in pool.user_connection_ids(requester_id) {
session.peer.send(connection_id, update.clone())?; session.peer.send(connection_id, update.clone())?;
if let Some(notification) = &notification {
session.peer.send( session.peer.send(
connection_id, connection_id,
proto::NewNotification { proto::NewNotification {
@ -2164,6 +2167,7 @@ async fn respond_to_contact_request(
)?; )?;
} }
} }
}
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
Ok(()) Ok(())
@ -2306,7 +2310,8 @@ async fn invite_channel_member(
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let invitee_id = UserId::from_proto(request.user_id); let invitee_id = UserId::from_proto(request.user_id);
db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin) let notification = db
.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
.await?; .await?;
let (channel, _) = db let (channel, _) = db
@ -2319,12 +2324,21 @@ async fn invite_channel_member(
id: channel.id.to_proto(), id: channel.id.to_proto(),
name: channel.name, name: channel.name,
}); });
for connection_id in session for connection_id in session
.connection_pool() .connection_pool()
.await .await
.user_connection_ids(invitee_id) .user_connection_ids(invitee_id)
{ {
session.peer.send(connection_id, update.clone())?; session.peer.send(connection_id, update.clone())?;
if let Some(notification) = &notification {
session.peer.send(
connection_id,
proto::NewNotification {
notification: Some(notification.clone()),
},
)?;
}
} }
response.send(proto::Ack {})?; response.send(proto::Ack {})?;

View file

@ -209,7 +209,12 @@ impl NotificationPanel {
channel_id, channel_id,
} => { } => {
actor = user_store.get_cached_user(inviter_id)?; actor = user_store.get_cached_user(inviter_id)?;
let channel = channel_store.channel_for_id(channel_id)?; let channel = channel_store.channel_for_id(channel_id).or_else(|| {
channel_store
.channel_invitations()
.iter()
.find(|c| c.id == channel_id)
})?;
icon = "icons/hash.svg"; icon = "icons/hash.svg";
text = format!( text = format!(