From 05a6bd914d6a907c5cb1e7d3f87ec42912c7ef2c Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 10 Nov 2022 11:03:52 -0800 Subject: [PATCH] Get integration tests passing with sqlite Co-authored-by: Antonio Scandurra --- crates/client/src/channel.rs | 820 ----------- crates/client/src/client.rs | 2 - .../20221109000000_test_schema.sql | 108 +- crates/collab/src/api.rs | 103 +- crates/collab/src/db.rs | 1296 +++++------------ crates/collab/src/db_tests.rs | 434 +----- crates/collab/src/integration_tests.rs | 551 +------ crates/collab/src/main.rs | 4 +- crates/collab/src/rpc.rs | 324 +---- crates/collab/src/rpc/store.rs | 131 +- crates/zed/src/main.rs | 1 - 11 files changed, 473 insertions(+), 3301 deletions(-) delete mode 100644 crates/client/src/channel.rs diff --git a/crates/client/src/channel.rs b/crates/client/src/channel.rs deleted file mode 100644 index 7b4f6073ce..0000000000 --- a/crates/client/src/channel.rs +++ /dev/null @@ -1,820 +0,0 @@ -use super::{ - proto, - user::{User, UserStore}, - Client, Status, Subscription, TypedEnvelope, -}; -use anyhow::{anyhow, Context, Result}; -use futures::lock::Mutex; -use gpui::{ - AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle, -}; -use postage::prelude::Stream; -use rand::prelude::*; -use std::{ - collections::{HashMap, HashSet}, - mem, - ops::Range, - sync::Arc, -}; -use sum_tree::{Bias, SumTree}; -use time::OffsetDateTime; -use util::{post_inc, ResultExt as _, TryFutureExt}; - -pub struct ChannelList { - available_channels: Option>, - channels: HashMap>, - client: Arc, - user_store: ModelHandle, - _task: Task>, -} - -#[derive(Clone, Debug, PartialEq)] -pub struct ChannelDetails { - pub id: u64, - pub name: String, -} - -pub struct Channel { - details: ChannelDetails, - messages: SumTree, - loaded_all_messages: bool, - next_pending_message_id: usize, - user_store: ModelHandle, - rpc: Arc, - outgoing_messages_lock: Arc>, - rng: StdRng, - _subscription: Subscription, -} - -#[derive(Clone, Debug)] -pub struct ChannelMessage { - pub id: ChannelMessageId, - pub body: String, - pub timestamp: OffsetDateTime, - pub sender: Arc, - pub nonce: u128, -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum ChannelMessageId { - Saved(u64), - Pending(usize), -} - -#[derive(Clone, Debug, Default)] -pub struct ChannelMessageSummary { - max_id: ChannelMessageId, - count: usize, -} - -#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] -struct Count(usize); - -pub enum ChannelListEvent {} - -#[derive(Clone, Debug, PartialEq)] -pub enum ChannelEvent { - MessagesUpdated { - old_range: Range, - new_count: usize, - }, -} - -impl Entity for ChannelList { - type Event = ChannelListEvent; -} - -impl ChannelList { - pub fn new( - user_store: ModelHandle, - rpc: Arc, - cx: &mut ModelContext, - ) -> Self { - let _task = cx.spawn_weak(|this, mut cx| { - let rpc = rpc.clone(); - async move { - let mut status = rpc.status(); - while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) { - match status { - Status::Connected { .. } => { - let response = rpc - .request(proto::GetChannels {}) - .await - .context("failed to fetch available channels")?; - this.update(&mut cx, |this, cx| { - this.available_channels = - Some(response.channels.into_iter().map(Into::into).collect()); - - let mut to_remove = Vec::new(); - for (channel_id, channel) in &this.channels { - if let Some(channel) = channel.upgrade(cx) { - channel.update(cx, |channel, cx| channel.rejoin(cx)) - } else { - to_remove.push(*channel_id); - } - } - - for channel_id in to_remove { - this.channels.remove(&channel_id); - } - cx.notify(); - }); - } - Status::SignedOut { .. } => { - this.update(&mut cx, |this, cx| { - this.available_channels = None; - this.channels.clear(); - cx.notify(); - }); - } - _ => {} - } - } - Ok(()) - } - .log_err() - }); - - Self { - available_channels: None, - channels: Default::default(), - user_store, - client: rpc, - _task, - } - } - - pub fn available_channels(&self) -> Option<&[ChannelDetails]> { - self.available_channels.as_deref() - } - - pub fn get_channel( - &mut self, - id: u64, - cx: &mut MutableAppContext, - ) -> Option> { - if let Some(channel) = self.channels.get(&id).and_then(|c| c.upgrade(cx)) { - return Some(channel); - } - - let channels = self.available_channels.as_ref()?; - let details = channels.iter().find(|details| details.id == id)?.clone(); - let channel = cx.add_model(|cx| { - Channel::new(details, self.user_store.clone(), self.client.clone(), cx) - }); - self.channels.insert(id, channel.downgrade()); - Some(channel) - } -} - -impl Entity for Channel { - type Event = ChannelEvent; - - fn release(&mut self, _: &mut MutableAppContext) { - self.rpc - .send(proto::LeaveChannel { - channel_id: self.details.id, - }) - .log_err(); - } -} - -impl Channel { - pub fn init(rpc: &Arc) { - rpc.add_model_message_handler(Self::handle_message_sent); - } - - pub fn new( - details: ChannelDetails, - user_store: ModelHandle, - rpc: Arc, - cx: &mut ModelContext, - ) -> Self { - let _subscription = rpc.add_model_for_remote_entity(details.id, cx); - - { - let user_store = user_store.clone(); - let rpc = rpc.clone(); - let channel_id = details.id; - cx.spawn(|channel, mut cx| { - async move { - let response = rpc.request(proto::JoinChannel { channel_id }).await?; - let messages = - messages_from_proto(response.messages, &user_store, &mut cx).await?; - let loaded_all_messages = response.done; - - channel.update(&mut cx, |channel, cx| { - channel.insert_messages(messages, cx); - channel.loaded_all_messages = loaded_all_messages; - }); - - Ok(()) - } - .log_err() - }) - .detach(); - } - - Self { - details, - user_store, - rpc, - outgoing_messages_lock: Default::default(), - messages: Default::default(), - loaded_all_messages: false, - next_pending_message_id: 0, - rng: StdRng::from_entropy(), - _subscription, - } - } - - pub fn name(&self) -> &str { - &self.details.name - } - - pub fn send_message( - &mut self, - body: String, - cx: &mut ModelContext, - ) -> Result>> { - if body.is_empty() { - Err(anyhow!("message body can't be empty"))?; - } - - let current_user = self - .user_store - .read(cx) - .current_user() - .ok_or_else(|| anyhow!("current_user is not present"))?; - - let channel_id = self.details.id; - let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id)); - let nonce = self.rng.gen(); - self.insert_messages( - SumTree::from_item( - ChannelMessage { - id: pending_id, - body: body.clone(), - sender: current_user, - timestamp: OffsetDateTime::now_utc(), - nonce, - }, - &(), - ), - cx, - ); - let user_store = self.user_store.clone(); - let rpc = self.rpc.clone(); - let outgoing_messages_lock = self.outgoing_messages_lock.clone(); - Ok(cx.spawn(|this, mut cx| async move { - let outgoing_message_guard = outgoing_messages_lock.lock().await; - let request = rpc.request(proto::SendChannelMessage { - channel_id, - body, - nonce: Some(nonce.into()), - }); - let response = request.await?; - drop(outgoing_message_guard); - let message = ChannelMessage::from_proto( - response.message.ok_or_else(|| anyhow!("invalid message"))?, - &user_store, - &mut cx, - ) - .await?; - this.update(&mut cx, |this, cx| { - this.insert_messages(SumTree::from_item(message, &()), cx); - Ok(()) - }) - })) - } - - pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> bool { - if !self.loaded_all_messages { - let rpc = self.rpc.clone(); - let user_store = self.user_store.clone(); - let channel_id = self.details.id; - if let Some(before_message_id) = - self.messages.first().and_then(|message| match message.id { - ChannelMessageId::Saved(id) => Some(id), - ChannelMessageId::Pending(_) => None, - }) - { - cx.spawn(|this, mut cx| { - async move { - let response = rpc - .request(proto::GetChannelMessages { - channel_id, - before_message_id, - }) - .await?; - let loaded_all_messages = response.done; - let messages = - messages_from_proto(response.messages, &user_store, &mut cx).await?; - this.update(&mut cx, |this, cx| { - this.loaded_all_messages = loaded_all_messages; - this.insert_messages(messages, cx); - }); - Ok(()) - } - .log_err() - }) - .detach(); - return true; - } - } - false - } - - pub fn rejoin(&mut self, cx: &mut ModelContext) { - let user_store = self.user_store.clone(); - let rpc = self.rpc.clone(); - let channel_id = self.details.id; - cx.spawn(|this, mut cx| { - async move { - let response = rpc.request(proto::JoinChannel { channel_id }).await?; - let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?; - let loaded_all_messages = response.done; - - let pending_messages = this.update(&mut cx, |this, cx| { - if let Some((first_new_message, last_old_message)) = - messages.first().zip(this.messages.last()) - { - if first_new_message.id > last_old_message.id { - let old_messages = mem::take(&mut this.messages); - cx.emit(ChannelEvent::MessagesUpdated { - old_range: 0..old_messages.summary().count, - new_count: 0, - }); - this.loaded_all_messages = loaded_all_messages; - } - } - - this.insert_messages(messages, cx); - if loaded_all_messages { - this.loaded_all_messages = loaded_all_messages; - } - - this.pending_messages().cloned().collect::>() - }); - - for pending_message in pending_messages { - let request = rpc.request(proto::SendChannelMessage { - channel_id, - body: pending_message.body, - nonce: Some(pending_message.nonce.into()), - }); - let response = request.await?; - let message = ChannelMessage::from_proto( - response.message.ok_or_else(|| anyhow!("invalid message"))?, - &user_store, - &mut cx, - ) - .await?; - this.update(&mut cx, |this, cx| { - this.insert_messages(SumTree::from_item(message, &()), cx); - }); - } - - Ok(()) - } - .log_err() - }) - .detach(); - } - - pub fn message_count(&self) -> usize { - self.messages.summary().count - } - - pub fn messages(&self) -> &SumTree { - &self.messages - } - - pub fn message(&self, ix: usize) -> &ChannelMessage { - let mut cursor = self.messages.cursor::(); - cursor.seek(&Count(ix), Bias::Right, &()); - cursor.item().unwrap() - } - - pub fn messages_in_range(&self, range: Range) -> impl Iterator { - let mut cursor = self.messages.cursor::(); - cursor.seek(&Count(range.start), Bias::Right, &()); - cursor.take(range.len()) - } - - pub fn pending_messages(&self) -> impl Iterator { - let mut cursor = self.messages.cursor::(); - cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &()); - cursor - } - - async fn handle_message_sent( - this: ModelHandle, - message: TypedEnvelope, - _: Arc, - mut cx: AsyncAppContext, - ) -> Result<()> { - let user_store = this.read_with(&cx, |this, _| this.user_store.clone()); - let message = message - .payload - .message - .ok_or_else(|| anyhow!("empty message"))?; - - let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?; - this.update(&mut cx, |this, cx| { - this.insert_messages(SumTree::from_item(message, &()), cx) - }); - - Ok(()) - } - - fn insert_messages(&mut self, messages: SumTree, cx: &mut ModelContext) { - if let Some((first_message, last_message)) = messages.first().zip(messages.last()) { - let nonces = messages - .cursor::<()>() - .map(|m| m.nonce) - .collect::>(); - - let mut old_cursor = self.messages.cursor::<(ChannelMessageId, Count)>(); - let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &()); - let start_ix = old_cursor.start().1 .0; - let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &()); - let removed_count = removed_messages.summary().count; - let new_count = messages.summary().count; - let end_ix = start_ix + removed_count; - - new_messages.push_tree(messages, &()); - - let mut ranges = Vec::>::new(); - if new_messages.last().unwrap().is_pending() { - new_messages.push_tree(old_cursor.suffix(&()), &()); - } else { - new_messages.push_tree( - old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()), - &(), - ); - - while let Some(message) = old_cursor.item() { - let message_ix = old_cursor.start().1 .0; - if nonces.contains(&message.nonce) { - if ranges.last().map_or(false, |r| r.end == message_ix) { - ranges.last_mut().unwrap().end += 1; - } else { - ranges.push(message_ix..message_ix + 1); - } - } else { - new_messages.push(message.clone(), &()); - } - old_cursor.next(&()); - } - } - - drop(old_cursor); - self.messages = new_messages; - - for range in ranges.into_iter().rev() { - cx.emit(ChannelEvent::MessagesUpdated { - old_range: range, - new_count: 0, - }); - } - cx.emit(ChannelEvent::MessagesUpdated { - old_range: start_ix..end_ix, - new_count, - }); - cx.notify(); - } - } -} - -async fn messages_from_proto( - proto_messages: Vec, - user_store: &ModelHandle, - cx: &mut AsyncAppContext, -) -> Result> { - let unique_user_ids = proto_messages - .iter() - .map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - user_store - .update(cx, |user_store, cx| { - user_store.get_users(unique_user_ids, cx) - }) - .await?; - - let mut messages = Vec::with_capacity(proto_messages.len()); - for message in proto_messages { - messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); - } - let mut result = SumTree::new(); - result.extend(messages, &()); - Ok(result) -} - -impl From for ChannelDetails { - fn from(message: proto::Channel) -> Self { - Self { - id: message.id, - name: message.name, - } - } -} - -impl ChannelMessage { - pub async fn from_proto( - message: proto::ChannelMessage, - user_store: &ModelHandle, - cx: &mut AsyncAppContext, - ) -> Result { - let sender = user_store - .update(cx, |user_store, cx| { - user_store.get_user(message.sender_id, cx) - }) - .await?; - Ok(ChannelMessage { - id: ChannelMessageId::Saved(message.id), - body: message.body, - timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, - sender, - nonce: message - .nonce - .ok_or_else(|| anyhow!("nonce is required"))? - .into(), - }) - } - - pub fn is_pending(&self) -> bool { - matches!(self.id, ChannelMessageId::Pending(_)) - } -} - -impl sum_tree::Item for ChannelMessage { - type Summary = ChannelMessageSummary; - - fn summary(&self) -> Self::Summary { - ChannelMessageSummary { - max_id: self.id, - count: 1, - } - } -} - -impl Default for ChannelMessageId { - fn default() -> Self { - Self::Saved(0) - } -} - -impl sum_tree::Summary for ChannelMessageSummary { - type Context = (); - - fn add_summary(&mut self, summary: &Self, _: &()) { - self.max_id = summary.max_id; - self.count += summary.count; - } -} - -impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId { - fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) { - debug_assert!(summary.max_id > *self); - *self = summary.max_id; - } -} - -impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count { - fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) { - self.0 += summary.count; - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test::{FakeHttpClient, FakeServer}; - use gpui::TestAppContext; - - #[gpui::test] - async fn test_channel_messages(cx: &mut TestAppContext) { - cx.foreground().forbid_parking(); - - let user_id = 5; - let http_client = FakeHttpClient::with_404_response(); - let client = cx.update(|cx| Client::new(http_client.clone(), cx)); - let server = FakeServer::for_client(user_id, &client, cx).await; - - Channel::init(&client); - let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx)); - - let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx)); - channel_list.read_with(cx, |list, _| assert_eq!(list.available_channels(), None)); - - // Get the available channels. - let get_channels = server.receive::().await.unwrap(); - server - .respond( - get_channels.receipt(), - proto::GetChannelsResponse { - channels: vec![proto::Channel { - id: 5, - name: "the-channel".to_string(), - }], - }, - ) - .await; - channel_list.next_notification(cx).await; - channel_list.read_with(cx, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: 5, - name: "the-channel".into(), - }] - ) - }); - - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![5]); - server - .respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 5, - github_login: "nathansobo".into(), - avatar_url: "http://avatar.com/nathansobo".into(), - }], - }, - ) - .await; - - // Join a channel and populate its existing messages. - let channel = channel_list - .update(cx, |list, cx| { - let channel_id = list.available_channels().unwrap()[0].id; - list.get_channel(channel_id, cx) - }) - .unwrap(); - channel.read_with(cx, |channel, _| assert!(channel.messages().is_empty())); - let join_channel = server.receive::().await.unwrap(); - server - .respond( - join_channel.receipt(), - proto::JoinChannelResponse { - messages: vec![ - proto::ChannelMessage { - id: 10, - body: "a".into(), - timestamp: 1000, - sender_id: 5, - nonce: Some(1.into()), - }, - proto::ChannelMessage { - id: 11, - body: "b".into(), - timestamp: 1001, - sender_id: 6, - nonce: Some(2.into()), - }, - ], - done: false, - }, - ) - .await; - - // Client requests all users for the received messages - let mut get_users = server.receive::().await.unwrap(); - get_users.payload.user_ids.sort(); - assert_eq!(get_users.payload.user_ids, vec![6]); - server - .respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 6, - github_login: "maxbrunsfeld".into(), - avatar_url: "http://avatar.com/maxbrunsfeld".into(), - }], - }, - ) - .await; - - assert_eq!( - channel.next_event(cx).await, - ChannelEvent::MessagesUpdated { - old_range: 0..0, - new_count: 2, - } - ); - channel.read_with(cx, |channel, _| { - assert_eq!( - channel - .messages_in_range(0..2) - .map(|message| (message.sender.github_login.clone(), message.body.clone())) - .collect::>(), - &[ - ("nathansobo".into(), "a".into()), - ("maxbrunsfeld".into(), "b".into()) - ] - ); - }); - - // Receive a new message. - server.send(proto::ChannelMessageSent { - channel_id: channel.read_with(cx, |channel, _| channel.details.id), - message: Some(proto::ChannelMessage { - id: 12, - body: "c".into(), - timestamp: 1002, - sender_id: 7, - nonce: Some(3.into()), - }), - }); - - // Client requests user for message since they haven't seen them yet - let get_users = server.receive::().await.unwrap(); - assert_eq!(get_users.payload.user_ids, vec![7]); - server - .respond( - get_users.receipt(), - proto::UsersResponse { - users: vec![proto::User { - id: 7, - github_login: "as-cii".into(), - avatar_url: "http://avatar.com/as-cii".into(), - }], - }, - ) - .await; - - assert_eq!( - channel.next_event(cx).await, - ChannelEvent::MessagesUpdated { - old_range: 2..2, - new_count: 1, - } - ); - channel.read_with(cx, |channel, _| { - assert_eq!( - channel - .messages_in_range(2..3) - .map(|message| (message.sender.github_login.clone(), message.body.clone())) - .collect::>(), - &[("as-cii".into(), "c".into())] - ) - }); - - // Scroll up to view older messages. - channel.update(cx, |channel, cx| { - assert!(channel.load_more_messages(cx)); - }); - let get_messages = server.receive::().await.unwrap(); - assert_eq!(get_messages.payload.channel_id, 5); - assert_eq!(get_messages.payload.before_message_id, 10); - server - .respond( - get_messages.receipt(), - proto::GetChannelMessagesResponse { - done: true, - messages: vec![ - proto::ChannelMessage { - id: 8, - body: "y".into(), - timestamp: 998, - sender_id: 5, - nonce: Some(4.into()), - }, - proto::ChannelMessage { - id: 9, - body: "z".into(), - timestamp: 999, - sender_id: 6, - nonce: Some(5.into()), - }, - ], - }, - ) - .await; - - assert_eq!( - channel.next_event(cx).await, - ChannelEvent::MessagesUpdated { - old_range: 0..0, - new_count: 2, - } - ); - channel.read_with(cx, |channel, _| { - assert_eq!( - channel - .messages_in_range(0..2) - .map(|message| (message.sender.github_login.clone(), message.body.clone())) - .collect::>(), - &[ - ("nathansobo".into(), "y".into()), - ("maxbrunsfeld".into(), "z".into()) - ] - ); - }); - } -} diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 587961f2a7..c943b27417 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -1,7 +1,6 @@ #[cfg(any(test, feature = "test-support"))] pub mod test; -pub mod channel; pub mod http; pub mod telemetry; pub mod user; @@ -44,7 +43,6 @@ use thiserror::Error; use url::Url; use util::{ResultExt, TryFutureExt}; -pub use channel::*; pub use rpc::*; pub use user::*; diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index aef8e5562b..63d2661de5 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -1,19 +1,14 @@ -CREATE TABLE IF NOT EXISTS "sessions" ( - "id" VARCHAR NOT NULL PRIMARY KEY, - "expires" TIMESTAMP WITH TIME ZONE NULL, - "session" TEXT NOT NULL -); - CREATE TABLE IF NOT EXISTS "users" ( - "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "id" INTEGER PRIMARY KEY, "github_login" VARCHAR, "admin" BOOLEAN, - email_address VARCHAR(255) DEFAULT NULL, - invite_code VARCHAR(64), - invite_count INTEGER NOT NULL DEFAULT 0, - inviter_id INTEGER REFERENCES users (id), - connected_once BOOLEAN NOT NULL DEFAULT false, - created_at TIMESTAMP NOT NULL DEFAULT now, + "email_address" VARCHAR(255) DEFAULT NULL, + "invite_code" VARCHAR(64), + "invite_count" INTEGER NOT NULL DEFAULT 0, + "inviter_id" INTEGER REFERENCES users (id), + "connected_once" BOOLEAN NOT NULL DEFAULT false, + "created_at" TIMESTAMP NOT NULL DEFAULT now, + "metrics_id" VARCHAR(255), "github_user_id" INTEGER ); CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login"); @@ -22,56 +17,14 @@ CREATE INDEX "index_users_on_email_address" ON "users" ("email_address"); CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id"); CREATE TABLE IF NOT EXISTS "access_tokens" ( - "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "id" INTEGER PRIMARY KEY, "user_id" INTEGER REFERENCES users (id), "hash" VARCHAR(128) ); CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id"); -CREATE TABLE IF NOT EXISTS "orgs" ( - "id" SERIAL PRIMARY KEY, - "name" VARCHAR NOT NULL, - "slug" VARCHAR NOT NULL -); -CREATE UNIQUE INDEX "index_orgs_slug" ON "orgs" ("slug"); - -CREATE TABLE IF NOT EXISTS "org_memberships" ( - "id" SERIAL PRIMARY KEY, - "org_id" INTEGER REFERENCES orgs (id) NOT NULL, - "user_id" INTEGER REFERENCES users (id) NOT NULL, - "admin" BOOLEAN NOT NULL -); -CREATE INDEX "index_org_memberships_user_id" ON "org_memberships" ("user_id"); -CREATE UNIQUE INDEX "index_org_memberships_org_id_and_user_id" ON "org_memberships" ("org_id", "user_id"); - -CREATE TABLE IF NOT EXISTS "channels" ( - "id" SERIAL PRIMARY KEY, - "owner_id" INTEGER NOT NULL, - "owner_is_user" BOOLEAN NOT NULL, - "name" VARCHAR NOT NULL -); -CREATE UNIQUE INDEX "index_channels_owner_and_name" ON "channels" ("owner_is_user", "owner_id", "name"); - -CREATE TABLE IF NOT EXISTS "channel_memberships" ( - "id" SERIAL PRIMARY KEY, - "channel_id" INTEGER REFERENCES channels (id) NOT NULL, - "user_id" INTEGER REFERENCES users (id) NOT NULL, - "admin" BOOLEAN NOT NULL -); -CREATE INDEX "index_channel_memberships_user_id" ON "channel_memberships" ("user_id"); -CREATE UNIQUE INDEX "index_channel_memberships_channel_id_and_user_id" ON "channel_memberships" ("channel_id", "user_id"); - -CREATE TABLE IF NOT EXISTS "channel_messages" ( - "id" SERIAL PRIMARY KEY, - "channel_id" INTEGER REFERENCES channels (id) NOT NULL, - "sender_id" INTEGER REFERENCES users (id) NOT NULL, - "body" TEXT NOT NULL, - "sent_at" TIMESTAMP -); -CREATE INDEX "index_channel_messages_channel_id" ON "channel_messages" ("channel_id"); - CREATE TABLE IF NOT EXISTS "contacts" ( - "id" SERIAL PRIMARY KEY, + "id" INTEGER PRIMARY KEY, "user_id_a" INTEGER REFERENCES users (id) NOT NULL, "user_id_b" INTEGER REFERENCES users (id) NOT NULL, "a_to_b" BOOLEAN NOT NULL, @@ -82,46 +35,7 @@ CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_ CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b"); CREATE TABLE IF NOT EXISTS "projects" ( - "id" SERIAL PRIMARY KEY, + "id" INTEGER PRIMARY KEY, "host_user_id" INTEGER REFERENCES users (id) NOT NULL, "unregistered" BOOLEAN NOT NULL DEFAULT false ); - -CREATE TABLE IF NOT EXISTS "worktree_extensions" ( - "id" SERIAL PRIMARY KEY, - "project_id" INTEGER REFERENCES projects (id) NOT NULL, - "worktree_id" INTEGER NOT NULL, - "extension" VARCHAR(255), - "count" INTEGER NOT NULL -); -CREATE UNIQUE INDEX "index_worktree_extensions_on_project_id_and_worktree_id_and_extension" ON "worktree_extensions" ("project_id", "worktree_id", "extension"); - -CREATE TABLE IF NOT EXISTS "project_activity_periods" ( - "id" SERIAL PRIMARY KEY, - "duration_millis" INTEGER NOT NULL, - "ended_at" TIMESTAMP NOT NULL, - "user_id" INTEGER REFERENCES users (id) NOT NULL, - "project_id" INTEGER REFERENCES projects (id) NOT NULL -); -CREATE INDEX "index_project_activity_periods_on_ended_at" ON "project_activity_periods" ("ended_at"); - -CREATE TABLE IF NOT EXISTS "signups" ( - "id" SERIAL PRIMARY KEY, - "email_address" VARCHAR NOT NULL, - "email_confirmation_code" VARCHAR(64) NOT NULL, - "email_confirmation_sent" BOOLEAN NOT NULL, - "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - "device_id" VARCHAR, - "user_id" INTEGER REFERENCES users (id) ON DELETE CASCADE, - "inviting_user_id" INTEGER REFERENCES users (id) ON DELETE SET NULL, - - "platform_mac" BOOLEAN NOT NULL, - "platform_linux" BOOLEAN NOT NULL, - "platform_windows" BOOLEAN NOT NULL, - "platform_unknown" BOOLEAN NOT NULL, - - "editor_features" VARCHAR[], - "programming_languages" VARCHAR[] -); -CREATE UNIQUE INDEX "index_signups_on_email_address" ON "signups" ("email_address"); -CREATE INDEX "index_signups_on_email_confirmation_sent" ON "signups" ("email_confirmation_sent"); diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index fbf45a3799..5fcdc5fcfd 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,6 +1,6 @@ use crate::{ auth, - db::{Invite, NewUserParams, ProjectId, Signup, User, UserId, WaitlistSummary}, + db::{Invite, NewUserParams, Signup, User, UserId, WaitlistSummary}, rpc::{self, ResultExt}, AppState, Error, Result, }; @@ -16,9 +16,7 @@ use axum::{ }; use axum_extra::response::ErasedJson; use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::{sync::Arc, time::Duration}; -use time::OffsetDateTime; +use std::sync::Arc; use tower::ServiceBuilder; use tracing::instrument; @@ -32,16 +30,6 @@ pub fn routes(rpc_server: Arc, state: Arc) -> Router, - Extension(app): Extension>, -) -> Result { - let summary = app - .db - .get_top_users_activity_summary(params.start..params.end, 100) - .await?; - Ok(ErasedJson::pretty(summary)) -} - -async fn get_user_activity_timeline( - Path(user_id): Path, - Query(params): Query, - Extension(app): Extension>, -) -> Result { - let summary = app - .db - .get_user_activity_timeline(params.start..params.end, UserId(user_id)) - .await?; - Ok(ErasedJson::pretty(summary)) -} - -#[derive(Deserialize)] -struct ActiveUserCountParams { - #[serde(flatten)] - period: TimePeriodParams, - durations_in_minutes: String, - #[serde(default)] - only_collaborative: bool, -} - -#[derive(Serialize)] -struct ActiveUserSet { - active_time_in_minutes: u64, - user_count: usize, -} - -async fn get_active_user_counts( - Query(params): Query, - Extension(app): Extension>, -) -> Result { - let durations_in_minutes = params.durations_in_minutes.split(','); - let mut user_sets = Vec::new(); - for duration in durations_in_minutes { - let duration = duration - .parse() - .map_err(|_| anyhow!("invalid duration: {duration}"))?; - user_sets.push(ActiveUserSet { - active_time_in_minutes: duration, - user_count: app - .db - .get_active_user_count( - params.period.start..params.period.end, - Duration::from_secs(duration * 60), - params.only_collaborative, - ) - .await?, - }) - } - Ok(ErasedJson::pretty(user_sets)) -} - -#[derive(Deserialize)] -struct GetProjectMetadataParams { - project_id: u64, -} - -async fn get_project_metadata( - Query(params): Query, - Extension(app): Extension>, -) -> Result { - let extensions = app - .db - .get_project_extensions(ProjectId::from_proto(params.project_id)) - .await?; - Ok(ErasedJson::pretty(json!({ "extensions": extensions }))) -} - #[derive(Deserialize)] struct CreateAccessTokenQueryParams { public_key: String, diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index ab0833b38e..9f00c02918 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -7,9 +7,9 @@ use serde::{Deserialize, Serialize}; use sqlx::{ migrate::{Migrate as _, Migration, MigrationSource}, types::Uuid, - FromRow, QueryBuilder, + FromRow, }; -use std::{cmp, ops::Range, path::Path, time::Duration}; +use std::{path::Path, time::Duration}; use time::{OffsetDateTime, PrimitiveDateTime}; #[cfg(test)] @@ -58,8 +58,8 @@ impl RowsAffected for sqlx::postgres::PgQueryResult { } } +#[cfg(test)] impl Db { - #[cfg(test)] pub async fn new(url: &str, max_connections: u32) -> Result { let pool = sqlx::sqlite::SqlitePoolOptions::new() .max_connections(max_connections) @@ -70,6 +70,76 @@ impl Db { background: None, }) } + + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { + test_support!(self, { + let query = " + SELECT metrics_id + FROM users + WHERE id = $1 + "; + Ok(sqlx::query_scalar(query) + .bind(id) + .fetch_one(&self.pool) + .await?) + }) + } + + pub async fn create_user( + &self, + email_address: &str, + admin: bool, + params: NewUserParams, + ) -> Result { + test_support!(self, { + let query = " + INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login + RETURNING id, metrics_id + "; + + let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) + .bind(email_address) + .bind(params.github_login) + .bind(params.github_user_id) + .bind(admin) + .bind(Uuid::new_v4().to_string()) + .fetch_one(&self.pool) + .await?; + Ok(NewUserResult { + user_id, + metrics_id, + signup_device_id: None, + inviting_user_id: None, + }) + }) + } + + pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result> { + unimplemented!() + } + + pub async fn create_user_from_invite( + &self, + _invite: &Invite, + _user: NewUserParams, + ) -> Result> { + unimplemented!() + } + + pub async fn create_signup(&self, _signup: Signup) -> Result<()> { + unimplemented!() + } + + pub async fn create_invite_from_code( + &self, + _code: &str, + _email_address: &str, + _device_id: Option<&str>, + ) -> Result { + unimplemented!() + } } impl Db { @@ -84,6 +154,302 @@ impl Db { background: None, }) } + + pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { + test_support!(self, { + let like_string = Self::fuzzy_like_string(name_query); + let query = " + SELECT users.* + FROM users + WHERE github_login ILIKE $1 + ORDER BY github_login <-> $2 + LIMIT $3 + "; + Ok(sqlx::query_as(query) + .bind(like_string) + .bind(name_query) + .bind(limit as i32) + .fetch_all(&self.pool) + .await?) + }) + } + + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { + test_support!(self, { + let query = " + SELECT metrics_id::text + FROM users + WHERE id = $1 + "; + Ok(sqlx::query_scalar(query) + .bind(id) + .fetch_one(&self.pool) + .await?) + }) + } + + pub async fn create_user( + &self, + email_address: &str, + admin: bool, + params: NewUserParams, + ) -> Result { + test_support!(self, { + let query = " + INSERT INTO users (email_address, github_login, github_user_id, admin) + VALUES ($1, $2, $3, $4) + ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login + RETURNING id, metrics_id::text + "; + + let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) + .bind(email_address) + .bind(params.github_login) + .bind(params.github_user_id) + .bind(admin) + .fetch_one(&self.pool) + .await?; + Ok(NewUserResult { + user_id, + metrics_id, + signup_device_id: None, + inviting_user_id: None, + }) + }) + } + + pub async fn create_user_from_invite( + &self, + invite: &Invite, + user: NewUserParams, + ) -> Result> { + test_support!(self, { + let mut tx = self.pool.begin().await?; + + let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( + i32, + Option, + Option, + Option, + ) = sqlx::query_as( + " + SELECT id, user_id, inviting_user_id, device_id + FROM signups + WHERE + email_address = $1 AND + email_confirmation_code = $2 + ", + ) + .bind(&invite.email_address) + .bind(&invite.email_confirmation_code) + .fetch_optional(&mut tx) + .await? + .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; + + if existing_user_id.is_some() { + return Ok(None); + } + + let (user_id, metrics_id): (UserId, String) = sqlx::query_as( + " + INSERT INTO users + (email_address, github_login, github_user_id, admin, invite_count, invite_code) + VALUES + ($1, $2, $3, FALSE, $4, $5) + ON CONFLICT (github_login) DO UPDATE SET + email_address = excluded.email_address, + github_user_id = excluded.github_user_id, + admin = excluded.admin + RETURNING id, metrics_id::text + ", + ) + .bind(&invite.email_address) + .bind(&user.github_login) + .bind(&user.github_user_id) + .bind(&user.invite_count) + .bind(random_invite_code()) + .fetch_one(&mut tx) + .await?; + + sqlx::query( + " + UPDATE signups + SET user_id = $1 + WHERE id = $2 + ", + ) + .bind(&user_id) + .bind(&signup_id) + .execute(&mut tx) + .await?; + + if let Some(inviting_user_id) = inviting_user_id { + let id: Option = sqlx::query_scalar( + " + UPDATE users + SET invite_count = invite_count - 1 + WHERE id = $1 AND invite_count > 0 + RETURNING id + ", + ) + .bind(&inviting_user_id) + .fetch_optional(&mut tx) + .await?; + + if id.is_none() { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + sqlx::query( + " + INSERT INTO contacts + (user_id_a, user_id_b, a_to_b, should_notify, accepted) + VALUES + ($1, $2, TRUE, TRUE, TRUE) + ON CONFLICT DO NOTHING + ", + ) + .bind(inviting_user_id) + .bind(user_id) + .execute(&mut tx) + .await?; + } + + tx.commit().await?; + Ok(Some(NewUserResult { + user_id, + metrics_id, + inviting_user_id, + signup_device_id, + })) + }) + } + + pub async fn create_signup(&self, signup: Signup) -> Result<()> { + test_support!(self, { + sqlx::query( + " + INSERT INTO signups + ( + email_address, + email_confirmation_code, + email_confirmation_sent, + platform_linux, + platform_mac, + platform_windows, + platform_unknown, + editor_features, + programming_languages, + device_id + ) + VALUES + ($1, $2, FALSE, $3, $4, $5, FALSE, $6) + RETURNING id + ", + ) + .bind(&signup.email_address) + .bind(&random_email_confirmation_code()) + .bind(&signup.platform_linux) + .bind(&signup.platform_mac) + .bind(&signup.platform_windows) + .bind(&signup.editor_features) + .bind(&signup.programming_languages) + .bind(&signup.device_id) + .execute(&self.pool) + .await?; + Ok(()) + }) + } + + pub async fn create_invite_from_code( + &self, + code: &str, + email_address: &str, + device_id: Option<&str>, + ) -> Result { + test_support!(self, { + let mut tx = self.pool.begin().await?; + + let existing_user: Option = sqlx::query_scalar( + " + SELECT id + FROM users + WHERE email_address = $1 + ", + ) + .bind(email_address) + .fetch_optional(&mut tx) + .await?; + if existing_user.is_some() { + Err(anyhow!("email address is already in use"))?; + } + + let row: Option<(UserId, i32)> = sqlx::query_as( + " + SELECT id, invite_count + FROM users + WHERE invite_code = $1 + ", + ) + .bind(code) + .fetch_optional(&mut tx) + .await?; + + let (inviter_id, invite_count) = match row { + Some(row) => row, + None => Err(Error::Http( + StatusCode::NOT_FOUND, + "invite code not found".to_string(), + ))?, + }; + + if invite_count == 0 { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + let email_confirmation_code: String = sqlx::query_scalar( + " + INSERT INTO signups + ( + email_address, + email_confirmation_code, + email_confirmation_sent, + inviting_user_id, + platform_linux, + platform_mac, + platform_windows, + platform_unknown, + device_id + ) + VALUES + ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4) + ON CONFLICT (email_address) + DO UPDATE SET + inviting_user_id = excluded.inviting_user_id + RETURNING email_confirmation_code + ", + ) + .bind(&email_address) + .bind(&random_email_confirmation_code()) + .bind(&inviter_id) + .bind(&device_id) + .fetch_one(&mut tx) + .await?; + + tx.commit().await?; + + Ok(Invite { + email_address: email_address.into(), + email_confirmation_code, + }) + }) + } } impl Db @@ -172,36 +538,6 @@ where // users - pub async fn create_user( - &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result { - test_support!(self, { - let query = " - INSERT INTO users (email_address, github_login, github_user_id, admin) - VALUES ($1, $2, $3, $4) - -- ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login - RETURNING id, 'the-metrics-id' - "; - - let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) - .bind(email_address) - .bind(params.github_login) - .bind(params.github_user_id) - .bind(admin) - .fetch_one(&self.pool) - .await?; - Ok(NewUserResult { - user_id, - metrics_id, - signup_device_id: None, - inviting_user_id: None, - }) - }) - } - pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { test_support!(self, { let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; @@ -213,25 +549,6 @@ where }) } - pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - test_support!(self, { - let like_string = Self::fuzzy_like_string(name_query); - let query = " - SELECT users.* - FROM users - WHERE github_login ILIKE $1 - ORDER BY github_login <-> $2 - LIMIT $3 - "; - Ok(sqlx::query_as(query) - .bind(like_string) - .bind(name_query) - .bind(limit as i32) - .fetch_all(&self.pool) - .await?) - }) - } - pub async fn get_user_by_id(&self, id: UserId) -> Result> { test_support!(self, { let query = " @@ -247,20 +564,6 @@ where }) } - pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - test_support!(self, { - let query = " - SELECT metrics_id::text - FROM users - WHERE id = $1 - "; - Ok(sqlx::query_scalar(query) - .bind(id) - .fetch_one(&self.pool) - .await?) - }) - } - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { test_support!(self, { let query = " @@ -389,42 +692,6 @@ where // signups - pub async fn create_signup(&self, signup: Signup) -> Result<()> { - test_support!(self, { - sqlx::query( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - editor_features, - programming_languages, - device_id - ) - VALUES - ($1, $2, FALSE, $3, $4, $5, FALSE, $6) - RETURNING id - ", - ) - .bind(&signup.email_address) - .bind(&random_email_confirmation_code()) - .bind(&signup.platform_linux) - .bind(&signup.platform_mac) - .bind(&signup.platform_windows) - // .bind(&signup.editor_features) - // .bind(&signup.programming_languages) - .bind(&signup.device_id) - .execute(&self.pool) - .await?; - Ok(()) - }) - } - pub async fn get_waitlist_summary(&self) -> Result { test_support!(self, { Ok(sqlx::query_as( @@ -487,116 +754,6 @@ where }) } - pub async fn create_user_from_invite( - &self, - invite: &Invite, - user: NewUserParams, - ) -> Result> { - test_support!(self, { - let mut tx = self.pool.begin().await?; - - let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( - i32, - Option, - Option, - Option, - ) = sqlx::query_as( - " - SELECT id, user_id, inviting_user_id, device_id - FROM signups - WHERE - email_address = $1 AND - email_confirmation_code = $2 - ", - ) - .bind(&invite.email_address) - .bind(&invite.email_confirmation_code) - .fetch_optional(&mut tx) - .await? - .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - - if existing_user_id.is_some() { - return Ok(None); - } - - let (user_id, metrics_id): (UserId, String) = sqlx::query_as( - " - INSERT INTO users - (email_address, github_login, github_user_id, admin, invite_count, invite_code) - VALUES - ($1, $2, $3, FALSE, $4, $5) - ON CONFLICT (github_login) DO UPDATE SET - email_address = excluded.email_address, - github_user_id = excluded.github_user_id, - admin = excluded.admin - RETURNING id, metrics_id::text - ", - ) - .bind(&invite.email_address) - .bind(&user.github_login) - .bind(&user.github_user_id) - .bind(&user.invite_count) - .bind(random_invite_code()) - .fetch_one(&mut tx) - .await?; - - sqlx::query( - " - UPDATE signups - SET user_id = $1 - WHERE id = $2 - ", - ) - .bind(&user_id) - .bind(&signup_id) - .execute(&mut tx) - .await?; - - if let Some(inviting_user_id) = inviting_user_id { - let id: Option = sqlx::query_scalar( - " - UPDATE users - SET invite_count = invite_count - 1 - WHERE id = $1 AND invite_count > 0 - RETURNING id - ", - ) - .bind(&inviting_user_id) - .fetch_optional(&mut tx) - .await?; - - if id.is_none() { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } - - sqlx::query( - " - INSERT INTO contacts - (user_id_a, user_id_b, a_to_b, should_notify, accepted) - VALUES - ($1, $2, TRUE, TRUE, TRUE) - ON CONFLICT DO NOTHING - ", - ) - .bind(inviting_user_id) - .bind(user_id) - .execute(&mut tx) - .await?; - } - - tx.commit().await?; - Ok(Some(NewUserResult { - user_id, - metrics_id, - inviting_user_id, - signup_device_id, - })) - }) - } - // invite codes pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { @@ -673,93 +830,6 @@ where }) } - pub async fn create_invite_from_code( - &self, - code: &str, - email_address: &str, - device_id: Option<&str>, - ) -> Result { - test_support!(self, { - let mut tx = self.pool.begin().await?; - - let existing_user: Option = sqlx::query_scalar( - " - SELECT id - FROM users - WHERE email_address = $1 - ", - ) - .bind(email_address) - .fetch_optional(&mut tx) - .await?; - if existing_user.is_some() { - Err(anyhow!("email address is already in use"))?; - } - - let row: Option<(UserId, i32)> = sqlx::query_as( - " - SELECT id, invite_count - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&mut tx) - .await?; - - let (inviter_id, invite_count) = match row { - Some(row) => row, - None => Err(Error::Http( - StatusCode::NOT_FOUND, - "invite code not found".to_string(), - ))?, - }; - - if invite_count == 0 { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } - - let email_confirmation_code: String = sqlx::query_scalar( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - inviting_user_id, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - device_id - ) - VALUES - ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4) - ON CONFLICT (email_address) - DO UPDATE SET - inviting_user_id = excluded.inviting_user_id - RETURNING email_confirmation_code - ", - ) - .bind(&email_address) - .bind(&random_email_confirmation_code()) - .bind(&inviter_id) - .bind(&device_id) - .fetch_one(&mut tx) - .await?; - - tx.commit().await?; - - Ok(Invite { - email_address: email_address.into(), - email_confirmation_code, - }) - }) - } - // projects /// Registers a new project for the given user. @@ -796,345 +866,6 @@ where }) } - /// Update file counts by extension for the given project and worktree. - pub async fn update_worktree_extensions( - &self, - project_id: ProjectId, - worktree_id: u64, - extensions: HashMap, - ) -> Result<()> { - test_support!(self, { - if extensions.is_empty() { - return Ok(()); - } - - let mut query = QueryBuilder::new( - "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)", - ); - query.push_values(extensions, |mut query, (extension, count)| { - query - .push_bind(project_id) - .push_bind(worktree_id as i32) - .push_bind(extension) - .push_bind(count as i32); - }); - query.push( - " - ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET - count = excluded.count - ", - ); - // query.build().execute(&self.pool).await?; - - Ok(()) - }) - } - - /// Get the file counts on the given project keyed by their worktree and extension. - pub async fn get_project_extensions( - &self, - project_id: ProjectId, - ) -> Result>> { - test_support!(self, { - #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] - struct WorktreeExtension { - worktree_id: i32, - extension: String, - count: i32, - } - - let query = " - SELECT worktree_id, extension, count - FROM worktree_extensions - WHERE project_id = $1 - "; - let counts = sqlx::query_as::<_, WorktreeExtension>(query) - .bind(&project_id) - .fetch_all(&self.pool) - .await?; - - let mut extension_counts = HashMap::default(); - for count in counts { - extension_counts - .entry(count.worktree_id as u64) - .or_insert_with(HashMap::default) - .insert(count.extension, count.count as usize); - } - Ok(extension_counts) - }) - } - - /// Record which users have been active in which projects during - /// a given period of time. - pub async fn record_user_activity( - &self, - time_period: Range, - projects: &[(UserId, ProjectId)], - ) -> Result<()> { - test_support!(self, { - let query = " - INSERT INTO project_activity_periods - (ended_at, duration_millis, user_id, project_id) - VALUES - ($1, $2, $3, $4); - "; - - let mut tx = self.pool.begin().await?; - let duration_millis = - ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32; - for (user_id, project_id) in projects { - sqlx::query(query) - .bind(time_period.end) - .bind(duration_millis) - .bind(user_id) - .bind(project_id) - .execute(&mut tx) - .await?; - } - tx.commit().await?; - - Ok(()) - }) - } - - /// Get the number of users who have been active in the given - /// time period for at least the given time duration. - pub async fn get_active_user_count( - &self, - time_period: Range, - min_duration: Duration, - only_collaborative: bool, - ) -> Result { - test_support!(self, { - let mut with_clause = String::new(); - with_clause.push_str("WITH\n"); - with_clause.push_str( - " - project_durations AS ( - SELECT user_id, project_id, SUM(duration_millis) AS project_duration - FROM project_activity_periods - WHERE $1 < ended_at AND ended_at <= $2 - GROUP BY user_id, project_id - ), - ", - ); - with_clause.push_str( - " - project_collaborators as ( - SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators - FROM project_durations - GROUP BY project_id - ), - ", - ); - - if only_collaborative { - with_clause.push_str( - " - user_durations AS ( - SELECT user_id, SUM(project_duration) as total_duration - FROM project_durations, project_collaborators - WHERE - project_durations.project_id = project_collaborators.project_id AND - max_collaborators > 1 - GROUP BY user_id - ORDER BY total_duration DESC - LIMIT $3 - ) - ", - ); - } else { - with_clause.push_str( - " - user_durations AS ( - SELECT user_id, SUM(project_duration) as total_duration - FROM project_durations - GROUP BY user_id - ORDER BY total_duration DESC - LIMIT $3 - ) - ", - ); - } - - let query = format!( - " - {with_clause} - SELECT count(user_durations.user_id) - FROM user_durations - WHERE user_durations.total_duration >= $3 - " - ); - - let count: i64 = sqlx::query_scalar(&query) - .bind(time_period.start) - .bind(time_period.end) - .bind(min_duration.as_millis() as i64) - .fetch_one(&self.pool) - .await?; - Ok(count as usize) - }) - } - - /// Get the users that have been most active during the given time period, - /// along with the amount of time they have been active in each project. - pub async fn get_top_users_activity_summary( - &self, - time_period: Range, - max_user_count: usize, - ) -> Result> { - test_support!(self, { - let query = " - WITH - project_durations AS ( - SELECT user_id, project_id, SUM(duration_millis) AS project_duration - FROM project_activity_periods - WHERE $1 < ended_at AND ended_at <= $2 - GROUP BY user_id, project_id - ), - user_durations AS ( - SELECT user_id, SUM(project_duration) as total_duration - FROM project_durations - GROUP BY user_id - ORDER BY total_duration DESC - LIMIT $3 - ), - project_collaborators as ( - SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators - FROM project_durations - GROUP BY project_id - ) - SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators - FROM user_durations, project_durations, project_collaborators, users - WHERE - user_durations.user_id = project_durations.user_id AND - user_durations.user_id = users.id AND - project_durations.project_id = project_collaborators.project_id - ORDER BY total_duration DESC, user_id ASC, project_id ASC - "; - - let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query) - .bind(time_period.start) - .bind(time_period.end) - .bind(max_user_count as i32) - .fetch(&self.pool); - - let mut result = Vec::::new(); - while let Some(row) = rows.next().await { - let (user_id, github_login, project_id, duration_millis, project_collaborators) = - row?; - let project_id = project_id; - let duration = Duration::from_millis(duration_millis as u64); - let project_activity = ProjectActivitySummary { - id: project_id, - duration, - max_collaborators: project_collaborators as usize, - }; - if let Some(last_summary) = result.last_mut() { - if last_summary.id == user_id { - last_summary.project_activity.push(project_activity); - continue; - } - } - result.push(UserActivitySummary { - id: user_id, - project_activity: vec![project_activity], - github_login, - }); - } - - Ok(result) - }) - } - - /// Get the project activity for the given user and time period. - pub async fn get_user_activity_timeline( - &self, - time_period: Range, - user_id: UserId, - ) -> Result> { - test_support!(self, { - const COALESCE_THRESHOLD: Duration = Duration::from_secs(30); - - let query = " - SELECT - project_activity_periods.ended_at, - project_activity_periods.duration_millis, - project_activity_periods.project_id, - worktree_extensions.extension, - worktree_extensions.count - FROM project_activity_periods - LEFT OUTER JOIN - worktree_extensions - ON - project_activity_periods.project_id = worktree_extensions.project_id - WHERE - project_activity_periods.user_id = $1 AND - $2 < project_activity_periods.ended_at AND - project_activity_periods.ended_at <= $3 - ORDER BY project_activity_periods.id ASC - "; - - let mut rows = sqlx::query_as::< - _, - ( - PrimitiveDateTime, - i32, - ProjectId, - Option, - Option, - ), - >(query) - .bind(user_id) - .bind(time_period.start) - .bind(time_period.end) - .fetch(&self.pool); - - let mut time_periods: HashMap> = Default::default(); - while let Some(row) = rows.next().await { - let (ended_at, duration_millis, project_id, extension, extension_count) = row?; - let ended_at = ended_at.assume_utc(); - let duration = Duration::from_millis(duration_millis as u64); - let started_at = ended_at - duration; - let project_time_periods = time_periods.entry(project_id).or_default(); - - if let Some(prev_duration) = project_time_periods.last_mut() { - if started_at <= prev_duration.end + COALESCE_THRESHOLD - && ended_at >= prev_duration.start - { - prev_duration.end = cmp::max(prev_duration.end, ended_at); - } else { - project_time_periods.push(UserActivityPeriod { - project_id, - start: started_at, - end: ended_at, - extensions: Default::default(), - }); - } - } else { - project_time_periods.push(UserActivityPeriod { - project_id, - start: started_at, - end: ended_at, - extensions: Default::default(), - }); - } - - if let Some((extension, extension_count)) = extension.zip(extension_count) { - project_time_periods - .last_mut() - .unwrap() - .extensions - .insert(extension, extension_count as usize); - } - } - - let mut durations = time_periods.into_values().flatten().collect::>(); - durations.sort_unstable_by_key(|duration| duration.start); - Ok(durations) - }) - } - // contacts pub async fn get_contacts(&self, user_id: UserId) -> Result> { @@ -1370,6 +1101,7 @@ where SELECT id from access_tokens WHERE user_id = $1 ORDER BY id DESC + LIMIT 10000 OFFSET $3 ) "; @@ -1404,222 +1136,6 @@ where .await?) }) } - - // orgs - - #[allow(unused)] // Help rust-analyzer - #[cfg(any(test, feature = "seed-support"))] - pub async fn find_org_by_slug(&self, slug: &str) -> Result> { - test_support!(self, { - let query = " - SELECT * - FROM orgs - WHERE slug = $1 - "; - Ok(sqlx::query_as(query) - .bind(slug) - .fetch_optional(&self.pool) - .await?) - }) - } - - #[cfg(any(test, feature = "seed-support"))] - pub async fn create_org(&self, name: &str, slug: &str) -> Result { - test_support!(self, { - let query = " - INSERT INTO orgs (name, slug) - VALUES ($1, $2) - RETURNING id - "; - Ok(sqlx::query_scalar(query) - .bind(name) - .bind(slug) - .fetch_one(&self.pool) - .await - .map(OrgId)?) - }) - } - - #[cfg(any(test, feature = "seed-support"))] - pub async fn add_org_member( - &self, - org_id: OrgId, - user_id: UserId, - is_admin: bool, - ) -> Result<()> { - test_support!(self, { - let query = " - INSERT INTO org_memberships (org_id, user_id, admin) - VALUES ($1, $2, $3) - ON CONFLICT DO NOTHING - "; - Ok(sqlx::query(query) - .bind(org_id.0) - .bind(user_id.0) - .bind(is_admin) - .execute(&self.pool) - .await - .map(drop)?) - }) - } - - // channels - - #[cfg(any(test, feature = "seed-support"))] - pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { - test_support!(self, { - let query = " - INSERT INTO channels (owner_id, owner_is_user, name) - VALUES ($1, false, $2) - RETURNING id - "; - Ok(sqlx::query_scalar(query) - .bind(org_id.0) - .bind(name) - .fetch_one(&self.pool) - .await - .map(ChannelId)?) - }) - } - - #[allow(unused)] // Help rust-analyzer - #[cfg(any(test, feature = "seed-support"))] - pub async fn get_org_channels(&self, org_id: OrgId) -> Result> { - test_support!(self, { - let query = " - SELECT * - FROM channels - WHERE - channels.owner_is_user = false AND - channels.owner_id = $1 - "; - Ok(sqlx::query_as(query) - .bind(org_id.0) - .fetch_all(&self.pool) - .await?) - }) - } - - pub async fn get_accessible_channels(&self, user_id: UserId) -> Result> { - test_support!(self, { - let query = " - SELECT - channels.* - FROM - channel_memberships, channels - WHERE - channel_memberships.user_id = $1 AND - channel_memberships.channel_id = channels.id - "; - Ok(sqlx::query_as(query) - .bind(user_id.0) - .fetch_all(&self.pool) - .await?) - }) - } - - pub async fn can_user_access_channel( - &self, - user_id: UserId, - channel_id: ChannelId, - ) -> Result { - test_support!(self, { - let query = " - SELECT id - FROM channel_memberships - WHERE user_id = $1 AND channel_id = $2 - LIMIT 1 - "; - Ok(sqlx::query_scalar::<_, i32>(query) - .bind(user_id.0) - .bind(channel_id.0) - .fetch_optional(&self.pool) - .await - .map(|e| e.is_some())?) - }) - } - - #[cfg(any(test, feature = "seed-support"))] - pub async fn add_channel_member( - &self, - channel_id: ChannelId, - user_id: UserId, - is_admin: bool, - ) -> Result<()> { - test_support!(self, { - let query = " - INSERT INTO channel_memberships (channel_id, user_id, admin) - VALUES ($1, $2, $3) - ON CONFLICT DO NOTHING - "; - Ok(sqlx::query(query) - .bind(channel_id.0) - .bind(user_id.0) - .bind(is_admin) - .execute(&self.pool) - .await - .map(drop)?) - }) - } - - // messages - - pub async fn create_channel_message( - &self, - channel_id: ChannelId, - sender_id: UserId, - body: &str, - timestamp: OffsetDateTime, - nonce: u128, - ) -> Result { - test_support!(self, { - let query = " - INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce - RETURNING id - "; - Ok(sqlx::query_scalar(query) - .bind(channel_id.0) - .bind(sender_id.0) - .bind(body) - .bind(timestamp) - .bind(Uuid::from_u128(nonce)) - .fetch_one(&self.pool) - .await - .map(MessageId)?) - }) - } - - pub async fn get_channel_messages( - &self, - channel_id: ChannelId, - count: usize, - before_id: Option, - ) -> Result> { - test_support!(self, { - let query = r#" - SELECT * FROM ( - SELECT - id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce - FROM - channel_messages - WHERE - channel_id = $1 AND - id < $2 - ORDER BY id DESC - LIMIT $3 - ) as recent_messages - ORDER BY id ASC - "#; - Ok(sqlx::query_as(query) - .bind(channel_id.0) - .bind(before_id.unwrap_or(MessageId::MAX)) - .bind(count as i64) - .fetch_all(&self.pool) - .await?) - }) - } } macro_rules! id_type { @@ -1686,58 +1202,6 @@ pub struct Project { pub unregistered: bool, } -#[derive(Clone, Debug, PartialEq, Serialize)] -pub struct UserActivitySummary { - pub id: UserId, - pub github_login: String, - pub project_activity: Vec, -} - -#[derive(Clone, Debug, PartialEq, Serialize)] -pub struct ProjectActivitySummary { - pub id: ProjectId, - pub duration: Duration, - pub max_collaborators: usize, -} - -#[derive(Clone, Debug, PartialEq, Serialize)] -pub struct UserActivityPeriod { - pub project_id: ProjectId, - #[serde(with = "time::serde::iso8601")] - pub start: OffsetDateTime, - #[serde(with = "time::serde::iso8601")] - pub end: OffsetDateTime, - pub extensions: HashMap, -} - -id_type!(OrgId); -#[derive(FromRow)] -pub struct Org { - pub id: OrgId, - pub name: String, - pub slug: String, -} - -id_type!(ChannelId); -#[derive(Clone, Debug, FromRow, Serialize)] -pub struct Channel { - pub id: ChannelId, - pub name: String, - pub owner_id: i32, - pub owner_is_user: bool, -} - -id_type!(MessageId); -#[derive(Clone, Debug, FromRow)] -pub struct ChannelMessage { - pub id: MessageId, - pub channel_id: ChannelId, - pub sender_id: UserId, - pub body: String, - pub sent_at: OffsetDateTime, - pub nonce: Uuid, -} - #[derive(Clone, Debug, PartialEq, Eq)] pub enum Contact { Accepted { @@ -1840,14 +1304,22 @@ mod test { } impl TestDb { - pub async fn new(background: Arc) -> Self { + pub fn new(background: Arc) -> Self { let mut rng = StdRng::from_entropy(); let url = format!("/tmp/zed-test-{}", rng.gen::()); - sqlx::Sqlite::create_database(&url).await.unwrap(); - let mut db = DefaultDb::new(&url, 5).await.unwrap(); - db.background = Some(background); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); - db.migrate(Path::new(migrations_path), false).await.unwrap(); + let db = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap() + .block_on(async { + sqlx::Sqlite::create_database(&url).await.unwrap(); + let mut db = DefaultDb::new(&url, 5).await.unwrap(); + db.background = Some(background); + let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); + db.migrate(Path::new(migrations_path), false).await.unwrap(); + db + }); Self { db: Some(Arc::new(db)), url, diff --git a/crates/collab/src/db_tests.rs b/crates/collab/src/db_tests.rs index ecedc9973d..b6a785e9f1 100644 --- a/crates/collab/src/db_tests.rs +++ b/crates/collab/src/db_tests.rs @@ -1,12 +1,10 @@ use super::db::*; -use collections::HashMap; use gpui::executor::{Background, Deterministic}; -use std::{sync::Arc, time::Duration}; -use time::OffsetDateTime; +use std::sync::Arc; -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_get_users_by_ids() { - let test_db = TestDb::new(build_background_executor()).await; + let test_db = TestDb::new(build_background_executor()); let db = test_db.db(); let mut user_ids = Vec::new(); @@ -66,9 +64,9 @@ async fn test_get_users_by_ids() { ); } -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_get_user_by_github_account() { - let test_db = TestDb::new(build_background_executor()).await; + let test_db = TestDb::new(build_background_executor()); let db = test_db.db(); let user_id1 = db .create_user( @@ -122,407 +120,9 @@ async fn test_get_user_by_github_account() { assert_eq!(user.github_user_id, Some(102)); } -#[tokio::test(flavor = "multi_thread")] -async fn test_worktree_extensions() { - let test_db = TestDb::new(build_background_executor()).await; - let db = test_db.db(); - - let user = db - .create_user( - "u1@example.com", - false, - NewUserParams { - github_login: "u1".into(), - github_user_id: 0, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let project = db.register_project(user).await.unwrap(); - - db.update_worktree_extensions(project, 100, Default::default()) - .await - .unwrap(); - db.update_worktree_extensions( - project, - 100, - [("rs".to_string(), 5), ("md".to_string(), 3)] - .into_iter() - .collect(), - ) - .await - .unwrap(); - db.update_worktree_extensions( - project, - 100, - [("rs".to_string(), 6), ("md".to_string(), 5)] - .into_iter() - .collect(), - ) - .await - .unwrap(); - db.update_worktree_extensions( - project, - 101, - [("ts".to_string(), 2), ("md".to_string(), 1)] - .into_iter() - .collect(), - ) - .await - .unwrap(); - - assert_eq!( - db.get_project_extensions(project).await.unwrap(), - [ - ( - 100, - [("rs".into(), 6), ("md".into(), 5),] - .into_iter() - .collect::>() - ), - ( - 101, - [("ts".into(), 2), ("md".into(), 1),] - .into_iter() - .collect::>() - ) - ] - .into_iter() - .collect() - ); -} - -#[tokio::test(flavor = "multi_thread")] -async fn test_user_activity() { - let test_db = TestDb::new(build_background_executor()).await; - let db = test_db.db(); - - let mut user_ids = Vec::new(); - for i in 0..=2 { - user_ids.push( - db.create_user( - &format!("user{i}@example.com"), - false, - NewUserParams { - github_login: format!("user{i}"), - github_user_id: i, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id, - ); - } - - let project_1 = db.register_project(user_ids[0]).await.unwrap(); - db.update_worktree_extensions( - project_1, - 1, - HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]), - ) - .await - .unwrap(); - let project_2 = db.register_project(user_ids[1]).await.unwrap(); - let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60); - - // User 2 opens a project - let t1 = t0 + Duration::from_secs(10); - db.record_user_activity(t0..t1, &[(user_ids[1], project_2)]) - .await - .unwrap(); - - let t2 = t1 + Duration::from_secs(10); - db.record_user_activity(t1..t2, &[(user_ids[1], project_2)]) - .await - .unwrap(); - - // User 1 joins the project - let t3 = t2 + Duration::from_secs(10); - db.record_user_activity( - t2..t3, - &[(user_ids[1], project_2), (user_ids[0], project_2)], - ) - .await - .unwrap(); - - // User 1 opens another project - let t4 = t3 + Duration::from_secs(10); - db.record_user_activity( - t3..t4, - &[ - (user_ids[1], project_2), - (user_ids[0], project_2), - (user_ids[0], project_1), - ], - ) - .await - .unwrap(); - - // User 3 joins that project - let t5 = t4 + Duration::from_secs(10); - db.record_user_activity( - t4..t5, - &[ - (user_ids[1], project_2), - (user_ids[0], project_2), - (user_ids[0], project_1), - (user_ids[2], project_1), - ], - ) - .await - .unwrap(); - - // User 2 leaves - let t6 = t5 + Duration::from_secs(5); - db.record_user_activity( - t5..t6, - &[(user_ids[0], project_1), (user_ids[2], project_1)], - ) - .await - .unwrap(); - - let t7 = t6 + Duration::from_secs(60); - let t8 = t7 + Duration::from_secs(10); - db.record_user_activity(t7..t8, &[(user_ids[0], project_1)]) - .await - .unwrap(); - - assert_eq!( - db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(), - &[ - UserActivitySummary { - id: user_ids[0], - github_login: "user0".to_string(), - project_activity: vec![ - ProjectActivitySummary { - id: project_1, - duration: Duration::from_secs(25), - max_collaborators: 2 - }, - ProjectActivitySummary { - id: project_2, - duration: Duration::from_secs(30), - max_collaborators: 2 - } - ] - }, - UserActivitySummary { - id: user_ids[1], - github_login: "user1".to_string(), - project_activity: vec![ProjectActivitySummary { - id: project_2, - duration: Duration::from_secs(50), - max_collaborators: 2 - }] - }, - UserActivitySummary { - id: user_ids[2], - github_login: "user2".to_string(), - project_activity: vec![ProjectActivitySummary { - id: project_1, - duration: Duration::from_secs(15), - max_collaborators: 2 - }] - }, - ] - ); - - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(56), false) - .await - .unwrap(), - 0 - ); - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(56), true) - .await - .unwrap(), - 0 - ); - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(54), false) - .await - .unwrap(), - 1 - ); - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(54), true) - .await - .unwrap(), - 1 - ); - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(30), false) - .await - .unwrap(), - 2 - ); - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(30), true) - .await - .unwrap(), - 2 - ); - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(10), false) - .await - .unwrap(), - 3 - ); - assert_eq!( - db.get_active_user_count(t0..t6, Duration::from_secs(10), true) - .await - .unwrap(), - 3 - ); - assert_eq!( - db.get_active_user_count(t0..t1, Duration::from_secs(5), false) - .await - .unwrap(), - 1 - ); - assert_eq!( - db.get_active_user_count(t0..t1, Duration::from_secs(5), true) - .await - .unwrap(), - 0 - ); - - assert_eq!( - db.get_user_activity_timeline(t3..t6, user_ids[0]) - .await - .unwrap(), - &[ - UserActivityPeriod { - project_id: project_1, - start: t3, - end: t6, - extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]), - }, - UserActivityPeriod { - project_id: project_2, - start: t3, - end: t5, - extensions: Default::default(), - }, - ] - ); - assert_eq!( - db.get_user_activity_timeline(t0..t8, user_ids[0]) - .await - .unwrap(), - &[ - UserActivityPeriod { - project_id: project_2, - start: t2, - end: t5, - extensions: Default::default(), - }, - UserActivityPeriod { - project_id: project_1, - start: t3, - end: t6, - extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]), - }, - UserActivityPeriod { - project_id: project_1, - start: t7, - end: t8, - extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]), - }, - ] - ); -} - -#[tokio::test(flavor = "multi_thread")] -async fn test_recent_channel_messages() { - let test_db = TestDb::new(build_background_executor()).await; - let db = test_db.db(); - let user = db - .create_user( - "u@example.com", - false, - NewUserParams { - github_login: "u".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let org = db.create_org("org", "org").await.unwrap(); - let channel = db.create_org_channel(org, "channel").await.unwrap(); - for i in 0..10 { - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) - .await - .unwrap(); - } - - let messages = db.get_channel_messages(channel, 5, None).await.unwrap(); - assert_eq!( - messages.iter().map(|m| &m.body).collect::>(), - ["5", "6", "7", "8", "9"] - ); - - let prev_messages = db - .get_channel_messages(channel, 4, Some(messages[0].id)) - .await - .unwrap(); - assert_eq!( - prev_messages.iter().map(|m| &m.body).collect::>(), - ["1", "2", "3", "4"] - ); -} - -#[tokio::test(flavor = "multi_thread")] -async fn test_channel_message_nonces() { - let test_db = TestDb::new(build_background_executor()).await; - let db = test_db.db(); - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let org = db.create_org("org", "org").await.unwrap(); - let channel = db.create_org_channel(org, "channel").await.unwrap(); - - let msg1_id = db - .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg2_id = db - .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - let msg3_id = db - .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg4_id = db - .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - - assert_ne!(msg1_id, msg2_id); - assert_eq!(msg1_id, msg3_id); - assert_eq!(msg2_id, msg4_id); -} - -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_create_access_tokens() { - let test_db = TestDb::new(build_background_executor()).await; + let test_db = TestDb::new(build_background_executor()); let db = test_db.db(); let user = db .create_user( @@ -571,9 +171,9 @@ fn test_fuzzy_like_string() { assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); } -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_fuzzy_search_users() { - let test_db = TestDb::new(build_background_executor()).await; + let test_db = TestDb::new(build_background_executor()); let db = test_db.db(); for (i, github_login) in [ "California", @@ -619,9 +219,9 @@ async fn test_fuzzy_search_users() { } } -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_add_contacts() { - let test_db = TestDb::new(build_background_executor()).await; + let test_db = TestDb::new(build_background_executor()); let db = test_db.db(); let mut user_ids = Vec::new(); @@ -783,9 +383,9 @@ async fn test_add_contacts() { ); } -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_invite_codes() { - let test_db = TestDb::new(build_background_executor()).await; + let test_db = TestDb::new(build_background_executor()); let db = test_db.db(); let NewUserResult { user_id: user1, .. } = db .create_user( @@ -978,9 +578,9 @@ async fn test_invite_codes() { assert_eq!(invite_count, 1); } -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_signups() { - let test_db = TestDb::new(build_background_executor()).await; + let test_db = TestDb::new(build_background_executor()); let db = test_db.db(); // people sign up on the waitlist @@ -1124,9 +724,9 @@ async fn test_signups() { .unwrap_err(); } -#[tokio::test(flavor = "multi_thread")] +#[gpui::test] async fn test_metrics_id() { - let test_db = TestDb::new(build_background_executor()).await; + let test_db = TestDb::new(build_background_executor()); let db = test_db.db(); let NewUserResult { diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index ef51ff7152..906424f9c9 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -1,14 +1,14 @@ use crate::{ db::{NewUserParams, ProjectId, TestDb, UserId}, - rpc::{Executor, Server, Store}, + rpc::{Executor, Server}, AppState, }; use ::rpc::Peer; use anyhow::anyhow; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{ - self, test::FakeHttpClient, Channel, ChannelDetails, ChannelList, Client, Connection, - Credentials, EstablishConnectionError, PeerId, User, UserStore, RECEIVE_TIMEOUT, + self, test::FakeHttpClient, Client, Connection, Credentials, EstablishConnectionError, PeerId, + User, UserStore, RECEIVE_TIMEOUT, }; use collections::{BTreeMap, HashMap, HashSet}; use editor::{ @@ -16,10 +16,7 @@ use editor::{ ToggleCodeActions, Undo, }; use fs::{FakeFs, Fs as _, HomeDir, LineEnding}; -use futures::{ - channel::{mpsc, oneshot}, - Future, StreamExt as _, -}; +use futures::{channel::oneshot, Future, StreamExt as _}; use gpui::{ executor::{self, Deterministic}, geometry::vector::vec2f, @@ -39,7 +36,6 @@ use project::{ use rand::prelude::*; use serde_json::json; use settings::{Formatter, Settings}; -use sqlx::types::time::OffsetDateTime; use std::{ cell::{Cell, RefCell}, env, mem, @@ -72,11 +68,8 @@ async fn test_basic_calls( cx_b2: &mut TestAppContext, cx_c: &mut TestAppContext, ) { - // let runtime = tokio::runtime::Runtime::new().unwrap(); - // let _enter_guard = runtime.enter(); - deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let start = std::time::Instant::now(); @@ -279,7 +272,7 @@ async fn test_room_uniqueness( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let _client_a2 = server.create_client(cx_a2, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; @@ -384,7 +377,7 @@ async fn test_leaving_room_on_disconnection( cx_b: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -513,7 +506,7 @@ async fn test_calls_on_multiple_connections( cx_b2: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b1 = server.create_client(cx_b1, "user_b").await; let client_b2 = server.create_client(cx_b2, "user_b").await; @@ -662,7 +655,7 @@ async fn test_share_project( ) { deterministic.forbid_parking(); let (_, window_b) = cx_b.add_window(|_| EmptyView); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -799,7 +792,7 @@ async fn test_unshare_project( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -882,7 +875,7 @@ async fn test_host_disconnect( ) { cx_b.update(editor::init); deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -987,7 +980,7 @@ async fn test_active_call_events( cx_b: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; client_a.fs.insert_tree("/a", json!({})).await; @@ -1076,7 +1069,7 @@ async fn test_room_location( cx_b: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; client_a.fs.insert_tree("/a", json!({})).await; @@ -1242,7 +1235,7 @@ async fn test_propagate_saves_and_fs_changes( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -1417,7 +1410,7 @@ async fn test_git_diff_base_change( cx_b: &mut TestAppContext, ) { executor.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -1669,7 +1662,7 @@ async fn test_fs_operations( cx_b: &mut TestAppContext, ) { executor.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -1935,7 +1928,7 @@ async fn test_fs_operations( #[gpui::test(iterations = 10)] async fn test_buffer_conflict_after_save(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -1989,7 +1982,7 @@ async fn test_buffer_conflict_after_save(cx_a: &mut TestAppContext, cx_b: &mut T #[gpui::test(iterations = 10)] async fn test_buffer_reloading(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2048,7 +2041,7 @@ async fn test_editing_while_guest_opens_buffer( cx_b: &mut TestAppContext, ) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2095,7 +2088,7 @@ async fn test_leaving_worktree_while_opening_buffer( cx_b: &mut TestAppContext, ) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2140,7 +2133,7 @@ async fn test_canceling_buffer_opening( ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2191,7 +2184,7 @@ async fn test_leaving_project( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -2324,7 +2317,7 @@ async fn test_collaborating_with_diagnostics( cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -2589,7 +2582,7 @@ async fn test_collaborating_with_diagnostics( #[gpui::test(iterations = 10)] async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2763,7 +2756,7 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu #[gpui::test(iterations = 10)] async fn test_reloading_buffer_manually(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2856,7 +2849,7 @@ async fn test_reloading_buffer_manually(cx_a: &mut TestAppContext, cx_b: &mut Te async fn test_formatting_buffer(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { use project::FormatTrigger; - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -2957,7 +2950,7 @@ async fn test_formatting_buffer(cx_a: &mut TestAppContext, cx_b: &mut TestAppCon #[gpui::test(iterations = 10)] async fn test_definition(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3101,7 +3094,7 @@ async fn test_definition(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { #[gpui::test(iterations = 10)] async fn test_references(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3202,7 +3195,7 @@ async fn test_references(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { #[gpui::test(iterations = 10)] async fn test_project_search(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3281,7 +3274,7 @@ async fn test_project_search(cx_a: &mut TestAppContext, cx_b: &mut TestAppContex #[gpui::test(iterations = 10)] async fn test_document_highlights(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3383,7 +3376,7 @@ async fn test_document_highlights(cx_a: &mut TestAppContext, cx_b: &mut TestAppC #[gpui::test(iterations = 10)] async fn test_lsp_hover(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3486,7 +3479,7 @@ async fn test_lsp_hover(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { #[gpui::test(iterations = 10)] async fn test_project_symbols(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3594,7 +3587,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it( mut rng: StdRng, ) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3670,7 +3663,7 @@ async fn test_collaborating_with_code_actions( ) { cx_a.foreground().forbid_parking(); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -3881,7 +3874,7 @@ async fn test_collaborating_with_code_actions( async fn test_collaborating_with_renames(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { cx_a.foreground().forbid_parking(); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -4073,7 +4066,7 @@ async fn test_language_server_statuses( deterministic.forbid_parking(); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -4177,415 +4170,6 @@ async fn test_language_server_statuses( }); } -#[gpui::test(iterations = 10)] -async fn test_basic_chat(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { - cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - - // Create an org that includes these 2 users. - let db = &server.app_state.db; - let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - db.add_org_member(org_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_org_member(org_id, client_b.current_user_id(cx_b), false) - .await - .unwrap(); - - // Create a channel that includes all the users. - let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_channel_member(channel_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_channel_member(channel_id, client_b.current_user_id(cx_b), false) - .await - .unwrap(); - db.create_channel_message( - channel_id, - client_b.current_user_id(cx_b), - "hello A, it's B.", - OffsetDateTime::now_utc(), - 1, - ) - .await - .unwrap(); - - let channels_a = - cx_a.add_model(|cx| ChannelList::new(client_a.user_store.clone(), client_a.clone(), cx)); - channels_a - .condition(cx_a, |list, _| list.available_channels().is_some()) - .await; - channels_a.read_with(cx_a, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - let channel_a = channels_a.update(cx_a, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - channel_a.read_with(cx_a, |channel, _| assert!(channel.messages().is_empty())); - channel_a - .condition(cx_a, |channel, _| { - channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - }) - .await; - - let channels_b = - cx_b.add_model(|cx| ChannelList::new(client_b.user_store.clone(), client_b.clone(), cx)); - channels_b - .condition(cx_b, |list, _| list.available_channels().is_some()) - .await; - channels_b.read_with(cx_b, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - - let channel_b = channels_b.update(cx_b, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - channel_b.read_with(cx_b, |channel, _| assert!(channel.messages().is_empty())); - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - }) - .await; - - channel_a - .update(cx_a, |channel, cx| { - channel - .send_message("oh, hi B.".to_string(), cx) - .unwrap() - .detach(); - let task = channel.send_message("sup".to_string(), cx).unwrap(); - assert_eq!( - channel_messages(channel), - &[ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), true), - ("user_a".to_string(), "sup".to_string(), true) - ] - ); - task - }) - .await - .unwrap(); - - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), false), - ("user_a".to_string(), "sup".to_string(), false), - ] - }) - .await; - - assert_eq!( - server - .store() - .await - .channel(channel_id) - .unwrap() - .connection_ids - .len(), - 2 - ); - cx_b.update(|_| drop(channel_b)); - server - .condition(|state| state.channel(channel_id).unwrap().connection_ids.len() == 1) - .await; - - cx_a.update(|_| drop(channel_a)); - server - .condition(|state| state.channel(channel_id).is_none()) - .await; -} - -#[gpui::test(iterations = 10)] -async fn test_chat_message_validation(cx_a: &mut TestAppContext) { - cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; - let client_a = server.create_client(cx_a, "user_a").await; - - let db = &server.app_state.db; - let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_org_member(org_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_channel_member(channel_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - - let channels_a = - cx_a.add_model(|cx| ChannelList::new(client_a.user_store.clone(), client_a.clone(), cx)); - channels_a - .condition(cx_a, |list, _| list.available_channels().is_some()) - .await; - let channel_a = channels_a.update(cx_a, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - - // Messages aren't allowed to be too long. - channel_a - .update(cx_a, |channel, cx| { - let long_body = "this is long.\n".repeat(1024); - channel.send_message(long_body, cx).unwrap() - }) - .await - .unwrap_err(); - - // Messages aren't allowed to be blank. - channel_a.update(cx_a, |channel, cx| { - channel.send_message(String::new(), cx).unwrap_err() - }); - - // Leading and trailing whitespace are trimmed. - channel_a - .update(cx_a, |channel, cx| { - channel - .send_message("\n surrounded by whitespace \n".to_string(), cx) - .unwrap() - }) - .await - .unwrap(); - assert_eq!( - db.get_channel_messages(channel_id, 10, None) - .await - .unwrap() - .iter() - .map(|m| &m.body) - .collect::>(), - &["surrounded by whitespace"] - ); -} - -#[gpui::test(iterations = 10)] -async fn test_chat_reconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) { - cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; - let client_a = server.create_client(cx_a, "user_a").await; - let client_b = server.create_client(cx_b, "user_b").await; - - let mut status_b = client_b.status(); - - // Create an org that includes these 2 users. - let db = &server.app_state.db; - let org_id = db.create_org("Test Org", "test-org").await.unwrap(); - db.add_org_member(org_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_org_member(org_id, client_b.current_user_id(cx_b), false) - .await - .unwrap(); - - // Create a channel that includes all the users. - let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); - db.add_channel_member(channel_id, client_a.current_user_id(cx_a), false) - .await - .unwrap(); - db.add_channel_member(channel_id, client_b.current_user_id(cx_b), false) - .await - .unwrap(); - db.create_channel_message( - channel_id, - client_b.current_user_id(cx_b), - "hello A, it's B.", - OffsetDateTime::now_utc(), - 2, - ) - .await - .unwrap(); - - let channels_a = - cx_a.add_model(|cx| ChannelList::new(client_a.user_store.clone(), client_a.clone(), cx)); - channels_a - .condition(cx_a, |list, _| list.available_channels().is_some()) - .await; - - channels_a.read_with(cx_a, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - let channel_a = channels_a.update(cx_a, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - channel_a.read_with(cx_a, |channel, _| assert!(channel.messages().is_empty())); - channel_a - .condition(cx_a, |channel, _| { - channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - }) - .await; - - let channels_b = - cx_b.add_model(|cx| ChannelList::new(client_b.user_store.clone(), client_b.clone(), cx)); - channels_b - .condition(cx_b, |list, _| list.available_channels().is_some()) - .await; - channels_b.read_with(cx_b, |list, _| { - assert_eq!( - list.available_channels().unwrap(), - &[ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - - let channel_b = channels_b.update(cx_b, |this, cx| { - this.get_channel(channel_id.to_proto(), cx).unwrap() - }); - channel_b.read_with(cx_b, |channel, _| assert!(channel.messages().is_empty())); - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - }) - .await; - - // Disconnect client B, ensuring we can still access its cached channel data. - server.forbid_connections(); - server.disconnect_client(client_b.peer_id().unwrap()); - cx_b.foreground().advance_clock(rpc::RECEIVE_TIMEOUT); - while !matches!( - status_b.next().await, - Some(client::Status::ReconnectionError { .. }) - ) {} - - channels_b.read_with(cx_b, |channels, _| { - assert_eq!( - channels.available_channels().unwrap(), - [ChannelDetails { - id: channel_id.to_proto(), - name: "test-channel".to_string() - }] - ) - }); - channel_b.read_with(cx_b, |channel, _| { - assert_eq!( - channel_messages(channel), - [("user_b".to_string(), "hello A, it's B.".to_string(), false)] - ) - }); - - // Send a message from client B while it is disconnected. - channel_b - .update(cx_b, |channel, cx| { - let task = channel - .send_message("can you see this?".to_string(), cx) - .unwrap(); - assert_eq!( - channel_messages(channel), - &[ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_b".to_string(), "can you see this?".to_string(), true) - ] - ); - task - }) - .await - .unwrap_err(); - - // Send a message from client A while B is disconnected. - channel_a - .update(cx_a, |channel, cx| { - channel - .send_message("oh, hi B.".to_string(), cx) - .unwrap() - .detach(); - let task = channel.send_message("sup".to_string(), cx).unwrap(); - assert_eq!( - channel_messages(channel), - &[ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), true), - ("user_a".to_string(), "sup".to_string(), true) - ] - ); - task - }) - .await - .unwrap(); - - // Give client B a chance to reconnect. - server.allow_connections(); - cx_b.foreground().advance_clock(Duration::from_secs(10)); - - // Verify that B sees the new messages upon reconnection, as well as the message client B - // sent while offline. - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), false), - ("user_a".to_string(), "sup".to_string(), false), - ("user_b".to_string(), "can you see this?".to_string(), false), - ] - }) - .await; - - // Ensure client A and B can communicate normally after reconnection. - channel_a - .update(cx_a, |channel, cx| { - channel.send_message("you online?".to_string(), cx).unwrap() - }) - .await - .unwrap(); - channel_b - .condition(cx_b, |channel, _| { - channel_messages(channel) - == [ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), false), - ("user_a".to_string(), "sup".to_string(), false), - ("user_b".to_string(), "can you see this?".to_string(), false), - ("user_a".to_string(), "you online?".to_string(), false), - ] - }) - .await; - - channel_b - .update(cx_b, |channel, cx| { - channel.send_message("yep".to_string(), cx).unwrap() - }) - .await - .unwrap(); - channel_a - .condition(cx_a, |channel, _| { - channel_messages(channel) - == [ - ("user_b".to_string(), "hello A, it's B.".to_string(), false), - ("user_a".to_string(), "oh, hi B.".to_string(), false), - ("user_a".to_string(), "sup".to_string(), false), - ("user_b".to_string(), "can you see this?".to_string(), false), - ("user_a".to_string(), "you online?".to_string(), false), - ("user_b".to_string(), "yep".to_string(), false), - ] - }) - .await; -} - #[gpui::test(iterations = 10)] async fn test_contacts( deterministic: Arc, @@ -4594,7 +4178,7 @@ async fn test_contacts( cx_c: &mut TestAppContext, ) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -4920,7 +4504,7 @@ async fn test_contact_requests( cx_a.foreground().forbid_parking(); // Connect to a server as 3 clients. - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_a2 = server.create_client(cx_a2, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; @@ -5101,7 +4685,7 @@ async fn test_following( cx_a.update(editor::init); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -5375,7 +4959,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T cx_a.update(editor::init); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -5553,7 +5137,7 @@ async fn test_auto_unfollowing(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont cx_b.update(editor::init); // 2 clients connect to a server. - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -5727,7 +5311,7 @@ async fn test_peers_simultaneously_following_each_other( cx_a.update(editor::init); cx_b.update(editor::init); - let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; + let mut server = TestServer::start(cx_a.background()).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; server @@ -5797,7 +5381,7 @@ async fn test_random_collaboration( .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .unwrap_or(10); - let mut server = TestServer::start(cx.foreground(), cx.background()).await; + let mut server = TestServer::start(cx.background()).await; let db = server.app_state.db.clone(); let mut available_guests = Vec::new(); @@ -6084,8 +5668,6 @@ struct TestServer { peer: Arc, app_state: Arc, server: Arc, - foreground: Rc, - notifications: mpsc::UnboundedReceiver<()>, connection_killers: Arc>>>, forbid_connections: Arc, _test_db: TestDb, @@ -6093,18 +5675,10 @@ struct TestServer { } impl TestServer { - async fn start( - foreground: Rc, - background: Arc, - ) -> Self { + async fn start(background: Arc) -> Self { static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - let test_db = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap() - .block_on(TestDb::new(background.clone())); + let test_db = TestDb::new(background.clone()); let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); let live_kit_server = live_kit_client::TestServer::create( format!("http://livekit.{}.test", live_kit_server_id), @@ -6115,14 +5689,11 @@ impl TestServer { .unwrap(); let app_state = Self::build_app_state(&test_db, &live_kit_server).await; let peer = Peer::new(); - let notifications = mpsc::unbounded(); - let server = Server::new(app_state.clone(), Some(notifications.0)); + let server = Server::new(app_state.clone()); Self { peer, app_state, server, - foreground, - notifications: notifications.1, connection_killers: Default::default(), forbid_connections: Default::default(), _test_db: test_db, @@ -6238,7 +5809,6 @@ impl TestServer { default_item_factory: |_, _| unimplemented!(), }); - Channel::init(&client); Project::init(&client); cx.update(|cx| { workspace::init(app_state.clone(), cx); @@ -6339,21 +5909,6 @@ impl TestServer { config: Default::default(), }) } - - async fn condition(&mut self, mut predicate: F) - where - F: FnMut(&Store) -> bool, - { - assert!( - self.foreground.parking_forbidden(), - "you must call forbid_parking to use server conditions so we don't block indefinitely" - ); - while !(predicate)(&*self.server.store.lock().await) { - self.foreground.start_waiting(); - self.notifications.next().await; - self.foreground.finish_waiting(); - } - } } impl Deref for TestServer { @@ -7069,20 +6624,6 @@ impl Executor for Arc { } } -fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> { - channel - .messages() - .cursor::<()>() - .map(|m| { - ( - m.sender.github_login.clone(), - m.body.clone(), - m.is_pending(), - ) - }) - .collect() -} - #[derive(Debug, Eq, PartialEq)] struct RoomParticipants { remote: Vec, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 334e67ead9..dc98a2ee68 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -121,9 +121,7 @@ async fn main() -> Result<()> { let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) .expect("failed to bind TCP listener"); - let rpc_server = rpc::Server::new(state.clone(), None); - rpc_server - .start_recording_project_activity(Duration::from_secs(5 * 60), rpc::RealExecutor); + let rpc_server = rpc::Server::new(state.clone()); let app = api::routes(rpc_server.clone(), state.clone()) .merge(rpc::routes(rpc_server.clone())) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 9fd9bef825..7bc2b43b9b 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,7 +2,7 @@ mod store; use crate::{ auth, - db::{self, ChannelId, MessageId, ProjectId, User, UserId}, + db::{self, ProjectId, User, UserId}, AppState, Result, }; use anyhow::anyhow; @@ -24,7 +24,7 @@ use axum::{ }; use collections::{HashMap, HashSet}; use futures::{ - channel::{mpsc, oneshot}, + channel::oneshot, future::{self, BoxFuture}, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt, TryStreamExt, @@ -51,7 +51,6 @@ use std::{ time::Duration, }; pub use store::{Store, Worktree}; -use time::OffsetDateTime; use tokio::{ sync::{Mutex, MutexGuard}, time::Sleep, @@ -62,10 +61,6 @@ use tracing::{info_span, instrument, Instrument}; lazy_static! { static ref METRIC_CONNECTIONS: IntGauge = register_int_gauge!("connections", "number of connections").unwrap(); - static ref METRIC_REGISTERED_PROJECTS: IntGauge = - register_int_gauge!("registered_projects", "number of registered projects").unwrap(); - static ref METRIC_ACTIVE_PROJECTS: IntGauge = - register_int_gauge!("active_projects", "number of active projects").unwrap(); static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!( "shared_projects", "number of open projects with one or more guests" @@ -95,7 +90,6 @@ pub struct Server { pub(crate) store: Mutex, app_state: Arc, handlers: HashMap, - notifications: Option>, } pub trait Executor: Send + Clone { @@ -107,9 +101,6 @@ pub trait Executor: Send + Clone { #[derive(Clone)] pub struct RealExecutor; -const MESSAGE_COUNT_PER_PAGE: usize = 100; -const MAX_MESSAGE_LEN: usize = 1024; - pub(crate) struct StoreGuard<'a> { guard: MutexGuard<'a, Store>, _not_send: PhantomData>, @@ -132,16 +123,12 @@ where } impl Server { - pub fn new( - app_state: Arc, - notifications: Option>, - ) -> Arc { + pub fn new(app_state: Arc) -> Arc { let mut server = Self { peer: Peer::new(), app_state, store: Default::default(), handlers: Default::default(), - notifications, }; server @@ -158,9 +145,7 @@ impl Server { .add_request_handler(Server::join_project) .add_message_handler(Server::leave_project) .add_message_handler(Server::update_project) - .add_message_handler(Server::register_project_activity) .add_request_handler(Server::update_worktree) - .add_message_handler(Server::update_worktree_extensions) .add_message_handler(Server::start_language_server) .add_message_handler(Server::update_language_server) .add_message_handler(Server::update_diagnostic_summary) @@ -194,19 +179,14 @@ impl Server { .add_message_handler(Server::buffer_reloaded) .add_message_handler(Server::buffer_saved) .add_request_handler(Server::save_buffer) - .add_request_handler(Server::get_channels) .add_request_handler(Server::get_users) .add_request_handler(Server::fuzzy_search_users) .add_request_handler(Server::request_contact) .add_request_handler(Server::remove_contact) .add_request_handler(Server::respond_to_contact_request) - .add_request_handler(Server::join_channel) - .add_message_handler(Server::leave_channel) - .add_request_handler(Server::send_channel_message) .add_request_handler(Server::follow) .add_message_handler(Server::unfollow) .add_message_handler(Server::update_followers) - .add_request_handler(Server::get_channel_messages) .add_message_handler(Server::update_diff_base) .add_request_handler(Server::get_private_user_info); @@ -290,58 +270,6 @@ impl Server { }) } - /// Start a long lived task that records which users are active in which projects. - pub fn start_recording_project_activity( - self: &Arc, - interval: Duration, - executor: E, - ) { - executor.spawn_detached({ - let this = Arc::downgrade(self); - let executor = executor.clone(); - async move { - let mut period_start = OffsetDateTime::now_utc(); - let mut active_projects = Vec::<(UserId, ProjectId)>::new(); - loop { - let sleep = executor.sleep(interval); - sleep.await; - let this = if let Some(this) = this.upgrade() { - this - } else { - break; - }; - - active_projects.clear(); - active_projects.extend(this.store().await.projects().flat_map( - |(project_id, project)| { - project.guests.values().chain([&project.host]).filter_map( - |collaborator| { - if !collaborator.admin - && collaborator - .last_activity - .map_or(false, |activity| activity > period_start) - { - Some((collaborator.user_id, *project_id)) - } else { - None - } - }, - ) - }, - )); - - let period_end = OffsetDateTime::now_utc(); - this.app_state - .db - .record_user_activity(period_start..period_end, &active_projects) - .await - .trace_err(); - period_start = period_end; - } - } - }); - } - pub fn handle_connection( self: &Arc, connection: Connection, @@ -432,18 +360,11 @@ impl Server { let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name); let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { - let notifications = this.notifications.clone(); let is_background = message.is_background(); let handle_message = (handler)(this.clone(), message); - drop(span_enter); - let handle_message = async move { - handle_message.await; - if let Some(mut notifications) = notifications { - let _ = notifications.send(()).await; - } - }.instrument(span); + let handle_message = handle_message.instrument(span); if is_background { executor.spawn_detached(handle_message); } else { @@ -1172,17 +1093,6 @@ impl Server { Ok(()) } - async fn register_project_activity( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - self.store().await.register_project_activity( - ProjectId::from_proto(request.payload.project_id), - request.sender_id, - )?; - Ok(()) - } - async fn update_worktree( self: Arc, request: TypedEnvelope, @@ -1209,25 +1119,6 @@ impl Server { Ok(()) } - async fn update_worktree_extensions( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let worktree_id = request.payload.worktree_id; - let extensions = request - .payload - .extensions - .into_iter() - .zip(request.payload.counts) - .collect(); - self.app_state - .db - .update_worktree_extensions(project_id, worktree_id, extensions) - .await?; - Ok(()) - } - async fn update_diagnostic_summary( self: Arc, request: TypedEnvelope, @@ -1363,8 +1254,7 @@ impl Server { ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let receiver_ids = { - let mut store = self.store().await; - store.register_project_activity(project_id, request.sender_id)?; + let store = self.store().await; store.project_connection_ids(project_id, request.sender_id)? }; @@ -1430,15 +1320,13 @@ impl Server { let leader_id = ConnectionId(request.payload.leader_id); let follower_id = request.sender_id; { - let mut store = self.store().await; + let store = self.store().await; if !store .project_connection_ids(project_id, follower_id)? .contains(&leader_id) { Err(anyhow!("no such peer"))?; } - - store.register_project_activity(project_id, follower_id)?; } let mut response_payload = self @@ -1455,14 +1343,13 @@ impl Server { async fn unfollow(self: Arc, request: TypedEnvelope) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); - let mut store = self.store().await; + let store = self.store().await; if !store .project_connection_ids(project_id, request.sender_id)? .contains(&leader_id) { Err(anyhow!("no such peer"))?; } - store.register_project_activity(project_id, request.sender_id)?; self.peer .forward_send(request.sender_id, leader_id, request.payload)?; Ok(()) @@ -1473,8 +1360,7 @@ impl Server { request: TypedEnvelope, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - let mut store = self.store().await; - store.register_project_activity(project_id, request.sender_id)?; + let store = self.store().await; let connection_ids = store.project_connection_ids(project_id, request.sender_id)?; let leader_id = request .payload @@ -1495,28 +1381,6 @@ impl Server { Ok(()) } - async fn get_channels( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let channels = self.app_state.db.get_accessible_channels(user_id).await?; - response.send(proto::GetChannelsResponse { - channels: channels - .into_iter() - .map(|chan| proto::Channel { - id: chan.id.to_proto(), - name: chan.name, - }) - .collect(), - })?; - Ok(()) - } - async fn get_users( self: Arc, request: TypedEnvelope, @@ -1712,175 +1576,6 @@ impl Server { Ok(()) } - async fn join_channel( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let channel_id = ChannelId::from_proto(request.payload.channel_id); - if !self - .app_state - .db - .can_user_access_channel(user_id, channel_id) - .await? - { - Err(anyhow!("access denied"))?; - } - - self.store() - .await - .join_channel(request.sender_id, channel_id); - let messages = self - .app_state - .db - .get_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None) - .await? - .into_iter() - .map(|msg| proto::ChannelMessage { - id: msg.id.to_proto(), - body: msg.body, - timestamp: msg.sent_at.unix_timestamp() as u64, - sender_id: msg.sender_id.to_proto(), - nonce: Some(msg.nonce.as_u128().into()), - }) - .collect::>(); - response.send(proto::JoinChannelResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - })?; - Ok(()) - } - - async fn leave_channel( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let channel_id = ChannelId::from_proto(request.payload.channel_id); - if !self - .app_state - .db - .can_user_access_channel(user_id, channel_id) - .await? - { - Err(anyhow!("access denied"))?; - } - - self.store() - .await - .leave_channel(request.sender_id, channel_id); - - Ok(()) - } - - async fn send_channel_message( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let channel_id = ChannelId::from_proto(request.payload.channel_id); - let user_id; - let connection_ids; - { - let state = self.store().await; - user_id = state.user_id_for_connection(request.sender_id)?; - connection_ids = state.channel_connection_ids(channel_id)?; - } - - // Validate the message body. - let body = request.payload.body.trim().to_string(); - if body.len() > MAX_MESSAGE_LEN { - return Err(anyhow!("message is too long"))?; - } - if body.is_empty() { - return Err(anyhow!("message can't be blank"))?; - } - - let timestamp = OffsetDateTime::now_utc(); - let nonce = request - .payload - .nonce - .ok_or_else(|| anyhow!("nonce can't be blank"))?; - - let message_id = self - .app_state - .db - .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into()) - .await? - .to_proto(); - let message = proto::ChannelMessage { - sender_id: user_id.to_proto(), - id: message_id, - body, - timestamp: timestamp.unix_timestamp() as u64, - nonce: Some(nonce), - }; - broadcast(request.sender_id, connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::ChannelMessageSent { - channel_id: channel_id.to_proto(), - message: Some(message.clone()), - }, - ) - }); - response.send(proto::SendChannelMessageResponse { - message: Some(message), - })?; - Ok(()) - } - - async fn get_channel_messages( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let channel_id = ChannelId::from_proto(request.payload.channel_id); - if !self - .app_state - .db - .can_user_access_channel(user_id, channel_id) - .await? - { - Err(anyhow!("access denied"))?; - } - - let messages = self - .app_state - .db - .get_channel_messages( - channel_id, - MESSAGE_COUNT_PER_PAGE, - Some(MessageId::from_proto(request.payload.before_message_id)), - ) - .await? - .into_iter() - .map(|msg| proto::ChannelMessage { - id: msg.id.to_proto(), - body: msg.body, - timestamp: msg.sent_at.unix_timestamp() as u64, - sender_id: msg.sender_id.to_proto(), - nonce: Some(msg.nonce.as_u128().into()), - }) - .collect::>(); - response.send(proto::GetChannelMessagesResponse { - done: messages.len() < MESSAGE_COUNT_PER_PAGE, - messages, - })?; - Ok(()) - } - async fn update_diff_base( self: Arc, request: TypedEnvelope, @@ -2061,11 +1756,8 @@ pub async fn handle_websocket_request( } pub async fn handle_metrics(Extension(server): Extension>) -> axum::response::Response { - // We call `store_mut` here for its side effects of updating metrics. let metrics = server.store().await.metrics(); METRIC_CONNECTIONS.set(metrics.connections as _); - METRIC_REGISTERED_PROJECTS.set(metrics.registered_projects as _); - METRIC_ACTIVE_PROJECTS.set(metrics.active_projects as _); METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _); let encoder = prometheus::TextEncoder::new(); diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index c9358ddc2a..81ef594ccd 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -1,11 +1,10 @@ -use crate::db::{self, ChannelId, ProjectId, UserId}; +use crate::db::{self, ProjectId, UserId}; use anyhow::{anyhow, Result}; use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet}; use nanoid::nanoid; use rpc::{proto, ConnectionId}; use serde::Serialize; -use std::{borrow::Cow, mem, path::PathBuf, str, time::Duration}; -use time::OffsetDateTime; +use std::{borrow::Cow, mem, path::PathBuf, str}; use tracing::instrument; use util::post_inc; @@ -18,8 +17,6 @@ pub struct Store { next_room_id: RoomId, rooms: BTreeMap, projects: BTreeMap, - #[serde(skip)] - channels: BTreeMap, } #[derive(Default, Serialize)] @@ -33,7 +30,6 @@ struct ConnectionState { user_id: UserId, admin: bool, projects: BTreeSet, - channels: HashSet, } #[derive(Copy, Clone, Eq, PartialEq, Serialize)] @@ -60,8 +56,6 @@ pub struct Project { pub struct Collaborator { pub replica_id: ReplicaId, pub user_id: UserId, - #[serde(skip)] - pub last_activity: Option, pub admin: bool, } @@ -78,11 +72,6 @@ pub struct Worktree { pub is_complete: bool, } -#[derive(Default)] -pub struct Channel { - pub connection_ids: HashSet, -} - pub type ReplicaId = u16; #[derive(Default)] @@ -113,38 +102,23 @@ pub struct LeftRoom<'a> { #[derive(Copy, Clone)] pub struct Metrics { pub connections: usize, - pub registered_projects: usize, - pub active_projects: usize, pub shared_projects: usize, } impl Store { pub fn metrics(&self) -> Metrics { - const ACTIVE_PROJECT_TIMEOUT: Duration = Duration::from_secs(60); - let active_window_start = OffsetDateTime::now_utc() - ACTIVE_PROJECT_TIMEOUT; - let connections = self.connections.values().filter(|c| !c.admin).count(); - let mut registered_projects = 0; - let mut active_projects = 0; let mut shared_projects = 0; for project in self.projects.values() { if let Some(connection) = self.connections.get(&project.host_connection_id) { if !connection.admin { - registered_projects += 1; - if project.is_active_since(active_window_start) { - active_projects += 1; - if !project.guests.is_empty() { - shared_projects += 1; - } - } + shared_projects += 1; } } } Metrics { connections, - registered_projects, - active_projects, shared_projects, } } @@ -162,7 +136,6 @@ impl Store { user_id, admin, projects: Default::default(), - channels: Default::default(), }, ); let connected_user = self.connected_users.entry(user_id).or_default(); @@ -201,18 +174,12 @@ impl Store { .ok_or_else(|| anyhow!("no such connection"))?; let user_id = connection.user_id; - let connection_channels = mem::take(&mut connection.channels); let mut result = RemovedConnectionState { user_id, ..Default::default() }; - // Leave all channels. - for channel_id in connection_channels { - self.leave_channel(connection_id, channel_id); - } - let connected_user = self.connected_users.get(&user_id).unwrap(); if let Some(active_call) = connected_user.active_call.as_ref() { let room_id = active_call.room_id; @@ -238,34 +205,6 @@ impl Store { Ok(result) } - #[cfg(test)] - pub fn channel(&self, id: ChannelId) -> Option<&Channel> { - self.channels.get(&id) - } - - pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { - if let Some(connection) = self.connections.get_mut(&connection_id) { - connection.channels.insert(channel_id); - self.channels - .entry(channel_id) - .or_default() - .connection_ids - .insert(connection_id); - } - } - - pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { - if let Some(connection) = self.connections.get_mut(&connection_id) { - connection.channels.remove(&channel_id); - if let btree_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) { - entry.get_mut().connection_ids.remove(&connection_id); - if entry.get_mut().connection_ids.is_empty() { - entry.remove(); - } - } - } - } - pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result { Ok(self .connections @@ -760,7 +699,6 @@ impl Store { host: Collaborator { user_id: connection.user_id, replica_id: 0, - last_activity: None, admin: connection.admin, }, guests: Default::default(), @@ -959,12 +897,10 @@ impl Store { Collaborator { replica_id, user_id: connection.user_id, - last_activity: Some(OffsetDateTime::now_utc()), admin: connection.admin, }, ); - project.host.last_activity = Some(OffsetDateTime::now_utc()); Ok((project, replica_id)) } @@ -1056,44 +992,12 @@ impl Store { .connection_ids()) } - pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Result> { - Ok(self - .channels - .get(&channel_id) - .ok_or_else(|| anyhow!("no such channel"))? - .connection_ids()) - } - pub fn project(&self, project_id: ProjectId) -> Result<&Project> { self.projects .get(&project_id) .ok_or_else(|| anyhow!("no such project")) } - pub fn register_project_activity( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result<()> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - let collaborator = if connection_id == project.host_connection_id { - &mut project.host - } else if let Some(guest) = project.guests.get_mut(&connection_id) { - guest - } else { - return Err(anyhow!("no such project"))?; - }; - collaborator.last_activity = Some(OffsetDateTime::now_utc()); - Ok(()) - } - - pub fn projects(&self) -> impl Iterator { - self.projects.iter() - } - pub fn read_project( &self, project_id: ProjectId, @@ -1154,10 +1058,7 @@ impl Store { } } } - for channel_id in &connection.channels { - let channel = self.channels.get(channel_id).unwrap(); - assert!(channel.connection_ids.contains(connection_id)); - } + assert!(self .connected_users .get(&connection.user_id) @@ -1253,28 +1154,10 @@ impl Store { "project was not shared in room" ); } - - for (channel_id, channel) in &self.channels { - for connection_id in &channel.connection_ids { - let connection = self.connections.get(connection_id).unwrap(); - assert!(connection.channels.contains(channel_id)); - } - } } } impl Project { - fn is_active_since(&self, start_time: OffsetDateTime) -> bool { - self.guests - .values() - .chain([&self.host]) - .any(|collaborator| { - collaborator - .last_activity - .map_or(false, |active_time| active_time > start_time) - }) - } - pub fn guest_connection_ids(&self) -> Vec { self.guests.keys().copied().collect() } @@ -1287,9 +1170,3 @@ impl Project { .collect() } } - -impl Channel { - fn connection_ids(&self) -> Vec { - self.connection_ids.iter().copied().collect() - } -} diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 74a38599ec..e849632a2d 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -115,7 +115,6 @@ fn main() { context_menu::init(cx); project::Project::init(&client); - client::Channel::init(&client); client::init(client.clone(), cx); command_palette::init(cx); editor::init(cx);