diff --git a/src/common/bound.rs b/src/common/bound.rs index 6ee7a74..1ca85f2 100644 --- a/src/common/bound.rs +++ b/src/common/bound.rs @@ -1,7 +1,6 @@ use syn::{punctuated::Punctuated, token::Comma, GenericParam, Meta, Path, Type, WherePredicate}; use crate::common::where_predicates_bool::{ - create_where_predicates_from_generic_parameters, create_where_predicates_from_generic_parameters_check_types, meta_2_where_predicates, WherePredicates, WherePredicatesOrBool, }; @@ -33,34 +32,20 @@ impl Bound { } impl Bound { - #[inline] - pub(crate) fn into_where_predicates_by_generic_parameters( - self, - params: &Punctuated, - bound_trait: &Path, - ) -> Punctuated { - match self { - Self::Disabled => Punctuated::new(), - Self::Auto => create_where_predicates_from_generic_parameters(params, bound_trait), - Self::Custom(where_predicates) => where_predicates, - } - } - #[inline] pub(crate) fn into_where_predicates_by_generic_parameters_check_types( self, - params: &Punctuated, + _params: &Punctuated, bound_trait: &Path, types: &[&Type], - recursive: Option<(bool, bool, bool)>, + supertraits: &[proc_macro2::TokenStream], ) -> Punctuated { match self { Self::Disabled => Punctuated::new(), Self::Auto => create_where_predicates_from_generic_parameters_check_types( - params, bound_trait, types, - recursive, + supertraits, ), Self::Custom(where_predicates) => where_predicates, } diff --git a/src/common/type.rs b/src/common/type.rs index e6fd1d0..c17ed73 100644 --- a/src/common/type.rs +++ b/src/common/type.rs @@ -1,9 +1,7 @@ -use std::collections::HashSet; - use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, - GenericArgument, Ident, Meta, Path, PathArguments, Token, Type, TypeParamBound, + Meta, Token, Type, }; pub(crate) struct TypeWithPunctuatedMeta { @@ -34,115 +32,6 @@ impl Parse for TypeWithPunctuatedMeta { } } -/// recursive (dereference, de_ptr, de_param) -#[inline] -pub(crate) fn find_idents_in_path<'a>( - set: &mut HashSet<&'a Ident>, - path: &'a Path, - recursive: Option<(bool, bool, bool)>, -) { - if let Some((_, _, de_param)) = recursive { - if de_param { - if let Some(segment) = path.segments.iter().last() { - if let PathArguments::AngleBracketed(a) = &segment.arguments { - // the ident is definitely not a generic parameter, so we don't insert it - - for arg in a.args.iter() { - match arg { - GenericArgument::Type(ty) => { - find_idents_in_type(set, ty, recursive); - }, - GenericArgument::AssocType(ty) => { - find_idents_in_type(set, &ty.ty, recursive); - }, - _ => (), - } - } - - return; - } - } - } - } - - if let Some(ty) = path.get_ident() { - set.insert(ty); - } -} - -/// recursive (dereference, de_ptr, de_param) -#[inline] -pub(crate) fn find_idents_in_type<'a>( - set: &mut HashSet<&'a Ident>, - ty: &'a Type, - recursive: Option<(bool, bool, bool)>, -) { - match ty { - Type::Array(ty) => { - if recursive.is_some() { - find_idents_in_type(set, ty.elem.as_ref(), recursive); - } - }, - Type::Group(ty) => { - if recursive.is_some() { - find_idents_in_type(set, ty.elem.as_ref(), recursive); - } - }, - Type::ImplTrait(ty) => { - // always recursive - for b in &ty.bounds { - if let TypeParamBound::Trait(ty) = b { - find_idents_in_path(set, &ty.path, recursive); - } - } - }, - Type::Macro(ty) => { - if recursive.is_some() { - find_idents_in_path(set, &ty.mac.path, recursive); - } - }, - Type::Paren(ty) => { - if recursive.is_some() { - find_idents_in_type(set, ty.elem.as_ref(), recursive); - } - }, - Type::Path(ty) => { - find_idents_in_path(set, &ty.path, recursive); - }, - Type::Ptr(ty) => { - if let Some((_, true, _)) = recursive { - find_idents_in_type(set, ty.elem.as_ref(), recursive); - } - }, - Type::Reference(ty) => { - if let Some((true, ..)) = recursive { - find_idents_in_type(set, ty.elem.as_ref(), recursive); - } - }, - Type::Slice(ty) => { - if recursive.is_some() { - find_idents_in_type(set, ty.elem.as_ref(), recursive); - } - }, - Type::TraitObject(ty) => { - // always recursive - for b in &ty.bounds { - if let TypeParamBound::Trait(ty) = b { - find_idents_in_path(set, &ty.path, recursive); - } - } - }, - Type::Tuple(ty) => { - if recursive.is_some() { - for ty in &ty.elems { - find_idents_in_type(set, ty, recursive) - } - } - }, - _ => (), - } -} - #[inline] pub(crate) fn dereference(ty: &Type) -> &Type { if let Type::Reference(ty) = ty { diff --git a/src/common/where_predicates_bool.rs b/src/common/where_predicates_bool.rs index a72e6b8..98e49a5 100644 --- a/src/common/where_predicates_bool.rs +++ b/src/common/where_predicates_bool.rs @@ -1,5 +1,3 @@ -use std::collections::HashSet; - use quote::{quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, @@ -9,7 +7,7 @@ use syn::{ Expr, GenericParam, Lit, Meta, MetaNameValue, Path, Token, Type, WherePredicate, }; -use super::{path::path_to_string, r#type::find_idents_in_type}; +use super::path::path_to_string; pub(crate) type WherePredicates = Punctuated; @@ -82,7 +80,8 @@ pub(crate) fn meta_2_where_predicates(meta: &Meta) -> syn::Result, bound_trait: &Path, ) -> WherePredicates { @@ -101,27 +100,18 @@ pub(crate) fn create_where_predicates_from_generic_parameters( #[inline] pub(crate) fn create_where_predicates_from_generic_parameters_check_types( - params: &Punctuated, bound_trait: &Path, types: &[&Type], - recursive: Option<(bool, bool, bool)>, + supertraits: &[proc_macro2::TokenStream], ) -> WherePredicates { let mut where_predicates = Punctuated::new(); - let mut set = HashSet::new(); - for t in types { - find_idents_in_type(&mut set, t, recursive); + where_predicates.push(syn::parse2(quote! { #t: #bound_trait }).unwrap()); } - for param in params { - if let GenericParam::Type(ty) = param { - let ident = &ty.ident; - - if set.contains(ident) { - where_predicates.push(syn::parse2(quote! { #ident: #bound_trait }).unwrap()); - } - } + for supertrait in supertraits { + where_predicates.push(syn::parse2(quote! { Self: #supertrait }).unwrap()); } where_predicates diff --git a/src/trait_handlers/clone/clone_enum.rs b/src/trait_handlers/clone/clone_enum.rs index 8cd7d1b..19b4d24 100644 --- a/src/trait_handlers/clone/clone_enum.rs +++ b/src/trait_handlers/clone/clone_enum.rs @@ -222,7 +222,7 @@ impl TraitHandler for CloneEnumHandler { }) .unwrap(), &clone_types, - Some((false, false, false)), + &[], ); } diff --git a/src/trait_handlers/clone/clone_struct.rs b/src/trait_handlers/clone/clone_struct.rs index 6f07f72..12bc279 100644 --- a/src/trait_handlers/clone/clone_struct.rs +++ b/src/trait_handlers/clone/clone_struct.rs @@ -145,7 +145,7 @@ impl TraitHandler for CloneStructHandler { }) .unwrap(), &clone_types, - Some((false, false, false)), + &[], ); } diff --git a/src/trait_handlers/clone/clone_union.rs b/src/trait_handlers/clone/clone_union.rs index 6d2986b..9b2968c 100644 --- a/src/trait_handlers/clone/clone_union.rs +++ b/src/trait_handlers/clone/clone_union.rs @@ -21,8 +21,11 @@ impl TraitHandler for CloneUnionHandler { } .build_from_clone_meta(meta)?; + let mut field_types = vec![]; + if let Data::Union(data) = &ast.data { for field in data.fields.named.iter() { + field_types.push(&field.ty); let _ = FieldAttributeBuilder { enable_method: false } @@ -32,9 +35,11 @@ impl TraitHandler for CloneUnionHandler { let ident = &ast.ident; - let bound = type_attribute.bound.into_where_predicates_by_generic_parameters( + let bound = type_attribute.bound.into_where_predicates_by_generic_parameters_check_types( &ast.generics.params, &syn::parse2(quote!(::core::marker::Copy)).unwrap(), + &field_types, + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/copy/mod.rs b/src/trait_handlers/copy/mod.rs index 8c3259f..02e0e62 100644 --- a/src/trait_handlers/copy/mod.rs +++ b/src/trait_handlers/copy/mod.rs @@ -29,11 +29,14 @@ impl TraitHandler for CopyHandler { } .build_from_copy_meta(meta)?; + let mut field_types = vec![]; + // if `contains_clone` is true, the implementation is handled by the `Clone` attribute, and field attributes is also handled by the `Clone` attribute if !contains_clone { match &ast.data { Data::Struct(data) => { for field in data.fields.iter() { + field_types.push(&field.ty); let _ = FieldAttributeBuilder.build_from_attributes(&field.attrs, traits)?; } @@ -46,6 +49,7 @@ impl TraitHandler for CopyHandler { .build_from_attributes(&variant.attrs, traits)?; for field in variant.fields.iter() { + field_types.push(&field.ty); let _ = FieldAttributeBuilder .build_from_attributes(&field.attrs, traits)?; } @@ -53,6 +57,7 @@ impl TraitHandler for CopyHandler { }, Data::Union(data) => { for field in data.fields.named.iter() { + field_types.push(&field.ty); let _ = FieldAttributeBuilder.build_from_attributes(&field.attrs, traits)?; } @@ -61,22 +66,13 @@ impl TraitHandler for CopyHandler { let ident = &ast.ident; - /* - #[derive(Clone)] - struct B { - f1: PhantomData, - } - - impl Copy for B { - - } - - // The above code will throw a compile error because T have to be bound to `Copy`. However, it seems not to be necessary logically. - */ - let bound = type_attribute.bound.into_where_predicates_by_generic_parameters( - &ast.generics.params, - &syn::parse2(quote!(::core::marker::Copy)).unwrap(), - ); + let bound = + type_attribute.bound.into_where_predicates_by_generic_parameters_check_types( + &ast.generics.params, + &syn::parse2(quote!(::core::marker::Copy)).unwrap(), + &field_types, + &[quote! {::core::clone::Clone}], + ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/debug/common.rs b/src/trait_handlers/debug/common.rs index 142eb10..e2c59a9 100644 --- a/src/trait_handlers/debug/common.rs +++ b/src/trait_handlers/debug/common.rs @@ -1,9 +1,5 @@ -use std::collections::HashSet; - use quote::quote; -use syn::{punctuated::Punctuated, token::Comma, GenericParam, Path, Type}; - -use crate::common::r#type::{dereference, find_idents_in_type}; +use syn::{DeriveInput, Path, Type}; #[inline] pub(crate) fn create_debug_map_builder() -> proc_macro2::TokenStream { @@ -24,41 +20,36 @@ pub(crate) fn create_debug_map_builder() -> proc_macro2::TokenStream { #[inline] pub(crate) fn create_format_arg( - params: &Punctuated, - ty: &Type, + ast: &DeriveInput, + field_ty: &Type, format_method: &Path, - field: proc_macro2::TokenStream, + field_expr: proc_macro2::TokenStream, ) -> proc_macro2::TokenStream { - let ty = dereference(ty); - - let mut idents = HashSet::new(); - find_idents_in_type(&mut idents, ty, Some((true, true, false))); - - // simply support one level generics (without considering bounds that use other generics) - let mut filtered_params: Punctuated = Punctuated::new(); + let ty_ident = &ast.ident; - for param in params.iter() { - if let GenericParam::Type(ty) = param { - let ident = &ty.ident; - - if idents.contains(ident) { - filtered_params.push(param.clone()); - } - } - } + // We use the complete original generics, not filtered by field, + // and include a PhantomData in our wrapper struct to use the generics. + // + // This avoids having to try to calculate the right *subset* of the generics + // relevant for this field, which is nontrivial and maybe impossible. + let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); quote!( let arg = { - struct MyDebug<'a, #filtered_params>(&'a #ty); + #[allow(non_camel_case_types)] // We're using __ to help avoid clashes. + struct Educe__DebugField(V, ::core::marker::PhantomData); - impl<'a, #filtered_params> ::core::fmt::Debug for MyDebug<'a, #filtered_params> { + impl #impl_generics ::core::fmt::Debug + for Educe__DebugField<&#field_ty, #ty_ident #ty_generics> + #where_clause + { #[inline] - fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { - #format_method(self.0, f) + fn fmt(&self, educe__f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { + #format_method(self.0, educe__f) } } - MyDebug(#field) + Educe__DebugField(#field_expr, ::core::marker::PhantomData::) }; ) } diff --git a/src/trait_handlers/debug/debug_enum.rs b/src/trait_handlers/debug/debug_enum.rs index f0dd818..6dc3364 100644 --- a/src/trait_handlers/debug/debug_enum.rs +++ b/src/trait_handlers/debug/debug_enum.rs @@ -110,7 +110,7 @@ impl TraitHandler for DebugEnumHandler { if let Some(method) = field_attribute.method { block_token_stream.extend(super::common::create_format_arg( - &ast.generics.params, + ast, ty, &method, quote!(#field_name_var), @@ -162,7 +162,7 @@ impl TraitHandler for DebugEnumHandler { if let Some(method) = field_attribute.method { block_token_stream.extend(super::common::create_format_arg( - &ast.generics.params, + ast, ty, &method, quote!(#field_name_var), @@ -230,7 +230,7 @@ impl TraitHandler for DebugEnumHandler { if let Some(method) = field_attribute.method { block_token_stream.extend(super::common::create_format_arg( - &ast.generics.params, + ast, ty, &method, quote!(#field_name_var), @@ -280,7 +280,7 @@ impl TraitHandler for DebugEnumHandler { if let Some(method) = field_attribute.method { block_token_stream.extend(super::common::create_format_arg( - &ast.generics.params, + ast, ty, &method, quote!(#field_name_var), @@ -336,7 +336,7 @@ impl TraitHandler for DebugEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::fmt::Debug)).unwrap(), &debug_types, - Some((true, false, false)), + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/debug/debug_struct.rs b/src/trait_handlers/debug/debug_struct.rs index ccf99ce..574a1d9 100644 --- a/src/trait_handlers/debug/debug_struct.rs +++ b/src/trait_handlers/debug/debug_struct.rs @@ -80,7 +80,7 @@ impl TraitHandler for DebugStructHandler { if let Some(method) = field_attribute.method { builder_token_stream.extend(super::common::create_format_arg( - &ast.generics.params, + ast, ty, &method, quote!(&self.#field_name), @@ -129,7 +129,7 @@ impl TraitHandler for DebugStructHandler { if let Some(method) = field_attribute.method { builder_token_stream.extend(super::common::create_format_arg( - &ast.generics.params, + ast, ty, &method, quote!(&self.#field_name), @@ -157,7 +157,7 @@ impl TraitHandler for DebugStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::fmt::Debug)).unwrap(), &debug_types, - Some((true, false, false)), + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/default/default_enum.rs b/src/trait_handlers/default/default_enum.rs index f1721b3..bac7767 100644 --- a/src/trait_handlers/default/default_enum.rs +++ b/src/trait_handlers/default/default_enum.rs @@ -166,7 +166,7 @@ impl TraitHandler for DefaultEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::default::Default)).unwrap(), &default_types, - Some((false, false, false)), + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/default/default_struct.rs b/src/trait_handlers/default/default_struct.rs index bf29c26..ece7720 100644 --- a/src/trait_handlers/default/default_struct.rs +++ b/src/trait_handlers/default/default_struct.rs @@ -111,7 +111,7 @@ impl TraitHandler for DefaultStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::default::Default)).unwrap(), &default_types, - Some((false, false, false)), + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/default/default_union.rs b/src/trait_handlers/default/default_union.rs index 18b3760..371d38e 100644 --- a/src/trait_handlers/default/default_union.rs +++ b/src/trait_handlers/default/default_union.rs @@ -115,7 +115,7 @@ impl TraitHandler for DefaultUnionHandler { &ast.generics.params, &syn::parse2(quote!(::core::default::Default)).unwrap(), &default_types, - Some((false, false, false)), + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/eq/mod.rs b/src/trait_handlers/eq/mod.rs index 1c3b2a5..ff722e1 100644 --- a/src/trait_handlers/eq/mod.rs +++ b/src/trait_handlers/eq/mod.rs @@ -29,11 +29,14 @@ impl TraitHandler for EqHandler { } .build_from_eq_meta(meta)?; + let mut field_types = vec![]; + // if `contains_partial_eq` is true, the implementation is handled by the `PartialEq` attribute, and field attributes is also handled by the `PartialEq` attribute if !contains_partial_eq { match &ast.data { Data::Struct(data) => { for field in data.fields.iter() { + field_types.push(&field.ty); let _ = FieldAttributeBuilder.build_from_attributes(&field.attrs, traits)?; } @@ -46,6 +49,7 @@ impl TraitHandler for EqHandler { .build_from_attributes(&variant.attrs, traits)?; for field in variant.fields.iter() { + field_types.push(&field.ty); let _ = FieldAttributeBuilder .build_from_attributes(&field.attrs, traits)?; } @@ -53,6 +57,7 @@ impl TraitHandler for EqHandler { }, Data::Union(data) => { for field in data.fields.named.iter() { + field_types.push(&field.ty); let _ = FieldAttributeBuilder.build_from_attributes(&field.attrs, traits)?; } @@ -73,10 +78,13 @@ impl TraitHandler for EqHandler { // The above code will throw a compile error because T have to be bound to `PartialEq`. However, it seems not to be necessary logically. */ - let bound = type_attribute.bound.into_where_predicates_by_generic_parameters( - &ast.generics.params, - &syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(), - ); + let bound = + type_attribute.bound.into_where_predicates_by_generic_parameters_check_types( + &ast.generics.params, + &syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(), + &field_types, + &[quote! {::core::cmp::PartialEq}], + ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/hash/hash_enum.rs b/src/trait_handlers/hash/hash_enum.rs index 0107056..06b04dc 100644 --- a/src/trait_handlers/hash/hash_enum.rs +++ b/src/trait_handlers/hash/hash_enum.rs @@ -143,7 +143,7 @@ impl TraitHandler for HashEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::hash::Hash)).unwrap(), &hash_types, - Some((true, false, false)), + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/hash/hash_struct.rs b/src/trait_handlers/hash/hash_struct.rs index d6c30a6..bafb6c5 100644 --- a/src/trait_handlers/hash/hash_struct.rs +++ b/src/trait_handlers/hash/hash_struct.rs @@ -62,7 +62,7 @@ impl TraitHandler for HashStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::hash::Hash)).unwrap(), &hash_types, - Some((true, false, false)), + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/into/into_enum.rs b/src/trait_handlers/into/into_enum.rs index 46d72dd..12fcd19 100644 --- a/src/trait_handlers/into/into_enum.rs +++ b/src/trait_handlers/into/into_enum.rs @@ -192,7 +192,7 @@ impl TraitHandlerMultiple for IntoEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::convert::Into<#target_ty>)).unwrap(), &into_types, - None, + &[], ); // clone generics in order to not to affect other Into implementations diff --git a/src/trait_handlers/into/into_struct.rs b/src/trait_handlers/into/into_struct.rs index bd9eb36..cac39c8 100644 --- a/src/trait_handlers/into/into_struct.rs +++ b/src/trait_handlers/into/into_struct.rs @@ -138,7 +138,7 @@ impl TraitHandlerMultiple for IntoStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::convert::Into<#target_ty>)).unwrap(), &into_types, - None, + &[], ); // clone generics in order to not to affect other Into implementations diff --git a/src/trait_handlers/ord/mod.rs b/src/trait_handlers/ord/mod.rs index 0913507..2b37631 100644 --- a/src/trait_handlers/ord/mod.rs +++ b/src/trait_handlers/ord/mod.rs @@ -3,6 +3,7 @@ mod ord_enum; mod ord_struct; mod panic; +use quote::quote; use syn::{Data, DeriveInput, Meta}; use super::TraitHandler; @@ -31,3 +32,18 @@ impl TraitHandler for OrdHandler { } } } + +fn supertraits(#[allow(unused_variables)] traits: &[Trait]) -> Vec { + let mut supertraits = vec![]; + supertraits.push(quote! {::core::cmp::Eq}); + + // We mustn't add the PartialOrd bound to the educed PartialOrd impl. + // When we're educing PartialOrd we can leave it off the Ord impl too, + // since we *know* Self is going to be PartialOrd. + #[cfg(feature = "PartialOrd")] + if !traits.contains(&Trait::PartialOrd) { + supertraits.push(quote! {::core::cmp::PartialOrd}); + }; + + supertraits +} diff --git a/src/trait_handlers/ord/ord_enum.rs b/src/trait_handlers/ord/ord_enum.rs index 2ea3a1e..e9461b9 100644 --- a/src/trait_handlers/ord/ord_enum.rs +++ b/src/trait_handlers/ord/ord_enum.rs @@ -245,7 +245,7 @@ impl TraitHandler for OrdEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::Ord)).unwrap(), &ord_types, - Some((true, false, false)), + &crate::trait_handlers::ord::supertraits(traits), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/ord/ord_struct.rs b/src/trait_handlers/ord/ord_struct.rs index ecfaf47..64d6630 100644 --- a/src/trait_handlers/ord/ord_struct.rs +++ b/src/trait_handlers/ord/ord_struct.rs @@ -83,7 +83,7 @@ impl TraitHandler for OrdStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::Ord)).unwrap(), &ord_types, - Some((true, false, false)), + &crate::trait_handlers::ord::supertraits(traits), ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/partial_eq/partial_eq_enum.rs b/src/trait_handlers/partial_eq/partial_eq_enum.rs index 8ef8c6f..905c776 100644 --- a/src/trait_handlers/partial_eq/partial_eq_enum.rs +++ b/src/trait_handlers/partial_eq/partial_eq_enum.rs @@ -182,7 +182,7 @@ impl TraitHandler for PartialEqEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(), &partial_eq_types, - Some((true, false, false)), + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/partial_eq/partial_eq_struct.rs b/src/trait_handlers/partial_eq/partial_eq_struct.rs index 318e71b..0fd6195 100644 --- a/src/trait_handlers/partial_eq/partial_eq_struct.rs +++ b/src/trait_handlers/partial_eq/partial_eq_struct.rs @@ -67,7 +67,7 @@ impl TraitHandler for PartialEqStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(), &partial_eq_types, - Some((true, false, false)), + &[], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/partial_ord/partial_ord_enum.rs b/src/trait_handlers/partial_ord/partial_ord_enum.rs index e0d9432..3e22a87 100644 --- a/src/trait_handlers/partial_ord/partial_ord_enum.rs +++ b/src/trait_handlers/partial_ord/partial_ord_enum.rs @@ -250,7 +250,7 @@ impl TraitHandler for PartialOrdEnumHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::PartialOrd)).unwrap(), &partial_ord_types, - Some((true, false, false)), + &[quote! {::core::cmp::PartialEq}], ); let where_clause = ast.generics.make_where_clause(); diff --git a/src/trait_handlers/partial_ord/partial_ord_struct.rs b/src/trait_handlers/partial_ord/partial_ord_struct.rs index da158d7..2fa9e60 100644 --- a/src/trait_handlers/partial_ord/partial_ord_struct.rs +++ b/src/trait_handlers/partial_ord/partial_ord_struct.rs @@ -85,7 +85,7 @@ impl TraitHandler for PartialOrdStructHandler { &ast.generics.params, &syn::parse2(quote!(::core::cmp::PartialOrd)).unwrap(), &partial_ord_types, - Some((true, false, false)), + &[quote! {::core::cmp::PartialEq}], ); let where_clause = ast.generics.make_where_clause(); diff --git a/tests/copy_clone_struct.rs b/tests/copy_clone_struct.rs index abaec67..6853e9c 100644 --- a/tests/copy_clone_struct.rs +++ b/tests/copy_clone_struct.rs @@ -2,6 +2,8 @@ #![no_std] #![allow(clippy::clone_on_copy)] +use core::marker::PhantomData; + use educe::Educe; #[test] @@ -111,3 +113,72 @@ fn bound_3() { assert_eq!(1, s.f1); assert_eq!(1, t.0); } + +#[test] +fn bound_4() { + #[derive(Educe)] + #[educe(Copy, Clone)] + struct Struct { + f1: Option, + f2: PhantomData, + } + + #[derive(Educe)] + #[educe(Copy, Clone)] + struct Tuple(Option, PhantomData); + + let s = Struct { + f1: Some(1), f2: PhantomData:: + } + .clone(); + let t = Tuple(Some(1), PhantomData::).clone(); + + assert_eq!(Some(1), s.f1); + assert_eq!(Some(1), t.0); +} + +#[test] +fn bound_5() { + trait Suitable {} + struct SuitableNotClone; + impl Suitable for SuitableNotClone {} + let phantom = PhantomData::; + + fn copy(t: &T) -> T { + *t + } + + #[derive(Educe)] + #[educe(Copy)] + struct Struct { + f1: Option, + f2: PhantomData, + } + + impl Clone for Struct { + fn clone(&self) -> Self { + Struct { + f1: self.f1.clone(), f2: PhantomData + } + } + } + + #[derive(Educe)] + #[educe(Copy)] + struct Tuple(Option, PhantomData); + + impl Clone for Tuple { + fn clone(&self) -> Self { + Tuple(self.0.clone(), PhantomData) + } + } + + let s = copy(&Struct { + f1: Some(1), f2: phantom + }); + + let t = copy(&Tuple(Some(1), phantom)); + + assert_eq!(Some(1), s.f1); + assert_eq!(Some(1), t.0); +} diff --git a/tests/debug_struct.rs b/tests/debug_struct.rs index 23fceff..4cd8573 100644 --- a/tests/debug_struct.rs +++ b/tests/debug_struct.rs @@ -4,6 +4,11 @@ #[macro_use] extern crate alloc; +use core::{ + fmt::{self, Debug, Display}, + marker::PhantomData, +}; + use educe::Educe; #[test] @@ -532,3 +537,57 @@ fn bound_3() { assert_eq!("Tuple(1)", format!("{:?}", Tuple(1))); } + +#[test] +fn bound_4() { + use core::cell::RefCell; + + #[derive(Educe)] + #[educe(Debug)] + struct Struct { + f1: RefCell, + } + + assert_eq!( + "Struct { f1: RefCell { value: 1 } }", + format!("{:?}", Struct { + f1: RefCell::new(1) + }) + ); + + #[derive(Educe)] + #[educe(Debug(bound(T: core::fmt::Debug)))] + struct Tuple(RefCell); + + assert_eq!("Tuple(RefCell { value: 1 })", format!("{:?}", Tuple(RefCell::new(1)))); +} + +#[test] +fn bound_5() { + struct DebugAsDisplay(T); + + struct NotDebug; + + impl Debug for DebugAsDisplay { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + Display::fmt(&self.0, f) + } + } + + #[derive(Educe)] + #[educe(Debug)] + struct Struct { + f1: Option, + f2: DebugAsDisplay, + f3: PhantomData, + } + + assert_eq!( + "Struct { f1: Some(1), f2: lit, f3: PhantomData }", + format!("{:?}", Struct { + f1: Some(1), + f2: DebugAsDisplay("lit"), + f3: PhantomData::, + }) + ); +} diff --git a/tests/eq_struct.rs b/tests/eq_struct.rs index 6730fbc..e8d96a4 100644 --- a/tests/eq_struct.rs +++ b/tests/eq_struct.rs @@ -1,6 +1,8 @@ #![cfg(all(feature = "Eq", feature = "PartialEq"))] #![no_std] +use core::marker::PhantomData; + use educe::Educe; #[test] @@ -333,6 +335,56 @@ fn bound_3() { assert!(Tuple(1) != Tuple(2)); } +#[test] +fn bound_4() { + trait Suitable {} + struct SuitableNotEq; + impl Suitable for SuitableNotEq {} + let phantom = PhantomData::; + + #[derive(Educe)] + #[educe(Eq)] + struct Struct { + f1: T, + // PhantomData is Eq (all PhantomData are equal to all others) + f2: PhantomData, + } + + impl PartialEq for Struct { + fn eq(&self, other: &Struct) -> bool { + self.f1.eq(&other.f1) + } + } + + #[derive(Educe)] + #[educe(Eq)] + struct Tuple(T, PhantomData); + + impl PartialEq for Tuple { + fn eq(&self, other: &Tuple) -> bool { + self.0.eq(&other.0) + } + } + assert!( + Struct { + f1: 1, f2: phantom + } == Struct { + f1: 1, f2: phantom + } + ); + + assert!( + Struct { + f1: 1, f2: phantom + } != Struct { + f1: 2, f2: phantom + } + ); + + assert!(Tuple(1, phantom) == Tuple(1, phantom)); + assert!(Tuple(1, phantom) != Tuple(2, phantom)); +} + #[allow(dead_code)] #[test] fn use_partial_eq_attr_ignore() { diff --git a/tests/ord_struct.rs b/tests/ord_struct.rs index efe3046..aeedfcf 100644 --- a/tests/ord_struct.rs +++ b/tests/ord_struct.rs @@ -1,7 +1,7 @@ #![cfg(all(feature = "Ord", feature = "PartialOrd"))] #![no_std] -use core::cmp::Ordering; +use core::{cmp::Ordering, marker::PhantomData}; use educe::Educe; @@ -528,6 +528,67 @@ fn bound_3() { assert!(Tuple(1) < Tuple(2)); } +#[test] +fn bound_4() { + trait Suitable {} + struct SuitableNotEq; + impl Suitable for SuitableNotEq {} + let phantom = PhantomData::; + + #[derive(Educe)] + #[educe(Ord, PartialEq)] + struct Struct { + f1: T, + // PhantomData is Eq (all PhantomData are equal to all others) + f2: PhantomData, + } + + impl Eq for Struct {} + impl PartialOrd for Struct { + fn partial_cmp(&self, other: &Struct) -> Option { + self.f1.partial_cmp(&other.f1) + } + } + + #[derive(Educe)] + #[educe(Ord, PartialEq)] + struct Tuple(T, PhantomData); + + impl Eq for Tuple {} + impl PartialOrd for Tuple { + fn partial_cmp(&self, other: &Tuple) -> Option { + self.0.partial_cmp(&other.0) + } + } + + assert_eq!( + Ord::cmp( + &Struct { + f1: 1, f2: phantom + }, + &Struct { + f1: 1, f2: phantom + } + ), + Ordering::Equal + ); + + assert_eq!( + Ord::cmp( + &Struct { + f1: 1, f2: phantom + }, + &Struct { + f1: 2, f2: phantom + } + ), + Ordering::Less + ); + + assert_eq!(Ord::cmp(&Tuple(1, phantom), &Tuple(1, phantom)), Ordering::Equal); + assert_eq!(Ord::cmp(&Tuple(1, phantom), &Tuple(2, phantom)), Ordering::Less); +} + #[test] fn use_partial_ord_attr_ignore() { #[derive(PartialEq, Eq, Educe)] diff --git a/tests/partial_ord_struct.rs b/tests/partial_ord_struct.rs index bb84517..b84a862 100644 --- a/tests/partial_ord_struct.rs +++ b/tests/partial_ord_struct.rs @@ -1,7 +1,7 @@ #![cfg(feature = "PartialOrd")] #![no_std] -use core::cmp::Ordering; +use core::{cmp::Ordering, marker::PhantomData}; use educe::Educe; @@ -527,3 +527,53 @@ fn bound_3() { assert!(Tuple(2) > Tuple(1)); assert!(Tuple(1) < Tuple(2)); } + +#[test] +fn bound_4() { + trait Suitable {} + struct SuitableNotEq; + impl Suitable for SuitableNotEq {} + let phantom = PhantomData::; + + #[derive(Educe)] + #[educe(PartialOrd)] + struct Struct { + f1: T, + // PhantomData is Eq (all PhantomData are equal to all others) + f2: PhantomData, + } + + impl PartialEq for Struct { + fn eq(&self, other: &Struct) -> bool { + self.f1.eq(&other.f1) + } + } + + #[derive(Educe)] + #[educe(Eq)] + struct Tuple(T, PhantomData); + + impl PartialEq for Tuple { + fn eq(&self, other: &Tuple) -> bool { + self.0.eq(&other.0) + } + } + assert!( + Struct { + f1: 1, f2: phantom + } == Struct { + f1: 1, f2: phantom + } + ); + + assert!( + Struct { + f1: 1, f2: phantom + } != Struct { + f1: 2, f2: phantom + } + ); + + assert!(Tuple(1, phantom) == Tuple(1, phantom)); + assert!(Tuple(1, phantom) != Tuple(2, phantom)); +}