From 232b96387e9eeae158085ec4fd51ba05f6f9d65e Mon Sep 17 00:00:00 2001 From: Alex Ge Date: Thu, 13 Jul 2017 15:59:38 +0100 Subject: [PATCH] cleaned-up examples, moved structs into headers --- examples/blackjack.cpp | 254 +------------------------------- examples/blackjack_header.hpp | 261 +++++++++++++++++++++++++++++++++ examples/gridworld_header.hpp | 165 +++++++++++++++++++++ examples/gridworld_offline.cpp | 165 +-------------------- examples/gridworld_online.cpp | 178 +++------------------- src/relearn.hpp | 26 +++- 6 files changed, 479 insertions(+), 570 deletions(-) create mode 100644 examples/blackjack_header.hpp create mode 100644 examples/gridworld_header.hpp diff --git a/examples/blackjack.cpp b/examples/blackjack.cpp index 02548f4..88bd0fb 100644 --- a/examples/blackjack.cpp +++ b/examples/blackjack.cpp @@ -5,250 +5,11 @@ * @version 0.1.0 * @author Alex Giokas * @date 19.11.2016 + * + * @see the header (blackjack_header.hpp) for implementation details */ -#include -#include -#include -#include -#include -#include -#include -#include "../src/relearn.hpp" +#include "blackhack_header.hpp" -// a simple card structure -struct card -{ - std::string name; - std::string label; - std::vector value; - - void print() const - { - std::cout << name << " " - << label << " "; - } - - bool operator==(const card & rhs) const - { - return this->name == rhs.name && - this->label == rhs.label && - this->value == rhs.value; - - } -}; - -// a 52 playing card constant vector with unicode symbols :-D -const std::deque cards { - {"Ace", "♠", {1, 11}}, {"Ace", "♥", {1, 11}}, {"Ace", "♦", {1, 11}}, {"Ace", "♣", {1, 11}}, - {"Two", "♠", {2}}, {"Two", "♥", {2}}, {"Two", "♦", {2}}, {"Two", "♣", {2}}, - {"Three","♠", {3}}, {"Three","♥", {3}}, {"Three","♦", {3}}, {"Three","♣", {3}}, - {"Four", "♠", {4}}, {"Four", "♥", {4}}, {"Four", "♦", {4}}, {"Four", "♣", {4}}, - {"Five", "♠", {5}}, {"Five", "♥", {5}}, {"Five", "♦", {5}}, {"Five", "♣", {5}}, - {"Six", "♠", {6}}, {"Six", "♥", {6}}, {"Six", "♦", {6}}, {"Six", "♣", {6}}, - {"Seven","♠", {7}}, {"Seven","♥", {7}}, {"Seven","♦", {7}}, {"Seven","♣", {7}}, - {"Eight","♠", {8}}, {"Eight","♥", {8}}, {"Eight","♦", {8}}, {"Eight","♣", {8}}, - {"Nine", "♠", {9}}, {"Nine", "♥", {9}}, {"Nine", "♦", {9}}, {"Nine", "♣", {9}}, - {"Ten", "♠", {10}}, {"Ten", "♥", {10}}, {"Ten", "♦", {10}}, {"Ten", "♣", {10}}, - {"Jack", "♠", {10}}, {"Jack", "♥", {10}}, {"Jack", "♦", {10}}, {"Jack", "♣", {10}}, - {"Queen","♠", {10}}, {"Queen","♥", {10}}, {"Queen","♦", {10}}, {"Queen","♣", {10}}, - {"King", "♠", {10}}, {"King", "♥", {10}}, {"King", "♦", {10}}, {"King", "♣", {10}} -}; - -bool card_compare(const card & lhs, const card & rhs) -{ - return lhs.name == rhs.name && - lhs.label == rhs.label && - lhs.value == rhs.value; -} - -// hand is the currently held cards -struct hand -{ - hand() = default; - hand(const hand &) = default; - - // calculate value of hand - use max value of hand - unsigned int max_value() const - { - unsigned int result = 0; - for (const card & k : cards) { - result += *std::max_element(k.value.begin(), k.value.end()); - } - return result; - } - - // calculate value of hand - use min value (e.g., when hold an Ace) - unsigned int min_value() const - { - unsigned int result = 0; - for (const card & k : cards) { - result += *std::min_element(k.value.begin(), k.value.end()); - } - return result; - } - - // print on stdout - void print() const - { - for (card k : cards) { - k.print(); - } - std::cout << std::endl; - } - - // add new card - void insert(card arg) - { - cards.push_back(arg); - } - - // clear hand - void clear() - { - cards.clear(); - } - - // hand is blackjack - bool blackjack() const - { - std::vector twoblacks = {{"Ace", "♠", {1, 11}}, - {"Ace", "♣", {1, 11}}}; - return std::is_permutation(twoblacks.begin(), twoblacks.end(), - cards.begin(), card_compare); - } - - // hash this hand for relearn - std::size_t hash() const - { - std::size_t seed = 0; - for (auto & k : cards) { - for (auto & v : k.value) { - relearn::hash_combine(seed, v); - } - } - return seed; - } - - bool operator==(const hand & rhs) const - { - return this->cards == rhs.cards; - } - -private: - std::vector cards; -}; - -namespace std -{ -template <> struct hash -{ - std::size_t operator()(hand const& arg) const - { - return arg.hash(); - } -}; -} - -/// compare hands (return true for lhs wins) -bool hand_compare(const hand & lhs, const hand & rhs) -{ - if (lhs.blackjack()) return true; - else if (rhs.blackjack()) return false; - - if (lhs.min_value() > 21) return false; - else if (rhs.min_value() > 21 && lhs.min_value() < 21) return true; - - if (lhs.max_value() > rhs.max_value()) return true; - else return false; -} - -// Base class for all players -struct player : public hand -{ - virtual bool draw() - { - return false; - } -}; - -// House/dealer only uses simple rules to draw or stay -struct house : public player -{ - house(std::deque cards, std::mt19937 & prng) - : cards(cards), gen(prng) - {} - - // draw a card based on current hand - house always draws until 17 is reached - bool draw() - { - return (min_value() < 17 || max_value() < 17); - } - - // deal a card using current deck - or reset and deal - card deal() - { - if (deck.size() > 0) { - auto obj = deck.front(); - deck.pop_front(); - return obj; - } - else { - reset_deck(); - return deal(); - } - } - - // shuffle cards randomly - void reset_deck() - { - deck = cards; - std::shuffle(std::begin(deck), std::end(deck), gen); - } - -private: - std::deque deck; - const std::deque cards; - std::mt19937 gen; -}; - -// -// our learning adaptive player -struct client : public player -{ - // decide on drawing or staying - bool draw(std::mt19937 & prng, - relearn::state s_t, - relearn::policy, - relearn::action> & map) - { - auto a_t = map.best_action(s_t); - auto q_v = map.best_value(s_t); - std::uniform_real_distribution dist(0, 1); - // there exists a "best action" and it is positive - if (a_t && q_v > 0) { - sum_q += q_v; - policy_actions++; - return a_t->trait(); - } - // there does not exist a "best action" - else { - random_actions++; - return (dist(prng) > 0.5 ? true : false); - } - } - - // return a state by casting self to base class - relearn::state state() const - { - return relearn::state(*this); - } - - float random_actions = 0; - float policy_actions = 0; - float sum_q = 0; -}; - -// // int main(void) { @@ -276,7 +37,7 @@ int main(void) << std::endl; start: // play 10 rounds - then stop - for (int i = 0; i < 10; i++) { + for (int i = 0; i < 100; i++) { sum++; std::deque episode; // one card to dealer/house @@ -327,13 +88,11 @@ int main(void) agent->clear(); dealer->clear(); experience.push_back(episode); + std::cout << "\twin ratio: " << wins / sum << std::endl; std::cout << "\ton-policy ratio: " << agent->policy_actions / (agent->policy_actions + agent->random_actions) << std::endl; - std::cout << "\tavg Q-value: " - << (agent->sum_q / agent->policy_actions) - << std::endl; } // at this point, we have some playing experience, which we're going to use @@ -346,6 +105,9 @@ int main(void) } // clear experience - we'll add new ones! experience.clear(); + agent->reset(); + sum = 0; + wins = 0; goto start; return 0; diff --git a/examples/blackjack_header.hpp b/examples/blackjack_header.hpp new file mode 100644 index 0000000..580c651 --- /dev/null +++ b/examples/blackjack_header.hpp @@ -0,0 +1,261 @@ +#ifndef BLACKJACK_HPP +#define BLACKJACK_HPP + +#include +#include +#include +#include +#include +#include +#include +#include "../src/relearn.hpp" + +/** + * The basic Blackjack header structures: + * - card + * - hand + * - player + * - house + * - client + * + * All that is minimally needed in order to + * create the Blackjack example + */ + +// a simple card structure +struct card +{ + std::string name; + std::string label; + std::vector value; + + void print() const + { + std::cout << name << " " + << label << " "; + } + + bool operator==(const card & rhs) const + { + return this->name == rhs.name && + this->label == rhs.label && + this->value == rhs.value; + + } +}; + +// a 52 playing card constant vector with unicode symbols :-D +const std::deque cards { + {"Ace", "♠", {1, 11}}, {"Ace", "♥", {1, 11}}, {"Ace", "♦", {1, 11}}, {"Ace", "♣", {1, 11}}, + {"Two", "♠", {2}}, {"Two", "♥", {2}}, {"Two", "♦", {2}}, {"Two", "♣", {2}}, + {"Three","♠", {3}}, {"Three","♥", {3}}, {"Three","♦", {3}}, {"Three","♣", {3}}, + {"Four", "♠", {4}}, {"Four", "♥", {4}}, {"Four", "♦", {4}}, {"Four", "♣", {4}}, + {"Five", "♠", {5}}, {"Five", "♥", {5}}, {"Five", "♦", {5}}, {"Five", "♣", {5}}, + {"Six", "♠", {6}}, {"Six", "♥", {6}}, {"Six", "♦", {6}}, {"Six", "♣", {6}}, + {"Seven","♠", {7}}, {"Seven","♥", {7}}, {"Seven","♦", {7}}, {"Seven","♣", {7}}, + {"Eight","♠", {8}}, {"Eight","♥", {8}}, {"Eight","♦", {8}}, {"Eight","♣", {8}}, + {"Nine", "♠", {9}}, {"Nine", "♥", {9}}, {"Nine", "♦", {9}}, {"Nine", "♣", {9}}, + {"Ten", "♠", {10}}, {"Ten", "♥", {10}}, {"Ten", "♦", {10}}, {"Ten", "♣", {10}}, + {"Jack", "♠", {10}}, {"Jack", "♥", {10}}, {"Jack", "♦", {10}}, {"Jack", "♣", {10}}, + {"Queen","♠", {10}}, {"Queen","♥", {10}}, {"Queen","♦", {10}}, {"Queen","♣", {10}}, + {"King", "♠", {10}}, {"King", "♥", {10}}, {"King", "♦", {10}}, {"King", "♣", {10}} +}; + +bool card_compare(const card & lhs, const card & rhs) +{ + return lhs.name == rhs.name && + lhs.label == rhs.label && + lhs.value == rhs.value; +} + +// hand is the currently held cards +struct hand +{ + hand() = default; + hand(const hand &) = default; + + // calculate value of hand - use max value of hand + unsigned int max_value() const + { + unsigned int result = 0; + for (const card & k : cards) { + result += *std::max_element(k.value.begin(), k.value.end()); + } + return result; + } + + // calculate value of hand - use min value (e.g., when hold an Ace) + unsigned int min_value() const + { + unsigned int result = 0; + for (const card & k : cards) { + result += *std::min_element(k.value.begin(), k.value.end()); + } + return result; + } + + // print on stdout + void print() const + { + for (card k : cards) { + k.print(); + } + std::cout << std::endl; + } + + // add new card + void insert(card arg) + { + cards.push_back(arg); + } + + // clear hand + void clear() + { + cards.clear(); + } + + // hand is blackjack + bool blackjack() const + { + std::vector twoblacks = {{"Ace", "♠", {1, 11}}, + {"Ace", "♣", {1, 11}}}; + return std::is_permutation(twoblacks.begin(), twoblacks.end(), + cards.begin(), card_compare); + } + + // hash this hand for relearn + std::size_t hash() const + { + std::size_t seed = 0; + for (auto & k : cards) { + for (auto & v : k.value) { + relearn::hash_combine(seed, v); + } + } + return seed; + } + + bool operator==(const hand & rhs) const + { + return this->cards == rhs.cards; + } + +private: + std::vector cards; +}; + +namespace std +{ +template <> struct hash +{ + std::size_t operator()(hand const& arg) const + { + return arg.hash(); + } +}; +} + +/// compare hands (return true for lhs wins) +bool hand_compare(const hand & lhs, const hand & rhs) +{ + if (lhs.blackjack()) return true; + else if (rhs.blackjack()) return false; + + if (lhs.min_value() > 21) return false; + else if (rhs.min_value() > 21 && lhs.min_value() < 21) return true; + + if (lhs.max_value() > rhs.max_value()) return true; + else return false; +} + +// Base class for all players +struct player : public hand +{ + virtual bool draw() + { + return false; + } +}; + +// House/dealer only uses simple rules to draw or stay +struct house : public player +{ + house(std::deque cards, std::mt19937 & prng) + : cards(cards), gen(prng) + {} + + // draw a card based on current hand - house always draws until 17 is reached + bool draw() + { + return (min_value() < 17 || max_value() < 17); + } + + // deal a card using current deck - or reset and deal + card deal() + { + if (deck.size() > 0) { + auto obj = deck.front(); + deck.pop_front(); + return obj; + } + else { + reset_deck(); + return deal(); + } + } + + // shuffle cards randomly + void reset_deck() + { + deck = cards; + std::shuffle(std::begin(deck), std::end(deck), gen); + } + +private: + std::deque deck; + const std::deque cards; + std::mt19937 gen; +}; + +// +// our learning adaptive player +struct client : public player +{ + // decide on drawing or staying + bool draw(std::mt19937 & prng, + relearn::state s_t, + relearn::policy, + relearn::action> & map) + { + auto pair = map.best(s_t); + std::uniform_real_distribution dist(0, 1); + // there exists a "best action" and it is positive + if (pair.first && pair.second > 0) { + policy_actions++; + return pair.first->trait(); + } + // there does not exist a "best action" + else { + random_actions++; + return (dist(prng) > 0.5 ? true : false); + } + } + + // return a state by casting self to base class + relearn::state state() const + { + return relearn::state(*this); + } + + void reset() + { + random_actions = 0; + policy_actions = 0; + } + + float random_actions = 0; + float policy_actions = 0; +}; + +#endif diff --git a/examples/gridworld_header.hpp b/examples/gridworld_header.hpp new file mode 100644 index 0000000..e479d41 --- /dev/null +++ b/examples/gridworld_header.hpp @@ -0,0 +1,165 @@ +#ifndef GRIDWORLD_HPP +#define GRIDWORLD_HPP +#include +#include +#include +#include +#include +#include +#include +#include +#include "../src/relearn.hpp" +/** + * This header contains the basic structures and operations + * needed to demonstrate the Gridworld ecosystem in a simple + * yet fully functional manner. + * + * 1. State space: GridWorld has 10x10 = 100 distinct states. + * The start state is the top left cell. + * The gray cells are walls and cannot be moved to. + * + * 2. Actions: The agent can choose from up to 4 actions to move around. + * + * 3. Rewards: The agent receives +1 reward when it is in + * the center square (the one that shows R 1.0), + * and -1 reward in a the boundary states (R -1.0 is shown for these). + * The state with +1.0 reward is the goal state. + */ + + +/** + * A grid block is simply a coordinate (x,y) + * A grid *may* have a reward R. + * This struct functions as the block upon which the gridworld problem is based. + * We also use this as the `state_trait` descriptor S for state and action + */ +struct grid +{ + unsigned int x = 0; + unsigned int y = 0; + double R = 0; + bool occupied = false; + + bool operator==(const grid & arg) const + { + return (this->x == arg.x) && + (this->y == arg.y); + } +}; + +/** + * A move in the grid world is simply a number. + * 0 for left, 1 for top, 2 for right, 3 for down. + * We also use this as the `action_trait` descriptor A for state and action + */ +struct direction +{ + unsigned int dir; + + bool operator==(const direction & arg) const + { + return (this->dir == arg.dir); + } +}; + +/** + * Hash specialisations in the STD namespace for structs grid and direction. + * Those are **required** because the underlying relearn library + * uses unordered_map and unordered_set, which use hashing functions for the classes + * which are mapped/stored internally. + */ +namespace std +{ +template <> struct hash +{ + std::size_t operator()(grid const& arg) const + { + std::size_t seed = 0; + relearn::hash_combine(seed, arg.x); + relearn::hash_combine(seed, arg.y); + return seed; + } +}; +template <> struct hash +{ + std::size_t operator()(direction const& arg) const + { + std::size_t seed = 0; + relearn::hash_combine(seed, arg.dir); + return seed; + } +}; +} + +/** + * The gridworld struct simply contains the grid blocks. + * Each block is uniquely identified by its coordinates. + */ +struct world +{ + std::unordered_set blocks; +}; + +using state = relearn::state; +using action = relearn::action; + +/// load the gridworld from the text file +/// boundaries are `occupied` e.g., can't move into them +/// fire/danger blocks are marked with a reward -1 +world populate() +{ + std::ifstream infile("../examples/gridworld.txt"); + world environment = {}; + std::string line; + while (std::getline(infile, line)) + { + std::istringstream iss(line); + unsigned int x; + unsigned int y; + double r; + bool occupied; + if (iss >> x >> y >> occupied >> r) { + environment.blocks.insert({x, y, r, occupied}); + } + else break; + } + return environment; +} + +/// Decide on a stochastic (random) direction and return the next grid block +struct rand_direction +{ + std::pair operator()(std::mt19937 & prng, + world gridworld, + grid current) + { + std::uniform_int_distribution dist(0, 3); + unsigned int x = current.x; + unsigned int y = current.y; + // randomly decide on next grid - we map numbers to a direction + unsigned int d = dist(prng); + switch (d) { + case 0 : y--; + break; + case 1 : x++; + break; + case 2 : y++; + break; + case 3 : x--; + break; + } + auto it = std::find_if(gridworld.blocks.begin(), + gridworld.blocks.end(), + [&](const auto b) { + return b.x == x && b.y == y; + }); + if (it == gridworld.blocks.end()) { + return rand_direction()(prng, gridworld, current); + } + if (it->occupied) { + return rand_direction()(prng, gridworld, current); + } + return std::make_pair(direction{d}, *it); + } +}; +#endif diff --git a/examples/gridworld_offline.cpp b/examples/gridworld_offline.cpp index 183c528..4b26c99 100644 --- a/examples/gridworld_offline.cpp +++ b/examples/gridworld_offline.cpp @@ -4,21 +4,8 @@ * as a toy model in the Reinforcement Learning literature. * In this particular case: * - * 1. State space: GridWorld has 10x10 = 100 distinct states. - * The start state is the top left cell. - * The gray cells are walls and cannot be moved to. - * - * 2. Actions: The agent can choose from up to 4 actions to move around. - * - * 3. Environment Dynamics: GridWorld is deterministic, - * leading to the same new state given each state and action. - * Non-deterministic environments nead a probabilistic approach. - * - * 4. Rewards: The agent receives +1 reward when it is in - * the center square (the one that shows R 1.0), - * and -1 reward in a the boundary states (R -1.0 is shown for these). - * The state with +1.0 reward is the goal state. - * + * Environment Dynamics: GridWorld is deterministic, + * leading to the same new state given each state and action. * This is a deterministic, finite Markov Decision Process (MDP) * and the goal is to find an agent policy that maximizes * the future discounted reward. @@ -32,153 +19,7 @@ * version can get stuck into repeating the same actions over and over again, * therefore if it is running for longer than a minute, feel free to CTRL-C it. */ -#include -#include -#include -#include -#include -#include -#include -#include -#include "../src/relearn.hpp" - -/** - * A grid block is simply a coordinate (x,y) - * A grid *may* have a reward R. - * This struct functions as the block upon which the gridworld problem is based. - * We also use this as the `state_trait` descriptor S for state and action - */ -struct grid -{ - unsigned int x = 0; - unsigned int y = 0; - double R = 0; - bool occupied = false; - - bool operator==(const grid & arg) const - { - return (this->x == arg.x) && - (this->y == arg.y); - } -}; - -/** - * A move in the grid world is simply a number. - * 0 for left, 1 for top, 2 for right, 3 for down. - * We also use this as the `action_trait` descriptor A for state and action - */ -struct direction -{ - unsigned int dir; - - bool operator==(const direction & arg) const - { - return (this->dir == arg.dir); - } -}; - -/** - * Hash specialisations in the STD namespace for structs grid and direction. - * Those are **required** because the underlying relearn library - * uses unordered_map and unordered_set, which use hashing functions for the classes - * which are mapped/stored internally. - */ -namespace std -{ -template <> struct hash -{ - std::size_t operator()(grid const& arg) const - { - std::size_t seed = 0; - relearn::hash_combine(seed, arg.x); - relearn::hash_combine(seed, arg.y); - return seed; - } -}; -template <> struct hash -{ - std::size_t operator()(direction const& arg) const - { - std::size_t seed = 0; - relearn::hash_combine(seed, arg.dir); - return seed; - } -}; -} - -/** - * The gridworld struct simply contains the grid blocks. - * Each block is uniquely identified by its coordinates. - */ -struct world -{ - std::unordered_set blocks; -}; - -using state = relearn::state; -using action = relearn::action; - -/// load the gridworld from the text file -/// boundaries are `occupied` e.g., can't move into them -/// fire/danger blocks are marked with a reward -1 -world populate() -{ - std::ifstream infile("../examples/gridworld.txt"); - world environment = {}; - std::string line; - while (std::getline(infile, line)) - { - std::istringstream iss(line); - unsigned int x; - unsigned int y; - double r; - bool occupied; - if (iss >> x >> y >> occupied >> r) { - environment.blocks.insert({x, y, r, occupied}); - } - else break; - } - return environment; -} - -/// Decide on a stochastic (random) direction and return the next grid block -struct rand_direction -{ - std::pair operator()(std::mt19937 & prng, - world gridworld, - grid current) - { - std::uniform_int_distribution dist(0, 3); - unsigned int x = current.x; - unsigned int y = current.y; - // randomly decide on next grid - we map numbers to a direction - unsigned int d = dist(prng); - switch (d) { - case 0 : y--; - break; - case 1 : x++; - break; - case 2 : y++; - break; - case 3 : x--; - break; - } - auto it = std::find_if(gridworld.blocks.begin(), - gridworld.blocks.end(), - [&](const auto b) { - return b.x == x && b.y == y; - }); - if (it == gridworld.blocks.end()) { - //std::cerr << "tried to move off the grid at: " << x << "," << y << std::endl; - return rand_direction()(prng, gridworld, current); - } - if (it->occupied) { - //std::cerr << "occupied block: " << x << "," << y << std::endl; - return rand_direction()(prng, gridworld, current); - } - return std::make_pair(direction{d}, *it); - } -}; +#include "gridworld_header.hpp" /** * Exploration technique is based on Monte-Carlo, e.g.: stochastic search. diff --git a/examples/gridworld_online.cpp b/examples/gridworld_online.cpp index 0465d74..b3621a4 100644 --- a/examples/gridworld_online.cpp +++ b/examples/gridworld_online.cpp @@ -28,160 +28,15 @@ * policy, unless it has a bad value, or if it is unknown, in which case * it takes a random action. */ -#include -#include -#include -#include -#include -#include -#include -#include -#include "../src/relearn.hpp" - -/** - * A grid block is simply a coordinate (x,y) - * A grid *may* have a reward R. - * This struct functions as the block upon which the gridworld problem is based. - * We also use this as the `state_trait` descriptor S for state and action - */ -struct grid -{ - unsigned int x = 0; - unsigned int y = 0; - double R = 0; - bool occupied = false; - - bool operator==(const grid & arg) const - { - return (this->x == arg.x) && - (this->y == arg.y); - } -}; - +#include "gridworld_header.hpp" /** - * A move in the grid world is simply a number. - * 0 for left, 1 for top, 2 for right, 3 for down. - * We also use this as the `action_trait` descriptor A for state and action - */ -struct direction -{ - unsigned int dir; - - bool operator==(const direction & arg) const - { - return (this->dir == arg.dir); - } -}; - -/** - * Hash specialisations in the STD namespace for structs grid and direction. - * Those are **required** because the underlying relearn library - * uses unordered_map and unordered_set, which use hashing functions for the classes - * which are mapped/stored internally. - */ -namespace std -{ -template <> struct hash -{ - std::size_t operator()(grid const& arg) const - { - std::size_t seed = 0; - relearn::hash_combine(seed, arg.x); - relearn::hash_combine(seed, arg.y); - return seed; - } -}; -template <> struct hash -{ - std::size_t operator()(direction const& arg) const - { - std::size_t seed = 0; - relearn::hash_combine(seed, arg.dir); - return seed; - } -}; -} - -/** - * The gridworld struct simply contains the grid blocks. - * Each block is uniquely identified by its coordinates. - */ -struct world -{ - std::unordered_set blocks; -}; - -using state = relearn::state; -using action = relearn::action; - -/// -/// load the gridworld from the text file -/// boundaries are `occupied` e.g., can't move into them -/// fire/danger blocks are marked with a reward -1 -/// -world populate() -{ - std::ifstream infile("../examples/gridworld.txt"); - world environment = {}; - std::string line; - while (std::getline(infile, line)) - { - std::istringstream iss(line); - unsigned int x; - unsigned int y; - double r; - bool occupied; - if (iss >> x >> y >> occupied >> r) { - environment.blocks.insert({x, y, r, occupied}); - } - else break; - } - return environment; -} - -/// -/// Decide on a stochastic (random) direction and return the next grid block -/// -struct rand_direction -{ - std::pair operator()(std::mt19937 & prng, - world gridworld, - grid current) - { - std::uniform_int_distribution dist(0, 3); - unsigned int x = current.x; - unsigned int y = current.y; - // randomly decide on next grid - we map numbers to a direction - unsigned int d = dist(prng); - switch (d) { - case 0 : y--; - break; - case 1 : x++; - break; - case 2 : y++; - break; - case 3 : x--; - break; - } - auto it = std::find_if(gridworld.blocks.begin(), - gridworld.blocks.end(), - [&](const auto b) { - return b.x == x && b.y == y; - }); - if (it == gridworld.blocks.end()) { - return rand_direction()(prng, gridworld, current); - } - if (it->occupied) { - return rand_direction()(prng, gridworld, current); - } - return std::make_pair(direction{d}, *it); - } -}; - -/** - * Exploration technique is based on Monte-Carlo, e.g.: stochastic search. - * The `agent` will randomly search the world experiencing different blocks. - * The agent will internally map its experience using the State/Action pairs. + * Exploration technique is `online` meaning the agent will follow + * policies, if they exist for a particular state. It will revert to a + * stochastic (random) approach only if policies don't exist, or if those + * policies have a negative value. + * + * For this type of approach to work, in between episodes, the agent must + * be rewarded and re-trained. */ template @@ -206,11 +61,9 @@ std::deque> explore(const world & w, // if there exists a policy, then stay on it! while (!stop) { - auto action = policy_map.best_action(state_now); - auto q_val = policy_map.best_value(state_now); - if (action && q_val > 0) - { - switch (action->trait().dir) { + auto pair = policy_map.best(state_now); + if (pair.first && pair.second > 0) { + switch (pair.first->trait().dir) { case 0 : curr.y--; break; case 1 : curr.x++; @@ -299,5 +152,14 @@ int main() } stop = (episode.back().state.reward() == 1 ? true : false); } + // + // The catch here is that although we explore and stay on policy at the same + // time, if we do not explore long enough, we may find "a solution" which + // is not necessarily the optimal or best one! + // the way to avoid doing this, is to not exit the loop as soon + // as a solution is found, and in conjunction with this, + // to also allow random actions, even when an optimal policy exists (this + // method is known as e-Greedy or explorative-Greedy). + // return 0; } diff --git a/src/relearn.hpp b/src/relearn.hpp index afea60c..fbffd39 100644 --- a/src/relearn.hpp +++ b/src/relearn.hpp @@ -224,6 +224,11 @@ class policy * @warning if none are found, returns nullptr */ std::unique_ptr best_action(state_class s_t); + /** + * @return a pair of action/value if one exists or a + * pair of if one doesn't exist + */ + std::pair,value_type> best(state_class s_t); private: // if using boost serialization the policies used // are wrappers defined in internal header `serialize.tpl` @@ -546,14 +551,27 @@ template policy::best_action(state_class s_t) { - auto it = std::max_element(__policies__[s_t].begin(), __policies__[s_t].end(), - [&](const auto &lhs, const auto &rhs) { - return lhs.second < rhs.second; - }); + auto it = std::max_element(__policies__[s_t].begin(), + __policies__[s_t].end(), + [&](const auto &lhs, const auto &rhs) { return lhs.second < rhs.second; }); return it != __policies__[s_t].end() ? std::move(std::make_unique(it->first)) : nullptr; } +template +std::pair,value_type> + policy::best(state_class s_t) +{ + auto it = std::max_element(__policies__[s_t].begin(), + __policies__[s_t].end(), + [&](const auto &lhs, const auto &rhs) { return lhs.second < rhs.second; }); + return it != __policies__[s_t].end() ? + std::make_pair(std::make_unique(it->first), it->second) : + std::make_pair(nullptr, 0); +} + #if USING_BOOST_SERIALIZATION template