Skip to content

Commit

Permalink
#344 Workaround for NVCC bug
Browse files Browse the repository at this point in the history
For NVCC < 11.7.1, the templated static constexpr member variables give
the wrong results (has_trait_v replaced with has_trait::value).
  • Loading branch information
Matthew-Whitlock committed Sep 24, 2024
1 parent 14b22f3 commit 2445350
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 54 deletions.
14 changes: 7 additions & 7 deletions examples/checkpoint_example_user_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ namespace test {

TestObj() {}

template<typename SerT, typename std::enable_if_t<SerT::template has_not_traits<shallow_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_not_traits<shallow_trait>::type = nullptr>
void serialize(SerT& s){
if constexpr(SerT::template has_traits<checkpoint_trait>::value){
if constexpr(SerT::template has_traits_v<checkpoint_trait>){
if(s.isSizing()) printf("Customizing serialization for checkpoint\n");
s | a;
} else {
Expand All @@ -26,13 +26,13 @@ namespace test {
}

namespace test {
template<typename SerT, typename std::enable_if_t<SerT::template has_traits<random_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_traits<random_trait>::type = nullptr>
void serialize(SerT& s, TestObj& myObj){
if(s.isSizing()) printf("Inserting random extra object serialization step! ");
myObj.serialize(s);
}

template<typename SerT, typename std::enable_if_t<SerT::template has_traits<shallow_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_traits<shallow_trait>::type = nullptr>
void serialize(SerT& s, TestObj& myObj){
if(s.isSizing()) printf("Removing shallow trait before passing along!\n");
auto newS = s.template withoutTraits<shallow_trait>();
Expand All @@ -41,23 +41,23 @@ namespace test {
}

namespace misc {
template<typename SerT, typename std::enable_if_t<SerT::template has_traits<test::random_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_traits<test::random_trait>::type = nullptr>
void serialize(SerT& s, test::TestObj& myObj){
if(s.isSizing()) printf("Serializers in other namespaces don't usually get found ");
myObj.serialize(s);
}


const struct namespace_trait {} NamespaceTrait;
template<typename SerT, typename std::enable_if_t<SerT::template has_traits<namespace_trait>::value>* = nullptr>
template<typename SerT, typename SerT::template has_traits<namespace_trait>::type = nullptr>
void serialize(SerT& s, test::TestObj& myObj){
if(s.isSizing()) printf("A misc:: trait means we can serialize from misc:: too: ");
myObj.serialize(s);
}


const struct hook_all_trait {} HookAllTrait;
template<typename SerT, typename T, typename std::enable_if_t<SerT::template has_traits<hook_all_trait>::value>* = nullptr>
template<typename SerT, typename T, typename SerT::template has_traits<hook_all_trait>::type = nullptr>
void serialize(SerT& s, T& myObj){
if(s.isSizing()) printf("We can even add on a generic pre-serialize hook: ");
auto newS = s.template withoutTraits<hook_all_trait>();
Expand Down
10 changes: 1 addition & 9 deletions src/checkpoint/dispatch/dispatch.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,9 @@ TraverserT Traverse::with(T& target, Args&&... args) {
#if !defined(SERIALIZATION_ERROR_CHECKING)
using CleanT = typename CleanType<T>::CleanT;
#endif

TraverserT t_base(std::forward<Args>(args)...);
auto t = SerializerRef(&t_base, Traits{});

//std::optional<TraverserT> t_opt;
//if constexpr(is_serializer_ref<TraverserT>) {
// t_opt.emplace(std::in_place, std::forward<Args>(args)...);
//} else {
// t_opt.emplace(std::forward<Args>(args)...);
//}
//TraverserT& t = t_opt.value();

#if !defined(SERIALIZATION_ERROR_CHECKING)
withTypeIdx<CleanT>(t);
Expand Down
2 changes: 1 addition & 1 deletion src/checkpoint/dispatch/vrt/virtual_serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ template <typename T, typename SerializerT>
void virtualSerialize(T*& base, SerializerT& s) {
//We can't support traited serializing with virtual types.
static_assert(std::is_same_v<SerializerT, typename SerializerT::TraitlessT>, "User Traits are incompatible with virtual serialization");

// Get the real base in case this is called on a derived type
using BaseT = ::checkpoint::dispatch::vrt::checkpoint_base_type_t<T>;
auto serializer_idx = serializer_registry::makeObjIdx<BaseT, SerializerT>();
Expand Down
46 changes: 21 additions & 25 deletions src/checkpoint/serializers/serializer_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@

namespace checkpoint {

namespace {
// Cuda does not play nicely with templated static constexpr
// member variables in the SerializerRef class, so we make a
// helper struct to hold non-templated static constexpr members.
template<bool B>
struct bool_enable_if {
static constexpr bool value = false;
};

template<>
struct bool_enable_if<true> {
static constexpr bool value = true;
using type = void*;
};
}

template<typename SerT, typename UserTraits = UserTraitHolder<>>
struct SerializerRef
{
Expand Down Expand Up @@ -112,33 +128,14 @@ struct SerializerRef

//Big block of helpers for conveniently checking traits in different contexts.
template<typename... Traits>
using has_traits = typename TraitHolder::template has<Traits...>;
template<typename... Traits>
using has_any_traits = typename TraitHolder::template has_any<Traits...>;

template<typename... Traits>
using has_not_traits = std::integral_constant<bool, !(has_traits<Traits...>::value)>;
template<typename... Traits>
using has_not_any_traits = std::integral_constant<bool, !(has_any_traits<Traits...>::value)>;

template<typename... Traits>
static constexpr bool has_traits_v = has_traits<Traits...>::value;
template<typename... Traits>
static constexpr bool has_any_traits_v = has_any_traits<Traits...>::value;
using has_traits = bool_enable_if<TraitHolder::template has<Traits...>::value>;
template<typename... Traits>
static constexpr bool has_not_traits_v = has_not_traits<Traits...>::value;
template<typename... Traits>
static constexpr bool has_not_any_traits_v = has_not_any_traits<Traits...>::value;
using has_any_traits = bool_enable_if<TraitHolder::template has_any<Traits...>::value>;

template<typename... Traits>
using has_traits_t = std::enable_if_t<has_traits_v<Traits...>>;
template<typename... Traits>
using has_any_traits_t = std::enable_if_t<has_any_traits_v<Traits...>>;
using has_not_traits = bool_enable_if<!(has_traits<Traits...>::value)>;
template<typename... Traits>
using has_not_traits_t = std::enable_if_t<has_not_traits_v<Traits...>>;
template<typename... Traits>
using has_not_any_traits_t = std::enable_if_t<has_not_any_traits_v<Traits...>>;

using has_not_any_traits = bool_enable_if<!(has_any_traits<Traits...>::value)>;

//Helpers for converting between traits
using TraitlessT = SerializerRef<SerT>;
Expand All @@ -147,7 +144,6 @@ struct SerializerRef
template<typename Trait, typename... Traits>
auto withTraits(UserTraitHolder<Trait, Traits...> = {}){
using NewTraitHolder = typename TraitHolder::template with<Trait, Traits...>;
//return setTraits(NewTraitHolder {});
return SerializerRef<SerT, NewTraitHolder>(*this);
}

Expand All @@ -168,7 +164,7 @@ struct SerializerRef

protected:
SerT *const impl;

template<typename, typename>
friend struct SerializerRef;
};
Expand Down
30 changes: 18 additions & 12 deletions tests/unit/test_user_traits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,16 @@ struct UserObjectA {
: name('A'), val_a(val) {};
explicit UserObjectA(int val, char name_)
: name(name_), val_a(val) {};

template <typename S>
void serialize(S& s) {
std::cout << "A: serializing with type "
<< abi::__cxa_demangle(typeid(s).name(), nullptr, nullptr, nullptr)
<< std::endl;
EXPECT_FALSE((S::template has_traits_v<ShallowTrait>));

EXPECT_EQ(
(S::template has_traits_v<TraitPairA>),
(S::template has_traits_v<TraitPairA>),
(S::template has_traits_v<TraitPairB>)
);

Expand All @@ -93,20 +96,20 @@ struct UserObjectA {
};

template<
typename SerT,
typename = typename SerT::template has_traits_t<CheckpointTraitNonintrusive>
typename S,
typename = typename S::template has_traits<CheckpointTraitNonintrusive>::type
>
void serialize(SerT& s, UserObjectA& obj){
void serialize(S& s, UserObjectA& obj){
s | obj.name;
obj.serialize(s);
}

namespace CheckpointNamespace {
template<
typename SerT,
typename = typename SerT::template has_traits_t<CheckpointTraitNamespaced>
typename S,
typename = typename S::template has_traits<CheckpointTraitNamespaced>::type
>
void serialize(SerT& s, UserObjectA& obj){
void serialize(S& s, UserObjectA& obj){
s | obj.name;
obj.serialize(s);
}
Expand Down Expand Up @@ -150,8 +153,11 @@ struct UserObjectB : public UserObjectA {

template <typename S>
void serialize(S& s) {
std::cout << "B: serializing with type "
<< abi::__cxa_demangle(typeid(s).name(), nullptr, nullptr, nullptr)
<< std::endl;
auto new_s = s.template withoutTraits<ShallowTrait>();
if constexpr(S::template has_traits_v<TraitPairA>){
if (S::template has_traits<TraitPairA>::value){
auto newer_s = new_s.template withTraits<TraitPairB>();
UserObjectA::serialize(newer_s);
} else {
Expand All @@ -168,7 +174,7 @@ TEST(TestUserTraits, test_trait_removal) {
UserObjectB objB(u_val);
int old_b_val = objB.val_b, old_a_val = objB.val_a;
int new_b_val = old_b_val+1, new_a_val = old_a_val+1;

auto serB = checkpoint::serialize<UserObjectB, ShallowTrait>(objB);
objB.val_b = new_b_val;
objB.val_a = new_a_val;
Expand All @@ -183,7 +189,7 @@ TEST(TestUserTraits, test_trait_addition) {
UserObjectB objB(u_val);
int old_b_val = objB.val_b, old_a_val = objB.val_a;
int new_b_val = old_b_val+1, new_a_val = old_a_val+1;

auto serB = checkpoint::serialize<UserObjectB, TraitPairA>(objB);
objB.val_b = new_b_val;
objB.val_a = new_a_val;
Expand Down

0 comments on commit 2445350

Please sign in to comment.