diff --git a/impl/src/fmt/display.rs b/impl/src/fmt/display.rs index 0610a145..cb664cdd 100644 --- a/impl/src/fmt/display.rs +++ b/impl/src/fmt/display.rs @@ -32,7 +32,17 @@ pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> syn::Result = input + .generics + .params + .iter() + .filter_map(|p| match p { + syn::GenericParam::Type(t) => Some(&t.ident), + syn::GenericParam::Const(..) | syn::GenericParam::Lifetime(..) => None, + }) + .collect(); + + let ctx: ExpansionCtx = (&attrs, &type_params, ident, &trait_ident, &attr_name); let (bounds, body) = match &input.data { syn::Data::Struct(s) => expand_struct(s, ctx), syn::Data::Enum(e) => expand_enum(e, ctx), @@ -62,6 +72,7 @@ pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> syn::Result syn::Result = ( &'a ContainerAttributes, + &'a [&'a syn::Ident], &'a syn::Ident, &'a syn::Ident, &'a syn::Ident, @@ -77,12 +89,13 @@ type ExpansionCtx<'a> = ( /// Expands a [`fmt::Display`]-like derive macro for the provided struct. fn expand_struct( s: &syn::DataStruct, - (attrs, ident, trait_ident, _): ExpansionCtx<'_>, + (attrs, type_params, ident, trait_ident, _): ExpansionCtx<'_>, ) -> syn::Result<(Vec, TokenStream)> { let s = Expansion { shared_attr: None, attrs, fields: &s.fields, + type_params, trait_ident, ident, }; @@ -111,7 +124,7 @@ fn expand_struct( /// Expands a [`fmt`]-like derive macro for the provided enum. fn expand_enum( e: &syn::DataEnum, - (container_attrs, _, trait_ident, attr_name): ExpansionCtx<'_>, + (container_attrs, type_params, _, trait_ident, attr_name): ExpansionCtx<'_>, ) -> syn::Result<(Vec, TokenStream)> { if let Some(shared_fmt) = &container_attrs.fmt { if shared_fmt @@ -153,6 +166,7 @@ fn expand_enum( shared_attr: container_attrs.fmt.as_ref(), attrs: &attrs, fields: &variant.fields, + type_params, trait_ident, ident, }; @@ -190,7 +204,7 @@ fn expand_enum( /// Expands a [`fmt::Display`]-like derive macro for the provided union. fn expand_union( u: &syn::DataUnion, - (attrs, _, _, attr_name): ExpansionCtx<'_>, + (attrs, _, _, _, attr_name): ExpansionCtx<'_>, ) -> syn::Result<(Vec, TokenStream)> { let fmt = &attrs.fmt.as_ref().ok_or_else(|| { syn::Error::new( @@ -227,6 +241,9 @@ struct Expansion<'a> { /// Struct or enum [`syn::Fields`]. fields: &'a syn::Fields, + /// Type parameters in this struct or enum. + type_params: &'a [&'a syn::Ident], + /// [`fmt`] trait [`syn::Ident`]. /// /// [`syn::Ident`]: struct@syn::Ident @@ -343,34 +360,150 @@ impl<'a> Expansion<'a> { if let Some(fmt) = &self.attrs.fmt { bounds.extend( fmt.bounded_types(self.fields) - .map(|(ty, trait_name)| { + .filter_map(|(ty, trait_name)| { + if !self.contains_generic_param(ty) { + return None; + } let trait_ident = format_ident!("{trait_name}"); - parse_quote! { #ty: derive_more::core::fmt::#trait_ident } + Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident }) }) .chain(self.attrs.bounds.0.clone()), ); } else { - bounds.extend(self.fields.iter().next().map(|f| { - let ty = &f.ty; - let trait_ident = &self.trait_ident; - parse_quote! { #ty: derive_more::core::fmt::#trait_ident } - })) + bounds.extend( + self.fields + .iter() + .next() + .map(|f| { + let ty = &f.ty; + if !self.contains_generic_param(ty) { + return vec![]; + } + let trait_ident = &self.trait_ident; + vec![parse_quote! { #ty: derive_more::core::fmt::#trait_ident }] + }) + .unwrap_or_default(), + ); }; } if let Some(shared_fmt) = &self.shared_attr { - bounds.extend(shared_fmt.bounded_types(self.fields).map( + bounds.extend(shared_fmt.bounded_types(self.fields).filter_map( |(ty, trait_name)| { + if !self.contains_generic_param(ty) { + return None; + } let trait_ident = format_ident!("{trait_name}"); - parse_quote! { #ty: derive_more::core::fmt::#trait_ident } + Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident }) }, )); } bounds } + + /// Checks whether the provided [`syn::Path`] contains any of these [`Expansion::type_params`]. + fn path_contains_generic_param(&self, path: &syn::Path) -> bool { + path.segments + .iter() + .any(|segment| match &segment.arguments { + syn::PathArguments::None => false, + syn::PathArguments::AngleBracketed( + syn::AngleBracketedGenericArguments { args, .. }, + ) => args.iter().any(|generic| match generic { + syn::GenericArgument::Type(ty) + | syn::GenericArgument::AssocType(syn::AssocType { ty, .. }) => { + self.contains_generic_param(ty) + } + + syn::GenericArgument::Lifetime(_) + | syn::GenericArgument::Const(_) + | syn::GenericArgument::AssocConst(_) + | syn::GenericArgument::Constraint(_) => false, + _ => unimplemented!( + "syntax is not supported by `derive_more`, please report a bug", + ), + }), + syn::PathArguments::Parenthesized( + syn::ParenthesizedGenericArguments { inputs, output, .. }, + ) => { + inputs.iter().any(|ty| self.contains_generic_param(ty)) + || match output { + syn::ReturnType::Default => false, + syn::ReturnType::Type(_, ty) => { + self.contains_generic_param(ty) + } + } + } + }) + } + + /// Checks whether the provided [`syn::Type`] contains any of these [`Expansion::type_params`]. + fn contains_generic_param(&self, ty: &syn::Type) -> bool { + if self.type_params.is_empty() { + return false; + } + match ty { + syn::Type::Path(syn::TypePath { qself, path }) => { + if let Some(qself) = qself { + if self.contains_generic_param(&qself.ty) { + return true; + } + } + + if let Some(ident) = path.get_ident() { + self.type_params.iter().any(|param| *param == ident) + } else { + self.path_contains_generic_param(path) + } + } + + syn::Type::Array(syn::TypeArray { elem, .. }) + | syn::Type::Group(syn::TypeGroup { elem, .. }) + | syn::Type::Paren(syn::TypeParen { elem, .. }) + | syn::Type::Ptr(syn::TypePtr { elem, .. }) + | syn::Type::Reference(syn::TypeReference { elem, .. }) + | syn::Type::Slice(syn::TypeSlice { elem, .. }) => { + self.contains_generic_param(elem) + } + + syn::Type::BareFn(syn::TypeBareFn { inputs, output, .. }) => { + inputs + .iter() + .any(|arg| self.contains_generic_param(&arg.ty)) + || match output { + syn::ReturnType::Default => false, + syn::ReturnType::Type(_, ty) => self.contains_generic_param(ty), + } + } + syn::Type::Tuple(syn::TypeTuple { elems, .. }) => { + elems.iter().any(|ty| self.contains_generic_param(ty)) + } + + syn::Type::ImplTrait(_) => false, + syn::Type::Infer(_) => false, + syn::Type::Macro(_) => false, + syn::Type::Never(_) => false, + syn::Type::TraitObject(syn::TypeTraitObject { bounds, .. }) => { + bounds.iter().any(|bound| match bound { + syn::TypeParamBound::Trait(syn::TraitBound { path, .. }) => { + self.path_contains_generic_param(path) + } + syn::TypeParamBound::Lifetime(_) => false, + syn::TypeParamBound::Verbatim(_) => false, + _ => unimplemented!( + "syntax is not supported by `derive_more`, please report a bug", + ), + }) + } + syn::Type::Verbatim(_) => false, + _ => unimplemented!( + "syntax is not supported by `derive_more`, please report a bug", + ), + } + } } /// Matches the provided derive macro `name` to appropriate actual trait name. diff --git a/tests/display.rs b/tests/display.rs index bd0f45ae..493206e5 100644 --- a/tests/display.rs +++ b/tests/display.rs @@ -2279,3 +2279,184 @@ mod generic { } } } + +// See: https://github.com/JelteF/derive_more/issues/363 +mod type_variables { + mod our_alloc { + #[cfg(not(feature = "std"))] + pub use alloc::{boxed::Box, format, vec::Vec}; + #[cfg(feature = "std")] + pub use std::{boxed::Box, format, vec::Vec}; + } + + use our_alloc::{format, Box}; + + // We want Vec in scope to test that code generation works if it is + #[allow(unused_imports)] + use our_alloc::Vec; + + use derive_more::Display; + + #[derive(Display, Debug)] + #[display("{inner:?}")] + #[display(bounds(T: Display))] + struct OptionalBox { + inner: Option>, + } + + #[derive(Display, Debug)] + #[display("{next}")] + struct ItemStruct { + next: OptionalBox, + } + + #[derive(Display)] + #[derive(Debug)] + struct ItemTuple(OptionalBox); + + #[derive(Display)] + #[derive(Debug)] + #[display("Item({_0})")] + struct ItemTupleContainerFmt(OptionalBox); + + #[derive(Display, Debug)] + #[display("{next}")] + enum ItemEnumOuterFormat { + Variant1 { + next: OptionalBox, + }, + Variant2 { + next: OptionalBox, + }, + } + + #[derive(Display, Debug)] + enum ItemEnumInnerFormat { + #[display("{next} {inner}")] + Node { + next: OptionalBox, + inner: i32, + }, + #[display("{inner}")] + Leaf { inner: i32 }, + } + + #[derive(Display)] + #[derive(Debug)] + #[display("{next:?}, {real:?}")] + struct VecMeansDifferent { + next: our_alloc::Vec, + real: Vec, + } + + #[derive(Display)] + #[derive(Debug)] + #[display("{t:?}")] + struct Array { + t: [T; 10], + } + + mod parens { + #![allow(unused_parens)] // test that type is found even in parentheses + + use derive_more::Display; + + #[derive(Display)] + struct Paren { + t: (T), + } + } + + #[derive(Display)] + struct ParenthesizedGenericArgumentsInput { + t: dyn Fn(T) -> i32, + } + + #[derive(Display)] + struct ParenthesizedGenericArgumentsOutput { + t: dyn Fn(i32) -> T, + } + + #[derive(Display)] + struct Ptr { + t: *const T, + } + + #[derive(Display)] + struct Reference<'a, T> { + t: &'a T, + } + + #[derive(Display)] + struct Slice<'a, T> { + t: &'a [T], + } + + #[derive(Display)] + struct BareFn { + t: Box T>, + } + + #[derive(Display)] + struct Tuple { + t: Box<(T, T)>, + } + + trait MyTrait {} + + #[derive(Display)] + struct TraitObject { + t: Box>, + } + + #[test] + fn assert() { + assert_eq!( + format!( + "{}", + ItemStruct { + next: OptionalBox { + inner: Some(Box::new(ItemStruct { + next: OptionalBox { inner: None } + })) + } + }, + ), + "Some(ItemStruct { next: OptionalBox { inner: None } })", + ); + + assert_eq!( + format!( + "{}", + ItemTuple(OptionalBox { + inner: Some(Box::new(ItemTuple(OptionalBox { inner: None }))) + }), + ), + "Some(ItemTuple(OptionalBox { inner: None }))", + ); + + assert_eq!( + format!( + "{}", + ItemTupleContainerFmt(OptionalBox { + inner: Some(Box::new(ItemTupleContainerFmt(OptionalBox { + inner: None + }))) + }), + ), + "Item(Some(ItemTupleContainerFmt(OptionalBox { inner: None })))", + ); + + let item = ItemEnumOuterFormat::Variant1 { + next: OptionalBox { + inner: Some(Box::new(ItemEnumOuterFormat::Variant2 { + next: OptionalBox { inner: None }, + })), + }, + }; + assert_eq!( + format!("{item}"), + "Some(Variant2 { next: OptionalBox { inner: None } })", + ) + } +}