working serialized writes with panics on failure. Everything seems to be working

This commit is contained in:
Kay Simmons 2022-11-23 01:53:58 -08:00 committed by Mikayla Maki
parent b01243109e
commit 1cc3e4820a
34 changed files with 669 additions and 312 deletions

View file

@ -1,36 +1,41 @@
use std::{marker::PhantomData, ops::Deref, sync::Arc};
use connection::Connection;
use futures::{Future, FutureExt};
use lazy_static::lazy_static;
use parking_lot::RwLock;
use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
use thread_local::ThreadLocal;
use crate::{
connection,
connection::Connection,
domain::{Domain, Migrator},
util::UnboundedSyncSender,
};
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
lazy_static! {
static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
Default::default();
}
pub struct ThreadSafeConnection<M: Migrator> {
uri: Option<Arc<str>>,
uri: Arc<str>,
persistent: bool,
initialize_query: Option<&'static str>,
connection: Arc<ThreadLocal<Connection>>,
_pd: PhantomData<M>,
connections: Arc<ThreadLocal<Connection>>,
_migrator: PhantomData<M>,
}
unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
impl<M: Migrator> ThreadSafeConnection<M> {
pub fn new(uri: Option<&str>, persistent: bool) -> Self {
if persistent == true && uri == None {
// This panic is securing the unwrap in open_file(), don't remove it!
panic!("Cannot create a persistent connection without a URI")
}
pub fn new(uri: &str, persistent: bool) -> Self {
Self {
uri: uri.map(|str| Arc::from(str)),
uri: Arc::from(uri),
persistent,
initialize_query: None,
connection: Default::default(),
_pd: PhantomData,
connections: Default::default(),
_migrator: PhantomData,
}
}
@ -46,13 +51,13 @@ impl<M: Migrator> ThreadSafeConnection<M> {
/// If opening fails, the connection falls back to a shared memory connection
fn open_file(&self) -> Connection {
// This unwrap is secured by a panic in the constructor. Be careful if you remove it!
Connection::open_file(self.uri.as_ref().unwrap())
Connection::open_file(self.uri.as_ref())
}
/// Opens a shared memory connection using the file path as the identifier. This unwraps
/// as we expect it always to succeed
fn open_shared_memory(&self) -> Connection {
Connection::open_memory(self.uri.as_ref().map(|str| str.deref()))
Connection::open_memory(Some(self.uri.as_ref()))
}
// Open a new connection for the given domain, leaving this
@ -62,10 +67,74 @@ impl<M: Migrator> ThreadSafeConnection<M> {
uri: self.uri.clone(),
persistent: self.persistent,
initialize_query: self.initialize_query,
connection: Default::default(),
_pd: PhantomData,
connections: Default::default(),
_migrator: PhantomData,
}
}
pub fn write<T: 'static + Send + Sync>(
&self,
callback: impl 'static + Send + FnOnce(&Connection) -> T,
) -> impl Future<Output = T> {
// Startup write thread for this database if one hasn't already
// been started and insert a channel to queue work for it
if !QUEUES.read().contains_key(&self.uri) {
use std::sync::mpsc::channel;
let (sender, reciever) = channel::<QueuedWrite>();
let mut write_connection = self.create_connection();
// Enable writes for this connection
write_connection.write = true;
thread::spawn(move || {
while let Ok(write) = reciever.recv() {
write(&write_connection)
}
});
let mut queues = QUEUES.write();
queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
}
// Grab the queue for this database
let queues = QUEUES.read();
let write_channel = queues.get(&self.uri).unwrap();
// Create a one shot channel for the result of the queued write
// so we can await on the result
let (sender, reciever) = futures::channel::oneshot::channel();
write_channel
.send(Box::new(move |connection| {
sender.send(callback(connection)).ok();
}))
.expect("Could not send write action to background thread");
reciever.map(|response| response.expect("Background thread unexpectedly closed"))
}
pub(crate) fn create_connection(&self) -> Connection {
let mut connection = if self.persistent {
self.open_file()
} else {
self.open_shared_memory()
};
// Enable writes for the migrations and initialization queries
connection.write = true;
if let Some(initialize_query) = self.initialize_query {
connection.exec(initialize_query).expect(&format!(
"Initialize query failed to execute: {}",
initialize_query
))()
.unwrap();
}
M::migrate(&connection).expect("Migrations failed");
// Disable db writes for normal thread local connection
connection.write = false;
connection
}
}
impl<D: Domain> Clone for ThreadSafeConnection<D> {
@ -74,8 +143,8 @@ impl<D: Domain> Clone for ThreadSafeConnection<D> {
uri: self.uri.clone(),
persistent: self.persistent,
initialize_query: self.initialize_query.clone(),
connection: self.connection.clone(),
_pd: PhantomData,
connections: self.connections.clone(),
_migrator: PhantomData,
}
}
}
@ -88,25 +157,7 @@ impl<M: Migrator> Deref for ThreadSafeConnection<M> {
type Target = Connection;
fn deref(&self) -> &Self::Target {
self.connection.get_or(|| {
let connection = if self.persistent {
self.open_file()
} else {
self.open_shared_memory()
};
if let Some(initialize_query) = self.initialize_query {
connection.exec(initialize_query).expect(&format!(
"Initialize query failed to execute: {}",
initialize_query
))()
.unwrap();
}
M::migrate(&connection).expect("Migrations failed");
connection
})
self.connections.get_or(|| self.create_connection())
}
}
@ -151,7 +202,7 @@ mod test {
}
}
let _ = ThreadSafeConnection::<TestWorkspace>::new(None, false)
let _ = ThreadSafeConnection::<TestWorkspace>::new("wild_zed_lost_failure", false)
.with_initialize_query("PRAGMA FOREIGN_KEYS=true")
.deref();
}