Start moving towards using sea-query to construct queries

This commit is contained in:
Antonio Scandurra 2022-11-29 13:55:08 +01:00
parent d525cfd697
commit ac24600a40
7 changed files with 168 additions and 63 deletions

34
Cargo.lock generated
View file

@ -1065,6 +1065,8 @@ dependencies = [
"reqwest", "reqwest",
"rpc", "rpc",
"scrypt", "scrypt",
"sea-query",
"sea-query-binder",
"serde", "serde",
"serde_json", "serde_json",
"settings", "settings",
@ -5121,6 +5123,38 @@ dependencies = [
"untrusted", "untrusted",
] ]
[[package]]
name = "sea-query"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4f0fc4d8e44e1d51c739a68d336252a18bc59553778075d5e32649be6ec92ed"
dependencies = [
"sea-query-derive",
]
[[package]]
name = "sea-query-binder"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c2585b89c985cfacfe0ec9fc9e7bb055b776c1a2581c4e3c6185af2b8bf8865"
dependencies = [
"sea-query",
"sqlx",
]
[[package]]
name = "sea-query-derive"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34cdc022b4f606353fe5dc85b09713a04e433323b70163e81513b141c6ae6eb5"
dependencies = [
"heck 0.3.3",
"proc-macro2",
"quote",
"syn",
"thiserror",
]
[[package]] [[package]]
name = "seahash" name = "seahash"
version = "4.1.0" version = "4.1.0"

View file

@ -67,6 +67,7 @@ rand = { version = "0.8" }
[patch.crates-io] [patch.crates-io]
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "366210ae925d7ea0891bc7a0c738f60c77c04d7b" } tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "366210ae925d7ea0891bc7a0c738f60c77c04d7b" }
async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" } async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" }
sqlx = { git = "https://github.com/launchbadge/sqlx", rev = "4b7053807c705df312bcb9b6281e184bf7534eb3" }
# TODO - Remove when a version is released with this PR: https://github.com/servo/core-foundation-rs/pull/457 # TODO - Remove when a version is released with this PR: https://github.com/servo/core-foundation-rs/pull/457
cocoa = { git = "https://github.com/servo/core-foundation-rs", rev = "079665882507dd5e2ff77db3de5070c1f6c0fb85" } cocoa = { git = "https://github.com/servo/core-foundation-rs", rev = "079665882507dd5e2ff77db3de5070c1f6c0fb85" }

View file

@ -36,9 +36,12 @@ prometheus = "0.13"
rand = "0.8" rand = "0.8"
reqwest = { version = "0.11", features = ["json"], optional = true } reqwest = { version = "0.11", features = ["json"], optional = true }
scrypt = "0.7" scrypt = "0.7"
sea-query = { version = "0.27", features = ["derive"] }
sea-query-binder = { version = "0.2", features = ["sqlx-postgres"] }
serde = { version = "1.0", features = ["derive", "rc"] } serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0" serde_json = "1.0"
sha-1 = "0.9" sha-1 = "0.9"
sqlx = { version = "0.6", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"] }
time = { version = "0.3", features = ["serde", "serde-well-known"] } time = { version = "0.3", features = ["serde", "serde-well-known"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-tungstenite = "0.17" tokio-tungstenite = "0.17"
@ -49,11 +52,6 @@ tracing = "0.1.34"
tracing-log = "0.1.3" tracing-log = "0.1.3"
tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] } tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] }
[dependencies.sqlx]
git = "https://github.com/launchbadge/sqlx"
rev = "4b7053807c705df312bcb9b6281e184bf7534eb3"
features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"]
[dev-dependencies] [dev-dependencies]
collections = { path = "../collections", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] }
@ -76,13 +74,10 @@ env_logger = "0.9"
log = { version = "0.4.16", features = ["kv_unstable_serde"] } log = { version = "0.4.16", features = ["kv_unstable_serde"] }
util = { path = "../util" } util = { path = "../util" }
lazy_static = "1.4" lazy_static = "1.4"
sea-query-binder = { version = "0.2", features = ["sqlx-sqlite"] }
serde_json = { version = "1.0", features = ["preserve_order"] } serde_json = { version = "1.0", features = ["preserve_order"] }
sqlx = { version = "0.6", features = ["sqlite"] }
unindent = "0.1" unindent = "0.1"
[dev-dependencies.sqlx]
git = "https://github.com/launchbadge/sqlx"
rev = "4b7053807c705df312bcb9b6281e184bf7534eb3"
features = ["sqlite"]
[features] [features]
seed-support = ["clap", "lipsum", "reqwest"] seed-support = ["clap", "lipsum", "reqwest"]

