Skip to content

Commit

Permalink
finished serialization and added a serialize test
Browse files Browse the repository at this point in the history
  • Loading branch information
alexge233 committed Jul 13, 2017
1 parent 232b963 commit d5a8b40
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 60 deletions.
19 changes: 14 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,15 @@ set(CMAKE_MACOSX_RPATH 1)
set(SRC ${SRC} src)
include_directories(${SRC})

# add -DUSING_BOOST_SERIALIZATION in order to save policies on disk
# add_definitions(-DUSING_BOOST_SERIALIZATION)
# find_package(Boost 1.55 COMPONENTS serialization system)
# include_directories(${Boost_INCLUDE_DIR})

# build examples
set(EXAMPLES ${EXAMPLES} examples)
add_executable(ex_gridworld_offline ${EXAMPLES}/gridworld_offline.cpp)
add_executable(ex_gridworld_online ${EXAMPLES}/gridworld_online.cpp)
add_executable(ex_blackjack ${EXAMPLES}/blackjack.cpp)

# TODO: add an example for serialization of policy, state, action
# which uses the definition:

# set output
set(CMAKE_COLOR_MAKEFILE on)
set(CMAKE_VERBOSE_MAKEFILE off)
Expand All @@ -47,6 +45,7 @@ message(STATUS "CXX Linker: " ${CMAKE_EXE_LINKER_FLAGS})

# if building tests
if (BUILD_TESTS)

message(STATUS "Checking ${PROJECT_SOURCE_DIR}/tests/catch.hpp")
if (EXISTS "${PROJECT_SOURCE_DIR}/tests/catch.hpp")
message(STATUS "catch.hpp exists, not downloading again - make sure it is up to date!")
Expand All @@ -62,4 +61,14 @@ if (BUILD_TESTS)
add_test(unit_test test_classes)
add_executable(test_logic tests/logic_test.cpp)
add_test(logic_test test_logic)

# if building tests & testing serialization
if (USING_BOOST_SERIALIZATION)
add_definitions(-DUSING_BOOST_SERIALIZATION)
find_package(Boost 1.55 COMPONENTS serialization system)
include_directories(${Boost_INCLUDE_DIR})
add_executable(test_serialize tests/serialize_test.cpp)
target_link_libraries(test_serialize ${Boost_LIBRARIES})
add_test(serialize_test test_serialize)
endif()
endif()
2 changes: 1 addition & 1 deletion examples/blackjack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*
* @see the header (blackjack_header.hpp) for implementation details
*/
#include "blackhack_header.hpp"
#include "blackjack_header.hpp"

//
int main(void)
Expand Down
61 changes: 48 additions & 13 deletions src/relearn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,13 @@ struct hasher<state<state_trait>>
{
std::size_t operator()(const state<state_trait> & arg) const;
};

#if USING_BOOST_SERIALIZATION
template <class state_class>
struct hasher<state_serial<state_class>>
{
std::size_t operator()(const state_serial<state_class> & arg) const;
};
#endif
/**
* @brief an action class - wraps around your class or pdt
* @class action
Expand Down Expand Up @@ -144,7 +150,7 @@ class action
/// return trait
action_trait trait() const;
private:
/// action descriptor - object/value wrapped
// action descriptor - object/value wrapped
action_trait __trait__;
#if USING_BOOST_SERIALIZATION
friend class boost::serialization::access;
Expand All @@ -158,9 +164,15 @@ class action
template <class action_trait>
struct hasher<action<action_trait>>
{
std::size_t operator()(const action<action_trait> &arg) const;
std::size_t operator()(const action<action_trait> & arg) const;
};

template <class action_class>
#if USING_BOOST_SERIALIZATION
struct hasher<action_serial<action_class>>
{
std::size_t operator()(const action_serial<action_class> & arg) const;
};
#endif
/**
* @struct link
* @brief a simple `link` or pair for joining state-actions in the MDP
Expand Down Expand Up @@ -230,20 +242,20 @@ class policy
*/
std::pair<std::unique_ptr<action_class>,value_type> best(state_class s_t);
private:
// if using boost serialization the policies used
// are wrappers defined in internal header `serialize.tpl`
#if USING_BOOST_SERIALIZATION
friend class boost::serialization::access;
// serialize method
template <typename archive>
void serialize(archive & ar, const unsigned int version);
std::unordered_map<state_serial
std::unordered_map<action_serial,
// actual policies use the `_serial` wrapper from `serialize.tpl`
std::unordered_map<state_serial<state_class>,
std::unordered_map<action_serial<action_class>,
value_type,
hasher<action_serial>>,
hasher<state_serial>
hasher<action_serial<action_class>>>,
hasher<state_serial<state_class>>
> __policies__;
// else we're using the actual `state_class` and `action_class` types
#else
// policies maps is: [state][action][state_next] => Q-value
std::unordered_map<state_class,
std::unordered_map<action_class,
value_type,
Expand All @@ -252,7 +264,6 @@ class policy
> __policies__;
#endif
};

