From b614143d6598fdc67d2282cef69286f672c8f167 Mon Sep 17 00:00:00 2001 From: Mac L Date: Mon, 10 Jun 2024 14:27:58 +1000 Subject: [PATCH] Add stable transparent enums --- tree_hash/tests/tests.rs | 27 +++++++++++++++++++++++++-- tree_hash_derive/src/lib.rs | 19 ++++++++++++++++--- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/tree_hash/tests/tests.rs b/tree_hash/tests/tests.rs index 01fd6c5..e94ad6f 100644 --- a/tree_hash/tests/tests.rs +++ b/tree_hash/tests/tests.rs @@ -138,7 +138,7 @@ struct Shape { radius: Option, } -#[derive(TreeHash)] +#[derive(TreeHash, Clone)] #[tree_hash(struct_behaviour = "profile")] #[tree_hash(max_fields = "typenum::U8")] struct Square { @@ -148,7 +148,7 @@ struct Square { color: u8, } -#[derive(TreeHash)] +#[derive(TreeHash, Clone)] #[tree_hash(struct_behaviour = "profile")] #[tree_hash(max_fields = "typenum::U8")] struct Circle { @@ -158,6 +158,13 @@ struct Circle { radius: u16, } +#[derive(TreeHash)] +#[tree_hash(enum_behaviour = "transparent_stable")] +enum ShapeEnum { + SquareVariant(Square), + CircleVariant(Circle), +} + #[test] fn shape_1() { let shape_1 = Shape { @@ -186,3 +193,19 @@ fn shape_2() { assert_eq!(shape_2.tree_hash_root(), circle.tree_hash_root()); } + +#[test] +fn shape_enum() { + let square = Square { side: 16, color: 2 }; + + let circle = Circle { + color: 1, + radius: 14, + }; + + let enum_square = ShapeEnum::SquareVariant(square.clone()); + let enum_circle = ShapeEnum::CircleVariant(circle.clone()); + + assert_eq!(square.tree_hash_root(), enum_square.tree_hash_root()); + assert_eq!(circle.tree_hash_root(), enum_circle.tree_hash_root()); +} diff --git a/tree_hash_derive/src/lib.rs b/tree_hash_derive/src/lib.rs index b4517fe..d12a50e 100644 --- a/tree_hash_derive/src/lib.rs +++ b/tree_hash_derive/src/lib.rs @@ -35,6 +35,7 @@ const STRUCT_PROFILE: &str = "profile"; const STRUCT_VARIANTS: &[&str] = &[STRUCT_CONTAINER, STRUCT_STABLE_CONTAINER, STRUCT_PROFILE]; const ENUM_TRANSPARENT: &str = "transparent"; +const ENUM_TRANSPARENT_STABLE: &str = "transparent_stable"; const ENUM_UNION: &str = "union"; const ENUM_VARIANTS: &[&str] = &[ENUM_TRANSPARENT, ENUM_UNION]; const NO_ENUM_BEHAVIOUR_ERROR: &str = "enums require an \"enum_behaviour\" attribute, \ @@ -62,6 +63,7 @@ impl StructBehaviour { enum EnumBehaviour { Transparent, + TransparentStable, Union, } @@ -69,6 +71,7 @@ impl EnumBehaviour { pub fn new(s: Option) -> Option { s.map(|s| match s.as_ref() { ENUM_TRANSPARENT => EnumBehaviour::Transparent, + ENUM_TRANSPARENT_STABLE => EnumBehaviour::TransparentStable, ENUM_UNION => EnumBehaviour::Union, other => panic!( "{} is an invalid enum_behaviour, use either {:?}", @@ -238,7 +241,16 @@ pub fn tree_hash_derive(input: TokenStream) -> TokenStream { panic!("cannot use \"struct_behaviour\" for an enum"); } match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) { - EnumBehaviour::Transparent => tree_hash_derive_enum_transparent(&item, s), + EnumBehaviour::Transparent => tree_hash_derive_enum_transparent( + &item, + s, + syn::parse_str("Container").unwrap(), + ), + EnumBehaviour::TransparentStable => tree_hash_derive_enum_transparent( + &item, + s, + syn::parse_str("StableContainer").unwrap(), + ), EnumBehaviour::Union => tree_hash_derive_enum_union(&item, s), } } @@ -454,6 +466,7 @@ fn tree_hash_derive_struct_profile( fn tree_hash_derive_enum_transparent( derive_input: &DeriveInput, enum_data: &DataEnum, + inner_container_type: Expr, ) -> TokenStream { let name = &derive_input.ident; let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl(); @@ -486,11 +499,11 @@ fn tree_hash_derive_enum_transparent( #( assert_eq!( #type_exprs, - tree_hash::TreeHashType::Container, + tree_hash::TreeHashType::#inner_container_type, "all variants must be of container type" ); )* - tree_hash::TreeHashType::Container + tree_hash::TreeHashType::#inner_container_type } fn tree_hash_packed_encoding(&self) -> tree_hash::PackedEncoding {