diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 24a1989ae7..7dbd28e5e9 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -26,6 +26,7 @@ pub fn routes(state: Arc) -> Router { put(update_user).delete(destroy_user).get(get_user), ) .route("/users/:id/access_tokens", post(create_access_token)) + .route("/invite_codes/:code", get(get_user_for_invite_code)) .route("/panic", post(trace_panic)) .layer( ServiceBuilder::new() @@ -210,3 +211,10 @@ async fn create_access_token( encrypted_access_token, })) } + +async fn get_user_for_invite_code( + Path(code): Path, + Extension(app): Extension>, +) -> Result> { + Ok(Json(app.db.get_user_for_invite_code(&code).await?)) +} diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 7463e3483d..2942821279 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -21,7 +21,8 @@ pub trait Db: Send + Sync { async fn destroy_user(&self, id: UserId) -> Result<()>; async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>; - async fn get_invite_code(&self, id: UserId) -> Result>; + async fn get_invite_code_for_user(&self, id: UserId) -> Result>; + async fn get_user_for_invite_code(&self, code: &str) -> Result; async fn redeem_invite_code(&self, code: &str, login: &str) -> Result; async fn get_contacts(&self, id: UserId) -> Result>; @@ -226,7 +227,7 @@ impl Db for PostgresDb { Ok(()) } - async fn get_invite_code(&self, id: UserId) -> Result> { + async fn get_invite_code_for_user(&self, id: UserId) -> Result> { let result: Option<(String, i32)> = sqlx::query_as( " SELECT invite_code, invite_count @@ -244,6 +245,25 @@ impl Db for PostgresDb { } } + async fn get_user_for_invite_code(&self, code: &str) -> Result { + sqlx::query_as( + " + SELECT * + FROM users + WHERE invite_code = $1 + ", + ) + .bind(code) + .fetch_optional(&self.pool) + .await? + .ok_or_else(|| { + Error::Http( + StatusCode::NOT_FOUND, + "that invite code does not exist".to_string(), + ) + }) + } + async fn redeem_invite_code(&self, code: &str, login: &str) -> Result { let mut tx = self.pool.begin().await?; @@ -1337,16 +1357,17 @@ pub mod tests { let user1 = db.create_user("user-1", false).await.unwrap(); // Initially, user 1 has no invite code - assert_eq!(db.get_invite_code(user1).await.unwrap(), None); + assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None); // User 1 creates an invite code that can be used twice. db.set_invite_count(user1, 2).await.unwrap(); - let (invite_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + let (invite_code, invite_count) = + db.get_invite_code_for_user(user1).await.unwrap().unwrap(); assert_eq!(invite_count, 2); // User 2 redeems the invite code and becomes a contact of user 1. let user2 = db.redeem_invite_code(&invite_code, "user-2").await.unwrap(); - let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); assert_eq!(invite_count, 1); assert_eq!( db.get_contacts(user1).await.unwrap(), @@ -1377,7 +1398,7 @@ pub mod tests { // User 3 redeems the invite code and becomes a contact of user 1. let user3 = db.redeem_invite_code(&invite_code, "user-3").await.unwrap(); - let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); assert_eq!(invite_count, 0); assert_eq!( db.get_contacts(user1).await.unwrap(), @@ -1417,13 +1438,14 @@ pub mod tests { // Invite count can be updated after the code has been created. db.set_invite_count(user1, 2).await.unwrap(); - let (latest_code, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + let (latest_code, invite_count) = + db.get_invite_code_for_user(user1).await.unwrap().unwrap(); assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0 assert_eq!(invite_count, 2); // User 4 can now redeem the invite code and becomes a contact of user 1. let user4 = db.redeem_invite_code(&invite_code, "user-4").await.unwrap(); - let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); assert_eq!(invite_count, 1); assert_eq!( db.get_contacts(user1).await.unwrap(), @@ -1464,7 +1486,7 @@ pub mod tests { db.redeem_invite_code(&invite_code, "user-2") .await .unwrap_err(); - let (_, invite_count) = db.get_invite_code(user1).await.unwrap().unwrap(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); assert_eq!(invite_count, 1); } @@ -1626,7 +1648,11 @@ pub mod tests { unimplemented!() } - async fn get_invite_code(&self, _id: UserId) -> Result> { + async fn get_invite_code_for_user(&self, _id: UserId) -> Result> { + unimplemented!() + } + + async fn get_user_for_invite_code(&self, _code: &str) -> Result { unimplemented!() }