/**
* @class q_learning This is the **deterministic** Q-Learning algorithm
* @brief Q-Learning update algorithm sets policies using episodes (`markov_chain`)
Expand Down Expand Up @@ -512,7 +523,15 @@ template <class state_class,
typename policy<state_class,action_class,value_type>::action_map
policy<state_class,action_class,value_type>::actions(state_class s_t)
{
#if USING_BOOST_SERIALIZATION
std::unordered_map<action_class, value_type, hasher<action_class>> retval;
for (const auto & item : __policies__[s_t]) {
retval[static_cast<action_class>(item.first)] = item.second;
}
return retval;
#else
return __policies__[s_t];
#endif
}

template <class state_class,
Expand Down Expand Up @@ -582,6 +601,20 @@ void policy<state_class,action_class,value_type>::serialize(archive & ar,
{
ar & __policies__;
}

template <class state_class>
std::size_t hasher<state_serial<state_class>
>::operator()(const state_serial<state_class> & arg) const
{
return arg.hash();
}

template <class action_class>
std::size_t hasher<action_serial<action_class>
>::operator()(const action_serial<action_class> & arg) const
{
return arg.hash();
}
#endif

template <class state_class,
Expand Down Expand Up @@ -673,6 +706,8 @@ void q_probabilistic<state_class,action_class,markov_chain,value_type
std::get<2>(triplet));
}
}

#if USING_BOOST_SERIALIZATION
#include "serialize.tpl"
#endif
} // end of namespace
#endif
83 changes: 42 additions & 41 deletions src/serialize.tpl
Original file line number Diff line number Diff line change
@@ -1,62 +1,63 @@
#if USING_BOOST_SERIALIZATION
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/access.hpp>
#include <boost/serialization/unordered_map.hpp>

template <class state_class>
struct state_serial : private state_class
struct state_serial : public state_class
{
state_serial();
operator state_class() const;
state_serial(state_class);
operator state_class();
};

template <class action_class>
struct action_serial : private action_class
struct action_serial : public action_class
{
action_serial();
operator action_class() const;
};

action_serial(action_class);
operator action_class();
};

/*
* Template serialization implementation
*
* |^| | |
* | |_____| |
* | _____ |
* | | | |
* | |_____| |
* |_|_____|_|
*
*/
template <class state_class>
struct cast_state
{
state_serial<state_class> operator()(const state_class arg) const;
};
state_serial<state_class>::state_serial()
: state_class()
{}

template <class action_class>
struct cast_action
{
action_serial<action_class> operator()(const action_class arg) const;
};
template <class state_class>
state_serial<state_class>::state_serial(state_class arg)
: state_class(arg)
{}

template <class state_class>
struct hasher<state_serial<state_class>>
template <class state_class>
state_serial<state_class>::operator state_class()
{
std::size_t operator()(const state_serial<state_class> & arg) const;
};
return static_cast<state_class>(*this);
}

