-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
finished serialization and added a serialize test
- Loading branch information
Showing
5 changed files
with
184 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,63 @@ | ||
#if USING_BOOST_SERIALIZATION | ||
#include <boost/serialization/serialization.hpp> | ||
#include <boost/serialization/access.hpp> | ||
#include <boost/serialization/unordered_map.hpp> | ||
|
||
template <class state_class> | ||
struct state_serial : private state_class | ||
struct state_serial : public state_class | ||
{ | ||
state_serial(); | ||
operator state_class() const; | ||
state_serial(state_class); | ||
operator state_class(); | ||
}; | ||
|
||
template <class action_class> | ||
struct action_serial : private action_class | ||
struct action_serial : public action_class | ||
{ | ||
action_serial(); | ||
operator action_class() const; | ||
}; | ||
|
||
action_serial(action_class); | ||
operator action_class(); | ||
}; | ||
|
||
/* | ||
* Template serialization implementation | ||
* | ||
* |^| | | | ||
* | |_____| | | ||
* | _____ | | ||
* | | | | | ||
* | |_____| | | ||
* |_|_____|_| | ||
* | ||
*/ | ||
template <class state_class> | ||
struct cast_state | ||
{ | ||
state_serial<state_class> operator()(const state_class arg) const; | ||
}; | ||
state_serial<state_class>::state_serial() | ||
: state_class() | ||
{} | ||
|
||
template <class action_class> | ||
struct cast_action | ||
{ | ||
action_serial<action_class> operator()(const action_class arg) const; | ||
}; | ||
template <class state_class> | ||
state_serial<state_class>::state_serial(state_class arg) | ||
: state_class(arg) | ||
{} | ||
|
||
template <class state_class> | ||
struct hasher<state_serial<state_class>> | ||
template <class state_class> | ||
state_serial<state_class>::operator state_class() | ||
{ | ||
std::size_t operator()(const state_serial<state_class> & arg) const; | ||
}; | ||
return static_cast<state_class>(*this); | ||
} | ||
|
||
template <class action_class> | ||
struct hasher<action_serial<action_class>> | ||
{ | ||
std::size_t operator()(const action_serial<action_class> & arg) const; | ||
}; | ||
template <class action_class> | ||
action_serial<action_class>::action_serial() | ||
: action_class() | ||
{} | ||
|
||
template <class action_class, | ||
typename value_type> | ||
struct hasher<std::unordered_map<action_class,value_type>> | ||
{ | ||
std::size_t operator()(const std::unordered_map<action_class,value_type> &arg) const; | ||
}; | ||
template <class action_class> | ||
action_serial<action_class>::action_serial(action_class arg) | ||
: action_class(arg) | ||
{} | ||
|
||
template <class action_class, | ||
typename value_type> | ||
struct hasher<std::unordered_map<action_serial<action_class>,value_type>> | ||
template <class action_class> | ||
action_serial<action_class>::operator action_class() | ||
{ | ||
std::size_t operator()(const std::unordered_map<action_serial<action_class>, | ||
value_type> &arg) const; | ||
}; | ||
|
||
|
||
/// TODO: implement those methods above | ||
return static_cast<action_class>(*this); | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#include "relearn.hpp" | ||
#include <fstream> | ||
#include <boost/archive/text_oarchive.hpp> | ||
#include <boost/archive/text_iarchive.hpp> | ||
#define CATCH_CONFIG_MAIN | ||
#include "catch.hpp" | ||
|
||
TEST_CASE("state class serialization test", "[state_class_serialization_test]") | ||
{ | ||
auto s_x = relearn::state<std::string>("hello"); | ||
auto s_y = relearn::state<std::string>("world"); | ||
// create and open a character archive for output | ||
std::ofstream ofs("serialize_test_state_class"); | ||
boost::archive::text_oarchive oa(ofs); | ||
// write class instance to archive | ||
oa << s_x; | ||
ofs.close(); | ||
// load data from the archive now | ||
std::ifstream ifs("serialize_test_state_class"); | ||
boost::archive::text_iarchive ia(ifs); | ||
ia >> s_y; | ||
// if serialization failed, then s_x will not equal s_y! | ||
REQUIRE(s_x == s_y); | ||
} | ||
|
||
TEST_CASE("action class serialization test", "[action_class_serialization_test]") | ||
{ | ||
auto a_x = relearn::action<int>(0); | ||
auto a_y = relearn::action<int>(1); | ||
// create and open a character archive for output | ||
std::ofstream ofs("serialize_test_action_class"); | ||
boost::archive::text_oarchive oa(ofs); | ||
// write class instance to archive | ||
oa << a_x; | ||
ofs.close(); | ||
// load data from the archive now | ||
std::ifstream ifs("serialize_test_action_class"); | ||
boost::archive::text_iarchive ia(ifs); | ||
ia >> a_y; | ||
// if serialization failed, then a_x will not equal a_y! | ||
REQUIRE(a_x == a_y); | ||
} | ||
|
||
SCENARIO("policy class serialization test", "[policy_class_serialize_test]") | ||
{ | ||
using state = relearn::state<std::string>; | ||
using action = relearn::action<std::string>; | ||
using link = relearn::link<state,action>; | ||
relearn::policy<state,action> memory; | ||
GIVEN("hard-coded episode with positive reward") | ||
{ | ||
std::deque<link> episode = { | ||
{state("hello"), action("hi!")}, | ||
{state("how are you?"), action("I'm fine, and you?")}, | ||
{state(1, "not too bad! what you doing here?"), | ||
action("I'm taking over the world!")}, | ||
}; | ||
WHEN("policy is trained and saved to disk") { | ||
auto learner = relearn::q_learning<state,action>{0.9, 0.9}; | ||
for (int k = 0; k < 10; k++) { | ||
learner(episode, memory); | ||
} | ||
std::ofstream ofs("serialize_test_policy_class"); | ||
boost::archive::text_oarchive oa(ofs); | ||
oa << memory; | ||
ofs.close(); | ||
THEN("we load a new policy and expect the same output") { | ||
relearn::policy<state,action> policies; | ||
std::ifstream ifs("serialize_test_policy_class"); | ||
boost::archive::text_iarchive ia(ifs); | ||
ia >> policies; | ||
REQUIRE(policies.best_action(state("hello"))->trait() == "hi!"); | ||
REQUIRE(policies.best_action(state("how are you?"))->trait() == "I'm fine, and you?"); | ||
REQUIRE(policies.best_action(state("not too bad! what you doing here?"))->trait() | ||
== "I'm taking over the world!"); | ||
} | ||
} | ||
} | ||
} |