Finished implementing the workspace stuff

This commit is contained in:
Mikayla Maki 2022-11-01 15:58:23 -07:00
parent 395070cb92
commit 777f05eb76
10 changed files with 263 additions and 278 deletions

View file

@ -53,6 +53,15 @@ impl Connection {
self.persistent
}
pub(crate) fn last_insert_id(&self) -> i64 {
unsafe { sqlite3_last_insert_rowid(self.sqlite3) }
}
pub fn insert(&self, query: impl AsRef<str>) -> Result<i64> {
self.exec(query)?;
Ok(self.last_insert_id())
}
pub fn exec(&self, query: impl AsRef<str>) -> Result<()> {
unsafe {
sqlite3_exec(
@ -140,9 +149,9 @@ mod test {
connection
.prepare("INSERT INTO text (text) VALUES (?);")
.unwrap()
.bound(text)
.bind(text)
.unwrap()
.run()
.exec()
.unwrap();
assert_eq!(
@ -176,8 +185,8 @@ mod test {
.prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)")
.unwrap();
insert.bound(tuple1.clone()).unwrap().run().unwrap();
insert.bound(tuple2.clone()).unwrap().run().unwrap();
insert.bind(tuple1.clone()).unwrap().exec().unwrap();
insert.bind(tuple2.clone()).unwrap().exec().unwrap();
assert_eq!(
connection
@ -203,7 +212,7 @@ mod test {
.prepare("INSERT INTO blobs (data) VALUES (?);")
.unwrap();
write.bind_blob(1, blob).unwrap();
write.run().unwrap();
write.exec().unwrap();
// Backup connection1 to connection2
let connection2 = Connection::open_memory("backup_works_other");

View file

@ -22,6 +22,7 @@ const MIGRATIONS_MIGRATION: Migration = Migration::new(
"}],
);
#[derive(Debug)]
pub struct Migration {
domain: &'static str,
migrations: &'static [&'static str],
@ -46,7 +47,7 @@ impl Migration {
WHERE domain = ?
ORDER BY step
"})?
.bound(self.domain)?
.bind(self.domain)?
.rows::<(String, usize, String)>()?;
let mut store_completed_migration = connection
@ -71,8 +72,8 @@ impl Migration {
connection.exec(migration)?;
store_completed_migration
.bound((self.domain, index, *migration))?
.run()?;
.bind((self.domain, index, *migration))?
.exec()?;
}
Ok(())
@ -162,9 +163,9 @@ mod test {
.unwrap();
store_completed_migration
.bound((domain, i, i.to_string()))
.bind((domain, i, i.to_string()))
.unwrap()
.run()
.exec()
.unwrap();
}
}

View file

@ -3,10 +3,36 @@ use anyhow::Result;
use crate::connection::Connection;
impl Connection {
// Run a set of commands within the context of a `SAVEPOINT name`. If the callback
// returns Err(_), the savepoint will be rolled back. Otherwise, the save
// point is released.
pub fn with_savepoint<R, F>(&mut self, name: impl AsRef<str>, f: F) -> Result<R>
where
F: FnOnce(&mut Connection) -> Result<R>,
{
let name = name.as_ref().to_owned();
self.exec(format!("SAVEPOINT {}", &name))?;
let result = f(self);
match result {
Ok(_) => {
self.exec(format!("RELEASE {}", name))?;
}
Err(_) => {
self.exec(format!("ROLLBACK TO {}", name))?;
self.exec(format!("RELEASE {}", name))?;
}
}
result
}
// Run a set of commands within the context of a `SAVEPOINT name`. If the callback
// returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
// point is released.
pub fn with_savepoint<F, R>(&mut self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
pub fn with_savepoint_rollback<R, F>(
&mut self,
name: impl AsRef<str>,
f: F,
) -> Result<Option<R>>
where
F: FnOnce(&mut Connection) -> Result<Option<R>>,
{
@ -50,15 +76,15 @@ mod tests {
connection.with_savepoint("first", |save1| {
save1
.prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
.bound((save1_text, 1))?
.run()?;
.bind((save1_text, 1))?
.exec()?;
assert!(save1
.with_savepoint("second", |save2| -> Result<Option<()>, anyhow::Error> {
save2
.prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
.bound((save2_text, 2))?
.run()?;
.bind((save2_text, 2))?
.exec()?;
assert_eq!(
save2
@ -79,11 +105,34 @@ mod tests {
vec![save1_text],
);
save1.with_savepoint("second", |save2| {
save1.with_savepoint_rollback::<(), _>("second", |save2| {
save2
.prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
.bound((save2_text, 2))?
.run()?;
.bind((save2_text, 2))?
.exec()?;
assert_eq!(
save2
.prepare("SELECT text FROM text ORDER BY text.idx ASC")?
.rows::<String>()?,
vec![save1_text, save2_text],
);
Ok(None)
})?;
assert_eq!(
save1
.prepare("SELECT text FROM text ORDER BY text.idx ASC")?
.rows::<String>()?,
vec![save1_text],
);
save1.with_savepoint_rollback("second", |save2| {
save2
.prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
.bind((save2_text, 2))?
.exec()?;
assert_eq!(
save2
@ -102,9 +151,16 @@ mod tests {
vec![save1_text, save2_text],
);
Ok(Some(()))
Ok(())
})?;
assert_eq!(
connection
.prepare("SELECT text FROM text ORDER BY text.idx ASC")?
.rows::<String>()?,
vec![save1_text, save2_text],
);
Ok(())
}
}

View file

@ -60,6 +60,10 @@ impl<'a> Statement<'a> {
}
}
pub fn parameter_count(&self) -> i32 {
unsafe { sqlite3_bind_parameter_count(self.raw_statement) }
}
pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
let index = index as c_int;
let blob_pointer = blob.as_ptr() as *const _;
@ -175,8 +179,9 @@ impl<'a> Statement<'a> {
Ok(str::from_utf8(slice)?)
}
pub fn bind<T: Bind>(&self, value: T) -> Result<()> {
value.bind(self, 1)?;
pub fn bind_value<T: Bind>(&self, value: T, idx: i32) -> Result<()> {
debug_assert!(idx > 0);
value.bind(self, idx)?;
Ok(())
}
@ -198,8 +203,8 @@ impl<'a> Statement<'a> {
}
}
pub fn bound(&mut self, bindings: impl Bind) -> Result<&mut Self> {
self.bind(bindings)?;
pub fn bind(&mut self, bindings: impl Bind) -> Result<&mut Self> {
self.bind_value(bindings, 1)?;
Ok(self)
}
@ -217,7 +222,12 @@ impl<'a> Statement<'a> {
}
}
pub fn run(&mut self) -> Result<()> {
pub fn insert(&mut self) -> Result<i64> {
self.exec()?;
Ok(self.connection.last_insert_id())
}
pub fn exec(&mut self) -> Result<()> {
fn logic(this: &mut Statement) -> Result<()> {
while this.step()? == StepResult::Row {}
Ok(())

View file

@ -3,12 +3,13 @@ use std::{ops::Deref, sync::Arc};
use connection::Connection;
use thread_local::ThreadLocal;
use crate::connection;
use crate::{connection, migrations::Migration};
pub struct ThreadSafeConnection {
uri: Arc<str>,
persistent: bool,
initialize_query: Option<&'static str>,
migrations: Option<&'static [Migration]>,
connection: Arc<ThreadLocal<Connection>>,
}
@ -18,6 +19,7 @@ impl ThreadSafeConnection {
uri: Arc::from(uri),
persistent,
initialize_query: None,
migrations: None,
connection: Default::default(),
}
}
@ -29,6 +31,11 @@ impl ThreadSafeConnection {
self
}
pub fn with_migrations(mut self, migrations: &'static [Migration]) -> Self {
self.migrations = Some(migrations);
self
}
/// Opens a new db connection with the initialized file path. This is internal and only
/// called from the deref function.
/// If opening fails, the connection falls back to a shared memory connection
@ -49,6 +56,7 @@ impl Clone for ThreadSafeConnection {
uri: self.uri.clone(),
persistent: self.persistent,
initialize_query: self.initialize_query.clone(),
migrations: self.migrations.clone(),
connection: self.connection.clone(),
}
}
@ -72,6 +80,14 @@ impl Deref for ThreadSafeConnection {
));
}
if let Some(migrations) = self.migrations {
for migration in migrations {
migration
.run(&connection)
.expect(&format!("Migrations failed to execute: {:?}", migration));
}
}
connection
})
}