-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcfr-inl.h
155 lines (140 loc) · 5.48 KB
/
cfr-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#ifndef GTO_CFR_INL_H_
#define GTO_CFR_INL_H_
#include <algorithm>
#include <string>
#include <vector>
#include "cfr.h"
#include "dealer_interface.h"
namespace GTO {
// Print all the strategies used at each node under "node" for
// "player". "states" is a vector containing all the possible states
// for the given player. Thus, the state id varies between 0 and
// states.size()-1.
// REQUIRES: the State class should implement StateInterface.
template<class State>
void
TreePrint(const Node& node,
Node::Player player,
const std::vector<State>& states)
{
if (node.isleaf())
return;
if (node.active_player() == player) {
Array<double> strat = node.AverageStrategy();
printf("%s", State::Name().c_str());
for (auto c : node.children())
printf(" | %s", c->name().c_str());
putchar('\n');
for (size_t s = 0; s < strat.num_rows(); s++) {
printf("%s", states[s].ToString().c_str());
for (size_t a = 0; a < strat.num_cols(); a++)
printf("\t%.4f", strat.get(s, a));
putchar('\n');
}
}
for (auto c : node.children())
TreePrint(*c, player, states);
}
struct Record {
string state;
double prob;
explicit Record(const string& state, double prob)
: state(state), prob(prob)
{}
bool operator<(const Record& rhs) const
{
return rhs.prob < prob;
}
};
// For each node under "root" where "player" is last active, print the
// probability that "player" would take the given action if he was in
// a given state. "states" is a vector containing all the possible
// states of "player" during the game. "names" is vector of nodes
// returned by GetFinalActionNames for the player.
// REQUIRES: the State class should implement StateInterface.
template<class State>
void
FlatPrint(const Node& root,
Node::Player player,
const vector<State>& states,
const vector<string>& names)
{
Array<double> probs(states.size(), names.size());
vector<Record> records;
double total_states = 0.0;
for (size_t id = 0; id < states.size(); id++)
total_states += states[id].NumCombos();
Node::GetFinalActionProbs(root, player, probs);
for (size_t n = 0; n < names.size(); n++) {
double total_range = 0.0;
records.clear();
records.reserve(names.size());
for (size_t id = 0; id < states.size(); id++) {
total_range += states[id].NumCombos()*probs.get(id,n);
if (probs.get(id, n) >= 0.05)
records.push_back(Record(states[id].ToString(),
probs.get(id, n)));
}
sort(records.begin(), records.end());
printf("%s range: %.2f%% %s%c\n",
names[n].c_str(),
total_range/total_states*100,
State::Name().c_str(),
total_range == 1 ? ' ' : 's');
printf("%s\tProb\n", State::Name().c_str());
for (auto r : records)
printf("%s\t%.4f\n", r.state.c_str(), r.prob);
}
}
inline void
UtilError(Node::Player player, const string& name)
{
err::quit("Don't have utility for %s at the node %s.",
Node::player_names[player], name.c_str());
}
// Use CFR to update "node" by playing repeatedly "Node::VILLAIN"
// against "Node::HERO" during "num_iter" iterations then print the
// result to the standard output at the end. "hero_states" and
// "vill_states" are vectors containing all the possible states of
// "Node::HERO" and "Node::VILLAIN" during the game respectively. And
// "dealer" is the dealer of the game.
// REQUIRES: the State class should implement StateInterface.
template<class State>
void
Train(size_t num_iter,
const vector<State>& hero_states,
const vector<State>& vill_states,
const string& hero_name,
const string& vill_name,
DealerInterface& dealer,
Node& node)
{
double vutil = 0.0;
double hutil = 0.0;
size_t hero_id = 0;
size_t vill_id = 0;
for (size_t i = 1; i <= num_iter; i++) {
dealer.Deal(hero_id, vill_id);
vutil += node.CFR(Node::VILLAIN, vill_id, hero_id);
dealer.Deal(hero_id, vill_id);
hutil += node.CFR(Node::HERO, hero_id, vill_id);
if (i % 1000000 == 0)
fprintf(stderr, "%05.2f%% %s: %.8f, %s: %.8f\n",
static_cast<double>(i)/num_iter*100,
vill_name.c_str(),
vutil/i,
hero_name.c_str(),
hutil/i);
}
hutil /= num_iter;
vutil /= num_iter;
vector<string> hnames;
vector<string> vnames;
Node::GetFinalActionNames(node, hnames, vnames);
printf("%s: %.4f\n", vill_name.c_str(), vutil);
FlatPrint(node, Node::VILLAIN, vill_states, vnames);
printf("\n%s: %.4f\n", hero_name.c_str() , hutil);
FlatPrint(node, Node::HERO, hero_states, hnames);
}
} // namespace GTO
#endif // !GTO_CFR_INL_H_