View file

@ -1,3 +1,7 @@
mod schema;
#[cfg(test)]
mod tests;
use crate::{Error, Result}; use crate::{Error, Result};
use anyhow::anyhow; use anyhow::anyhow;
use axum::http::StatusCode; use axum::http::StatusCode;
@ -5,6 +9,8 @@ use collections::{BTreeMap, HashMap, HashSet};
use dashmap::DashMap; use dashmap::DashMap;
use futures::{future::BoxFuture, FutureExt, StreamExt}; use futures::{future::BoxFuture, FutureExt, StreamExt};
use rpc::{proto, ConnectionId}; use rpc::{proto, ConnectionId};
use sea_query::{Expr, Query};
use sea_query_binder::SqlxBinder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{ use sqlx::{
migrate::{Migrate as _, Migration, MigrationSource}, migrate::{Migrate as _, Migration, MigrationSource},
@ -89,6 +95,23 @@ impl BeginTransaction for Db<sqlx::Sqlite> {
} }
} }
pub trait BuildQuery {
fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues);
}
impl BuildQuery for Db<sqlx::Postgres> {
fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues) {
query.build_sqlx(sea_query::PostgresQueryBuilder)
}
}
#[cfg(test)]
impl BuildQuery for Db<sqlx::Sqlite> {
fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues) {
query.build_sqlx(sea_query::SqliteQueryBuilder)
}
}
pub trait RowsAffected { pub trait RowsAffected {
fn rows_affected(&self) -> u64; fn rows_affected(&self) -> u64;
} }
@ -595,10 +618,11 @@ impl Db<sqlx::Postgres> {
impl<D> Db<D> impl<D> Db<D>
where where
Self: BeginTransaction<Database = D>, Self: BeginTransaction<Database = D> + BuildQuery,
D: sqlx::Database + sqlx::migrate::MigrateDatabase, D: sqlx::Database + sqlx::migrate::MigrateDatabase,
D::Connection: sqlx::migrate::Migrate, D::Connection: sqlx::migrate::Migrate,
for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>, for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
for<'a> sea_query_binder::SqlxValues: sqlx::IntoArguments<'a, D>,
for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>, for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>, for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
D::QueryResult: RowsAffected, D::QueryResult: RowsAffected,
@ -1537,63 +1561,66 @@ where
worktrees: &[proto::WorktreeMetadata], worktrees: &[proto::WorktreeMetadata],
) -> Result<RoomGuard<(ProjectId, proto::Room)>> { ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
self.transact(|mut tx| async move { self.transact(|mut tx| async move {
let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( let (sql, values) = self.build_query(
" Query::select()
SELECT room_id, user_id .columns([
FROM room_participants schema::room_participant::Definition::RoomId,
WHERE answering_connection_id = $1 schema::room_participant::Definition::UserId,
", ])
) .from(schema::room_participant::Definition::Table)
.bind(connection_id.0 as i32) .and_where(
Expr::col(schema::room_participant::Definition::AnsweringConnectionId)
.eq(connection_id.0),
),
);
let (room_id, user_id) = sqlx::query_as_with::<_, (RoomId, UserId), _>(&sql, values)
.fetch_one(&mut tx) .fetch_one(&mut tx)
.await?; .await?;
if room_id != expected_room_id { if room_id != expected_room_id {
return Err(anyhow!("shared project on unexpected room"))?; return Err(anyhow!("shared project on unexpected room"))?;
} }
let project_id: ProjectId = sqlx::query_scalar( let (sql, values) = self.build_query(
" Query::insert()
INSERT INTO projects (room_id, host_user_id, host_connection_id) .into_table(schema::project::Definition::Table)
VALUES ($1, $2, $3) .columns([
RETURNING id schema::project::Definition::RoomId,
", schema::project::Definition::HostUserId,
) schema::project::Definition::HostConnectionId,
.bind(room_id) ])
.bind(user_id) .values_panic([room_id.into(), user_id.into(), connection_id.0.into()])
.bind(connection_id.0 as i32) .returning_col(schema::project::Definition::Id),
);
let project_id: ProjectId = sqlx::query_scalar_with(&sql, values)
.fetch_one(&mut tx) .fetch_one(&mut tx)
.await?; .await?;
if !worktrees.is_empty() { if !worktrees.is_empty() {
let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); let mut query = Query::insert()
params.pop(); .into_table(schema::worktree::Definition::Table)
let query = format!( .columns([
" schema::worktree::Definition::ProjectId,
INSERT INTO worktrees ( schema::worktree::Definition::Id,
project_id, schema::worktree::Definition::RootName,
id, schema::worktree::Definition::AbsPath,
root_name, schema::worktree::Definition::Visible,
abs_path, schema::worktree::Definition::ScanId,
visible, schema::worktree::Definition::IsComplete,
scan_id, ])
is_complete .to_owned();
)
VALUES {params}
"
);
let mut query = sqlx::query(&query);
for worktree in worktrees { for worktree in worktrees {
query = query query.values_panic([
.bind(project_id) project_id.into(),
.bind(worktree.id as i32) worktree.id.into(),
.bind(&worktree.root_name) worktree.root_name.clone().into(),
.bind(&worktree.abs_path) worktree.abs_path.clone().into(),
.bind(worktree.visible) worktree.visible.into(),
.bind(0) 0.into(),
.bind(false); false.into(),
]);
} }
query.execute(&mut tx).await?; let (sql, values) = self.build_query(&query);
sqlx::query_with(&sql, values).execute(&mut tx).await?;
} }
sqlx::query( sqlx::query(
@ -2648,6 +2675,12 @@ macro_rules! id_type {
self.0.fmt(f) self.0.fmt(f)
} }
} }
impl From<$name> for sea_query::Value {
fn from(value: $name) -> Self {
sea_query::Value::Int(Some(value.0))
}
}
}; };
} }
@ -2692,6 +2725,7 @@ id_type!(WorktreeId);
#[derive(Clone, Debug, Default, FromRow, PartialEq)] #[derive(Clone, Debug, Default, FromRow, PartialEq)]
struct WorktreeRow { struct WorktreeRow {
pub id: WorktreeId, pub id: WorktreeId,
pub project_id: ProjectId,
pub abs_path: String, pub abs_path: String,
pub root_name: String, pub root_name: String,
pub visible: bool, pub visible: bool,

View file

@ -0,0 +1,43 @@
pub mod project {
use sea_query::Iden;
#[derive(Iden)]
pub enum Definition {
#[iden = "projects"]
Table,
Id,
RoomId,
HostUserId,
HostConnectionId,
}
}
pub mod worktree {
use sea_query::Iden;
#[derive(Iden)]
pub enum Definition {
#[iden = "worktrees"]
Table,
Id,
ProjectId,
AbsPath,
RootName,
Visible,
ScanId,
IsComplete,
}
}
pub mod room_participant {
use sea_query::Iden;
#[derive(Iden)]
pub enum Definition {
#[iden = "room_participants"]
Table,
RoomId,
UserId,
AnsweringConnectionId,
}
}

View file

@ -1,4 +1,4 @@
use super::db::*; use super::*;
use gpui::executor::{Background, Deterministic}; use gpui::executor::{Background, Deterministic};
use std::sync::Arc; use std::sync::Arc;

View file

@ -4,8 +4,6 @@ mod db;
mod env; mod env;
mod rpc; mod rpc;
#[cfg(test)]
mod db_tests;
#[cfg(test)] #[cfg(test)]
mod integration_tests; mod integration_tests;