From 5e240f98f0b80a5f2ebd902c690957e11a7d63b6 Mon Sep 17 00:00:00 2001 From: Mikayla Maki Date: Thu, 1 Dec 2022 18:31:05 -0800 Subject: [PATCH] Reworked thread safe connection be threadsafer,,,, again Co-Authored-By: kay@zed.dev --- crates/db/src/db.rs | 559 ++++++++------------- crates/db/src/kvp.rs | 29 +- crates/db/src/query.rs | 314 ++++++++++++ crates/editor/src/persistence.rs | 27 +- crates/sqlez/src/bindable.rs | 164 +++--- crates/sqlez/src/connection.rs | 14 +- crates/sqlez/src/domain.rs | 4 +- crates/sqlez/src/migrations.rs | 3 + crates/sqlez/src/thread_safe_connection.rs | 143 +++--- crates/terminal/src/persistence.rs | 19 +- crates/workspace/src/persistence.rs | 44 +- crates/workspace/src/workspace.rs | 5 +- 12 files changed, 741 insertions(+), 584 deletions(-) create mode 100644 crates/db/src/query.rs diff --git a/crates/db/src/db.rs b/crates/db/src/db.rs index 6de51cb0e6..6c6688b0d1 100644 --- a/crates/db/src/db.rs +++ b/crates/db/src/db.rs @@ -1,26 +1,27 @@ pub mod kvp; +pub mod query; // Re-export pub use anyhow; use anyhow::Context; pub use indoc::indoc; pub use lazy_static; +use parking_lot::{Mutex, RwLock}; pub use smol; pub use sqlez; pub use sqlez_macros; +pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}; +pub use util::paths::DB_DIR; use sqlez::domain::Migrator; use sqlez::thread_safe_connection::ThreadSafeConnection; use sqlez_macros::sql; use std::fs::{create_dir_all, remove_dir_all}; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{SystemTime, UNIX_EPOCH}; use util::{async_iife, ResultExt}; -use util::channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}; -use util::paths::DB_DIR; - -// TODO: Add a savepoint to the thread safe connection initialization and migrations +use util::channel::ReleaseChannel; const CONNECTION_INITIALIZE_QUERY: &'static str = sql!( PRAGMA synchronous=NORMAL; @@ -36,79 +37,117 @@ const DB_INITIALIZE_QUERY: &'static str = sql!( const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB"; lazy_static::lazy_static! { - static ref DB_WIPED: AtomicBool = AtomicBool::new(false); + static ref DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(()); + static ref DB_WIPED: RwLock = RwLock::new(false); + pub static ref BACKUP_DB_PATH: RwLock> = RwLock::new(None); + pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false); } /// Open or create a database at the given directory path. -pub async fn open_db() -> ThreadSafeConnection { - let db_dir = (*DB_DIR).join(Path::new(&format!("0-{}", *RELEASE_CHANNEL_NAME))); +/// This will retry a couple times if there are failures. If opening fails once, the db directory +/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created. +/// In either case, static variables are set so that the user can be notified. +pub async fn open_db(wipe_db: bool, db_dir: &Path, release_channel: &ReleaseChannel) -> ThreadSafeConnection { + let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel.name()))); // If WIPE_DB, delete 0-{channel} - if *RELEASE_CHANNEL == ReleaseChannel::Dev - && std::env::var("WIPE_DB").is_ok() - && !DB_WIPED.load(Ordering::Acquire) + if release_channel == &ReleaseChannel::Dev + && wipe_db + && !*DB_WIPED.read() { - remove_dir_all(&db_dir).ok(); - DB_WIPED.store(true, Ordering::Release); + let mut db_wiped = DB_WIPED.write(); + if !*db_wiped { + remove_dir_all(&main_db_dir).ok(); + + *db_wiped = true; + } } let connection = async_iife!({ + // Note: This still has a race condition where 1 set of migrations succeeds + // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal)) + // This will cause the first connection to have the database taken out + // from under it. This *should* be fine though. The second dabatase failure will + // cause errors in the log and so should be observed by developers while writing + // soon-to-be good migrations. If user databases are corrupted, we toss them out + // and try again from a blank. As long as running all migrations from start to end + // is ok, this race condition will never be triggered. + // + // Basically: Don't ever push invalid migrations to stable or everyone will have + // a bad time. + // If no db folder, create one at 0-{channel} - create_dir_all(&db_dir).context("Could not create db directory")?; - let db_path = db_dir.join(Path::new("db.sqlite")); - - // Try building a connection - if let Some(connection) = ThreadSafeConnection::::builder(db_path.to_string_lossy().as_ref(), true) - .with_db_initialization_query(DB_INITIALIZE_QUERY) - .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) - .build() - .await - .log_err() { - return Ok(connection) + create_dir_all(&main_db_dir).context("Could not create db directory")?; + let db_path = main_db_dir.join(Path::new("db.sqlite")); + + // Optimistically open databases in parallel + if !DB_FILE_OPERATIONS.is_locked() { + // Try building a connection + if let Some(connection) = open_main_db(&db_path).await { + return Ok(connection) + }; } + // Take a lock in the failure case so that we move the db once per process instead + // of potentially multiple times from different threads. This shouldn't happen in the + // normal path + let _lock = DB_FILE_OPERATIONS.lock(); + if let Some(connection) = open_main_db(&db_path).await { + return Ok(connection) + }; + let backup_timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) - .expect( - "System clock is set before the unix timestamp, Zed does not support this region of spacetime" - ) + .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime") .as_millis(); // If failed, move 0-{channel} to {current unix timestamp}-{channel} - let backup_db_dir = (*DB_DIR).join(Path::new(&format!( - "{}{}", + let backup_db_dir = db_dir.join(Path::new(&format!( + "{}-{}", backup_timestamp, - *RELEASE_CHANNEL_NAME + release_channel.name(), ))); - std::fs::rename(&db_dir, backup_db_dir) + std::fs::rename(&main_db_dir, &backup_db_dir) .context("Failed clean up corrupted database, panicking.")?; - // TODO: Set a constant with the failed timestamp and error so we can notify the user - + // Set a static ref with the failed timestamp and error so we can notify the user + { + let mut guard = BACKUP_DB_PATH.write(); + *guard = Some(backup_db_dir); + } + // Create a new 0-{channel} - create_dir_all(&db_dir).context("Should be able to create the database directory")?; - let db_path = db_dir.join(Path::new("db.sqlite")); + create_dir_all(&main_db_dir).context("Should be able to create the database directory")?; + let db_path = main_db_dir.join(Path::new("db.sqlite")); // Try again - ThreadSafeConnection::::builder(db_path.to_string_lossy().as_ref(), true) - .with_db_initialization_query(DB_INITIALIZE_QUERY) - .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) - .build() - .await + open_main_db(&db_path).await.context("Could not newly created db") }).await.log_err(); - if let Some(connection) = connection { + if let Some(connection) = connection { return connection; } - // TODO: Set another constant so that we can escalate the notification + // Set another static ref so that we can escalate the notification + ALL_FILE_DB_FAILED.store(true, Ordering::Release); // If still failed, create an in memory db with a known name open_fallback_db().await } +async fn open_main_db(db_path: &PathBuf) -> Option> { + println!("Opening main db"); + ThreadSafeConnection::::builder(db_path.to_string_lossy().as_ref(), true) + .with_db_initialization_query(DB_INITIALIZE_QUERY) + .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) + .build() + .await + .log_err() +} + async fn open_fallback_db() -> ThreadSafeConnection { + println!("Opening fallback db"); ThreadSafeConnection::::builder(FALLBACK_DB_NAME, false) .with_db_initialization_query(DB_INITIALIZE_QUERY) .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) @@ -135,17 +174,27 @@ pub async fn open_test_db(db_name: &str) -> ThreadSafeConnection /// Implements a basic DB wrapper for a given domain #[macro_export] -macro_rules! connection { - ($id:ident: $t:ident<$d:ty>) => { - pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$d>); +macro_rules! define_connection { + (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => { + pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>); impl ::std::ops::Deref for $t { - type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$d>; + type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>; fn deref(&self) -> &Self::Target { &self.0 } } + + impl $crate::sqlez::domain::Domain for $t { + fn name() -> &'static str { + stringify!($t) + } + + fn migrations() -> &'static [&'static str] { + $migrations + } + } #[cfg(any(test, feature = "test-support"))] $crate::lazy_static::lazy_static! { @@ -154,322 +203,124 @@ macro_rules! connection { #[cfg(not(any(test, feature = "test-support")))] $crate::lazy_static::lazy_static! { - pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db())); + pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(std::env::var("WIPE_DB").is_ok(), &$crate::DB_DIR, &$crate::RELEASE_CHANNEL))); + } + }; + (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => { + pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>); + + impl ::std::ops::Deref for $t { + type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl $crate::sqlez::domain::Domain for $t { + fn name() -> &'static str { + stringify!($t) + } + + fn migrations() -> &'static [&'static str] { + $migrations + } + } + + #[cfg(any(test, feature = "test-support"))] + $crate::lazy_static::lazy_static! { + pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id)))); + } + + #[cfg(not(any(test, feature = "test-support")))] + $crate::lazy_static::lazy_static! { + pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(std::env::var("WIPE_DB").is_ok(), &$crate::DB_DIR, &$crate::RELEASE_CHANNEL))); } }; } -#[macro_export] -macro_rules! query { - ($vis:vis fn $id:ident() -> Result<()> { $($sql:tt)+ }) => { - $vis fn $id(&self) -> $crate::anyhow::Result<()> { - use $crate::anyhow::Context; +#[cfg(test)] +mod tests { + use std::thread; - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + use sqlez::domain::Domain; + use sqlez_macros::sql; + use tempdir::TempDir; + use util::channel::ReleaseChannel; - self.exec(sql_stmt)?().context(::std::format!( - "Error in {}, exec failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt, - )) + use crate::open_db; + + enum TestDB {} + + impl Domain for TestDB { + fn name() -> &'static str { + "db_tests" } - }; - ($vis:vis async fn $id:ident() -> Result<()> { $($sql:tt)+ }) => { - $vis async fn $id(&self) -> $crate::anyhow::Result<()> { - use $crate::anyhow::Context; - self.write(|connection| { - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - connection.exec(sql_stmt)?().context(::std::format!( - "Error in {}, exec failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - }).await + fn migrations() -> &'static [&'static str] { + &[sql!( + CREATE TABLE test(value); + )] } - }; - ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => { - $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> { - use $crate::anyhow::Context; + } + + // Test that wipe_db exists and works and gives a new db + #[test] + fn test_wipe_db() { + env_logger::try_init().ok(); + + smol::block_on(async { + let tempdir = TempDir::new("DbTests").unwrap(); + + let test_db = open_db::(false, tempdir.path(), &util::channel::ReleaseChannel::Dev).await; + test_db.write(|connection| + connection.exec(sql!( + INSERT INTO test(value) VALUES (10) + )).unwrap()().unwrap() + ).await; + drop(test_db); + + let mut guards = vec![]; + for _ in 0..5 { + let path = tempdir.path().to_path_buf(); + let guard = thread::spawn(move || smol::block_on(async { + let test_db = open_db::(true, &path, &ReleaseChannel::Dev).await; + + assert!(test_db.select_row::<()>(sql!(SELECT value FROM test)).unwrap()().unwrap().is_none()) + })); + + guards.push(guard); + } + + for guard in guards { + guard.join().unwrap(); + } + }) + } - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - self.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+)) - .context(::std::format!( - "Error in {}, exec_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - } - }; - ($vis:vis async fn $id:ident($arg:ident: $arg_type:ty) -> Result<()> { $($sql:tt)+ }) => { - $vis async fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<()> { - use $crate::anyhow::Context; - - self.write(move |connection| { - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - connection.exec_bound::<$arg_type>(sql_stmt)?($arg) - .context(::std::format!( - "Error in {}, exec_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - }).await - } - }; - ($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => { - $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> { - use $crate::anyhow::Context; - - self.write(move |connection| { - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - connection.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+)) - .context(::std::format!( - "Error in {}, exec_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - }).await - } - }; - ($vis:vis fn $id:ident() -> Result> { $($sql:tt)+ }) => { - $vis fn $id(&self) -> $crate::anyhow::Result> { - use $crate::anyhow::Context; - - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - self.select::<$return_type>(sql_stmt)?(()) - .context(::std::format!( - "Error in {}, select_row failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - } - }; - ($vis:vis async fn $id:ident() -> Result> { $($sql:tt)+ }) => { - pub async fn $id(&self) -> $crate::anyhow::Result> { - use $crate::anyhow::Context; - - self.write(|connection| { - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - connection.select::<$return_type>(sql_stmt)?(()) - .context(::std::format!( - "Error in {}, select_row failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - }).await - } - }; - ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { - $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { - use $crate::anyhow::Context; - - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - self.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) - .context(::std::format!( - "Error in {}, exec_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - } - }; - ($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { - $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { - use $crate::anyhow::Context; - - self.write(|connection| { - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - connection.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) - .context(::std::format!( - "Error in {}, exec_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - }).await - } - }; - ($vis:vis fn $id:ident() -> Result> { $($sql:tt)+ }) => { - $vis fn $id(&self) -> $crate::anyhow::Result> { - use $crate::anyhow::Context; - - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - self.select_row::<$return_type>(sql_stmt)?() - .context(::std::format!( - "Error in {}, select_row failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - } - }; - ($vis:vis async fn $id:ident() -> Result> { $($sql:tt)+ }) => { - $vis async fn $id(&self) -> $crate::anyhow::Result> { - use $crate::anyhow::Context; - - self.write(|connection| { - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - connection.select_row::<$return_type>(sql_stmt)?() - .context(::std::format!( - "Error in {}, select_row failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - }).await - } - }; - ($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result> { $($sql:tt)+ }) => { - $vis fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result> { - use $crate::anyhow::Context; - - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg) - .context(::std::format!( - "Error in {}, select_row_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - - } - }; - ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { - $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { - use $crate::anyhow::Context; - - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) - .context(::std::format!( - "Error in {}, select_row_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - - } - }; - ($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { - $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { - use $crate::anyhow::Context; - - - self.write(|connection| { - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - connection.select_row_bound::<($($arg_type),+), $return_type>(indoc! { $sql })?(($($arg),+)) - .context(::std::format!( - "Error in {}, select_row_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - )) - }).await - } - }; - ($vis:vis fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => { - $vis fn $id(&self) -> $crate::anyhow::Result<$return_type> { - use $crate::anyhow::Context; - - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - self.select_row::<$return_type>(indoc! { $sql })?() - .context(::std::format!( - "Error in {}, select_row_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - ))? - .context(::std::format!( - "Error in {}, select_row_bound expected single row result but found none for: {}", - ::std::stringify!($id), - sql_stmt - )) - } - }; - ($vis:vis async fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => { - $vis async fn $id(&self) -> $crate::anyhow::Result<$return_type> { - use $crate::anyhow::Context; - - self.write(|connection| { - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - connection.select_row::<$return_type>(sql_stmt)?() - .context(::std::format!( - "Error in {}, select_row_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - ))? - .context(::std::format!( - "Error in {}, select_row_bound expected single row result but found none for: {}", - ::std::stringify!($id), - sql_stmt - )) - }).await - } - }; - ($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<$return_type:ty> { $($sql:tt)+ }) => { - pub fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<$return_type> { - use $crate::anyhow::Context; - - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg) - .context(::std::format!( - "Error in {}, select_row_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - ))? - .context(::std::format!( - "Error in {}, select_row_bound expected single row result but found none for: {}", - ::std::stringify!($id), - sql_stmt - )) - } - }; - ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => { - $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> { - use $crate::anyhow::Context; - - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) - .context(::std::format!( - "Error in {}, select_row_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - ))? - .context(::std::format!( - "Error in {}, select_row_bound expected single row result but found none for: {}", - ::std::stringify!($id), - sql_stmt - )) - } - }; - ($vis:vis fn async $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => { - $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> { - use $crate::anyhow::Context; - - - self.write(|connection| { - let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); - - connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) - .context(::std::format!( - "Error in {}, select_row_bound failed to execute or parse for: {}", - ::std::stringify!($id), - sql_stmt - ))? - .context(::std::format!( - "Error in {}, select_row_bound expected single row result but found none for: {}", - ::std::stringify!($id), - sql_stmt - )) - }).await - } - }; + // Test a file system failure (like in create_dir_all()) + #[test] + fn test_file_system_failure() { + + } + + // Test happy path where everything exists and opens + #[test] + fn test_open_db() { + + } + + // Test bad migration panics + #[test] + fn test_bad_migration_panics() { + + } + + /// Test that DB exists but corrupted (causing recreate) + #[test] + fn test_db_corruption() { + + + // open_db(db_dir, release_channel) + } } diff --git a/crates/db/src/kvp.rs b/crates/db/src/kvp.rs index 70ee9f64da..0b0cdd9aa1 100644 --- a/crates/db/src/kvp.rs +++ b/crates/db/src/kvp.rs @@ -1,26 +1,15 @@ -use sqlez::domain::Domain; use sqlez_macros::sql; -use crate::{connection, query}; +use crate::{define_connection, query}; -connection!(KEY_VALUE_STORE: KeyValueStore); - -impl Domain for KeyValueStore { - fn name() -> &'static str { - "kvp" - } - - fn migrations() -> &'static [&'static str] { - // Legacy migrations using rusqlite may have already created kv_store during alpha, - // migrations must be infallible so this must have 'IF NOT EXISTS' - &[sql!( - CREATE TABLE IF NOT EXISTS kv_store( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - ) STRICT; - )] - } -} +define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> = + &[sql!( + CREATE TABLE IF NOT EXISTS kv_store( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) STRICT; + )]; +); impl KeyValueStore { query! { diff --git a/crates/db/src/query.rs b/crates/db/src/query.rs new file mode 100644 index 0000000000..731fca15cb --- /dev/null +++ b/crates/db/src/query.rs @@ -0,0 +1,314 @@ +#[macro_export] +macro_rules! query { + ($vis:vis fn $id:ident() -> Result<()> { $($sql:tt)+ }) => { + $vis fn $id(&self) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.exec(sql_stmt)?().context(::std::format!( + "Error in {}, exec failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt, + )) + } + }; + ($vis:vis async fn $id:ident() -> Result<()> { $($sql:tt)+ }) => { + $vis async fn $id(&self) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.exec(sql_stmt)?().context(::std::format!( + "Error in {}, exec failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => { + $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident($arg:ident: $arg_type:ty) -> Result<()> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + self.write(move |connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.exec_bound::<$arg_type>(sql_stmt)?($arg) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + self.write(move |connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident() -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select::<$return_type>(sql_stmt)?(()) + .context(::std::format!( + "Error in {}, select_row failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident() -> Result> { $($sql:tt)+ }) => { + pub async fn $id(&self) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select::<$return_type>(sql_stmt)?(()) + .context(::std::format!( + "Error in {}, select_row failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident() -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row::<$return_type>(sql_stmt)?() + .context(::std::format!( + "Error in {}, select_row failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident() -> Result> { $($sql:tt)+ }) => { + $vis async fn $id(&self) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_row::<$return_type>(sql_stmt)?() + .context(::std::format!( + "Error in {}, select_row failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + + } + }; + ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + + } + }; + ($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_row_bound::<($($arg_type),+), $return_type>(indoc! { $sql })?(($($arg),+)) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => { + $vis fn $id(&self) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row::<$return_type>(indoc! { $sql })?() + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => { + $vis async fn $id(&self) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_row::<$return_type>(sql_stmt)?() + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<$return_type:ty> { $($sql:tt)+ }) => { + pub fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => { + $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis fn async $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; +} diff --git a/crates/editor/src/persistence.rs b/crates/editor/src/persistence.rs index 3416f479e7..31ada105af 100644 --- a/crates/editor/src/persistence.rs +++ b/crates/editor/src/persistence.rs @@ -1,19 +1,11 @@ use std::path::PathBuf; -use crate::Editor; use db::sqlez_macros::sql; -use db::{connection, query}; -use sqlez::domain::Domain; -use workspace::{ItemId, Workspace, WorkspaceId}; +use db::{define_connection, query}; +use workspace::{ItemId, WorkspaceDb, WorkspaceId}; -connection!(DB: EditorDb<(Workspace, Editor)>); - -impl Domain for Editor { - fn name() -> &'static str { - "editor" - } - - fn migrations() -> &'static [&'static str] { +define_connection!( + pub static ref DB: EditorDb = &[sql! ( CREATE TABLE editors( item_id INTEGER NOT NULL, @@ -21,12 +13,11 @@ impl Domain for Editor { path BLOB NOT NULL, PRIMARY KEY(item_id, workspace_id), FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) - ON DELETE CASCADE - ON UPDATE CASCADE - ) STRICT; - )] - } -} + ON DELETE CASCADE + ON UPDATE CASCADE + ) STRICT; + )]; +); impl EditorDb { query! { diff --git a/crates/sqlez/src/bindable.rs b/crates/sqlez/src/bindable.rs index ffef7814f9..3649037e50 100644 --- a/crates/sqlez/src/bindable.rs +++ b/crates/sqlez/src/bindable.rs @@ -137,13 +137,6 @@ impl Column for usize { } } -impl Bind for () { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind_null(start_index)?; - Ok(start_index + 1) - } -} - impl Bind for &str { fn bind(&self, statement: &Statement, start_index: i32) -> Result { statement.bind_text(start_index, self)?; @@ -179,78 +172,6 @@ impl Column for String { } } -impl Bind for (T1, T2) { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - let next_index = self.0.bind(statement, start_index)?; - self.1.bind(statement, next_index) - } -} - -impl Column for (T1, T2) { - fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let (first, next_index) = T1::column(statement, start_index)?; - let (second, next_index) = T2::column(statement, next_index)?; - Ok(((first, second), next_index)) - } -} - -impl Bind for (T1, T2, T3) { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - let next_index = self.0.bind(statement, start_index)?; - let next_index = self.1.bind(statement, next_index)?; - self.2.bind(statement, next_index) - } -} - -impl Column for (T1, T2, T3) { - fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let (first, next_index) = T1::column(statement, start_index)?; - let (second, next_index) = T2::column(statement, next_index)?; - let (third, next_index) = T3::column(statement, next_index)?; - Ok(((first, second, third), next_index)) - } -} - -impl Bind for (T1, T2, T3, T4) { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - let next_index = self.0.bind(statement, start_index)?; - let next_index = self.1.bind(statement, next_index)?; - let next_index = self.2.bind(statement, next_index)?; - self.3.bind(statement, next_index) - } -} - -impl Column for (T1, T2, T3, T4) { - fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let (first, next_index) = T1::column(statement, start_index)?; - let (second, next_index) = T2::column(statement, next_index)?; - let (third, next_index) = T3::column(statement, next_index)?; - let (fourth, next_index) = T4::column(statement, next_index)?; - Ok(((first, second, third, fourth), next_index)) - } -} - -impl Bind for (T1, T2, T3, T4, T5) { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - let next_index = self.0.bind(statement, start_index)?; - let next_index = self.1.bind(statement, next_index)?; - let next_index = self.2.bind(statement, next_index)?; - let next_index = self.3.bind(statement, next_index)?; - self.4.bind(statement, next_index) - } -} - -impl Column for (T1, T2, T3, T4, T5) { - fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let (first, next_index) = T1::column(statement, start_index)?; - let (second, next_index) = T2::column(statement, next_index)?; - let (third, next_index) = T3::column(statement, next_index)?; - let (fourth, next_index) = T4::column(statement, next_index)?; - let (fifth, next_index) = T5::column(statement, next_index)?; - Ok(((first, second, third, fourth, fifth), next_index)) - } -} - impl Bind for Option { fn bind(&self, statement: &Statement, start_index: i32) -> Result { if let Some(this) = self { @@ -344,3 +265,88 @@ impl Column for PathBuf { )) } } + +/// Unit impls do nothing. This simplifies query macros +impl Bind for () { + fn bind(&self, _statement: &Statement, start_index: i32) -> Result { + Ok(start_index) + } +} + +impl Column for () { + fn column(_statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + Ok(((), start_index)) + } +} + +impl Bind for (T1, T2) { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let next_index = self.0.bind(statement, start_index)?; + self.1.bind(statement, next_index) + } +} + +impl Column for (T1, T2) { + fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (first, next_index) = T1::column(statement, start_index)?; + let (second, next_index) = T2::column(statement, next_index)?; + Ok(((first, second), next_index)) + } +} + +impl Bind for (T1, T2, T3) { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let next_index = self.0.bind(statement, start_index)?; + let next_index = self.1.bind(statement, next_index)?; + self.2.bind(statement, next_index) + } +} + +impl Column for (T1, T2, T3) { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (first, next_index) = T1::column(statement, start_index)?; + let (second, next_index) = T2::column(statement, next_index)?; + let (third, next_index) = T3::column(statement, next_index)?; + Ok(((first, second, third), next_index)) + } +} + +impl Bind for (T1, T2, T3, T4) { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let next_index = self.0.bind(statement, start_index)?; + let next_index = self.1.bind(statement, next_index)?; + let next_index = self.2.bind(statement, next_index)?; + self.3.bind(statement, next_index) + } +} + +impl Column for (T1, T2, T3, T4) { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (first, next_index) = T1::column(statement, start_index)?; + let (second, next_index) = T2::column(statement, next_index)?; + let (third, next_index) = T3::column(statement, next_index)?; + let (fourth, next_index) = T4::column(statement, next_index)?; + Ok(((first, second, third, fourth), next_index)) + } +} + +impl Bind for (T1, T2, T3, T4, T5) { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let next_index = self.0.bind(statement, start_index)?; + let next_index = self.1.bind(statement, next_index)?; + let next_index = self.2.bind(statement, next_index)?; + let next_index = self.3.bind(statement, next_index)?; + self.4.bind(statement, next_index) + } +} + +impl Column for (T1, T2, T3, T4, T5) { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (first, next_index) = T1::column(statement, start_index)?; + let (second, next_index) = T2::column(statement, next_index)?; + let (third, next_index) = T3::column(statement, next_index)?; + let (fourth, next_index) = T4::column(statement, next_index)?; + let (fifth, next_index) = T5::column(statement, next_index)?; + Ok(((first, second, third, fourth, fifth), next_index)) + } +} diff --git a/crates/sqlez/src/connection.rs b/crates/sqlez/src/connection.rs index 0456266594..3342845d14 100644 --- a/crates/sqlez/src/connection.rs +++ b/crates/sqlez/src/connection.rs @@ -1,4 +1,5 @@ use std::{ + cell::RefCell, ffi::{CStr, CString}, marker::PhantomData, path::Path, @@ -11,7 +12,7 @@ use libsqlite3_sys::*; pub struct Connection { pub(crate) sqlite3: *mut sqlite3, persistent: bool, - pub(crate) write: bool, + pub(crate) write: RefCell, _sqlite: PhantomData, } unsafe impl Send for Connection {} @@ -21,7 +22,7 @@ impl Connection { let mut connection = Self { sqlite3: 0 as *mut _, persistent, - write: true, + write: RefCell::new(true), _sqlite: PhantomData, }; @@ -64,7 +65,7 @@ impl Connection { } pub fn can_write(&self) -> bool { - self.write + *self.write.borrow() } pub fn backup_main(&self, destination: &Connection) -> Result<()> { @@ -152,6 +153,13 @@ impl Connection { )) } } + + pub(crate) fn with_write(&self, callback: impl FnOnce(&Connection) -> T) -> T { + *self.write.borrow_mut() = true; + let result = callback(self); + *self.write.borrow_mut() = false; + result + } } impl Drop for Connection { diff --git a/crates/sqlez/src/domain.rs b/crates/sqlez/src/domain.rs index 3a477b2bc9..a83f4e18d6 100644 --- a/crates/sqlez/src/domain.rs +++ b/crates/sqlez/src/domain.rs @@ -1,11 +1,11 @@ use crate::connection::Connection; -pub trait Domain { +pub trait Domain: 'static { fn name() -> &'static str; fn migrations() -> &'static [&'static str]; } -pub trait Migrator { +pub trait Migrator: 'static { fn migrate(connection: &Connection) -> anyhow::Result<()>; } diff --git a/crates/sqlez/src/migrations.rs b/crates/sqlez/src/migrations.rs index 41c505f85b..aa8d5fe00b 100644 --- a/crates/sqlez/src/migrations.rs +++ b/crates/sqlez/src/migrations.rs @@ -12,6 +12,7 @@ use crate::connection::Connection; impl Connection { pub fn migrate(&self, domain: &'static str, migrations: &[&'static str]) -> Result<()> { self.with_savepoint("migrating", || { + println!("Processing domain"); // Setup the migrations table unconditionally self.exec(indoc! {" CREATE TABLE IF NOT EXISTS migrations ( @@ -43,11 +44,13 @@ impl Connection { {}", domain, index, completed_migration, migration})); } else { // Migration already run. Continue + println!("Migration already run"); continue; } } self.exec(migration)?()?; + println!("Ran migration"); store_completed_migration((domain, index, *migration))?; } diff --git a/crates/sqlez/src/thread_safe_connection.rs b/crates/sqlez/src/thread_safe_connection.rs index 4849e785b5..77ba3406a2 100644 --- a/crates/sqlez/src/thread_safe_connection.rs +++ b/crates/sqlez/src/thread_safe_connection.rs @@ -5,17 +5,13 @@ use parking_lot::{Mutex, RwLock}; use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread}; use thread_local::ThreadLocal; -use crate::{ - connection::Connection, - domain::{Domain, Migrator}, - util::UnboundedSyncSender, -}; +use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender}; const MIGRATION_RETRIES: usize = 10; -type QueuedWrite = Box; +type QueuedWrite = Box; type WriteQueueConstructor = - Box Box>; + Box Box>; lazy_static! { /// List of queues of tasks by database uri. This lets us serialize writes to the database /// and have a single worker thread per db file. This means many thread safe connections @@ -28,18 +24,18 @@ lazy_static! { /// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static, /// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection /// may be accessed by passing a callback to the `write` function which will queue the callback -pub struct ThreadSafeConnection { +pub struct ThreadSafeConnection { uri: Arc, persistent: bool, connection_initialize_query: Option<&'static str>, connections: Arc>, - _migrator: PhantomData, + _migrator: PhantomData<*mut M>, } -unsafe impl Send for ThreadSafeConnection {} -unsafe impl Sync for ThreadSafeConnection {} +unsafe impl Send for ThreadSafeConnection {} +unsafe impl Sync for ThreadSafeConnection {} -pub struct ThreadSafeConnectionBuilder { +pub struct ThreadSafeConnectionBuilder { db_initialize_query: Option<&'static str>, write_queue_constructor: Option, connection: ThreadSafeConnection, @@ -54,6 +50,13 @@ impl ThreadSafeConnectionBuilder { self } + /// Queues an initialization query for the database file. This must be infallible + /// but may cause changes to the database file such as with `PRAGMA journal_mode` + pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self { + self.db_initialize_query = Some(initialize_query); + self + } + /// Specifies how the thread safe connection should serialize writes. If provided /// the connection will call the write_queue_constructor for each database file in /// this process. The constructor is responsible for setting up a background thread or @@ -66,13 +69,6 @@ impl ThreadSafeConnectionBuilder { self } - /// Queues an initialization query for the database file. This must be infallible - /// but may cause changes to the database file such as with `PRAGMA journal_mode` - pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self { - self.db_initialize_query = Some(initialize_query); - self - } - pub async fn build(self) -> anyhow::Result> { self.connection .initialize_queues(self.write_queue_constructor); @@ -100,6 +96,7 @@ impl ThreadSafeConnectionBuilder { .with_savepoint("thread_safe_multi_migration", || M::migrate(connection)); if migration_result.is_ok() { + println!("Migration succeded"); break; } } @@ -113,38 +110,17 @@ impl ThreadSafeConnectionBuilder { } impl ThreadSafeConnection { - fn initialize_queues(&self, write_queue_constructor: Option) { + fn initialize_queues(&self, write_queue_constructor: Option) -> bool { if !QUEUES.read().contains_key(&self.uri) { let mut queues = QUEUES.write(); if !queues.contains_key(&self.uri) { - let mut write_connection = self.create_connection(); - // Enable writes for this connection - write_connection.write = true; - if let Some(mut write_queue_constructor) = write_queue_constructor { - let write_channel = write_queue_constructor(write_connection); - queues.insert(self.uri.clone(), write_channel); - } else { - use std::sync::mpsc::channel; - - let (sender, reciever) = channel::(); - thread::spawn(move || { - while let Ok(write) = reciever.recv() { - write(&write_connection) - } - }); - - let sender = UnboundedSyncSender::new(sender); - queues.insert( - self.uri.clone(), - Box::new(move |queued_write| { - sender - .send(queued_write) - .expect("Could not send write action to backgorund thread"); - }), - ); - } + let mut write_queue_constructor = + write_queue_constructor.unwrap_or(background_thread_queue()); + queues.insert(self.uri.clone(), write_queue_constructor()); + return true; } } + return false; } pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder { @@ -163,20 +139,21 @@ impl ThreadSafeConnection { /// Opens a new db connection with the initialized file path. This is internal and only /// called from the deref function. - fn open_file(&self) -> Connection { - Connection::open_file(self.uri.as_ref()) + fn open_file(uri: &str) -> Connection { + Connection::open_file(uri) } /// Opens a shared memory connection using the file path as the identifier. This is internal /// and only called from the deref function. - fn open_shared_memory(&self) -> Connection { - Connection::open_memory(Some(self.uri.as_ref())) + fn open_shared_memory(uri: &str) -> Connection { + Connection::open_memory(Some(uri)) } pub fn write( &self, callback: impl 'static + Send + FnOnce(&Connection) -> T, ) -> impl Future { + // Check and invalidate queue and maybe recreate queue let queues = QUEUES.read(); let write_channel = queues .get(&self.uri) @@ -185,24 +162,32 @@ impl ThreadSafeConnection { // Create a one shot channel for the result of the queued write // so we can await on the result let (sender, reciever) = oneshot::channel(); - write_channel(Box::new(move |connection| { - sender.send(callback(connection)).ok(); + + let thread_safe_connection = (*self).clone(); + write_channel(Box::new(move || { + let connection = thread_safe_connection.deref(); + let result = connection.with_write(|connection| callback(connection)); + sender.send(result).ok(); })); reciever.map(|response| response.expect("Background writer thread unexpectedly closed")) } - pub(crate) fn create_connection(&self) -> Connection { - let mut connection = if self.persistent { - self.open_file() + pub(crate) fn create_connection( + persistent: bool, + uri: &str, + connection_initialize_query: Option<&'static str>, + ) -> Connection { + let mut connection = if persistent { + Self::open_file(uri) } else { - self.open_shared_memory() + Self::open_shared_memory(uri) }; // Disallow writes on the connection. The only writes allowed for thread safe connections // are from the background thread that can serialize them. - connection.write = false; + *connection.write.get_mut() = false; - if let Some(initialize_query) = self.connection_initialize_query { + if let Some(initialize_query) = connection_initialize_query { connection.exec(initialize_query).expect(&format!( "Initialize query failed to execute: {}", initialize_query @@ -236,7 +221,7 @@ impl ThreadSafeConnection<()> { } } -impl Clone for ThreadSafeConnection { +impl Clone for ThreadSafeConnection { fn clone(&self) -> Self { Self { uri: self.uri.clone(), @@ -252,16 +237,41 @@ impl Deref for ThreadSafeConnection { type Target = Connection; fn deref(&self) -> &Self::Target { - self.connections.get_or(|| self.create_connection()) + self.connections.get_or(|| { + Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query) + }) } } -pub fn locking_queue() -> WriteQueueConstructor { - Box::new(|connection| { - let connection = Mutex::new(connection); +pub fn background_thread_queue() -> WriteQueueConstructor { + use std::sync::mpsc::channel; + + Box::new(|| { + let (sender, reciever) = channel::(); + + thread::spawn(move || { + while let Ok(write) = reciever.recv() { + write() + } + }); + + let sender = UnboundedSyncSender::new(sender); Box::new(move |queued_write| { - let connection = connection.lock(); - queued_write(&connection) + sender + .send(queued_write) + .expect("Could not send write action to background thread"); + }) + }) +} + +pub fn locking_queue() -> WriteQueueConstructor { + Box::new(|| { + let mutex = Mutex::new(()); + Box::new(move |queued_write| { + eprintln!("Write started"); + let _ = mutex.lock(); + queued_write(); + eprintln!("Write finished"); }) }) } @@ -269,7 +279,8 @@ pub fn locking_queue() -> WriteQueueConstructor { #[cfg(test)] mod test { use indoc::indoc; - use std::ops::Deref; + use lazy_static::__Deref; + use std::thread; use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection}; diff --git a/crates/terminal/src/persistence.rs b/crates/terminal/src/persistence.rs index f9cfb6fc01..1669a3a546 100644 --- a/crates/terminal/src/persistence.rs +++ b/crates/terminal/src/persistence.rs @@ -1,19 +1,11 @@ use std::path::PathBuf; -use db::{connection, query, sqlez::domain::Domain, sqlez_macros::sql}; +use db::{define_connection, query, sqlez_macros::sql}; -use workspace::{ItemId, Workspace, WorkspaceId}; +use workspace::{ItemId, WorkspaceDb, WorkspaceId}; -use crate::Terminal; - -connection!(TERMINAL_CONNECTION: TerminalDb<(Workspace, Terminal)>); - -impl Domain for Terminal { - fn name() -> &'static str { - "terminal" - } - - fn migrations() -> &'static [&'static str] { +define_connection! { + pub static ref TERMINAL_CONNECTION: TerminalDb = &[sql!( CREATE TABLE terminals ( workspace_id INTEGER, @@ -23,8 +15,7 @@ impl Domain for Terminal { FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) ON DELETE CASCADE ) STRICT; - )] - } + )]; } impl TerminalDb { diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index db59141087..a0cc48ca1c 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -5,30 +5,21 @@ pub mod model; use std::path::Path; use anyhow::{anyhow, bail, Context, Result}; -use db::{connection, query, sqlez::connection::Connection, sqlez_macros::sql}; +use db::{define_connection, query, sqlez::connection::Connection, sqlez_macros::sql}; use gpui::Axis; -use db::sqlez::domain::Domain; use util::{iife, unzip_option, ResultExt}; use crate::dock::DockPosition; use crate::WorkspaceId; -use super::Workspace; - use model::{ GroupId, PaneId, SerializedItem, SerializedPane, SerializedPaneGroup, SerializedWorkspace, WorkspaceLocation, }; -connection!(DB: WorkspaceDb); - -impl Domain for Workspace { - fn name() -> &'static str { - "workspace" - } - - fn migrations() -> &'static [&'static str] { +define_connection! { + pub static ref DB: WorkspaceDb<()> = &[sql!( CREATE TABLE workspaces( workspace_id INTEGER PRIMARY KEY, @@ -40,7 +31,7 @@ impl Domain for Workspace { timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL, FOREIGN KEY(dock_pane) REFERENCES panes(pane_id) ) STRICT; - + CREATE TABLE pane_groups( group_id INTEGER PRIMARY KEY, workspace_id INTEGER NOT NULL, @@ -48,29 +39,29 @@ impl Domain for Workspace { position INTEGER, // NULL indicates that this is a root node axis TEXT NOT NULL, // Enum: 'Vertical' / 'Horizontal' FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) - ON DELETE CASCADE - ON UPDATE CASCADE, + ON DELETE CASCADE + ON UPDATE CASCADE, FOREIGN KEY(parent_group_id) REFERENCES pane_groups(group_id) ON DELETE CASCADE ) STRICT; - + CREATE TABLE panes( pane_id INTEGER PRIMARY KEY, workspace_id INTEGER NOT NULL, active INTEGER NOT NULL, // Boolean FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) - ON DELETE CASCADE - ON UPDATE CASCADE + ON DELETE CASCADE + ON UPDATE CASCADE ) STRICT; - + CREATE TABLE center_panes( pane_id INTEGER PRIMARY KEY, parent_group_id INTEGER, // NULL means that this is a root pane position INTEGER, // NULL means that this is a root pane FOREIGN KEY(pane_id) REFERENCES panes(pane_id) - ON DELETE CASCADE, + ON DELETE CASCADE, FOREIGN KEY(parent_group_id) REFERENCES pane_groups(group_id) ON DELETE CASCADE ) STRICT; - + CREATE TABLE items( item_id INTEGER NOT NULL, // This is the item's view id, so this is not unique workspace_id INTEGER NOT NULL, @@ -79,14 +70,13 @@ impl Domain for Workspace { position INTEGER NOT NULL, active INTEGER NOT NULL, FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) - ON DELETE CASCADE - ON UPDATE CASCADE, + ON DELETE CASCADE + ON UPDATE CASCADE, FOREIGN KEY(pane_id) REFERENCES panes(pane_id) - ON DELETE CASCADE, + ON DELETE CASCADE, PRIMARY KEY(item_id, workspace_id) ) STRICT; - )] - } + )]; } impl WorkspaceDb { @@ -149,7 +139,7 @@ impl WorkspaceDb { UPDATE workspaces SET dock_pane = NULL WHERE workspace_id = ?1; DELETE FROM pane_groups WHERE workspace_id = ?1; DELETE FROM panes WHERE workspace_id = ?1;))?(workspace.id) - .context("Clearing old panes")?; + .expect("Clearing old panes"); conn.exec_bound(sql!( DELETE FROM workspaces WHERE workspace_location = ? AND workspace_id != ? diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 66ef63f27f..8e9131839d 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -44,8 +44,11 @@ use language::LanguageRegistry; use log::{error, warn}; pub use pane::*; pub use pane_group::*; -pub use persistence::model::{ItemId, WorkspaceLocation}; use persistence::{model::SerializedItem, DB}; +pub use persistence::{ + model::{ItemId, WorkspaceLocation}, + WorkspaceDb, +}; use postage::prelude::Stream; use project::{Project, ProjectEntryId, ProjectPath, ProjectStore, Worktree, WorktreeId}; use serde::Deserialize;