Added sql! proc macro which checks syntax errors on sql code and displays them with reasonable underline locations
Co-Authored-By: Mikayla Maki <mikayla@zed.dev>
This commit is contained in:
parent
260164a711
commit
dd9d20be25
15 changed files with 342 additions and 211 deletions
|
@ -2,6 +2,7 @@ use std::{
|
|||
ffi::{CStr, CString},
|
||||
marker::PhantomData,
|
||||
path::Path,
|
||||
ptr,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
|
@ -85,6 +86,45 @@ impl Connection {
|
|||
self.backup_main(&destination)
|
||||
}
|
||||
|
||||
pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
|
||||
let sql = CString::new(sql).unwrap();
|
||||
let mut remaining_sql = sql.as_c_str();
|
||||
let sql_start = remaining_sql.as_ptr();
|
||||
|
||||
unsafe {
|
||||
while {
|
||||
let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
|
||||
remaining_sql_str != ";" && !remaining_sql_str.is_empty()
|
||||
} {
|
||||
let mut raw_statement = 0 as *mut sqlite3_stmt;
|
||||
let mut remaining_sql_ptr = ptr::null();
|
||||
sqlite3_prepare_v2(
|
||||
self.sqlite3,
|
||||
remaining_sql.as_ptr(),
|
||||
-1,
|
||||
&mut raw_statement,
|
||||
&mut remaining_sql_ptr,
|
||||
);
|
||||
|
||||
let res = sqlite3_errcode(self.sqlite3);
|
||||
let offset = sqlite3_error_offset(self.sqlite3);
|
||||
|
||||
if res == 1 && offset >= 0 {
|
||||
let message = sqlite3_errmsg(self.sqlite3);
|
||||
let err_msg =
|
||||
String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
|
||||
.into_owned();
|
||||
let sub_statement_correction =
|
||||
remaining_sql.as_ptr() as usize - sql_start as usize;
|
||||
|
||||
return Some((err_msg, offset as usize + sub_statement_correction));
|
||||
}
|
||||
remaining_sql = CStr::from_ptr(remaining_sql_ptr);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn last_error(&self) -> Result<()> {
|
||||
unsafe {
|
||||
let code = sqlite3_errcode(self.sqlite3);
|
||||
|
@ -259,10 +299,31 @@ mod test {
|
|||
|
||||
assert_eq!(
|
||||
connection
|
||||
.select_row::<usize>("SELECt * FROM test")
|
||||
.select_row::<usize>("SELECT * FROM test")
|
||||
.unwrap()()
|
||||
.unwrap(),
|
||||
Some(2)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sql_has_syntax_errors() {
|
||||
let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
|
||||
let first_stmt =
|
||||
"CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
|
||||
let second_stmt = "SELECT FROM";
|
||||
|
||||
let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
|
||||
|
||||
let res = connection
|
||||
.sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
|
||||
.map(|(_, offset)| offset);
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
Some(first_stmt.len() + second_offset + 1) // TODO: This value is wrong!
|
||||
);
|
||||
|
||||
panic!("{:?}", res)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,12 @@ pub trait Migrator {
|
|||
fn migrate(connection: &Connection) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
impl Migrator for () {
|
||||
fn migrate(_connection: &Connection) -> anyhow::Result<()> {
|
||||
Ok(()) // Do nothing
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Domain> Migrator for D {
|
||||
fn migrate(connection: &Connection) -> anyhow::Result<()> {
|
||||
connection.migrate(Self::name(), Self::migrations())
|
||||
|
|
|
@ -489,76 +489,3 @@ mod test {
|
|||
);
|
||||
}
|
||||
}
|
||||
|
||||
mod syntax_check {
|
||||
use std::{
|
||||
ffi::{CStr, CString},
|
||||
ptr,
|
||||
};
|
||||
|
||||
use libsqlite3_sys::{
|
||||
sqlite3_close, sqlite3_errmsg, sqlite3_error_offset, sqlite3_extended_errcode,
|
||||
sqlite3_extended_result_codes, sqlite3_finalize, sqlite3_open_v2, sqlite3_prepare_v2,
|
||||
sqlite3_stmt, SQLITE_OPEN_CREATE, SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_READWRITE,
|
||||
};
|
||||
|
||||
fn syntax_errors(sql: &str) -> Option<(String, i32)> {
|
||||
let mut sqlite3 = 0 as *mut _;
|
||||
let mut raw_statement = 0 as *mut sqlite3_stmt;
|
||||
|
||||
let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
|
||||
unsafe {
|
||||
let memory_str = CString::new(":memory:").unwrap();
|
||||
sqlite3_open_v2(memory_str.as_ptr(), &mut sqlite3, flags, 0 as *const _);
|
||||
|
||||
let sql = CString::new(sql).unwrap();
|
||||
|
||||
// Turn on extended error codes
|
||||
sqlite3_extended_result_codes(sqlite3, 1);
|
||||
|
||||
sqlite3_prepare_v2(
|
||||
sqlite3,
|
||||
sql.as_c_str().as_ptr(),
|
||||
-1,
|
||||
&mut raw_statement,
|
||||
&mut ptr::null(),
|
||||
);
|
||||
|
||||
let res = sqlite3_extended_errcode(sqlite3);
|
||||
let offset = sqlite3_error_offset(sqlite3);
|
||||
|
||||
if res == 1 && offset != -1 {
|
||||
let message = sqlite3_errmsg(sqlite3);
|
||||
let err_msg =
|
||||
String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
|
||||
.into_owned();
|
||||
|
||||
sqlite3_finalize(*&mut raw_statement);
|
||||
sqlite3_close(sqlite3);
|
||||
|
||||
return Some((err_msg, offset));
|
||||
} else {
|
||||
sqlite3_finalize(*&mut raw_statement);
|
||||
sqlite3_close(sqlite3);
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::syntax_errors;
|
||||
|
||||
#[test]
|
||||
fn test_check_syntax() {
|
||||
assert!(syntax_errors("SELECT FROM").is_some());
|
||||
|
||||
assert!(syntax_errors("SELECT col FROM table_t;").is_none());
|
||||
|
||||
assert!(syntax_errors("CREATE TABLE t(col TEXT,) STRICT;").is_some());
|
||||
|
||||
assert!(syntax_errors("CREATE TABLE t(col TEXT) STRICT;").is_none());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ lazy_static! {
|
|||
Default::default();
|
||||
}
|
||||
|
||||
pub struct ThreadSafeConnection<M: Migrator> {
|
||||
pub struct ThreadSafeConnection<M: Migrator = ()> {
|
||||
uri: Arc<str>,
|
||||
persistent: bool,
|
||||
initialize_query: Option<&'static str>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue