Introduce Refinement trait and derive macro

This commit is contained in:
Nathan Sobo 2023-08-18 01:03:46 -06:00
parent 19ccb19c96
commit 9b74dc196e
22 changed files with 6164 additions and 244 deletions

View file

@ -0,0 +1,14 @@
[package]
name = "refineable"
version = "0.1.0"
edition = "2021"
[lib]
path = "src/refineable.rs"
doctest = false
[dependencies]
syn = "1.0.72"
quote = "1.0.9"
proc-macro2 = "1.0.66"
derive_refineable = { path = "./derive_refineable" }

View file

@ -0,0 +1,14 @@
[package]
name = "derive_refineable"
version = "0.1.0"
edition = "2021"
[lib]
path = "src/derive_refineable.rs"
proc-macro = true
doctest = false
[dependencies]
syn = "1.0.72"
quote = "1.0.9"
proc-macro2 = "1.0.66"

View file

@ -0,0 +1,162 @@
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{
parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
Type, TypeParamBound, WhereClause, WherePredicate,
};
#[proc_macro_derive(Refineable, attributes(refineable))]
pub fn derive_refineable(input: TokenStream) -> TokenStream {
let DeriveInput {
ident,
data,
generics,
..
} = parse_macro_input!(input);
let refinement_ident = format_ident!("{}Refinement", ident);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let fields = match data {
syn::Data::Struct(syn::DataStruct {
fields: syn::Fields::Named(FieldsNamed { named, .. }),
..
}) => named.into_iter().collect::<Vec<Field>>(),
_ => panic!("This derive macro only supports structs with named fields"),
};
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
// Create trait bound that each wrapped type must implement Clone & Default
let type_param_bounds: Vec<_> = wrapped_types
.iter()
.map(|ty| {
WherePredicate::Type(PredicateType {
lifetimes: None,
bounded_ty: ty.clone(),
colon_token: Default::default(),
bounds: {
let mut punctuated = syn::punctuated::Punctuated::new();
punctuated.push_value(TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None,
path: parse_quote!(std::clone::Clone),
}));
punctuated.push_punct(syn::token::Add::default());
punctuated.push_value(TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: syn::TraitBoundModifier::None,
lifetimes: None,
path: parse_quote!(std::default::Default),
}));
punctuated
},
})
})
.collect();
// Append to where_clause or create a new one if it doesn't exist
let where_clause = match where_clause.cloned() {
Some(mut where_clause) => {
where_clause
.predicates
.extend(type_param_bounds.into_iter());
where_clause.clone()
}
None => WhereClause {
where_token: Default::default(),
predicates: type_param_bounds.into_iter().collect(),
},
};
let field_initializations: Vec<TokenStream2> = fields
.iter()
.map(|field| {
let name = &field.ident;
let is_refineable = is_refineable_field(field);
let is_optional = is_optional_field(field);
if is_refineable {
quote! {
clone.#name = self.#name.refine(&refinement.#name);
}
} else if is_optional {
quote! {
if let Some(ref value) = &refinement.#name {
clone.#name = Some(value.clone());
}
}
} else {
quote! {
if let Some(ref value) = &refinement.#name {
clone.#name = value.clone();
}
}
}
})
.collect();
let gen = quote! {
#[derive(Default, Clone)]
pub struct #refinement_ident #impl_generics {
#( #field_visibilities #field_names: #wrapped_types ),*
}
impl #impl_generics Refineable for #ident #ty_generics
#where_clause
{
type Refinement = #refinement_ident #ty_generics;
fn refine(&self, refinement: &Self::Refinement) -> Self {
let mut clone = self.clone();
#( #field_initializations )*
clone
}
}
};
println!("{}", gen);
gen.into()
}
fn is_refineable_field(f: &Field) -> bool {
f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
}
fn is_optional_field(f: &Field) -> bool {
if let Type::Path(typepath) = &f.ty {
if typepath.qself.is_none() {
let segments = &typepath.path.segments;
if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
return true;
}
}
}
false
}
fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
if is_refineable_field(field) {
let struct_name = if let Type::Path(tp) = ty {
tp.path.segments.last().unwrap().ident.clone()
} else {
panic!("Expected struct type for a refineable field");
};
let refinement_struct_name = format_ident!("{}Refinement", struct_name);
let generics = if let Type::Path(tp) = ty {
&tp.path.segments.last().unwrap().arguments
} else {
&syn::PathArguments::None
};
parse_quote!(#refinement_struct_name #generics)
} else if is_optional_field(field) {
ty.clone()
} else {
parse_quote!(Option<#ty>)
}
}

View file

@ -0,0 +1,13 @@
pub use derive_refineable::Refineable;
pub trait Refineable {
type Refinement;
fn refine(&self, refinement: &Self::Refinement) -> Self;
fn from_refinement(refinement: &Self::Refinement) -> Self
where
Self: Sized + Default,
{
Self::default().refine(refinement)
}
}