From 9e03e9d6dfc4a1caef3f685ce2c49423c46c070f Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 19 Jul 2023 17:41:29 +0200 Subject: [PATCH] Sketch in a bi-directional sync (not yet tested) --- crates/crdb/src/crdb.rs | 248 +++++++++++++++++++++++--------- crates/crdb/src/messages.rs | 19 +++ crates/sum_tree/src/tree_map.rs | 33 ++++- 3 files changed, 231 insertions(+), 69 deletions(-) diff --git a/crates/crdb/src/crdb.rs b/crates/crdb/src/crdb.rs index b3b3bdf019..c8af454935 100644 --- a/crates/crdb/src/crdb.rs +++ b/crates/crdb/src/crdb.rs @@ -8,7 +8,7 @@ mod test; use anyhow::{anyhow, Result}; use collections::{btree_map, BTreeMap, Bound, HashMap}; use dense_id::DenseId; -use futures::{future::BoxFuture, FutureExt}; +use futures::{channel::mpsc, future::BoxFuture, FutureExt, StreamExt}; use messages::{MessageEnvelope, Operation, RequestEnvelope}; use operations::CreateBranch; use parking_lot::{Mutex, RwLock}; @@ -28,6 +28,8 @@ use sum_tree::{Bias, SumTree, TreeMap}; use util::ResultExt; use uuid::Uuid; +const CHUNK_SIZE: usize = 64; + #[derive( Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, )] @@ -139,46 +141,88 @@ pub trait ClientRoom: 'static + Send + Sync { fn handle_messages(&self, handle_message: impl 'static + Send + Fn(Vec)); } -struct Client { +pub trait Executor: 'static + Send + Sync { + fn spawn(&self, future: F) + where + F: 'static + Send + Future; +} + +struct Client { db: Db, network: Arc, - repo_rooms: Arc>>>, + checkouts: Arc>>>, + executor: Arc, } -struct RepoRoom { - network_room: N::Room, +struct Checkout { + repo: Repo, + network_room: Arc, + operations_tx: mpsc::UnboundedSender, message_handlers: - Arc, RepoId, Box)>>>>, + Arc, RepoId, Box)>>>>, } -impl RepoRoom { - fn new(client: Client, repo_id: RepoId, network_room: N::Room) -> Self { +impl Clone for Checkout { + fn clone(&self) -> Self { + Self { + repo: self.repo.clone(), + network_room: self.network_room.clone(), + operations_tx: self.operations_tx.clone(), + message_handlers: self.message_handlers.clone(), + } + } +} + +impl Checkout { + fn new(client: Client, repo: Repo, network_room: N::Room) -> Self { + let (operations_tx, operations_rx) = mpsc::unbounded(); let this = Self { - network_room, + repo: repo.clone(), + network_room: Arc::new(network_room), + operations_tx, message_handlers: Default::default(), }; { let handlers = this.message_handlers.clone(); + let client = client.clone(); this.network_room.handle_messages(move |message| { if let Some(envelope) = serde_bare::from_slice::(&message).log_err() { let message = envelope.unwrap(); if let Some(handler) = handlers.read().get(&message.as_ref().type_id()) { - handler(client.clone(), repo_id, message); + handler(client.clone(), repo.id, message); } }; }); } + client.executor.spawn({ + let this = this.clone(); + let client = client.clone(); + async move { + this.sync(&client).await.expect("network is infallible"); + let mut operations_rx = operations_rx.ready_chunks(CHUNK_SIZE); + while let Some(operations) = operations_rx.next().await { + client + .request(messages::PublishOperations { + repo_id: this.repo.id, + operations, + }) + .await + .expect("network is infallible"); + } + } + }); + this } fn handle_messages(&self, handle_message: H) where M: Message, - H: 'static + Fn(Client, RepoId, M) + Send + Sync, + H: 'static + Fn(Client, RepoId, M) + Send + Sync, { self.message_handlers.write().insert( TypeId::of::(), @@ -188,9 +232,38 @@ impl RepoRoom { ); } - fn broadcast(&self, message: M) { + fn broadcast(&self, message: &M) { self.network_room.broadcast(message.to_bytes()); } + + fn broadcast_operation(&self, operation: Operation) { + self.broadcast(&operation); + self.operations_tx.unbounded_send(operation).unwrap(); + } + + async fn sync(&self, client: &Client) -> Result<()> { + let response = client + .request(messages::SyncRepo { + id: self.repo.id, + max_operation_ids: self.repo.read(|repo| (&repo.max_operation_ids).into()), + }) + .await?; + + let operations = self + .repo + .operations_since(&(&response.max_operation_ids).into()); + + for chunk in operations.chunks(CHUNK_SIZE) { + client + .request(messages::PublishOperations { + repo_id: self.repo.id, + operations: chunk.to_vec(), + }) + .await?; + } + + Ok(()) + } } #[derive(Clone, Serialize, Deserialize)] @@ -199,22 +272,24 @@ pub struct RoomCredentials { token: RoomToken, } -impl Clone for Client { +impl Clone for Client { fn clone(&self) -> Self { Self { db: self.db.clone(), network: self.network.clone(), - repo_rooms: self.repo_rooms.clone(), + checkouts: self.checkouts.clone(), + executor: self.executor.clone(), } } } -impl Client { - pub fn new(network: N) -> Self { +impl Client { + pub fn new(executor: E, network: N) -> Self { let mut this = Self { db: Db::new(), network: Arc::new(network), - repo_rooms: Default::default(), + checkouts: Default::default(), + executor: Arc::new(executor), }; this.db.on_local_operation({ let this = this.clone(); @@ -240,23 +315,25 @@ impl Client { async move { let response = this.request(messages::CloneRepo { name }).await?; let repo_id = response.repo_id; - let room = RepoRoom::new( - this.clone(), - repo_id, - this.network.room(response.credentials), - ); - room.handle_messages(Self::handle_remote_operation); - this.repo_rooms.lock().insert(repo_id, room); + let repo = Repo { + id: repo_id, + db: this.db.clone(), + }; this.db .snapshot .lock() .repos .insert(repo_id, Default::default()); - Ok(Repo { - id: repo_id, - db: this.db.clone(), - }) + let checkout = Checkout::new( + this.clone(), + repo.clone(), + this.network.room(response.credentials), + ); + checkout.handle_messages(Self::handle_remote_operation); + this.checkouts.lock().insert(repo_id, checkout); + + Ok(repo) } } @@ -266,21 +343,27 @@ impl Client { name: impl Into>, ) -> impl Future> { let this = self.clone(); - let id = repo.id; let name = name.into(); + let repo = repo.clone(); async move { - let response = this.request(messages::PublishRepo { id, name }).await?; - let room = RepoRoom::new(this.clone(), id, this.network.room(response.credentials)); - room.handle_messages(Self::handle_remote_operation); - this.repo_rooms.lock().insert(id, room); + let response = this + .request(messages::PublishRepo { id: repo.id, name }) + .await?; + let checkout = Checkout::new( + this.clone(), + repo.clone(), + this.network.room(response.credentials), + ); + checkout.handle_messages(Self::handle_remote_operation); + this.checkouts.lock().insert(repo.id, checkout); Ok(()) } } fn handle_local_operation(&self, repo_id: RepoId, operation: Operation) { - if let Some(room) = self.repo_rooms.lock().get(&repo_id) { - room.broadcast(operation); + if let Some(checkout) = self.checkouts.lock().get(&repo_id) { + checkout.broadcast_operation(operation); } } @@ -454,35 +537,11 @@ impl Server { .get(&request.id) .ok_or_else(|| anyhow!("repo not found"))? .clone(); - let mut response = messages::SyncRepoResponse { - operations: Default::default(), - }; - for (replica_id, end_op_count) in repo.max_operation_ids.iter() { - let end_op = OperationId { - replica_id: *replica_id, - operation_count: *end_op_count, - }; - if let Some(start_op_count) = request.max_operation_ids.get(&replica_id) { - let start_op = OperationId { - replica_id: *replica_id, - operation_count: *start_op_count, - }; - response.operations.extend( - repo.operations - .range((Bound::Excluded(&start_op), Bound::Included(&end_op))) - .map(|(_, op)| op.clone()), - ); - } else { - let start_op = OperationId::new(*replica_id); - response.operations.extend( - repo.operations - .range((Bound::Included(&start_op), Bound::Included(&end_op))) - .map(|(_, op)| op.clone()), - ); - } - } - Ok(response) + Ok(messages::SyncRepoResponse { + operations: repo.operations_since(&(&request.max_operation_ids).into()), + max_operation_ids: (&repo.max_operation_ids).into(), + }) } } @@ -506,6 +565,17 @@ impl Db { ) { self.local_operation_created = Some(Arc::new(operation_created)); } + + fn repo(&self, id: RepoId) -> Option { + self.snapshot + .lock() + .repos + .contains_key(&id) + .then_some(Repo { + id, + db: self.clone(), + }) + } } #[derive(Clone)] @@ -523,6 +593,10 @@ impl Repo { } } + fn operations_since(&self, version: &TreeMap) -> Vec { + self.read(|repo| repo.operations_since(version)) + } + fn read(&self, f: F) -> T where F: FnOnce(&RepoSnapshot) -> T, @@ -549,8 +623,8 @@ impl Repo { repo.max_operation_ids.insert(replica_id, count); } - if let Some(operation_created) = self.db.local_operation_created.as_ref() { - operation_created(self.id, operation); + if let Some(local_operation_created) = self.db.local_operation_created.as_ref() { + local_operation_created(self.id, operation); } result }) @@ -1209,6 +1283,35 @@ impl RepoSnapshot { fn apply_operation(&mut self, operation: Operation) { todo!() } + + fn operations_since(&self, version: &TreeMap) -> Vec { + let mut new_operations = Vec::new(); + for (replica_id, end_op_count) in self.max_operation_ids.iter() { + let end_op = OperationId { + replica_id: *replica_id, + operation_count: *end_op_count, + }; + if let Some(start_op_count) = version.get(&replica_id) { + let start_op = OperationId { + replica_id: *replica_id, + operation_count: *start_op_count, + }; + new_operations.extend( + self.operations + .range((Bound::Excluded(&start_op), Bound::Included(&end_op))) + .map(|(_, op)| op.clone()), + ); + } else { + let start_op = OperationId::new(*replica_id); + new_operations.extend( + self.operations + .range((Bound::Included(&start_op), Bound::Included(&end_op))) + .map(|(_, op)| op.clone()), + ); + } + } + new_operations + } } #[derive(Clone, Debug)] @@ -1228,7 +1331,7 @@ struct Revision { #[cfg(test)] mod tests { - use gpui::executor::Deterministic; + use gpui::executor::{Background, Deterministic}; use super::*; use crate::test::TestNetwork; @@ -1238,7 +1341,7 @@ mod tests { let network = TestNetwork::new(deterministic.build_background()); let server = Server::new(network.server()); - let client_a = Client::new(network.client("client-a")); + let client_a = Client::new(deterministic.build_background(), network.client("client-a")); let repo_a = client_a.create_repo(); let branch_a = repo_a.create_empty_branch("main"); @@ -1252,7 +1355,16 @@ mod tests { assert_eq!(doc2.text().to_string(), "def"); client_a.publish_repo(&repo_a, "repo-1").await.unwrap(); - let db_b = Client::new(network.client("client-b")); + let db_b = Client::new(deterministic.build_background(), network.client("client-b")); let repo_b = db_b.clone_repo("repo-1").await.unwrap(); } + + impl Executor for Arc { + fn spawn(&self, future: F) + where + F: 'static + Send + Future, + { + todo!() + } + } } diff --git a/crates/crdb/src/messages.rs b/crates/crdb/src/messages.rs index 4d01b348dd..fc607baad7 100644 --- a/crates/crdb/src/messages.rs +++ b/crates/crdb/src/messages.rs @@ -11,6 +11,7 @@ pub enum RequestEnvelope { PublishRepo(PublishRepo), CloneRepo(CloneRepo), SyncRepo(SyncRepo), + PublishOperations(PublishOperations), } impl RequestEnvelope { @@ -19,6 +20,7 @@ impl RequestEnvelope { RequestEnvelope::PublishRepo(request) => Box::new(request), RequestEnvelope::CloneRepo(request) => Box::new(request), RequestEnvelope::SyncRepo(request) => Box::new(request), + RequestEnvelope::PublishOperations(request) => Box::new(request), } } } @@ -91,6 +93,23 @@ impl Into for SyncRepo { #[derive(Clone, Serialize, Deserialize)] pub struct SyncRepoResponse { pub operations: Vec, + pub max_operation_ids: BTreeMap, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct PublishOperations { + pub repo_id: RepoId, + pub operations: Vec, +} + +impl Request for PublishOperations { + type Response = (); +} + +impl Into for PublishOperations { + fn into(self) -> RequestEnvelope { + RequestEnvelope::PublishOperations(self) + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/crates/sum_tree/src/tree_map.rs b/crates/sum_tree/src/tree_map.rs index 14d00b0ce2..8a6095eb68 100644 --- a/crates/sum_tree/src/tree_map.rs +++ b/crates/sum_tree/src/tree_map.rs @@ -1,5 +1,6 @@ use std::{ cmp::Ordering, + collections::BTreeMap, fmt::Debug, ops::{Bound, RangeBounds}, }; @@ -29,7 +30,11 @@ pub struct TreeSet(TreeMap) where K: Clone + Debug + Default + Ord; -impl TreeMap { +impl TreeMap +where + K: Clone + Debug + Default + Ord, + V: Clone + Debug, +{ pub fn from_ordered_entries(entries: impl IntoIterator) -> Self { let tree = SumTree::from_iter( entries @@ -58,6 +63,10 @@ impl TreeMap { } } + pub fn contains_key<'a>(&self, key: &'a K) -> bool { + self.get(key).is_some() + } + pub fn insert(&mut self, key: K, value: V) { self.0.insert_or_replace(MapEntry { key, value }, &()); } @@ -192,6 +201,28 @@ impl TreeMap { } } +impl Into> for &TreeMap +where + K: Clone + Debug + Default + Ord, + V: Clone + Debug, +{ + fn into(self) -> BTreeMap { + self.iter() + .map(|(replica_id, count)| (replica_id.clone(), count.clone())) + .collect() + } +} + +impl From<&BTreeMap> for TreeMap +where + K: Clone + Debug + Default + Ord, + V: Clone + Debug, +{ + fn from(value: &BTreeMap) -> Self { + TreeMap::from_ordered_entries(value.into_iter().map(|(k, v)| (k.clone(), v.clone()))) + } +} + #[derive(Debug)] struct MapSeekTargetAdaptor<'a, T>(&'a T);