working serialized writes with panics on failure. Everything seems to be working

This commit is contained in:
Kay Simmons 2022-11-23 01:53:58 -08:00 committed by Mikayla Maki
parent b01243109e
commit 1cc3e4820a
34 changed files with 669 additions and 312 deletions

View file

@ -9,4 +9,7 @@ edition = "2021"
anyhow = { version = "1.0.38", features = ["backtrace"] }
indoc = "1.0.7"
libsqlite3-sys = { version = "0.25.2", features = ["bundled"] }
thread_local = "1.1.4"
thread_local = "1.1.4"
lazy_static = "1.4"
parking_lot = "0.11.1"
futures = "0.3"

View file

@ -322,6 +322,18 @@ impl Bind for &Path {
}
}
impl Bind for Arc<Path> {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
self.as_ref().bind(statement, start_index)
}
}
impl Bind for PathBuf {
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
(self.as_ref() as &Path).bind(statement, start_index)
}
}
impl Column for PathBuf {
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
let blob = statement.column_blob(start_index)?;

View file

@ -10,16 +10,18 @@ use libsqlite3_sys::*;
pub struct Connection {
pub(crate) sqlite3: *mut sqlite3,
persistent: bool,
phantom: PhantomData<sqlite3>,
pub(crate) write: bool,
_sqlite: PhantomData<sqlite3>,
}
unsafe impl Send for Connection {}
impl Connection {
fn open(uri: &str, persistent: bool) -> Result<Self> {
pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
let mut connection = Self {
sqlite3: 0 as *mut _,
persistent,
phantom: PhantomData,
write: true,
_sqlite: PhantomData,
};
let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
@ -60,6 +62,10 @@ impl Connection {
self.persistent
}
pub fn can_write(&self) -> bool {
self.write
}
pub fn backup_main(&self, destination: &Connection) -> Result<()> {
unsafe {
let backup = sqlite3_backup_init(

View file

@ -1,5 +1,3 @@
pub use anyhow;
pub mod bindable;
pub mod connection;
pub mod domain;
@ -8,3 +6,6 @@ pub mod savepoint;
pub mod statement;
pub mod thread_safe_connection;
pub mod typed_statements;
mod util;
pub use anyhow;

View file

@ -11,46 +11,48 @@ use crate::connection::Connection;
impl Connection {
pub fn migrate(&self, domain: &'static str, migrations: &[&'static str]) -> Result<()> {
// Setup the migrations table unconditionally
self.exec(indoc! {"
CREATE TABLE IF NOT EXISTS migrations (
self.with_savepoint("migrating", || {
// Setup the migrations table unconditionally
self.exec(indoc! {"
CREATE TABLE IF NOT EXISTS migrations (
domain TEXT,
step INTEGER,
migration TEXT
)"})?()?;
)"})?()?;
let completed_migrations =
self.select_bound::<&str, (String, usize, String)>(indoc! {"
let completed_migrations =
self.select_bound::<&str, (String, usize, String)>(indoc! {"
SELECT domain, step, migration FROM migrations
WHERE domain = ?
ORDER BY step
"})?(domain)?;
"})?(domain)?;
let mut store_completed_migration =
self.exec_bound("INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)")?;
let mut store_completed_migration = self
.exec_bound("INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)")?;
for (index, migration) in migrations.iter().enumerate() {
if let Some((_, _, completed_migration)) = completed_migrations.get(index) {
if completed_migration != migration {
return Err(anyhow!(formatdoc! {"
Migration changed for {} at step {}
Stored migration:
{}
Proposed migration:
{}", domain, index, completed_migration, migration}));
} else {
// Migration already run. Continue
continue;
for (index, migration) in migrations.iter().enumerate() {
if let Some((_, _, completed_migration)) = completed_migrations.get(index) {
if completed_migration != migration {
return Err(anyhow!(formatdoc! {"
Migration changed for {} at step {}
Stored migration:
{}
Proposed migration:
{}", domain, index, completed_migration, migration}));
} else {
// Migration already run. Continue
continue;
}
}
self.exec(migration)?()?;
store_completed_migration((domain, index, *migration))?;
}
self.exec(migration)?()?;
store_completed_migration((domain, index, *migration))?;
}
Ok(())
Ok(())
})
}
}

View file

@ -2,7 +2,7 @@ use std::ffi::{c_int, CStr, CString};
use std::marker::PhantomData;
use std::{ptr, slice, str};
use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use libsqlite3_sys::*;
use crate::bindable::{Bind, Column};
@ -57,12 +57,21 @@ impl<'a> Statement<'a> {
&mut raw_statement,
&mut remaining_sql_ptr,
);
remaining_sql = CStr::from_ptr(remaining_sql_ptr);
statement.raw_statements.push(raw_statement);
connection.last_error().with_context(|| {
format!("Prepare call failed for query:\n{}", query.as_ref())
})?;
if !connection.can_write() && sqlite3_stmt_readonly(raw_statement) == 0 {
let sql = CStr::from_ptr(sqlite3_sql(raw_statement));
bail!(
"Write statement prepared with connection that is not write capable. SQL:\n{} ",
sql.to_str()?)
}
}
}

View file

@ -1,36 +1,41 @@
use std::{marker::PhantomData, ops::Deref, sync::Arc};
use connection::Connection;
use futures::{Future, FutureExt};
use lazy_static::lazy_static;
use parking_lot::RwLock;
use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
use thread_local::ThreadLocal;
use crate::{
connection,
connection::Connection,
domain::{Domain, Migrator},
util::UnboundedSyncSender,
};
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
lazy_static! {
static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
Default::default();
}
pub struct ThreadSafeConnection<M: Migrator> {
uri: Option<Arc<str>>,
uri: Arc<str>,
persistent: bool,
initialize_query: Option<&'static str>,
connection: Arc<ThreadLocal<Connection>>,
_pd: PhantomData<M>,
connections: Arc<ThreadLocal<Connection>>,
_migrator: PhantomData<M>,
}
unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
impl<M: Migrator> ThreadSafeConnection<M> {
pub fn new(uri: Option<&str>, persistent: bool) -> Self {
if persistent == true && uri == None {
// This panic is securing the unwrap in open_file(), don't remove it!
panic!("Cannot create a persistent connection without a URI")
}
pub fn new(uri: &str, persistent: bool) -> Self {
Self {
uri: uri.map(|str| Arc::from(str)),
uri: Arc::from(uri),
persistent,
initialize_query: None,
connection: Default::default(),
_pd: PhantomData,
connections: Default::default(),
_migrator: PhantomData,
}
}
@ -46,13 +51,13 @@ impl<M: Migrator> ThreadSafeConnection<M> {
/// If opening fails, the connection falls back to a shared memory connection
fn open_file(&self) -> Connection {
// This unwrap is secured by a panic in the constructor. Be careful if you remove it!
Connection::open_file(self.uri.as_ref().unwrap())
Connection::open_file(self.uri.as_ref())
}
/// Opens a shared memory connection using the file path as the identifier. This unwraps
/// as we expect it always to succeed
fn open_shared_memory(&self) -> Connection {
Connection::open_memory(self.uri.as_ref().map(|str| str.deref()))
Connection::open_memory(Some(self.uri.as_ref()))
}
// Open a new connection for the given domain, leaving this
@ -62,10 +67,74 @@ impl<M: Migrator> ThreadSafeConnection<M> {
uri: self.uri.clone(),
persistent: self.persistent,
initialize_query: self.initialize_query,
connection: Default::default(),
_pd: PhantomData,
connections: Default::default(),
_migrator: PhantomData,
}
}
pub fn write<T: 'static + Send + Sync>(
&self,
callback: impl 'static + Send + FnOnce(&Connection) -> T,
) -> impl Future<Output = T> {
// Startup write thread for this database if one hasn't already
// been started and insert a channel to queue work for it
if !QUEUES.read().contains_key(&self.uri) {
use std::sync::mpsc::channel;
let (sender, reciever) = channel::<QueuedWrite>();
let mut write_connection = self.create_connection();
// Enable writes for this connection
write_connection.write = true;
thread::spawn(move || {
while let Ok(write) = reciever.recv() {
write(&write_connection)
}
});
let mut queues = QUEUES.write();
queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
}
// Grab the queue for this database
let queues = QUEUES.read();
let write_channel = queues.get(&self.uri).unwrap();
// Create a one shot channel for the result of the queued write
// so we can await on the result
let (sender, reciever) = futures::channel::oneshot::channel();
write_channel
.send(Box::new(move |connection| {
sender.send(callback(connection)).ok();
}))
.expect("Could not send write action to background thread");
reciever.map(|response| response.expect("Background thread unexpectedly closed"))
}
pub(crate) fn create_connection(&self) -> Connection {
let mut connection = if self.persistent {
self.open_file()
} else {
self.open_shared_memory()
};
// Enable writes for the migrations and initialization queries
connection.write = true;
if let Some(initialize_query) = self.initialize_query {
connection.exec(initialize_query).expect(&format!(
"Initialize query failed to execute: {}",
initialize_query
))()
.unwrap();
}
M::migrate(&connection).expect("Migrations failed");
// Disable db writes for normal thread local connection
connection.write = false;
connection
}
}
impl<D: Domain> Clone for ThreadSafeConnection<D> {
@ -74,8 +143,8 @@ impl<D: Domain> Clone for ThreadSafeConnection<D> {
uri: self.uri.clone(),
persistent: self.persistent,
initialize_query: self.initialize_query.clone(),
connection: self.connection.clone(),
_pd: PhantomData,
connections: self.connections.clone(),
_migrator: PhantomData,
}
}
}
@ -88,25 +157,7 @@ impl<M: Migrator> Deref for ThreadSafeConnection<M> {
type Target = Connection;
fn deref(&self) -> &Self::Target {
self.connection.get_or(|| {
let connection = if self.persistent {
self.open_file()
} else {
self.open_shared_memory()
};
if let Some(initialize_query) = self.initialize_query {
connection.exec(initialize_query).expect(&format!(
"Initialize query failed to execute: {}",
initialize_query
))()
.unwrap();
}
M::migrate(&connection).expect("Migrations failed");
connection
})
self.connections.get_or(|| self.create_connection())
}
}
@ -151,7 +202,7 @@ mod test {
}
}
let _ = ThreadSafeConnection::<TestWorkspace>::new(None, false)
let _ = ThreadSafeConnection::<TestWorkspace>::new("wild_zed_lost_failure", false)
.with_initialize_query("PRAGMA FOREIGN_KEYS=true")
.deref();
}

28
crates/sqlez/src/util.rs Normal file
View file

@ -0,0 +1,28 @@
use std::ops::Deref;
use std::sync::mpsc::Sender;
use parking_lot::Mutex;
use thread_local::ThreadLocal;
pub struct UnboundedSyncSender<T: Send> {
clonable_sender: Mutex<Sender<T>>,
local_senders: ThreadLocal<Sender<T>>,
}
impl<T: Send> UnboundedSyncSender<T> {
pub fn new(sender: Sender<T>) -> Self {
Self {
clonable_sender: Mutex::new(sender),
local_senders: ThreadLocal::new(),
}
}
}
impl<T: Send> Deref for UnboundedSyncSender<T> {
type Target = Sender<T>;
fn deref(&self) -> &Self::Target {
self.local_senders
.get_or(|| self.clonable_sender.lock().clone())
}
}