Reworked thread safe connection be threadsafer,,,, again

Co-Authored-By: kay@zed.dev
This commit is contained in:
Mikayla Maki 2022-12-01 18:31:05 -08:00
parent 189a820113
commit 5e240f98f0
12 changed files with 741 additions and 584 deletions

View file

@ -5,17 +5,13 @@ use parking_lot::{Mutex, RwLock};
use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
use thread_local::ThreadLocal;
use crate::{
connection::Connection,
domain::{Domain, Migrator},
util::UnboundedSyncSender,
};
use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
const MIGRATION_RETRIES: usize = 10;
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
type WriteQueueConstructor =
Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
Box<dyn 'static + Send + FnMut() -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
lazy_static! {
/// List of queues of tasks by database uri. This lets us serialize writes to the database
/// and have a single worker thread per db file. This means many thread safe connections
@ -28,18 +24,18 @@ lazy_static! {
/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
/// may be accessed by passing a callback to the `write` function which will queue the callback
pub struct ThreadSafeConnection<M: Migrator = ()> {
pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
uri: Arc<str>,
persistent: bool,
connection_initialize_query: Option<&'static str>,
connections: Arc<ThreadLocal<Connection>>,
_migrator: PhantomData<M>,
_migrator: PhantomData<*mut M>,
}
unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
db_initialize_query: Option<&'static str>,
write_queue_constructor: Option<WriteQueueConstructor>,
connection: ThreadSafeConnection<M>,
@ -54,6 +50,13 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
self
}
/// Queues an initialization query for the database file. This must be infallible
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
self.db_initialize_query = Some(initialize_query);
self
}
/// Specifies how the thread safe connection should serialize writes. If provided
/// the connection will call the write_queue_constructor for each database file in
/// this process. The constructor is responsible for setting up a background thread or
@ -66,13 +69,6 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
self
}
/// Queues an initialization query for the database file. This must be infallible
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
self.db_initialize_query = Some(initialize_query);
self
}
pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
self.connection
.initialize_queues(self.write_queue_constructor);
@ -100,6 +96,7 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
.with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
if migration_result.is_ok() {
println!("Migration succeded");
break;
}
}
@ -113,38 +110,17 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
}
impl<M: Migrator> ThreadSafeConnection<M> {
fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) {
fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
if !QUEUES.read().contains_key(&self.uri) {
let mut queues = QUEUES.write();
if !queues.contains_key(&self.uri) {
let mut write_connection = self.create_connection();
// Enable writes for this connection
write_connection.write = true;
if let Some(mut write_queue_constructor) = write_queue_constructor {
let write_channel = write_queue_constructor(write_connection);
queues.insert(self.uri.clone(), write_channel);
} else {
use std::sync::mpsc::channel;
let (sender, reciever) = channel::<QueuedWrite>();
thread::spawn(move || {
while let Ok(write) = reciever.recv() {
write(&write_connection)
}
});
let sender = UnboundedSyncSender::new(sender);
queues.insert(
self.uri.clone(),
Box::new(move |queued_write| {
sender
.send(queued_write)
.expect("Could not send write action to backgorund thread");
}),
);
}
let mut write_queue_constructor =
write_queue_constructor.unwrap_or(background_thread_queue());
queues.insert(self.uri.clone(), write_queue_constructor());
return true;
}
}
return false;
}
pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
@ -163,20 +139,21 @@ impl<M: Migrator> ThreadSafeConnection<M> {
/// Opens a new db connection with the initialized file path. This is internal and only
/// called from the deref function.
fn open_file(&self) -> Connection {
Connection::open_file(self.uri.as_ref())
fn open_file(uri: &str) -> Connection {
Connection::open_file(uri)
}
/// Opens a shared memory connection using the file path as the identifier. This is internal
/// and only called from the deref function.
fn open_shared_memory(&self) -> Connection {
Connection::open_memory(Some(self.uri.as_ref()))
fn open_shared_memory(uri: &str) -> Connection {
Connection::open_memory(Some(uri))
}
pub fn write<T: 'static + Send + Sync>(
&self,
callback: impl 'static + Send + FnOnce(&Connection) -> T,
) -> impl Future<Output = T> {
// Check and invalidate queue and maybe recreate queue
let queues = QUEUES.read();
let write_channel = queues
.get(&self.uri)
@ -185,24 +162,32 @@ impl<M: Migrator> ThreadSafeConnection<M> {
// Create a one shot channel for the result of the queued write
// so we can await on the result
let (sender, reciever) = oneshot::channel();
write_channel(Box::new(move |connection| {
sender.send(callback(connection)).ok();
let thread_safe_connection = (*self).clone();
write_channel(Box::new(move || {
let connection = thread_safe_connection.deref();
let result = connection.with_write(|connection| callback(connection));
sender.send(result).ok();
}));
reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
}
pub(crate) fn create_connection(&self) -> Connection {
let mut connection = if self.persistent {
self.open_file()
pub(crate) fn create_connection(
persistent: bool,
uri: &str,
connection_initialize_query: Option<&'static str>,
) -> Connection {
let mut connection = if persistent {
Self::open_file(uri)
} else {
self.open_shared_memory()
Self::open_shared_memory(uri)
};
// Disallow writes on the connection. The only writes allowed for thread safe connections
// are from the background thread that can serialize them.
connection.write = false;
*connection.write.get_mut() = false;
if let Some(initialize_query) = self.connection_initialize_query {
if let Some(initialize_query) = connection_initialize_query {
connection.exec(initialize_query).expect(&format!(
"Initialize query failed to execute: {}",
initialize_query
@ -236,7 +221,7 @@ impl ThreadSafeConnection<()> {
}
}
impl<D: Domain> Clone for ThreadSafeConnection<D> {
impl<M: Migrator> Clone for ThreadSafeConnection<M> {
fn clone(&self) -> Self {
Self {
uri: self.uri.clone(),
@ -252,16 +237,41 @@ impl<M: Migrator> Deref for ThreadSafeConnection<M> {
type Target = Connection;
fn deref(&self) -> &Self::Target {
self.connections.get_or(|| self.create_connection())
self.connections.get_or(|| {
Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
})
}
}
pub fn locking_queue() -> WriteQueueConstructor {
Box::new(|connection| {
let connection = Mutex::new(connection);
pub fn background_thread_queue() -> WriteQueueConstructor {
use std::sync::mpsc::channel;
Box::new(|| {
let (sender, reciever) = channel::<QueuedWrite>();
thread::spawn(move || {
while let Ok(write) = reciever.recv() {
write()
}
});
let sender = UnboundedSyncSender::new(sender);
Box::new(move |queued_write| {
let connection = connection.lock();
queued_write(&connection)
sender
.send(queued_write)
.expect("Could not send write action to background thread");
})
})
}
pub fn locking_queue() -> WriteQueueConstructor {
Box::new(|| {
let mutex = Mutex::new(());
Box::new(move |queued_write| {
eprintln!("Write started");
let _ = mutex.lock();
queued_write();
eprintln!("Write finished");
})
})
}
@ -269,7 +279,8 @@ pub fn locking_queue() -> WriteQueueConstructor {
#[cfg(test)]
mod test {
use indoc::indoc;
use std::ops::Deref;
use lazy_static::__Deref;
use std::thread;
use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};