Reworked thread safe connection be threadsafer,,,, again

Co-Authored-By: kay@zed.dev
This commit is contained in:
Mikayla Maki 2022-12-01 18:31:05 -08:00
parent 189a820113
commit 5e240f98f0
12 changed files with 741 additions and 584 deletions

View file

@ -1,26 +1,27 @@
pub mod kvp; pub mod kvp;
pub mod query;
// Re-export // Re-export
pub use anyhow; pub use anyhow;
use anyhow::Context; use anyhow::Context;
pub use indoc::indoc; pub use indoc::indoc;
pub use lazy_static; pub use lazy_static;
use parking_lot::{Mutex, RwLock};
pub use smol; pub use smol;
pub use sqlez; pub use sqlez;
pub use sqlez_macros; pub use sqlez_macros;
pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
pub use util::paths::DB_DIR;
use sqlez::domain::Migrator; use sqlez::domain::Migrator;
use sqlez::thread_safe_connection::ThreadSafeConnection; use sqlez::thread_safe_connection::ThreadSafeConnection;
use sqlez_macros::sql; use sqlez_macros::sql;
use std::fs::{create_dir_all, remove_dir_all}; use std::fs::{create_dir_all, remove_dir_all};
use std::path::Path; use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use util::{async_iife, ResultExt}; use util::{async_iife, ResultExt};
use util::channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}; use util::channel::ReleaseChannel;
use util::paths::DB_DIR;
// TODO: Add a savepoint to the thread safe connection initialization and migrations
const CONNECTION_INITIALIZE_QUERY: &'static str = sql!( const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
PRAGMA synchronous=NORMAL; PRAGMA synchronous=NORMAL;
@ -36,79 +37,117 @@ const DB_INITIALIZE_QUERY: &'static str = sql!(
const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB"; const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
lazy_static::lazy_static! { lazy_static::lazy_static! {
static ref DB_WIPED: AtomicBool = AtomicBool::new(false); static ref DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
static ref DB_WIPED: RwLock<bool> = RwLock::new(false);
pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
} }
/// Open or create a database at the given directory path. /// Open or create a database at the given directory path.
pub async fn open_db<M: Migrator>() -> ThreadSafeConnection<M> { /// This will retry a couple times if there are failures. If opening fails once, the db directory
let db_dir = (*DB_DIR).join(Path::new(&format!("0-{}", *RELEASE_CHANNEL_NAME))); /// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
/// In either case, static variables are set so that the user can be notified.
pub async fn open_db<M: Migrator + 'static>(wipe_db: bool, db_dir: &Path, release_channel: &ReleaseChannel) -> ThreadSafeConnection<M> {
let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel.name())));
// If WIPE_DB, delete 0-{channel} // If WIPE_DB, delete 0-{channel}
if *RELEASE_CHANNEL == ReleaseChannel::Dev if release_channel == &ReleaseChannel::Dev
&& std::env::var("WIPE_DB").is_ok() && wipe_db
&& !DB_WIPED.load(Ordering::Acquire) && !*DB_WIPED.read()
{ {
remove_dir_all(&db_dir).ok(); let mut db_wiped = DB_WIPED.write();
DB_WIPED.store(true, Ordering::Release); if !*db_wiped {
remove_dir_all(&main_db_dir).ok();
*db_wiped = true;
}
} }
let connection = async_iife!({ let connection = async_iife!({
// Note: This still has a race condition where 1 set of migrations succeeds
// (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
// This will cause the first connection to have the database taken out
// from under it. This *should* be fine though. The second dabatase failure will
// cause errors in the log and so should be observed by developers while writing
// soon-to-be good migrations. If user databases are corrupted, we toss them out
// and try again from a blank. As long as running all migrations from start to end
// is ok, this race condition will never be triggered.
//
// Basically: Don't ever push invalid migrations to stable or everyone will have
// a bad time.
// If no db folder, create one at 0-{channel} // If no db folder, create one at 0-{channel}
create_dir_all(&db_dir).context("Could not create db directory")?; create_dir_all(&main_db_dir).context("Could not create db directory")?;
let db_path = db_dir.join(Path::new("db.sqlite")); let db_path = main_db_dir.join(Path::new("db.sqlite"));
// Try building a connection // Optimistically open databases in parallel
if let Some(connection) = ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true) if !DB_FILE_OPERATIONS.is_locked() {
.with_db_initialization_query(DB_INITIALIZE_QUERY) // Try building a connection
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) if let Some(connection) = open_main_db(&db_path).await {
.build() return Ok(connection)
.await };
.log_err() {
return Ok(connection)
} }
// Take a lock in the failure case so that we move the db once per process instead
// of potentially multiple times from different threads. This shouldn't happen in the
// normal path
let _lock = DB_FILE_OPERATIONS.lock();
if let Some(connection) = open_main_db(&db_path).await {
return Ok(connection)
};
let backup_timestamp = SystemTime::now() let backup_timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.expect( .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
"System clock is set before the unix timestamp, Zed does not support this region of spacetime"
)
.as_millis(); .as_millis();
// If failed, move 0-{channel} to {current unix timestamp}-{channel} // If failed, move 0-{channel} to {current unix timestamp}-{channel}
let backup_db_dir = (*DB_DIR).join(Path::new(&format!( let backup_db_dir = db_dir.join(Path::new(&format!(
"{}{}", "{}-{}",
backup_timestamp, backup_timestamp,
*RELEASE_CHANNEL_NAME release_channel.name(),
))); )));
std::fs::rename(&db_dir, backup_db_dir) std::fs::rename(&main_db_dir, &backup_db_dir)
.context("Failed clean up corrupted database, panicking.")?; .context("Failed clean up corrupted database, panicking.")?;
// TODO: Set a constant with the failed timestamp and error so we can notify the user // Set a static ref with the failed timestamp and error so we can notify the user
{
let mut guard = BACKUP_DB_PATH.write();
*guard = Some(backup_db_dir);
}
// Create a new 0-{channel} // Create a new 0-{channel}
create_dir_all(&db_dir).context("Should be able to create the database directory")?; create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
let db_path = db_dir.join(Path::new("db.sqlite")); let db_path = main_db_dir.join(Path::new("db.sqlite"));
// Try again // Try again
ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true) open_main_db(&db_path).await.context("Could not newly created db")
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
.build()
.await
}).await.log_err(); }).await.log_err();
if let Some(connection) = connection { if let Some(connection) = connection {
return connection; return connection;
} }
// TODO: Set another constant so that we can escalate the notification // Set another static ref so that we can escalate the notification
ALL_FILE_DB_FAILED.store(true, Ordering::Release);
// If still failed, create an in memory db with a known name // If still failed, create an in memory db with a known name
open_fallback_db().await open_fallback_db().await
} }
async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
println!("Opening main db");
ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
.build()
.await
.log_err()
}
async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> { async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
println!("Opening fallback db");
ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false) ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
.with_db_initialization_query(DB_INITIALIZE_QUERY) .with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
@ -135,17 +174,27 @@ pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M>
/// Implements a basic DB wrapper for a given domain /// Implements a basic DB wrapper for a given domain
#[macro_export] #[macro_export]
macro_rules! connection { macro_rules! define_connection {
($id:ident: $t:ident<$d:ty>) => { (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$d>); pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
impl ::std::ops::Deref for $t { impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$d>; type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0
} }
} }
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
$crate::lazy_static::lazy_static! { $crate::lazy_static::lazy_static! {
@ -154,322 +203,124 @@ macro_rules! connection {
#[cfg(not(any(test, feature = "test-support")))] #[cfg(not(any(test, feature = "test-support")))]
$crate::lazy_static::lazy_static! { $crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db())); pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(std::env::var("WIPE_DB").is_ok(), &$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
}
};
(pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl $crate::sqlez::domain::Domain for $t {
fn name() -> &'static str {
stringify!($t)
}
fn migrations() -> &'static [&'static str] {
$migrations
}
}
#[cfg(any(test, feature = "test-support"))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
}
#[cfg(not(any(test, feature = "test-support")))]
$crate::lazy_static::lazy_static! {
pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(std::env::var("WIPE_DB").is_ok(), &$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
} }
}; };
} }
#[macro_export] #[cfg(test)]
macro_rules! query { mod tests {
($vis:vis fn $id:ident() -> Result<()> { $($sql:tt)+ }) => { use std::thread;
$vis fn $id(&self) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); use sqlez::domain::Domain;
use sqlez_macros::sql;
use tempdir::TempDir;
use util::channel::ReleaseChannel;
self.exec(sql_stmt)?().context(::std::format!( use crate::open_db;
"Error in {}, exec failed to execute or parse for: {}",
::std::stringify!($id), enum TestDB {}
sql_stmt,
)) impl Domain for TestDB {
fn name() -> &'static str {
"db_tests"
} }
};
($vis:vis async fn $id:ident() -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(|connection| { fn migrations() -> &'static [&'static str] {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); &[sql!(
CREATE TABLE test(value);
connection.exec(sql_stmt)?().context(::std::format!( )]
"Error in {}, exec failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
} }
}; }
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> { // Test that wipe_db exists and works and gives a new db
use $crate::anyhow::Context; #[test]
fn test_wipe_db() {
env_logger::try_init().ok();
smol::block_on(async {
let tempdir = TempDir::new("DbTests").unwrap();
let test_db = open_db::<TestDB>(false, tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
test_db.write(|connection|
connection.exec(sql!(
INSERT INTO test(value) VALUES (10)
)).unwrap()().unwrap()
).await;
drop(test_db);
let mut guards = vec![];
for _ in 0..5 {
let path = tempdir.path().to_path_buf();
let guard = thread::spawn(move || smol::block_on(async {
let test_db = open_db::<TestDB>(true, &path, &ReleaseChannel::Dev).await;
assert!(test_db.select_row::<()>(sql!(SELECT value FROM test)).unwrap()().unwrap().is_none())
}));
guards.push(guard);
}
for guard in guards {
guard.join().unwrap();
}
})
}
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); // Test a file system failure (like in create_dir_all())
#[test]
self.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+)) fn test_file_system_failure() {
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}", }
::std::stringify!($id),
sql_stmt // Test happy path where everything exists and opens
)) #[test]
} fn test_open_db() {
};
($vis:vis async fn $id:ident($arg:ident: $arg_type:ty) -> Result<()> { $($sql:tt)+ }) => { }
$vis async fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context; // Test bad migration panics
#[test]
self.write(move |connection| { fn test_bad_migration_panics() {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
}
connection.exec_bound::<$arg_type>(sql_stmt)?($arg)
.context(::std::format!( /// Test that DB exists but corrupted (causing recreate)
"Error in {}, exec_bound failed to execute or parse for: {}", #[test]
::std::stringify!($id), fn test_db_corruption() {
sql_stmt
))
}).await // open_db(db_dir, release_channel)
} }
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(move |connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select::<$return_type>(sql_stmt)?(())
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
pub async fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select::<$return_type>(sql_stmt)?(())
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row_bound::<($($arg_type),+), $return_type>(indoc! { $sql })?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row::<$return_type>(indoc! { $sql })?()
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
pub fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn async $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
} }

