Start work on using sqlite in tests

This commit is contained in:
Max Brunsfeld 2022-11-09 19:15:05 -08:00
parent d14dd27cdc
commit 7e02ac772a
7 changed files with 250 additions and 96 deletions

View file

@ -5,7 +5,6 @@ use axum::http::StatusCode;
use collections::HashMap;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
pub use sqlx::postgres::PgPoolOptions as DbOptions;
use sqlx::{
migrate::{Migrate as _, Migration, MigrationSource},
types::Uuid,
@ -181,11 +180,14 @@ pub trait Db: Send + Sync {
pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
pub const TEST_MIGRATIONS_PATH: Option<&'static str> =
Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"));
#[cfg(not(any(test, debug_assertions)))]
pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
pub struct PostgresDb {
pool: sqlx::PgPool,
pub struct RealDb {
pool: sqlx::SqlitePool,
}
macro_rules! test_support {
@ -202,13 +204,13 @@ macro_rules! test_support {
}};
}
impl PostgresDb {
impl RealDb {
pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
let pool = DbOptions::new()
.max_connections(max_connections)
eprintln!("{url}");
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(1)
.connect(url)
.await
.context("failed to connect to postgres database")?;
.await?;
Ok(Self { pool })
}
@ -267,7 +269,7 @@ impl PostgresDb {
}
#[async_trait]
impl Db for PostgresDb {
impl Db for RealDb {
// users
async fn create_user(
@ -280,8 +282,8 @@ impl Db for PostgresDb {
let query = "
INSERT INTO users (email_address, github_login, github_user_id, admin)
VALUES ($1, $2, $3, $4)
ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
RETURNING id, metrics_id::text
-- ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
RETURNING id, 'the-metrics-id'
";
let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
@ -331,8 +333,18 @@ impl Db for PostgresDb {
}
async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
let users = self.get_users_by_ids(vec![id]).await?;
Ok(users.into_iter().next())
test_support!(self, {
let query = "
SELECT users.*
FROM users
WHERE id = $1
LIMIT 1
";
Ok(sqlx::query_as(query)
.bind(&id)
.fetch_optional(&self.pool)
.await?)
})
}
async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
@ -351,14 +363,13 @@ impl Db for PostgresDb {
async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
test_support!(self, {
let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
let query = "
SELECT users.*
FROM users
WHERE users.id = ANY ($1)
WHERE users.id IN (SELECT value from json_each($1))
";
Ok(sqlx::query_as(query)
.bind(&ids)
.bind(&serde_json::json!(ids))
.fetch_all(&self.pool)
.await?)
})
@ -493,7 +504,7 @@ impl Db for PostgresDb {
device_id
)
VALUES
($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
($1, $2, FALSE, $3, $4, $5, FALSE, $6)
RETURNING id
",
)
@ -502,8 +513,8 @@ impl Db for PostgresDb {
.bind(&signup.platform_linux)
.bind(&signup.platform_mac)
.bind(&signup.platform_windows)
.bind(&signup.editor_features)
.bind(&signup.programming_languages)
// .bind(&signup.editor_features)
// .bind(&signup.programming_languages)
.bind(&signup.device_id)
.execute(&self.pool)
.await?;
@ -555,21 +566,21 @@ impl Db for PostgresDb {
async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
test_support!(self, {
sqlx::query(
"
UPDATE signups
SET email_confirmation_sent = 't'
WHERE email_address = ANY ($1)
",
)
.bind(
&invites
.iter()
.map(|s| s.email_address.as_str())
.collect::<Vec<_>>(),
)
.execute(&self.pool)
.await?;
// sqlx::query(
// "
// UPDATE signups
// SET email_confirmation_sent = TRUE
// WHERE email_address = ANY ($1)
// ",
// )
// .bind(
// &invites
// .iter()
// .map(|s| s.email_address.as_str())
// .collect::<Vec<_>>(),
// )
// .execute(&self.pool)
// .await?;
Ok(())
})
}
@ -611,7 +622,7 @@ impl Db for PostgresDb {
INSERT INTO users
(email_address, github_login, github_user_id, admin, invite_count, invite_code)
VALUES
($1, $2, $3, 'f', $4, $5)
($1, $2, $3, FALSE, $4, $5)
ON CONFLICT (github_login) DO UPDATE SET
email_address = excluded.email_address,
github_user_id = excluded.github_user_id,
@ -664,7 +675,7 @@ impl Db for PostgresDb {
INSERT INTO contacts
(user_id_a, user_id_b, a_to_b, should_notify, accepted)
VALUES
($1, $2, 't', 't', 't')
($1, $2, TRUE, TRUE, TRUE)
ON CONFLICT DO NOTHING
",
)
@ -824,7 +835,7 @@ impl Db for PostgresDb {
device_id
)
VALUES
($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4)
($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
ON CONFLICT (email_address)
DO UPDATE SET
inviting_user_id = excluded.inviting_user_id
@ -870,7 +881,7 @@ impl Db for PostgresDb {
sqlx::query(
"
UPDATE projects
SET unregistered = 't'
SET unregistered = TRUE
WHERE id = $1
",
)
@ -1274,7 +1285,7 @@ impl Db for PostgresDb {
let query = "
SELECT 1 FROM contacts
WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
LIMIT 1
";
Ok(sqlx::query_scalar::<_, i32>(query)
@ -1295,11 +1306,11 @@ impl Db for PostgresDb {
};
let query = "
INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
VALUES ($1, $2, $3, 'f', 't')
VALUES ($1, $2, $3, FALSE, TRUE)
ON CONFLICT (user_id_a, user_id_b) DO UPDATE
SET
accepted = 't',
should_notify = 'f'
accepted = TRUE,
should_notify = FALSE
WHERE
NOT contacts.accepted AND
((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
@ -1359,7 +1370,7 @@ impl Db for PostgresDb {
let query = "
UPDATE contacts
SET should_notify = 'f'
SET should_notify = FALSE
WHERE
user_id_a = $1 AND user_id_b = $2 AND
(
@ -1398,7 +1409,7 @@ impl Db for PostgresDb {
let result = if accept {
let query = "
UPDATE contacts
SET accepted = 't', should_notify = 't'
SET accepted = TRUE, should_notify = TRUE
WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
";
sqlx::query(query)
@ -1706,7 +1717,7 @@ impl Db for PostgresDb {
";
sqlx::query(query).execute(&self.pool).await.log_err();
self.pool.close().await;
<sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
<sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
.await
.log_err();
eprintln!("tore down database: {:?}", start.elapsed());
@ -1929,10 +1940,9 @@ mod test {
use anyhow::anyhow;
use collections::BTreeMap;
use gpui::executor::Background;
use lazy_static::lazy_static;
use parking_lot::Mutex;
use rand::prelude::*;
use sqlx::{migrate::MigrateDatabase, Postgres};
use sqlx::{migrate::MigrateDatabase, Sqlite};
use std::sync::Arc;
use util::post_inc;
@ -2587,22 +2597,14 @@ mod test {
impl TestDb {
#[allow(clippy::await_holding_lock)]
pub async fn postgres() -> Self {
lazy_static! {
static ref LOCK: Mutex<()> = Mutex::new(());
}
pub async fn real() -> Self {
eprintln!("creating database...");
let start = std::time::Instant::now();
let _guard = LOCK.lock();
let mut rng = StdRng::from_entropy();
let name = format!("zed-test-{}", rng.gen::<u128>());
let url = format!("postgres://postgres@localhost:5433/{}", name);
Postgres::create_database(&url)
.await
.expect("failed to create test db");
let db = PostgresDb::new(&url, 5).await.unwrap();
db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false)
let url = format!("/tmp/zed-test-{}", rng.gen::<u128>());
Sqlite::create_database(&url).await.unwrap();
let db = RealDb::new(&url, 5).await.unwrap();
db.migrate(Path::new(TEST_MIGRATIONS_PATH.unwrap()), false)
.await
.unwrap();
@ -2628,7 +2630,7 @@ mod test {
impl Drop for TestDb {
fn drop(&mut self) {
if let Some(db) = self.db.take() {
futures::executor::block_on(db.teardown(&self.url));
std::fs::remove_file(&self.url).ok();
}
}
}

View file

@ -7,7 +7,7 @@ use time::OffsetDateTime;
#[tokio::test(flavor = "multi_thread")]
async fn test_get_users_by_ids() {
for test_db in [
TestDb::postgres().await,
TestDb::real().await,
TestDb::fake(build_background_executor()),
] {
let db = test_db.db();
@ -73,7 +73,7 @@ async fn test_get_users_by_ids() {
#[tokio::test(flavor = "multi_thread")]
async fn test_get_user_by_github_account() {
for test_db in [
TestDb::postgres().await,
TestDb::real().await,
TestDb::fake(build_background_executor()),
] {
let db = test_db.db();
@ -132,7 +132,7 @@ async fn test_get_user_by_github_account() {
#[tokio::test(flavor = "multi_thread")]
async fn test_worktree_extensions() {
let test_db = TestDb::postgres().await;
let test_db = TestDb::real().await;
let db = test_db.db();
let user = db
@ -204,7 +204,7 @@ async fn test_worktree_extensions() {
#[tokio::test(flavor = "multi_thread")]
async fn test_user_activity() {
let test_db = TestDb::postgres().await;
let test_db = TestDb::real().await;
let db = test_db.db();
let mut user_ids = Vec::new();
@ -448,7 +448,7 @@ async fn test_user_activity() {
#[tokio::test(flavor = "multi_thread")]
async fn test_recent_channel_messages() {
for test_db in [
TestDb::postgres().await,
TestDb::real().await,
TestDb::fake(build_background_executor()),
] {
let db = test_db.db();
@ -493,7 +493,7 @@ async fn test_recent_channel_messages() {
#[tokio::test(flavor = "multi_thread")]
async fn test_channel_message_nonces() {
for test_db in [
TestDb::postgres().await,
TestDb::real().await,
TestDb::fake(build_background_executor()),
] {
let db = test_db.db();
@ -538,7 +538,7 @@ async fn test_channel_message_nonces() {
#[tokio::test(flavor = "multi_thread")]
async fn test_create_access_tokens() {
let test_db = TestDb::postgres().await;
let test_db = TestDb::real().await;
let db = test_db.db();
let user = db
.create_user(
@ -582,14 +582,14 @@ async fn test_create_access_tokens() {
#[test]
fn test_fuzzy_like_string() {
assert_eq!(PostgresDb::fuzzy_like_string("abcd"), "%a%b%c%d%");
assert_eq!(PostgresDb::fuzzy_like_string("x y"), "%x%y%");
assert_eq!(PostgresDb::fuzzy_like_string(" z "), "%z%");
assert_eq!(RealDb::fuzzy_like_string("abcd"), "%a%b%c%d%");
assert_eq!(RealDb::fuzzy_like_string("x y"), "%x%y%");
assert_eq!(RealDb::fuzzy_like_string(" z "), "%z%");
}
#[tokio::test(flavor = "multi_thread")]
async fn test_fuzzy_search_users() {
let test_db = TestDb::postgres().await;
let test_db = TestDb::real().await;
let db = test_db.db();
for (i, github_login) in [
"California",
@ -638,7 +638,7 @@ async fn test_fuzzy_search_users() {
#[tokio::test(flavor = "multi_thread")]
async fn test_add_contacts() {
for test_db in [
TestDb::postgres().await,
TestDb::real().await,
TestDb::fake(build_background_executor()),
] {
let db = test_db.db();
@ -805,7 +805,7 @@ async fn test_add_contacts() {
#[tokio::test(flavor = "multi_thread")]
async fn test_invite_codes() {
let postgres = TestDb::postgres().await;
let postgres = TestDb::real().await;
let db = postgres.db();
let NewUserResult { user_id: user1, .. } = db
.create_user(
@ -1000,7 +1000,7 @@ async fn test_invite_codes() {
#[tokio::test(flavor = "multi_thread")]
async fn test_signups() {
let postgres = TestDb::postgres().await;
let postgres = TestDb::real().await;
let db = postgres.db();
// people sign up on the waitlist
@ -1146,7 +1146,7 @@ async fn test_signups() {
#[tokio::test(flavor = "multi_thread")]
async fn test_metrics_id() {
let postgres = TestDb::postgres().await;
let postgres = TestDb::real().await;
let db = postgres.db();
let NewUserResult {

View file

@ -53,7 +53,6 @@ use std::{
time::Duration,
};
use theme::ThemeRegistry;
use tokio::runtime::{EnterGuard, Runtime};
use unindent::Unindent as _;
use util::post_inc;
use workspace::{shared_screen::SharedScreen, Item, SplitDirection, ToggleFollow, Workspace};
@ -80,7 +79,6 @@ async fn test_basic_calls(
let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let start = std::time::Instant::now();
eprintln!("test_basic_calls");
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
@ -6106,7 +6104,7 @@ impl TestServer {
.enable_time()
.build()
.unwrap()
.block_on(TestDb::postgres());
.block_on(TestDb::real());
let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
let live_kit_server = live_kit_client::TestServer::create(
format!("http://livekit.{}.test", live_kit_server_id),
@ -6162,7 +6160,7 @@ impl TestServer {
},
)
.await
.unwrap()
.expect("creating user failed")
.user_id
};
let client_name = name.to_string();
@ -6202,7 +6200,11 @@ impl TestServer {
let (client_conn, server_conn, killed) =
Connection::in_memory(cx.background());
let (connection_id_tx, connection_id_rx) = oneshot::channel();
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
let user = db
.get_user_by_id(user_id)
.await
.expect("retrieving user failed")
.unwrap();
cx.background()
.spawn(server.handle_connection(
server_conn,

View file

@ -13,7 +13,7 @@ use crate::rpc::ResultExt as _;
use anyhow::anyhow;
use axum::{routing::get, Router};
use collab::{Error, Result};
use db::{Db, PostgresDb};
use db::{Db, RealDb};
use serde::Deserialize;
use std::{
env::args,
@ -56,7 +56,7 @@ pub struct AppState {
impl AppState {
async fn new(config: Config) -> Result<Arc<Self>> {
let db = PostgresDb::new(&config.database_url, 5).await?;
let db = RealDb::new(&config.database_url, 5).await?;
let live_kit_client = if let Some(((server, key), secret)) = config
.live_kit_server
.as_ref()
@ -96,7 +96,7 @@ async fn main() -> Result<()> {
}
Some("migrate") => {
let config = envy::from_env::<MigrateConfig>().expect("error loading config");
let db = PostgresDb::new(&config.database_url, 5).await?;
let db = RealDb::new(&config.database_url, 5).await?;
let migrations_path = config
.migrations_path