From d9d997b218712b6efdf673841edbe68a0d03297e Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 2 Oct 2023 15:58:34 -0700 Subject: [PATCH] Avoid N+1 query for channels with notes changes Also, start work on new timing for recording observed notes edits. Co-authored-by: Mikayla --- .../20221109000000_test_schema.sql | 1 + .../20230925210437_add_channel_changes.sql | 1 + crates/collab/src/db.rs | 4 +- crates/collab/src/db/queries/buffers.rs | 269 ++++++++++++------ crates/collab/src/db/queries/channels.rs | 12 +- .../src/db/tables/observed_buffer_edits.rs | 1 + crates/collab/src/db/tests/buffer_tests.rs | 190 +++++++++++++ 7 files changed, 381 insertions(+), 97 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 277a78d2d6..2d963ff15f 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -296,6 +296,7 @@ CREATE TABLE "observed_buffer_edits" ( "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE, "epoch" INTEGER NOT NULL, "lamport_timestamp" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, PRIMARY KEY (user_id, buffer_id) ); diff --git a/crates/collab/migrations/20230925210437_add_channel_changes.sql b/crates/collab/migrations/20230925210437_add_channel_changes.sql index 7787975c1c..250a9ac731 100644 --- a/crates/collab/migrations/20230925210437_add_channel_changes.sql +++ b/crates/collab/migrations/20230925210437_add_channel_changes.sql @@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS "observed_buffer_edits" ( "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE, "epoch" INTEGER NOT NULL, "lamport_timestamp" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, PRIMARY KEY (user_id, buffer_id) ); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 8f7f9cc975..b0223bbf27 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -119,7 +119,7 @@ impl Database { Ok(new_migrations) } - async fn transaction(&self, f: F) -> Result + pub async fn transaction(&self, f: F) -> Result where F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, @@ -321,7 +321,7 @@ fn is_serialization_error(error: &Error) -> bool { } } -struct TransactionHandle(Arc>); +pub struct TransactionHandle(Arc>); impl Deref for TransactionHandle { type Target = DatabaseTransaction; diff --git a/crates/collab/src/db/queries/buffers.rs b/crates/collab/src/db/queries/buffers.rs index 1e8dd30c6b..b22bfc80cf 100644 --- a/crates/collab/src/db/queries/buffers.rs +++ b/crates/collab/src/db/queries/buffers.rs @@ -79,12 +79,13 @@ impl Database { self.get_buffer_state(&buffer, &tx).await?; // Save the last observed operation - if let Some(max_operation) = max_operation { + if let Some(op) = max_operation { observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel { user_id: ActiveValue::Set(user_id), buffer_id: ActiveValue::Set(buffer.id), - epoch: ActiveValue::Set(max_operation.0), - lamport_timestamp: ActiveValue::Set(max_operation.1), + epoch: ActiveValue::Set(op.epoch), + lamport_timestamp: ActiveValue::Set(op.lamport_timestamp), + replica_id: ActiveValue::Set(op.replica_id), }) .on_conflict( OnConflict::columns([ @@ -99,37 +100,6 @@ impl Database { ) .exec(&*tx) .await?; - } else { - let buffer_max = buffer_operation::Entity::find() - .filter(buffer_operation::Column::BufferId.eq(buffer.id)) - .filter(buffer_operation::Column::Epoch.eq(buffer.epoch.saturating_sub(1))) - .order_by(buffer_operation::Column::Epoch, Desc) - .order_by(buffer_operation::Column::LamportTimestamp, Desc) - .one(&*tx) - .await? - .map(|model| (model.epoch, model.lamport_timestamp)); - - if let Some(buffer_max) = buffer_max { - observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel { - user_id: ActiveValue::Set(user_id), - buffer_id: ActiveValue::Set(buffer.id), - epoch: ActiveValue::Set(buffer_max.0), - lamport_timestamp: ActiveValue::Set(buffer_max.1), - }) - .on_conflict( - OnConflict::columns([ - observed_buffer_edits::Column::UserId, - observed_buffer_edits::Column::BufferId, - ]) - .update_columns([ - observed_buffer_edits::Column::Epoch, - observed_buffer_edits::Column::LamportTimestamp, - ]) - .to_owned(), - ) - .exec(&*tx) - .await?; - } } Ok(proto::JoinChannelBufferResponse { @@ -487,13 +457,8 @@ impl Database { if !operations.is_empty() { // get current channel participants and save the max operation above - self.save_max_operation_for_collaborators( - operations.as_slice(), - channel_id, - buffer.id, - &*tx, - ) - .await?; + self.save_max_operation(user, buffer.id, buffer.epoch, operations.as_slice(), &*tx) + .await?; channel_members = self.get_channel_members_internal(channel_id, &*tx).await?; let collaborators = self @@ -539,54 +504,55 @@ impl Database { .await } - async fn save_max_operation_for_collaborators( + async fn save_max_operation( &self, - operations: &[buffer_operation::ActiveModel], - channel_id: ChannelId, + user_id: UserId, buffer_id: BufferId, + epoch: i32, + operations: &[buffer_operation::ActiveModel], tx: &DatabaseTransaction, ) -> Result<()> { + use observed_buffer_edits::Column; + let max_operation = operations .iter() - .map(|storage_model| { - ( - storage_model.epoch.clone(), - storage_model.lamport_timestamp.clone(), - ) - }) - .max_by( - |(epoch_a, lamport_timestamp_a), (epoch_b, lamport_timestamp_b)| { - epoch_a.as_ref().cmp(epoch_b.as_ref()).then( - lamport_timestamp_a - .as_ref() - .cmp(lamport_timestamp_b.as_ref()), - ) - }, - ) + .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref())) .unwrap(); - let users = self - .get_channel_buffer_collaborators_internal(channel_id, tx) - .await?; - - observed_buffer_edits::Entity::insert_many(users.iter().map(|id| { - observed_buffer_edits::ActiveModel { - user_id: ActiveValue::Set(*id), - buffer_id: ActiveValue::Set(buffer_id), - epoch: max_operation.0.clone(), - lamport_timestamp: ActiveValue::Set(*max_operation.1.as_ref()), - } - })) + observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel { + user_id: ActiveValue::Set(user_id), + buffer_id: ActiveValue::Set(buffer_id), + epoch: ActiveValue::Set(epoch), + replica_id: max_operation.replica_id.clone(), + lamport_timestamp: max_operation.lamport_timestamp.clone(), + }) .on_conflict( - OnConflict::columns([ - observed_buffer_edits::Column::UserId, - observed_buffer_edits::Column::BufferId, - ]) - .update_columns([ - observed_buffer_edits::Column::Epoch, - observed_buffer_edits::Column::LamportTimestamp, - ]) - .to_owned(), + OnConflict::columns([Column::UserId, Column::BufferId]) + .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId]) + .target_cond_where( + Condition::any() + .add(Column::Epoch.lt(*max_operation.epoch.as_ref())) + .add( + Condition::all() + .add(Column::Epoch.eq(*max_operation.epoch.as_ref())) + .add( + Condition::any() + .add( + Column::LamportTimestamp + .lt(*max_operation.lamport_timestamp.as_ref()), + ) + .add( + Column::LamportTimestamp + .eq(*max_operation.lamport_timestamp.as_ref()) + .and( + Column::ReplicaId + .lt(*max_operation.replica_id.as_ref()), + ), + ), + ), + ), + ) + .to_owned(), ) .exec(tx) .await?; @@ -611,7 +577,7 @@ impl Database { .ok_or_else(|| anyhow!("missing buffer snapshot"))?) } - async fn get_channel_buffer( + pub async fn get_channel_buffer( &self, channel_id: ChannelId, tx: &DatabaseTransaction, @@ -630,7 +596,11 @@ impl Database { &self, buffer: &buffer::Model, tx: &DatabaseTransaction, - ) -> Result<(String, Vec, Option<(i32, i32)>)> { + ) -> Result<( + String, + Vec, + Option, + )> { let id = buffer.id; let (base_text, version) = if buffer.epoch > 0 { let snapshot = buffer_snapshot::Entity::find() @@ -655,24 +625,28 @@ impl Database { .eq(id) .and(buffer_operation::Column::Epoch.eq(buffer.epoch)), ) + .order_by_asc(buffer_operation::Column::LamportTimestamp) + .order_by_asc(buffer_operation::Column::ReplicaId) .stream(&*tx) .await?; - let mut operations = Vec::new(); - let mut max_epoch: Option = None; - let mut max_timestamp: Option = None; + let mut operations = Vec::new(); + let mut last_row = None; while let Some(row) = rows.next().await { let row = row?; - - max_assign(&mut max_epoch, row.epoch); - max_assign(&mut max_timestamp, row.lamport_timestamp); - + last_row = Some(buffer_operation::Model { + buffer_id: row.buffer_id, + epoch: row.epoch, + lamport_timestamp: row.lamport_timestamp, + replica_id: row.lamport_timestamp, + value: Default::default(), + }); operations.push(proto::Operation { variant: Some(operation_from_storage(row, version)?), - }) + }); } - Ok((base_text, operations, max_epoch.zip(max_timestamp))) + Ok((base_text, operations, last_row)) } async fn snapshot_channel_buffer( @@ -725,6 +699,119 @@ impl Database { .await } + pub async fn channels_with_changed_notes( + &self, + user_id: UserId, + channel_ids: impl IntoIterator, + tx: &DatabaseTransaction, + ) -> Result> { + #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] + enum QueryIds { + ChannelId, + Id, + } + + let mut channel_ids_by_buffer_id = HashMap::default(); + let mut rows = buffer::Entity::find() + .filter(buffer::Column::ChannelId.is_in(channel_ids)) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + channel_ids_by_buffer_id.insert(row.id, row.channel_id); + } + drop(rows); + + let mut observed_edits_by_buffer_id = HashMap::default(); + let mut rows = observed_buffer_edits::Entity::find() + .filter(observed_buffer_edits::Column::UserId.eq(user_id)) + .filter( + observed_buffer_edits::Column::BufferId + .is_in(channel_ids_by_buffer_id.keys().copied()), + ) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + observed_edits_by_buffer_id.insert(row.buffer_id, row); + } + drop(rows); + + let last_operations = self + .get_last_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx) + .await?; + + let mut channels_with_new_changes = HashSet::default(); + for last_operation in last_operations { + if let Some(observed_edit) = observed_edits_by_buffer_id.get(&last_operation.buffer_id) + { + if observed_edit.epoch == last_operation.epoch + && observed_edit.lamport_timestamp == last_operation.lamport_timestamp + && observed_edit.replica_id == last_operation.replica_id + { + continue; + } + } + + if let Some(channel_id) = channel_ids_by_buffer_id.get(&last_operation.buffer_id) { + channels_with_new_changes.insert(*channel_id); + } + } + + Ok(channels_with_new_changes) + } + + pub async fn get_last_operations_for_buffers( + &self, + channel_ids: impl IntoIterator, + tx: &DatabaseTransaction, + ) -> Result> { + let mut values = String::new(); + for id in channel_ids { + if !values.is_empty() { + values.push_str(", "); + } + write!(&mut values, "({})", id).unwrap(); + } + + if values.is_empty() { + return Ok(Vec::default()); + } + + let sql = format!( + r#" + SELECT + * + FROM ( + SELECT + buffer_id, + epoch, + lamport_timestamp, + replica_id, + value, + row_number() OVER ( + PARTITION BY buffer_id + ORDER BY + epoch DESC, + lamport_timestamp DESC, + replica_id DESC + ) as row_number + FROM buffer_operations + WHERE + buffer_id in ({values}) + ) AS operations + WHERE + row_number = 1 + "#, + ); + + let stmt = Statement::from_string(self.pool.get_database_backend(), sql); + let operations = buffer_operation::Model::find_by_statement(stmt) + .all(&*tx) + .await?; + Ok(operations) + } + pub async fn has_note_changed( &self, user_id: UserId, diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index ea9f64fe5e..6d976b310e 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -463,12 +463,16 @@ impl Database { } } - let mut channels_with_changed_notes = HashSet::default(); + let channels_with_changed_notes = self + .channels_with_changed_notes( + user_id, + graph.channels.iter().map(|channel| channel.id), + &*tx, + ) + .await?; + let mut channels_with_new_messages = HashSet::default(); for channel in graph.channels.iter() { - if self.has_note_changed(user_id, channel.id, tx).await? { - channels_with_changed_notes.insert(channel.id); - } if self.has_new_message(channel.id, user_id, tx).await? { channels_with_new_messages.insert(channel.id); } diff --git a/crates/collab/src/db/tables/observed_buffer_edits.rs b/crates/collab/src/db/tables/observed_buffer_edits.rs index db027f78b2..e8e7aafaa2 100644 --- a/crates/collab/src/db/tables/observed_buffer_edits.rs +++ b/crates/collab/src/db/tables/observed_buffer_edits.rs @@ -9,6 +9,7 @@ pub struct Model { pub buffer_id: BufferId, pub epoch: i32, pub lamport_timestamp: i32, + pub replica_id: i32, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/tests/buffer_tests.rs b/crates/collab/src/db/tests/buffer_tests.rs index 115a20ffa6..5a5fe6a812 100644 --- a/crates/collab/src/db/tests/buffer_tests.rs +++ b/crates/collab/src/db/tests/buffer_tests.rs @@ -272,3 +272,193 @@ async fn test_channel_buffers_diffs(db: &Database) { assert!(!db.test_has_note_changed(a_id, zed_id).await.unwrap()); assert!(!db.test_has_note_changed(b_id, zed_id).await.unwrap()); } + +test_both_dbs!( + test_channel_buffers_last_operations, + test_channel_buffers_last_operations_postgres, + test_channel_buffers_last_operations_sqlite +); + +async fn test_channel_buffers_last_operations(db: &Database) { + let user_id = db + .create_user( + "user_a@example.com", + false, + NewUserParams { + github_login: "user_a".into(), + github_user_id: 101, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + let owner_id = db.create_server("production").await.unwrap().0 as u32; + let connection_id = ConnectionId { + owner_id, + id: user_id.0 as u32, + }; + + let mut buffers = Vec::new(); + let mut text_buffers = Vec::new(); + for i in 0..3 { + let channel = db + .create_root_channel(&format!("channel-{i}"), &format!("room-{i}"), user_id) + .await + .unwrap(); + + db.join_channel_buffer(channel, user_id, connection_id) + .await + .unwrap(); + + buffers.push( + db.transaction(|tx| async move { db.get_channel_buffer(channel, &*tx).await }) + .await + .unwrap(), + ); + + text_buffers.push(Buffer::new(0, 0, "".to_string())); + } + + let operations = db + .transaction(|tx| { + let buffers = &buffers; + async move { + db.get_last_operations_for_buffers([buffers[0].id, buffers[2].id], &*tx) + .await + } + }) + .await + .unwrap(); + + assert!(operations.is_empty()); + + update_buffer( + buffers[0].channel_id, + user_id, + db, + vec![ + text_buffers[0].edit([(0..0, "a")]), + text_buffers[0].edit([(0..0, "b")]), + text_buffers[0].edit([(0..0, "c")]), + ], + ) + .await; + + update_buffer( + buffers[1].channel_id, + user_id, + db, + vec![ + text_buffers[1].edit([(0..0, "d")]), + text_buffers[1].edit([(1..1, "e")]), + text_buffers[1].edit([(2..2, "f")]), + ], + ) + .await; + + // cause buffer 1's epoch to increment. + db.leave_channel_buffer(buffers[1].channel_id, connection_id) + .await + .unwrap(); + db.join_channel_buffer(buffers[1].channel_id, user_id, connection_id) + .await + .unwrap(); + text_buffers[1] = Buffer::new(1, 0, "def".to_string()); + update_buffer( + buffers[1].channel_id, + user_id, + db, + vec![ + text_buffers[1].edit([(0..0, "g")]), + text_buffers[1].edit([(0..0, "h")]), + ], + ) + .await; + + update_buffer( + buffers[2].channel_id, + user_id, + db, + vec![text_buffers[2].edit([(0..0, "i")])], + ) + .await; + + let operations = db + .transaction(|tx| { + let buffers = &buffers; + async move { + db.get_last_operations_for_buffers([buffers[1].id, buffers[2].id], &*tx) + .await + } + }) + .await + .unwrap(); + assert_operations( + &operations, + &[ + (buffers[1].id, 1, &text_buffers[1]), + (buffers[2].id, 0, &text_buffers[2]), + ], + ); + + let operations = db + .transaction(|tx| { + let buffers = &buffers; + async move { + db.get_last_operations_for_buffers([buffers[0].id, buffers[1].id], &*tx) + .await + } + }) + .await + .unwrap(); + assert_operations( + &operations, + &[ + (buffers[0].id, 0, &text_buffers[0]), + (buffers[1].id, 1, &text_buffers[1]), + ], + ); + + async fn update_buffer( + channel_id: ChannelId, + user_id: UserId, + db: &Database, + operations: Vec, + ) { + let operations = operations + .into_iter() + .map(|op| proto::serialize_operation(&language::Operation::Buffer(op))) + .collect::>(); + db.update_channel_buffer(channel_id, user_id, &operations) + .await + .unwrap(); + } + + fn assert_operations( + operations: &[buffer_operation::Model], + expected: &[(BufferId, i32, &text::Buffer)], + ) { + let actual = operations + .iter() + .map(|op| buffer_operation::Model { + buffer_id: op.buffer_id, + epoch: op.epoch, + lamport_timestamp: op.lamport_timestamp, + replica_id: op.replica_id, + value: vec![], + }) + .collect::>(); + let expected = expected + .iter() + .map(|(buffer_id, epoch, buffer)| buffer_operation::Model { + buffer_id: *buffer_id, + epoch: *epoch, + lamport_timestamp: buffer.lamport_clock.value as i32 - 1, + replica_id: buffer.replica_id() as i32, + value: vec![], + }) + .collect::>(); + assert_eq!(actual, expected, "unexpected operations") + } +}