template <class action_class>
struct hasher<action_serial<action_class>>
{
std::size_t operator()(const action_serial<action_class> & arg) const;
};
template <class action_class>
action_serial<action_class>::action_serial()
: action_class()
{}

template <class action_class,
typename value_type>
struct hasher<std::unordered_map<action_class,value_type>>
{
std::size_t operator()(const std::unordered_map<action_class,value_type> &arg) const;
};
template <class action_class>
action_serial<action_class>::action_serial(action_class arg)
: action_class(arg)
{}

template <class action_class,
typename value_type>
struct hasher<std::unordered_map<action_serial<action_class>,value_type>>
template <class action_class>
action_serial<action_class>::operator action_class()
{
std::size_t operator()(const std::unordered_map<action_serial<action_class>,
value_type> &arg) const;
};


/// TODO: implement those methods above
return static_cast<action_class>(*this);
}

#endif
79 changes: 79 additions & 0 deletions tests/serialize_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "relearn.hpp"
#include <fstream>
#include <boost/archive/text_oarchive.hpp>
#include <boost/archive/text_iarchive.hpp>
#define CATCH_CONFIG_MAIN
#include "catch.hpp"

TEST_CASE("state class serialization test", "[state_class_serialization_test]")
{
auto s_x = relearn::state<std::string>("hello");
auto s_y = relearn::state<std::string>("world");
// create and open a character archive for output
std::ofstream ofs("serialize_test_state_class");
boost::archive::text_oarchive oa(ofs);
// write class instance to archive
oa << s_x;
ofs.close();
// load data from the archive now
std::ifstream ifs("serialize_test_state_class");
boost::archive::text_iarchive ia(ifs);
ia >> s_y;
// if serialization failed, then s_x will not equal s_y!
REQUIRE(s_x == s_y);
}

TEST_CASE("action class serialization test", "[action_class_serialization_test]")
{
auto a_x = relearn::action<int>(0);
auto a_y = relearn::action<int>(1);
// create and open a character archive for output
std::ofstream ofs("serialize_test_action_class");
boost::archive::text_oarchive oa(ofs);
// write class instance to archive
oa << a_x;
ofs.close();
// load data from the archive now
std::ifstream ifs("serialize_test_action_class");
boost::archive::text_iarchive ia(ifs);
ia >> a_y;
// if serialization failed, then a_x will not equal a_y!
REQUIRE(a_x == a_y);
}

SCENARIO("policy class serialization test", "[policy_class_serialize_test]")
{
using state = relearn::state<std::string>;
using action = relearn::action<std::string>;
using link = relearn::link<state,action>;
relearn::policy<state,action> memory;
GIVEN("hard-coded episode with positive reward")
{
std::deque<link> episode = {
{state("hello"), action("hi!")},
{state("how are you?"), action("I'm fine, and you?")},
{state(1, "not too bad! what you doing here?"),
action("I'm taking over the world!")},
};
WHEN("policy is trained and saved to disk") {
auto learner = relearn::q_learning<state,action>{0.9, 0.9};
for (int k = 0; k < 10; k++) {
learner(episode, memory);
}
std::ofstream ofs("serialize_test_policy_class");
boost::archive::text_oarchive oa(ofs);
oa << memory;
ofs.close();
THEN("we load a new policy and expect the same output") {
relearn::policy<state,action> policies;
std::ifstream ifs("serialize_test_policy_class");
boost::archive::text_iarchive ia(ifs);
ia >> policies;
REQUIRE(policies.best_action(state("hello"))->trait() == "hi!");
REQUIRE(policies.best_action(state("how are you?"))->trait() == "I'm fine, and you?");
REQUIRE(policies.best_action(state("not too bad! what you doing here?"))->trait()
== "I'm taking over the world!");
}
}
}
}

0 comments on commit d5a8b40

Please sign in to comment.