From feaea2ac7c7295a77a0f134f9b6a47a901b269f9 Mon Sep 17 00:00:00 2001 From: Alex Ge Date: Tue, 4 Jul 2017 15:54:23 +0100 Subject: [PATCH] adding serialize tpl --- examples/gridworld.cpp | 1 - src/relearn.hpp | 41 ++++++++++++++-------------- src/serialize.tpl | 62 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 22 deletions(-) create mode 100644 src/serialize.tpl diff --git a/examples/gridworld.cpp b/examples/gridworld.cpp index e2ba6c6..bba0b26 100644 --- a/examples/gridworld.cpp +++ b/examples/gridworld.cpp @@ -327,4 +327,3 @@ int main() on_policy(w, policies, start); return 0; } - diff --git a/src/relearn.hpp b/src/relearn.hpp index 8f7a141..28f71b2 100644 --- a/src/relearn.hpp +++ b/src/relearn.hpp @@ -22,9 +22,7 @@ #include #include #if USING_BOOST_SERIALIZATION -#include -#include -#include +#include "serialize.tpl" #endif /** * @brief relearn C++ reinforcement learning library @@ -79,12 +77,6 @@ class state value_type reward() const; /// @return trait state_trait trait() const; -#if USING_BOOST_SERIALIZATION - /// default empty state - used by boost serialization only - /// @warning creating an empty state on purpose will break the - /// learning algorithm - state() = default; -#endif private: // state reward value_type __reward__; @@ -95,6 +87,8 @@ class state // @warning - template parameter `state_trait` must be serializable template void serialize(archive & ar, const unsigned int version); +protected: + state() = default; #endif }; @@ -129,12 +123,6 @@ class action std::size_t hash() const; /// return trait action_trait trait() const; -#if USING_BOOST_SERIALIZATION - /// default empty action - used by boost serialization only - /// @warning creating an empty action class on purpose - /// will break the algorithm - action() = default; -#endif private: /// action descriptor - object/value wrapped action_trait __trait__; @@ -143,6 +131,8 @@ class action // @warning - template parameter `action_trait` must be serializable template void serialize(archive & ar, const unsigned int version); +protected: + action() = default; #endif }; @@ -210,23 +200,32 @@ class policy /// @warning if none are found, returns nullptr std::unique_ptr best_action(state_class s_t); private: - // internal structure mapping states => (map of actions => values) +// if using boost serialization the policies used +// are wrappers defined in internal header `serialize.tpl` +#if USING_BOOST_SERIALIZATION + friend class boost::serialization::access; + template + void serialize(archive & ar, const unsigned int version); + std::unordered_map>, + hasher + > __policies__; +// else we're using the actual `state_class` and `action_class` types +#else std::unordered_map>, hasher > __policies__; -#if USING_BOOST_SERIALIZATION - friend class boost::serialization::access; - template - void serialize(archive & ar, const unsigned int version); #endif }; /// @brief definition of hash functor for state template + typename value_type> struct hasher> { std::size_t operator()(const std::unordered_map &arg) const; diff --git a/src/serialize.tpl b/src/serialize.tpl new file mode 100644 index 0000000..2d7fbc5 --- /dev/null +++ b/src/serialize.tpl @@ -0,0 +1,62 @@ +#if USING_BOOST_SERIALIZATION +#include +#include +#include + +template +struct state_wrapper : private state_class +{ + state_wrapper(); + operator state_class() const; +}; + +template +struct action_wrapper : private action_class +{ + action_wrapper(); + operator action_class() const; +}; + +template +struct cast_state +{ + state_wrapper operator()(const state_class arg) const; +}; + +template +struct cast_action +{ + action_wrapper operator()(const action_class arg) const; +}; + +template +struct hasher> +{ + std::size_t operator()(const state_wrapper & arg) const; +}; + +template +struct hasher> +{ + std::size_t operator()(const action_wrapper & arg) const; +}; + +template +struct hasher> +{ + std::size_t operator()(const std::unordered_map &arg) const; +}; + +template +struct hasher,value_type>> +{ + std::size_t operator()(const std::unordered_map, + value_type> &arg) const; +}; + + +/// TODO: implement those methods above + +#endif