Expose a new POST /api/bulk_users API to create many users at once

This API will accept a vector of JSON entries containing the GitHub login,
the email address and the invite count. If that user already exist, the
invite count will be updated to the new one.
This commit is contained in:
Antonio Scandurra 2022-06-13 10:53:56 +02:00
parent b1e8e81513
commit fe1a861bf3
2 changed files with 142 additions and 1 deletions

View file

@ -28,6 +28,7 @@ pub fn routes(rpc_server: &Arc<rpc::Server>, state: Arc<AppState>) -> Router<Bod
put(update_user).delete(destroy_user).get(get_user), put(update_user).delete(destroy_user).get(get_user),
) )
.route("/users/:id/access_tokens", post(create_access_token)) .route("/users/:id/access_tokens", post(create_access_token))
.route("/bulk_users", post(create_users))
.route("/invite_codes/:code", get(get_user_for_invite_code)) .route("/invite_codes/:code", get(get_user_for_invite_code))
.route("/panic", post(trace_panic)) .route("/panic", post(trace_panic))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
@ -167,6 +168,42 @@ async fn get_user(
Ok(Json(user)) Ok(Json(user))
} }
#[derive(Deserialize)]
struct CreateUsersParams {
users: Vec<CreateUsersEntry>,
}
#[derive(Deserialize)]
struct CreateUsersEntry {
github_login: String,
email_address: String,
invite_count: usize,
}
async fn create_users(
Json(params): Json<CreateUsersParams>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<Json<Vec<User>>> {
let user_ids = app
.db
.create_users(
params
.users
.into_iter()
.map(|params| {
(
params.github_login,
params.email_address,
params.invite_count,
)
})
.collect(),
)
.await?;
let users = app.db.get_users_by_ids(user_ids).await?;
Ok(Json(users))
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Panic { struct Panic {
version: String, version: String,

View file

@ -6,7 +6,7 @@ use futures::StreamExt;
use nanoid::nanoid; use nanoid::nanoid;
use serde::Serialize; use serde::Serialize;
pub use sqlx::postgres::PgPoolOptions as DbOptions; pub use sqlx::postgres::PgPoolOptions as DbOptions;
use sqlx::{types::Uuid, FromRow}; use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
use time::OffsetDateTime; use time::OffsetDateTime;
#[async_trait] #[async_trait]
@ -17,6 +17,7 @@ pub trait Db: Send + Sync {
email_address: Option<&str>, email_address: Option<&str>,
admin: bool, admin: bool,
) -> Result<UserId>; ) -> Result<UserId>;
async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
async fn get_all_users(&self) -> Result<Vec<User>>; async fn get_all_users(&self) -> Result<Vec<User>>;
async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>; async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>; async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
@ -141,6 +142,41 @@ impl Db for PostgresDb {
.map(UserId)?) .map(UserId)?)
} }
async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
let mut query = QueryBuilder::new(
"INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
);
query.push_values(
users,
|mut query, (github_login, email_address, invite_count)| {
query
.push_bind(github_login)
.push_bind(email_address)
.push_bind(false)
.push_bind(nanoid!(16))
.push_bind(invite_count as u32);
},
);
query.push(
"
ON CONFLICT (github_login) DO UPDATE SET
github_login = excluded.github_login,
invite_count = excluded.invite_count,
invite_code = CASE WHEN users.invite_code IS NULL
THEN excluded.invite_code
ELSE users.invite_code
END
RETURNING id
",
);
let rows = query.build().fetch_all(&self.pool).await?;
Ok(rows
.into_iter()
.filter_map(|row| row.try_get::<UserId, _>(0).ok())
.collect())
}
async fn get_all_users(&self) -> Result<Vec<User>> { async fn get_all_users(&self) -> Result<Vec<User>> {
let query = "SELECT * FROM users ORDER BY github_login ASC"; let query = "SELECT * FROM users ORDER BY github_login ASC";
Ok(sqlx::query_as(query).fetch_all(&self.pool).await?) Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
@ -1021,6 +1057,70 @@ pub mod tests {
} }
} }
#[tokio::test(flavor = "multi_thread")]
async fn test_create_users() {
let db = TestDb::postgres().await;
let db = db.db();
// Create the first batch of users, ensuring invite counts are assigned
// correctly and the respective invite codes are unique.
let user_ids_batch_1 = db
.create_users(vec![
("user1".to_string(), "hi@user1.com".to_string(), 5),
("user2".to_string(), "hi@user2.com".to_string(), 4),
("user3".to_string(), "hi@user3.com".to_string(), 3),
])
.await
.unwrap();
assert_eq!(user_ids_batch_1.len(), 3);
let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
assert_eq!(users.len(), 3);
assert_eq!(users[0].github_login, "user1");
assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
assert_eq!(users[0].invite_count, 5);
assert_eq!(users[1].github_login, "user2");
assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
assert_eq!(users[1].invite_count, 4);
assert_eq!(users[2].github_login, "user3");
assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
assert_eq!(users[2].invite_count, 3);
let invite_code_1 = users[0].invite_code.clone().unwrap();
let invite_code_2 = users[1].invite_code.clone().unwrap();
let invite_code_3 = users[2].invite_code.clone().unwrap();
assert_ne!(invite_code_1, invite_code_2);
assert_ne!(invite_code_1, invite_code_3);
assert_ne!(invite_code_2, invite_code_3);
// Create the second batch of users and include a user that is already in the database, ensuring
// the invite count for the existing user is updated without changing their invite code.
let user_ids_batch_2 = db
.create_users(vec![
("user2".to_string(), "hi@user2.com".to_string(), 10),
("user4".to_string(), "hi@user4.com".to_string(), 2),
])
.await
.unwrap();
assert_eq!(user_ids_batch_2.len(), 2);
assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
assert_eq!(users.len(), 2);
assert_eq!(users[0].github_login, "user2");
assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
assert_eq!(users[0].invite_count, 10);
assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
assert_eq!(users[1].github_login, "user4");
assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
assert_eq!(users[1].invite_count, 2);
let invite_code_4 = users[1].invite_code.clone().unwrap();
assert_ne!(invite_code_4, invite_code_1);
assert_ne!(invite_code_4, invite_code_2);
assert_ne!(invite_code_4, invite_code_3);
}
#[tokio::test(flavor = "multi_thread")] #[tokio::test(flavor = "multi_thread")]
async fn test_recent_channel_messages() { async fn test_recent_channel_messages() {
for test_db in [ for test_db in [
@ -1665,6 +1765,10 @@ pub mod tests {
} }
} }
async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
unimplemented!()
}
async fn get_all_users(&self) -> Result<Vec<User>> { async fn get_all_users(&self) -> Result<Vec<User>> {
unimplemented!() unimplemented!()
} }