Made dev tools not break everything about the db

Also improved multi statements to allow out of order parameter binding in statements
Ensured that all statements are run for maybe_row and single, and that of all statements only 1 of them returns only 1 row
Made bind and column calls add useful context to errors

Co-authored-by: kay@zed.dev
This commit is contained in:
Mikayla Maki 2022-11-21 13:42:26 -08:00
parent 2dc1130902
commit 3e0f9d27a7
13 changed files with 219 additions and 110 deletions

View file

@ -19,8 +19,6 @@ pub struct Statement<'a> {
pub enum StepResult {
Row,
Done,
Misuse,
Other(i32),
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
@ -40,12 +38,14 @@ impl<'a> Statement<'a> {
connection,
phantom: PhantomData,
};
unsafe {
let sql = CString::new(query.as_ref())?;
let sql = CString::new(query.as_ref()).context("Error creating cstr")?;
let mut remaining_sql = sql.as_c_str();
while {
let remaining_sql_str = remaining_sql.to_str()?.trim();
let remaining_sql_str = remaining_sql
.to_str()
.context("Parsing remaining sql")?
.trim();
remaining_sql_str != ";" && !remaining_sql_str.is_empty()
} {
let mut raw_statement = 0 as *mut sqlite3_stmt;
@ -92,116 +92,136 @@ impl<'a> Statement<'a> {
}
}
fn bind_index_with(&self, index: i32, bind: impl Fn(&*mut sqlite3_stmt) -> ()) -> Result<()> {
let mut any_succeed = false;
unsafe {
for raw_statement in self.raw_statements.iter() {
if index <= sqlite3_bind_parameter_count(*raw_statement) {
bind(raw_statement);
self.connection
.last_error()
.with_context(|| format!("Failed to bind value at index {index}"))?;
any_succeed = true;
} else {
continue;
}
}
}
if any_succeed {
Ok(())
} else {
Err(anyhow!("Failed to bind parameters"))
}
}
pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
let index = index as c_int;
let blob_pointer = blob.as_ptr() as *const _;
let len = blob.len() as c_int;
unsafe {
for raw_statement in self.raw_statements.iter() {
sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
}
}
self.connection.last_error()
self.bind_index_with(index, |raw_statement| unsafe {
sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
})
}
pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
let index = index as c_int;
let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
self.connection.last_error()?;
self.connection
.last_error()
.with_context(|| format!("Failed to read blob at index {index}"))?;
if pointer.is_null() {
return Ok(&[]);
}
let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
self.connection.last_error()?;
self.connection
.last_error()
.with_context(|| format!("Failed to read length of blob at index {index}"))?;
unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
}
pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
let index = index as c_int;
unsafe {
for raw_statement in self.raw_statements.iter() {
sqlite3_bind_double(*raw_statement, index, double);
}
}
self.connection.last_error()
self.bind_index_with(index, |raw_statement| unsafe {
sqlite3_bind_double(*raw_statement, index, double);
})
}
pub fn column_double(&self, index: i32) -> Result<f64> {
let index = index as c_int;
let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
self.connection.last_error()?;
self.connection
.last_error()
.with_context(|| format!("Failed to read double at index {index}"))?;
Ok(result)
}
pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
let index = index as c_int;
unsafe {
for raw_statement in self.raw_statements.iter() {
sqlite3_bind_int(*raw_statement, index, int);
}
};
self.connection.last_error()
self.bind_index_with(index, |raw_statement| unsafe {
sqlite3_bind_int(*raw_statement, index, int);
})
}
pub fn column_int(&self, index: i32) -> Result<i32> {
let index = index as c_int;
let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
self.connection.last_error()?;
self.connection
.last_error()
.with_context(|| format!("Failed to read int at index {index}"))?;
Ok(result)
}
pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
let index = index as c_int;
unsafe {
for raw_statement in self.raw_statements.iter() {
sqlite3_bind_int64(*raw_statement, index, int);
}
}
self.connection.last_error()
self.bind_index_with(index, |raw_statement| unsafe {
sqlite3_bind_int64(*raw_statement, index, int);
})
}
pub fn column_int64(&self, index: i32) -> Result<i64> {
let index = index as c_int;
let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
self.connection.last_error()?;
self.connection
.last_error()
.with_context(|| format!("Failed to read i64 at index {index}"))?;
Ok(result)
}
pub fn bind_null(&self, index: i32) -> Result<()> {
let index = index as c_int;
unsafe {
for raw_statement in self.raw_statements.iter() {
sqlite3_bind_null(*raw_statement, index);
}
}
self.connection.last_error()
self.bind_index_with(index, |raw_statement| unsafe {
sqlite3_bind_null(*raw_statement, index);
})
}
pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
let index = index as c_int;
let text_pointer = text.as_ptr() as *const _;
let len = text.len() as c_int;
unsafe {
for raw_statement in self.raw_statements.iter() {
sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
}
}
self.connection.last_error()
self.bind_index_with(index, |raw_statement| unsafe {
sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
})
}
pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
let index = index as c_int;
let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
self.connection.last_error()?;
self.connection
.last_error()
.with_context(|| format!("Failed to read text from column {index}"))?;
if pointer.is_null() {
return Ok("");
}
let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
self.connection.last_error()?;
self.connection
.last_error()
.with_context(|| format!("Failed to read text length at {index}"))?;
let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
Ok(str::from_utf8(slice)?)
@ -247,11 +267,11 @@ impl<'a> Statement<'a> {
self.step()
}
}
SQLITE_MISUSE => Ok(StepResult::Misuse),
other => self
.connection
.last_error()
.map(|_| StepResult::Other(other)),
SQLITE_MISUSE => Err(anyhow!("Statement step returned SQLITE_MISUSE")),
_other_error => {
self.connection.last_error()?;
unreachable!("Step returned error code and last error failed to catch it");
}
}
}
}
@ -293,11 +313,17 @@ impl<'a> Statement<'a> {
callback: impl FnOnce(&mut Statement) -> Result<R>,
) -> Result<R> {
if this.step()? != StepResult::Row {
return Err(anyhow!("single called with query that returns no rows."));
}
let result = callback(this)?;
if this.step()? != StepResult::Done {
return Err(anyhow!(
"Single(Map) called with query that returns no rows."
"single called with a query that returns more than one row."
));
}
callback(this)
Ok(result)
}
let result = logic(self, callback);
self.reset();
@ -316,10 +342,21 @@ impl<'a> Statement<'a> {
this: &mut Statement,
callback: impl FnOnce(&mut Statement) -> Result<R>,
) -> Result<Option<R>> {
if this.step()? != StepResult::Row {
if this.step().context("Failed on step call")? != StepResult::Row {
return Ok(None);
}
callback(this).map(|r| Some(r))
let result = callback(this)
.map(|r| Some(r))
.context("Failed to parse row result")?;
if this.step().context("Second step call")? != StepResult::Done {
return Err(anyhow!(
"maybe called with a query that returns more than one row."
));
}
Ok(result)
}
let result = logic(self, callback);
self.reset();
@ -350,6 +387,38 @@ mod test {
statement::{Statement, StepResult},
};
#[test]
fn binding_multiple_statements_with_parameter_gaps() {
let connection =
Connection::open_memory(Some("binding_multiple_statements_with_parameter_gaps"));
connection
.exec(indoc! {"
CREATE TABLE test (
col INTEGER
)"})
.unwrap()()
.unwrap();
let statement = Statement::prepare(
&connection,
indoc! {"
INSERT INTO test(col) VALUES (?3);
SELECT * FROM test WHERE col = ?1"},
)
.unwrap();
statement
.bind_int(1, 1)
.expect("Could not bind parameter to first index");
statement
.bind_int(2, 2)
.expect("Could not bind parameter to second index");
statement
.bind_int(3, 3)
.expect("Could not bind parameter to third index");
}
#[test]
fn blob_round_trips() {
let connection1 = Connection::open_memory(Some("blob_round_trips"));