Start work on rejoining channel buffers

This commit is contained in:
Max Brunsfeld 2023-09-01 15:31:52 -07:00
parent 2bf417fa45
commit d370c72fbf
9 changed files with 526 additions and 163 deletions

View file

@ -10,8 +10,6 @@ impl Database {
connection: ConnectionId,
) -> Result<proto::JoinChannelBufferResponse> {
self.transaction(|tx| async move {
let tx = tx;
self.check_user_is_channel_member(channel_id, user_id, &tx)
.await?;
@ -70,7 +68,6 @@ impl Database {
.await?;
collaborators.push(collaborator);
// Assemble the buffer state
let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
Ok(proto::JoinChannelBufferResponse {
@ -78,6 +75,7 @@ impl Database {
replica_id: replica_id.to_proto() as u32,
base_text,
operations,
epoch: buffer.epoch as u64,
collaborators: collaborators
.into_iter()
.map(|collaborator| proto::Collaborator {
@ -91,6 +89,113 @@ impl Database {
.await
}
pub async fn rejoin_channel_buffers(
&self,
buffers: &[proto::ChannelBufferVersion],
user_id: UserId,
connection_id: ConnectionId,
) -> Result<proto::RejoinChannelBuffersResponse> {
self.transaction(|tx| async move {
let mut response = proto::RejoinChannelBuffersResponse::default();
for client_buffer in buffers {
let channel_id = ChannelId::from_proto(client_buffer.channel_id);
if self
.check_user_is_channel_member(channel_id, user_id, &*tx)
.await
.is_err()
{
log::info!("user is not a member of channel");
continue;
}
let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
let mut collaborators = channel_buffer_collaborator::Entity::find()
.filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
.all(&*tx)
.await?;
// If the buffer epoch hasn't changed since the client lost
// connection, then the client's buffer can be syncronized with
// the server's buffer.
if buffer.epoch as u64 != client_buffer.epoch {
continue;
}
// If there is still a disconnected collaborator for the user,
// update the connection associated with that collaborator, and reuse
// that replica id.
if let Some(ix) = collaborators
.iter()
.position(|c| c.user_id == user_id && c.connection_lost)
{
let self_collaborator = &mut collaborators[ix];
*self_collaborator = channel_buffer_collaborator::ActiveModel {
id: ActiveValue::Unchanged(self_collaborator.id),
connection_id: ActiveValue::Set(connection_id.id as i32),
connection_server_id: ActiveValue::Set(ServerId(
connection_id.owner_id as i32,
)),
connection_lost: ActiveValue::Set(false),
..Default::default()
}
.update(&*tx)
.await?;
} else {
continue;
}
let client_version = version_from_wire(&client_buffer.version);
let serialization_version = self
.get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
.await?;
let mut rows = buffer_operation::Entity::find()
.filter(
buffer_operation::Column::BufferId
.eq(buffer.id)
.and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
)
.stream(&*tx)
.await?;
// Find the server's version vector and any operations
// that the client has not seen.
let mut server_version = clock::Global::new();
let mut operations = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
let timestamp = clock::Lamport {
replica_id: row.replica_id as u16,
value: row.lamport_timestamp as u32,
};
server_version.observe(timestamp);
if !client_version.observed(timestamp) {
operations.push(proto::Operation {
variant: Some(operation_from_storage(row, serialization_version)?),
})
}
}
response.buffers.push(proto::RejoinedChannelBuffer {
channel_id: client_buffer.channel_id,
version: version_to_wire(&server_version),
operations,
collaborators: collaborators
.into_iter()
.map(|collaborator| proto::Collaborator {
peer_id: Some(collaborator.connection().into()),
user_id: collaborator.user_id.to_proto(),
replica_id: collaborator.replica_id.0 as u32,
})
.collect(),
});
}
Ok(response)
})
.await
}
pub async fn leave_channel_buffer(
&self,
channel_id: ChannelId,
@ -103,6 +208,39 @@ impl Database {
.await
}
pub async fn leave_channel_buffers(
&self,
connection: ConnectionId,
) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
self.transaction(|tx| async move {
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryChannelIds {
ChannelId,
}
let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
.select_only()
.column(channel_buffer_collaborator::Column::ChannelId)
.filter(Condition::all().add(
channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
))
.into_values::<_, QueryChannelIds>()
.all(&*tx)
.await?;
let mut result = Vec::new();
for channel_id in channel_ids {
let collaborators = self
.leave_channel_buffer_internal(channel_id, connection, &*tx)
.await?;
result.push((channel_id, collaborators));
}
Ok(result)
})
.await
}
pub async fn leave_channel_buffer_internal(
&self,
channel_id: ChannelId,
@ -143,45 +281,12 @@ impl Database {
drop(rows);
if connections.is_empty() {
self.snapshot_buffer(channel_id, &tx).await?;
self.snapshot_channel_buffer(channel_id, &tx).await?;
}
Ok(connections)
}
pub async fn leave_channel_buffers(
&self,
connection: ConnectionId,
) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
self.transaction(|tx| async move {
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryChannelIds {
ChannelId,
}
let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
.select_only()
.column(channel_buffer_collaborator::Column::ChannelId)
.filter(Condition::all().add(
channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
))
.into_values::<_, QueryChannelIds>()
.all(&*tx)
.await?;
let mut result = Vec::new();
for channel_id in channel_ids {
let collaborators = self
.leave_channel_buffer_internal(channel_id, connection, &*tx)
.await?;
result.push((channel_id, collaborators));
}
Ok(result)
})
.await
}
pub async fn get_channel_buffer_collaborators(
&self,
channel_id: ChannelId,
@ -224,20 +329,9 @@ impl Database {
.await?
.ok_or_else(|| anyhow!("no such buffer"))?;
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryVersion {
OperationSerializationVersion,
}
let serialization_version: i32 = buffer
.find_related(buffer_snapshot::Entity)
.select_only()
.column(buffer_snapshot::Column::OperationSerializationVersion)
.filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch))
.into_values::<_, QueryVersion>()
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("missing buffer snapshot"))?;
let serialization_version = self
.get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
.await?;
let operations = operations
.iter()
@ -270,6 +364,38 @@ impl Database {
.await
}
async fn get_buffer_operation_serialization_version(
&self,
buffer_id: BufferId,
epoch: i32,
tx: &DatabaseTransaction,
) -> Result<i32> {
Ok(buffer_snapshot::Entity::find()
.filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
.filter(buffer_snapshot::Column::Epoch.eq(epoch))
.select_only()
.column(buffer_snapshot::Column::OperationSerializationVersion)
.into_values::<_, QueryOperationSerializationVersion>()
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("missing buffer snapshot"))?)
}
async fn get_channel_buffer(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
) -> Result<buffer::Model> {
Ok(channel::Model {
id: channel_id,
..Default::default()
}
.find_related(buffer::Entity)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such buffer"))?)
}
async fn get_buffer_state(
&self,
buffer: &buffer::Model,
@ -303,27 +429,20 @@ impl Database {
.await?;
let mut operations = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
let operation = operation_from_storage(row, version)?;
operations.push(proto::Operation {
variant: Some(operation),
variant: Some(operation_from_storage(row?, version)?),
})
}
Ok((base_text, operations))
}
async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> {
let buffer = channel::Model {
id: channel_id,
..Default::default()
}
.find_related(buffer::Entity)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such buffer"))?;
async fn snapshot_channel_buffer(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
) -> Result<()> {
let buffer = self.get_channel_buffer(channel_id, tx).await?;
let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
if operations.is_empty() {
return Ok(());
@ -527,6 +646,22 @@ fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
version
}
fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
let mut message = Vec::new();
for entry in version.iter() {
message.push(proto::VectorClockEntry {
replica_id: entry.replica_id as u32,
timestamp: entry.value,
});
}
message
}
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryOperationSerializationVersion {
OperationSerializationVersion,
}
mod storage {
#![allow(non_snake_case)]
use prost::Message;

View file

@ -251,6 +251,7 @@ impl Server {
.add_request_handler(join_channel_buffer)
.add_request_handler(leave_channel_buffer)
.add_message_handler(update_channel_buffer)
.add_request_handler(rejoin_channel_buffers)
.add_request_handler(get_channel_members)
.add_request_handler(respond_to_channel_invite)
.add_request_handler(join_channel)
@ -854,13 +855,12 @@ async fn connection_lost(
.await
.trace_err();
leave_channel_buffers_for_session(&session)
.await
.trace_err();
futures::select_biased! {
_ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
leave_room_for_session(&session).await.trace_err();
leave_channel_buffers_for_session(&session)
.await
.trace_err();
if !session
.connection_pool()
@ -2547,6 +2547,23 @@ async fn update_channel_buffer(
Ok(())
}
async fn rejoin_channel_buffers(
request: proto::RejoinChannelBuffers,
response: Response<proto::RejoinChannelBuffers>,
session: Session,
) -> Result<()> {
let db = session.db().await;
let rejoin_response = db
.rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
.await?;
// TODO: inform channel buffer collaborators that this user has rejoined.
response.send(rejoin_response)?;
Ok(())
}
async fn leave_channel_buffer(
request: proto::LeaveChannelBuffer,
response: Response<proto::LeaveChannelBuffer>,

View file

@ -21,20 +21,19 @@ async fn test_core_channel_buffers(
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let zed_id = server
let channel_id = server
.make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)])
.await;
// Client A joins the channel buffer
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx))
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
// Client A edits the buffer
let buffer_a = channel_buffer_a.read_with(cx_a, |buffer, _| buffer.buffer());
buffer_a.update(cx_a, |buffer, cx| {
buffer.edit([(0..0, "hello world")], None, cx)
});
@ -45,17 +44,15 @@ async fn test_core_channel_buffers(
buffer.edit([(0..5, "goodbye")], None, cx)
});
buffer_a.update(cx_a, |buffer, cx| buffer.undo(cx));
deterministic.run_until_parked();
assert_eq!(buffer_text(&buffer_a, cx_a), "hello, cruel world");
deterministic.run_until_parked();
// Client B joins the channel buffer
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |channel, cx| channel.open_channel_buffer(zed_id, cx))
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
channel_buffer_b.read_with(cx_b, |buffer, _| {
assert_collaborators(
buffer.collaborators(),
@ -91,9 +88,7 @@ async fn test_core_channel_buffers(
// Client A rejoins the channel buffer
let _channel_buffer_a = client_a
.channel_store()
.update(cx_a, |channels, cx| {
channels.open_channel_buffer(zed_id, cx)
})
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
deterministic.run_until_parked();
@ -136,7 +131,7 @@ async fn test_channel_buffer_replica_ids(
let channel_id = server
.make_channel(
"zed",
"the-channel",
(&client_a, cx_a),
&mut [(&client_b, cx_b), (&client_c, cx_c)],
)
@ -160,23 +155,17 @@ async fn test_channel_buffer_replica_ids(
// C first so that the replica IDs in the project and the channel buffer are different
let channel_buffer_c = client_c
.channel_store()
.update(cx_c, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
@ -286,28 +275,30 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await;
let channel_id = server
.make_channel("the-channel", (&client_a, cx_a), &mut [])
.await;
let channel_buffer_1 = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
let channel_buffer_2 = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
let channel_buffer_3 = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
// All concurrent tasks for opening a channel buffer return the same model handle.
let (channel_buffer_1, channel_buffer_2, channel_buffer_3) =
let (channel_buffer, channel_buffer_2, channel_buffer_3) =
future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3)
.await
.unwrap();
let model_id = channel_buffer_1.id();
assert_eq!(channel_buffer_1, channel_buffer_2);
assert_eq!(channel_buffer_1, channel_buffer_3);
let channel_buffer_model_id = channel_buffer.id();
assert_eq!(channel_buffer, channel_buffer_2);
assert_eq!(channel_buffer, channel_buffer_3);
channel_buffer_1.update(cx_a, |buffer, cx| {
channel_buffer.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "hello")], None, cx);
})
@ -315,7 +306,7 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
deterministic.run_until_parked();
cx_a.update(|_| {
drop(channel_buffer_1);
drop(channel_buffer);
drop(channel_buffer_2);
drop(channel_buffer_3);
});
@ -324,10 +315,10 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
// The channel buffer can be reopened after dropping it.
let channel_buffer = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx))
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
assert_ne!(channel_buffer.id(), model_id);
assert_ne!(channel_buffer.id(), channel_buffer_model_id);
channel_buffer.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, _| {
assert_eq!(buffer.text(), "hello");
@ -347,22 +338,17 @@ async fn test_channel_buffer_disconnect(
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)])
.make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
.await;
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
@ -375,7 +361,7 @@ async fn test_channel_buffer_disconnect(
buffer.channel().as_ref(),
&Channel {
id: channel_id,
name: "zed".to_string()
name: "the-channel".to_string()
}
);
assert!(!buffer.is_connected());
@ -403,13 +389,81 @@ async fn test_channel_buffer_disconnect(
buffer.channel().as_ref(),
&Channel {
id: channel_id,
name: "zed".to_string()
name: "the-channel".to_string()
}
);
assert!(!buffer.is_connected());
});
}
#[gpui::test]
async fn test_rejoin_channel_buffer(
deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
.await;
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
channel_buffer_a.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "1")], None, cx);
})
});
deterministic.run_until_parked();
// Client A disconnects.
server.forbid_connections();
server.disconnect_client(client_a.peer_id().unwrap());
// deterministic.advance_clock(RECEIVE_TIMEOUT);
// Both clients make an edit. Both clients see their own edit.
channel_buffer_a.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(1..1, "2")], None, cx);
})
});
channel_buffer_b.update(cx_b, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "0")], None, cx);
})
});
deterministic.run_until_parked();
channel_buffer_a.read_with(cx_a, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "12");
});
channel_buffer_b.read_with(cx_b, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "01");
});
// Client A reconnects.
server.allow_connections();
deterministic.advance_clock(RECEIVE_TIMEOUT);
channel_buffer_a.read_with(cx_a, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "012");
});
channel_buffer_b.read_with(cx_b, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "012");
});
}
#[track_caller]
fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option<UserId>]) {
assert_eq!(