-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathai_rlq.py
313 lines (278 loc) · 12.9 KB
/
ai_rlq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
'''
This module contains class implementing Reinforcement Learning
algorithm using Q-learning technique.
'''
import numpy as np
import random
import json
import os
from ai_base import SystemState, AI_Base, DecayingFloat
from snake import GameOutcome
class AI_RLQ(AI_Base):
'''
This is the implementation of the Reinforcement Learning algorithm.
At the beginning, the algorithm will look for `q-table-learned.json`
file which contains the learned Q-table. If it is found, the algorithm
will load and initialize its Q-table based on the data stored in the
file. If it is not found, the algorithm will initialize an empty
Q-table.
When termination signal is received, the algorithm will store its
Q-table in a JSON file named `q-table.json`.
The constructor takes one input parameter.
Parameters
----------
training_mode : bool, optional, default=True
Specify if this algorithm is in training mode (or online learning
mode). If not, then this algorithm will make decision based on
the established Q-table and won't perform any update to the Q-table.
'''
class Action:
'''
This is an inner class providing three possible actions, which
are FRONT, LEFT and RIGHT.
'''
LEFT = 0
FRONT = 1
RIGHT = 2
ALL = [LEFT, FRONT, RIGHT]
def __init__(self):
self.action = None
def __eq__(self, action:int) -> bool:
return self.action==action
def __int__(self) -> int:
return self.action
def set_action(self, action:int):
self.action = action
def get_action(self):
return self.action
def to_xy(self, x:int, y:int) -> (int,int):
'''It translates the relative movement to the absolute movement, and
returns the absolute movement as a tuple. The inputs x,y are the current
movement which are needed for the translation.'''
if self.action==self.FRONT:
pass
elif self.action==self.LEFT: # left of (x,y) direction
if x!=0:
y = -x; x = 0
else:
x = y; y = 0
elif self.action==self.RIGHT: # right of (x,y) direction
if x!=0:
y = x; x = 0
else:
x = -y; y = 0
return (x,y)
## system state: inheriting from SystemState class
## but translate to relative to the movement of the snake
class State(SystemState):
'''
This is an inner class for the translated system state. It
translates absolute direction (north/east/south/west) given by
the environment to a relative direction (front/back/left/right),
relative to the movement of the snake.
'''
def __init__(self, other:SystemState):
## translating north/east/south/west to front/back/left/right
self.obj_front = None
self.obj_left = None
self.obj_right = None
self.food_front = None
self.food_back = None
self.food_left = None
self.food_right = None
self.dir_x = other.dir_x
self.dir_y = other.dir_y
if other.dir_x==+1: # moving east
self.obj_front = other.obj_east
self.obj_left = other.obj_north
self.obj_right = other.obj_south
self.food_front = other.food_east
self.food_back = other.food_west
self.food_left = other.food_north
self.food_right = other.food_south
elif other.dir_x==-1: # moving west
self.obj_front = other.obj_west
self.obj_left = other.obj_south
self.obj_right = other.obj_north
self.food_front = other.food_west
self.food_back = other.food_east
self.food_left = other.food_south
self.food_right = other.food_north
elif other.dir_y==+1: # moving south
self.obj_front = other.obj_south
self.obj_left = other.obj_east
self.obj_right = other.obj_west
self.food_front = other.food_south
self.food_back = other.food_north
self.food_left = other.food_east
self.food_right = other.food_west
elif other.dir_y==-1: # moving north
self.obj_front = other.obj_north
self.obj_left = other.obj_west
self.obj_right = other.obj_east
self.food_front = other.food_north
self.food_back = other.food_south
self.food_left = other.food_west
self.food_right = other.food_east
def __eq__(self, other):
return isinstance(other, SystemState) and str(self)==str(other)
def __hash__(self):
return hash(str(self))
def __str__(self):
return "["+("<" if self.food_left else " ") \
+("^" if self.food_front else " ") \
+(">" if self.food_right else " ") \
+("v" if self.food_back else " ") + "]," \
+ "[%+d,%+d,%+d]"%(self.obj_left,self.obj_front,self.obj_right)
## the following state info doesn't appear to help,
## so removed
#+ "-%s"%("N" if self.dir_y==-1 else "S" if self.dir_y==1 else \
# "W" if self.dir_x==-1 else "E")
def __init__(self, training_mode:bool=True):
'''Default constructor.'''
super().__init__()
self._name = "Q-Learning " \
+ ("" if training_mode else "(testing mode)")
## episode related hyperparameters
## note: our programming control flow is environment oriented,
## we can't control the number of episodes and length here. They
## are ignored.
## - num_episodes: The environment sets this to infinity by default,
## but the user can terminate at anytime by ^C or
## choose not to continue on the popup dialog when
## the snake crashed.
## - len_episodes: The environment sets this to infinity by default,
## so the only terminating condition is a snake crash.
self.num_episodes: int = 2000 # number of episodes
self.len_episodes: int = 10000 # max number of steps in each episode
## learning related hyperparameters
self.alpha: float = 0.2 # learning rate
self.gamma: float = 0.9 # discount factor
self.epsilon: float = 0.05 # exploration weight, 0=no; 1=full
#self.epsilon = DecayingFloat(1.0, mode="exp", factor=0.9, minval=0.1)
self.training_mode: bool = training_mode # training mode (T/F)?
if not self.training_mode:
self.epsilon = 0.0 # if not in training, zero exploration
## reward settings
self.food_reward: int = 10 # reward for getting the snake to eat the food
self.crash_reward: int = -10 # negative reward for being crashed
## Q-table: q_table[s:State][a:Action] it is a dict
self.q_table = dict()
## current state & action
self.current_state = None
self.current_action = None
## load Q-table
self.load_table()
def load_table(self):
'''Load Q-table from `q-table-learned.json`. This is used internally.'''
filename_q_table = "q-table-learned.json"
if os.path.exists(filename_q_table):
with open(filename_q_table, "r") as fp:
self.q_table = json.load(fp)
if len(self.q_table)!=0:
print("- loaded '%s' which contains %d states"
%(filename_q_table,len(self.q_table)))
else:
print("- '%s' not found, no experience is used"%filename_q_table)
def save_table(self):
'''Save Q-table to `q-table.json`. This is used internally.'''
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NpEncoder, self).default(obj)
## write Q-Table to the json file
## this way, we don't lose the training data
with open("q-table.json", "w") as fp:
json.dump(self.q_table, fp, cls=NpEncoder, indent=4)
def state_str(self, state:SystemState) -> str:
'''It returns the string representation of the system state
observed by this algorithm. This implementation uses
translated system state, see `AI_RLQ.State` inner class.
Returns
-------
str
The string representation of the translated system state.
'''
return str(self.State(state))
## helper function, easy access to the Q-table
def q(self, state):
'''It provides easy access to Q-table, i.e. use `q(s)[a]` to access
the Q-value of state `s` and action `a`.
Parameters
----------
state : AI_RLQ.State
The translated system state instance.
'''
s = str(state) # we use str to index Q-table, easier to debug
if s not in self.q_table:
## create a row for this new state in Q-table
self.q_table[s] = np.zeros(len(self.Action.ALL))
return self.q_table[s]
def callback_take_action(self, state:SystemState) -> (int,int):
'''Here we implement the Q-learning exploration-exploitation.
For exploration, random action is pick. For exploitatioin,
the best action (that is, the action that can lead to the next
immediate state carrying the highest Q-value) is picked.'''
## setup current state 's'
s = self.State(state)
self.current_state = s # keep the state
## step 1: choose action 'a' based on the system state
## exploration or explotation?
a = self.Action()
possible_actions = []
if random.uniform(0, 1) < self.epsilon:
## exploration: include all actions
possible_actions = self.Action.ALL.copy()
else:
## exploitation: limit to the choice based on optimal policy
## may have multiple same max value
## ie. pi_star(s) = argmax_a(Q_star(s,a))
max_value = np.max(self.q(s)) # find the max value first
for i in self.Action.ALL:
if self.q(s)[i]==max_value:
possible_actions.append(i) # add all carrying max value
a.set_action(random.choice(possible_actions))
self.current_action = a # keep the action
## step 2:
## now we need to return our action to the environment
## so that the environment can take action and call us back via
## 'callback_action_outcome()' to inform us the outcome.
## .to_xy() will translate back from FRONT/LEFT/RIGHT to (x,y) direction
return a.to_xy(s.dir_x,s.dir_y)
def callback_action_outcome(self, state:SystemState, outcome:GameOutcome):
'''Here we implement the update of Q-table based on the outcome.
This will make the algorithm learned how good its previous action
is. The update is done using Bellman equation.'''
## ...continuing from 'callback_take_action()'
## retrieve: state, action -> next_state
s = self.current_state # was the state before our action
a = self.current_action # was our action FRONT/LEFT/RIGHT
s1 = self.State(state) # is the state after our action
## step 3: calculate the reward
if outcome==GameOutcome.CRASHED_TO_BODY or \
outcome==GameOutcome.CRASHED_TO_WALL:
reward = self.crash_reward
elif outcome==GameOutcome.REACHED_FOOD:
reward = self.food_reward
else:
reward = 0 # no reward for this time step
## step 4: update Q table using Bellman equation
## Q_next(s,a) = Q(s,a) \
## + alpha * (reward + gamma*max_a(Q(s_next,a)) - Q(s,a))
## = (1-alpha) * Q(s,a)
## + alpha * (reward + gamma*max_a(Q(s_next,a)))
## update Q-Tabel only if we're in the training mode
if self.training_mode:
a = int(a) # 'a' needs to be an integer now to index the Q-table
self.q(s)[a] = self.q(s)[a] \
+ self.alpha * (reward + self.gamma*np.max(self.q(s1)) - self.q(s)[a])
def callback_terminating(self):
'''This is a listener listening to the termination signal. When triggered,
it saves its Q-table.'''
self.save_table()