diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 57976f16fd..df5071b90f 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -5,6 +5,7 @@ pub mod extensions; pub mod ips_file; pub mod slack; +use crate::db::Database; use crate::{ AppState, Error, Result, auth, db::{User, UserId}, @@ -97,6 +98,7 @@ impl std::fmt::Display for SystemIdHeader { pub fn routes(rpc_server: Arc) -> Router<(), Body> { Router::new() .route("/user", get(get_authenticated_user)) + .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .merge(billing::router()) @@ -181,6 +183,87 @@ async fn get_authenticated_user( })) } +#[derive(Debug, Deserialize)] +struct LookUpUserParams { + identifier: String, +} + +#[derive(Debug, Serialize)] +struct LookUpUserResponse { + user: Option, +} + +async fn look_up_user( + Query(params): Query, + Extension(app): Extension>, +) -> Result> { + let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?; + let user = if let Some(user) = user { + match user { + UserOrId::User(user) => Some(user), + UserOrId::Id(id) => app.db.get_user_by_id(id).await?, + } + } else { + None + }; + + Ok(Json(LookUpUserResponse { user })) +} + +enum UserOrId { + User(User), + Id(UserId), +} + +async fn resolve_identifier_to_user( + db: &Arc, + identifier: &str, +) -> Result> { + if let Some(identifier) = identifier.parse::().ok() { + let user = db.get_user_by_id(UserId(identifier)).await?; + + return Ok(user.map(UserOrId::User)); + } + + if identifier.starts_with("cus_") { + let billing_customer = db + .get_billing_customer_by_stripe_customer_id(&identifier) + .await?; + + return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))); + } + + if identifier.starts_with("sub_") { + let billing_subscription = db + .get_billing_subscription_by_stripe_subscription_id(&identifier) + .await?; + + if let Some(billing_subscription) = billing_subscription { + let billing_customer = db + .get_billing_customer_by_id(billing_subscription.billing_customer_id) + .await?; + + return Ok( + billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)) + ); + } else { + return Ok(None); + } + } + + if identifier.contains('@') { + let user = db.get_user_by_email(identifier).await?; + + return Ok(user.map(UserOrId::User)); + } + + if let Some(user) = db.get_user_by_github_login(identifier).await? { + return Ok(Some(UserOrId::User(user))); + } + + Ok(None) +} + #[derive(Deserialize, Debug)] struct CreateUserParams { github_user_id: i32, diff --git a/crates/collab/src/db/queries/billing_customers.rs b/crates/collab/src/db/queries/billing_customers.rs index 47e31bbe65..ead9e6cd32 100644 --- a/crates/collab/src/db/queries/billing_customers.rs +++ b/crates/collab/src/db/queries/billing_customers.rs @@ -57,6 +57,19 @@ impl Database { .await } + pub async fn get_billing_customer_by_id( + &self, + id: BillingCustomerId, + ) -> Result> { + self.transaction(|tx| async move { + Ok(billing_customer::Entity::find() + .filter(billing_customer::Column::Id.eq(id)) + .one(&*tx) + .await?) + }) + .await + } + /// Returns the billing customer for the user with the specified ID. pub async fn get_billing_customer_by_user_id( &self,