diff --git a/server/src/db.rs b/server/src/db.rs index 1e489aae36..2ae8fc8f1d 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -162,7 +162,7 @@ impl Db { FROM users, channel_memberships WHERE - users.id IN $1 AND + users.id = ANY ($1) AND channel_memberships.user_id = users.id AND channel_memberships.channel_id IN ( SELECT channel_id diff --git a/server/src/rpc.rs b/server/src/rpc.rs index a44738e8ff..c869dd1aea 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -939,6 +939,7 @@ mod tests { language::LanguageRegistry, rpc::Client, settings, test, + user::UserStore, worktree::Worktree, }; use zrpc::Peer; @@ -1425,7 +1426,8 @@ mod tests { .await .unwrap(); - let channels_a = cx_a.add_model(|cx| ChannelList::new(client_a, cx)); + let user_store_a = Arc::new(UserStore::new(client_a.clone())); + let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx)); channels_a .condition(&mut cx_a, |list, _| list.available_channels().is_some()) .await; @@ -1445,11 +1447,12 @@ mod tests { channel_a .condition(&cx_a, |channel, _| { channel_messages(channel) - == [(user_id_b.to_proto(), "hello A, it's B.".to_string())] + == [("user_b".to_string(), "hello A, it's B.".to_string())] }) .await; - let channels_b = cx_b.add_model(|cx| ChannelList::new(client_b, cx)); + let user_store_b = Arc::new(UserStore::new(client_b.clone())); + let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx)); channels_b .condition(&mut cx_b, |list, _| list.available_channels().is_some()) .await; @@ -1470,7 +1473,7 @@ mod tests { channel_b .condition(&cx_b, |channel, _| { channel_messages(channel) - == [(user_id_b.to_proto(), "hello A, it's B.".to_string())] + == [("user_b".to_string(), "hello A, it's B.".to_string())] }) .await; @@ -1494,9 +1497,9 @@ mod tests { .condition(&cx_b, |channel, _| { channel_messages(channel) == [ - (user_id_b.to_proto(), "hello A, it's B.".to_string()), - (user_id_a.to_proto(), "oh, hi B.".to_string()), - (user_id_a.to_proto(), "sup".to_string()), + ("user_b".to_string(), "hello A, it's B.".to_string()), + ("user_a".to_string(), "oh, hi B.".to_string()), + ("user_a".to_string(), "sup".to_string()), ] }) .await; @@ -1517,11 +1520,11 @@ mod tests { .condition(|state| !state.channels.contains_key(&channel_id)) .await; - fn channel_messages(channel: &Channel) -> Vec<(u64, String)> { + fn channel_messages(channel: &Channel) -> Vec<(String, String)> { channel .messages() .cursor::<(), ()>() - .map(|m| (m.sender_id, m.body.clone())) + .map(|m| (m.sender.github_login.clone(), m.body.clone())) .collect() } } diff --git a/zed/src/channel.rs b/zed/src/channel.rs index a70b7dd068..73eb015884 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -1,5 +1,6 @@ use crate::{ rpc::{self, Client}, + user::{User, UserStore}, util::TryFutureExt, }; use anyhow::{anyhow, Context, Result}; @@ -9,7 +10,7 @@ use gpui::{ }; use postage::prelude::Stream; use std::{ - collections::{hash_map, HashMap}, + collections::{hash_map, HashMap, HashSet}, ops::Range, sync::Arc, }; @@ -22,6 +23,7 @@ pub struct ChannelList { available_channels: Option>, channels: HashMap>, rpc: Arc, + user_store: Arc, _task: Task>, } @@ -36,6 +38,7 @@ pub struct Channel { messages: SumTree, pending_messages: Vec, next_local_message_id: u64, + user_store: Arc, rpc: Arc, _subscription: rpc::Subscription, } @@ -43,8 +46,8 @@ pub struct Channel { #[derive(Clone, Debug, PartialEq)] pub struct ChannelMessage { pub id: u64, - pub sender_id: u64, pub body: String, + pub sender: Arc, } pub struct PendingChannelMessage { @@ -76,7 +79,11 @@ impl Entity for ChannelList { } impl ChannelList { - pub fn new(rpc: Arc, cx: &mut ModelContext) -> Self { + pub fn new( + user_store: Arc, + rpc: Arc, + cx: &mut ModelContext, + ) -> Self { let _task = cx.spawn(|this, mut cx| { let rpc = rpc.clone(); async move { @@ -114,6 +121,7 @@ impl ChannelList { Self { available_channels: None, channels: Default::default(), + user_store, rpc, _task, } @@ -136,8 +144,10 @@ impl ChannelList { .as_ref() .and_then(|channels| channels.iter().find(|details| details.id == id)) { + let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); - let channel = cx.add_model(|cx| Channel::new(details.clone(), rpc, cx)); + let channel = + cx.add_model(|cx| Channel::new(details.clone(), user_store, rpc, cx)); entry.insert(channel.downgrade()); Some(channel) } else { @@ -165,34 +175,58 @@ impl Entity for Channel { } impl Channel { - pub fn new(details: ChannelDetails, rpc: Arc, cx: &mut ModelContext) -> Self { + pub fn new( + details: ChannelDetails, + user_store: Arc, + rpc: Arc, + cx: &mut ModelContext, + ) -> Self { let _subscription = rpc.subscribe_from_model(details.id, cx, Self::handle_message_sent); { + let user_store = user_store.clone(); let rpc = rpc.clone(); let channel_id = details.id; - cx.spawn(|channel, mut cx| async move { - match rpc.request(proto::JoinChannel { channel_id }).await { - Ok(response) => channel.update(&mut cx, |channel, cx| { + cx.spawn(|channel, mut cx| { + async move { + let response = rpc.request(proto::JoinChannel { channel_id }).await?; + + let unique_user_ids = response + .messages + .iter() + .map(|m| m.sender_id) + .collect::>() + .into_iter() + .collect(); + user_store.load_users(unique_user_ids).await?; + + let mut messages = Vec::with_capacity(response.messages.len()); + for message in response.messages { + messages.push(ChannelMessage::from_proto(message, &user_store).await?); + } + + channel.update(&mut cx, |channel, cx| { let old_count = channel.messages.summary().count.0; - let new_count = response.messages.len(); + let new_count = messages.len(); + channel.messages = SumTree::new(); - channel - .messages - .extend(response.messages.into_iter().map(Into::into), &()); + channel.messages.extend(messages, &()); cx.emit(ChannelEvent::Message { old_range: 0..old_count, new_count, }); - }), - Err(error) => log::error!("error joining channel: {}", error), + }); + + Ok(()) } + .log_err() }) .detach(); } Self { details, + user_store, rpc, messages: Default::default(), pending_messages: Default::default(), @@ -210,11 +244,14 @@ impl Channel { local_id, body: body.clone(), }); + let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); cx.spawn(|this, mut cx| { async move { let request = rpc.request(proto::SendChannelMessage { channel_id, body }); let response = request.await?; + let sender = user_store.get_user(current_user_id).await?; + this.update(&mut cx, |this, cx| { if let Ok(i) = this .pending_messages @@ -224,8 +261,8 @@ impl Channel { this.insert_message( ChannelMessage { id: response.message_id, - sender_id: current_user_id, body, + sender, }, cx, ); @@ -267,11 +304,21 @@ impl Channel { _: Arc, cx: &mut ModelContext, ) -> Result<()> { + let user_store = self.user_store.clone(); let message = message .payload .message .ok_or_else(|| anyhow!("empty message"))?; - self.insert_message(message.into(), cx); + + cx.spawn(|this, mut cx| { + async move { + let message = ChannelMessage::from_proto(message, &user_store).await?; + this.update(&mut cx, |this, cx| this.insert_message(message, cx)); + Ok(()) + } + .log_err() + }) + .detach(); Ok(()) } @@ -307,13 +354,17 @@ impl From for ChannelDetails { } } -impl From for ChannelMessage { - fn from(message: proto::ChannelMessage) -> Self { - ChannelMessage { +impl ChannelMessage { + pub async fn from_proto( + message: proto::ChannelMessage, + user_store: &UserStore, + ) -> Result { + let sender = user_store.get_user(message.sender_id).await?; + Ok(ChannelMessage { id: message.id, - sender_id: message.sender_id, body: message.body, - } + sender, + }) } } @@ -368,15 +419,16 @@ mod tests { let user_id = 5; let client = Client::new(); let mut server = FakeServer::for_client(user_id, &client, &cx).await; + let user_store = Arc::new(UserStore::new(client.clone())); - let channel_list = cx.add_model(|cx| ChannelList::new(client.clone(), cx)); + let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx)); channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None)); // Get the available channels. - let message = server.receive::().await; + let get_channels = server.receive::().await; server .respond( - message.receipt(), + get_channels.receipt(), proto::GetChannelsResponse { channels: vec![proto::Channel { id: 5, @@ -404,10 +456,10 @@ mod tests { }) .unwrap(); channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty())); - let message = server.receive::().await; + let join_channel = server.receive::().await; server .respond( - message.receipt(), + join_channel.receipt(), proto::JoinChannelResponse { messages: vec![ proto::ChannelMessage { @@ -420,12 +472,36 @@ mod tests { id: 11, body: "b".into(), timestamp: 1001, - sender_id: 5, + sender_id: 6, }, ], }, ) .await; + // Client requests all users for the received messages + let mut get_users = server.receive::().await; + get_users.payload.user_ids.sort(); + assert_eq!(get_users.payload.user_ids, vec![5, 6]); + server + .respond( + get_users.receipt(), + proto::GetUsersResponse { + users: vec![ + proto::User { + id: 5, + github_login: "nathansobo".into(), + avatar_url: "http://avatar.com/nathansobo".into(), + }, + proto::User { + id: 6, + github_login: "maxbrunsfeld".into(), + avatar_url: "http://avatar.com/maxbrunsfeld".into(), + }, + ], + }, + ) + .await; + assert_eq!( channel.next_event(&cx).await, ChannelEvent::Message { @@ -437,9 +513,12 @@ mod tests { assert_eq!( channel .messages_in_range(0..2) - .map(|message| &message.body) + .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), - &["a", "b"] + &[ + ("nathansobo".into(), "a".into()), + ("maxbrunsfeld".into(), "b".into()) + ] ); }); @@ -451,10 +530,27 @@ mod tests { id: 12, body: "c".into(), timestamp: 1002, - sender_id: 5, + sender_id: 7, }), }) .await; + + // Client requests user for message since they haven't seen them yet + let get_users = server.receive::().await; + assert_eq!(get_users.payload.user_ids, vec![7]); + server + .respond( + get_users.receipt(), + proto::GetUsersResponse { + users: vec![proto::User { + id: 7, + github_login: "as-cii".into(), + avatar_url: "http://avatar.com/as-cii".into(), + }], + }, + ) + .await; + assert_eq!( channel.next_event(&cx).await, ChannelEvent::Message { @@ -466,9 +562,9 @@ mod tests { assert_eq!( channel .messages_in_range(2..3) - .map(|message| &message.body) + .map(|message| (message.sender.github_login.clone(), message.body.clone())) .collect::>(), - &["c"] + &[("as-cii".into(), "c".into())] ) }) } diff --git a/zed/src/lib.rs b/zed/src/lib.rs index d674fd5ea6..d451c187d5 100644 --- a/zed/src/lib.rs +++ b/zed/src/lib.rs @@ -15,6 +15,7 @@ pub mod test; pub mod theme; pub mod theme_selector; mod time; +pub mod user; mod util; pub mod workspace; pub mod worktree; diff --git a/zed/src/main.rs b/zed/src/main.rs index 7eaa5ea838..2258f74686 100644 --- a/zed/src/main.rs +++ b/zed/src/main.rs @@ -12,6 +12,7 @@ use zed::{ chat_panel, editor, file_finder, fs::RealFs, language, menus, rpc, settings, theme_selector, + user::UserStore, workspace::{self, OpenParams, OpenPaths}, AppState, }; @@ -29,12 +30,13 @@ fn main() { app.run(move |cx| { let rpc = rpc::Client::new(); + let user_store = Arc::new(UserStore::new(rpc.clone())); let app_state = Arc::new(AppState { languages: languages.clone(), settings_tx: Arc::new(Mutex::new(settings_tx)), settings, themes, - channel_list: cx.add_model(|cx| ChannelList::new(rpc.clone(), cx)), + channel_list: cx.add_model(|cx| ChannelList::new(user_store, rpc.clone(), cx)), rpc, fs: Arc::new(RealFs), }); diff --git a/zed/src/test.rs b/zed/src/test.rs index 6fac3c8bc9..f406df1946 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -5,6 +5,7 @@ use crate::{ rpc, settings::{self, ThemeRegistry}, time::ReplicaId, + user::UserStore, AppState, Settings, }; use gpui::{AppContext, Entity, ModelHandle, MutableAppContext}; @@ -164,12 +165,13 @@ pub fn build_app_state(cx: &mut MutableAppContext) -> Arc { let languages = Arc::new(LanguageRegistry::new()); let themes = ThemeRegistry::new(()); let rpc = rpc::Client::new(); + let user_store = Arc::new(UserStore::new(rpc.clone())); Arc::new(AppState { settings_tx: Arc::new(Mutex::new(settings_tx)), settings, themes, languages: languages.clone(), - channel_list: cx.add_model(|cx| ChannelList::new(rpc.clone(), cx)), + channel_list: cx.add_model(|cx| ChannelList::new(user_store, rpc.clone(), cx)), rpc, fs: Arc::new(RealFs), }) diff --git a/zed/src/user.rs b/zed/src/user.rs new file mode 100644 index 0000000000..df98707a8e --- /dev/null +++ b/zed/src/user.rs @@ -0,0 +1,59 @@ +use crate::rpc::Client; +use anyhow::{anyhow, Result}; +use parking_lot::Mutex; +use std::{collections::HashMap, sync::Arc}; +use zrpc::proto; + +pub use proto::User; + +pub struct UserStore { + users: Mutex>>, + rpc: Arc, +} + +impl UserStore { + pub fn new(rpc: Arc) -> Self { + Self { + users: Default::default(), + rpc, + } + } + + pub async fn load_users(&self, mut user_ids: Vec) -> Result<()> { + { + let users = self.users.lock(); + user_ids.retain(|id| !users.contains_key(id)); + } + + if !user_ids.is_empty() { + let response = self.rpc.request(proto::GetUsers { user_ids }).await?; + let mut users = self.users.lock(); + for user in response.users { + users.insert(user.id, Arc::new(user)); + } + } + + Ok(()) + } + + pub async fn get_user(&self, user_id: u64) -> Result> { + if let Some(user) = self.users.lock().get(&user_id).cloned() { + return Ok(user); + } + + let response = self + .rpc + .request(proto::GetUsers { + user_ids: vec![user_id], + }) + .await?; + + if let Some(user) = response.users.into_iter().next() { + let user = Arc::new(user); + self.users.lock().insert(user_id, user.clone()); + Ok(user) + } else { + Err(anyhow!("server responded with no users")) + } + } +}