Skip to content

Commit

Permalink
Add stable transparent enums
Browse files Browse the repository at this point in the history
  • Loading branch information
macladson committed Jun 10, 2024
1 parent 3fcd1e6 commit b614143
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
27 changes: 25 additions & 2 deletions tree_hash/tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ struct Shape {
radius: Option<u16>,
}

#[derive(TreeHash)]
#[derive(TreeHash, Clone)]
#[tree_hash(struct_behaviour = "profile")]
#[tree_hash(max_fields = "typenum::U8")]
struct Square {
Expand All @@ -148,7 +148,7 @@ struct Square {
color: u8,
}

#[derive(TreeHash)]
#[derive(TreeHash, Clone)]
#[tree_hash(struct_behaviour = "profile")]
#[tree_hash(max_fields = "typenum::U8")]
struct Circle {
Expand All @@ -158,6 +158,13 @@ struct Circle {
radius: u16,
}

#[derive(TreeHash)]
#[tree_hash(enum_behaviour = "transparent_stable")]
enum ShapeEnum {
SquareVariant(Square),
CircleVariant(Circle),
}

#[test]
fn shape_1() {
let shape_1 = Shape {
Expand Down Expand Up @@ -186,3 +193,19 @@ fn shape_2() {

assert_eq!(shape_2.tree_hash_root(), circle.tree_hash_root());
}

#[test]
fn shape_enum() {
let square = Square { side: 16, color: 2 };

let circle = Circle {
color: 1,
radius: 14,
};

let enum_square = ShapeEnum::SquareVariant(square.clone());
let enum_circle = ShapeEnum::CircleVariant(circle.clone());

assert_eq!(square.tree_hash_root(), enum_square.tree_hash_root());
assert_eq!(circle.tree_hash_root(), enum_circle.tree_hash_root());
}
19 changes: 16 additions & 3 deletions tree_hash_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const STRUCT_PROFILE: &str = "profile";
const STRUCT_VARIANTS: &[&str] = &[STRUCT_CONTAINER, STRUCT_STABLE_CONTAINER, STRUCT_PROFILE];

const ENUM_TRANSPARENT: &str = "transparent";
const ENUM_TRANSPARENT_STABLE: &str = "transparent_stable";
const ENUM_UNION: &str = "union";
const ENUM_VARIANTS: &[&str] = &[ENUM_TRANSPARENT, ENUM_UNION];
const NO_ENUM_BEHAVIOUR_ERROR: &str = "enums require an \"enum_behaviour\" attribute, \
Expand Down Expand Up @@ -62,13 +63,15 @@ impl StructBehaviour {

enum EnumBehaviour {
Transparent,
TransparentStable,
Union,
}

impl EnumBehaviour {
pub fn new(s: Option<String>) -> Option<Self> {
s.map(|s| match s.as_ref() {
ENUM_TRANSPARENT => EnumBehaviour::Transparent,
ENUM_TRANSPARENT_STABLE => EnumBehaviour::TransparentStable,
ENUM_UNION => EnumBehaviour::Union,
other => panic!(
"{} is an invalid enum_behaviour, use either {:?}",
Expand Down Expand Up @@ -238,7 +241,16 @@ pub fn tree_hash_derive(input: TokenStream) -> TokenStream {
panic!("cannot use \"struct_behaviour\" for an enum");
}
match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) {
EnumBehaviour::Transparent => tree_hash_derive_enum_transparent(&item, s),
EnumBehaviour::Transparent => tree_hash_derive_enum_transparent(
&item,
s,
syn::parse_str("Container").unwrap(),
),
EnumBehaviour::TransparentStable => tree_hash_derive_enum_transparent(
&item,
s,
syn::parse_str("StableContainer").unwrap(),
),
EnumBehaviour::Union => tree_hash_derive_enum_union(&item, s),
}
}
Expand Down Expand Up @@ -454,6 +466,7 @@ fn tree_hash_derive_struct_profile(
fn tree_hash_derive_enum_transparent(
derive_input: &DeriveInput,
enum_data: &DataEnum,
inner_container_type: Expr,
) -> TokenStream {
let name = &derive_input.ident;
let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();
Expand Down Expand Up @@ -486,11 +499,11 @@ fn tree_hash_derive_enum_transparent(
#(
assert_eq!(
#type_exprs,
tree_hash::TreeHashType::Container,
tree_hash::TreeHashType::#inner_container_type,
"all variants must be of container type"
);
)*
tree_hash::TreeHashType::Container
tree_hash::TreeHashType::#inner_container_type
}

fn tree_hash_packed_encoding(&self) -> tree_hash::PackedEncoding {
Expand Down

0 comments on commit b614143

Please sign in to comment.