make thread safe connection more thread safe
Co-Authored-By: Mikayla Maki <mikayla@zed.dev>
This commit is contained in:
parent
9cd6894dc5
commit
a29ccb4ff8
12 changed files with 196 additions and 124 deletions
|
@ -9,6 +9,7 @@ edition = "2021"
|
|||
anyhow = { version = "1.0.38", features = ["backtrace"] }
|
||||
indoc = "1.0.7"
|
||||
libsqlite3-sys = { version = "0.25.2", features = ["bundled"] }
|
||||
smol = "1.2"
|
||||
thread_local = "1.1.4"
|
||||
lazy_static = "1.4"
|
||||
parking_lot = "0.11.1"
|
||||
|
|
|
@ -15,9 +15,9 @@ impl Connection {
|
|||
// Setup the migrations table unconditionally
|
||||
self.exec(indoc! {"
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
domain TEXT,
|
||||
step INTEGER,
|
||||
migration TEXT
|
||||
domain TEXT,
|
||||
step INTEGER,
|
||||
migration TEXT
|
||||
)"})?()?;
|
||||
|
||||
let completed_migrations =
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use futures::{Future, FutureExt};
|
||||
use futures::{channel::oneshot, Future, FutureExt};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::RwLock;
|
||||
use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
|
||||
|
@ -10,17 +10,25 @@ use crate::{
|
|||
util::UnboundedSyncSender,
|
||||
};
|
||||
|
||||
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
|
||||
const MIGRATION_RETRIES: usize = 10;
|
||||
|
||||
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
|
||||
lazy_static! {
|
||||
/// List of queues of tasks by database uri. This lets us serialize writes to the database
|
||||
/// and have a single worker thread per db file. This means many thread safe connections
|
||||
/// (possibly with different migrations) could all be communicating with the same background
|
||||
/// thread.
|
||||
static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
|
||||
Default::default();
|
||||
}
|
||||
|
||||
/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
|
||||
/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
|
||||
/// may be accessed by passing a callback to the `write` function which will queue the callback
|
||||
pub struct ThreadSafeConnection<M: Migrator = ()> {
|
||||
uri: Arc<str>,
|
||||
persistent: bool,
|
||||
initialize_query: Option<&'static str>,
|
||||
connection_initialize_query: Option<&'static str>,
|
||||
connections: Arc<ThreadLocal<Connection>>,
|
||||
_migrator: PhantomData<M>,
|
||||
}
|
||||
|
@ -28,87 +36,125 @@ pub struct ThreadSafeConnection<M: Migrator = ()> {
|
|||
unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
|
||||
unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
|
||||
|
||||
impl<M: Migrator> ThreadSafeConnection<M> {
|
||||
pub fn new(uri: &str, persistent: bool) -> Self {
|
||||
Self {
|
||||
uri: Arc::from(uri),
|
||||
persistent,
|
||||
initialize_query: None,
|
||||
connections: Default::default(),
|
||||
_migrator: PhantomData,
|
||||
}
|
||||
pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
|
||||
db_initialize_query: Option<&'static str>,
|
||||
connection: ThreadSafeConnection<M>,
|
||||
}
|
||||
|
||||
impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
|
||||
/// Sets the query to run every time a connection is opened. This must
|
||||
/// be infallible (EG only use pragma statements) and not cause writes.
|
||||
/// to the db or it will panic.
|
||||
pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
|
||||
self.connection.connection_initialize_query = Some(initialize_query);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the query to run every time a connection is opened. This must
|
||||
/// be infallible (EG only use pragma statements)
|
||||
pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
|
||||
self.initialize_query = Some(initialize_query);
|
||||
/// Queues an initialization query for the database file. This must be infallible
|
||||
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
|
||||
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
|
||||
self.db_initialize_query = Some(initialize_query);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn build(self) -> ThreadSafeConnection<M> {
|
||||
let db_initialize_query = self.db_initialize_query;
|
||||
|
||||
self.connection
|
||||
.write(move |connection| {
|
||||
if let Some(db_initialize_query) = db_initialize_query {
|
||||
connection.exec(db_initialize_query).expect(&format!(
|
||||
"Db initialize query failed to execute: {}",
|
||||
db_initialize_query
|
||||
))()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let mut failure_result = None;
|
||||
for _ in 0..MIGRATION_RETRIES {
|
||||
failure_result = Some(M::migrate(connection));
|
||||
if failure_result.as_ref().unwrap().is_ok() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
failure_result.unwrap().expect("Migration failed");
|
||||
})
|
||||
.await;
|
||||
|
||||
self.connection
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Migrator> ThreadSafeConnection<M> {
|
||||
pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
|
||||
ThreadSafeConnectionBuilder::<M> {
|
||||
db_initialize_query: None,
|
||||
connection: Self {
|
||||
uri: Arc::from(uri),
|
||||
persistent,
|
||||
connection_initialize_query: None,
|
||||
connections: Default::default(),
|
||||
_migrator: PhantomData,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Opens a new db connection with the initialized file path. This is internal and only
|
||||
/// called from the deref function.
|
||||
/// If opening fails, the connection falls back to a shared memory connection
|
||||
fn open_file(&self) -> Connection {
|
||||
// This unwrap is secured by a panic in the constructor. Be careful if you remove it!
|
||||
Connection::open_file(self.uri.as_ref())
|
||||
}
|
||||
|
||||
/// Opens a shared memory connection using the file path as the identifier. This unwraps
|
||||
/// as we expect it always to succeed
|
||||
/// Opens a shared memory connection using the file path as the identifier. This is internal
|
||||
/// and only called from the deref function.
|
||||
fn open_shared_memory(&self) -> Connection {
|
||||
Connection::open_memory(Some(self.uri.as_ref()))
|
||||
}
|
||||
|
||||
// Open a new connection for the given domain, leaving this
|
||||
// connection intact.
|
||||
pub fn for_domain<D2: Domain>(&self) -> ThreadSafeConnection<D2> {
|
||||
ThreadSafeConnection {
|
||||
uri: self.uri.clone(),
|
||||
persistent: self.persistent,
|
||||
initialize_query: self.initialize_query,
|
||||
connections: Default::default(),
|
||||
_migrator: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write<T: 'static + Send + Sync>(
|
||||
&self,
|
||||
callback: impl 'static + Send + FnOnce(&Connection) -> T,
|
||||
) -> impl Future<Output = T> {
|
||||
fn queue_write_task(&self, callback: QueuedWrite) {
|
||||
// Startup write thread for this database if one hasn't already
|
||||
// been started and insert a channel to queue work for it
|
||||
if !QUEUES.read().contains_key(&self.uri) {
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
let (sender, reciever) = channel::<QueuedWrite>();
|
||||
let mut write_connection = self.create_connection();
|
||||
// Enable writes for this connection
|
||||
write_connection.write = true;
|
||||
thread::spawn(move || {
|
||||
while let Ok(write) = reciever.recv() {
|
||||
write(&write_connection)
|
||||
}
|
||||
});
|
||||
|
||||
let mut queues = QUEUES.write();
|
||||
queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
|
||||
if !queues.contains_key(&self.uri) {
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
let (sender, reciever) = channel::<QueuedWrite>();
|
||||
let mut write_connection = self.create_connection();
|
||||
// Enable writes for this connection
|
||||
write_connection.write = true;
|
||||
thread::spawn(move || {
|
||||
while let Ok(write) = reciever.recv() {
|
||||
write(&write_connection)
|
||||
}
|
||||
});
|
||||
|
||||
queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
|
||||
}
|
||||
}
|
||||
|
||||
// Grab the queue for this database
|
||||
let queues = QUEUES.read();
|
||||
let write_channel = queues.get(&self.uri).unwrap();
|
||||
|
||||
write_channel
|
||||
.send(callback)
|
||||
.expect("Could not send write action to backgorund thread");
|
||||
}
|
||||
|
||||
pub fn write<T: 'static + Send + Sync>(
|
||||
&self,
|
||||
callback: impl 'static + Send + FnOnce(&Connection) -> T,
|
||||
) -> impl Future<Output = T> {
|
||||
// Create a one shot channel for the result of the queued write
|
||||
// so we can await on the result
|
||||
let (sender, reciever) = futures::channel::oneshot::channel();
|
||||
write_channel
|
||||
.send(Box::new(move |connection| {
|
||||
sender.send(callback(connection)).ok();
|
||||
}))
|
||||
.expect("Could not send write action to background thread");
|
||||
let (sender, reciever) = oneshot::channel();
|
||||
self.queue_write_task(Box::new(move |connection| {
|
||||
sender.send(callback(connection)).ok();
|
||||
}));
|
||||
|
||||
reciever.map(|response| response.expect("Background thread unexpectedly closed"))
|
||||
reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
|
||||
}
|
||||
|
||||
pub(crate) fn create_connection(&self) -> Connection {
|
||||
|
@ -118,10 +164,11 @@ impl<M: Migrator> ThreadSafeConnection<M> {
|
|||
self.open_shared_memory()
|
||||
};
|
||||
|
||||
// Enable writes for the migrations and initialization queries
|
||||
connection.write = true;
|
||||
// Disallow writes on the connection. The only writes allowed for thread safe connections
|
||||
// are from the background thread that can serialize them.
|
||||
connection.write = false;
|
||||
|
||||
if let Some(initialize_query) = self.initialize_query {
|
||||
if let Some(initialize_query) = self.connection_initialize_query {
|
||||
connection.exec(initialize_query).expect(&format!(
|
||||
"Initialize query failed to execute: {}",
|
||||
initialize_query
|
||||
|
@ -129,20 +176,34 @@ impl<M: Migrator> ThreadSafeConnection<M> {
|
|||
.unwrap()
|
||||
}
|
||||
|
||||
M::migrate(&connection).expect("Migrations failed");
|
||||
|
||||
// Disable db writes for normal thread local connection
|
||||
connection.write = false;
|
||||
connection
|
||||
}
|
||||
}
|
||||
|
||||
impl ThreadSafeConnection<()> {
|
||||
/// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
|
||||
/// This allows construction to be infallible and not write to the db.
|
||||
pub fn new(
|
||||
uri: &str,
|
||||
persistent: bool,
|
||||
connection_initialize_query: Option<&'static str>,
|
||||
) -> Self {
|
||||
Self {
|
||||
uri: Arc::from(uri),
|
||||
persistent,
|
||||
connection_initialize_query,
|
||||
connections: Default::default(),
|
||||
_migrator: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Domain> Clone for ThreadSafeConnection<D> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
uri: self.uri.clone(),
|
||||
persistent: self.persistent,
|
||||
initialize_query: self.initialize_query.clone(),
|
||||
connection_initialize_query: self.connection_initialize_query.clone(),
|
||||
connections: self.connections.clone(),
|
||||
_migrator: PhantomData,
|
||||
}
|
||||
|
@ -163,11 +224,11 @@ impl<M: Migrator> Deref for ThreadSafeConnection<M> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::{fs, ops::Deref, thread};
|
||||
use indoc::indoc;
|
||||
use lazy_static::__Deref;
|
||||
use std::thread;
|
||||
|
||||
use crate::domain::Domain;
|
||||
|
||||
use super::ThreadSafeConnection;
|
||||
use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
|
||||
|
||||
#[test]
|
||||
fn many_initialize_and_migrate_queries_at_once() {
|
||||
|
@ -185,27 +246,22 @@ mod test {
|
|||
|
||||
for _ in 0..100 {
|
||||
handles.push(thread::spawn(|| {
|
||||
let _ = ThreadSafeConnection::<TestDomain>::new("annoying-test.db", false)
|
||||
.with_initialize_query(
|
||||
"
|
||||
PRAGMA journal_mode=WAL;
|
||||
PRAGMA synchronous=NORMAL;
|
||||
PRAGMA busy_timeout=1;
|
||||
PRAGMA foreign_keys=TRUE;
|
||||
PRAGMA case_sensitive_like=TRUE;
|
||||
",
|
||||
)
|
||||
.deref();
|
||||
let builder =
|
||||
ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
|
||||
.with_db_initialization_query("PRAGMA journal_mode=WAL")
|
||||
.with_connection_initialize_query(indoc! {"
|
||||
PRAGMA synchronous=NORMAL;
|
||||
PRAGMA busy_timeout=1;
|
||||
PRAGMA foreign_keys=TRUE;
|
||||
PRAGMA case_sensitive_like=TRUE;
|
||||
"});
|
||||
let _ = smol::block_on(builder.build()).deref();
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
let _ = handle.join();
|
||||
}
|
||||
|
||||
// fs::remove_file("annoying-test.db").unwrap();
|
||||
// fs::remove_file("annoying-test.db-shm").unwrap();
|
||||
// fs::remove_file("annoying-test.db-wal").unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -241,8 +297,10 @@ mod test {
|
|||
}
|
||||
}
|
||||
|
||||
let _ = ThreadSafeConnection::<TestWorkspace>::new("wild_zed_lost_failure", false)
|
||||
.with_initialize_query("PRAGMA FOREIGN_KEYS=true")
|
||||
.deref();
|
||||
let builder =
|
||||
ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
|
||||
.with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
|
||||
|
||||
smol::block_on(builder.build());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,10 @@ use std::sync::mpsc::Sender;
|
|||
use parking_lot::Mutex;
|
||||
use thread_local::ThreadLocal;
|
||||
|
||||
/// Unbounded standard library sender which is stored per thread to get around
|
||||
/// the lack of sync on the standard library version while still being unbounded
|
||||
/// Note: this locks on the cloneable sender, but its done once per thread, so it
|
||||
/// shouldn't result in too much contention
|
||||
pub struct UnboundedSyncSender<T: Send> {
|
||||
clonable_sender: Mutex<Sender<T>>,
|
||||
local_senders: ThreadLocal<Sender<T>>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue