diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index e6bcf14bda..41a40617ee 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -1,6 +1,6 @@ use super::{http::HttpClient, proto, Client, Status, TypedEnvelope}; use anyhow::{anyhow, Result}; -use futures::{future, AsyncReadExt}; +use futures::{future, AsyncReadExt, Future}; use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task}; use postage::{prelude::Stream, sink::Sink, watch}; use rpc::proto::{RequestMessage, UsersResponse}; @@ -121,6 +121,13 @@ impl UserStore { user_ids.insert(contact.user_id); user_ids.extend(contact.projects.iter().flat_map(|w| &w.guests).copied()); } + user_ids.extend(message.pending_requests_to_user_ids.iter()); + user_ids.extend( + message + .pending_requests_from_user_ids + .iter() + .map(|req| req.user_id), + ); let load_users = self.get_users(user_ids.into_iter().collect(), cx); cx.spawn(|this, mut cx| async move { @@ -153,6 +160,39 @@ impl UserStore { .is_ok() } + pub fn request_contact(&self, to_user_id: u64) -> impl Future> { + let client = self.client.upgrade(); + async move { + client + .ok_or_else(|| anyhow!("not logged in"))? + .request(proto::RequestContact { to_user_id }) + .await?; + Ok(()) + } + } + + pub fn respond_to_contact_request( + &self, + from_user_id: u64, + accept: bool, + ) -> impl Future> { + let client = self.client.upgrade(); + async move { + client + .ok_or_else(|| anyhow!("not logged in"))? + .request(proto::RespondToContactRequest { + requesting_user_id: from_user_id, + response: if accept { + proto::ContactRequestResponse::Accept + } else { + proto::ContactRequestResponse::Reject + } as i32, + }) + .await?; + Ok(()) + } + } + pub fn get_users( &mut self, mut user_ids: Vec, diff --git a/crates/collab/migrations/20220506130724_create_contacts.sql b/crates/collab/migrations/20220506130724_create_contacts.sql index 216635b319..56beb70fd0 100644 --- a/crates/collab/migrations/20220506130724_create_contacts.sql +++ b/crates/collab/migrations/20220506130724_create_contacts.sql @@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS "contacts" ( "user_id_a" INTEGER REFERENCES users (id) NOT NULL, "user_id_b" INTEGER REFERENCES users (id) NOT NULL, "a_to_b" BOOLEAN NOT NULL, + "should_notify" BOOLEAN NOT NULL, "accepted" BOOLEAN NOT NULL ); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 0f2c700c2c..5d5f55fc92 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -19,6 +19,11 @@ pub trait Db: Send + Sync { async fn get_contacts(&self, id: UserId) -> Result; async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>; + async fn dismiss_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + ) -> Result<()>; async fn respond_to_contact_request( &self, responder_id: UserId, @@ -184,12 +189,12 @@ impl Db for PostgresDb { async fn get_contacts(&self, user_id: UserId) -> Result { let query = " - SELECT user_id_a, user_id_b, a_to_b, accepted + SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify FROM contacts WHERE user_id_a = $1 OR user_id_b = $1; "; - let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool)>(query) + let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query) .bind(user_id) .fetch(&self.pool); @@ -197,7 +202,7 @@ impl Db for PostgresDb { let mut requests_sent = Vec::new(); let mut requests_received = Vec::new(); while let Some(row) = rows.next().await { - let (user_id_a, user_id_b, a_to_b, accepted) = row?; + let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?; if user_id_a == user_id { if accepted { @@ -205,13 +210,19 @@ impl Db for PostgresDb { } else if a_to_b { requests_sent.push(user_id_b); } else { - requests_received.push(user_id_b); + requests_received.push(IncomingContactRequest { + requesting_user_id: user_id_b, + should_notify, + }); } } else { if accepted { current.push(user_id_a); } else if a_to_b { - requests_received.push(user_id_a); + requests_received.push(IncomingContactRequest { + requesting_user_id: user_id_a, + should_notify, + }); } else { requests_sent.push(user_id_a); } @@ -232,8 +243,8 @@ impl Db for PostgresDb { (receiver_id, sender_id, false) }; let query = " - INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted) - VALUES ($1, $2, $3, 'f') + INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify) + VALUES ($1, $2, $3, 'f', 't') ON CONFLICT (user_id_a, user_id_b) DO UPDATE SET accepted = 't' @@ -270,7 +281,7 @@ impl Db for PostgresDb { let result = if accept { let query = " UPDATE contacts - SET accepted = 't' + SET accepted = 't', should_notify = 'f' WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; "; sqlx::query(query) @@ -298,6 +309,37 @@ impl Db for PostgresDb { } } + async fn dismiss_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + ) -> Result<()> { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + + let query = " + UPDATE contacts + SET should_notify = 'f' + WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; + "; + + let result = sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await?; + + if result.rows_affected() == 0 { + Err(anyhow!("no such contact request"))?; + } + + Ok(()) + } + // access tokens async fn create_access_token_hash( @@ -628,7 +670,13 @@ pub struct ChannelMessage { pub struct Contacts { pub current: Vec, pub requests_sent: Vec, - pub requests_received: Vec, + pub requests_received: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct IncomingContactRequest { + pub requesting_user_id: UserId, + pub should_notify: bool, } fn fuzzy_like_string(string: &str) -> String { @@ -886,7 +934,28 @@ pub mod tests { Contacts { current: vec![], requests_sent: vec![], - requests_received: vec![user_1], + requests_received: vec![IncomingContactRequest { + requesting_user_id: user_1, + should_notify: true + }], + }, + ); + + // User 2 dismisses the contact request notification without accepting or rejecting. + // We shouldn't notify them again. + db.dismiss_contact_request(user_1, user_2) + .await + .unwrap_err(); + db.dismiss_contact_request(user_2, user_1).await.unwrap(); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + Contacts { + current: vec![], + requests_sent: vec![], + requests_received: vec![IncomingContactRequest { + requesting_user_id: user_1, + should_notify: false + }], }, ); @@ -1032,6 +1101,7 @@ pub mod tests { requester_id: UserId, responder_id: UserId, accepted: bool, + should_notify: bool, } impl FakeDb { @@ -1124,7 +1194,10 @@ pub mod tests { if contact.accepted { current.push(contact.requester_id); } else { - requests_received.push(contact.requester_id); + requests_received.push(IncomingContactRequest { + requesting_user_id: contact.requester_id, + should_notify: contact.should_notify, + }); } } } @@ -1162,10 +1235,29 @@ pub mod tests { requester_id, responder_id, accepted: false, + should_notify: true, }); Ok(()) } + async fn dismiss_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + ) -> Result<()> { + let mut contacts = self.contacts.lock(); + for contact in contacts.iter_mut() { + if contact.requester_id == requester_id && contact.responder_id == responder_id { + if contact.accepted { + return Err(anyhow!("contact already confirmed")); + } + contact.should_notify = false; + return Ok(()); + } + } + Err(anyhow!("no such contact request")) + } + async fn respond_to_contact_request( &self, responder_id: UserId, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 184592f033..0097d2580a 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -154,6 +154,8 @@ impl Server { .add_request_handler(Server::get_channels) .add_request_handler(Server::get_users) .add_request_handler(Server::fuzzy_search_users) + .add_request_handler(Server::request_contact) + .add_request_handler(Server::respond_to_contact_request) .add_request_handler(Server::join_channel) .add_message_handler(Server::leave_channel) .add_request_handler(Server::send_channel_message) @@ -914,6 +916,48 @@ impl Server { Ok(()) } + async fn request_contact( + self: Arc, + request: TypedEnvelope, + response: Response, + ) -> Result<()> { + let requester_id = self + .store + .read() + .await + .user_id_for_connection(request.sender_id)?; + let responder_id = UserId::from_proto(request.payload.to_user_id); + self.app_state + .db + .send_contact_request(requester_id, responder_id) + .await?; + response.send(proto::Ack {})?; + Ok(()) + } + + async fn respond_to_contact_request( + self: Arc, + request: TypedEnvelope, + response: Response, + ) -> Result<()> { + let responder_id = self + .store + .read() + .await + .user_id_for_connection(request.sender_id)?; + let requester_id = UserId::from_proto(request.payload.requesting_user_id); + self.app_state + .db + .respond_to_contact_request( + responder_id, + requester_id, + request.payload.response == proto::ContactRequestResponse::Accept as i32, + ) + .await?; + response.send(proto::Ack {})?; + Ok(()) + } + #[instrument(skip(self, state, user_ids))] fn update_contacts_for_users<'a>( self: &Arc, @@ -4911,6 +4955,33 @@ mod tests { } } + #[gpui::test(iterations = 10)] + async fn test_contacts_requests(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { + cx_a.foreground().forbid_parking(); + + // Connect to a server as 3 clients. + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + client_a + .user_store + .read_with(cx_a, |store, _| { + store.request_contact(client_b.user_id().unwrap()) + }) + .await + .unwrap(); + + client_a.user_store.read_with(cx_a, |store, _| { + let contacts = store + .contacts() + .iter() + .map(|contact| contact.user.github_login.clone()) + .collect::>(); + assert_eq!(contacts, &["user_b"]) + }); + } + #[gpui::test(iterations = 10)] async fn test_following(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); diff --git a/crates/rpc/src/macros.rs b/crates/rpc/src/macros.rs new file mode 100644 index 0000000000..38d35893ee --- /dev/null +++ b/crates/rpc/src/macros.rs @@ -0,0 +1,67 @@ +#[macro_export] +macro_rules! messages { + ($(($name:ident, $priority:ident)),* $(,)?) => { + pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option> { + match envelope.payload { + $(Some(envelope::Payload::$name(payload)) => { + Some(Box::new(TypedEnvelope { + sender_id, + original_sender_id: envelope.original_sender_id.map(PeerId), + message_id: envelope.id, + payload, + })) + }, )* + _ => None + } + } + + $( + impl EnvelopedMessage for $name { + const NAME: &'static str = std::stringify!($name); + const PRIORITY: MessagePriority = MessagePriority::$priority; + + fn into_envelope( + self, + id: u32, + responding_to: Option, + original_sender_id: Option, + ) -> Envelope { + Envelope { + id, + responding_to, + original_sender_id, + payload: Some(envelope::Payload::$name(self)), + } + } + + fn from_envelope(envelope: Envelope) -> Option { + if let Some(envelope::Payload::$name(msg)) = envelope.payload { + Some(msg) + } else { + None + } + } + } + )* + }; +} + +#[macro_export] +macro_rules! request_messages { + ($(($request_name:ident, $response_name:ident)),* $(,)?) => { + $(impl RequestMessage for $request_name { + type Response = $response_name; + })* + }; +} + +#[macro_export] +macro_rules! entity_messages { + ($id_field:ident, $($name:ident),* $(,)?) => { + $(impl EntityMessage for $name { + fn remote_entity_id(&self) -> u64 { + self.$id_field + } + })* + }; +} diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 0935dc6265..2674e8a0d8 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -1,4 +1,4 @@ -use super::{ConnectionId, PeerId, TypedEnvelope}; +use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope}; use anyhow::{anyhow, Result}; use async_tungstenite::tungstenite::Message as WebSocketMessage; use futures::{SinkExt as _, StreamExt as _}; @@ -73,71 +73,6 @@ impl AnyTypedEnvelope for TypedEnvelope { } } -macro_rules! messages { - ($(($name:ident, $priority:ident)),* $(,)?) => { - pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option> { - match envelope.payload { - $(Some(envelope::Payload::$name(payload)) => { - Some(Box::new(TypedEnvelope { - sender_id, - original_sender_id: envelope.original_sender_id.map(PeerId), - message_id: envelope.id, - payload, - })) - }, )* - _ => None - } - } - - $( - impl EnvelopedMessage for $name { - const NAME: &'static str = std::stringify!($name); - const PRIORITY: MessagePriority = MessagePriority::$priority; - - fn into_envelope( - self, - id: u32, - responding_to: Option, - original_sender_id: Option, - ) -> Envelope { - Envelope { - id, - responding_to, - original_sender_id, - payload: Some(envelope::Payload::$name(self)), - } - } - - fn from_envelope(envelope: Envelope) -> Option { - if let Some(envelope::Payload::$name(msg)) = envelope.payload { - Some(msg) - } else { - None - } - } - } - )* - }; -} - -macro_rules! request_messages { - ($(($request_name:ident, $response_name:ident)),* $(,)?) => { - $(impl RequestMessage for $request_name { - type Response = $response_name; - })* - }; -} - -macro_rules! entity_messages { - ($id_field:ident, $($name:ident),* $(,)?) => { - $(impl EntityMessage for $name { - fn remote_entity_id(&self) -> u64 { - self.$id_field - } - })* - }; -} - messages!( (Ack, Foreground), (AddProjectCollaborator, Foreground), @@ -198,6 +133,8 @@ messages!( (ReloadBuffersResponse, Foreground), (RemoveProjectCollaborator, Foreground), (RenameProjectEntry, Foreground), + (RequestContact, Foreground), + (RespondToContactRequest, Foreground), (SaveBuffer, Foreground), (SearchProject, Background), (SearchProjectResponse, Background), @@ -250,6 +187,8 @@ request_messages!( (RegisterProject, RegisterProjectResponse), (RegisterWorktree, Ack), (ReloadBuffers, ReloadBuffersResponse), + (RequestContact, Ack), + (RespondToContactRequest, Ack), (RenameProjectEntry, ProjectEntryResponse), (SaveBuffer, BufferSaved), (SearchProject, SearchProjectResponse), diff --git a/crates/rpc/src/rpc.rs b/crates/rpc/src/rpc.rs index ffddcb9cd3..f21a0ba76e 100644 --- a/crates/rpc/src/rpc.rs +++ b/crates/rpc/src/rpc.rs @@ -4,5 +4,6 @@ mod peer; pub mod proto; pub use conn::Connection; pub use peer::*; +mod macros; pub const PROTOCOL_VERSION: u32 = 16;