Added sql! proc macro which checks syntax errors on sql code and displays them with reasonable underline locations

Co-Authored-By: Mikayla Maki <mikayla@zed.dev>
This commit is contained in:
Kay Simmons 2022-11-28 17:42:18 -08:00 committed by Mikayla Maki
parent 260164a711
commit dd9d20be25
15 changed files with 342 additions and 211 deletions

View file

@ -2,6 +2,7 @@ use std::{
ffi::{CStr, CString},
marker::PhantomData,
path::Path,
ptr,
};
use anyhow::{anyhow, Result};
@ -85,6 +86,45 @@ impl Connection {
self.backup_main(&destination)
}
pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
let sql = CString::new(sql).unwrap();
let mut remaining_sql = sql.as_c_str();
let sql_start = remaining_sql.as_ptr();
unsafe {
while {
let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
remaining_sql_str != ";" && !remaining_sql_str.is_empty()
} {
let mut raw_statement = 0 as *mut sqlite3_stmt;
let mut remaining_sql_ptr = ptr::null();
sqlite3_prepare_v2(
self.sqlite3,
remaining_sql.as_ptr(),
-1,
&mut raw_statement,
&mut remaining_sql_ptr,
);
let res = sqlite3_errcode(self.sqlite3);
let offset = sqlite3_error_offset(self.sqlite3);
if res == 1 && offset >= 0 {
let message = sqlite3_errmsg(self.sqlite3);
let err_msg =
String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
.into_owned();
let sub_statement_correction =
remaining_sql.as_ptr() as usize - sql_start as usize;
return Some((err_msg, offset as usize + sub_statement_correction));
}
remaining_sql = CStr::from_ptr(remaining_sql_ptr);
}
}
None
}
pub(crate) fn last_error(&self) -> Result<()> {
unsafe {
let code = sqlite3_errcode(self.sqlite3);
@ -259,10 +299,31 @@ mod test {
assert_eq!(
connection
.select_row::<usize>("SELECt * FROM test")
.select_row::<usize>("SELECT * FROM test")
.unwrap()()
.unwrap(),
Some(2)
);
}
#[test]
fn test_sql_has_syntax_errors() {
let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
let first_stmt =
"CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
let second_stmt = "SELECT FROM";
let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
let res = connection
.sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
.map(|(_, offset)| offset);
assert_eq!(
res,
Some(first_stmt.len() + second_offset + 1) // TODO: This value is wrong!
);
panic!("{:?}", res)
}
}