ZIm/crates/sqlez/src/connection.rs
Mikayla Maki 3e0f9d27a7 Made dev tools not break everything about the db
Also improved multi statements to allow out of order parameter binding in statements
Ensured that all statements are run for maybe_row and single, and that of all statements only 1 of them returns only 1 row
Made bind and column calls add useful context to errors

Co-authored-by: kay@zed.dev
2022-12-03 16:06:01 -08:00

262 lines
7.1 KiB
Rust

use std::{
ffi::{CStr, CString},
marker::PhantomData,
path::Path,
};
use anyhow::{anyhow, Result};
use libsqlite3_sys::*;
pub struct Connection {
pub(crate) sqlite3: *mut sqlite3,
persistent: bool,
phantom: PhantomData<sqlite3>,
}
unsafe impl Send for Connection {}
impl Connection {
fn open(uri: &str, persistent: bool) -> Result<Self> {
let mut connection = Self {
sqlite3: 0 as *mut _,
persistent,
phantom: PhantomData,
};
let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
unsafe {
sqlite3_open_v2(
CString::new(uri)?.as_ptr(),
&mut connection.sqlite3,
flags,
0 as *const _,
);
// Turn on extended error codes
sqlite3_extended_result_codes(connection.sqlite3, 1);
connection.last_error()?;
}
Ok(connection)
}
/// Attempts to open the database at uri. If it fails, a shared memory db will be opened
/// instead.
pub fn open_file(uri: &str) -> Self {
Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
}
pub fn open_memory(uri: Option<&str>) -> Self {
let in_memory_path = if let Some(uri) = uri {
format!("file:{}?mode=memory&cache=shared", uri)
} else {
":memory:".to_string()
};
Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
}
pub fn persistent(&self) -> bool {
self.persistent
}
pub fn backup_main(&self, destination: &Connection) -> Result<()> {
unsafe {
let backup = sqlite3_backup_init(
destination.sqlite3,
CString::new("main")?.as_ptr(),
self.sqlite3,
CString::new("main")?.as_ptr(),
);
sqlite3_backup_step(backup, -1);
sqlite3_backup_finish(backup);
destination.last_error()
}
}
pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
self.backup_main(&destination)
}
pub(crate) fn last_error(&self) -> Result<()> {
unsafe {
let code = sqlite3_errcode(self.sqlite3);
const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
if NON_ERROR_CODES.contains(&code) {
return Ok(());
}
let message = sqlite3_errmsg(self.sqlite3);
let message = if message.is_null() {
None
} else {
Some(
String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
.into_owned(),
)
};
Err(anyhow!(
"Sqlite call failed with code {} and message: {:?}",
code as isize,
message
))
}
}
}
impl Drop for Connection {
fn drop(&mut self) {
unsafe { sqlite3_close(self.sqlite3) };
}
}
#[cfg(test)]
mod test {
use anyhow::Result;
use indoc::indoc;
use crate::connection::Connection;
#[test]
fn string_round_trips() -> Result<()> {
let connection = Connection::open_memory(Some("string_round_trips"));
connection
.exec(indoc! {"
CREATE TABLE text (
text TEXT
);"})
.unwrap()()
.unwrap();
let text = "Some test text";
connection
.exec_bound("INSERT INTO text (text) VALUES (?);")
.unwrap()(text)
.unwrap();
assert_eq!(
connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
Some(text.to_string())
);
Ok(())
}
#[test]
fn tuple_round_trips() {
let connection = Connection::open_memory(Some("tuple_round_trips"));
connection
.exec(indoc! {"
CREATE TABLE test (
text TEXT,
integer INTEGER,
blob BLOB
);"})
.unwrap()()
.unwrap();
let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
let mut insert = connection
.exec_bound::<(String, usize, Vec<u8>)>(
"INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
)
.unwrap();
insert(tuple1.clone()).unwrap();
insert(tuple2.clone()).unwrap();
assert_eq!(
connection
.select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
.unwrap()()
.unwrap(),
vec![tuple1, tuple2]
);
}
#[test]
fn bool_round_trips() {
let connection = Connection::open_memory(Some("bool_round_trips"));
connection
.exec(indoc! {"
CREATE TABLE bools (
t INTEGER,
f INTEGER
);"})
.unwrap()()
.unwrap();
connection
.exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
.unwrap()((true, false))
.unwrap();
assert_eq!(
connection
.select_row::<(bool, bool)>("SELECT * FROM bools;")
.unwrap()()
.unwrap(),
Some((true, false))
);
}
#[test]
fn backup_works() {
let connection1 = Connection::open_memory(Some("backup_works"));
connection1
.exec(indoc! {"
CREATE TABLE blobs (
data BLOB
);"})
.unwrap()()
.unwrap();
let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
connection1
.exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
.unwrap()(blob.clone())
.unwrap();
// Backup connection1 to connection2
let connection2 = Connection::open_memory(Some("backup_works_other"));
connection1.backup_main(&connection2).unwrap();
// Delete the added blob and verify its deleted on the other side
let read_blobs = connection1
.select::<Vec<u8>>("SELECT * FROM blobs;")
.unwrap()()
.unwrap();
assert_eq!(read_blobs, vec![blob]);
}
#[test]
fn multi_step_statement_works() {
let connection = Connection::open_memory(Some("multi_step_statement_works"));
connection
.exec(indoc! {"
CREATE TABLE test (
col INTEGER
)"})
.unwrap()()
.unwrap();
connection
.exec(indoc! {"
INSERT INTO test(col) VALUES (2)"})
.unwrap()()
.unwrap();
assert_eq!(
connection
.select_row::<usize>("SELECt * FROM test")
.unwrap()()
.unwrap(),
Some(2)
);
}
}