diff --git a/Cargo.lock b/Cargo.lock index ded74ab07e..7973316c9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -836,7 +836,7 @@ dependencies = [ "target_build_utils", "term", "toml 0.4.10", - "uuid", + "uuid 0.5.1", "walkdir", ] @@ -884,7 +884,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e7fb075b9b54e939006aa12e1f6cd2d3194041ff4ebe7f2efcbedf18f25b667" dependencies = [ "byteorder", - "uuid", + "uuid 0.5.1", ] [[package]] @@ -2963,7 +2963,7 @@ dependencies = [ "byteorder", "cfb", "encoding", - "uuid", + "uuid 0.5.1", ] [[package]] @@ -4784,6 +4784,7 @@ dependencies = [ "thiserror", "time 0.2.25", "url", + "uuid 0.8.2", "webpki", "webpki-roots", "whoami", @@ -5606,6 +5607,12 @@ dependencies = [ "sha1 0.2.0", ] +[[package]] +name = "uuid" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" + [[package]] name = "value-bag" version = "1.0.0-alpha.7" diff --git a/server/Cargo.toml b/server/Cargo.toml index b73c70102a..b295ff21ac 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -5,6 +5,9 @@ edition = "2018" name = "zed-server" version = "0.1.0" +[[bin]] +name = "zed-server" + [[bin]] name = "seed" required-features = ["seed-support"] @@ -47,7 +50,7 @@ default-features = false [dependencies.sqlx] version = "0.5.2" -features = ["runtime-async-std-rustls", "postgres", "time"] +features = ["runtime-async-std-rustls", "postgres", "time", "uuid"] [dev-dependencies] gpui = { path = "../gpui" } diff --git a/server/src/bin/seed.rs b/server/src/bin/seed.rs index b259dc4c14..d2427d495c 100644 --- a/server/src/bin/seed.rs +++ b/server/src/bin/seed.rs @@ -73,7 +73,7 @@ async fn main() { for timestamp in timestamps { let sender_id = *zed_user_ids.choose(&mut rng).unwrap(); let body = lipsum::lipsum_words(rng.gen_range(1..=50)); - db.create_channel_message(channel_id, sender_id, &body, timestamp) + db.create_channel_message(channel_id, sender_id, &body, timestamp, rng.gen()) .await .expect("failed to insert message"); } diff --git a/server/src/db.rs b/server/src/db.rs index c3e270bc87..14ad85b68a 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -1,7 +1,7 @@ use anyhow::Context; use async_std::task::{block_on, yield_now}; use serde::Serialize; -use sqlx::{FromRow, Result}; +use sqlx::{types::Uuid, FromRow, Result}; use time::OffsetDateTime; pub use async_sqlx_session::PostgresSessionStore as SessionStore; @@ -402,11 +402,13 @@ impl Db { sender_id: UserId, body: &str, timestamp: OffsetDateTime, + nonce: u128, ) -> Result { test_support!(self, { let query = " - INSERT INTO channel_messages (channel_id, sender_id, body, sent_at) - VALUES ($1, $2, $3, $4) + INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce RETURNING id "; sqlx::query_scalar(query) @@ -414,6 +416,7 @@ impl Db { .bind(sender_id.0) .bind(body) .bind(timestamp) + .bind(Uuid::from_u128(nonce)) .fetch_one(&self.pool) .await .map(MessageId) @@ -430,7 +433,7 @@ impl Db { let query = r#" SELECT * FROM ( SELECT - id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at + id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce FROM channel_messages WHERE @@ -514,6 +517,7 @@ pub struct ChannelMessage { pub sender_id: UserId, pub body: String, pub sent_at: OffsetDateTime, + pub nonce: Uuid, } #[cfg(test)] @@ -677,7 +681,7 @@ pub mod tests { let org = db.create_org("org", "org").await.unwrap(); let channel = db.create_org_channel(org, "channel").await.unwrap(); for i in 0..10 { - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc()) + db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) .await .unwrap(); } @@ -697,4 +701,34 @@ pub mod tests { ["1", "2", "3", "4"] ); } + + #[gpui::test] + async fn test_channel_message_nonces() { + let test_db = TestDb::new(); + let db = test_db.db(); + let user = db.create_user("user", false).await.unwrap(); + let org = db.create_org("org", "org").await.unwrap(); + let channel = db.create_org_channel(org, "channel").await.unwrap(); + + let msg1_id = db + .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg2_id = db + .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); + let msg3_id = db + .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg4_id = db + .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); + + assert_ne!(msg1_id, msg2_id); + assert_eq!(msg1_id, msg3_id); + assert_eq!(msg2_id, msg4_id); + } } diff --git a/server/src/rpc.rs b/server/src/rpc.rs index e6a48ae410..debd982366 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -602,6 +602,7 @@ impl Server { body: msg.body, timestamp: msg.sent_at.unix_timestamp() as u64, sender_id: msg.sender_id.to_proto(), + nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); self.peer @@ -687,10 +688,24 @@ impl Server { } let timestamp = OffsetDateTime::now_utc(); + let nonce = if let Some(nonce) = request.payload.nonce { + nonce + } else { + self.peer + .respond_with_error( + receipt, + proto::Error { + message: "nonce can't be blank".to_string(), + }, + ) + .await?; + return Ok(()); + }; + let message_id = self .app_state .db - .create_channel_message(channel_id, user_id, &body, timestamp) + .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into()) .await? .to_proto(); let message = proto::ChannelMessage { @@ -698,6 +713,7 @@ impl Server { id: message_id, body, timestamp: timestamp.unix_timestamp() as u64, + nonce: Some(nonce), }; broadcast(request.sender_id, connection_ids, |conn_id| { self.peer.send( @@ -754,6 +770,7 @@ impl Server { body: msg.body, timestamp: msg.sent_at.unix_timestamp() as u64, sender_id: msg.sender_id.to_proto(), + nonce: Some(msg.nonce.as_u128().into()), }) .collect::>(); self.peer @@ -1513,6 +1530,7 @@ mod tests { current_user_id(&user_store_b), "hello A, it's B.", OffsetDateTime::now_utc(), + 1, ) .await .unwrap(); @@ -1707,6 +1725,7 @@ mod tests { current_user_id(&user_store_b), "hello A, it's B.", OffsetDateTime::now_utc(), + 2, ) .await .unwrap(); @@ -1787,6 +1806,24 @@ mod tests { ) }); + // Send a message from client B while it is disconnected. + channel_b + .update(&mut cx_b, |channel, cx| { + let task = channel + .send_message("can you see this?".to_string(), cx) + .unwrap(); + assert_eq!( + channel_messages(channel), + &[ + ("user_b".to_string(), "hello A, it's B.".to_string(), false), + ("user_b".to_string(), "can you see this?".to_string(), true) + ] + ); + task + }) + .await + .unwrap_err(); + // Send a message from client A while B is disconnected. channel_a .update(&mut cx_a, |channel, cx| { @@ -1812,7 +1849,8 @@ mod tests { server.allow_connections(); cx_b.foreground().advance_clock(Duration::from_secs(10)); - // Verify that B sees the new messages upon reconnection. + // Verify that B sees the new messages upon reconnection, as well as the message client B + // sent while offline. channel_b .condition(&cx_b, |channel, _| { channel_messages(channel) @@ -1820,6 +1858,7 @@ mod tests { ("user_b".to_string(), "hello A, it's B.".to_string(), false), ("user_a".to_string(), "oh, hi B.".to_string(), false), ("user_a".to_string(), "sup".to_string(), false), + ("user_b".to_string(), "can you see this?".to_string(), false), ] }) .await; @@ -1838,6 +1877,7 @@ mod tests { ("user_b".to_string(), "hello A, it's B.".to_string(), false), ("user_a".to_string(), "oh, hi B.".to_string(), false), ("user_a".to_string(), "sup".to_string(), false), + ("user_b".to_string(), "can you see this?".to_string(), false), ("user_a".to_string(), "you online?".to_string(), false), ] }) @@ -1856,6 +1896,7 @@ mod tests { ("user_b".to_string(), "hello A, it's B.".to_string(), false), ("user_a".to_string(), "oh, hi B.".to_string(), false), ("user_a".to_string(), "sup".to_string(), false), + ("user_b".to_string(), "can you see this?".to_string(), false), ("user_a".to_string(), "you online?".to_string(), false), ("user_b".to_string(), "yep".to_string(), false), ] diff --git a/zed/src/channel.rs b/zed/src/channel.rs index ed7fc4d6c9..c43cf2e6f7 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -9,6 +9,7 @@ use gpui::{ Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle, }; use postage::prelude::Stream; +use rand::prelude::*; use std::{ collections::{HashMap, HashSet}, mem, @@ -42,6 +43,7 @@ pub struct Channel { next_pending_message_id: usize, user_store: Arc, rpc: Arc, + rng: StdRng, _subscription: rpc::Subscription, } @@ -51,6 +53,7 @@ pub struct ChannelMessage { pub body: String, pub timestamp: OffsetDateTime, pub sender: Arc, + pub nonce: u128, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -218,6 +221,7 @@ impl Channel { messages: Default::default(), loaded_all_messages: false, next_pending_message_id: 0, + rng: StdRng::from_entropy(), _subscription, } } @@ -242,6 +246,7 @@ impl Channel { let channel_id = self.details.id; let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id)); + let nonce = self.rng.gen(); self.insert_messages( SumTree::from_item( ChannelMessage { @@ -249,6 +254,7 @@ impl Channel { body: body.clone(), sender: current_user, timestamp: OffsetDateTime::now_utc(), + nonce, }, &(), ), @@ -257,7 +263,11 @@ impl Channel { let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); Ok(cx.spawn(|this, mut cx| async move { - let request = rpc.request(proto::SendChannelMessage { channel_id, body }); + let request = rpc.request(proto::SendChannelMessage { + channel_id, + body, + nonce: Some(nonce.into()), + }); let response = request.await?; let message = ChannelMessage::from_proto( response.message.ok_or_else(|| anyhow!("invalid message"))?, @@ -265,7 +275,6 @@ impl Channel { ) .await?; this.update(&mut cx, |this, cx| { - this.remove_message(pending_id, cx); this.insert_messages(SumTree::from_item(message, &()), cx); Ok(()) }) @@ -312,32 +321,51 @@ impl Channel { let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); let channel_id = self.details.id; - cx.spawn(|channel, mut cx| { + cx.spawn(|this, mut cx| { async move { let response = rpc.request(proto::JoinChannel { channel_id }).await?; let messages = messages_from_proto(response.messages, &user_store).await?; let loaded_all_messages = response.done; - channel.update(&mut cx, |channel, cx| { + let pending_messages = this.update(&mut cx, |this, cx| { if let Some((first_new_message, last_old_message)) = - messages.first().zip(channel.messages.last()) + messages.first().zip(this.messages.last()) { if first_new_message.id > last_old_message.id { - let old_messages = mem::take(&mut channel.messages); + let old_messages = mem::take(&mut this.messages); cx.emit(ChannelEvent::MessagesUpdated { old_range: 0..old_messages.summary().count, new_count: 0, }); - channel.loaded_all_messages = loaded_all_messages; + this.loaded_all_messages = loaded_all_messages; } } - channel.insert_messages(messages, cx); + this.insert_messages(messages, cx); if loaded_all_messages { - channel.loaded_all_messages = loaded_all_messages; + this.loaded_all_messages = loaded_all_messages; } + + this.pending_messages().cloned().collect::>() }); + for pending_message in pending_messages { + let request = rpc.request(proto::SendChannelMessage { + channel_id, + body: pending_message.body, + nonce: Some(pending_message.nonce.into()), + }); + let response = request.await?; + let message = ChannelMessage::from_proto( + response.message.ok_or_else(|| anyhow!("invalid message"))?, + &user_store, + ) + .await?; + this.update(&mut cx, |this, cx| { + this.insert_messages(SumTree::from_item(message, &()), cx); + }); + } + Ok(()) } .log_err() @@ -365,6 +393,12 @@ impl Channel { cursor.take(range.len()) } + pub fn pending_messages(&self) -> impl Iterator { + let mut cursor = self.messages.cursor::(); + cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &()); + cursor + } + fn handle_message_sent( &mut self, message: TypedEnvelope, @@ -391,29 +425,13 @@ impl Channel { Ok(()) } - fn remove_message(&mut self, message_id: ChannelMessageId, cx: &mut ModelContext) { - let mut old_cursor = self.messages.cursor::(); - let mut new_messages = old_cursor.slice(&message_id, Bias::Left, &()); - let start_ix = old_cursor.sum_start().0; - let removed_messages = old_cursor.slice(&message_id, Bias::Right, &()); - let removed_count = removed_messages.summary().count; - new_messages.push_tree(old_cursor.suffix(&()), &()); - - drop(old_cursor); - self.messages = new_messages; - - if removed_count > 0 { - let end_ix = start_ix + removed_count; - cx.emit(ChannelEvent::MessagesUpdated { - old_range: start_ix..end_ix, - new_count: 0, - }); - cx.notify(); - } - } - fn insert_messages(&mut self, messages: SumTree, cx: &mut ModelContext) { if let Some((first_message, last_message)) = messages.first().zip(messages.last()) { + let nonces = messages + .cursor::<(), ()>() + .map(|m| m.nonce) + .collect::>(); + let mut old_cursor = self.messages.cursor::(); let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &()); let start_ix = old_cursor.sum_start().0; @@ -423,10 +441,40 @@ impl Channel { let end_ix = start_ix + removed_count; new_messages.push_tree(messages, &()); - new_messages.push_tree(old_cursor.suffix(&()), &()); + + let mut ranges = Vec::>::new(); + if new_messages.last().unwrap().is_pending() { + new_messages.push_tree(old_cursor.suffix(&()), &()); + } else { + new_messages.push_tree( + old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()), + &(), + ); + + while let Some(message) = old_cursor.item() { + let message_ix = old_cursor.sum_start().0; + if nonces.contains(&message.nonce) { + if ranges.last().map_or(false, |r| r.end == message_ix) { + ranges.last_mut().unwrap().end += 1; + } else { + ranges.push(message_ix..message_ix + 1); + } + } else { + new_messages.push(message.clone(), &()); + } + old_cursor.next(&()); + } + } + drop(old_cursor); self.messages = new_messages; + for range in ranges.into_iter().rev() { + cx.emit(ChannelEvent::MessagesUpdated { + old_range: range, + new_count: 0, + }); + } cx.emit(ChannelEvent::MessagesUpdated { old_range: start_ix..end_ix, new_count, @@ -477,6 +525,10 @@ impl ChannelMessage { body: message.body, timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, sender, + nonce: message + .nonce + .ok_or_else(|| anyhow!("nonce is required"))? + .into(), }) } @@ -606,12 +658,14 @@ mod tests { body: "a".into(), timestamp: 1000, sender_id: 5, + nonce: Some(1.into()), }, proto::ChannelMessage { id: 11, body: "b".into(), timestamp: 1001, sender_id: 6, + nonce: Some(2.into()), }, ], done: false, @@ -665,6 +719,7 @@ mod tests { body: "c".into(), timestamp: 1002, sender_id: 7, + nonce: Some(3.into()), }), }) .await; @@ -720,12 +775,14 @@ mod tests { body: "y".into(), timestamp: 998, sender_id: 5, + nonce: Some(4.into()), }, proto::ChannelMessage { id: 9, body: "z".into(), timestamp: 999, sender_id: 6, + nonce: Some(5.into()), }, ], }, diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index c9f1dc0f80..4e42441eb2 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -151,6 +151,7 @@ message GetUsersResponse { message SendChannelMessage { uint64 channel_id = 1; string body = 2; + Nonce nonce = 3; } message SendChannelMessageResponse { @@ -296,6 +297,11 @@ message Range { uint64 end = 2; } +message Nonce { + uint64 upper_half = 1; + uint64 lower_half = 2; +} + message Channel { uint64 id = 1; string name = 2; @@ -306,4 +312,5 @@ message ChannelMessage { string body = 2; uint64 timestamp = 3; uint64 sender_id = 4; + Nonce nonce = 5; } diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index af9dbf3abc..b2d4de3bbf 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -248,3 +248,22 @@ impl From for Timestamp { } } } + +impl From for Nonce { + fn from(nonce: u128) -> Self { + let upper_half = (nonce >> 64) as u64; + let lower_half = nonce as u64; + Self { + upper_half, + lower_half, + } + } +} + +impl From for u128 { + fn from(nonce: Nonce) -> Self { + let upper_half = (nonce.upper_half as u128) << 64; + let lower_half = nonce.lower_half as u128; + upper_half | lower_half + } +}