Acknowledge channel notes and chat changes when views are active

Co-authored-by: Mikayla <mikayla@zed.dev>
This commit is contained in:
Max Brunsfeld 2023-10-03 17:40:10 -07:00
parent af09861f5c
commit 61e0289014
19 changed files with 478 additions and 209 deletions

View file

@ -439,8 +439,8 @@ pub struct ChannelsForUser {
pub channels: ChannelGraph,
pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
pub channels_with_admin_privileges: HashSet<ChannelId>,
pub channels_with_changed_notes: HashSet<ChannelId>,
pub channels_with_new_messages: HashSet<ChannelId>,
pub unseen_buffer_changes: Vec<proto::UnseenChannelBufferChange>,
pub channel_messages: Vec<proto::UnseenChannelMessage>,
}
#[derive(Debug)]

View file

@ -432,7 +432,12 @@ impl Database {
channel_id: ChannelId,
user: UserId,
operations: &[proto::Operation],
) -> Result<(Vec<ConnectionId>, Vec<UserId>)> {
) -> Result<(
Vec<ConnectionId>,
Vec<UserId>,
i32,
Vec<proto::VectorClockEntry>,
)> {
self.transaction(move |tx| async move {
self.check_user_is_channel_member(channel_id, user, &*tx)
.await?;
@ -453,6 +458,7 @@ impl Database {
.collect::<Vec<_>>();
let mut channel_members;
let max_version;
if !operations.is_empty() {
let max_operation = operations
@ -460,6 +466,11 @@ impl Database {
.max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
.unwrap();
max_version = vec![proto::VectorClockEntry {
replica_id: *max_operation.replica_id.as_ref() as u32,
timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
}];
// get current channel participants and save the max operation above
self.save_max_operation(
user,
@ -492,6 +503,7 @@ impl Database {
.await?;
} else {
channel_members = Vec::new();
max_version = Vec::new();
}
let mut connections = Vec::new();
@ -510,7 +522,7 @@ impl Database {
});
}
Ok((connections, channel_members))
Ok((connections, channel_members, buffer.epoch, max_version))
})
.await
}
@ -712,12 +724,12 @@ impl Database {
.await
}
pub async fn channels_with_changed_notes(
pub async fn unseen_channel_buffer_changes(
&self,
user_id: UserId,
channel_ids: &[ChannelId],
tx: &DatabaseTransaction,
) -> Result<HashSet<ChannelId>> {
) -> Result<Vec<proto::UnseenChannelBufferChange>> {
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryIds {
ChannelId,
@ -750,37 +762,45 @@ impl Database {
}
drop(rows);
let last_operations = self
.get_last_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
let latest_operations = self
.get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
.await?;
let mut channels_with_new_changes = HashSet::default();
for last_operation in last_operations {
if let Some(observed_edit) = observed_edits_by_buffer_id.get(&last_operation.buffer_id)
{
if observed_edit.epoch == last_operation.epoch
&& observed_edit.lamport_timestamp == last_operation.lamport_timestamp
&& observed_edit.replica_id == last_operation.replica_id
let mut changes = Vec::default();
for latest in latest_operations {
if let Some(observed) = observed_edits_by_buffer_id.get(&latest.buffer_id) {
if (
observed.epoch,
observed.lamport_timestamp,
observed.replica_id,
) >= (latest.epoch, latest.lamport_timestamp, latest.replica_id)
{
continue;
}
}
if let Some(channel_id) = channel_ids_by_buffer_id.get(&last_operation.buffer_id) {
channels_with_new_changes.insert(*channel_id);
if let Some(channel_id) = channel_ids_by_buffer_id.get(&latest.buffer_id) {
changes.push(proto::UnseenChannelBufferChange {
channel_id: channel_id.to_proto(),
epoch: latest.epoch as u64,
version: vec![proto::VectorClockEntry {
replica_id: latest.replica_id as u32,
timestamp: latest.lamport_timestamp as u32,
}],
});
}
}
Ok(channels_with_new_changes)
Ok(changes)
}
pub async fn get_last_operations_for_buffers(
pub async fn get_latest_operations_for_buffers(
&self,
channel_ids: impl IntoIterator<Item = BufferId>,
buffer_ids: impl IntoIterator<Item = BufferId>,
tx: &DatabaseTransaction,
) -> Result<Vec<buffer_operation::Model>> {
let mut values = String::new();
for id in channel_ids {
for id in buffer_ids {
if !values.is_empty() {
values.push_str(", ");
}
@ -795,13 +815,10 @@ impl Database {
r#"
SELECT
*
FROM (
FROM
(
SELECT
buffer_id,
epoch,
lamport_timestamp,
replica_id,
value,
*,
row_number() OVER (
PARTITION BY buffer_id
ORDER BY
@ -812,17 +829,17 @@ impl Database {
FROM buffer_operations
WHERE
buffer_id in ({values})
) AS operations
) AS last_operations
WHERE
row_number = 1
"#,
);
let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
let operations = buffer_operation::Model::find_by_statement(stmt)
Ok(buffer_operation::Entity::find()
.from_raw_sql(stmt)
.all(&*tx)
.await?;
Ok(operations)
.await?)
}
}

View file

@ -463,20 +463,20 @@ impl Database {
}
let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
let channels_with_changed_notes = self
.channels_with_changed_notes(user_id, &channel_ids, &*tx)
let channel_buffer_changes = self
.unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
.await?;
let channels_with_new_messages = self
.channels_with_new_messages(user_id, &channel_ids, &*tx)
let unseen_messages = self
.unseen_channel_messages(user_id, &channel_ids, &*tx)
.await?;
Ok(ChannelsForUser {
channels: graph,
channel_participants,
channels_with_admin_privileges,
channels_with_changed_notes,
channels_with_new_messages,
unseen_buffer_changes: channel_buffer_changes,
channel_messages: unseen_messages,
})
}

View file

@ -279,12 +279,12 @@ impl Database {
Ok(())
}
pub async fn channels_with_new_messages(
pub async fn unseen_channel_messages(
&self,
user_id: UserId,
channel_ids: &[ChannelId],
tx: &DatabaseTransaction,
) -> Result<collections::HashSet<ChannelId>> {
) -> Result<Vec<proto::UnseenChannelMessage>> {
let mut observed_messages_by_channel_id = HashMap::default();
let mut rows = observed_channel_messages::Entity::find()
.filter(observed_channel_messages::Column::UserId.eq(user_id))
@ -334,7 +334,7 @@ impl Database {
.all(&*tx)
.await?;
let mut channels_with_new_changes = HashSet::default();
let mut changes = Vec::new();
for last_message in last_messages {
if let Some(observed_message) =
observed_messages_by_channel_id.get(&last_message.channel_id)
@ -343,10 +343,13 @@ impl Database {
continue;
}
}
channels_with_new_changes.insert(last_message.channel_id);
changes.push(proto::UnseenChannelMessage {
channel_id: last_message.channel_id.to_proto(),
message_id: last_message.id.to_proto(),
});
}
Ok(channels_with_new_changes)
Ok(changes)
}
pub async fn remove_channel_message(

View file

@ -235,7 +235,7 @@ async fn test_channel_buffers_last_operations(db: &Database) {
.transaction(|tx| {
let buffers = &buffers;
async move {
db.get_last_operations_for_buffers([buffers[0].id, buffers[2].id], &*tx)
db.get_latest_operations_for_buffers([buffers[0].id, buffers[2].id], &*tx)
.await
}
})
@ -299,7 +299,7 @@ async fn test_channel_buffers_last_operations(db: &Database) {
.transaction(|tx| {
let buffers = &buffers;
async move {
db.get_last_operations_for_buffers([buffers[1].id, buffers[2].id], &*tx)
db.get_latest_operations_for_buffers([buffers[1].id, buffers[2].id], &*tx)
.await
}
})
@ -317,7 +317,7 @@ async fn test_channel_buffers_last_operations(db: &Database) {
.transaction(|tx| {
let buffers = &buffers;
async move {
db.get_last_operations_for_buffers([buffers[0].id, buffers[1].id], &*tx)
db.get_latest_operations_for_buffers([buffers[0].id, buffers[1].id], &*tx)
.await
}
})
@ -331,11 +331,11 @@ async fn test_channel_buffers_last_operations(db: &Database) {
],
);
let changed_channels = db
let buffer_changes = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.channels_with_changed_notes(
db.unseen_channel_buffer_changes(
observer_id,
&[
buffers[0].channel_id,
@ -349,31 +349,42 @@ async fn test_channel_buffers_last_operations(db: &Database) {
})
.await
.unwrap();
assert_eq!(
changed_channels,
buffer_changes,
[
buffers[0].channel_id,
buffers[1].channel_id,
buffers[2].channel_id,
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[0].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[0].version()),
},
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[1].channel_id.to_proto(),
epoch: 1,
version: serialize_version(&text_buffers[1].version()),
},
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[2].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[2].version()),
},
]
.into_iter()
.collect::<HashSet<_>>()
);
db.observe_buffer_version(
buffers[1].id,
observer_id,
1,
&serialize_version(&text_buffers[1].version()),
serialize_version(&text_buffers[1].version()).as_slice(),
)
.await
.unwrap();
let changed_channels = db
let buffer_changes = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.channels_with_changed_notes(
db.unseen_channel_buffer_changes(
observer_id,
&[
buffers[0].channel_id,
@ -387,11 +398,21 @@ async fn test_channel_buffers_last_operations(db: &Database) {
})
.await
.unwrap();
assert_eq!(
changed_channels,
[buffers[0].channel_id, buffers[2].channel_id,]
.into_iter()
.collect::<HashSet<_>>()
buffer_changes,
[
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[0].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[0].version()),
},
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[2].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[2].version()),
},
]
);
// Observe an earlier version of the buffer.
@ -407,11 +428,11 @@ async fn test_channel_buffers_last_operations(db: &Database) {
.await
.unwrap();
let changed_channels = db
let buffer_changes = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.channels_with_changed_notes(
db.unseen_channel_buffer_changes(
observer_id,
&[
buffers[0].channel_id,
@ -425,11 +446,21 @@ async fn test_channel_buffers_last_operations(db: &Database) {
})
.await
.unwrap();
assert_eq!(
changed_channels,
[buffers[0].channel_id, buffers[2].channel_id,]
.into_iter()
.collect::<HashSet<_>>()
buffer_changes,
[
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[0].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[0].version()),
},
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[2].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[2].version()),
},
]
);
}

View file

@ -144,25 +144,32 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
.await
.unwrap();
let _ = db
let (fourth_message, _, _) = db
.create_channel_message(channel_2, user, "2_1", OffsetDateTime::now_utc(), 4)
.await
.unwrap();
// Check that observer has new messages
let channels_with_new_messages = db
let unseen_messages = db
.transaction(|tx| async move {
db.channels_with_new_messages(observer, &[channel_1, channel_2], &*tx)
db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx)
.await
})
.await
.unwrap();
assert_eq!(
channels_with_new_messages,
[channel_1, channel_2]
.into_iter()
.collect::<collections::HashSet<_>>()
unseen_messages,
[
rpc::proto::UnseenChannelMessage {
channel_id: channel_1.to_proto(),
message_id: third_message.to_proto(),
},
rpc::proto::UnseenChannelMessage {
channel_id: channel_2.to_proto(),
message_id: fourth_message.to_proto(),
},
]
);
// Observe the second message
@ -171,18 +178,25 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
.unwrap();
// Make sure the observer still has a new message
let channels_with_new_messages = db
let unseen_messages = db
.transaction(|tx| async move {
db.channels_with_new_messages(observer, &[channel_1, channel_2], &*tx)
db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx)
.await
})
.await
.unwrap();
assert_eq!(
channels_with_new_messages,
[channel_1, channel_2]
.into_iter()
.collect::<collections::HashSet<_>>()
unseen_messages,
[
rpc::proto::UnseenChannelMessage {
channel_id: channel_1.to_proto(),
message_id: third_message.to_proto(),
},
rpc::proto::UnseenChannelMessage {
channel_id: channel_2.to_proto(),
message_id: fourth_message.to_proto(),
},
]
);
// Observe the third message,
@ -191,16 +205,20 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
.unwrap();
// Make sure the observer does not have a new method
let channels_with_new_messages = db
let unseen_messages = db
.transaction(|tx| async move {
db.channels_with_new_messages(observer, &[channel_1, channel_2], &*tx)
db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx)
.await
})
.await
.unwrap();
assert_eq!(
channels_with_new_messages,
[channel_2].into_iter().collect::<collections::HashSet<_>>()
unseen_messages,
[rpc::proto::UnseenChannelMessage {
channel_id: channel_2.to_proto(),
message_id: fourth_message.to_proto(),
}]
);
// Observe the second message again, should not regress our observed state
@ -208,16 +226,19 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
.await
.unwrap();
// Make sure the observer does not have a new method
let channels_with_new_messages = db
// Make sure the observer does not have a new message
let unseen_messages = db
.transaction(|tx| async move {
db.channels_with_new_messages(observer, &[channel_1, channel_2], &*tx)
db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx)
.await
})
.await
.unwrap();
assert_eq!(
channels_with_new_messages,
[channel_2].into_iter().collect::<collections::HashSet<_>>()
unseen_messages,
[rpc::proto::UnseenChannelMessage {
channel_id: channel_2.to_proto(),
message_id: fourth_message.to_proto(),
}]
);
}

View file

@ -274,7 +274,8 @@ impl Server {
.add_message_handler(unfollow)
.add_message_handler(update_followers)
.add_message_handler(update_diff_base)
.add_request_handler(get_private_user_info);
.add_request_handler(get_private_user_info)
.add_message_handler(acknowledge_channel_message);
Arc::new(server)
}
@ -2568,16 +2569,8 @@ async fn respond_to_channel_invite(
name: channel.name,
}),
);
update.notes_changed = result
.channels_with_changed_notes
.iter()
.map(|id| id.to_proto())
.collect();
update.new_messages = result
.channels_with_new_messages
.iter()
.map(|id| id.to_proto())
.collect();
update.unseen_channel_messages = result.channel_messages;
update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
update.insert_edge = result.channels.edges;
update
.channel_participants
@ -2701,7 +2694,7 @@ async fn update_channel_buffer(
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
let (collaborators, non_collaborators) = db
let (collaborators, non_collaborators, epoch, version) = db
.update_channel_buffer(channel_id, session.user_id, &request.operations)
.await?;
@ -2726,7 +2719,11 @@ async fn update_channel_buffer(
session.peer.send(
peer_id.into(),
proto::UpdateChannels {
notes_changed: vec![channel_id.to_proto()],
unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
channel_id: channel_id.to_proto(),
epoch: epoch as u64,
version: version.clone(),
}],
..Default::default()
},
)
@ -2859,9 +2856,7 @@ async fn send_channel_message(
message: Some(message),
})?;
dbg!(&non_participants);
let pool = &*session.connection_pool().await;
broadcast(
None,
non_participants
@ -2871,7 +2866,10 @@ async fn send_channel_message(
session.peer.send(
peer_id.into(),
proto::UpdateChannels {
new_messages: vec![channel_id.to_proto()],
unseen_channel_messages: vec![proto::UnseenChannelMessage {
channel_id: channel_id.to_proto(),
message_id: message_id.to_proto(),
}],
..Default::default()
},
)
@ -2900,6 +2898,20 @@ async fn remove_channel_message(
Ok(())
}
async fn acknowledge_channel_message(
request: proto::AckChannelMessage,
session: Session,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
session
.db()
.await
.observe_channel_message(channel_id, session.user_id, message_id)
.await?;
Ok(())
}
async fn join_channel_chat(
request: proto::JoinChannelChat,
response: Response<proto::JoinChannelChat>,
@ -3035,18 +3047,8 @@ fn build_initial_channels_update(
});
}
update.notes_changed = channels
.channels_with_changed_notes
.iter()
.map(|channel_id| channel_id.to_proto())
.collect();
update.new_messages = channels
.channels_with_new_messages
.iter()
.map(|channel_id| channel_id.to_proto())
.collect();
update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
update.unseen_channel_messages = channels.channel_messages;
update.insert_edge = channels.channels.edges;
for (channel_id, participants) in channels.channel_participants {

View file

@ -445,8 +445,8 @@ fn channel(id: u64, name: &'static str) -> Channel {
Channel {
id,
name: name.to_string(),
has_note_changed: false,
has_new_messages: false,
unseen_note_version: None,
unseen_message_id: None,
}
}

View file

@ -151,12 +151,12 @@ impl TestServer {
Arc::get_mut(&mut client)
.unwrap()
.set_id(user_id.0 as usize)
.set_id(user_id.to_proto())
.override_authenticate(move |cx| {
cx.spawn(|_| async move {
let access_token = "the-token".to_string();
Ok(Credentials {
user_id: user_id.0 as u64,
user_id: user_id.to_proto(),
access_token,
})
})