From 49f7eafc2b051853002562c9568fb1d153e263a1 Mon Sep 17 00:00:00 2001 From: Roland Fredenhagen Date: Fri, 1 Sep 2023 01:55:40 +0200 Subject: [PATCH] Add `TryFrom` to convert repr to enum --- CHANGELOG.md | 2 + Cargo.toml | 7 + README.md | 7 +- impl/Cargo.toml | 4 +- impl/doc/try_from.md | 37 +++++ impl/src/lib.rs | 4 + impl/src/parsing.rs | 4 +- impl/src/try_from.rs | 141 ++++++++++++++++++ src/convert.rs | 32 ++++ src/lib.rs | 12 +- tests/compile_fail/try_from/invalid_repr.rs | 7 + .../compile_fail/try_from/invalid_repr.stderr | 11 ++ tests/compile_fail/try_from/struct.rs | 4 + tests/compile_fail/try_from/struct.stderr | 5 + tests/compile_fail/try_from/union.rs | 6 + tests/compile_fail/try_from/union.stderr | 5 + tests/try_from.rs | 80 ++++++++++ 17 files changed, 361 insertions(+), 7 deletions(-) create mode 100644 impl/doc/try_from.md create mode 100644 impl/src/try_from.rs create mode 100644 tests/compile_fail/try_from/invalid_repr.rs create mode 100644 tests/compile_fail/try_from/invalid_repr.stderr create mode 100644 tests/compile_fail/try_from/struct.rs create mode 100644 tests/compile_fail/try_from/struct.stderr create mode 100644 tests/compile_fail/try_from/union.rs create mode 100644 tests/compile_fail/try_from/union.stderr create mode 100644 tests/try_from.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 87fa5252..c00829d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ([#279](https://github.com/JelteF/derive_more/pull/279)) - `derive_more::derive` module exporting only macros, without traits. ([#290](https://github.com/JelteF/derive_more/pull/290)) +- Add `TryFrom` derive for enums to convert from their discriminant. + ([#300](https://github.com/JelteF/derive_more/pull/300)) ### Changed diff --git a/Cargo.toml b/Cargo.toml index b32d0f17..2cedb5ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,7 @@ mul_assign = ["derive_more-impl/mul_assign"] mul = ["derive_more-impl/mul"] not = ["derive_more-impl/not"] sum = ["derive_more-impl/sum"] +try_from = ["derive_more-impl/try_from"] try_into = ["derive_more-impl/try_into"] is_variant = ["derive_more-impl/is_variant"] unwrap = ["derive_more-impl/unwrap"] @@ -92,6 +93,7 @@ full = [ "mul_assign", "not", "sum", + "try_from", "try_into", "try_unwrap", "unwrap", @@ -204,6 +206,11 @@ name = "sum" path = "tests/sum.rs" required-features = ["sum"] +[[test]] +name = "try_from" +path = "tests/try_from.rs" +required-features = ["try_from"] + [[test]] name = "try_into" path = "tests/try_into.rs" diff --git a/README.md b/README.md index cec17fde..fc83b6cf 100644 --- a/README.md +++ b/README.md @@ -85,9 +85,10 @@ These are traits that are used to convert automatically between types. 1. [`From`] 2. [`Into`] 3. [`FromStr`] -4. [`TryInto`] -5. [`IntoIterator`] -6. [`AsRef`], [`AsMut`] +4. [`TryFrom`] +5. [`TryInto`] +6. [`IntoIterator`] +7. [`AsRef`], [`AsMut`] ### Formatting traits diff --git a/impl/Cargo.toml b/impl/Cargo.toml index 304a37ce..62307f7d 100644 --- a/impl/Cargo.toml +++ b/impl/Cargo.toml @@ -10,7 +10,6 @@ repository = "https://github.com/JelteF/derive_more" documentation = "https://docs.rs/derive_more" # explicitly no keywords or categories so it cannot be found easily - include = [ "src/**/*.rs", "doc/**/*.md", @@ -35,6 +34,7 @@ rustc_version = { version = "0.4", optional = true } [dev-dependencies] derive_more = { path = "..", features = ["full"] } itertools = "0.11.0" +rustversion = "1.0" [badges] github = { repository = "JelteF/derive_more", workflow = "CI" } @@ -66,6 +66,7 @@ mul = ["syn/extra-traits"] mul_assign = ["syn/extra-traits"] not = ["syn/extra-traits"] sum = [] +try_from = [] try_into = ["syn/extra-traits"] try_unwrap = ["dep:convert_case"] unwrap = ["dep:convert_case"] @@ -91,6 +92,7 @@ full = [ "mul_assign", "not", "sum", + "try_from", "try_into", "try_unwrap", "unwrap", diff --git a/impl/doc/try_from.md b/impl/doc/try_from.md new file mode 100644 index 00000000..0a6d2990 --- /dev/null +++ b/impl/doc/try_from.md @@ -0,0 +1,37 @@ +# What `#[derive(TryFrom)]` generates + +This derive allows you to convert enum discriminants into their corresponding variants. +By default a `TryFrom` is generated, matching the [type of the discriminant](https://doc.rust-lang.org/reference/items/enumerations.html#discriminants). +The type can be changed with a `#[repr(u/i*)]` attribute, e.g., `#[repr(u8)]` or `#[repr(i32)]`. +Only field-less variants can be constructed from their variant, therefor the `TryFrom` implementation will return an error for a discriminant representing a variant with fields. + +## Example usage + +```rust +# #[rustversion::since(1.66)] +# mod discriminant_on_non_unit_enum { +# use derive_more::TryFrom; +#[derive(TryFrom, Debug, PartialEq)] +#[repr(u32)] +enum Enum { + Implicit, + Explicit = 5, + Field(usize), + Empty{}, +} + +# #[rustversion::since(1.66)] +# pub fn test(){ +assert_eq!(Enum::Implicit, Enum::try_from(0).unwrap()); +assert_eq!(Enum::Explicit, Enum::try_from(5).unwrap()); +assert_eq!(Enum::Empty{}, Enum::try_from(7).unwrap()); + +// variants with fields are not supported +assert!(Enum::try_from(6).is_err()); +# } +# } +# // We need to use a `function` declaration, because we cannot put `rustversion` on a statement. +# #[rustversion::since(1.66)] use discriminant_on_non_unit_enum::test; +# #[rustversion::before(1.66)] fn test() {} +# test(); +``` diff --git a/impl/src/lib.rs b/impl/src/lib.rs index 62c32700..a229c848 100644 --- a/impl/src/lib.rs +++ b/impl/src/lib.rs @@ -64,6 +64,8 @@ mod not_like; pub(crate) mod parsing; #[cfg(feature = "sum")] mod sum_like; +#[cfg(feature = "try_from")] +mod try_from; #[cfg(feature = "try_into")] mod try_into; #[cfg(feature = "try_unwrap")] @@ -265,6 +267,8 @@ create_derive!("not", not_like, Neg, neg_derive); create_derive!("sum", sum_like, Sum, sum_derive); create_derive!("sum", sum_like, Product, product_derive); +create_derive!("try_from", try_from, TryFrom, try_from_derive, try_from); + create_derive!("try_into", try_into, TryInto, try_into_derive, try_into); create_derive!( diff --git a/impl/src/parsing.rs b/impl/src/parsing.rs index 42b66376..dfbd21e0 100644 --- a/impl/src/parsing.rs +++ b/impl/src/parsing.rs @@ -245,8 +245,8 @@ pub fn seq( move |c| { parsers .iter_mut() - .fold(Some((TokenStream::new(), c)), |out, parser| { - let (mut out, mut c) = out?; + .try_fold((TokenStream::new(), c), |out, parser| { + let (mut out, mut c) = out; let (stream, cursor) = parser(c)?; out.extend(stream); c = cursor; diff --git a/impl/src/try_from.rs b/impl/src/try_from.rs new file mode 100644 index 00000000..28656abe --- /dev/null +++ b/impl/src/try_from.rs @@ -0,0 +1,141 @@ +//! Implementation of a [`TryFrom`] derive macro. + +use proc_macro2::{Literal, Span, TokenStream}; +use quote::{format_ident, quote, ToTokens as _}; +use syn::{spanned::Spanned as _, Ident, Variant}; + +/// Expands a [`TryFrom`] derive macro. +pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { + match &input.data { + syn::Data::Struct(data) => Err(syn::Error::new( + data.struct_token.span(), + "`TryFrom` cannot be derived for structs", + )), + syn::Data::Enum(data) => Expansion { + repr: ReprAttribute::parse_attrs(&input.attrs)?, + ident: &input.ident, + variants: data.variants.iter().collect(), + generics: &input.generics, + } + .expand(), + syn::Data::Union(data) => Err(syn::Error::new( + data.union_token.span(), + "`TryFrom` cannot be derived for unions", + )), + } +} + +/// Representation of a [`Repr`] derive macro struct container attribute. +/// +/// ```rust,ignore +/// #[repr()] +/// ``` +struct ReprAttribute(Ident); + +impl ReprAttribute { + /// Parses a [`StructAttribute`] from the provided [`syn::Attribute`]s. + fn parse_attrs(attrs: impl AsRef<[syn::Attribute]>) -> syn::Result { + attrs + .as_ref() + .iter() + .filter(|attr| attr.path().is_ident("repr")) + .try_fold(None, |mut repr, attr| { + attr.parse_nested_meta(|meta| { + if let Some(ident) = meta.path.get_ident() { + if let "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" + | "i16" | "i32" | "i64" | "i128" | "isize" = + ident.to_string().as_str() + { + repr = Some(ident.clone()); + return Ok(()); + } + } + // ignore all other attributes that could have a body e.g. `align` + _ = meta.input.parse::(); + Ok(()) + }) + .map(|_| repr) + }) + // Default discriminant is interpreted as `isize` (https://doc.rust-lang.org/reference/items/enumerations.html#discriminants) + .map(|repr| repr.unwrap_or_else(|| Ident::new("isize", Span::call_site()))) + .map(Self) + } +} + +/// Expansion of a macro for generating [`TryFrom`] implementation of an enum +struct Expansion<'a> { + /// Enum `#[repr(u/i*)]` + repr: ReprAttribute, + /// Enum [`Ident`]. + ident: &'a Ident, + + /// Variant [`Ident`] in case of enum expansion. + variants: Vec<&'a syn::Variant>, + + /// Struct or enum [`syn::Generics`]. + generics: &'a syn::Generics, +} + +impl<'a> Expansion<'a> { + /// Expands [`TryFrom`] implementations for a struct. + fn expand(&self) -> syn::Result { + let ident = self.ident; + let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); + + let repr = &self.repr.0; + + let mut last_discriminant = quote! {0}; + let mut inc = 0usize; + let (consts, (discriminants, variants)): ( + Vec, + (Vec, Vec), + ) = self + .variants + .iter() + .filter_map( + |Variant { + ident, + fields, + discriminant, + .. + }| { + if let Some(discriminant) = discriminant { + last_discriminant = discriminant.1.to_token_stream(); + inc = 0; + } + let ret = { + let inc = Literal::usize_unsuffixed(inc); + fields.is_empty().then_some(( + format_ident!("__DISCRIMINANT_{ident}"), + ( + quote! {#last_discriminant + #inc}, + quote! {#ident #fields}, + ), + )) + }; + inc += 1; + ret + }, + ) + .unzip(); + + Ok(quote! { + #[automatically_derived] + impl #impl_generics + ::core::convert::TryFrom<#repr #ty_generics> for #ident + #where_clause + { + type Error = ::derive_more::TryFromError<#repr>; + + #[inline] + fn try_from(value: #repr) -> ::core::result::Result { + #(#[allow(non_upper_case_globals)] const #consts: #repr = #discriminants;)* + match value { + #(#consts => ::core::result::Result::Ok(#ident::#variants),)* + _ => ::core::result::Result::Err(::derive_more::TryFromError::new(value)), + } + } + } + }) + } +} diff --git a/src/convert.rs b/src/convert.rs index 4885ace2..7b06ed0a 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -45,3 +45,35 @@ impl fmt::Display for TryIntoError { #[cfg(feature = "std")] impl std::error::Error for TryIntoError {} + +/// Error returned by the derived [`TryFrom`] implementation. +/// +/// [`TryFrom`]: macro@crate::TryFrom +#[derive(Clone, Copy, Debug)] +pub struct TryFromError { + /// Original input value which failed to convert via the derived + /// [`TryFrom`] implementation. + /// + /// [`TryFrom`]: macro@crate::TryFrom + pub input: T, +} + +impl TryFromError { + #[doc(hidden)] + #[must_use] + #[inline] + pub const fn new(input: T) -> Self { + Self { input } + } +} + +// `T` should only be an integer type and therefor display +impl fmt::Display for TryFromError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "`{}` does not respond to a unit variant", self.input) + } +} + +#[cfg(feature = "std")] +// `T` should only be an integer type and therefor display and debug +impl std::error::Error for TryFromError {} diff --git a/src/lib.rs b/src/lib.rs index e38053a8..7c687375 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ //! [`From`]: macro@crate::From //! [`Into`]: macro@crate::Into //! [`FromStr`]: macro@crate::FromStr +//! [`TryFrom`]: macro@crate::TryInto //! [`TryInto`]: macro@crate::TryInto //! [`IntoIterator`]: macro@crate::IntoIterator //! [`AsRef`]: macro@crate::AsRef @@ -89,8 +90,11 @@ mod r#str; #[doc(inline)] pub use crate::r#str::FromStrError; -#[cfg(feature = "try_into")] +#[cfg(any(feature = "try_into", feature = "try_from"))] mod convert; +#[cfg(feature = "try_from")] +#[doc(inline)] +pub use crate::convert::TryFromError; #[cfg(feature = "try_into")] #[doc(inline)] pub use crate::convert::TryIntoError; @@ -203,6 +207,8 @@ re_export_traits!("not", not_traits, core::ops, Neg, Not); re_export_traits!("sum", sum_traits, core::iter, Product, Sum); +re_export_traits!("try_from", try_from_traits, core::convert, TryFrom); + re_export_traits!("try_into", try_into_traits, core::convert, TryInto); // Now re-export our own derives by their exact name to overwrite any derives that the trait @@ -271,6 +277,9 @@ pub use derive_more_impl::{Neg, Not}; #[cfg(feature = "sum")] pub use derive_more_impl::{Product, Sum}; +#[cfg(feature = "try_from")] +pub use derive_more_impl::TryFrom; + #[cfg(feature = "try_into")] pub use derive_more_impl::TryInto; @@ -303,6 +312,7 @@ pub use derive_more_impl::Unwrap; feature = "mul_assign", feature = "not", feature = "sum", + feature = "try_from", feature = "try_into", feature = "try_unwrap", feature = "unwrap", diff --git a/tests/compile_fail/try_from/invalid_repr.rs b/tests/compile_fail/try_from/invalid_repr.rs new file mode 100644 index 00000000..1b748d82 --- /dev/null +++ b/tests/compile_fail/try_from/invalid_repr.rs @@ -0,0 +1,7 @@ +#[derive(derive_more::TryFrom)] +#[repr(a + b)] +enum Enum { + Variant +} + +fn main() {} diff --git a/tests/compile_fail/try_from/invalid_repr.stderr b/tests/compile_fail/try_from/invalid_repr.stderr new file mode 100644 index 00000000..90e7dc7f --- /dev/null +++ b/tests/compile_fail/try_from/invalid_repr.stderr @@ -0,0 +1,11 @@ +error: expected `,` + --> tests/compile_fail/try_from/invalid_repr.rs:2:10 + | +2 | #[repr(a + b)] + | ^ + +error: expected one of `(`, `,`, `::`, or `=`, found `+` + --> tests/compile_fail/try_from/invalid_repr.rs:2:10 + | +2 | #[repr(a + b)] + | ^ expected one of `(`, `,`, `::`, or `=` diff --git a/tests/compile_fail/try_from/struct.rs b/tests/compile_fail/try_from/struct.rs new file mode 100644 index 00000000..f3609b19 --- /dev/null +++ b/tests/compile_fail/try_from/struct.rs @@ -0,0 +1,4 @@ +#[derive(derive_more::TryFrom)] +struct Struct; + +fn main() {} diff --git a/tests/compile_fail/try_from/struct.stderr b/tests/compile_fail/try_from/struct.stderr new file mode 100644 index 00000000..47dacfac --- /dev/null +++ b/tests/compile_fail/try_from/struct.stderr @@ -0,0 +1,5 @@ +error: `TryFrom` cannot be derived for structs + --> tests/compile_fail/try_from/struct.rs:2:1 + | +2 | struct Struct; + | ^^^^^^ diff --git a/tests/compile_fail/try_from/union.rs b/tests/compile_fail/try_from/union.rs new file mode 100644 index 00000000..25d66234 --- /dev/null +++ b/tests/compile_fail/try_from/union.rs @@ -0,0 +1,6 @@ +#[derive(derive_more::TryFrom)] +union Union { + field: i32, +} + +fn main() {} diff --git a/tests/compile_fail/try_from/union.stderr b/tests/compile_fail/try_from/union.stderr new file mode 100644 index 00000000..089a99bb --- /dev/null +++ b/tests/compile_fail/try_from/union.stderr @@ -0,0 +1,5 @@ +error: `TryFrom` cannot be derived for unions + --> tests/compile_fail/try_from/union.rs:2:1 + | +2 | union Union { + | ^^^^^ diff --git a/tests/try_from.rs b/tests/try_from.rs new file mode 100644 index 00000000..1891f44b --- /dev/null +++ b/tests/try_from.rs @@ -0,0 +1,80 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![allow(dead_code)] + +use derive_more::TryFrom; + +#[test] +fn test_with_repr() { + #[derive(TryFrom, Clone, Copy, Debug, Eq, PartialEq)] + #[repr(i16)] + enum Enum { + A, + B = -21, + C, + D, + } + assert_eq!(Enum::A, Enum::try_from(0i16).unwrap()); + assert_eq!(Enum::B, Enum::try_from(-21).unwrap()); + assert_eq!(Enum::C, Enum::try_from(-20).unwrap()); + assert_eq!(Enum::D, Enum::try_from(-19).unwrap()); + assert!(Enum::try_from(-1).is_err()); +} + +#[test] +fn enum_without_repr() { + #[derive(TryFrom, Clone, Copy, Debug, Eq, PartialEq)] + enum Enum { + A, + B = -21, + C, + D, + } + assert_eq!(Enum::A, Enum::try_from(0isize).unwrap()); + assert_eq!(Enum::B, Enum::try_from(-21).unwrap()); + assert_eq!(Enum::C, Enum::try_from(-20).unwrap()); + assert_eq!(Enum::D, Enum::try_from(-19).unwrap()); + assert!(Enum::try_from(-1).is_err()); +} + +#[test] +fn enum_with_complex_repr() { + #[derive(TryFrom, Clone, Copy, Debug, Eq, PartialEq)] + #[repr(align(16), i32)] + enum Enum { + A, + B = -21, + C, + D, + } + assert_eq!(Enum::A, Enum::try_from(0i32).unwrap()); + assert_eq!(Enum::B, Enum::try_from(-21).unwrap()); + assert_eq!(Enum::C, Enum::try_from(-20).unwrap()); + assert_eq!(Enum::D, Enum::try_from(-19).unwrap()); + assert!(Enum::try_from(-1).is_err()); +} + +#[rustversion::since(1.66)] +mod discriminants_on_enum_with_fields { + use super::*; + + #[derive(TryFrom, Clone, Copy, Debug, Eq, PartialEq)] + #[repr(i16)] + enum Enum { + A, + Discriminant = 5, + Field(usize), + Empty {}, + FieldWithDiscriminant(u8, i64) = -14, + EmptyTuple(), + } + + #[test] + fn test_discriminants_on_enum_with_fields() { + assert_eq!(Enum::A, Enum::try_from(0).unwrap()); + assert_eq!(Enum::Discriminant, Enum::try_from(5).unwrap()); + assert!(Enum::try_from(6).is_err()); + assert_eq!(Enum::Empty {}, Enum::try_from(7).unwrap()); + assert!(Enum::try_from(-14).is_err()); + assert_eq!(Enum::EmptyTuple(), Enum::try_from(-13).unwrap()); + } +}