fix test failures
This commit is contained in:
parent
a29ccb4ff8
commit
1b225fa37c
2 changed files with 64 additions and 34 deletions
|
@ -4,6 +4,7 @@ pub mod kvp;
|
||||||
pub use anyhow;
|
pub use anyhow;
|
||||||
pub use indoc::indoc;
|
pub use indoc::indoc;
|
||||||
pub use lazy_static;
|
pub use lazy_static;
|
||||||
|
use parking_lot::Mutex;
|
||||||
pub use smol;
|
pub use smol;
|
||||||
pub use sqlez;
|
pub use sqlez;
|
||||||
pub use sqlez_macros;
|
pub use sqlez_macros;
|
||||||
|
@ -59,6 +60,14 @@ pub async fn open_memory_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<
|
||||||
ThreadSafeConnection::<M>::builder(db_name, false)
|
ThreadSafeConnection::<M>::builder(db_name, false)
|
||||||
.with_db_initialization_query(DB_INITIALIZE_QUERY)
|
.with_db_initialization_query(DB_INITIALIZE_QUERY)
|
||||||
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
|
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
|
||||||
|
// Serialize queued writes via a mutex and run them synchronously
|
||||||
|
.with_write_queue_constructor(Box::new(|connection| {
|
||||||
|
let connection = Mutex::new(connection);
|
||||||
|
Box::new(move |queued_write| {
|
||||||
|
let connection = connection.lock();
|
||||||
|
queued_write(&connection)
|
||||||
|
})
|
||||||
|
}))
|
||||||
.build()
|
.build()
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,12 +13,14 @@ use crate::{
|
||||||
const MIGRATION_RETRIES: usize = 10;
|
const MIGRATION_RETRIES: usize = 10;
|
||||||
|
|
||||||
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
|
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
|
||||||
|
type WriteQueueConstructor =
|
||||||
|
Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
/// List of queues of tasks by database uri. This lets us serialize writes to the database
|
/// List of queues of tasks by database uri. This lets us serialize writes to the database
|
||||||
/// and have a single worker thread per db file. This means many thread safe connections
|
/// and have a single worker thread per db file. This means many thread safe connections
|
||||||
/// (possibly with different migrations) could all be communicating with the same background
|
/// (possibly with different migrations) could all be communicating with the same background
|
||||||
/// thread.
|
/// thread.
|
||||||
static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
|
static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
|
||||||
Default::default();
|
Default::default();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +40,7 @@ unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
|
||||||
|
|
||||||
pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
|
pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
|
||||||
db_initialize_query: Option<&'static str>,
|
db_initialize_query: Option<&'static str>,
|
||||||
|
write_queue_constructor: Option<WriteQueueConstructor>,
|
||||||
connection: ThreadSafeConnection<M>,
|
connection: ThreadSafeConnection<M>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,6 +53,18 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Specifies how the thread safe connection should serialize writes. If provided
|
||||||
|
/// the connection will call the write_queue_constructor for each database file in
|
||||||
|
/// this process. The constructor is responsible for setting up a background thread or
|
||||||
|
/// async task which handles queued writes with the provided connection.
|
||||||
|
pub fn with_write_queue_constructor(
|
||||||
|
mut self,
|
||||||
|
write_queue_constructor: WriteQueueConstructor,
|
||||||
|
) -> Self {
|
||||||
|
self.write_queue_constructor = Some(write_queue_constructor);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Queues an initialization query for the database file. This must be infallible
|
/// Queues an initialization query for the database file. This must be infallible
|
||||||
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
|
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
|
||||||
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
|
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
|
||||||
|
@ -58,6 +73,38 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn build(self) -> ThreadSafeConnection<M> {
|
pub async fn build(self) -> ThreadSafeConnection<M> {
|
||||||
|
if !QUEUES.read().contains_key(&self.connection.uri) {
|
||||||
|
let mut queues = QUEUES.write();
|
||||||
|
if !queues.contains_key(&self.connection.uri) {
|
||||||
|
let mut write_connection = self.connection.create_connection();
|
||||||
|
// Enable writes for this connection
|
||||||
|
write_connection.write = true;
|
||||||
|
if let Some(mut write_queue_constructor) = self.write_queue_constructor {
|
||||||
|
let write_channel = write_queue_constructor(write_connection);
|
||||||
|
queues.insert(self.connection.uri.clone(), write_channel);
|
||||||
|
} else {
|
||||||
|
use std::sync::mpsc::channel;
|
||||||
|
|
||||||
|
let (sender, reciever) = channel::<QueuedWrite>();
|
||||||
|
thread::spawn(move || {
|
||||||
|
while let Ok(write) = reciever.recv() {
|
||||||
|
write(&write_connection)
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let sender = UnboundedSyncSender::new(sender);
|
||||||
|
queues.insert(
|
||||||
|
self.connection.uri.clone(),
|
||||||
|
Box::new(move |queued_write| {
|
||||||
|
sender
|
||||||
|
.send(queued_write)
|
||||||
|
.expect("Could not send write action to backgorund thread");
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let db_initialize_query = self.db_initialize_query;
|
let db_initialize_query = self.db_initialize_query;
|
||||||
|
|
||||||
self.connection
|
self.connection
|
||||||
|
@ -90,6 +137,7 @@ impl<M: Migrator> ThreadSafeConnection<M> {
|
||||||
pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
|
pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
|
||||||
ThreadSafeConnectionBuilder::<M> {
|
ThreadSafeConnectionBuilder::<M> {
|
||||||
db_initialize_query: None,
|
db_initialize_query: None,
|
||||||
|
write_queue_constructor: None,
|
||||||
connection: Self {
|
connection: Self {
|
||||||
uri: Arc::from(uri),
|
uri: Arc::from(uri),
|
||||||
persistent,
|
persistent,
|
||||||
|
@ -112,48 +160,21 @@ impl<M: Migrator> ThreadSafeConnection<M> {
|
||||||
Connection::open_memory(Some(self.uri.as_ref()))
|
Connection::open_memory(Some(self.uri.as_ref()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn queue_write_task(&self, callback: QueuedWrite) {
|
|
||||||
// Startup write thread for this database if one hasn't already
|
|
||||||
// been started and insert a channel to queue work for it
|
|
||||||
if !QUEUES.read().contains_key(&self.uri) {
|
|
||||||
let mut queues = QUEUES.write();
|
|
||||||
if !queues.contains_key(&self.uri) {
|
|
||||||
use std::sync::mpsc::channel;
|
|
||||||
|
|
||||||
let (sender, reciever) = channel::<QueuedWrite>();
|
|
||||||
let mut write_connection = self.create_connection();
|
|
||||||
// Enable writes for this connection
|
|
||||||
write_connection.write = true;
|
|
||||||
thread::spawn(move || {
|
|
||||||
while let Ok(write) = reciever.recv() {
|
|
||||||
write(&write_connection)
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Grab the queue for this database
|
|
||||||
let queues = QUEUES.read();
|
|
||||||
let write_channel = queues.get(&self.uri).unwrap();
|
|
||||||
|
|
||||||
write_channel
|
|
||||||
.send(callback)
|
|
||||||
.expect("Could not send write action to backgorund thread");
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn write<T: 'static + Send + Sync>(
|
pub fn write<T: 'static + Send + Sync>(
|
||||||
&self,
|
&self,
|
||||||
callback: impl 'static + Send + FnOnce(&Connection) -> T,
|
callback: impl 'static + Send + FnOnce(&Connection) -> T,
|
||||||
) -> impl Future<Output = T> {
|
) -> impl Future<Output = T> {
|
||||||
|
let queues = QUEUES.read();
|
||||||
|
let write_channel = queues
|
||||||
|
.get(&self.uri)
|
||||||
|
.expect("Queues are inserted when build is called. This should always succeed");
|
||||||
|
|
||||||
// Create a one shot channel for the result of the queued write
|
// Create a one shot channel for the result of the queued write
|
||||||
// so we can await on the result
|
// so we can await on the result
|
||||||
let (sender, reciever) = oneshot::channel();
|
let (sender, reciever) = oneshot::channel();
|
||||||
self.queue_write_task(Box::new(move |connection| {
|
write_channel(Box::new(move |connection| {
|
||||||
sender.send(callback(connection)).ok();
|
sender.send(callback(connection)).ok();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
|
reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue