Introduce Refinement trait and derive macro
This commit is contained in:
parent
19ccb19c96
commit
9b74dc196e
22 changed files with 6164 additions and 244 deletions
14
crates/refineable/Cargo.toml
Normal file
14
crates/refineable/Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "refineable"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
path = "src/refineable.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
syn = "1.0.72"
|
||||
quote = "1.0.9"
|
||||
proc-macro2 = "1.0.66"
|
||||
derive_refineable = { path = "./derive_refineable" }
|
14
crates/refineable/derive_refineable/Cargo.toml
Normal file
14
crates/refineable/derive_refineable/Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "derive_refineable"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
path = "src/derive_refineable.rs"
|
||||
proc-macro = true
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
syn = "1.0.72"
|
||||
quote = "1.0.9"
|
||||
proc-macro2 = "1.0.66"
|
162
crates/refineable/derive_refineable/src/derive_refineable.rs
Normal file
162
crates/refineable/derive_refineable/src/derive_refineable.rs
Normal file
|
@ -0,0 +1,162 @@
|
|||
use proc_macro::TokenStream;
|
||||
use proc_macro2::TokenStream as TokenStream2;
|
||||
use quote::{format_ident, quote};
|
||||
use syn::{
|
||||
parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
|
||||
Type, TypeParamBound, WhereClause, WherePredicate,
|
||||
};
|
||||
|
||||
#[proc_macro_derive(Refineable, attributes(refineable))]
|
||||
pub fn derive_refineable(input: TokenStream) -> TokenStream {
|
||||
let DeriveInput {
|
||||
ident,
|
||||
data,
|
||||
generics,
|
||||
..
|
||||
} = parse_macro_input!(input);
|
||||
|
||||
let refinement_ident = format_ident!("{}Refinement", ident);
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
|
||||
let fields = match data {
|
||||
syn::Data::Struct(syn::DataStruct {
|
||||
fields: syn::Fields::Named(FieldsNamed { named, .. }),
|
||||
..
|
||||
}) => named.into_iter().collect::<Vec<Field>>(),
|
||||
_ => panic!("This derive macro only supports structs with named fields"),
|
||||
};
|
||||
|
||||
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
|
||||
let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
|
||||
let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
|
||||
|
||||
// Create trait bound that each wrapped type must implement Clone & Default
|
||||
let type_param_bounds: Vec<_> = wrapped_types
|
||||
.iter()
|
||||
.map(|ty| {
|
||||
WherePredicate::Type(PredicateType {
|
||||
lifetimes: None,
|
||||
bounded_ty: ty.clone(),
|
||||
colon_token: Default::default(),
|
||||
bounds: {
|
||||
let mut punctuated = syn::punctuated::Punctuated::new();
|
||||
punctuated.push_value(TypeParamBound::Trait(TraitBound {
|
||||
paren_token: None,
|
||||
modifier: syn::TraitBoundModifier::None,
|
||||
lifetimes: None,
|
||||
path: parse_quote!(std::clone::Clone),
|
||||
}));
|
||||
punctuated.push_punct(syn::token::Add::default());
|
||||
punctuated.push_value(TypeParamBound::Trait(TraitBound {
|
||||
paren_token: None,
|
||||
modifier: syn::TraitBoundModifier::None,
|
||||
lifetimes: None,
|
||||
path: parse_quote!(std::default::Default),
|
||||
}));
|
||||
punctuated
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Append to where_clause or create a new one if it doesn't exist
|
||||
let where_clause = match where_clause.cloned() {
|
||||
Some(mut where_clause) => {
|
||||
where_clause
|
||||
.predicates
|
||||
.extend(type_param_bounds.into_iter());
|
||||
where_clause.clone()
|
||||
}
|
||||
None => WhereClause {
|
||||
where_token: Default::default(),
|
||||
predicates: type_param_bounds.into_iter().collect(),
|
||||
},
|
||||
};
|
||||
|
||||
let field_initializations: Vec<TokenStream2> = fields
|
||||
.iter()
|
||||
.map(|field| {
|
||||
let name = &field.ident;
|
||||
let is_refineable = is_refineable_field(field);
|
||||
let is_optional = is_optional_field(field);
|
||||
|
||||
if is_refineable {
|
||||
quote! {
|
||||
clone.#name = self.#name.refine(&refinement.#name);
|
||||
}
|
||||
} else if is_optional {
|
||||
quote! {
|
||||
if let Some(ref value) = &refinement.#name {
|
||||
clone.#name = Some(value.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
if let Some(ref value) = &refinement.#name {
|
||||
clone.#name = value.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let gen = quote! {
|
||||
#[derive(Default, Clone)]
|
||||
pub struct #refinement_ident #impl_generics {
|
||||
#( #field_visibilities #field_names: #wrapped_types ),*
|
||||
}
|
||||
|
||||
impl #impl_generics Refineable for #ident #ty_generics
|
||||
#where_clause
|
||||
{
|
||||
type Refinement = #refinement_ident #ty_generics;
|
||||
|
||||
fn refine(&self, refinement: &Self::Refinement) -> Self {
|
||||
let mut clone = self.clone();
|
||||
#( #field_initializations )*
|
||||
clone
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
println!("{}", gen);
|
||||
|
||||
gen.into()
|
||||
}
|
||||
|
||||
fn is_refineable_field(f: &Field) -> bool {
|
||||
f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
|
||||
}
|
||||
|
||||
fn is_optional_field(f: &Field) -> bool {
|
||||
if let Type::Path(typepath) = &f.ty {
|
||||
if typepath.qself.is_none() {
|
||||
let segments = &typepath.path.segments;
|
||||
if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
|
||||
if is_refineable_field(field) {
|
||||
let struct_name = if let Type::Path(tp) = ty {
|
||||
tp.path.segments.last().unwrap().ident.clone()
|
||||
} else {
|
||||
panic!("Expected struct type for a refineable field");
|
||||
};
|
||||
let refinement_struct_name = format_ident!("{}Refinement", struct_name);
|
||||
let generics = if let Type::Path(tp) = ty {
|
||||
&tp.path.segments.last().unwrap().arguments
|
||||
} else {
|
||||
&syn::PathArguments::None
|
||||
};
|
||||
parse_quote!(#refinement_struct_name #generics)
|
||||
} else if is_optional_field(field) {
|
||||
ty.clone()
|
||||
} else {
|
||||
parse_quote!(Option<#ty>)
|
||||
}
|
||||
}
|
13
crates/refineable/src/refineable.rs
Normal file
13
crates/refineable/src/refineable.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
pub use derive_refineable::Refineable;
|
||||
|
||||
pub trait Refineable {
|
||||
type Refinement;
|
||||
|
||||
fn refine(&self, refinement: &Self::Refinement) -> Self;
|
||||
fn from_refinement(refinement: &Self::Refinement) -> Self
|
||||
where
|
||||
Self: Sized + Default,
|
||||
{
|
||||
Self::default().refine(refinement)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue