working serialized writes with panics on failure. Everything seems to be working
This commit is contained in:
parent
b01243109e
commit
1cc3e4820a
34 changed files with 669 additions and 312 deletions
|
@ -9,4 +9,7 @@ edition = "2021"
|
|||
anyhow = { version = "1.0.38", features = ["backtrace"] }
|
||||
indoc = "1.0.7"
|
||||
libsqlite3-sys = { version = "0.25.2", features = ["bundled"] }
|
||||
thread_local = "1.1.4"
|
||||
thread_local = "1.1.4"
|
||||
lazy_static = "1.4"
|
||||
parking_lot = "0.11.1"
|
||||
futures = "0.3"
|
|
@ -322,6 +322,18 @@ impl Bind for &Path {
|
|||
}
|
||||
}
|
||||
|
||||
impl Bind for Arc<Path> {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
self.as_ref().bind(statement, start_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for PathBuf {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
(self.as_ref() as &Path).bind(statement, start_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for PathBuf {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let blob = statement.column_blob(start_index)?;
|
||||
|
|
|
@ -10,16 +10,18 @@ use libsqlite3_sys::*;
|
|||
pub struct Connection {
|
||||
pub(crate) sqlite3: *mut sqlite3,
|
||||
persistent: bool,
|
||||
phantom: PhantomData<sqlite3>,
|
||||
pub(crate) write: bool,
|
||||
_sqlite: PhantomData<sqlite3>,
|
||||
}
|
||||
unsafe impl Send for Connection {}
|
||||
|
||||
impl Connection {
|
||||
fn open(uri: &str, persistent: bool) -> Result<Self> {
|
||||
pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
|
||||
let mut connection = Self {
|
||||
sqlite3: 0 as *mut _,
|
||||
persistent,
|
||||
phantom: PhantomData,
|
||||
write: true,
|
||||
_sqlite: PhantomData,
|
||||
};
|
||||
|
||||
let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
|
||||
|
@ -60,6 +62,10 @@ impl Connection {
|
|||
self.persistent
|
||||
}
|
||||
|
||||
pub fn can_write(&self) -> bool {
|
||||
self.write
|
||||
}
|
||||
|
||||
pub fn backup_main(&self, destination: &Connection) -> Result<()> {
|
||||
unsafe {
|
||||
let backup = sqlite3_backup_init(
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
pub use anyhow;
|
||||
|
||||
pub mod bindable;
|
||||
pub mod connection;
|
||||
pub mod domain;
|
||||
|
@ -8,3 +6,6 @@ pub mod savepoint;
|
|||
pub mod statement;
|
||||
pub mod thread_safe_connection;
|
||||
pub mod typed_statements;
|
||||
mod util;
|
||||
|
||||
pub use anyhow;
|
||||
|
|
|
@ -11,46 +11,48 @@ use crate::connection::Connection;
|
|||
|
||||
impl Connection {
|
||||
pub fn migrate(&self, domain: &'static str, migrations: &[&'static str]) -> Result<()> {
|
||||
// Setup the migrations table unconditionally
|
||||
self.exec(indoc! {"
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
self.with_savepoint("migrating", || {
|
||||
// Setup the migrations table unconditionally
|
||||
self.exec(indoc! {"
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
domain TEXT,
|
||||
step INTEGER,
|
||||
migration TEXT
|
||||
)"})?()?;
|
||||
)"})?()?;
|
||||
|
||||
let completed_migrations =
|
||||
self.select_bound::<&str, (String, usize, String)>(indoc! {"
|
||||
let completed_migrations =
|
||||
self.select_bound::<&str, (String, usize, String)>(indoc! {"
|
||||
SELECT domain, step, migration FROM migrations
|
||||
WHERE domain = ?
|
||||
ORDER BY step
|
||||
"})?(domain)?;
|
||||
"})?(domain)?;
|
||||
|
||||
let mut store_completed_migration =
|
||||
self.exec_bound("INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)")?;
|
||||
let mut store_completed_migration = self
|
||||
.exec_bound("INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)")?;
|
||||
|
||||
for (index, migration) in migrations.iter().enumerate() {
|
||||
if let Some((_, _, completed_migration)) = completed_migrations.get(index) {
|
||||
if completed_migration != migration {
|
||||
return Err(anyhow!(formatdoc! {"
|
||||
Migration changed for {} at step {}
|
||||
|
||||
Stored migration:
|
||||
{}
|
||||
|
||||
Proposed migration:
|
||||
{}", domain, index, completed_migration, migration}));
|
||||
} else {
|
||||
// Migration already run. Continue
|
||||
continue;
|
||||
for (index, migration) in migrations.iter().enumerate() {
|
||||
if let Some((_, _, completed_migration)) = completed_migrations.get(index) {
|
||||
if completed_migration != migration {
|
||||
return Err(anyhow!(formatdoc! {"
|
||||
Migration changed for {} at step {}
|
||||
|
||||
Stored migration:
|
||||
{}
|
||||
|
||||
Proposed migration:
|
||||
{}", domain, index, completed_migration, migration}));
|
||||
} else {
|
||||
// Migration already run. Continue
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
self.exec(migration)?()?;
|
||||
store_completed_migration((domain, index, *migration))?;
|
||||
}
|
||||
|
||||
self.exec(migration)?()?;
|
||||
store_completed_migration((domain, index, *migration))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::ffi::{c_int, CStr, CString};
|
|||
use std::marker::PhantomData;
|
||||
use std::{ptr, slice, str};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use libsqlite3_sys::*;
|
||||
|
||||
use crate::bindable::{Bind, Column};
|
||||
|
@ -57,12 +57,21 @@ impl<'a> Statement<'a> {
|
|||
&mut raw_statement,
|
||||
&mut remaining_sql_ptr,
|
||||
);
|
||||
|
||||
remaining_sql = CStr::from_ptr(remaining_sql_ptr);
|
||||
statement.raw_statements.push(raw_statement);
|
||||
|
||||
connection.last_error().with_context(|| {
|
||||
format!("Prepare call failed for query:\n{}", query.as_ref())
|
||||
})?;
|
||||
|
||||
if !connection.can_write() && sqlite3_stmt_readonly(raw_statement) == 0 {
|
||||
let sql = CStr::from_ptr(sqlite3_sql(raw_statement));
|
||||
|
||||
bail!(
|
||||
"Write statement prepared with connection that is not write capable. SQL:\n{} ",
|
||||
sql.to_str()?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,36 +1,41 @@
|
|||
use std::{marker::PhantomData, ops::Deref, sync::Arc};
|
||||
|
||||
use connection::Connection;
|
||||
use futures::{Future, FutureExt};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::RwLock;
|
||||
use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
|
||||
use thread_local::ThreadLocal;
|
||||
|
||||
use crate::{
|
||||
connection,
|
||||
connection::Connection,
|
||||
domain::{Domain, Migrator},
|
||||
util::UnboundedSyncSender,
|
||||
};
|
||||
|
||||
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
|
||||
|
||||
lazy_static! {
|
||||
static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
|
||||
Default::default();
|
||||
}
|
||||
|
||||
pub struct ThreadSafeConnection<M: Migrator> {
|
||||
uri: Option<Arc<str>>,
|
||||
uri: Arc<str>,
|
||||
persistent: bool,
|
||||
initialize_query: Option<&'static str>,
|
||||
connection: Arc<ThreadLocal<Connection>>,
|
||||
_pd: PhantomData<M>,
|
||||
connections: Arc<ThreadLocal<Connection>>,
|
||||
_migrator: PhantomData<M>,
|
||||
}
|
||||
|
||||
unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
|
||||
unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
|
||||
|
||||
impl<M: Migrator> ThreadSafeConnection<M> {
|
||||
pub fn new(uri: Option<&str>, persistent: bool) -> Self {
|
||||
if persistent == true && uri == None {
|
||||
// This panic is securing the unwrap in open_file(), don't remove it!
|
||||
panic!("Cannot create a persistent connection without a URI")
|
||||
}
|
||||
pub fn new(uri: &str, persistent: bool) -> Self {
|
||||
Self {
|
||||
uri: uri.map(|str| Arc::from(str)),
|
||||
uri: Arc::from(uri),
|
||||
persistent,
|
||||
initialize_query: None,
|
||||
connection: Default::default(),
|
||||
_pd: PhantomData,
|
||||
connections: Default::default(),
|
||||
_migrator: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -46,13 +51,13 @@ impl<M: Migrator> ThreadSafeConnection<M> {
|
|||
/// If opening fails, the connection falls back to a shared memory connection
|
||||
fn open_file(&self) -> Connection {
|
||||
// This unwrap is secured by a panic in the constructor. Be careful if you remove it!
|
||||
Connection::open_file(self.uri.as_ref().unwrap())
|
||||
Connection::open_file(self.uri.as_ref())
|
||||
}
|
||||
|
||||
/// Opens a shared memory connection using the file path as the identifier. This unwraps
|
||||
/// as we expect it always to succeed
|
||||
fn open_shared_memory(&self) -> Connection {
|
||||
Connection::open_memory(self.uri.as_ref().map(|str| str.deref()))
|
||||
Connection::open_memory(Some(self.uri.as_ref()))
|
||||
}
|
||||
|
||||
// Open a new connection for the given domain, leaving this
|
||||
|
@ -62,10 +67,74 @@ impl<M: Migrator> ThreadSafeConnection<M> {
|
|||
uri: self.uri.clone(),
|
||||
persistent: self.persistent,
|
||||
initialize_query: self.initialize_query,
|
||||
connection: Default::default(),
|
||||
_pd: PhantomData,
|
||||
connections: Default::default(),
|
||||
_migrator: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write<T: 'static + Send + Sync>(
|
||||
&self,
|
||||
callback: impl 'static + Send + FnOnce(&Connection) -> T,
|
||||
) -> impl Future<Output = T> {
|
||||
// Startup write thread for this database if one hasn't already
|
||||
// been started and insert a channel to queue work for it
|
||||
if !QUEUES.read().contains_key(&self.uri) {
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
let (sender, reciever) = channel::<QueuedWrite>();
|
||||
let mut write_connection = self.create_connection();
|
||||
// Enable writes for this connection
|
||||
write_connection.write = true;
|
||||
thread::spawn(move || {
|
||||
while let Ok(write) = reciever.recv() {
|
||||
write(&write_connection)
|
||||
}
|
||||
});
|
||||
|
||||
let mut queues = QUEUES.write();
|
||||
queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
|
||||
}
|
||||
|
||||
// Grab the queue for this database
|
||||
let queues = QUEUES.read();
|
||||
let write_channel = queues.get(&self.uri).unwrap();
|
||||
|
||||
// Create a one shot channel for the result of the queued write
|
||||
// so we can await on the result
|
||||
let (sender, reciever) = futures::channel::oneshot::channel();
|
||||
write_channel
|
||||
.send(Box::new(move |connection| {
|
||||
sender.send(callback(connection)).ok();
|
||||
}))
|
||||
.expect("Could not send write action to background thread");
|
||||
|
||||
reciever.map(|response| response.expect("Background thread unexpectedly closed"))
|
||||
}
|
||||
|
||||
pub(crate) fn create_connection(&self) -> Connection {
|
||||
let mut connection = if self.persistent {
|
||||
self.open_file()
|
||||
} else {
|
||||
self.open_shared_memory()
|
||||
};
|
||||
|
||||
// Enable writes for the migrations and initialization queries
|
||||
connection.write = true;
|
||||
|
||||
if let Some(initialize_query) = self.initialize_query {
|
||||
connection.exec(initialize_query).expect(&format!(
|
||||
"Initialize query failed to execute: {}",
|
||||
initialize_query
|
||||
))()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
M::migrate(&connection).expect("Migrations failed");
|
||||
|
||||
// Disable db writes for normal thread local connection
|
||||
connection.write = false;
|
||||
connection
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Domain> Clone for ThreadSafeConnection<D> {
|
||||
|
@ -74,8 +143,8 @@ impl<D: Domain> Clone for ThreadSafeConnection<D> {
|
|||
uri: self.uri.clone(),
|
||||
persistent: self.persistent,
|
||||
initialize_query: self.initialize_query.clone(),
|
||||
connection: self.connection.clone(),
|
||||
_pd: PhantomData,
|
||||
connections: self.connections.clone(),
|
||||
_migrator: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -88,25 +157,7 @@ impl<M: Migrator> Deref for ThreadSafeConnection<M> {
|
|||
type Target = Connection;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.connection.get_or(|| {
|
||||
let connection = if self.persistent {
|
||||
self.open_file()
|
||||
} else {
|
||||
self.open_shared_memory()
|
||||
};
|
||||
|
||||
if let Some(initialize_query) = self.initialize_query {
|
||||
connection.exec(initialize_query).expect(&format!(
|
||||
"Initialize query failed to execute: {}",
|
||||
initialize_query
|
||||
))()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
M::migrate(&connection).expect("Migrations failed");
|
||||
|
||||
connection
|
||||
})
|
||||
self.connections.get_or(|| self.create_connection())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -151,7 +202,7 @@ mod test {
|
|||
}
|
||||
}
|
||||
|
||||
let _ = ThreadSafeConnection::<TestWorkspace>::new(None, false)
|
||||
let _ = ThreadSafeConnection::<TestWorkspace>::new("wild_zed_lost_failure", false)
|
||||
.with_initialize_query("PRAGMA FOREIGN_KEYS=true")
|
||||
.deref();
|
||||
}
|
||||
|
|
28
crates/sqlez/src/util.rs
Normal file
28
crates/sqlez/src/util.rs
Normal file
|
@ -0,0 +1,28 @@
|
|||
use std::ops::Deref;
|
||||
use std::sync::mpsc::Sender;
|
||||
|
||||
use parking_lot::Mutex;
|
||||
use thread_local::ThreadLocal;
|
||||
|
||||
pub struct UnboundedSyncSender<T: Send> {
|
||||
clonable_sender: Mutex<Sender<T>>,
|
||||
local_senders: ThreadLocal<Sender<T>>,
|
||||
}
|
||||
|
||||
impl<T: Send> UnboundedSyncSender<T> {
|
||||
pub fn new(sender: Sender<T>) -> Self {
|
||||
Self {
|
||||
clonable_sender: Mutex::new(sender),
|
||||
local_senders: ThreadLocal::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send> Deref for UnboundedSyncSender<T> {
|
||||
type Target = Sender<T>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.local_senders
|
||||
.get_or(|| self.clonable_sender.lock().clone())
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue