Add typed statements
This commit is contained in:
parent
64ac84fdf4
commit
4a00f0b062
14 changed files with 388 additions and 394 deletions
|
@ -1,6 +1,6 @@
|
|||
use std::ffi::{c_int, CString};
|
||||
use std::ffi::{c_int, CStr, CString};
|
||||
use std::marker::PhantomData;
|
||||
use std::{slice, str};
|
||||
use std::{ptr, slice, str};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use libsqlite3_sys::*;
|
||||
|
@ -9,7 +9,8 @@ use crate::bindable::{Bind, Column};
|
|||
use crate::connection::Connection;
|
||||
|
||||
pub struct Statement<'a> {
|
||||
raw_statement: *mut sqlite3_stmt,
|
||||
raw_statements: Vec<*mut sqlite3_stmt>,
|
||||
current_statement: usize,
|
||||
connection: &'a Connection,
|
||||
phantom: PhantomData<sqlite3_stmt>,
|
||||
}
|
||||
|
@ -34,19 +35,31 @@ pub enum SqlType {
|
|||
impl<'a> Statement<'a> {
|
||||
pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
|
||||
let mut statement = Self {
|
||||
raw_statement: 0 as *mut _,
|
||||
raw_statements: Default::default(),
|
||||
current_statement: 0,
|
||||
connection,
|
||||
phantom: PhantomData,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
sqlite3_prepare_v2(
|
||||
connection.sqlite3,
|
||||
CString::new(query.as_ref())?.as_ptr(),
|
||||
-1,
|
||||
&mut statement.raw_statement,
|
||||
0 as *mut _,
|
||||
);
|
||||
let sql = CString::new(query.as_ref())?;
|
||||
let mut remaining_sql = sql.as_c_str();
|
||||
while {
|
||||
let remaining_sql_str = remaining_sql.to_str()?;
|
||||
remaining_sql_str.trim() != ";" && !remaining_sql_str.is_empty()
|
||||
} {
|
||||
let mut raw_statement = 0 as *mut sqlite3_stmt;
|
||||
let mut remaining_sql_ptr = ptr::null();
|
||||
sqlite3_prepare_v2(
|
||||
connection.sqlite3,
|
||||
remaining_sql.as_ptr(),
|
||||
-1,
|
||||
&mut raw_statement,
|
||||
&mut remaining_sql_ptr,
|
||||
);
|
||||
remaining_sql = CStr::from_ptr(remaining_sql_ptr);
|
||||
statement.raw_statements.push(raw_statement);
|
||||
}
|
||||
|
||||
connection
|
||||
.last_error()
|
||||
|
@ -56,131 +69,138 @@ impl<'a> Statement<'a> {
|
|||
Ok(statement)
|
||||
}
|
||||
|
||||
fn current_statement(&self) -> *mut sqlite3_stmt {
|
||||
*self.raw_statements.get(self.current_statement).unwrap()
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
unsafe {
|
||||
sqlite3_reset(self.raw_statement);
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_reset(*raw_statement);
|
||||
}
|
||||
}
|
||||
self.current_statement = 0;
|
||||
}
|
||||
|
||||
pub fn parameter_count(&self) -> i32 {
|
||||
unsafe { sqlite3_bind_parameter_count(self.raw_statement) }
|
||||
unsafe {
|
||||
self.raw_statements
|
||||
.iter()
|
||||
.map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
|
||||
// dbg!("bind blob", index);
|
||||
let index = index as c_int;
|
||||
let blob_pointer = blob.as_ptr() as *const _;
|
||||
let len = blob.len() as c_int;
|
||||
unsafe {
|
||||
sqlite3_bind_blob(
|
||||
self.raw_statement,
|
||||
index,
|
||||
blob_pointer,
|
||||
len,
|
||||
SQLITE_TRANSIENT(),
|
||||
);
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
|
||||
}
|
||||
}
|
||||
self.connection.last_error()
|
||||
}
|
||||
|
||||
pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
|
||||
let index = index as c_int;
|
||||
let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) };
|
||||
let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
|
||||
|
||||
self.connection.last_error()?;
|
||||
if pointer.is_null() {
|
||||
return Ok(&[]);
|
||||
}
|
||||
let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
|
||||
let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
|
||||
self.connection.last_error()?;
|
||||
unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
|
||||
}
|
||||
|
||||
pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
|
||||
// dbg!("bind double", index);
|
||||
let index = index as c_int;
|
||||
|
||||
unsafe {
|
||||
sqlite3_bind_double(self.raw_statement, index, double);
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_bind_double(*raw_statement, index, double);
|
||||
}
|
||||
}
|
||||
self.connection.last_error()
|
||||
}
|
||||
|
||||
pub fn column_double(&self, index: i32) -> Result<f64> {
|
||||
let index = index as c_int;
|
||||
let result = unsafe { sqlite3_column_double(self.raw_statement, index) };
|
||||
let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
|
||||
self.connection.last_error()?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
|
||||
// dbg!("bind int", index);
|
||||
let index = index as c_int;
|
||||
|
||||
unsafe {
|
||||
sqlite3_bind_int(self.raw_statement, index, int);
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_bind_int(*raw_statement, index, int);
|
||||
}
|
||||
};
|
||||
self.connection.last_error()
|
||||
}
|
||||
|
||||
pub fn column_int(&self, index: i32) -> Result<i32> {
|
||||
let index = index as c_int;
|
||||
let result = unsafe { sqlite3_column_int(self.raw_statement, index) };
|
||||
let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
|
||||
self.connection.last_error()?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
|
||||
// dbg!("bind int64", index);
|
||||
let index = index as c_int;
|
||||
unsafe {
|
||||
sqlite3_bind_int64(self.raw_statement, index, int);
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_bind_int64(*raw_statement, index, int);
|
||||
}
|
||||
}
|
||||
self.connection.last_error()
|
||||
}
|
||||
|
||||
pub fn column_int64(&self, index: i32) -> Result<i64> {
|
||||
let index = index as c_int;
|
||||
let result = unsafe { sqlite3_column_int64(self.raw_statement, index) };
|
||||
let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
|
||||
self.connection.last_error()?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn bind_null(&self, index: i32) -> Result<()> {
|
||||
// dbg!("bind null", index);
|
||||
let index = index as c_int;
|
||||
unsafe {
|
||||
sqlite3_bind_null(self.raw_statement, index);
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_bind_null(*raw_statement, index);
|
||||
}
|
||||
}
|
||||
self.connection.last_error()
|
||||
}
|
||||
|
||||
pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
|
||||
// dbg!("bind text", index, text);
|
||||
let index = index as c_int;
|
||||
let text_pointer = text.as_ptr() as *const _;
|
||||
let len = text.len() as c_int;
|
||||
unsafe {
|
||||
sqlite3_bind_text(
|
||||
self.raw_statement,
|
||||
index,
|
||||
text_pointer,
|
||||
len,
|
||||
SQLITE_TRANSIENT(),
|
||||
);
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
|
||||
}
|
||||
}
|
||||
self.connection.last_error()
|
||||
}
|
||||
|
||||
pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
|
||||
let index = index as c_int;
|
||||
let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) };
|
||||
let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
|
||||
|
||||
self.connection.last_error()?;
|
||||
if pointer.is_null() {
|
||||
return Ok("");
|
||||
}
|
||||
let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
|
||||
let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
|
||||
self.connection.last_error()?;
|
||||
|
||||
let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
|
||||
|
@ -198,7 +218,7 @@ impl<'a> Statement<'a> {
|
|||
}
|
||||
|
||||
pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
|
||||
let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT <FRIEND> FROM TABLE
|
||||
let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
|
||||
self.connection.last_error()?;
|
||||
match result {
|
||||
SQLITE_INTEGER => Ok(SqlType::Integer),
|
||||
|
@ -217,9 +237,16 @@ impl<'a> Statement<'a> {
|
|||
|
||||
fn step(&mut self) -> Result<StepResult> {
|
||||
unsafe {
|
||||
match sqlite3_step(self.raw_statement) {
|
||||
match sqlite3_step(self.current_statement()) {
|
||||
SQLITE_ROW => Ok(StepResult::Row),
|
||||
SQLITE_DONE => Ok(StepResult::Done),
|
||||
SQLITE_DONE => {
|
||||
if self.current_statement >= self.raw_statements.len() - 1 {
|
||||
Ok(StepResult::Done)
|
||||
} else {
|
||||
self.current_statement += 1;
|
||||
self.step()
|
||||
}
|
||||
}
|
||||
SQLITE_MISUSE => Ok(StepResult::Misuse),
|
||||
other => self
|
||||
.connection
|
||||
|
@ -311,7 +338,11 @@ impl<'a> Statement<'a> {
|
|||
|
||||
impl<'a> Drop for Statement<'a> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { sqlite3_finalize(self.raw_statement) };
|
||||
unsafe {
|
||||
for raw_statement in self.raw_statements.iter() {
|
||||
sqlite3_finalize(*raw_statement);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -319,7 +350,10 @@ impl<'a> Drop for Statement<'a> {
|
|||
mod test {
|
||||
use indoc::indoc;
|
||||
|
||||
use crate::{connection::Connection, statement::StepResult};
|
||||
use crate::{
|
||||
connection::Connection,
|
||||
statement::{Statement, StepResult},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn blob_round_trips() {
|
||||
|
@ -327,28 +361,28 @@ mod test {
|
|||
connection1
|
||||
.exec(indoc! {"
|
||||
CREATE TABLE blobs (
|
||||
data BLOB
|
||||
);"})
|
||||
.unwrap();
|
||||
data BLOB
|
||||
)"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
|
||||
|
||||
let mut write = connection1
|
||||
.prepare("INSERT INTO blobs (data) VALUES (?);")
|
||||
.unwrap();
|
||||
let mut write =
|
||||
Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
|
||||
write.bind_blob(1, blob).unwrap();
|
||||
assert_eq!(write.step().unwrap(), StepResult::Done);
|
||||
|
||||
// Read the blob from the
|
||||
let connection2 = Connection::open_memory("blob_round_trips");
|
||||
let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap();
|
||||
let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
|
||||
assert_eq!(read.step().unwrap(), StepResult::Row);
|
||||
assert_eq!(read.column_blob(0).unwrap(), blob);
|
||||
assert_eq!(read.step().unwrap(), StepResult::Done);
|
||||
|
||||
// Delete the added blob and verify its deleted on the other side
|
||||
connection2.exec("DELETE FROM blobs;").unwrap();
|
||||
let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap();
|
||||
connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
|
||||
let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
|
||||
assert_eq!(read.step().unwrap(), StepResult::Done);
|
||||
}
|
||||
|
||||
|
@ -359,32 +393,25 @@ mod test {
|
|||
.exec(indoc! {"
|
||||
CREATE TABLE texts (
|
||||
text TEXT
|
||||
);"})
|
||||
.unwrap();
|
||||
)"})
|
||||
.unwrap()()
|
||||
.unwrap();
|
||||
|
||||
assert!(connection
|
||||
.prepare("SELECT text FROM texts")
|
||||
.unwrap()
|
||||
.maybe_row::<String>()
|
||||
.unwrap()
|
||||
.is_none());
|
||||
.select_row::<String>("SELECT text FROM texts")
|
||||
.unwrap()()
|
||||
.unwrap()
|
||||
.is_none());
|
||||
|
||||
let text_to_insert = "This is a test";
|
||||
|
||||
connection
|
||||
.prepare("INSERT INTO texts VALUES (?)")
|
||||
.unwrap()
|
||||
.with_bindings(text_to_insert)
|
||||
.unwrap()
|
||||
.exec()
|
||||
.unwrap();
|
||||
.exec_bound("INSERT INTO texts VALUES (?)")
|
||||
.unwrap()(text_to_insert)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
connection
|
||||
.prepare("SELECT text FROM texts")
|
||||
.unwrap()
|
||||
.maybe_row::<String>()
|
||||
.unwrap(),
|
||||
connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
|
||||
Some(text_to_insert.to_string())
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue