Skip to content

Commit

Permalink
adding serialize tpl
Browse files Browse the repository at this point in the history
  • Loading branch information
alexge233 committed Jul 4, 2017
1 parent 63cd9c3 commit feaea2a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 22 deletions.
1 change: 0 additions & 1 deletion examples/gridworld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,3 @@ int main()
on_policy(w, policies, start);
return 0;
}

41 changes: 20 additions & 21 deletions src/relearn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
#include <memory>
#include <cassert>
#if USING_BOOST_SERIALIZATION
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/access.hpp>
#include <boost/serialization/unordered_map.hpp>
#include "serialize.tpl"
#endif
/**
* @brief relearn C++ reinforcement learning library
Expand Down Expand Up @@ -79,12 +77,6 @@ class state
value_type reward() const;
/// @return trait
state_trait trait() const;
#if USING_BOOST_SERIALIZATION
/// default empty state - used by boost serialization only
/// @warning creating an empty state on purpose will break the
/// learning algorithm
state() = default;
#endif
private:
// state reward
value_type __reward__;
Expand All @@ -95,6 +87,8 @@ class state
// @warning - template parameter `state_trait` must be serializable
template <typename archive>
void serialize(archive & ar, const unsigned int version);
protected:
state() = default;
#endif
};

Expand Down Expand Up @@ -129,12 +123,6 @@ class action
std::size_t hash() const;
/// return trait
action_trait trait() const;
#if USING_BOOST_SERIALIZATION
/// default empty action - used by boost serialization only
/// @warning creating an empty action class on purpose
/// will break the algorithm
action() = default;
#endif
private:
/// action descriptor - object/value wrapped
action_trait __trait__;
Expand All @@ -143,6 +131,8 @@ class action
// @warning - template parameter `action_trait` must be serializable
template <typename archive>
void serialize(archive & ar, const unsigned int version);
protected:
action() = default;
#endif
};

Expand Down Expand Up @@ -210,23 +200,32 @@ class policy
/// @warning if none are found, returns nullptr
std::unique_ptr<action_class> best_action(state_class s_t);
private:
// internal structure mapping states => (map of actions => values)
// if using boost serialization the policies used
// are wrappers defined in internal header `serialize.tpl`
#if USING_BOOST_SERIALIZATION
friend class boost::serialization::access;
template <typename archive>
void serialize(archive & ar, const unsigned int version);
std::unordered_map<state_wrapper
std::unordered_map<action_wrapper,
value_type,
hasher<action_wrapper>>,
hasher<state_class>
> __policies__;
// else we're using the actual `state_class` and `action_class` types
#else
std::unordered_map<state_class,
std::unordered_map<action_class,
value_type,
hasher<action_class>>,
hasher<state_class>
> __policies__;
#if USING_BOOST_SERIALIZATION
friend class boost::serialization::access;
template <typename archive>
void serialize(archive & ar, const unsigned int version);
#endif
};

/// @brief definition of hash functor for state<S,A>
template <class action_class,
typename value_type>
typename value_type>
struct hasher<std::unordered_map<action_class,value_type>>
{
std::size_t operator()(const std::unordered_map<action_class,value_type> &arg) const;
Expand Down
62 changes: 62 additions & 0 deletions src/serialize.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#if USING_BOOST_SERIALIZATION
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/access.hpp>
#include <boost/serialization/unordered_map.hpp>

template <class state_class>
struct state_wrapper : private state_class
{
state_wrapper();
operator state_class() const;
};

template <class action_class>
struct action_wrapper : private action_class
{
action_wrapper();
operator action_class() const;
};

template <class state_class>
struct cast_state
{
state_wrapper<state_class> operator()(const state_class arg) const;
};

template <class action_class>
struct cast_action
{
action_wrapper<action_class> operator()(const action_class arg) const;
};

template <class state_class>
struct hasher<state_wrapper<state_class>>
{
std::size_t operator()(const state_wrapper<state_class> & arg) const;
};

template <class action_class>
struct hasher<action_wrapper<action_class>>
{
std::size_t operator()(const action_wrapper<action_class> & arg) const;
};

template <class action_class,
typename value_type>
struct hasher<std::unordered_map<action_class,value_type>>
{
std::size_t operator()(const std::unordered_map<action_class,value_type> &arg) const;
};

template <class action_class,
typename value_type>
struct hasher<std::unordered_map<action_wrapper<action_class>,value_type>>
{
std::size_t operator()(const std::unordered_map<action_wrapper<action_class>,
value_type> &arg) const;
};


/// TODO: implement those methods above

#endif

0 comments on commit feaea2a

Please sign in to comment.