diff --git a/examples/checkpoint_example_user_traits.hpp b/examples/checkpoint_example_user_traits.hpp index 8add6179..c0b64097 100644 --- a/examples/checkpoint_example_user_traits.hpp +++ b/examples/checkpoint_example_user_traits.hpp @@ -11,9 +11,9 @@ namespace test { TestObj() {} - template::value>* = nullptr> + template::type = nullptr> void serialize(SerT& s){ - if constexpr(SerT::template has_traits::value){ + if constexpr(SerT::template has_traits_v){ if(s.isSizing()) printf("Customizing serialization for checkpoint\n"); s | a; } else { @@ -26,13 +26,13 @@ namespace test { } namespace test { - template::value>* = nullptr> + template::type = nullptr> void serialize(SerT& s, TestObj& myObj){ if(s.isSizing()) printf("Inserting random extra object serialization step! "); myObj.serialize(s); } - template::value>* = nullptr> + template::type = nullptr> void serialize(SerT& s, TestObj& myObj){ if(s.isSizing()) printf("Removing shallow trait before passing along!\n"); auto newS = s.template withoutTraits(); @@ -41,7 +41,7 @@ namespace test { } namespace misc { - template::value>* = nullptr> + template::type = nullptr> void serialize(SerT& s, test::TestObj& myObj){ if(s.isSizing()) printf("Serializers in other namespaces don't usually get found "); myObj.serialize(s); @@ -49,7 +49,7 @@ namespace misc { const struct namespace_trait {} NamespaceTrait; - template::value>* = nullptr> + template::type = nullptr> void serialize(SerT& s, test::TestObj& myObj){ if(s.isSizing()) printf("A misc:: trait means we can serialize from misc:: too: "); myObj.serialize(s); @@ -57,7 +57,7 @@ namespace misc { const struct hook_all_trait {} HookAllTrait; - template::value>* = nullptr> + template::type = nullptr> void serialize(SerT& s, T& myObj){ if(s.isSizing()) printf("We can even add on a generic pre-serialize hook: "); auto newS = s.template withoutTraits(); diff --git a/src/checkpoint/dispatch/dispatch.impl.h b/src/checkpoint/dispatch/dispatch.impl.h index 568848cb..20eb65c2 100644 --- a/src/checkpoint/dispatch/dispatch.impl.h +++ b/src/checkpoint/dispatch/dispatch.impl.h @@ -172,17 +172,9 @@ TraverserT Traverse::with(T& target, Args&&... args) { #if !defined(SERIALIZATION_ERROR_CHECKING) using CleanT = typename CleanType::CleanT; #endif - + TraverserT t_base(std::forward(args)...); auto t = SerializerRef(&t_base, Traits{}); - - //std::optional t_opt; - //if constexpr(is_serializer_ref) { - // t_opt.emplace(std::in_place, std::forward(args)...); - //} else { - // t_opt.emplace(std::forward(args)...); - //} - //TraverserT& t = t_opt.value(); #if !defined(SERIALIZATION_ERROR_CHECKING) withTypeIdx(t); diff --git a/src/checkpoint/dispatch/vrt/virtual_serialize.h b/src/checkpoint/dispatch/vrt/virtual_serialize.h index 99f8c81f..a2e66b4c 100644 --- a/src/checkpoint/dispatch/vrt/virtual_serialize.h +++ b/src/checkpoint/dispatch/vrt/virtual_serialize.h @@ -61,7 +61,7 @@ template void virtualSerialize(T*& base, SerializerT& s) { //We can't support traited serializing with virtual types. static_assert(std::is_same_v, "User Traits are incompatible with virtual serialization"); - + // Get the real base in case this is called on a derived type using BaseT = ::checkpoint::dispatch::vrt::checkpoint_base_type_t; auto serializer_idx = serializer_registry::makeObjIdx(); diff --git a/src/checkpoint/serializers/serializer_ref.h b/src/checkpoint/serializers/serializer_ref.h index 9afbb6c2..32339954 100644 --- a/src/checkpoint/serializers/serializer_ref.h +++ b/src/checkpoint/serializers/serializer_ref.h @@ -54,6 +54,22 @@ namespace checkpoint { +namespace { + // Cuda does not play nicely with templated static constexpr + // member variables in the SerializerRef class, so we make a + // helper struct to hold non-templated static constexpr members. + template + struct bool_enable_if { + static constexpr bool value = false; + }; + + template<> + struct bool_enable_if { + static constexpr bool value = true; + using type = void*; + }; +} + template> struct SerializerRef { @@ -112,33 +128,14 @@ struct SerializerRef //Big block of helpers for conveniently checking traits in different contexts. template - using has_traits = typename TraitHolder::template has; - template - using has_any_traits = typename TraitHolder::template has_any; - - template - using has_not_traits = std::integral_constant::value)>; - template - using has_not_any_traits = std::integral_constant::value)>; - - template - static constexpr bool has_traits_v = has_traits::value; - template - static constexpr bool has_any_traits_v = has_any_traits::value; + using has_traits = bool_enable_if::value>; template - static constexpr bool has_not_traits_v = has_not_traits::value; - template - static constexpr bool has_not_any_traits_v = has_not_any_traits::value; + using has_any_traits = bool_enable_if::value>; template - using has_traits_t = std::enable_if_t>; - template - using has_any_traits_t = std::enable_if_t>; + using has_not_traits = bool_enable_if::value)>; template - using has_not_traits_t = std::enable_if_t>; - template - using has_not_any_traits_t = std::enable_if_t>; - + using has_not_any_traits = bool_enable_if::value)>; //Helpers for converting between traits using TraitlessT = SerializerRef; @@ -147,7 +144,6 @@ struct SerializerRef template auto withTraits(UserTraitHolder = {}){ using NewTraitHolder = typename TraitHolder::template with; - //return setTraits(NewTraitHolder {}); return SerializerRef(*this); } @@ -168,7 +164,7 @@ struct SerializerRef protected: SerT *const impl; - + template friend struct SerializerRef; }; diff --git a/tests/unit/test_user_traits.cc b/tests/unit/test_user_traits.cc index 5b3ed098..9f47b1ee 100644 --- a/tests/unit/test_user_traits.cc +++ b/tests/unit/test_user_traits.cc @@ -72,13 +72,16 @@ struct UserObjectA { : name('A'), val_a(val) {}; explicit UserObjectA(int val, char name_) : name(name_), val_a(val) {}; - + template void serialize(S& s) { + std::cout << "A: serializing with type " + << abi::__cxa_demangle(typeid(s).name(), nullptr, nullptr, nullptr) + << std::endl; EXPECT_FALSE((S::template has_traits_v)); - + EXPECT_EQ( - (S::template has_traits_v), + (S::template has_traits_v), (S::template has_traits_v) ); @@ -93,20 +96,20 @@ struct UserObjectA { }; template< - typename SerT, - typename = typename SerT::template has_traits_t + typename S, + typename = typename S::template has_traits::type > -void serialize(SerT& s, UserObjectA& obj){ +void serialize(S& s, UserObjectA& obj){ s | obj.name; obj.serialize(s); } namespace CheckpointNamespace { template< - typename SerT, - typename = typename SerT::template has_traits_t + typename S, + typename = typename S::template has_traits::type > -void serialize(SerT& s, UserObjectA& obj){ +void serialize(S& s, UserObjectA& obj){ s | obj.name; obj.serialize(s); } @@ -150,8 +153,11 @@ struct UserObjectB : public UserObjectA { template void serialize(S& s) { + std::cout << "B: serializing with type " + << abi::__cxa_demangle(typeid(s).name(), nullptr, nullptr, nullptr) + << std::endl; auto new_s = s.template withoutTraits(); - if constexpr(S::template has_traits_v){ + if (S::template has_traits::value){ auto newer_s = new_s.template withTraits(); UserObjectA::serialize(newer_s); } else { @@ -168,7 +174,7 @@ TEST(TestUserTraits, test_trait_removal) { UserObjectB objB(u_val); int old_b_val = objB.val_b, old_a_val = objB.val_a; int new_b_val = old_b_val+1, new_a_val = old_a_val+1; - + auto serB = checkpoint::serialize(objB); objB.val_b = new_b_val; objB.val_a = new_a_val; @@ -183,7 +189,7 @@ TEST(TestUserTraits, test_trait_addition) { UserObjectB objB(u_val); int old_b_val = objB.val_b, old_a_val = objB.val_a; int new_b_val = old_b_val+1, new_a_val = old_a_val+1; - + auto serB = checkpoint::serialize(objB); objB.val_b = new_b_val; objB.val_a = new_a_val;