From 8c261ff51ca29471a43a0868857e658b9c9b4360 Mon Sep 17 00:00:00 2001 From: Magic Len Date: Sat, 9 Dec 2023 22:47:43 +0800 Subject: [PATCH] improve auto bound --- Cargo.toml | 2 +- README.md | 2 +- src/common/bound.rs | 2 +- src/common/type.rs | 130 +++++++----------- src/common/where_predicates_bool.rs | 4 +- src/lib.rs | 2 +- src/trait_handlers/clone/clone_enum.rs | 2 +- src/trait_handlers/clone/clone_struct.rs | 2 +- src/trait_handlers/copy/mod.rs | 12 ++ src/trait_handlers/debug/common.rs | 2 +- src/trait_handlers/debug/debug_enum.rs | 2 +- src/trait_handlers/debug/debug_struct.rs | 2 +- src/trait_handlers/default/default_enum.rs | 2 +- src/trait_handlers/default/default_struct.rs | 2 +- src/trait_handlers/default/default_union.rs | 2 +- src/trait_handlers/eq/mod.rs | 12 ++ src/trait_handlers/hash/hash_enum.rs | 2 +- src/trait_handlers/hash/hash_struct.rs | 2 +- src/trait_handlers/ord/ord_enum.rs | 2 +- src/trait_handlers/ord/ord_struct.rs | 2 +- .../partial_eq/partial_eq_enum.rs | 2 +- .../partial_eq/partial_eq_struct.rs | 2 +- .../partial_ord/partial_ord_enum.rs | 2 +- .../partial_ord/partial_ord_struct.rs | 2 +- 24 files changed, 99 insertions(+), 99 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b397ad3..c0c18f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "educe" -version = "0.5.1" +version = "0.5.2" authors = ["Magic Len "] edition = "2021" rust-version = "1.56" diff --git a/README.md b/README.md index 63064b8..515eded 100644 --- a/README.md +++ b/README.md @@ -288,7 +288,7 @@ enum Enum { ###### Generic Parameters Bound to the `Clone` Trait or Others -Generic parameters will be automatically bound to the `Clone` trait if necessary. If the `#[educe(Copy)]` attribute exists, all generic parameters will bound to the `Copy` trait. +Generic parameters will be automatically bound to the `Clone` trait if necessary. If the `#[educe(Copy)]` attribute exists, they will be bound to the `Copy` trait. ```rust use educe::Educe; diff --git a/src/common/bound.rs b/src/common/bound.rs index cd86839..6ee7a74 100644 --- a/src/common/bound.rs +++ b/src/common/bound.rs @@ -52,7 +52,7 @@ impl Bound { params: &Punctuated, bound_trait: &Path, types: &[&Type], - recursive: Option<(bool, bool)>, + recursive: Option<(bool, bool, bool)>, ) -> Punctuated { match self { Self::Disabled => Punctuated::new(), diff --git a/src/common/type.rs b/src/common/type.rs index 645d2d4..e6fd1d0 100644 --- a/src/common/type.rs +++ b/src/common/type.rs @@ -3,7 +3,7 @@ use std::collections::HashSet; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, - Ident, Meta, Token, Type, TypeParamBound, + GenericArgument, Ident, Meta, Path, PathArguments, Token, Type, TypeParamBound, }; pub(crate) struct TypeWithPunctuatedMeta { @@ -34,12 +34,48 @@ impl Parse for TypeWithPunctuatedMeta { } } -/// recursive (dereference, de_ptr) +/// recursive (dereference, de_ptr, de_param) +#[inline] +pub(crate) fn find_idents_in_path<'a>( + set: &mut HashSet<&'a Ident>, + path: &'a Path, + recursive: Option<(bool, bool, bool)>, +) { + if let Some((_, _, de_param)) = recursive { + if de_param { + if let Some(segment) = path.segments.iter().last() { + if let PathArguments::AngleBracketed(a) = &segment.arguments { + // the ident is definitely not a generic parameter, so we don't insert it + + for arg in a.args.iter() { + match arg { + GenericArgument::Type(ty) => { + find_idents_in_type(set, ty, recursive); + }, + GenericArgument::AssocType(ty) => { + find_idents_in_type(set, &ty.ty, recursive); + }, + _ => (), + } + } + + return; + } + } + } + } + + if let Some(ty) = path.get_ident() { + set.insert(ty); + } +} + +/// recursive (dereference, de_ptr, de_param) #[inline] pub(crate) fn find_idents_in_type<'a>( set: &mut HashSet<&'a Ident>, ty: &'a Type, - recursive: Option<(bool, bool)>, + recursive: Option<(bool, bool, bool)>, ) { match ty { Type::Array(ty) => { @@ -53,21 +89,16 @@ pub(crate) fn find_idents_in_type<'a>( } }, Type::ImplTrait(ty) => { - if recursive.is_some() { - for b in &ty.bounds { - if let TypeParamBound::Trait(ty) = b { - if let Some(ty) = ty.path.get_ident() { - set.insert(ty); - } - } + // always recursive + for b in &ty.bounds { + if let TypeParamBound::Trait(ty) = b { + find_idents_in_path(set, &ty.path, recursive); } } }, Type::Macro(ty) => { if recursive.is_some() { - if let Some(ty) = ty.mac.path.get_ident() { - set.insert(ty); - } + find_idents_in_path(set, &ty.mac.path, recursive); } }, Type::Paren(ty) => { @@ -76,22 +107,16 @@ pub(crate) fn find_idents_in_type<'a>( } }, Type::Path(ty) => { - if let Some(ty) = ty.path.get_ident() { - set.insert(ty); - } + find_idents_in_path(set, &ty.path, recursive); }, Type::Ptr(ty) => { - if let Some((_, de_ptr)) = recursive { - if de_ptr { - find_idents_in_type(set, ty.elem.as_ref(), recursive); - } + if let Some((_, true, _)) = recursive { + find_idents_in_type(set, ty.elem.as_ref(), recursive); } }, Type::Reference(ty) => { - if let Some((dereference, _)) = recursive { - if dereference { - find_idents_in_type(set, ty.elem.as_ref(), recursive); - } + if let Some((true, ..)) = recursive { + find_idents_in_type(set, ty.elem.as_ref(), recursive); } }, Type::Slice(ty) => { @@ -100,13 +125,10 @@ pub(crate) fn find_idents_in_type<'a>( } }, Type::TraitObject(ty) => { - if recursive.is_some() { - for b in &ty.bounds { - if let TypeParamBound::Trait(ty) = b { - if let Some(ty) = ty.path.get_ident() { - set.insert(ty); - } - } + // always recursive + for b in &ty.bounds { + if let TypeParamBound::Trait(ty) = b { + find_idents_in_path(set, &ty.path, recursive); } } }, @@ -121,52 +143,6 @@ pub(crate) fn find_idents_in_type<'a>( } } -#[inline] -pub(crate) fn find_derivable_idents_in_type<'a>(set: &mut HashSet<&'a Ident>, ty: &'a Type) { - match ty { - Type::Array(ty) => find_derivable_idents_in_type(set, ty.elem.as_ref()), - Type::Group(ty) => find_derivable_idents_in_type(set, ty.elem.as_ref()), - Type::ImplTrait(ty) => { - for b in &ty.bounds { - if let TypeParamBound::Trait(ty) = b { - if let Some(ty) = ty.path.get_ident() { - set.insert(ty); - } - } - } - }, - Type::Macro(ty) => { - if let Some(ty) = ty.mac.path.get_ident() { - set.insert(ty); - } - }, - Type::Paren(ty) => find_derivable_idents_in_type(set, ty.elem.as_ref()), - Type::Path(ty) => { - if let Some(ty) = ty.path.get_ident() { - set.insert(ty); - } - }, - Type::Ptr(_) => (), - Type::Reference(_) => (), - Type::Slice(ty) => find_derivable_idents_in_type(set, ty.elem.as_ref()), - Type::TraitObject(ty) => { - for b in &ty.bounds { - if let TypeParamBound::Trait(ty) = b { - if let Some(ty) = ty.path.get_ident() { - set.insert(ty); - } - } - } - }, - Type::Tuple(ty) => { - for ty in &ty.elems { - find_derivable_idents_in_type(set, ty) - } - }, - _ => (), - } -} - #[inline] pub(crate) fn dereference(ty: &Type) -> &Type { if let Type::Reference(ty) = ty { diff --git a/src/common/where_predicates_bool.rs b/src/common/where_predicates_bool.rs index 833fd66..1aa66de 100644 --- a/src/common/where_predicates_bool.rs +++ b/src/common/where_predicates_bool.rs @@ -111,14 +111,14 @@ pub(crate) fn create_where_predicates_from_generic_parameters_check_types( params: &Punctuated, bound_trait: &Path, types: &[&Type], - resursive: Option<(bool, bool)>, + recursive: Option<(bool, bool, bool)>, ) -> WherePredicates { let mut where_predicates = Punctuated::new(); let mut set = HashSet::new(); for t in types { - find_idents_in_type(&mut set, t, resursive); + find_idents_in_type(&mut set, t, recursive); } for param in params { diff --git a/src/lib.rs b/src/lib.rs index 02399e5..ec57aeb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -316,7 +316,7 @@ enum Enum { ###### Generic Parameters Bound to the `Clone` Trait or Others -Generic parameters will be automatically bound to the `Clone` trait if necessary. If the `#[educe(Copy)]` attribute exists, all generic parameters will bound to the `Copy` trait. +Generic parameters will be automatically bound to the `Clone` trait if necessary. If the `#[educe(Copy)]` attribute exists, they will be bound to the `Copy` trait. ```rust # #[cfg(feature = "Clone")] diff --git a/src/trait_handlers/clone/clone_enum.rs b/src/trait_handlers/clone/clone_enum.rs index 0e22851..ad89ff1 100644 --- a/src/trait_handlers/clone/clone_enum.rs +++ b/src/trait_handlers/clone/clone_enum.rs @@ -222,7 +222,7 @@ impl TraitHandler for CloneEnumHandler { }) .unwrap(), &clone_types, - Some((false, false)), + Some((false, false, false)), ); } diff --git a/src/trait_handlers/clone/clone_struct.rs b/src/trait_handlers/clone/clone_struct.rs index b5bfd1a..6f07f72 100644 --- a/src/trait_handlers/clone/clone_struct.rs +++ b/src/trait_handlers/clone/clone_struct.rs @@ -145,7 +145,7 @@ impl TraitHandler for CloneStructHandler { }) .unwrap(), &clone_types, - Some((false, false)), + Some((false, false, false)), ); } diff --git a/src/trait_handlers/copy/mod.rs b/src/trait_handlers/copy/mod.rs index e58e31d..99f03d2 100644 --- a/src/trait_handlers/copy/mod.rs +++ b/src/trait_handlers/copy/mod.rs @@ -60,6 +60,18 @@ impl TraitHandler for CopyHandler { let ident = &ast.ident; + /* + #[derive(Clone)] + struct B { + f1: PhantomData, + } + + impl Copy for B { + + } + + // The above code will throw a compile error because T have to be bound to `Copy`. However, it seems not to be necessary logically. + */ let bound = type_attribute.bound.into_where_predicates_by_generic_parameters( &ast.generics.params, &syn::parse2(quote!(::core::marker::Copy)).unwrap(), diff --git a/src/trait_handlers/debug/common.rs b/src/trait_handlers/debug/common.rs index 78cba1b..da02c5c 100644 --- a/src/trait_handlers/debug/common.rs +++ b/src/trait_handlers/debug/common.rs @@ -31,7 +31,7 @@ pub(crate) fn create_format_arg( let ty = dereference(ty); let mut idents = HashSet::new(); - find_idents_in_type(&mut idents, ty, Some((true, false))); + find_idents_in_type(&mut idents, ty, Some((true, true, false))); // simply support one level generics (without considering bounds that use other generics) let mut filtered_params: Punctuated = Punctuated::new(); diff --git a/src/trait_handlers/debug/debug_enum.rs b/src/trait_handlers/debug/debug_enum.rs index e9ffae7..ba45c15 100644 --- a/src/trait_handlers/debug/debug_enum.rs +++ b/src/trait_handlers/debug/debug_enum.rs @@ -332,7 +332,7 @@ impl TraitHandler for DebugEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::fmt::Debug)).unwrap(), &debug_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/debug/debug_struct.rs b/src/trait_handlers/debug/debug_struct.rs index f564956..fd83418 100644 --- a/src/trait_handlers/debug/debug_struct.rs +++ b/src/trait_handlers/debug/debug_struct.rs @@ -160,7 +160,7 @@ impl TraitHandler for DebugStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::fmt::Debug)).unwrap(), &debug_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/default/default_enum.rs b/src/trait_handlers/default/default_enum.rs index 1b16e44..f1721b3 100644 --- a/src/trait_handlers/default/default_enum.rs +++ b/src/trait_handlers/default/default_enum.rs @@ -166,7 +166,7 @@ impl TraitHandler for DefaultEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::default::Default)).unwrap(), &default_types, - Some((false, false)), + Some((false, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/default/default_struct.rs b/src/trait_handlers/default/default_struct.rs index cefcbd2..bf29c26 100644 --- a/src/trait_handlers/default/default_struct.rs +++ b/src/trait_handlers/default/default_struct.rs @@ -111,7 +111,7 @@ impl TraitHandler for DefaultStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::default::Default)).unwrap(), &default_types, - Some((false, false)), + Some((false, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/default/default_union.rs b/src/trait_handlers/default/default_union.rs index 09280c4..18b3760 100644 --- a/src/trait_handlers/default/default_union.rs +++ b/src/trait_handlers/default/default_union.rs @@ -115,7 +115,7 @@ impl TraitHandler for DefaultUnionHandler { &ast.generics.params, &syn::parse2(quote!(::core::default::Default)).unwrap(), &default_types, - Some((false, false)), + Some((false, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/eq/mod.rs b/src/trait_handlers/eq/mod.rs index 1fafaee..1c3b2a5 100644 --- a/src/trait_handlers/eq/mod.rs +++ b/src/trait_handlers/eq/mod.rs @@ -61,6 +61,18 @@ impl TraitHandler for EqHandler { let ident = &ast.ident; + /* + #[derive(PartialEq)] + struct B { + f1: PhantomData, + } + + impl Eq for B { + + } + + // The above code will throw a compile error because T have to be bound to `PartialEq`. However, it seems not to be necessary logically. + */ let bound = type_attribute.bound.into_where_predicates_by_generic_parameters( &ast.generics.params, &syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(), diff --git a/src/trait_handlers/hash/hash_enum.rs b/src/trait_handlers/hash/hash_enum.rs index 1e9e9fe..d62c88a 100644 --- a/src/trait_handlers/hash/hash_enum.rs +++ b/src/trait_handlers/hash/hash_enum.rs @@ -142,7 +142,7 @@ impl TraitHandler for HashEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::hash::Hash)).unwrap(), &hash_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/hash/hash_struct.rs b/src/trait_handlers/hash/hash_struct.rs index c483191..d6c30a6 100644 --- a/src/trait_handlers/hash/hash_struct.rs +++ b/src/trait_handlers/hash/hash_struct.rs @@ -62,7 +62,7 @@ impl TraitHandler for HashStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::hash::Hash)).unwrap(), &hash_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/ord/ord_enum.rs b/src/trait_handlers/ord/ord_enum.rs index c53814c..3df84ab 100644 --- a/src/trait_handlers/ord/ord_enum.rs +++ b/src/trait_handlers/ord/ord_enum.rs @@ -237,7 +237,7 @@ impl TraitHandler for OrdEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::Ord)).unwrap(), &ord_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/ord/ord_struct.rs b/src/trait_handlers/ord/ord_struct.rs index a3362de..ecfaf47 100644 --- a/src/trait_handlers/ord/ord_struct.rs +++ b/src/trait_handlers/ord/ord_struct.rs @@ -83,7 +83,7 @@ impl TraitHandler for OrdStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::Ord)).unwrap(), &ord_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/partial_eq/partial_eq_enum.rs b/src/trait_handlers/partial_eq/partial_eq_enum.rs index dec1e00..8c79914 100644 --- a/src/trait_handlers/partial_eq/partial_eq_enum.rs +++ b/src/trait_handlers/partial_eq/partial_eq_enum.rs @@ -182,7 +182,7 @@ impl TraitHandler for PartialEqEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(), &partial_eq_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/partial_eq/partial_eq_struct.rs b/src/trait_handlers/partial_eq/partial_eq_struct.rs index 1e9ea37..318e71b 100644 --- a/src/trait_handlers/partial_eq/partial_eq_struct.rs +++ b/src/trait_handlers/partial_eq/partial_eq_struct.rs @@ -67,7 +67,7 @@ impl TraitHandler for PartialEqStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(), &partial_eq_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/partial_ord/partial_ord_enum.rs b/src/trait_handlers/partial_ord/partial_ord_enum.rs index 298daf2..0f6e33d 100644 --- a/src/trait_handlers/partial_ord/partial_ord_enum.rs +++ b/src/trait_handlers/partial_ord/partial_ord_enum.rs @@ -242,7 +242,7 @@ impl TraitHandler for PartialOrdEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::PartialOrd)).unwrap(), &partial_ord_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/partial_ord/partial_ord_struct.rs b/src/trait_handlers/partial_ord/partial_ord_struct.rs index 346f969..da158d7 100644 --- a/src/trait_handlers/partial_ord/partial_ord_struct.rs +++ b/src/trait_handlers/partial_ord/partial_ord_struct.rs @@ -85,7 +85,7 @@ impl TraitHandler for PartialOrdStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::PartialOrd)).unwrap(), &partial_ord_types, - Some((true, false)), + Some((true, false, false)), ); let where_clause = ast.generics.make_where_clause();