View file

@ -1,26 +1,15 @@
use sqlez::domain::Domain;
use sqlez_macros::sql; use sqlez_macros::sql;
use crate::{connection, query}; use crate::{define_connection, query};
connection!(KEY_VALUE_STORE: KeyValueStore<KeyValueStore>); define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> =
&[sql!(
impl Domain for KeyValueStore { CREATE TABLE IF NOT EXISTS kv_store(
fn name() -> &'static str { key TEXT PRIMARY KEY,
"kvp" value TEXT NOT NULL
} ) STRICT;
)];
fn migrations() -> &'static [&'static str] { );
// Legacy migrations using rusqlite may have already created kv_store during alpha,
// migrations must be infallible so this must have 'IF NOT EXISTS'
&[sql!(
CREATE TABLE IF NOT EXISTS kv_store(
key TEXT PRIMARY KEY,
value TEXT NOT NULL
) STRICT;
)]
}
}
impl KeyValueStore { impl KeyValueStore {
query! { query! {

314
crates/db/src/query.rs Normal file
View file

@ -0,0 +1,314 @@
#[macro_export]
macro_rules! query {
($vis:vis fn $id:ident() -> Result<()> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.exec(sql_stmt)?().context(::std::format!(
"Error in {}, exec failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt,
))
}
};
($vis:vis async fn $id:ident() -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec(sql_stmt)?().context(::std::format!(
"Error in {}, exec failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($arg:ident: $arg_type:ty) -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(move |connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec_bound::<$arg_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> {
use $crate::anyhow::Context;
self.write(move |connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select::<$return_type>(sql_stmt)?(())
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
pub async fn $id(&self) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select::<$return_type>(sql_stmt)?(())
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Vec<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Vec<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, exec_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<Option<$return_type:ty>> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<Option<$return_type>> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row_bound::<($($arg_type),+), $return_type>(indoc! { $sql })?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis fn $id(&self) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row::<$return_type>(indoc! { $sql })?()
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis async fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis async fn $id(&self) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row::<$return_type>(sql_stmt)?()
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
pub fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg)
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}
};
($vis:vis fn async $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => {
$vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> {
use $crate::anyhow::Context;
self.write(|connection| {
let sql_stmt = $crate::sqlez_macros::sql!($($sql)+);
connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+))
.context(::std::format!(
"Error in {}, select_row_bound failed to execute or parse for: {}",
::std::stringify!($id),
sql_stmt
))?
.context(::std::format!(
"Error in {}, select_row_bound expected single row result but found none for: {}",
::std::stringify!($id),
sql_stmt
))
}).await
}
};
}

View file

@ -1,19 +1,11 @@
use std::path::PathBuf; use std::path::PathBuf;
use crate::Editor;
use db::sqlez_macros::sql; use db::sqlez_macros::sql;
use db::{connection, query}; use db::{define_connection, query};
use sqlez::domain::Domain; use workspace::{ItemId, WorkspaceDb, WorkspaceId};
use workspace::{ItemId, Workspace, WorkspaceId};
connection!(DB: EditorDb<(Workspace, Editor)>); define_connection!(
pub static ref DB: EditorDb<WorkspaceDb> =
impl Domain for Editor {
fn name() -> &'static str {
"editor"
}
fn migrations() -> &'static [&'static str] {
&[sql! ( &[sql! (
CREATE TABLE editors( CREATE TABLE editors(
item_id INTEGER NOT NULL, item_id INTEGER NOT NULL,
@ -21,12 +13,11 @@ impl Domain for Editor {
path BLOB NOT NULL, path BLOB NOT NULL,
PRIMARY KEY(item_id, workspace_id), PRIMARY KEY(item_id, workspace_id),
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE ON DELETE CASCADE
ON UPDATE CASCADE ON UPDATE CASCADE
) STRICT; ) STRICT;
)] )];
} );
}
impl EditorDb { impl EditorDb {
query! { query! {

View file

@ -137,13 +137,6 @@ impl Column for usize {
} }
} }
impl Bind for () {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
statement.bind_null(start_index)?;
Ok(start_index + 1)
}
}
impl Bind for &str { impl Bind for &str {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> { fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
statement.bind_text(start_index, self)?; statement.bind_text(start_index, self)?;
@ -179,78 +172,6 @@ impl Column for String {
} }
} }
impl<T1: Bind, T2: Bind> Bind for (T1, T2) {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let next_index = self.0.bind(statement, start_index)?;
self.1.bind(statement, next_index)
}
}
impl<T1: Column, T2: Column> Column for (T1, T2) {
fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (first, next_index) = T1::column(statement, start_index)?;
let (second, next_index) = T2::column(statement, next_index)?;
Ok(((first, second), next_index))
}
}
impl<T1: Bind, T2: Bind, T3: Bind> Bind for (T1, T2, T3) {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let next_index = self.0.bind(statement, start_index)?;
let next_index = self.1.bind(statement, next_index)?;
self.2.bind(statement, next_index)
}
}
impl<T1: Column, T2: Column, T3: Column> Column for (T1, T2, T3) {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (first, next_index) = T1::column(statement, start_index)?;
let (second, next_index) = T2::column(statement, next_index)?;
let (third, next_index) = T3::column(statement, next_index)?;
Ok(((first, second, third), next_index))
}
}
impl<T1: Bind, T2: Bind, T3: Bind, T4: Bind> Bind for (T1, T2, T3, T4) {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let next_index = self.0.bind(statement, start_index)?;
let next_index = self.1.bind(statement, next_index)?;
let next_index = self.2.bind(statement, next_index)?;
self.3.bind(statement, next_index)
}
}
impl<T1: Column, T2: Column, T3: Column, T4: Column> Column for (T1, T2, T3, T4) {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (first, next_index) = T1::column(statement, start_index)?;
let (second, next_index) = T2::column(statement, next_index)?;
let (third, next_index) = T3::column(statement, next_index)?;
let (fourth, next_index) = T4::column(statement, next_index)?;
Ok(((first, second, third, fourth), next_index))
}
}
impl<T1: Bind, T2: Bind, T3: Bind, T4: Bind, T5: Bind> Bind for (T1, T2, T3, T4, T5) {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let next_index = self.0.bind(statement, start_index)?;
let next_index = self.1.bind(statement, next_index)?;
let next_index = self.2.bind(statement, next_index)?;
let next_index = self.3.bind(statement, next_index)?;
self.4.bind(statement, next_index)
}
}
impl<T1: Column, T2: Column, T3: Column, T4: Column, T5: Column> Column for (T1, T2, T3, T4, T5) {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (first, next_index) = T1::column(statement, start_index)?;
let (second, next_index) = T2::column(statement, next_index)?;
let (third, next_index) = T3::column(statement, next_index)?;
let (fourth, next_index) = T4::column(statement, next_index)?;
let (fifth, next_index) = T5::column(statement, next_index)?;
Ok(((first, second, third, fourth, fifth), next_index))
}
}
impl<T: Bind> Bind for Option<T> { impl<T: Bind> Bind for Option<T> {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> { fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
if let Some(this) = self { if let Some(this) = self {
@ -344,3 +265,88 @@ impl Column for PathBuf {
)) ))
} }
} }
/// Unit impls do nothing. This simplifies query macros
impl Bind for () {
fn bind(&self, _statement: &Statement, start_index: i32) -> Result<i32> {
Ok(start_index)
}
}
impl Column for () {
fn column(_statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
Ok(((), start_index))
}
}
impl<T1: Bind, T2: Bind> Bind for (T1, T2) {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let next_index = self.0.bind(statement, start_index)?;
self.1.bind(statement, next_index)
}
}
impl<T1: Column, T2: Column> Column for (T1, T2) {
fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (first, next_index) = T1::column(statement, start_index)?;
let (second, next_index) = T2::column(statement, next_index)?;
Ok(((first, second), next_index))
}
}
impl<T1: Bind, T2: Bind, T3: Bind> Bind for (T1, T2, T3) {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let next_index = self.0.bind(statement, start_index)?;
let next_index = self.1.bind(statement, next_index)?;
self.2.bind(statement, next_index)
}
}
impl<T1: Column, T2: Column, T3: Column> Column for (T1, T2, T3) {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (first, next_index) = T1::column(statement, start_index)?;
let (second, next_index) = T2::column(statement, next_index)?;
let (third, next_index) = T3::column(statement, next_index)?;
Ok(((first, second, third), next_index))
}
}
impl<T1: Bind, T2: Bind, T3: Bind, T4: Bind> Bind for (T1, T2, T3, T4) {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let next_index = self.0.bind(statement, start_index)?;
let next_index = self.1.bind(statement, next_index)?;
let next_index = self.2.bind(statement, next_index)?;
self.3.bind(statement, next_index)
}
}
impl<T1: Column, T2: Column, T3: Column, T4: Column> Column for (T1, T2, T3, T4) {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (first, next_index) = T1::column(statement, start_index)?;
let (second, next_index) = T2::column(statement, next_index)?;
let (third, next_index) = T3::column(statement, next_index)?;
let (fourth, next_index) = T4::column(statement, next_index)?;
Ok(((first, second, third, fourth), next_index))
}
}
impl<T1: Bind, T2: Bind, T3: Bind, T4: Bind, T5: Bind> Bind for (T1, T2, T3, T4, T5) {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
let next_index = self.0.bind(statement, start_index)?;
let next_index = self.1.bind(statement, next_index)?;
let next_index = self.2.bind(statement, next_index)?;
let next_index = self.3.bind(statement, next_index)?;
self.4.bind(statement, next_index)
}
}
impl<T1: Column, T2: Column, T3: Column, T4: Column, T5: Column> Column for (T1, T2, T3, T4, T5) {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let (first, next_index) = T1::column(statement, start_index)?;
let (second, next_index) = T2::column(statement, next_index)?;
let (third, next_index) = T3::column(statement, next_index)?;
let (fourth, next_index) = T4::column(statement, next_index)?;
let (fifth, next_index) = T5::column(statement, next_index)?;
Ok(((first, second, third, fourth, fifth), next_index))
}
}

