diff --git a/crates/db/src/db.rs b/crates/db/src/db.rs index bde69fead7..b3370db753 100644 --- a/crates/db/src/db.rs +++ b/crates/db/src/db.rs @@ -6,17 +6,11 @@ pub use indoc::indoc; pub use lazy_static; pub use sqlez; -#[cfg(any(test, feature = "test-support"))] -use anyhow::Result; -#[cfg(any(test, feature = "test-support"))] -use sqlez::connection::Connection; -#[cfg(any(test, feature = "test-support"))] -use sqlez::domain::Domain; - use sqlez::domain::Migrator; use sqlez::thread_safe_connection::ThreadSafeConnection; use std::fs::{create_dir_all, remove_dir_all}; use std::path::Path; +use std::sync::atomic::{AtomicBool, Ordering}; use util::channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}; use util::paths::DB_DIR; @@ -28,13 +22,21 @@ const INITIALIZE_QUERY: &'static str = indoc! {" PRAGMA case_sensitive_like=TRUE; "}; +lazy_static::lazy_static! { + static ref DB_WIPED: AtomicBool = AtomicBool::new(false); +} + /// Open or create a database at the given directory path. pub fn open_file_db() -> ThreadSafeConnection { // Use 0 for now. Will implement incrementing and clearing of old db files soon TM let current_db_dir = (*DB_DIR).join(Path::new(&format!("0-{}", *RELEASE_CHANNEL_NAME))); - if *RELEASE_CHANNEL == ReleaseChannel::Dev && std::env::var("WIPE_DB").is_ok() { + if *RELEASE_CHANNEL == ReleaseChannel::Dev + && std::env::var("WIPE_DB").is_ok() + && !DB_WIPED.load(Ordering::Acquire) + { remove_dir_all(¤t_db_dir).ok(); + DB_WIPED.store(true, Ordering::Relaxed); } create_dir_all(¤t_db_dir).expect("Should be able to create the database directory"); @@ -48,15 +50,6 @@ pub fn open_memory_db(db_name: Option<&str>) -> ThreadSafeConnectio ThreadSafeConnection::new(db_name, false).with_initialize_query(INITIALIZE_QUERY) } -#[cfg(any(test, feature = "test-support"))] -pub fn write_db_to>( - conn: &ThreadSafeConnection, - dest: P, -) -> Result<()> { - let destination = Connection::open_file(dest.as_ref().to_string_lossy().as_ref()); - conn.backup_main(&destination) -} - /// Implements a basic DB wrapper for a given domain #[macro_export] macro_rules! connection { @@ -155,11 +148,11 @@ macro_rules! sql_method { } }; - ($id:ident() -> Result<$return_type:ty>>: $sql:expr) => { + ($id:ident() -> Result<$return_type:ty>: $sql:expr) => { pub fn $id(&self) -> $crate::sqlez::anyhow::Result<$return_type> { use $crate::anyhow::Context; - self.select_row::<$return_type>($sql)?(($($arg),+)) + self.select_row::<$return_type>($sql)?() .context(::std::format!( "Error in {}, select_row_bound failed to execute or parse for: {}", ::std::stringify!($id), @@ -172,7 +165,7 @@ macro_rules! sql_method { )) } }; - ($id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty>>: $sql:expr) => { + ($id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty>: $sql:expr) => { pub fn $id(&self, $($arg: $arg_type),+) -> $crate::sqlez::anyhow::Result<$return_type> { use $crate::anyhow::Context; diff --git a/crates/editor/src/persistence.rs b/crates/editor/src/persistence.rs index 5747558700..a77eec7fd1 100644 --- a/crates/editor/src/persistence.rs +++ b/crates/editor/src/persistence.rs @@ -32,7 +32,7 @@ impl Domain for Editor { impl EditorDb { sql_method! { - get_path(item_id: ItemId, workspace_id: WorkspaceId) -> Result>: + get_path(item_id: ItemId, workspace_id: WorkspaceId) -> Result: indoc! {" SELECT path FROM editors WHERE item_id = ? AND workspace_id = ?"} diff --git a/crates/sqlez/src/bindable.rs b/crates/sqlez/src/bindable.rs index 18c4acedad..51f67dd03f 100644 --- a/crates/sqlez/src/bindable.rs +++ b/crates/sqlez/src/bindable.rs @@ -5,7 +5,7 @@ use std::{ sync::Arc, }; -use anyhow::Result; +use anyhow::{Context, Result}; use crate::statement::{SqlType, Statement}; @@ -19,61 +19,82 @@ pub trait Column: Sized { impl Bind for bool { fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind(self.then_some(1).unwrap_or(0), start_index) + statement + .bind(self.then_some(1).unwrap_or(0), start_index) + .with_context(|| format!("Failed to bind bool at index {start_index}")) } } impl Column for bool { fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - i32::column(statement, start_index).map(|(i, next_index)| (i != 0, next_index)) + i32::column(statement, start_index) + .map(|(i, next_index)| (i != 0, next_index)) + .with_context(|| format!("Failed to read bool at index {start_index}")) } } impl Bind for &[u8] { fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind_blob(start_index, self)?; + statement + .bind_blob(start_index, self) + .with_context(|| format!("Failed to bind &[u8] at index {start_index}"))?; Ok(start_index + 1) } } impl Bind for &[u8; C] { fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind_blob(start_index, self.as_slice())?; + statement + .bind_blob(start_index, self.as_slice()) + .with_context(|| format!("Failed to bind &[u8; C] at index {start_index}"))?; Ok(start_index + 1) } } impl Bind for Vec { fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind_blob(start_index, self)?; + statement + .bind_blob(start_index, self) + .with_context(|| format!("Failed to bind Vec at index {start_index}"))?; Ok(start_index + 1) } } impl Column for Vec { fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let result = statement.column_blob(start_index)?; + let result = statement + .column_blob(start_index) + .with_context(|| format!("Failed to read Vec at index {start_index}"))?; + Ok((Vec::from(result), start_index + 1)) } } impl Bind for f64 { fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind_double(start_index, *self)?; + statement + .bind_double(start_index, *self) + .with_context(|| format!("Failed to bind f64 at index {start_index}"))?; Ok(start_index + 1) } } impl Column for f64 { fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let result = statement.column_double(start_index)?; + let result = statement + .column_double(start_index) + .with_context(|| format!("Failed to parse f64 at index {start_index}"))?; + Ok((result, start_index + 1)) } } impl Bind for i32 { fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind_int(start_index, *self)?; + statement + .bind_int(start_index, *self) + .with_context(|| format!("Failed to bind i32 at index {start_index}"))?; + Ok(start_index + 1) } } @@ -87,7 +108,9 @@ impl Column for i32 { impl Bind for i64 { fn bind(&self, statement: &Statement, start_index: i32) -> Result { - statement.bind_int64(start_index, *self)?; + statement + .bind_int64(start_index, *self) + .with_context(|| format!("Failed to bind i64 at index {start_index}"))?; Ok(start_index + 1) } } @@ -101,7 +124,9 @@ impl Column for i64 { impl Bind for usize { fn bind(&self, statement: &Statement, start_index: i32) -> Result { - (*self as i64).bind(statement, start_index) + (*self as i64) + .bind(statement, start_index) + .with_context(|| format!("Failed to bind usize at index {start_index}")) } } diff --git a/crates/sqlez/src/connection.rs b/crates/sqlez/src/connection.rs index 1eaeb090e1..5a71cefb52 100644 --- a/crates/sqlez/src/connection.rs +++ b/crates/sqlez/src/connection.rs @@ -1,6 +1,7 @@ use std::{ ffi::{CStr, CString}, marker::PhantomData, + path::Path, }; use anyhow::{anyhow, Result}; @@ -73,6 +74,11 @@ impl Connection { } } + pub fn backup_main_to(&self, destination: impl AsRef) -> Result<()> { + let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref()); + self.backup_main(&destination) + } + pub(crate) fn last_error(&self) -> Result<()> { unsafe { let code = sqlite3_errcode(self.sqlite3); diff --git a/crates/sqlez/src/statement.rs b/crates/sqlez/src/statement.rs index 164929010b..0a7305c6ed 100644 --- a/crates/sqlez/src/statement.rs +++ b/crates/sqlez/src/statement.rs @@ -19,8 +19,6 @@ pub struct Statement<'a> { pub enum StepResult { Row, Done, - Misuse, - Other(i32), } #[derive(Clone, Copy, PartialEq, Eq, Debug)] @@ -40,12 +38,14 @@ impl<'a> Statement<'a> { connection, phantom: PhantomData, }; - unsafe { - let sql = CString::new(query.as_ref())?; + let sql = CString::new(query.as_ref()).context("Error creating cstr")?; let mut remaining_sql = sql.as_c_str(); while { - let remaining_sql_str = remaining_sql.to_str()?.trim(); + let remaining_sql_str = remaining_sql + .to_str() + .context("Parsing remaining sql")? + .trim(); remaining_sql_str != ";" && !remaining_sql_str.is_empty() } { let mut raw_statement = 0 as *mut sqlite3_stmt; @@ -92,116 +92,136 @@ impl<'a> Statement<'a> { } } + fn bind_index_with(&self, index: i32, bind: impl Fn(&*mut sqlite3_stmt) -> ()) -> Result<()> { + let mut any_succeed = false; + unsafe { + for raw_statement in self.raw_statements.iter() { + if index <= sqlite3_bind_parameter_count(*raw_statement) { + bind(raw_statement); + self.connection + .last_error() + .with_context(|| format!("Failed to bind value at index {index}"))?; + any_succeed = true; + } else { + continue; + } + } + } + if any_succeed { + Ok(()) + } else { + Err(anyhow!("Failed to bind parameters")) + } + } + pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> { let index = index as c_int; let blob_pointer = blob.as_ptr() as *const _; let len = blob.len() as c_int; - unsafe { - for raw_statement in self.raw_statements.iter() { - sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT()); - } - } - self.connection.last_error() + + self.bind_index_with(index, |raw_statement| unsafe { + sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT()); + }) } pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> { let index = index as c_int; let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) }; - self.connection.last_error()?; + self.connection + .last_error() + .with_context(|| format!("Failed to read blob at index {index}"))?; if pointer.is_null() { return Ok(&[]); } let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize }; - self.connection.last_error()?; + self.connection + .last_error() + .with_context(|| format!("Failed to read length of blob at index {index}"))?; + unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) } } pub fn bind_double(&self, index: i32, double: f64) -> Result<()> { let index = index as c_int; - unsafe { - for raw_statement in self.raw_statements.iter() { - sqlite3_bind_double(*raw_statement, index, double); - } - } - self.connection.last_error() + self.bind_index_with(index, |raw_statement| unsafe { + sqlite3_bind_double(*raw_statement, index, double); + }) } pub fn column_double(&self, index: i32) -> Result { let index = index as c_int; let result = unsafe { sqlite3_column_double(self.current_statement(), index) }; - self.connection.last_error()?; + self.connection + .last_error() + .with_context(|| format!("Failed to read double at index {index}"))?; Ok(result) } pub fn bind_int(&self, index: i32, int: i32) -> Result<()> { let index = index as c_int; - - unsafe { - for raw_statement in self.raw_statements.iter() { - sqlite3_bind_int(*raw_statement, index, int); - } - }; - self.connection.last_error() + self.bind_index_with(index, |raw_statement| unsafe { + sqlite3_bind_int(*raw_statement, index, int); + }) } pub fn column_int(&self, index: i32) -> Result { let index = index as c_int; let result = unsafe { sqlite3_column_int(self.current_statement(), index) }; - self.connection.last_error()?; + self.connection + .last_error() + .with_context(|| format!("Failed to read int at index {index}"))?; Ok(result) } pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> { let index = index as c_int; - unsafe { - for raw_statement in self.raw_statements.iter() { - sqlite3_bind_int64(*raw_statement, index, int); - } - } - self.connection.last_error() + self.bind_index_with(index, |raw_statement| unsafe { + sqlite3_bind_int64(*raw_statement, index, int); + }) } pub fn column_int64(&self, index: i32) -> Result { let index = index as c_int; let result = unsafe { sqlite3_column_int64(self.current_statement(), index) }; - self.connection.last_error()?; + self.connection + .last_error() + .with_context(|| format!("Failed to read i64 at index {index}"))?; Ok(result) } pub fn bind_null(&self, index: i32) -> Result<()> { let index = index as c_int; - unsafe { - for raw_statement in self.raw_statements.iter() { - sqlite3_bind_null(*raw_statement, index); - } - } - self.connection.last_error() + self.bind_index_with(index, |raw_statement| unsafe { + sqlite3_bind_null(*raw_statement, index); + }) } pub fn bind_text(&self, index: i32, text: &str) -> Result<()> { let index = index as c_int; let text_pointer = text.as_ptr() as *const _; let len = text.len() as c_int; - unsafe { - for raw_statement in self.raw_statements.iter() { - sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT()); - } - } - self.connection.last_error() + + self.bind_index_with(index, |raw_statement| unsafe { + sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT()); + }) } pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> { let index = index as c_int; let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) }; - self.connection.last_error()?; + self.connection + .last_error() + .with_context(|| format!("Failed to read text from column {index}"))?; if pointer.is_null() { return Ok(""); } let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize }; - self.connection.last_error()?; + self.connection + .last_error() + .with_context(|| format!("Failed to read text length at {index}"))?; let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) }; Ok(str::from_utf8(slice)?) @@ -247,11 +267,11 @@ impl<'a> Statement<'a> { self.step() } } - SQLITE_MISUSE => Ok(StepResult::Misuse), - other => self - .connection - .last_error() - .map(|_| StepResult::Other(other)), + SQLITE_MISUSE => Err(anyhow!("Statement step returned SQLITE_MISUSE")), + _other_error => { + self.connection.last_error()?; + unreachable!("Step returned error code and last error failed to catch it"); + } } } } @@ -293,11 +313,17 @@ impl<'a> Statement<'a> { callback: impl FnOnce(&mut Statement) -> Result, ) -> Result { if this.step()? != StepResult::Row { + return Err(anyhow!("single called with query that returns no rows.")); + } + let result = callback(this)?; + + if this.step()? != StepResult::Done { return Err(anyhow!( - "Single(Map) called with query that returns no rows." + "single called with a query that returns more than one row." )); } - callback(this) + + Ok(result) } let result = logic(self, callback); self.reset(); @@ -316,10 +342,21 @@ impl<'a> Statement<'a> { this: &mut Statement, callback: impl FnOnce(&mut Statement) -> Result, ) -> Result> { - if this.step()? != StepResult::Row { + if this.step().context("Failed on step call")? != StepResult::Row { return Ok(None); } - callback(this).map(|r| Some(r)) + + let result = callback(this) + .map(|r| Some(r)) + .context("Failed to parse row result")?; + + if this.step().context("Second step call")? != StepResult::Done { + return Err(anyhow!( + "maybe called with a query that returns more than one row." + )); + } + + Ok(result) } let result = logic(self, callback); self.reset(); @@ -350,6 +387,38 @@ mod test { statement::{Statement, StepResult}, }; + #[test] + fn binding_multiple_statements_with_parameter_gaps() { + let connection = + Connection::open_memory(Some("binding_multiple_statements_with_parameter_gaps")); + + connection + .exec(indoc! {" + CREATE TABLE test ( + col INTEGER + )"}) + .unwrap()() + .unwrap(); + + let statement = Statement::prepare( + &connection, + indoc! {" + INSERT INTO test(col) VALUES (?3); + SELECT * FROM test WHERE col = ?1"}, + ) + .unwrap(); + + statement + .bind_int(1, 1) + .expect("Could not bind parameter to first index"); + statement + .bind_int(2, 2) + .expect("Could not bind parameter to second index"); + statement + .bind_int(3, 3) + .expect("Could not bind parameter to third index"); + } + #[test] fn blob_round_trips() { let connection1 = Connection::open_memory(Some("blob_round_trips")); diff --git a/crates/sqlez/src/typed_statements.rs b/crates/sqlez/src/typed_statements.rs index 98f51b970a..c7d8b20aa5 100644 --- a/crates/sqlez/src/typed_statements.rs +++ b/crates/sqlez/src/typed_statements.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{Context, Result}; use crate::{ bindable::{Bind, Column}, @@ -49,6 +49,12 @@ impl Connection { query: &str, ) -> Result Result>> { let mut statement = Statement::prepare(&self, query)?; - Ok(move |bindings| statement.with_bindings(bindings)?.maybe_row::()) + Ok(move |bindings| { + statement + .with_bindings(bindings) + .context("Bindings failed")? + .maybe_row::() + .context("Maybe row failed") + }) } } diff --git a/crates/terminal/src/persistence.rs b/crates/terminal/src/persistence.rs index 384dcc18e0..07bca0c66f 100644 --- a/crates/terminal/src/persistence.rs +++ b/crates/terminal/src/persistence.rs @@ -29,15 +29,21 @@ impl Domain for Terminal { impl TerminalDb { sql_method! { - save_working_directory(item_id: ItemId, workspace_id: WorkspaceId, working_directory: &Path) -> Result<()>: - "INSERT OR REPLACE INTO terminals(item_id, workspace_id, working_directory) - VALUES (?1, ?2, ?3)" + save_working_directory(item_id: ItemId, + workspace_id: WorkspaceId, + working_directory: &Path) -> Result<()>: + indoc!{" + INSERT OR REPLACE INTO terminals(item_id, workspace_id, working_directory) + VALUES (?1, ?2, ?3) + "} } sql_method! { get_working_directory(item_id: ItemId, workspace_id: WorkspaceId) -> Result>: - "SELECT working_directory - FROM terminals - WHERE item_id = ? AND workspace_id = ?" + indoc!{" + SELECT working_directory + FROM terminals + WHERE item_id = ? AND workspace_id = ? + "} } } diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index a4073d27d3..477e5a4960 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -152,7 +152,7 @@ impl WorkspaceDb { "})?((&workspace.location, workspace.id)) .context("clearing out old locations")?; - // Update or insert + // Upsert self.exec_bound(indoc! { "INSERT INTO workspaces(workspace_id, workspace_location, dock_visible, dock_anchor, timestamp) @@ -190,8 +190,8 @@ impl WorkspaceDb { .log_err(); } - sql_method! { - next_id() -> Result>: + sql_method!{ + next_id() -> Result: "INSERT INTO workspaces DEFAULT VALUES RETURNING workspace_id" } @@ -402,6 +402,10 @@ mod tests { .unwrap(); let id = db.next_id().unwrap(); + // Assert the empty row got inserted + assert_eq!(Some(id), db.select_row_bound:: + ("SELECT workspace_id FROM workspaces WHERE workspace_id = ?").unwrap() + (id).unwrap()); db.exec_bound("INSERT INTO test_table(text, workspace_id) VALUES (?, ?)") .unwrap()(("test-text-1", id)) diff --git a/crates/workspace/src/persistence/model.rs b/crates/workspace/src/persistence/model.rs index 111a6904c6..2f0bc050d2 100644 --- a/crates/workspace/src/persistence/model.rs +++ b/crates/workspace/src/persistence/model.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use anyhow::Result; +use anyhow::{Context, Result}; use async_recursion::async_recursion; use gpui::{AsyncAppContext, Axis, ModelHandle, Task, ViewHandle}; @@ -52,7 +52,7 @@ impl Column for WorkspaceLocation { fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { let blob = statement.column_blob(start_index)?; Ok(( - WorkspaceLocation(bincode::deserialize(blob)?), + WorkspaceLocation(bincode::deserialize(blob).context("Bincode failed")?), start_index + 1, )) } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 0a4a6c8740..155c95e4e8 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -633,11 +633,11 @@ impl Workspace { active_call = Some((call, subscriptions)); } - let id = if let Some(id) = serialized_workspace.as_ref().map(|ws| ws.id) { - id - } else { - DB.next_id().log_err().flatten().unwrap_or(0) - }; + let database_id = serialized_workspace + .as_ref() + .map(|ws| ws.id) + .or_else(|| DB.next_id().log_err()) + .unwrap_or(0); let mut this = Workspace { modal: None, @@ -666,7 +666,7 @@ impl Workspace { last_leaders_by_pane: Default::default(), window_edited: false, active_call, - database_id: id, + database_id, _observe_current_user, }; this.project_remote_id_changed(project.read(cx).remote_id(), cx); diff --git a/dest-term.db b/dest-term.db new file mode 100644 index 0000000000..d6115b0670 Binary files /dev/null and b/dest-term.db differ diff --git a/dest-workspace.db b/dest-workspace.db new file mode 100644 index 0000000000..90682f8642 Binary files /dev/null and b/dest-workspace.db differ diff --git a/dest.db b/dest.db new file mode 100644 index 0000000000..e378341661 Binary files /dev/null and b/dest.db differ