diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index bb264d5adb..cc29245697 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -4,20 +4,19 @@ mod tables; #[cfg(test)] pub mod tests; -use crate::{Error, Result, executor::Executor}; +use crate::{Error, Result}; use anyhow::{Context as _, anyhow}; use collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use dashmap::DashMap; use futures::StreamExt; use project_repository_statuses::StatusKind; -use rand::{Rng, SeedableRng, prelude::StdRng}; use rpc::ExtensionProvides; use rpc::{ ConnectionId, ExtensionMetadata, proto::{self}, }; use sea_orm::{ - ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr, + ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, TransactionTrait, entity::prelude::*, @@ -33,7 +32,6 @@ use std::{ ops::{Deref, DerefMut}, rc::Rc, sync::Arc, - time::Duration, }; use time::PrimitiveDateTime; use tokio::sync::{Mutex, OwnedMutexGuard}; @@ -58,6 +56,7 @@ pub use tables::*; #[cfg(test)] pub struct DatabaseTestOptions { + pub executor: gpui::BackgroundExecutor, pub runtime: tokio::runtime::Runtime, pub query_failure_probability: parking_lot::Mutex, } @@ -69,8 +68,6 @@ pub struct Database { pool: DatabaseConnection, rooms: DashMap>>, projects: DashMap>>, - rng: Mutex, - executor: Executor, notification_kinds_by_id: HashMap, notification_kinds_by_name: HashMap, #[cfg(test)] @@ -81,17 +78,15 @@ pub struct Database { // separate files in the `queries` folder. impl Database { /// Connects to the database with the given options - pub async fn new(options: ConnectOptions, executor: Executor) -> Result { + pub async fn new(options: ConnectOptions) -> Result { sqlx::any::install_default_drivers(); Ok(Self { options: options.clone(), pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), projects: DashMap::with_capacity(16384), - rng: Mutex::new(StdRng::seed_from_u64(0)), notification_kinds_by_id: HashMap::default(), notification_kinds_by_name: HashMap::default(), - executor, #[cfg(test)] test_options: None, }) @@ -107,48 +102,13 @@ impl Database { self.projects.clear(); } - /// Transaction runs things in a transaction. If you want to call other methods - /// and pass the transaction around you need to reborrow the transaction at each - /// call site with: `&*tx`. pub async fn transaction(&self, f: F) -> Result where F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { let body = async { - let mut i = 0; - loop { - let (tx, result) = self.with_transaction(&f).await?; - match result { - Ok(result) => match tx.commit().await.map_err(Into::into) { - Ok(()) => return Ok(result), - Err(error) => { - if !self.retry_on_serialization_error(&error, i).await { - return Err(error); - } - } - }, - Err(error) => { - tx.rollback().await?; - if !self.retry_on_serialization_error(&error, i).await { - return Err(error); - } - } - } - i += 1; - } - }; - - self.run(body).await - } - - pub async fn weak_transaction(&self, f: F) -> Result - where - F: Send + Fn(TransactionHandle) -> Fut, - Fut: Send + Future>, - { - let body = async { - let (tx, result) = self.with_weak_transaction(&f).await?; + let (tx, result) = self.with_transaction(&f).await?; match result { Ok(result) => match tx.commit().await.map_err(Into::into) { Ok(()) => Ok(result), @@ -174,44 +134,28 @@ impl Database { Fut: Send + Future>>, { let body = async { - let mut i = 0; - loop { - let (tx, result) = self.with_transaction(&f).await?; - match result { - Ok(Some((room_id, data))) => { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - match tx.commit().await.map_err(Into::into) { - Ok(()) => { - return Ok(Some(TransactionGuard { - data, - _guard, - _not_send: PhantomData, - })); - } - Err(error) => { - if !self.retry_on_serialization_error(&error, i).await { - return Err(error); - } - } - } - } - Ok(None) => match tx.commit().await.map_err(Into::into) { - Ok(()) => return Ok(None), - Err(error) => { - if !self.retry_on_serialization_error(&error, i).await { - return Err(error); - } - } - }, - Err(error) => { - tx.rollback().await?; - if !self.retry_on_serialization_error(&error, i).await { - return Err(error); - } + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(Some((room_id, data))) => { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + match tx.commit().await.map_err(Into::into) { + Ok(()) => Ok(Some(TransactionGuard { + data, + _guard, + _not_send: PhantomData, + })), + Err(error) => Err(error), } } - i += 1; + Ok(None) => match tx.commit().await.map_err(Into::into) { + Ok(()) => Ok(None), + Err(error) => Err(error), + }, + Err(error) => { + tx.rollback().await?; + Err(error) + } } }; @@ -229,38 +173,26 @@ impl Database { { let room_id = Database::room_id_for_project(self, project_id).await?; let body = async { - let mut i = 0; - loop { - let lock = if let Some(room_id) = room_id { - self.rooms.entry(room_id).or_default().clone() - } else { - self.projects.entry(project_id).or_default().clone() - }; - let _guard = lock.lock_owned().await; - let (tx, result) = self.with_transaction(&f).await?; - match result { - Ok(data) => match tx.commit().await.map_err(Into::into) { - Ok(()) => { - return Ok(TransactionGuard { - data, - _guard, - _not_send: PhantomData, - }); - } - Err(error) => { - if !self.retry_on_serialization_error(&error, i).await { - return Err(error); - } - } - }, - Err(error) => { - tx.rollback().await?; - if !self.retry_on_serialization_error(&error, i).await { - return Err(error); - } - } + let lock = if let Some(room_id) = room_id { + self.rooms.entry(room_id).or_default().clone() + } else { + self.projects.entry(project_id).or_default().clone() + }; + let _guard = lock.lock_owned().await; + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(data) => match tx.commit().await.map_err(Into::into) { + Ok(()) => Ok(TransactionGuard { + data, + _guard, + _not_send: PhantomData, + }), + Err(error) => Err(error), + }, + Err(error) => { + tx.rollback().await?; + Err(error) } - i += 1; } }; @@ -280,34 +212,22 @@ impl Database { Fut: Send + Future>, { let body = async { - let mut i = 0; - loop { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - let (tx, result) = self.with_transaction(&f).await?; - match result { - Ok(data) => match tx.commit().await.map_err(Into::into) { - Ok(()) => { - return Ok(TransactionGuard { - data, - _guard, - _not_send: PhantomData, - }); - } - Err(error) => { - if !self.retry_on_serialization_error(&error, i).await { - return Err(error); - } - } - }, - Err(error) => { - tx.rollback().await?; - if !self.retry_on_serialization_error(&error, i).await { - return Err(error); - } - } + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(data) => match tx.commit().await.map_err(Into::into) { + Ok(()) => Ok(TransactionGuard { + data, + _guard, + _not_send: PhantomData, + }), + Err(error) => Err(error), + }, + Err(error) => { + tx.rollback().await?; + Err(error) } - i += 1; } }; @@ -315,28 +235,6 @@ impl Database { } async fn with_transaction(&self, f: &F) -> Result<(DatabaseTransaction, Result)> - where - F: Send + Fn(TransactionHandle) -> Fut, - Fut: Send + Future>, - { - let tx = self - .pool - .begin_with_config(Some(IsolationLevel::Serializable), None) - .await?; - - let mut tx = Arc::new(Some(tx)); - let result = f(TransactionHandle(tx.clone())).await; - let tx = Arc::get_mut(&mut tx) - .and_then(|tx| tx.take()) - .context("couldn't complete transaction because it's still in use")?; - - Ok((tx, result)) - } - - async fn with_weak_transaction( - &self, - f: &F, - ) -> Result<(DatabaseTransaction, Result)> where F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, @@ -361,13 +259,13 @@ impl Database { { #[cfg(test)] { + use rand::prelude::*; + let test_options = self.test_options.as_ref().unwrap(); - if let Executor::Deterministic(executor) = &self.executor { - executor.simulate_random_delay().await; - let fail_probability = *test_options.query_failure_probability.lock(); - if executor.rng().gen_bool(fail_probability) { - return Err(anyhow!("simulated query failure"))?; - } + test_options.executor.simulate_random_delay().await; + let fail_probability = *test_options.query_failure_probability.lock(); + if test_options.executor.rng().gen_bool(fail_probability) { + return Err(anyhow!("simulated query failure"))?; } test_options.runtime.block_on(future) @@ -378,46 +276,6 @@ impl Database { future.await } } - - async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool { - // If the error is due to a failure to serialize concurrent transactions, then retry - // this transaction after a delay. With each subsequent retry, double the delay duration. - // Also vary the delay randomly in order to ensure different database connections retry - // at different times. - const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.]; - if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() { - let base_delay = SLEEPS[prev_attempt_count]; - let randomized_delay = base_delay * self.rng.lock().await.gen_range(0.5..=2.0); - log::warn!( - "retrying transaction after serialization error. delay: {} ms.", - randomized_delay - ); - self.executor - .sleep(Duration::from_millis(randomized_delay as u64)) - .await; - true - } else { - false - } - } -} - -fn is_serialization_error(error: &Error) -> bool { - const SERIALIZATION_FAILURE_CODE: &str = "40001"; - match error { - Error::Database( - DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) - | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), - ) if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some(SERIALIZATION_FAILURE_CODE) => - { - true - } - _ => false, - } } /// A handle to a [`DatabaseTransaction`]. diff --git a/crates/collab/src/db/queries/billing_customers.rs b/crates/collab/src/db/queries/billing_customers.rs index eaa3edf7c0..ead9e6cd32 100644 --- a/crates/collab/src/db/queries/billing_customers.rs +++ b/crates/collab/src/db/queries/billing_customers.rs @@ -20,7 +20,7 @@ impl Database { &self, params: &CreateBillingCustomerParams, ) -> Result { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let customer = billing_customer::Entity::insert(billing_customer::ActiveModel { user_id: ActiveValue::set(params.user_id), stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()), @@ -40,7 +40,7 @@ impl Database { id: BillingCustomerId, params: &UpdateBillingCustomerParams, ) -> Result<()> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { billing_customer::Entity::update(billing_customer::ActiveModel { id: ActiveValue::set(id), user_id: params.user_id.clone(), @@ -61,7 +61,7 @@ impl Database { &self, id: BillingCustomerId, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { Ok(billing_customer::Entity::find() .filter(billing_customer::Column::Id.eq(id)) .one(&*tx) @@ -75,7 +75,7 @@ impl Database { &self, user_id: UserId, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { Ok(billing_customer::Entity::find() .filter(billing_customer::Column::UserId.eq(user_id)) .one(&*tx) @@ -89,7 +89,7 @@ impl Database { &self, stripe_customer_id: &str, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { Ok(billing_customer::Entity::find() .filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id)) .one(&*tx) diff --git a/crates/collab/src/db/queries/billing_preferences.rs b/crates/collab/src/db/queries/billing_preferences.rs index 55a9dd20a2..1a6fbe946a 100644 --- a/crates/collab/src/db/queries/billing_preferences.rs +++ b/crates/collab/src/db/queries/billing_preferences.rs @@ -22,7 +22,7 @@ impl Database { &self, user_id: UserId, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { Ok(billing_preference::Entity::find() .filter(billing_preference::Column::UserId.eq(user_id)) .one(&*tx) @@ -37,7 +37,7 @@ impl Database { user_id: UserId, params: &CreateBillingPreferencesParams, ) -> Result { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel { user_id: ActiveValue::set(user_id), max_monthly_llm_usage_spending_in_cents: ActiveValue::set( @@ -65,7 +65,7 @@ impl Database { user_id: UserId, params: &UpdateBillingPreferencesParams, ) -> Result { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let preferences = billing_preference::Entity::update_many() .set(billing_preference::ActiveModel { max_monthly_llm_usage_spending_in_cents: params diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 88b208751f..f25d0abeaa 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -35,7 +35,7 @@ impl Database { &self, params: &CreateBillingSubscriptionParams, ) -> Result { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel { billing_customer_id: ActiveValue::set(params.billing_customer_id), kind: ActiveValue::set(params.kind), @@ -64,7 +64,7 @@ impl Database { id: BillingSubscriptionId, params: &UpdateBillingSubscriptionParams, ) -> Result<()> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { billing_subscription::Entity::update(billing_subscription::ActiveModel { id: ActiveValue::set(id), billing_customer_id: params.billing_customer_id.clone(), @@ -90,7 +90,7 @@ impl Database { &self, id: BillingSubscriptionId, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { Ok(billing_subscription::Entity::find_by_id(id) .one(&*tx) .await?) @@ -103,7 +103,7 @@ impl Database { &self, stripe_subscription_id: &str, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { Ok(billing_subscription::Entity::find() .filter( billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id), @@ -118,7 +118,7 @@ impl Database { &self, user_id: UserId, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { Ok(billing_subscription::Entity::find() .inner_join(billing_customer::Entity) .filter(billing_customer::Column::UserId.eq(user_id)) @@ -152,7 +152,7 @@ impl Database { &self, user_id: UserId, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let subscriptions = billing_subscription::Entity::find() .inner_join(billing_customer::Entity) .filter(billing_customer::Column::UserId.eq(user_id)) @@ -169,7 +169,7 @@ impl Database { &self, user_ids: HashSet, ) -> Result> { - self.weak_transaction(|tx| { + self.transaction(|tx| { let user_ids = user_ids.clone(); async move { let mut rows = billing_subscription::Entity::find() @@ -201,7 +201,7 @@ impl Database { &self, user_ids: HashSet, ) -> Result> { - self.weak_transaction(|tx| { + self.transaction(|tx| { let user_ids = user_ids.clone(); async move { let mut rows = billing_subscription::Entity::find() @@ -236,7 +236,7 @@ impl Database { /// Returns the count of the active billing subscriptions for the user with the specified ID. pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let count = billing_subscription::Entity::find() .inner_join(billing_customer::Entity) .filter( diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index 9a370bb73b..5e296e0a3b 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -501,10 +501,8 @@ impl Database { /// Returns all channels for the user with the given ID. pub async fn get_channels_for_user(&self, user_id: UserId) -> Result { - self.weak_transaction( - |tx| async move { self.get_user_channels(user_id, None, true, &tx).await }, - ) - .await + self.transaction(|tx| async move { self.get_user_channels(user_id, None, true, &tx).await }) + .await } /// Returns all channels for the user with the given ID that are descendants diff --git a/crates/collab/src/db/queries/contacts.rs b/crates/collab/src/db/queries/contacts.rs index e1e063ce23..8521814bdb 100644 --- a/crates/collab/src/db/queries/contacts.rs +++ b/crates/collab/src/db/queries/contacts.rs @@ -15,7 +15,7 @@ impl Database { user_b_busy: bool, } - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let user_a_participant = Alias::new("user_a_participant"); let user_b_participant = Alias::new("user_b_participant"); let mut db_contacts = contact::Entity::find() @@ -91,7 +91,7 @@ impl Database { /// Returns whether the given user is a busy (on a call). pub async fn is_user_busy(&self, user_id: UserId) -> Result { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let participant = room_participant::Entity::find() .filter(room_participant::Column::UserId.eq(user_id)) .one(&*tx) diff --git a/crates/collab/src/db/queries/contributors.rs b/crates/collab/src/db/queries/contributors.rs index 76e5267d91..6f675a5fe7 100644 --- a/crates/collab/src/db/queries/contributors.rs +++ b/crates/collab/src/db/queries/contributors.rs @@ -9,7 +9,7 @@ pub enum ContributorSelector { impl Database { /// Retrieves the GitHub logins of all users who have signed the CLA. pub async fn get_contributors(&self) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryGithubLogin { GithubLogin, @@ -32,7 +32,7 @@ impl Database { &self, selector: &ContributorSelector, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let condition = match selector { ContributorSelector::GitHubUserId { github_user_id } => { user::Column::GithubUserId.eq(*github_user_id) @@ -69,7 +69,7 @@ impl Database { github_user_created_at: DateTimeUtc, initial_channel_id: Option, ) -> Result<()> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let user = self .update_or_create_user_by_github_account_tx( github_login, diff --git a/crates/collab/src/db/queries/embeddings.rs b/crates/collab/src/db/queries/embeddings.rs index d901b59659..6ae8013284 100644 --- a/crates/collab/src/db/queries/embeddings.rs +++ b/crates/collab/src/db/queries/embeddings.rs @@ -8,7 +8,7 @@ impl Database { model: &str, digests: &[Vec], ) -> Result, Vec>> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let embeddings = { let mut db_embeddings = embedding::Entity::find() .filter( @@ -52,7 +52,7 @@ impl Database { model: &str, embeddings: &HashMap, Vec>, ) -> Result<()> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { embedding::Entity::insert_many(embeddings.iter().map(|(digest, dimensions)| { let now_offset_datetime = OffsetDateTime::now_utc(); let retrieved_at = @@ -78,7 +78,7 @@ impl Database { } pub async fn purge_old_embeddings(&self) -> Result<()> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { embedding::Entity::delete_many() .filter( embedding::Column::RetrievedAt diff --git a/crates/collab/src/db/queries/extensions.rs b/crates/collab/src/db/queries/extensions.rs index 90f88179c5..7d8aad2be4 100644 --- a/crates/collab/src/db/queries/extensions.rs +++ b/crates/collab/src/db/queries/extensions.rs @@ -15,7 +15,7 @@ impl Database { max_schema_version: i32, limit: usize, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let mut condition = Condition::all() .add( extension::Column::LatestVersion @@ -43,7 +43,7 @@ impl Database { ids: &[&str], constraints: Option<&ExtensionVersionConstraints>, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let extensions = extension::Entity::find() .filter(extension::Column::ExternalId.is_in(ids.iter().copied())) .all(&*tx) @@ -123,7 +123,7 @@ impl Database { &self, extension_id: &str, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let condition = extension::Column::ExternalId .eq(extension_id) .into_condition(); @@ -162,7 +162,7 @@ impl Database { extension_id: &str, constraints: Option<&ExtensionVersionConstraints>, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let extension = extension::Entity::find() .filter(extension::Column::ExternalId.eq(extension_id)) .one(&*tx) @@ -187,7 +187,7 @@ impl Database { extension_id: &str, version: &str, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let extension = extension::Entity::find() .filter(extension::Column::ExternalId.eq(extension_id)) .filter(extension_version::Column::Version.eq(version)) @@ -204,7 +204,7 @@ impl Database { } pub async fn get_known_extension_versions(&self) -> Result>> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let mut extension_external_ids_by_id = HashMap::default(); let mut rows = extension::Entity::find().stream(&*tx).await?; @@ -242,7 +242,7 @@ impl Database { &self, versions_by_extension_id: &HashMap<&str, Vec>, ) -> Result<()> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { for (external_id, versions) in versions_by_extension_id { if versions.is_empty() { continue; @@ -349,7 +349,7 @@ impl Database { } pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryId { Id, diff --git a/crates/collab/src/db/queries/processed_stripe_events.rs b/crates/collab/src/db/queries/processed_stripe_events.rs index 8e92cff98f..f14ad480e0 100644 --- a/crates/collab/src/db/queries/processed_stripe_events.rs +++ b/crates/collab/src/db/queries/processed_stripe_events.rs @@ -13,7 +13,7 @@ impl Database { &self, params: &CreateProcessedStripeEventParams, ) -> Result<()> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel { stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()), stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()), @@ -35,7 +35,7 @@ impl Database { &self, event_id: &str, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { Ok(processed_stripe_event::Entity::find_by_id(event_id) .one(&*tx) .await?) @@ -48,7 +48,7 @@ impl Database { &self, event_ids: &[&str], ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { Ok(processed_stripe_event::Entity::find() .filter( processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()), diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index ae244b2516..ba22a7b4e3 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -112,7 +112,7 @@ impl Database { } pub async fn delete_project(&self, project_id: ProjectId) -> Result<()> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { project::Entity::delete_by_id(project_id).exec(&*tx).await?; Ok(()) }) diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index 33eaa95aa2..cb805786dd 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -80,7 +80,7 @@ impl Database { &self, user_id: UserId, ) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { let pending_participant = room_participant::Entity::find() .filter( room_participant::Column::UserId diff --git a/crates/collab/src/db/queries/users.rs b/crates/collab/src/db/queries/users.rs index 12587c0faf..4b0f66fcbe 100644 --- a/crates/collab/src/db/queries/users.rs +++ b/crates/collab/src/db/queries/users.rs @@ -382,7 +382,7 @@ impl Database { /// Returns the active flags for the user. pub async fn get_user_flags(&self, user: UserId) -> Result> { - self.weak_transaction(|tx| async move { + self.transaction(|tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryAs { Flag, diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 2fc00fd13c..9404e2670c 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -17,11 +17,15 @@ use crate::migrations::run_database_migrations; use super::*; use gpui::BackgroundExecutor; use parking_lot::Mutex; +use rand::prelude::*; use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; -use std::sync::{ - Arc, - atomic::{AtomicI32, AtomicU32, Ordering::SeqCst}, +use std::{ + sync::{ + Arc, + atomic::{AtomicI32, AtomicU32, Ordering::SeqCst}, + }, + time::Duration, }; pub struct TestDb { @@ -41,9 +45,7 @@ impl TestDb { let mut db = runtime.block_on(async { let mut options = ConnectOptions::new(url); options.max_connections(5); - let mut db = Database::new(options, Executor::Deterministic(executor.clone())) - .await - .unwrap(); + let mut db = Database::new(options).await.unwrap(); let sql = include_str!(concat!( env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite/20221109000000_test_schema.sql" @@ -60,6 +62,7 @@ impl TestDb { }); db.test_options = Some(DatabaseTestOptions { + executor, runtime, query_failure_probability: parking_lot::Mutex::new(0.0), }); @@ -93,9 +96,7 @@ impl TestDb { options .max_connections(5) .idle_timeout(Duration::from_secs(0)); - let mut db = Database::new(options, Executor::Deterministic(executor.clone())) - .await - .unwrap(); + let mut db = Database::new(options).await.unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); run_database_migrations(db.options(), migrations_path) .await @@ -105,6 +106,7 @@ impl TestDb { }); db.test_options = Some(DatabaseTestOptions { + executor, runtime, query_failure_probability: parking_lot::Mutex::new(0.0), }); diff --git a/crates/collab/src/db/tests/embedding_tests.rs b/crates/collab/src/db/tests/embedding_tests.rs index bfc238dd9a..367e89f87b 100644 --- a/crates/collab/src/db/tests/embedding_tests.rs +++ b/crates/collab/src/db/tests/embedding_tests.rs @@ -49,7 +49,7 @@ async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) { db.save_embeddings(model, &embeddings).await.unwrap(); // Reach into the DB and change the retrieved at to be > 60 days - db.weak_transaction(|tx| { + db.transaction(|tx| { let digest = digest.clone(); async move { let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61)); diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 95922f411c..2b20c8f080 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -285,7 +285,7 @@ impl AppState { pub async fn new(config: Config, executor: Executor) -> Result> { let mut db_options = db::ConnectOptions::new(config.database_url.clone()); db_options.max_connections(config.database_max_connections); - let mut db = Database::new(db_options, Executor::Production).await?; + let mut db = Database::new(db_options).await?; db.initialize_notification_kinds().await?; let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 6bdff74938..6a78049b3f 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -59,7 +59,7 @@ async fn main() -> Result<()> { let config = envy::from_env::().expect("error loading config"); let db_options = db::ConnectOptions::new(config.database_url.clone()); - let mut db = Database::new(db_options, Executor::Production).await?; + let mut db = Database::new(db_options).await?; db.initialize_notification_kinds().await?; collab::seed::seed(&config, &db, false).await?; @@ -253,7 +253,7 @@ async fn main() -> Result<()> { async fn setup_app_database(config: &Config) -> Result<()> { let db_options = db::ConnectOptions::new(config.database_url.clone()); - let mut db = Database::new(db_options, Executor::Production).await?; + let mut db = Database::new(db_options).await?; let migrations_path = config.migrations_path.as_deref().unwrap_or_else(|| { #[cfg(feature = "sqlite")]