View file

@ -1,4 +1,5 @@
use std::{ use std::{
cell::RefCell,
ffi::{CStr, CString}, ffi::{CStr, CString},
marker::PhantomData, marker::PhantomData,
path::Path, path::Path,
@ -11,7 +12,7 @@ use libsqlite3_sys::*;
pub struct Connection { pub struct Connection {
pub(crate) sqlite3: *mut sqlite3, pub(crate) sqlite3: *mut sqlite3,
persistent: bool, persistent: bool,
pub(crate) write: bool, pub(crate) write: RefCell<bool>,
_sqlite: PhantomData<sqlite3>, _sqlite: PhantomData<sqlite3>,
} }
unsafe impl Send for Connection {} unsafe impl Send for Connection {}
@ -21,7 +22,7 @@ impl Connection {
let mut connection = Self { let mut connection = Self {
sqlite3: 0 as *mut _, sqlite3: 0 as *mut _,
persistent, persistent,
write: true, write: RefCell::new(true),
_sqlite: PhantomData, _sqlite: PhantomData,
}; };
@ -64,7 +65,7 @@ impl Connection {
} }
pub fn can_write(&self) -> bool { pub fn can_write(&self) -> bool {
self.write *self.write.borrow()
} }
pub fn backup_main(&self, destination: &Connection) -> Result<()> { pub fn backup_main(&self, destination: &Connection) -> Result<()> {
@ -152,6 +153,13 @@ impl Connection {
)) ))
} }
} }
pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
*self.write.borrow_mut() = true;
let result = callback(self);
*self.write.borrow_mut() = false;
result
}
} }
impl Drop for Connection { impl Drop for Connection {

View file

@ -1,11 +1,11 @@
use crate::connection::Connection; use crate::connection::Connection;
pub trait Domain { pub trait Domain: 'static {
fn name() -> &'static str; fn name() -> &'static str;
fn migrations() -> &'static [&'static str]; fn migrations() -> &'static [&'static str];
} }
pub trait Migrator { pub trait Migrator: 'static {
fn migrate(connection: &Connection) -> anyhow::Result<()>; fn migrate(connection: &Connection) -> anyhow::Result<()>;
} }

