Merge branch 'main' into reconnections-2
This commit is contained in:
commit
9a62150dce
86 changed files with 7454 additions and 2211 deletions
2
crates/sqlez/.gitignore
vendored
Normal file
2
crates/sqlez/.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
debug/
|
||||
target/
|
150
crates/sqlez/Cargo.lock
generated
Normal file
150
crates/sqlez/Cargo.lock
generated
Normal file
|
@ -0,0 +1,150 @@
|
|||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
version = "0.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b"
|
||||
dependencies = [
|
||||
"gimli",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adler"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.66"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.66"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cab84319d616cfb654d03394f38ab7e6f0919e181b1b57e1fd15e7fb4077d9a7"
|
||||
dependencies = [
|
||||
"addr2line",
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"miniz_oxide",
|
||||
"object",
|
||||
"rustc-demangle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.26.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22030e2c5a68ec659fde1e949a745124b48e6fa8b045b7ed5bd1fe4ccc5c4e5d"
|
||||
|
||||
[[package]]
|
||||
name = "indoc"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adab1eaa3408fb7f0c777a73e7465fd5656136fc93b670eb6df3c88c2c1344e3"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.137"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89"
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
version = "0.25.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29f835d03d717946d28b1d1ed632eb6f0e24a299388ee623d0c23118d3e8a7fa"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34"
|
||||
dependencies = [
|
||||
"adler",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.29.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21158b2c33aa6d4561f1c0a6ea283ca92bc54802a93b263e910746d679a7eb53"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
|
||||
|
||||
[[package]]
|
||||
name = "pkg-config"
|
||||
version = "0.3.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342"
|
||||
|
||||
[[package]]
|
||||
name = "sqlez"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"indoc",
|
||||
"libsqlite3-sys",
|
||||
"thread_local",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vcpkg"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
16
crates/sqlez/Cargo.toml
Normal file
16
crates/sqlez/Cargo.toml
Normal file
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "sqlez"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = { version = "1.0.38", features = ["backtrace"] }
|
||||
indoc = "1.0.7"
|
||||
libsqlite3-sys = { version = "0.24", features = ["bundled"] }
|
||||
smol = "1.2"
|
||||
thread_local = "1.1.4"
|
||||
lazy_static = "1.4"
|
||||
parking_lot = "0.11.1"
|
||||
futures = "0.3"
|
352
crates/sqlez/src/bindable.rs
Normal file
352
crates/sqlez/src/bindable.rs
Normal file
|
@ -0,0 +1,352 @@
|
|||
use std::{
|
||||
ffi::OsStr,
|
||||
os::unix::prelude::OsStrExt,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
use crate::statement::{SqlType, Statement};
|
||||
|
||||
pub trait Bind {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32>;
|
||||
}
|
||||
|
||||
pub trait Column: Sized {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)>;
|
||||
}
|
||||
|
||||
impl Bind for bool {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement
|
||||
.bind(self.then_some(1).unwrap_or(0), start_index)
|
||||
.with_context(|| format!("Failed to bind bool at index {start_index}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for bool {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
i32::column(statement, start_index)
|
||||
.map(|(i, next_index)| (i != 0, next_index))
|
||||
.with_context(|| format!("Failed to read bool at index {start_index}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for &[u8] {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement
|
||||
.bind_blob(start_index, self)
|
||||
.with_context(|| format!("Failed to bind &[u8] at index {start_index}"))?;
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const C: usize> Bind for &[u8; C] {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement
|
||||
.bind_blob(start_index, self.as_slice())
|
||||
.with_context(|| format!("Failed to bind &[u8; C] at index {start_index}"))?;
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for Vec<u8> {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement
|
||||
.bind_blob(start_index, self)
|
||||
.with_context(|| format!("Failed to bind Vec<u8> at index {start_index}"))?;
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for Vec<u8> {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let result = statement
|
||||
.column_blob(start_index)
|
||||
.with_context(|| format!("Failed to read Vec<u8> at index {start_index}"))?;
|
||||
|
||||
Ok((Vec::from(result), start_index + 1))
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for f64 {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement
|
||||
.bind_double(start_index, *self)
|
||||
.with_context(|| format!("Failed to bind f64 at index {start_index}"))?;
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for f64 {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let result = statement
|
||||
.column_double(start_index)
|
||||
.with_context(|| format!("Failed to parse f64 at index {start_index}"))?;
|
||||
|
||||
Ok((result, start_index + 1))
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for i32 {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement
|
||||
.bind_int(start_index, *self)
|
||||
.with_context(|| format!("Failed to bind i32 at index {start_index}"))?;
|
||||
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for i32 {
|
||||
fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let result = statement.column_int(start_index)?;
|
||||
Ok((result, start_index + 1))
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for i64 {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement
|
||||
.bind_int64(start_index, *self)
|
||||
.with_context(|| format!("Failed to bind i64 at index {start_index}"))?;
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for i64 {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let result = statement.column_int64(start_index)?;
|
||||
Ok((result, start_index + 1))
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for usize {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
(*self as i64)
|
||||
.bind(statement, start_index)
|
||||
.with_context(|| format!("Failed to bind usize at index {start_index}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for usize {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let result = statement.column_int64(start_index)?;
|
||||
Ok((result as usize, start_index + 1))
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for &str {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement.bind_text(start_index, self)?;
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for Arc<str> {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement.bind_text(start_index, self.as_ref())?;
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for String {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
statement.bind_text(start_index, self)?;
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for Arc<str> {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let result = statement.column_text(start_index)?;
|
||||
Ok((Arc::from(result), start_index + 1))
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for String {
|
||||
fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let result = statement.column_text(start_index)?;
|
||||
Ok((result.to_owned(), start_index + 1))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Bind> Bind for Option<T> {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
if let Some(this) = self {
|
||||
this.bind(statement, start_index)
|
||||
} else {
|
||||
statement.bind_null(start_index)?;
|
||||
Ok(start_index + 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Column> Column for Option<T> {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
if let SqlType::Null = statement.column_type(start_index)? {
|
||||
Ok((None, start_index + 1))
|
||||
} else {
|
||||
T::column(statement, start_index).map(|(result, next_index)| (Some(result), next_index))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Bind, const COUNT: usize> Bind for [T; COUNT] {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
let mut current_index = start_index;
|
||||
for binding in self {
|
||||
current_index = binding.bind(statement, current_index)?
|
||||
}
|
||||
|
||||
Ok(current_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Column + Default + Copy, const COUNT: usize> Column for [T; COUNT] {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let mut array = [Default::default(); COUNT];
|
||||
let mut current_index = start_index;
|
||||
for i in 0..COUNT {
|
||||
(array[i], current_index) = T::column(statement, current_index)?;
|
||||
}
|
||||
Ok((array, current_index))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Bind> Bind for Vec<T> {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
let mut current_index = start_index;
|
||||
for binding in self.iter() {
|
||||
current_index = binding.bind(statement, current_index)?
|
||||
}
|
||||
|
||||
Ok(current_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Bind> Bind for &[T] {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
let mut current_index = start_index;
|
||||
for binding in *self {
|
||||
current_index = binding.bind(statement, current_index)?
|
||||
}
|
||||
|
||||
Ok(current_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for &Path {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
self.as_os_str().as_bytes().bind(statement, start_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for Arc<Path> {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
self.as_ref().bind(statement, start_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for PathBuf {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
(self.as_ref() as &Path).bind(statement, start_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for PathBuf {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let blob = statement.column_blob(start_index)?;
|
||||
|
||||
Ok((
|
||||
PathBuf::from(OsStr::from_bytes(blob).to_owned()),
|
||||
start_index + 1,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Unit impls do nothing. This simplifies query macros
|
||||
impl Bind for () {
|
||||
fn bind(&self, _statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
Ok(start_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for () {
|
||||
fn column(_statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
Ok(((), start_index))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: Bind, T2: Bind> Bind for (T1, T2) {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
let next_index = self.0.bind(statement, start_index)?;
|
||||
self.1.bind(statement, next_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: Column, T2: Column> Column for (T1, T2) {
|
||||
fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let (first, next_index) = T1::column(statement, start_index)?;
|
||||
let (second, next_index) = T2::column(statement, next_index)?;
|
||||
Ok(((first, second), next_index))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: Bind, T2: Bind, T3: Bind> Bind for (T1, T2, T3) {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
let next_index = self.0.bind(statement, start_index)?;
|
||||
let next_index = self.1.bind(statement, next_index)?;
|
||||
self.2.bind(statement, next_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: Column, T2: Column, T3: Column> Column for (T1, T2, T3) {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let (first, next_index) = T1::column(statement, start_index)?;
|
||||
let (second, next_index) = T2::column(statement, next_index)?;
|
||||
let (third, next_index) = T3::column(statement, next_index)?;
|
||||
Ok(((first, second, third), next_index))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: Bind, T2: Bind, T3: Bind, T4: Bind> Bind for (T1, T2, T3, T4) {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
let next_index = self.0.bind(statement, start_index)?;
|
||||
let next_index = self.1.bind(statement, next_index)?;
|
||||
let next_index = self.2.bind(statement, next_index)?;
|
||||
self.3.bind(statement, next_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: Column, T2: Column, T3: Column, T4: Column> Column for (T1, T2, T3, T4) {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let (first, next_index) = T1::column(statement, start_index)?;
|
||||
let (second, next_index) = T2::column(statement, next_index)?;
|
||||
let (third, next_index) = T3::column(statement, next_index)?;
|
||||
let (fourth, next_index) = T4::column(statement, next_index)?;
|
||||
Ok(((first, second, third, fourth), next_index))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: Bind, T2: Bind, T3: Bind, T4: Bind, T5: Bind> Bind for (T1, T2, T3, T4, T5) {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
let next_index = self.0.bind(statement, start_index)?;
|
||||
let next_index = self.1.bind(statement, next_index)?;
|
||||
let next_index = self.2.bind(statement, next_index)?;
|
||||
let next_index = self.3.bind(statement, next_index)?;
|
||||
self.4.bind(statement, next_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: Column, T2: Column, T3: Column, T4: Column, T5: Column> Column for (T1, T2, T3, T4, T5) {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let (first, next_index) = T1::column(statement, start_index)?;
|
||||
let (second, next_index) = T2::column(statement, next_index)?;
|
||||
let (third, next_index) = T3::column(statement, next_index)?;
|
||||
let (fourth, next_index) = T4::column(statement, next_index)?;
|
||||
let (fifth, next_index) = T5::column(statement, next_index)?;
|
||||
Ok(((first, second, third, fourth, fifth), next_index))
|
||||
}
|
||||
}
|
334
crates/sqlez/src/connection.rs
Normal file
334
crates/sqlez/src/connection.rs
Normal file
|
@ -0,0 +1,334 @@
|
|||
use std::{
|
||||
cell::RefCell,
|
||||
ffi::{CStr, CString},
|
||||
marker::PhantomData,
|
||||
path::Path,
|
||||
ptr,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use libsqlite3_sys::*;
|
||||
|
||||
pub struct Connection {
|
||||
pub(crate) sqlite3: *mut sqlite3,
|
||||
persistent: bool,
|
||||
pub(crate) write: RefCell<bool>,
|
||||
_sqlite: PhantomData<sqlite3>,
|
||||
}
|
||||
unsafe impl Send for Connection {}
|
||||
|
||||
impl Connection {
|
||||
pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
|
||||
let mut connection = Self {
|
||||
sqlite3: 0 as *mut _,
|
||||
persistent,
|
||||
write: RefCell::new(true),
|
||||
_sqlite: PhantomData,
|
||||
};
|
||||
|
||||
let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
|
||||
unsafe {
|
||||
sqlite3_open_v2(
|
||||
CString::new(uri)?.as_ptr(),
|
||||
&mut connection.sqlite3,
|
||||
flags,
|
||||
0 as *const _,
|
||||
);
|
||||
|
||||
// Turn on extended error codes
|
||||
sqlite3_extended_result_codes(connection.sqlite3, 1);
|
||||
|
||||
connection.last_error()?;
|
||||
}
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
/// Attempts to open the database at uri. If it fails, a shared memory db will be opened
|
||||
/// instead.
|
||||
pub fn open_file(uri: &str) -> Self {
|
||||
Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
|
||||
}
|
||||
|
||||
pub fn open_memory(uri: Option<&str>) -> Self {
|
||||
let in_memory_path = if let Some(uri) = uri {
|
||||
format!("file:{}?mode=memory&cache=shared", uri)
|
||||
} else {
|
||||
":memory:".to_string()
|
||||
};
|
||||
|
||||
Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
|
||||
}
|
||||
|
||||
pub fn persistent(&self) -> bool {
|
||||
self.persistent
|
||||
}
|
||||
|
||||
pub fn can_write(&self) -> bool {
|
||||
*self.write.borrow()
|
||||
}
|
||||
|
||||
pub fn backup_main(&self, destination: &Connection) -> Result<()> {
|
||||
unsafe {
|
||||
let backup = sqlite3_backup_init(
|
||||
destination.sqlite3,
|
||||
CString::new("main")?.as_ptr(),
|
||||
self.sqlite3,
|
||||
CString::new("main")?.as_ptr(),
|
||||
);
|
||||
sqlite3_backup_step(backup, -1);
|
||||
sqlite3_backup_finish(backup);
|
||||
destination.last_error()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
|
||||
let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
|
||||
self.backup_main(&destination)
|
||||
}
|
||||
|
||||
pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
|
||||
let sql = CString::new(sql).unwrap();
|
||||
let mut remaining_sql = sql.as_c_str();
|
||||
let sql_start = remaining_sql.as_ptr();
|
||||
|
||||
unsafe {
|
||||
while {
|
||||
let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
|
||||
remaining_sql_str != ";" && !remaining_sql_str.is_empty()
|
||||
} {
|
||||
let mut raw_statement = 0 as *mut sqlite3_stmt;
|
||||
let mut remaining_sql_ptr = ptr::null();
|
||||
sqlite3_prepare_v2(
|
||||
self.sqlite3,
|
||||
remaining_sql.as_ptr(),
|
||||
-1,
|
||||
&mut raw_statement,
|
||||
&mut remaining_sql_ptr,
|
||||
);
|
||||
|
||||
let res = sqlite3_errcode(self.sqlite3);
|
||||
let offset = sqlite3_error_offset(self.sqlite3);
|
||||
let message = sqlite3_errmsg(self.sqlite3);
|
||||
|
||||
sqlite3_finalize(raw_statement);
|
||||
|
||||
if res == 1 && offset >= 0 {
|
||||
let err_msg =
|
||||
String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
|
||||
.into_owned();
|
||||
let sub_statement_correction =
|
||||
remaining_sql.as_ptr() as usize - sql_start as usize;
|
||||
|
||||
return Some((err_msg, offset as usize + sub_statement_correction));
|
||||
}
|
||||
remaining_sql = CStr::from_ptr(remaining_sql_ptr);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn last_error(&self) -> Result<()> {
|
||||
unsafe {
|
||||
let code = sqlite3_errcode(self.sqlite3);
|
||||
const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
|
||||
if NON_ERROR_CODES.contains(&code) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let message = sqlite3_errmsg(self.sqlite3);
|
||||
let message = if message.is_null() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
|
||||
.into_owned(),
|
||||
)
|
||||
};
|
||||
|
||||
Err(anyhow!(
|
||||
"Sqlite call failed with code {} and message: {:?}",
|
||||
code as isize,
|
||||
message
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
|
||||
*self.write.borrow_mut() = true;
|
||||
let result = callback(self);
|
||||
*self.write.borrow_mut() = false;
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Connection {
|
||||
fn drop(&mut self) {
|
||||
unsafe { sqlite3_close(self.sqlite3) };
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use anyhow::Result;
|
||||
use indoc::indoc;
|
||||
|
||||
use crate::connection::Connection;
|
||||
|
||||
#[test]
|
||||
fn string_round_trips() -> Result<()> {
|
||||
let connection = Connection::open_memory(Some("string_round_trips"));
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE text (
|
||||
text TEXT
|
||||
);"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
let text = "Some test text";
|
||||
|
||||
connection
|
||||
.exec_bound("INSERT INTO text (text) VALUES (?);")
|
||||
.unwrap()(text)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
|
||||
Some(text.to_string())
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tuple_round_trips() {
|
||||
let connection = Connection::open_memory(Some("tuple_round_trips"));
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE test (
|
||||
text TEXT,
|
||||
integer INTEGER,
|
||||
blob BLOB
|
||||
);"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
|
||||
let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
|
||||
|
||||
let mut insert = connection
|
||||
.exec_bound::<(String, usize, Vec<u8>)>(
|
||||
"INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
insert(tuple1.clone()).unwrap();
|
||||
insert(tuple2.clone()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
connection
|
||||
.select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
|
||||
.unwrap()()
|
||||
.unwrap(),
|
||||
vec![tuple1, tuple2]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bool_round_trips() {
|
||||
let connection = Connection::open_memory(Some("bool_round_trips"));
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE bools (
|
||||
t INTEGER,
|
||||
f INTEGER
|
||||
);"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
connection
|
||||
.exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
|
||||
.unwrap()((true, false))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
connection
|
||||
.select_row::<(bool, bool)>("SELECT * FROM bools;")
|
||||
.unwrap()()
|
||||
.unwrap(),
|
||||
Some((true, false))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn backup_works() {
|
||||
let connection1 = Connection::open_memory(Some("backup_works"));
|
||||
connection1
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE blobs (
|
||||
data BLOB
|
||||
);"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
|
||||
connection1
|
||||
.exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
|
||||
.unwrap()(blob.clone())
|
||||
.unwrap();
|
||||
|
||||
// Backup connection1 to connection2
|
||||
let connection2 = Connection::open_memory(Some("backup_works_other"));
|
||||
connection1.backup_main(&connection2).unwrap();
|
||||
|
||||
// Delete the added blob and verify its deleted on the other side
|
||||
let read_blobs = connection1
|
||||
.select::<Vec<u8>>("SELECT * FROM blobs;")
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
assert_eq!(read_blobs, vec![blob]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_step_statement_works() {
|
||||
let connection = Connection::open_memory(Some("multi_step_statement_works"));
|
||||
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE test (
|
||||
col INTEGER
|
||||
)"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
INSERT INTO test(col) VALUES (2)"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
connection
|
||||
.select_row::<usize>("SELECT * FROM test")
|
||||
.unwrap()()
|
||||
.unwrap(),
|
||||
Some(2)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sql_has_syntax_errors() {
|
||||
let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
|
||||
let first_stmt =
|
||||
"CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
|
||||
let second_stmt = "SELECT FROM";
|
||||
|
||||
let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
|
||||
|
||||
let res = connection
|
||||
.sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
|
||||
.map(|(_, offset)| offset);
|
||||
|
||||
assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
|
||||
}
|
||||
}
|
56
crates/sqlez/src/domain.rs
Normal file
56
crates/sqlez/src/domain.rs
Normal file
|
@ -0,0 +1,56 @@
|
|||
use crate::connection::Connection;
|
||||
|
||||
pub trait Domain: 'static {
|
||||
fn name() -> &'static str;
|
||||
fn migrations() -> &'static [&'static str];
|
||||
}
|
||||
|
||||
pub trait Migrator: 'static {
|
||||
fn migrate(connection: &Connection) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
impl Migrator for () {
|
||||
fn migrate(_connection: &Connection) -> anyhow::Result<()> {
|
||||
Ok(()) // Do nothing
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Domain> Migrator for D {
|
||||
fn migrate(connection: &Connection) -> anyhow::Result<()> {
|
||||
connection.migrate(Self::name(), Self::migrations())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Domain, D2: Domain> Migrator for (D1, D2) {
|
||||
fn migrate(connection: &Connection) -> anyhow::Result<()> {
|
||||
D1::migrate(connection)?;
|
||||
D2::migrate(connection)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Domain, D2: Domain, D3: Domain> Migrator for (D1, D2, D3) {
|
||||
fn migrate(connection: &Connection) -> anyhow::Result<()> {
|
||||
D1::migrate(connection)?;
|
||||
D2::migrate(connection)?;
|
||||
D3::migrate(connection)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Domain, D2: Domain, D3: Domain, D4: Domain> Migrator for (D1, D2, D3, D4) {
|
||||
fn migrate(connection: &Connection) -> anyhow::Result<()> {
|
||||
D1::migrate(connection)?;
|
||||
D2::migrate(connection)?;
|
||||
D3::migrate(connection)?;
|
||||
D4::migrate(connection)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Domain, D2: Domain, D3: Domain, D4: Domain, D5: Domain> Migrator for (D1, D2, D3, D4, D5) {
|
||||
fn migrate(connection: &Connection) -> anyhow::Result<()> {
|
||||
D1::migrate(connection)?;
|
||||
D2::migrate(connection)?;
|
||||
D3::migrate(connection)?;
|
||||
D4::migrate(connection)?;
|
||||
D5::migrate(connection)
|
||||
}
|
||||
}
|
11
crates/sqlez/src/lib.rs
Normal file
11
crates/sqlez/src/lib.rs
Normal file
|
@ -0,0 +1,11 @@
|
|||
pub mod bindable;
|
||||
pub mod connection;
|
||||
pub mod domain;
|
||||
pub mod migrations;
|
||||
pub mod savepoint;
|
||||
pub mod statement;
|
||||
pub mod thread_safe_connection;
|
||||
pub mod typed_statements;
|
||||
mod util;
|
||||
|
||||
pub use anyhow;
|
260
crates/sqlez/src/migrations.rs
Normal file
260
crates/sqlez/src/migrations.rs
Normal file
|
@ -0,0 +1,260 @@
|
|||
// Migrations are constructed by domain, and stored in a table in the connection db with domain name,
|
||||
// effected tables, actual query text, and order.
|
||||
// If a migration is run and any of the query texts don't match, the app panics on startup (maybe fallback
|
||||
// to creating a new db?)
|
||||
// Otherwise any missing migrations are run on the connection
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use indoc::{formatdoc, indoc};
|
||||
|
||||
use crate::connection::Connection;
|
||||
|
||||
impl Connection {
|
||||
pub fn migrate(&self, domain: &'static str, migrations: &[&'static str]) -> Result<()> {
|
||||
self.with_savepoint("migrating", || {
|
||||
// Setup the migrations table unconditionally
|
||||
self.exec(indoc! {"
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
domain TEXT,
|
||||
step INTEGER,
|
||||
migration TEXT
|
||||
)"})?()?;
|
||||
|
||||
let completed_migrations =
|
||||
self.select_bound::<&str, (String, usize, String)>(indoc! {"
|
||||
SELECT domain, step, migration FROM migrations
|
||||
WHERE domain = ?
|
||||
ORDER BY step
|
||||
"})?(domain)?;
|
||||
|
||||
let mut store_completed_migration = self
|
||||
.exec_bound("INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)")?;
|
||||
|
||||
for (index, migration) in migrations.iter().enumerate() {
|
||||
if let Some((_, _, completed_migration)) = completed_migrations.get(index) {
|
||||
if completed_migration != migration {
|
||||
return Err(anyhow!(formatdoc! {"
|
||||
Migration changed for {} at step {}
|
||||
|
||||
Stored migration:
|
||||
{}
|
||||
|
||||
Proposed migration:
|
||||
{}", domain, index, completed_migration, migration}));
|
||||
} else {
|
||||
// Migration already run. Continue
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
self.exec(migration)?()?;
|
||||
store_completed_migration((domain, index, *migration))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use indoc::indoc;
|
||||
|
||||
use crate::connection::Connection;
|
||||
|
||||
#[test]
|
||||
fn test_migrations_are_added_to_table() {
|
||||
let connection = Connection::open_memory(Some("migrations_are_added_to_table"));
|
||||
|
||||
// Create first migration with a single step and run it
|
||||
connection
|
||||
.migrate(
|
||||
"test",
|
||||
&[indoc! {"
|
||||
CREATE TABLE test1 (
|
||||
a TEXT,
|
||||
b TEXT
|
||||
)"}],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Verify it got added to the migrations table
|
||||
assert_eq!(
|
||||
&connection
|
||||
.select::<String>("SELECT (migration) FROM migrations")
|
||||
.unwrap()()
|
||||
.unwrap()[..],
|
||||
&[indoc! {"
|
||||
CREATE TABLE test1 (
|
||||
a TEXT,
|
||||
b TEXT
|
||||
)"}],
|
||||
);
|
||||
|
||||
// Add another step to the migration and run it again
|
||||
connection
|
||||
.migrate(
|
||||
"test",
|
||||
&[
|
||||
indoc! {"
|
||||
CREATE TABLE test1 (
|
||||
a TEXT,
|
||||
b TEXT
|
||||
)"},
|
||||
indoc! {"
|
||||
CREATE TABLE test2 (
|
||||
c TEXT,
|
||||
d TEXT
|
||||
)"},
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Verify it is also added to the migrations table
|
||||
assert_eq!(
|
||||
&connection
|
||||
.select::<String>("SELECT (migration) FROM migrations")
|
||||
.unwrap()()
|
||||
.unwrap()[..],
|
||||
&[
|
||||
indoc! {"
|
||||
CREATE TABLE test1 (
|
||||
a TEXT,
|
||||
b TEXT
|
||||
)"},
|
||||
indoc! {"
|
||||
CREATE TABLE test2 (
|
||||
c TEXT,
|
||||
d TEXT
|
||||
)"},
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_migration_setup_works() {
|
||||
let connection = Connection::open_memory(Some("migration_setup_works"));
|
||||
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
domain TEXT,
|
||||
step INTEGER,
|
||||
migration TEXT
|
||||
);"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
let mut store_completed_migration = connection
|
||||
.exec_bound::<(&str, usize, String)>(indoc! {"
|
||||
INSERT INTO migrations (domain, step, migration)
|
||||
VALUES (?, ?, ?)"})
|
||||
.unwrap();
|
||||
|
||||
let domain = "test_domain";
|
||||
for i in 0..5 {
|
||||
// Create a table forcing a schema change
|
||||
connection
|
||||
.exec(&format!("CREATE TABLE table{} ( test TEXT );", i))
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
store_completed_migration((domain, i, i.to_string())).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migrations_dont_rerun() {
|
||||
let connection = Connection::open_memory(Some("migrations_dont_rerun"));
|
||||
|
||||
// Create migration which clears a tabl
|
||||
|
||||
// Manually create the table for that migration with a row
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE test_table (
|
||||
test_column INTEGER
|
||||
);"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
INSERT INTO test_table (test_column) VALUES (1);"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
connection
|
||||
.select_row::<usize>("SELECT * FROM test_table")
|
||||
.unwrap()()
|
||||
.unwrap(),
|
||||
Some(1)
|
||||
);
|
||||
|
||||
// Run the migration verifying that the row got dropped
|
||||
connection
|
||||
.migrate("test", &["DELETE FROM test_table"])
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
connection
|
||||
.select_row::<usize>("SELECT * FROM test_table")
|
||||
.unwrap()()
|
||||
.unwrap(),
|
||||
None
|
||||
);
|
||||
|
||||
// Recreate the dropped row
|
||||
connection
|
||||
.exec("INSERT INTO test_table (test_column) VALUES (2)")
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
// Run the same migration again and verify that the table was left unchanged
|
||||
connection
|
||||
.migrate("test", &["DELETE FROM test_table"])
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
connection
|
||||
.select_row::<usize>("SELECT * FROM test_table")
|
||||
.unwrap()()
|
||||
.unwrap(),
|
||||
Some(2)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn changed_migration_fails() {
|
||||
let connection = Connection::open_memory(Some("changed_migration_fails"));
|
||||
|
||||
// Create a migration with two steps and run it
|
||||
connection
|
||||
.migrate(
|
||||
"test migration",
|
||||
&[
|
||||
indoc! {"
|
||||
CREATE TABLE test (
|
||||
col INTEGER
|
||||
)"},
|
||||
indoc! {"
|
||||
INSERT INTO test (col) VALUES (1)"},
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Create another migration with the same domain but different steps
|
||||
let second_migration_result = connection.migrate(
|
||||
"test migration",
|
||||
&[
|
||||
indoc! {"
|
||||
CREATE TABLE test (
|
||||
color INTEGER
|
||||
)"},
|
||||
indoc! {"
|
||||
INSERT INTO test (color) VALUES (1)"},
|
||||
],
|
||||
);
|
||||
|
||||
// Verify new migration returns error when run
|
||||
assert!(second_migration_result.is_err())
|
||||
}
|
||||
}
|
148
crates/sqlez/src/savepoint.rs
Normal file
148
crates/sqlez/src/savepoint.rs
Normal file
|
@ -0,0 +1,148 @@
|
|||
use anyhow::Result;
|
||||
use indoc::formatdoc;
|
||||
|
||||
use crate::connection::Connection;
|
||||
|
||||
impl Connection {
|
||||
// Run a set of commands within the context of a `SAVEPOINT name`. If the callback
|
||||
// returns Err(_), the savepoint will be rolled back. Otherwise, the save
|
||||
// point is released.
|
||||
pub fn with_savepoint<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<R>
|
||||
where
|
||||
F: FnOnce() -> Result<R>,
|
||||
{
|
||||
let name = name.as_ref();
|
||||
self.exec(&format!("SAVEPOINT {name}"))?()?;
|
||||
let result = f();
|
||||
match result {
|
||||
Ok(_) => {
|
||||
self.exec(&format!("RELEASE {name}"))?()?;
|
||||
}
|
||||
Err(_) => {
|
||||
self.exec(&formatdoc! {"
|
||||
ROLLBACK TO {name};
|
||||
RELEASE {name}"})?()?;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
// Run a set of commands within the context of a `SAVEPOINT name`. If the callback
|
||||
// returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
|
||||
// point is released.
|
||||
pub fn with_savepoint_rollback<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
|
||||
where
|
||||
F: FnOnce() -> Result<Option<R>>,
|
||||
{
|
||||
let name = name.as_ref();
|
||||
self.exec(&format!("SAVEPOINT {name}"))?()?;
|
||||
let result = f();
|
||||
match result {
|
||||
Ok(Some(_)) => {
|
||||
self.exec(&format!("RELEASE {name}"))?()?;
|
||||
}
|
||||
Ok(None) | Err(_) => {
|
||||
self.exec(&formatdoc! {"
|
||||
ROLLBACK TO {name};
|
||||
RELEASE {name}"})?()?;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::connection::Connection;
|
||||
use anyhow::Result;
|
||||
use indoc::indoc;
|
||||
|
||||
#[test]
|
||||
fn test_nested_savepoints() -> Result<()> {
|
||||
let connection = Connection::open_memory(Some("nested_savepoints"));
|
||||
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE text (
|
||||
text TEXT,
|
||||
idx INTEGER
|
||||
);"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
let save1_text = "test save1";
|
||||
let save2_text = "test save2";
|
||||
|
||||
connection.with_savepoint("first", || {
|
||||
connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((save1_text, 1))?;
|
||||
|
||||
assert!(connection
|
||||
.with_savepoint("second", || -> Result<Option<()>, anyhow::Error> {
|
||||
connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
|
||||
save2_text, 2,
|
||||
))?;
|
||||
|
||||
assert_eq!(
|
||||
connection
|
||||
.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?(
|
||||
)?,
|
||||
vec![save1_text, save2_text],
|
||||
);
|
||||
|
||||
anyhow::bail!("Failed second save point :(")
|
||||
})
|
||||
.err()
|
||||
.is_some());
|
||||
|
||||
assert_eq!(
|
||||
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
|
||||
vec![save1_text],
|
||||
);
|
||||
|
||||
connection.with_savepoint_rollback::<(), _>("second", || {
|
||||
connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
|
||||
save2_text, 2,
|
||||
))?;
|
||||
|
||||
assert_eq!(
|
||||
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
|
||||
vec![save1_text, save2_text],
|
||||
);
|
||||
|
||||
Ok(None)
|
||||
})?;
|
||||
|
||||
assert_eq!(
|
||||
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
|
||||
vec![save1_text],
|
||||
);
|
||||
|
||||
connection.with_savepoint_rollback("second", || {
|
||||
connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
|
||||
save2_text, 2,
|
||||
))?;
|
||||
|
||||
assert_eq!(
|
||||
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
|
||||
vec![save1_text, save2_text],
|
||||
);
|
||||
|
||||
Ok(Some(()))
|
||||
})?;
|
||||
|
||||
assert_eq!(
|
||||
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
|
||||
vec![save1_text, save2_text],
|
||||
);
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
assert_eq!(
|
||||
connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
|
||||
vec![save1_text, save2_text],
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
491
crates/sqlez/src/statement.rs
Normal file
491
crates/sqlez/src/statement.rs
Normal file
|
@ -0,0 +1,491 @@
|
|||
use std::ffi::{c_int, CStr, CString};
|
||||
use std::marker::PhantomData;
|
||||
use std::{ptr, slice, str};
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use libsqlite3_sys::*;
|
||||
|
||||
use crate::bindable::{Bind, Column};
|
||||
use crate::connection::Connection;
|
||||
|
||||
pub struct Statement<'a> {
|
||||
raw_statements: Vec<*mut sqlite3_stmt>,
|
||||
current_statement: usize,
|
||||
connection: &'a Connection,
|
||||
phantom: PhantomData<sqlite3_stmt>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub enum StepResult {
|
||||
Row,
|
||||
Done,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub enum SqlType {
|
||||
Text,
|
||||
Integer,
|
||||
Blob,
|
||||
Float,
|
||||
Null,
|
||||
}
|
||||
|
||||
impl<'a> Statement<'a> {
|
||||
pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
|
||||
let mut statement = Self {
|
||||
raw_statements: Default::default(),
|
||||
current_statement: 0,
|
||||
connection,
|
||||
phantom: PhantomData,
|
||||
};
|
||||
unsafe {
|
||||
let sql = CString::new(query.as_ref()).context("Error creating cstr")?;
|
||||
let mut remaining_sql = sql.as_c_str();
|
||||
while {
|
||||
let remaining_sql_str = remaining_sql
|
||||
.to_str()
|
||||
.context("Parsing remaining sql")?
|
||||
.trim();
|
||||
remaining_sql_str != ";" && !remaining_sql_str.is_empty()
|
||||
} {
|
||||
let mut raw_statement = 0 as *mut sqlite3_stmt;
|
||||
let mut remaining_sql_ptr = ptr::null();
|
||||
sqlite3_prepare_v2(
|
||||
connection.sqlite3,
|
||||
remaining_sql.as_ptr(),
|
||||
-1,
|
||||
&mut raw_statement,
|
||||
&mut remaining_sql_ptr,
|
||||
);
|
||||
|
||||
remaining_sql = CStr::from_ptr(remaining_sql_ptr);
|
||||
statement.raw_statements.push(raw_statement);
|
||||
|
||||
connection.last_error().with_context(|| {
|
||||
format!("Prepare call failed for query:\n{}", query.as_ref())
|
||||
})?;
|
||||
|
||||
if !connection.can_write() && sqlite3_stmt_readonly(raw_statement) == 0 {
|
||||
let sql = CStr::from_ptr(sqlite3_sql(raw_statement));
|
||||
|
||||
bail!(
|
||||
"Write statement prepared with connection that is not write capable. SQL:\n{} ",
|
||||
sql.to_str()?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(statement)
|
||||
}
|
||||
|
||||
fn current_statement(&self) -> *mut sqlite3_stmt {
|
||||
*self.raw_statements.get(self.current_statement).unwrap()
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
unsafe {
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_reset(*raw_statement);
|
||||
}
|
||||
}
|
||||
self.current_statement = 0;
|
||||
}
|
||||
|
||||
pub fn parameter_count(&self) -> i32 {
|
||||
unsafe {
|
||||
self.raw_statements
|
||||
.iter()
|
||||
.map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn bind_index_with(&self, index: i32, bind: impl Fn(&*mut sqlite3_stmt) -> ()) -> Result<()> {
|
||||
let mut any_succeed = false;
|
||||
unsafe {
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
if index <= sqlite3_bind_parameter_count(*raw_statement) {
|
||||
bind(raw_statement);
|
||||
self.connection
|
||||
.last_error()
|
||||
.with_context(|| format!("Failed to bind value at index {index}"))?;
|
||||
any_succeed = true;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
if any_succeed {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Failed to bind parameters"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
|
||||
let index = index as c_int;
|
||||
let blob_pointer = blob.as_ptr() as *const _;
|
||||
let len = blob.len() as c_int;
|
||||
|
||||
self.bind_index_with(index, |raw_statement| unsafe {
|
||||
sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
|
||||
})
|
||||
}
|
||||
|
||||
pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
|
||||
let index = index as c_int;
|
||||
let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
|
||||
|
||||
self.connection
|
||||
.last_error()
|
||||
.with_context(|| format!("Failed to read blob at index {index}"))?;
|
||||
if pointer.is_null() {
|
||||
return Ok(&[]);
|
||||
}
|
||||
let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
|
||||
self.connection
|
||||
.last_error()
|
||||
.with_context(|| format!("Failed to read length of blob at index {index}"))?;
|
||||
|
||||
unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
|
||||
}
|
||||
|
||||
pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
|
||||
let index = index as c_int;
|
||||
|
||||
self.bind_index_with(index, |raw_statement| unsafe {
|
||||
sqlite3_bind_double(*raw_statement, index, double);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn column_double(&self, index: i32) -> Result<f64> {
|
||||
let index = index as c_int;
|
||||
let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
|
||||
self.connection
|
||||
.last_error()
|
||||
.with_context(|| format!("Failed to read double at index {index}"))?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
|
||||
let index = index as c_int;
|
||||
self.bind_index_with(index, |raw_statement| unsafe {
|
||||
sqlite3_bind_int(*raw_statement, index, int);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn column_int(&self, index: i32) -> Result<i32> {
|
||||
let index = index as c_int;
|
||||
let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
|
||||
self.connection
|
||||
.last_error()
|
||||
.with_context(|| format!("Failed to read int at index {index}"))?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
|
||||
let index = index as c_int;
|
||||
self.bind_index_with(index, |raw_statement| unsafe {
|
||||
sqlite3_bind_int64(*raw_statement, index, int);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn column_int64(&self, index: i32) -> Result<i64> {
|
||||
let index = index as c_int;
|
||||
let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
|
||||
self.connection
|
||||
.last_error()
|
||||
.with_context(|| format!("Failed to read i64 at index {index}"))?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn bind_null(&self, index: i32) -> Result<()> {
|
||||
let index = index as c_int;
|
||||
self.bind_index_with(index, |raw_statement| unsafe {
|
||||
sqlite3_bind_null(*raw_statement, index);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
|
||||
let index = index as c_int;
|
||||
let text_pointer = text.as_ptr() as *const _;
|
||||
let len = text.len() as c_int;
|
||||
|
||||
self.bind_index_with(index, |raw_statement| unsafe {
|
||||
sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
|
||||
})
|
||||
}
|
||||
|
||||
pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
|
||||
let index = index as c_int;
|
||||
let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
|
||||
|
||||
self.connection
|
||||
.last_error()
|
||||
.with_context(|| format!("Failed to read text from column {index}"))?;
|
||||
if pointer.is_null() {
|
||||
return Ok("");
|
||||
}
|
||||
let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
|
||||
self.connection
|
||||
.last_error()
|
||||
.with_context(|| format!("Failed to read text length at {index}"))?;
|
||||
|
||||
let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
|
||||
Ok(str::from_utf8(slice)?)
|
||||
}
|
||||
|
||||
pub fn bind<T: Bind>(&self, value: T, index: i32) -> Result<i32> {
|
||||
debug_assert!(index > 0);
|
||||
value.bind(self, index)
|
||||
}
|
||||
|
||||
pub fn column<T: Column>(&mut self) -> Result<T> {
|
||||
let (result, _) = T::column(self, 0)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
|
||||
let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
|
||||
self.connection.last_error()?;
|
||||
match result {
|
||||
SQLITE_INTEGER => Ok(SqlType::Integer),
|
||||
SQLITE_FLOAT => Ok(SqlType::Float),
|
||||
SQLITE_TEXT => Ok(SqlType::Text),
|
||||
SQLITE_BLOB => Ok(SqlType::Blob),
|
||||
SQLITE_NULL => Ok(SqlType::Null),
|
||||
_ => Err(anyhow!("Column type returned was incorrect ")),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> {
|
||||
self.bind(bindings, 1)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn step(&mut self) -> Result<StepResult> {
|
||||
unsafe {
|
||||
match sqlite3_step(self.current_statement()) {
|
||||
SQLITE_ROW => Ok(StepResult::Row),
|
||||
SQLITE_DONE => {
|
||||
if self.current_statement >= self.raw_statements.len() - 1 {
|
||||
Ok(StepResult::Done)
|
||||
} else {
|
||||
self.current_statement += 1;
|
||||
self.step()
|
||||
}
|
||||
}
|
||||
SQLITE_MISUSE => Err(anyhow!("Statement step returned SQLITE_MISUSE")),
|
||||
_other_error => {
|
||||
self.connection.last_error()?;
|
||||
unreachable!("Step returned error code and last error failed to catch it");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exec(&mut self) -> Result<()> {
|
||||
fn logic(this: &mut Statement) -> Result<()> {
|
||||
while this.step()? == StepResult::Row {}
|
||||
Ok(())
|
||||
}
|
||||
let result = logic(self);
|
||||
self.reset();
|
||||
result
|
||||
}
|
||||
|
||||
pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
|
||||
fn logic<R>(
|
||||
this: &mut Statement,
|
||||
mut callback: impl FnMut(&mut Statement) -> Result<R>,
|
||||
) -> Result<Vec<R>> {
|
||||
let mut mapped_rows = Vec::new();
|
||||
while this.step()? == StepResult::Row {
|
||||
mapped_rows.push(callback(this)?);
|
||||
}
|
||||
Ok(mapped_rows)
|
||||
}
|
||||
|
||||
let result = logic(self, callback);
|
||||
self.reset();
|
||||
result
|
||||
}
|
||||
|
||||
pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
|
||||
self.map(|s| s.column::<R>())
|
||||
}
|
||||
|
||||
pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
|
||||
fn logic<R>(
|
||||
this: &mut Statement,
|
||||
callback: impl FnOnce(&mut Statement) -> Result<R>,
|
||||
) -> Result<R> {
|
||||
if this.step()? != StepResult::Row {
|
||||
return Err(anyhow!("single called with query that returns no rows."));
|
||||
}
|
||||
let result = callback(this)?;
|
||||
|
||||
if this.step()? != StepResult::Done {
|
||||
return Err(anyhow!(
|
||||
"single called with a query that returns more than one row."
|
||||
));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
let result = logic(self, callback);
|
||||
self.reset();
|
||||
result
|
||||
}
|
||||
|
||||
pub fn row<R: Column>(&mut self) -> Result<R> {
|
||||
self.single(|this| this.column::<R>())
|
||||
}
|
||||
|
||||
pub fn maybe<R>(
|
||||
&mut self,
|
||||
callback: impl FnOnce(&mut Statement) -> Result<R>,
|
||||
) -> Result<Option<R>> {
|
||||
fn logic<R>(
|
||||
this: &mut Statement,
|
||||
callback: impl FnOnce(&mut Statement) -> Result<R>,
|
||||
) -> Result<Option<R>> {
|
||||
if this.step().context("Failed on step call")? != StepResult::Row {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let result = callback(this)
|
||||
.map(|r| Some(r))
|
||||
.context("Failed to parse row result")?;
|
||||
|
||||
if this.step().context("Second step call")? != StepResult::Done {
|
||||
return Err(anyhow!(
|
||||
"maybe called with a query that returns more than one row."
|
||||
));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
let result = logic(self, callback);
|
||||
self.reset();
|
||||
result
|
||||
}
|
||||
|
||||
pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
|
||||
self.maybe(|this| this.column::<R>())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for Statement<'a> {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_finalize(*raw_statement);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use indoc::indoc;
|
||||
|
||||
use crate::{
|
||||
connection::Connection,
|
||||
statement::{Statement, StepResult},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn binding_multiple_statements_with_parameter_gaps() {
|
||||
let connection =
|
||||
Connection::open_memory(Some("binding_multiple_statements_with_parameter_gaps"));
|
||||
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE test (
|
||||
col INTEGER
|
||||
)"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
let statement = Statement::prepare(
|
||||
&connection,
|
||||
indoc! {"
|
||||
INSERT INTO test(col) VALUES (?3);
|
||||
SELECT * FROM test WHERE col = ?1"},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
statement
|
||||
.bind_int(1, 1)
|
||||
.expect("Could not bind parameter to first index");
|
||||
statement
|
||||
.bind_int(2, 2)
|
||||
.expect("Could not bind parameter to second index");
|
||||
statement
|
||||
.bind_int(3, 3)
|
||||
.expect("Could not bind parameter to third index");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blob_round_trips() {
|
||||
let connection1 = Connection::open_memory(Some("blob_round_trips"));
|
||||
connection1
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE blobs (
|
||||
data BLOB
|
||||
)"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
|
||||
|
||||
let mut write =
|
||||
Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
|
||||
write.bind_blob(1, blob).unwrap();
|
||||
assert_eq!(write.step().unwrap(), StepResult::Done);
|
||||
|
||||
// Read the blob from the
|
||||
let connection2 = Connection::open_memory(Some("blob_round_trips"));
|
||||
let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
|
||||
assert_eq!(read.step().unwrap(), StepResult::Row);
|
||||
assert_eq!(read.column_blob(0).unwrap(), blob);
|
||||
assert_eq!(read.step().unwrap(), StepResult::Done);
|
||||
|
||||
// Delete the added blob and verify its deleted on the other side
|
||||
connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
|
||||
let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
|
||||
assert_eq!(read.step().unwrap(), StepResult::Done);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn maybe_returns_options() {
|
||||
let connection = Connection::open_memory(Some("maybe_returns_options"));
|
||||
connection
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE texts (
|
||||
text TEXT
|
||||
)"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
assert!(connection
|
||||
.select_row::<String>("SELECT text FROM texts")
|
||||
.unwrap()()
|
||||
.unwrap()
|
||||
.is_none());
|
||||
|
||||
let text_to_insert = "This is a test";
|
||||
|
||||
connection
|
||||
.exec_bound("INSERT INTO texts VALUES (?)")
|
||||
.unwrap()(text_to_insert)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
|
||||
Some(text_to_insert.to_string())
|
||||
);
|
||||
}
|
||||
}
|
359
crates/sqlez/src/thread_safe_connection.rs
Normal file
359
crates/sqlez/src/thread_safe_connection.rs
Normal file
|
@ -0,0 +1,359 @@
|
|||
use anyhow::Context;
|
||||
use futures::{channel::oneshot, Future, FutureExt};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
|
||||
use thread_local::ThreadLocal;
|
||||
|
||||
use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
|
||||
|
||||
const MIGRATION_RETRIES: usize = 10;
|
||||
|
||||
type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
|
||||
type WriteQueueConstructor =
|
||||
Box<dyn 'static + Send + FnMut() -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
|
||||
lazy_static! {
|
||||
/// List of queues of tasks by database uri. This lets us serialize writes to the database
|
||||
/// and have a single worker thread per db file. This means many thread safe connections
|
||||
/// (possibly with different migrations) could all be communicating with the same background
|
||||
/// thread.
|
||||
static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
|
||||
Default::default();
|
||||
}
|
||||
|
||||
/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
|
||||
/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
|
||||
/// may be accessed by passing a callback to the `write` function which will queue the callback
|
||||
pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
|
||||
uri: Arc<str>,
|
||||
persistent: bool,
|
||||
connection_initialize_query: Option<&'static str>,
|
||||
connections: Arc<ThreadLocal<Connection>>,
|
||||
_migrator: PhantomData<*mut M>,
|
||||
}
|
||||
|
||||
unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
|
||||
unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
|
||||
|
||||
pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
|
||||
db_initialize_query: Option<&'static str>,
|
||||
write_queue_constructor: Option<WriteQueueConstructor>,
|
||||
connection: ThreadSafeConnection<M>,
|
||||
}
|
||||
|
||||
impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
|
||||
/// Sets the query to run every time a connection is opened. This must
|
||||
/// be infallible (EG only use pragma statements) and not cause writes.
|
||||
/// to the db or it will panic.
|
||||
pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
|
||||
self.connection.connection_initialize_query = Some(initialize_query);
|
||||
self
|
||||
}
|
||||
|
||||
/// Queues an initialization query for the database file. This must be infallible
|
||||
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
|
||||
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
|
||||
self.db_initialize_query = Some(initialize_query);
|
||||
self
|
||||
}
|
||||
|
||||
/// Specifies how the thread safe connection should serialize writes. If provided
|
||||
/// the connection will call the write_queue_constructor for each database file in
|
||||
/// this process. The constructor is responsible for setting up a background thread or
|
||||
/// async task which handles queued writes with the provided connection.
|
||||
pub fn with_write_queue_constructor(
|
||||
mut self,
|
||||
write_queue_constructor: WriteQueueConstructor,
|
||||
) -> Self {
|
||||
self.write_queue_constructor = Some(write_queue_constructor);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
|
||||
self.connection
|
||||
.initialize_queues(self.write_queue_constructor);
|
||||
|
||||
let db_initialize_query = self.db_initialize_query;
|
||||
|
||||
self.connection
|
||||
.write(move |connection| {
|
||||
if let Some(db_initialize_query) = db_initialize_query {
|
||||
connection.exec(db_initialize_query).with_context(|| {
|
||||
format!(
|
||||
"Db initialize query failed to execute: {}",
|
||||
db_initialize_query
|
||||
)
|
||||
})?()?;
|
||||
}
|
||||
|
||||
// Retry failed migrations in case they were run in parallel from different
|
||||
// processes. This gives a best attempt at migrating before bailing
|
||||
let mut migration_result =
|
||||
anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
|
||||
|
||||
for _ in 0..MIGRATION_RETRIES {
|
||||
migration_result = connection
|
||||
.with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
|
||||
|
||||
if migration_result.is_ok() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
migration_result
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(self.connection)
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Migrator> ThreadSafeConnection<M> {
|
||||
fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
|
||||
if !QUEUES.read().contains_key(&self.uri) {
|
||||
let mut queues = QUEUES.write();
|
||||
if !queues.contains_key(&self.uri) {
|
||||
let mut write_queue_constructor =
|
||||
write_queue_constructor.unwrap_or(background_thread_queue());
|
||||
queues.insert(self.uri.clone(), write_queue_constructor());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
|
||||
ThreadSafeConnectionBuilder::<M> {
|
||||
db_initialize_query: None,
|
||||
write_queue_constructor: None,
|
||||
connection: Self {
|
||||
uri: Arc::from(uri),
|
||||
persistent,
|
||||
connection_initialize_query: None,
|
||||
connections: Default::default(),
|
||||
_migrator: PhantomData,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Opens a new db connection with the initialized file path. This is internal and only
|
||||
/// called from the deref function.
|
||||
fn open_file(uri: &str) -> Connection {
|
||||
Connection::open_file(uri)
|
||||
}
|
||||
|
||||
/// Opens a shared memory connection using the file path as the identifier. This is internal
|
||||
/// and only called from the deref function.
|
||||
fn open_shared_memory(uri: &str) -> Connection {
|
||||
Connection::open_memory(Some(uri))
|
||||
}
|
||||
|
||||
pub fn write<T: 'static + Send + Sync>(
|
||||
&self,
|
||||
callback: impl 'static + Send + FnOnce(&Connection) -> T,
|
||||
) -> impl Future<Output = T> {
|
||||
// Check and invalidate queue and maybe recreate queue
|
||||
let queues = QUEUES.read();
|
||||
let write_channel = queues
|
||||
.get(&self.uri)
|
||||
.expect("Queues are inserted when build is called. This should always succeed");
|
||||
|
||||
// Create a one shot channel for the result of the queued write
|
||||
// so we can await on the result
|
||||
let (sender, reciever) = oneshot::channel();
|
||||
|
||||
let thread_safe_connection = (*self).clone();
|
||||
write_channel(Box::new(move || {
|
||||
let connection = thread_safe_connection.deref();
|
||||
let result = connection.with_write(|connection| callback(connection));
|
||||
sender.send(result).ok();
|
||||
}));
|
||||
reciever.map(|response| response.expect("Write queue unexpectedly closed"))
|
||||
}
|
||||
|
||||
pub(crate) fn create_connection(
|
||||
persistent: bool,
|
||||
uri: &str,
|
||||
connection_initialize_query: Option<&'static str>,
|
||||
) -> Connection {
|
||||
let mut connection = if persistent {
|
||||
Self::open_file(uri)
|
||||
} else {
|
||||
Self::open_shared_memory(uri)
|
||||
};
|
||||
|
||||
// Disallow writes on the connection. The only writes allowed for thread safe connections
|
||||
// are from the background thread that can serialize them.
|
||||
*connection.write.get_mut() = false;
|
||||
|
||||
if let Some(initialize_query) = connection_initialize_query {
|
||||
connection.exec(initialize_query).expect(&format!(
|
||||
"Initialize query failed to execute: {}",
|
||||
initialize_query
|
||||
))()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
connection
|
||||
}
|
||||
}
|
||||
|
||||
impl ThreadSafeConnection<()> {
|
||||
/// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
|
||||
/// This allows construction to be infallible and not write to the db.
|
||||
pub fn new(
|
||||
uri: &str,
|
||||
persistent: bool,
|
||||
connection_initialize_query: Option<&'static str>,
|
||||
write_queue_constructor: Option<WriteQueueConstructor>,
|
||||
) -> Self {
|
||||
let connection = Self {
|
||||
uri: Arc::from(uri),
|
||||
persistent,
|
||||
connection_initialize_query,
|
||||
connections: Default::default(),
|
||||
_migrator: PhantomData,
|
||||
};
|
||||
|
||||
connection.initialize_queues(write_queue_constructor);
|
||||
connection
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Migrator> Clone for ThreadSafeConnection<M> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
uri: self.uri.clone(),
|
||||
persistent: self.persistent,
|
||||
connection_initialize_query: self.connection_initialize_query.clone(),
|
||||
connections: self.connections.clone(),
|
||||
_migrator: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Migrator> Deref for ThreadSafeConnection<M> {
|
||||
type Target = Connection;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.connections.get_or(|| {
|
||||
Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn background_thread_queue() -> WriteQueueConstructor {
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
Box::new(|| {
|
||||
let (sender, reciever) = channel::<QueuedWrite>();
|
||||
|
||||
thread::spawn(move || {
|
||||
while let Ok(write) = reciever.recv() {
|
||||
write()
|
||||
}
|
||||
});
|
||||
|
||||
let sender = UnboundedSyncSender::new(sender);
|
||||
Box::new(move |queued_write| {
|
||||
sender
|
||||
.send(queued_write)
|
||||
.expect("Could not send write action to background thread");
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn locking_queue() -> WriteQueueConstructor {
|
||||
Box::new(|| {
|
||||
let write_mutex = Mutex::new(());
|
||||
Box::new(move |queued_write| {
|
||||
let _lock = write_mutex.lock();
|
||||
queued_write();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use indoc::indoc;
|
||||
use lazy_static::__Deref;
|
||||
|
||||
use std::thread;
|
||||
|
||||
use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
|
||||
|
||||
#[test]
|
||||
fn many_initialize_and_migrate_queries_at_once() {
|
||||
let mut handles = vec![];
|
||||
|
||||
enum TestDomain {}
|
||||
impl Domain for TestDomain {
|
||||
fn name() -> &'static str {
|
||||
"test"
|
||||
}
|
||||
fn migrations() -> &'static [&'static str] {
|
||||
&["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
|
||||
}
|
||||
}
|
||||
|
||||
for _ in 0..100 {
|
||||
handles.push(thread::spawn(|| {
|
||||
let builder =
|
||||
ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
|
||||
.with_db_initialization_query("PRAGMA journal_mode=WAL")
|
||||
.with_connection_initialize_query(indoc! {"
|
||||
PRAGMA synchronous=NORMAL;
|
||||
PRAGMA busy_timeout=1;
|
||||
PRAGMA foreign_keys=TRUE;
|
||||
PRAGMA case_sensitive_like=TRUE;
|
||||
"});
|
||||
|
||||
let _ = smol::block_on(builder.build()).unwrap().deref();
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
let _ = handle.join();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn wild_zed_lost_failure() {
|
||||
enum TestWorkspace {}
|
||||
impl Domain for TestWorkspace {
|
||||
fn name() -> &'static str {
|
||||
"workspace"
|
||||
}
|
||||
|
||||
fn migrations() -> &'static [&'static str] {
|
||||
&["
|
||||
CREATE TABLE workspaces(
|
||||
workspace_id INTEGER PRIMARY KEY,
|
||||
dock_visible INTEGER, -- Boolean
|
||||
dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
|
||||
dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
|
||||
timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
|
||||
FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE panes(
|
||||
pane_id INTEGER PRIMARY KEY,
|
||||
workspace_id INTEGER NOT NULL,
|
||||
active INTEGER NOT NULL, -- Boolean
|
||||
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
|
||||
ON DELETE CASCADE
|
||||
ON UPDATE CASCADE
|
||||
) STRICT;
|
||||
"]
|
||||
}
|
||||
}
|
||||
|
||||
let builder =
|
||||
ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
|
||||
.with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
|
||||
|
||||
smol::block_on(builder.build()).unwrap();
|
||||
}
|
||||
}
|
60
crates/sqlez/src/typed_statements.rs
Normal file
60
crates/sqlez/src/typed_statements.rs
Normal file
|
@ -0,0 +1,60 @@
|
|||
use anyhow::{Context, Result};
|
||||
|
||||
use crate::{
|
||||
bindable::{Bind, Column},
|
||||
connection::Connection,
|
||||
statement::Statement,
|
||||
};
|
||||
|
||||
impl Connection {
|
||||
pub fn exec<'a>(&'a self, query: &str) -> Result<impl 'a + FnMut() -> Result<()>> {
|
||||
let mut statement = Statement::prepare(&self, query)?;
|
||||
Ok(move || statement.exec())
|
||||
}
|
||||
|
||||
pub fn exec_bound<'a, B: Bind>(
|
||||
&'a self,
|
||||
query: &str,
|
||||
) -> Result<impl 'a + FnMut(B) -> Result<()>> {
|
||||
let mut statement = Statement::prepare(&self, query)?;
|
||||
Ok(move |bindings| statement.with_bindings(bindings)?.exec())
|
||||
}
|
||||
|
||||
pub fn select<'a, C: Column>(
|
||||
&'a self,
|
||||
query: &str,
|
||||
) -> Result<impl 'a + FnMut() -> Result<Vec<C>>> {
|
||||
let mut statement = Statement::prepare(&self, query)?;
|
||||
Ok(move || statement.rows::<C>())
|
||||
}
|
||||
|
||||
pub fn select_bound<'a, B: Bind, C: Column>(
|
||||
&'a self,
|
||||
query: &str,
|
||||
) -> Result<impl 'a + FnMut(B) -> Result<Vec<C>>> {
|
||||
let mut statement = Statement::prepare(&self, query)?;
|
||||
Ok(move |bindings| statement.with_bindings(bindings)?.rows::<C>())
|
||||
}
|
||||
|
||||
pub fn select_row<'a, C: Column>(
|
||||
&'a self,
|
||||
query: &str,
|
||||
) -> Result<impl 'a + FnMut() -> Result<Option<C>>> {
|
||||
let mut statement = Statement::prepare(&self, query)?;
|
||||
Ok(move || statement.maybe_row::<C>())
|
||||
}
|
||||
|
||||
pub fn select_row_bound<'a, B: Bind, C: Column>(
|
||||
&'a self,
|
||||
query: &str,
|
||||
) -> Result<impl 'a + FnMut(B) -> Result<Option<C>>> {
|
||||
let mut statement = Statement::prepare(&self, query)?;
|
||||
Ok(move |bindings| {
|
||||
statement
|
||||
.with_bindings(bindings)
|
||||
.context("Bindings failed")?
|
||||
.maybe_row::<C>()
|
||||
.context("Maybe row failed")
|
||||
})
|
||||
}
|
||||
}
|
32
crates/sqlez/src/util.rs
Normal file
32
crates/sqlez/src/util.rs
Normal file
|
@ -0,0 +1,32 @@
|
|||
use std::ops::Deref;
|
||||
use std::sync::mpsc::Sender;
|
||||
|
||||
use parking_lot::Mutex;
|
||||
use thread_local::ThreadLocal;
|
||||
|
||||
/// Unbounded standard library sender which is stored per thread to get around
|
||||
/// the lack of sync on the standard library version while still being unbounded
|
||||
/// Note: this locks on the cloneable sender, but its done once per thread, so it
|
||||
/// shouldn't result in too much contention
|
||||
pub struct UnboundedSyncSender<T: Send> {
|
||||
clonable_sender: Mutex<Sender<T>>,
|
||||
local_senders: ThreadLocal<Sender<T>>,
|
||||
}
|
||||
|
||||
impl<T: Send> UnboundedSyncSender<T> {
|
||||
pub fn new(sender: Sender<T>) -> Self {
|
||||
Self {
|
||||
clonable_sender: Mutex::new(sender),
|
||||
local_senders: ThreadLocal::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send> Deref for UnboundedSyncSender<T> {
|
||||
type Target = Sender<T>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.local_senders
|
||||
.get_or(|| self.clonable_sender.lock().clone())
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue