From fe1a861bf3f37e0a218afa7df8a19c947e18c833 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 13 Jun 2022 10:53:56 +0200 Subject: [PATCH] Expose a new `POST /api/bulk_users` API to create many users at once This API will accept a vector of JSON entries containing the GitHub login, the email address and the invite count. If that user already exist, the invite count will be updated to the new one. --- crates/collab/src/api.rs | 37 ++++++++++++++ crates/collab/src/db.rs | 106 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 86db1d28d0..d1d4625c90 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -28,6 +28,7 @@ pub fn routes(rpc_server: &Arc, state: Arc) -> Router, +} + +#[derive(Deserialize)] +struct CreateUsersEntry { + github_login: String, + email_address: String, + invite_count: usize, +} + +async fn create_users( + Json(params): Json, + Extension(app): Extension>, +) -> Result>> { + let user_ids = app + .db + .create_users( + params + .users + .into_iter() + .map(|params| { + ( + params.github_login, + params.email_address, + params.invite_count, + ) + }) + .collect(), + ) + .await?; + let users = app.db.get_users_by_ids(user_ids).await?; + Ok(Json(users)) +} + #[derive(Debug, Deserialize)] struct Panic { version: String, diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index fc8c3f8d3e..439c6985a8 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -6,7 +6,7 @@ use futures::StreamExt; use nanoid::nanoid; use serde::Serialize; pub use sqlx::postgres::PgPoolOptions as DbOptions; -use sqlx::{types::Uuid, FromRow}; +use sqlx::{types::Uuid, FromRow, QueryBuilder, Row}; use time::OffsetDateTime; #[async_trait] @@ -17,6 +17,7 @@ pub trait Db: Send + Sync { email_address: Option<&str>, admin: bool, ) -> Result; + async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result>; async fn get_all_users(&self) -> Result>; async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result>; async fn get_user_by_id(&self, id: UserId) -> Result>; @@ -141,6 +142,41 @@ impl Db for PostgresDb { .map(UserId)?) } + async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result> { + let mut query = QueryBuilder::new( + "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)", + ); + query.push_values( + users, + |mut query, (github_login, email_address, invite_count)| { + query + .push_bind(github_login) + .push_bind(email_address) + .push_bind(false) + .push_bind(nanoid!(16)) + .push_bind(invite_count as u32); + }, + ); + query.push( + " + ON CONFLICT (github_login) DO UPDATE SET + github_login = excluded.github_login, + invite_count = excluded.invite_count, + invite_code = CASE WHEN users.invite_code IS NULL + THEN excluded.invite_code + ELSE users.invite_code + END + RETURNING id + ", + ); + + let rows = query.build().fetch_all(&self.pool).await?; + Ok(rows + .into_iter() + .filter_map(|row| row.try_get::(0).ok()) + .collect()) + } + async fn get_all_users(&self) -> Result> { let query = "SELECT * FROM users ORDER BY github_login ASC"; Ok(sqlx::query_as(query).fetch_all(&self.pool).await?) @@ -1021,6 +1057,70 @@ pub mod tests { } } + #[tokio::test(flavor = "multi_thread")] + async fn test_create_users() { + let db = TestDb::postgres().await; + let db = db.db(); + + // Create the first batch of users, ensuring invite counts are assigned + // correctly and the respective invite codes are unique. + let user_ids_batch_1 = db + .create_users(vec![ + ("user1".to_string(), "hi@user1.com".to_string(), 5), + ("user2".to_string(), "hi@user2.com".to_string(), 4), + ("user3".to_string(), "hi@user3.com".to_string(), 3), + ]) + .await + .unwrap(); + assert_eq!(user_ids_batch_1.len(), 3); + + let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap(); + assert_eq!(users.len(), 3); + assert_eq!(users[0].github_login, "user1"); + assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com")); + assert_eq!(users[0].invite_count, 5); + assert_eq!(users[1].github_login, "user2"); + assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com")); + assert_eq!(users[1].invite_count, 4); + assert_eq!(users[2].github_login, "user3"); + assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com")); + assert_eq!(users[2].invite_count, 3); + + let invite_code_1 = users[0].invite_code.clone().unwrap(); + let invite_code_2 = users[1].invite_code.clone().unwrap(); + let invite_code_3 = users[2].invite_code.clone().unwrap(); + assert_ne!(invite_code_1, invite_code_2); + assert_ne!(invite_code_1, invite_code_3); + assert_ne!(invite_code_2, invite_code_3); + + // Create the second batch of users and include a user that is already in the database, ensuring + // the invite count for the existing user is updated without changing their invite code. + let user_ids_batch_2 = db + .create_users(vec![ + ("user2".to_string(), "hi@user2.com".to_string(), 10), + ("user4".to_string(), "hi@user4.com".to_string(), 2), + ]) + .await + .unwrap(); + assert_eq!(user_ids_batch_2.len(), 2); + assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]); + + let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap(); + assert_eq!(users.len(), 2); + assert_eq!(users[0].github_login, "user2"); + assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com")); + assert_eq!(users[0].invite_count, 10); + assert_eq!(users[0].invite_code, Some(invite_code_2.clone())); + assert_eq!(users[1].github_login, "user4"); + assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com")); + assert_eq!(users[1].invite_count, 2); + + let invite_code_4 = users[1].invite_code.clone().unwrap(); + assert_ne!(invite_code_4, invite_code_1); + assert_ne!(invite_code_4, invite_code_2); + assert_ne!(invite_code_4, invite_code_3); + } + #[tokio::test(flavor = "multi_thread")] async fn test_recent_channel_messages() { for test_db in [ @@ -1665,6 +1765,10 @@ pub mod tests { } } + async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result> { + unimplemented!() + } + async fn get_all_users(&self) -> Result> { unimplemented!() }