View file

@ -12,6 +12,7 @@ use crate::connection::Connection;
impl Connection { impl Connection {
pub fn migrate(&self, domain: &'static str, migrations: &[&'static str]) -> Result<()> { pub fn migrate(&self, domain: &'static str, migrations: &[&'static str]) -> Result<()> {
self.with_savepoint("migrating", || { self.with_savepoint("migrating", || {
println!("Processing domain");
// Setup the migrations table unconditionally // Setup the migrations table unconditionally
self.exec(indoc! {" self.exec(indoc! {"
CREATE TABLE IF NOT EXISTS migrations ( CREATE TABLE IF NOT EXISTS migrations (
@ -43,11 +44,13 @@ impl Connection {
{}", domain, index, completed_migration, migration})); {}", domain, index, completed_migration, migration}));
} else { } else {
// Migration already run. Continue // Migration already run. Continue
println!("Migration already run");
continue; continue;
} }
} }
self.exec(migration)?()?; self.exec(migration)?()?;
println!("Ran migration");
store_completed_migration((domain, index, *migration))?; store_completed_migration((domain, index, *migration))?;
} }

View file

@ -5,17 +5,13 @@ use parking_lot::{Mutex, RwLock};
use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread}; use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
use thread_local::ThreadLocal; use thread_local::ThreadLocal;
use crate::{ use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
connection::Connection,
domain::{Domain, Migrator},
util::UnboundedSyncSender,
};
const MIGRATION_RETRIES: usize = 10; const MIGRATION_RETRIES: usize = 10;
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>; type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
type WriteQueueConstructor = type WriteQueueConstructor =
Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>; Box<dyn 'static + Send + FnMut() -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
lazy_static! { lazy_static! {
/// List of queues of tasks by database uri. This lets us serialize writes to the database /// List of queues of tasks by database uri. This lets us serialize writes to the database
/// and have a single worker thread per db file. This means many thread safe connections /// and have a single worker thread per db file. This means many thread safe connections
@ -28,18 +24,18 @@ lazy_static! {
/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static, /// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection /// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
/// may be accessed by passing a callback to the `write` function which will queue the callback /// may be accessed by passing a callback to the `write` function which will queue the callback
pub struct ThreadSafeConnection<M: Migrator = ()> { pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
uri: Arc<str>, uri: Arc<str>,
persistent: bool, persistent: bool,
connection_initialize_query: Option<&'static str>, connection_initialize_query: Option<&'static str>,
connections: Arc<ThreadLocal<Connection>>, connections: Arc<ThreadLocal<Connection>>,
_migrator: PhantomData<M>, _migrator: PhantomData<*mut M>,
} }
unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {} unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {} unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> { pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
db_initialize_query: Option<&'static str>, db_initialize_query: Option<&'static str>,
write_queue_constructor: Option<WriteQueueConstructor>, write_queue_constructor: Option<WriteQueueConstructor>,
connection: ThreadSafeConnection<M>, connection: ThreadSafeConnection<M>,
@ -54,6 +50,13 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
self self
} }
/// Queues an initialization query for the database file. This must be infallible
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
self.db_initialize_query = Some(initialize_query);
self
}
/// Specifies how the thread safe connection should serialize writes. If provided /// Specifies how the thread safe connection should serialize writes. If provided
/// the connection will call the write_queue_constructor for each database file in /// the connection will call the write_queue_constructor for each database file in
/// this process. The constructor is responsible for setting up a background thread or /// this process. The constructor is responsible for setting up a background thread or
@ -66,13 +69,6 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
self self
} }
/// Queues an initialization query for the database file. This must be infallible
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
self.db_initialize_query = Some(initialize_query);
self
}
pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> { pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
self.connection self.connection
.initialize_queues(self.write_queue_constructor); .initialize_queues(self.write_queue_constructor);
@ -100,6 +96,7 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
.with_savepoint("thread_safe_multi_migration", || M::migrate(connection)); .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
if migration_result.is_ok() { if migration_result.is_ok() {
println!("Migration succeded");
break; break;
} }
} }
@ -113,38 +110,17 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
} }
impl<M: Migrator> ThreadSafeConnection<M> { impl<M: Migrator> ThreadSafeConnection<M> {
fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) { fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
if !QUEUES.read().contains_key(&self.uri) { if !QUEUES.read().contains_key(&self.uri) {
let mut queues = QUEUES.write(); let mut queues = QUEUES.write();
if !queues.contains_key(&self.uri) { if !queues.contains_key(&self.uri) {
let mut write_connection = self.create_connection(); let mut write_queue_constructor =
// Enable writes for this connection write_queue_constructor.unwrap_or(background_thread_queue());
write_connection.write = true; queues.insert(self.uri.clone(), write_queue_constructor());
if let Some(mut write_queue_constructor) = write_queue_constructor { return true;
let write_channel = write_queue_constructor(write_connection);
queues.insert(self.uri.clone(), write_channel);
} else {
use std::sync::mpsc::channel;
let (sender, reciever) = channel::<QueuedWrite>();
thread::spawn(move || {
while let Ok(write) = reciever.recv() {
write(&write_connection)
}
});
let sender = UnboundedSyncSender::new(sender);
queues.insert(
self.uri.clone(),
Box::new(move |queued_write| {
sender
.send(queued_write)
.expect("Could not send write action to backgorund thread");
}),
);
}
} }
} }
return false;
} }
pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> { pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
@ -163,20 +139,21 @@ impl<M: Migrator> ThreadSafeConnection<M> {
/// Opens a new db connection with the initialized file path. This is internal and only /// Opens a new db connection with the initialized file path. This is internal and only
/// called from the deref function. /// called from the deref function.
fn open_file(&self) -> Connection { fn open_file(uri: &str) -> Connection {
Connection::open_file(self.uri.as_ref()) Connection::open_file(uri)
} }
/// Opens a shared memory connection using the file path as the identifier. This is internal /// Opens a shared memory connection using the file path as the identifier. This is internal
/// and only called from the deref function. /// and only called from the deref function.
fn open_shared_memory(&self) -> Connection { fn open_shared_memory(uri: &str) -> Connection {
Connection::open_memory(Some(self.uri.as_ref())) Connection::open_memory(Some(uri))
} }
pub fn write<T: 'static + Send + Sync>( pub fn write<T: 'static + Send + Sync>(
&self, &self,
callback: impl 'static + Send + FnOnce(&Connection) -> T, callback: impl 'static + Send + FnOnce(&Connection) -> T,
) -> impl Future<Output = T> { ) -> impl Future<Output = T> {
// Check and invalidate queue and maybe recreate queue
let queues = QUEUES.read(); let queues = QUEUES.read();
let write_channel = queues let write_channel = queues
.get(&self.uri) .get(&self.uri)
@ -185,24 +162,32 @@ impl<M: Migrator> ThreadSafeConnection<M> {
// Create a one shot channel for the result of the queued write // Create a one shot channel for the result of the queued write
// so we can await on the result // so we can await on the result
let (sender, reciever) = oneshot::channel(); let (sender, reciever) = oneshot::channel();
write_channel(Box::new(move |connection| {
sender.send(callback(connection)).ok(); let thread_safe_connection = (*self).clone();
write_channel(Box::new(move || {
let connection = thread_safe_connection.deref();
let result = connection.with_write(|connection| callback(connection));
sender.send(result).ok();
})); }));
reciever.map(|response| response.expect("Background writer thread unexpectedly closed")) reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
} }
pub(crate) fn create_connection(&self) -> Connection { pub(crate) fn create_connection(
let mut connection = if self.persistent { persistent: bool,
self.open_file() uri: &str,
connection_initialize_query: Option<&'static str>,
) -> Connection {
let mut connection = if persistent {
Self::open_file(uri)
} else { } else {
self.open_shared_memory() Self::open_shared_memory(uri)
}; };
// Disallow writes on the connection. The only writes allowed for thread safe connections // Disallow writes on the connection. The only writes allowed for thread safe connections
// are from the background thread that can serialize them. // are from the background thread that can serialize them.
connection.write = false; *connection.write.get_mut() = false;
if let Some(initialize_query) = self.connection_initialize_query { if let Some(initialize_query) = connection_initialize_query {
connection.exec(initialize_query).expect(&format!( connection.exec(initialize_query).expect(&format!(
"Initialize query failed to execute: {}", "Initialize query failed to execute: {}",
initialize_query initialize_query
@ -236,7 +221,7 @@ impl ThreadSafeConnection<()> {
} }
} }
impl<D: Domain> Clone for ThreadSafeConnection<D> { impl<M: Migrator> Clone for ThreadSafeConnection<M> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
uri: self.uri.clone(), uri: self.uri.clone(),
@ -252,16 +237,41 @@ impl<M: Migrator> Deref for ThreadSafeConnection<M> {
type Target = Connection; type Target = Connection;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
self.connections.get_or(|| self.create_connection()) self.connections.get_or(|| {
Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
})
} }
} }
pub fn locking_queue() -> WriteQueueConstructor { pub fn background_thread_queue() -> WriteQueueConstructor {
Box::new(|connection| { use std::sync::mpsc::channel;
let connection = Mutex::new(connection);
Box::new(|| {
let (sender, reciever) = channel::<QueuedWrite>();
thread::spawn(move || {
while let Ok(write) = reciever.recv() {
write()
}
});
let sender = UnboundedSyncSender::new(sender);
Box::new(move |queued_write| { Box::new(move |queued_write| {
let connection = connection.lock(); sender
queued_write(&connection) .send(queued_write)
.expect("Could not send write action to background thread");
})
})
}
pub fn locking_queue() -> WriteQueueConstructor {
Box::new(|| {
let mutex = Mutex::new(());
Box::new(move |queued_write| {
eprintln!("Write started");
let _ = mutex.lock();
queued_write();
eprintln!("Write finished");
}) })
}) })
} }
@ -269,7 +279,8 @@ pub fn locking_queue() -> WriteQueueConstructor {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use indoc::indoc; use indoc::indoc;
use std::ops::Deref; use lazy_static::__Deref;
use std::thread; use std::thread;
use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection}; use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};

View file

@ -1,19 +1,11 @@
use std::path::PathBuf; use std::path::PathBuf;
use db::{connection, query, sqlez::domain::Domain, sqlez_macros::sql}; use db::{define_connection, query, sqlez_macros::sql};
use workspace::{ItemId, Workspace, WorkspaceId}; use workspace::{ItemId, WorkspaceDb, WorkspaceId};
use crate::Terminal; define_connection! {
pub static ref TERMINAL_CONNECTION: TerminalDb<WorkspaceDb> =
connection!(TERMINAL_CONNECTION: TerminalDb<(Workspace, Terminal)>);
impl Domain for Terminal {
fn name() -> &'static str {
"terminal"
}
fn migrations() -> &'static [&'static str] {
&[sql!( &[sql!(
CREATE TABLE terminals ( CREATE TABLE terminals (
workspace_id INTEGER, workspace_id INTEGER,
@ -23,8 +15,7 @@ impl Domain for Terminal {
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE ON DELETE CASCADE
) STRICT; ) STRICT;
)] )];
}
} }
impl TerminalDb { impl TerminalDb {

View file

@ -5,30 +5,21 @@ pub mod model;
use std::path::Path; use std::path::Path;
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use db::{connection, query, sqlez::connection::Connection, sqlez_macros::sql}; use db::{define_connection, query, sqlez::connection::Connection, sqlez_macros::sql};
use gpui::Axis; use gpui::Axis;
use db::sqlez::domain::Domain;
use util::{iife, unzip_option, ResultExt}; use util::{iife, unzip_option, ResultExt};
use crate::dock::DockPosition; use crate::dock::DockPosition;
use crate::WorkspaceId; use crate::WorkspaceId;
use super::Workspace;
use model::{ use model::{
GroupId, PaneId, SerializedItem, SerializedPane, SerializedPaneGroup, SerializedWorkspace, GroupId, PaneId, SerializedItem, SerializedPane, SerializedPaneGroup, SerializedWorkspace,
WorkspaceLocation, WorkspaceLocation,
}; };
connection!(DB: WorkspaceDb<Workspace>); define_connection! {
pub static ref DB: WorkspaceDb<()> =
impl Domain for Workspace {
fn name() -> &'static str {
"workspace"
}
fn migrations() -> &'static [&'static str] {
&[sql!( &[sql!(
CREATE TABLE workspaces( CREATE TABLE workspaces(
workspace_id INTEGER PRIMARY KEY, workspace_id INTEGER PRIMARY KEY,
@ -40,7 +31,7 @@ impl Domain for Workspace {
timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL, timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
FOREIGN KEY(dock_pane) REFERENCES panes(pane_id) FOREIGN KEY(dock_pane) REFERENCES panes(pane_id)
) STRICT; ) STRICT;
CREATE TABLE pane_groups( CREATE TABLE pane_groups(
group_id INTEGER PRIMARY KEY, group_id INTEGER PRIMARY KEY,
workspace_id INTEGER NOT NULL, workspace_id INTEGER NOT NULL,
@ -48,29 +39,29 @@ impl Domain for Workspace {
position INTEGER, // NULL indicates that this is a root node position INTEGER, // NULL indicates that this is a root node
axis TEXT NOT NULL, // Enum: 'Vertical' / 'Horizontal' axis TEXT NOT NULL, // Enum: 'Vertical' / 'Horizontal'
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE ON DELETE CASCADE
ON UPDATE CASCADE, ON UPDATE CASCADE,
FOREIGN KEY(parent_group_id) REFERENCES pane_groups(group_id) ON DELETE CASCADE FOREIGN KEY(parent_group_id) REFERENCES pane_groups(group_id) ON DELETE CASCADE
) STRICT; ) STRICT;
CREATE TABLE panes( CREATE TABLE panes(
pane_id INTEGER PRIMARY KEY, pane_id INTEGER PRIMARY KEY,
workspace_id INTEGER NOT NULL, workspace_id INTEGER NOT NULL,
active INTEGER NOT NULL, // Boolean active INTEGER NOT NULL, // Boolean
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE ON DELETE CASCADE
ON UPDATE CASCADE ON UPDATE CASCADE
) STRICT; ) STRICT;
CREATE TABLE center_panes( CREATE TABLE center_panes(
pane_id INTEGER PRIMARY KEY, pane_id INTEGER PRIMARY KEY,
parent_group_id INTEGER, // NULL means that this is a root pane parent_group_id INTEGER, // NULL means that this is a root pane
position INTEGER, // NULL means that this is a root pane position INTEGER, // NULL means that this is a root pane
FOREIGN KEY(pane_id) REFERENCES panes(pane_id) FOREIGN KEY(pane_id) REFERENCES panes(pane_id)
ON DELETE CASCADE, ON DELETE CASCADE,
FOREIGN KEY(parent_group_id) REFERENCES pane_groups(group_id) ON DELETE CASCADE FOREIGN KEY(parent_group_id) REFERENCES pane_groups(group_id) ON DELETE CASCADE
) STRICT; ) STRICT;
CREATE TABLE items( CREATE TABLE items(
item_id INTEGER NOT NULL, // This is the item's view id, so this is not unique item_id INTEGER NOT NULL, // This is the item's view id, so this is not unique
workspace_id INTEGER NOT NULL, workspace_id INTEGER NOT NULL,
@ -79,14 +70,13 @@ impl Domain for Workspace {
position INTEGER NOT NULL, position INTEGER NOT NULL,
active INTEGER NOT NULL, active INTEGER NOT NULL,
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE ON DELETE CASCADE
ON UPDATE CASCADE, ON UPDATE CASCADE,
FOREIGN KEY(pane_id) REFERENCES panes(pane_id) FOREIGN KEY(pane_id) REFERENCES panes(pane_id)
ON DELETE CASCADE, ON DELETE CASCADE,
PRIMARY KEY(item_id, workspace_id) PRIMARY KEY(item_id, workspace_id)
) STRICT; ) STRICT;
)] )];
}
} }
impl WorkspaceDb { impl WorkspaceDb {
@ -149,7 +139,7 @@ impl WorkspaceDb {
UPDATE workspaces SET dock_pane = NULL WHERE workspace_id = ?1; UPDATE workspaces SET dock_pane = NULL WHERE workspace_id = ?1;
DELETE FROM pane_groups WHERE workspace_id = ?1; DELETE FROM pane_groups WHERE workspace_id = ?1;
DELETE FROM panes WHERE workspace_id = ?1;))?(workspace.id) DELETE FROM panes WHERE workspace_id = ?1;))?(workspace.id)
.context("Clearing old panes")?; .expect("Clearing old panes");
conn.exec_bound(sql!( conn.exec_bound(sql!(
DELETE FROM workspaces WHERE workspace_location = ? AND workspace_id != ? DELETE FROM workspaces WHERE workspace_location = ? AND workspace_id != ?

View file

@ -44,8 +44,11 @@ use language::LanguageRegistry;
use log::{error, warn}; use log::{error, warn};
pub use pane::*; pub use pane::*;
pub use pane_group::*; pub use pane_group::*;
pub use persistence::model::{ItemId, WorkspaceLocation};
use persistence::{model::SerializedItem, DB}; use persistence::{model::SerializedItem, DB};
pub use persistence::{
model::{ItemId, WorkspaceLocation},
WorkspaceDb,
};
use postage::prelude::Stream; use postage::prelude::Stream;
use project::{Project, ProjectEntryId, ProjectPath, ProjectStore, Worktree, WorktreeId}; use project::{Project, ProjectEntryId, ProjectPath, ProjectStore, Worktree, WorktreeId};
use serde::Deserialize; use serde::Deserialize;