From d5a8b40dea9163f6c4605d3121e3ee54f5f5b6dc Mon Sep 17 00:00:00 2001 From: Alex Ge Date: Thu, 13 Jul 2017 17:50:09 +0100 Subject: [PATCH] finished serialization and added a serialize test --- CMakeLists.txt | 19 ++++++--- examples/blackjack.cpp | 2 +- src/relearn.hpp | 61 ++++++++++++++++++++++------- src/serialize.tpl | 83 ++++++++++++++++++++-------------------- tests/serialize_test.cpp | 79 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 184 insertions(+), 60 deletions(-) create mode 100644 tests/serialize_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index eae2fe5..afbc18d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,17 +20,15 @@ set(CMAKE_MACOSX_RPATH 1) set(SRC ${SRC} src) include_directories(${SRC}) -# add -DUSING_BOOST_SERIALIZATION in order to save policies on disk -# add_definitions(-DUSING_BOOST_SERIALIZATION) -# find_package(Boost 1.55 COMPONENTS serialization system) -# include_directories(${Boost_INCLUDE_DIR}) - # build examples set(EXAMPLES ${EXAMPLES} examples) add_executable(ex_gridworld_offline ${EXAMPLES}/gridworld_offline.cpp) add_executable(ex_gridworld_online ${EXAMPLES}/gridworld_online.cpp) add_executable(ex_blackjack ${EXAMPLES}/blackjack.cpp) +# TODO: add an example for serialization of policy, state, action +# which uses the definition: + # set output set(CMAKE_COLOR_MAKEFILE on) set(CMAKE_VERBOSE_MAKEFILE off) @@ -47,6 +45,7 @@ message(STATUS "CXX Linker: " ${CMAKE_EXE_LINKER_FLAGS}) # if building tests if (BUILD_TESTS) + message(STATUS "Checking ${PROJECT_SOURCE_DIR}/tests/catch.hpp") if (EXISTS "${PROJECT_SOURCE_DIR}/tests/catch.hpp") message(STATUS "catch.hpp exists, not downloading again - make sure it is up to date!") @@ -62,4 +61,14 @@ if (BUILD_TESTS) add_test(unit_test test_classes) add_executable(test_logic tests/logic_test.cpp) add_test(logic_test test_logic) + + # if building tests & testing serialization + if (USING_BOOST_SERIALIZATION) + add_definitions(-DUSING_BOOST_SERIALIZATION) + find_package(Boost 1.55 COMPONENTS serialization system) + include_directories(${Boost_INCLUDE_DIR}) + add_executable(test_serialize tests/serialize_test.cpp) + target_link_libraries(test_serialize ${Boost_LIBRARIES}) + add_test(serialize_test test_serialize) + endif() endif() diff --git a/examples/blackjack.cpp b/examples/blackjack.cpp index 88bd0fb..5088e06 100644 --- a/examples/blackjack.cpp +++ b/examples/blackjack.cpp @@ -8,7 +8,7 @@ * * @see the header (blackjack_header.hpp) for implementation details */ -#include "blackhack_header.hpp" +#include "blackjack_header.hpp" // int main(void) diff --git a/src/relearn.hpp b/src/relearn.hpp index fbffd39..2ed15a6 100644 --- a/src/relearn.hpp +++ b/src/relearn.hpp @@ -114,7 +114,13 @@ struct hasher> { std::size_t operator()(const state & arg) const; }; - +#if USING_BOOST_SERIALIZATION +template +struct hasher> +{ + std::size_t operator()(const state_serial & arg) const; +}; +#endif /** * @brief an action class - wraps around your class or pdt * @class action @@ -144,7 +150,7 @@ class action /// return trait action_trait trait() const; private: - /// action descriptor - object/value wrapped + // action descriptor - object/value wrapped action_trait __trait__; #if USING_BOOST_SERIALIZATION friend class boost::serialization::access; @@ -158,9 +164,15 @@ class action template struct hasher> { - std::size_t operator()(const action &arg) const; + std::size_t operator()(const action & arg) const; }; - +template +#if USING_BOOST_SERIALIZATION +struct hasher> +{ + std::size_t operator()(const action_serial & arg) const; +}; +#endif /** * @struct link * @brief a simple `link` or pair for joining state-actions in the MDP @@ -230,20 +242,20 @@ class policy */ std::pair,value_type> best(state_class s_t); private: -// if using boost serialization the policies used -// are wrappers defined in internal header `serialize.tpl` #if USING_BOOST_SERIALIZATION friend class boost::serialization::access; + // serialize method template void serialize(archive & ar, const unsigned int version); - std::unordered_map, + std::unordered_map, value_type, - hasher>, - hasher + hasher>>, + hasher> > __policies__; -// else we're using the actual `state_class` and `action_class` types #else + // policies maps is: [state][action][state_next] => Q-value std::unordered_map __policies__; #endif }; - /** * @class q_learning This is the **deterministic** Q-Learning algorithm * @brief Q-Learning update algorithm sets policies using episodes (`markov_chain`) @@ -512,7 +523,15 @@ template ::action_map policy::actions(state_class s_t) { +#if USING_BOOST_SERIALIZATION + std::unordered_map> retval; + for (const auto & item : __policies__[s_t]) { + retval[static_cast(item.first)] = item.second; + } + return retval; +#else return __policies__[s_t]; +#endif } template ::serialize(archive & ar, { ar & __policies__; } + +template +std::size_t hasher + >::operator()(const state_serial & arg) const +{ + return arg.hash(); +} + +template +std::size_t hasher + >::operator()(const action_serial & arg) const +{ + return arg.hash(); +} #endif template (triplet)); } } - +#if USING_BOOST_SERIALIZATION +#include "serialize.tpl" +#endif } // end of namespace #endif diff --git a/src/serialize.tpl b/src/serialize.tpl index e475103..ae721c2 100644 --- a/src/serialize.tpl +++ b/src/serialize.tpl @@ -1,62 +1,63 @@ -#if USING_BOOST_SERIALIZATION #include #include #include template -struct state_serial : private state_class +struct state_serial : public state_class { state_serial(); - operator state_class() const; + state_serial(state_class); + operator state_class(); }; template -struct action_serial : private action_class +struct action_serial : public action_class { action_serial(); - operator action_class() const; -}; - + action_serial(action_class); + operator action_class(); +}; + +/* + * Template serialization implementation + * + * |^| | | + * | |_____| | + * | _____ | + * | | | | + * | |_____| | + * |_|_____|_| + * + */ template -struct cast_state -{ - state_serial operator()(const state_class arg) const; -}; +state_serial::state_serial() +: state_class() +{} -template -struct cast_action -{ - action_serial operator()(const action_class arg) const; -}; +template +state_serial::state_serial(state_class arg) +: state_class(arg) +{} -template -struct hasher> +template +state_serial::operator state_class() { - std::size_t operator()(const state_serial & arg) const; -}; + return static_cast(*this); +} -template -struct hasher> -{ - std::size_t operator()(const action_serial & arg) const; -}; +template +action_serial::action_serial() +: action_class() +{} -template -struct hasher> -{ - std::size_t operator()(const std::unordered_map &arg) const; -}; +template +action_serial::action_serial(action_class arg) +: action_class(arg) +{} -template -struct hasher,value_type>> +template +action_serial::operator action_class() { - std::size_t operator()(const std::unordered_map, - value_type> &arg) const; -}; - - -/// TODO: implement those methods above + return static_cast(*this); +} -#endif diff --git a/tests/serialize_test.cpp b/tests/serialize_test.cpp new file mode 100644 index 0000000..a81ffe1 --- /dev/null +++ b/tests/serialize_test.cpp @@ -0,0 +1,79 @@ +#include "relearn.hpp" +#include +#include +#include +#define CATCH_CONFIG_MAIN +#include "catch.hpp" + +TEST_CASE("state class serialization test", "[state_class_serialization_test]") +{ + auto s_x = relearn::state("hello"); + auto s_y = relearn::state("world"); + // create and open a character archive for output + std::ofstream ofs("serialize_test_state_class"); + boost::archive::text_oarchive oa(ofs); + // write class instance to archive + oa << s_x; + ofs.close(); + // load data from the archive now + std::ifstream ifs("serialize_test_state_class"); + boost::archive::text_iarchive ia(ifs); + ia >> s_y; + // if serialization failed, then s_x will not equal s_y! + REQUIRE(s_x == s_y); +} + +TEST_CASE("action class serialization test", "[action_class_serialization_test]") +{ + auto a_x = relearn::action(0); + auto a_y = relearn::action(1); + // create and open a character archive for output + std::ofstream ofs("serialize_test_action_class"); + boost::archive::text_oarchive oa(ofs); + // write class instance to archive + oa << a_x; + ofs.close(); + // load data from the archive now + std::ifstream ifs("serialize_test_action_class"); + boost::archive::text_iarchive ia(ifs); + ia >> a_y; + // if serialization failed, then a_x will not equal a_y! + REQUIRE(a_x == a_y); +} + +SCENARIO("policy class serialization test", "[policy_class_serialize_test]") +{ + using state = relearn::state; + using action = relearn::action; + using link = relearn::link; + relearn::policy memory; + GIVEN("hard-coded episode with positive reward") + { + std::deque episode = { + {state("hello"), action("hi!")}, + {state("how are you?"), action("I'm fine, and you?")}, + {state(1, "not too bad! what you doing here?"), + action("I'm taking over the world!")}, + }; + WHEN("policy is trained and saved to disk") { + auto learner = relearn::q_learning{0.9, 0.9}; + for (int k = 0; k < 10; k++) { + learner(episode, memory); + } + std::ofstream ofs("serialize_test_policy_class"); + boost::archive::text_oarchive oa(ofs); + oa << memory; + ofs.close(); + THEN("we load a new policy and expect the same output") { + relearn::policy policies; + std::ifstream ifs("serialize_test_policy_class"); + boost::archive::text_iarchive ia(ifs); + ia >> policies; + REQUIRE(policies.best_action(state("hello"))->trait() == "hi!"); + REQUIRE(policies.best_action(state("how are you?"))->trait() == "I'm fine, and you?"); + REQUIRE(policies.best_action(state("not too bad! what you doing here?"))->trait() + == "I'm taking over the world!"); + } + } + } +}