-
Notifications
You must be signed in to change notification settings - Fork 0
/
MKDSRL.py
406 lines (371 loc) · 14.5 KB
/
MKDSRL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
"""
Code for solving Mario Kart DS Figure 8 Circuit (uses DeSmuMe Nintendo DS Emulator)
Control keys:
X: accelerate
Q: Item
Right arrow key: Turn right
Left arrow key: Turn left
W: Drift
Network is a Convolutional DDQN
It takes 5 contiguous frames of a resized minimap (obtained via screenshot)
and decides between one of six actions (or twelve if drift is set to True):
1. Turn left
2. Turn right
3. Go straight
4. Turn left and use item
5. Turn right and use item
6. Go straight and use item
(for drift, the other six actions go in the same exact order, but the bot also drifts)
Uses one pixel of bottom screen to determine speed and direction (reward)
Bottom screen pixel made by custom Lua Script file that runs in tandem to this code
(more details on Lua Script in the Lua Script file)
One race is one episode (uses more reference pixels to determine when finished)
"""
from tensorflow.keras.layers import Dense, Input, Flatten, Conv3D, MaxPool3D
from pynput.keyboard import Key, Controller
from PIL import ImageGrab
from tensorflow.keras.models import Model
import time
import numpy as np
import random
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
# Initializing keyboard
keyboard = Controller()
# Setting seed
random.seed(1)
np.random.seed(1)
def get_screen(): # Retrieves screenshot of DS box
screen = ImageGrab.grab(bbox=(625, 125, 1275, 1125))
return screen
def is_equal(lis, lis2): # For RGB Value determination
i = len(lis2)
if lis[0] == lis2[0]:
for l in range(i - 1):
if lis[l] != lis2[l]:
return False
return True
return False
def get_speed_n_dir(pixel): # Gets the checkpoint for direction checking
# And speed for other rewards
if pixel[2] == 100: # Going in right direction
if pixel[1] == 200:
speed_dir = 0.75
elif pixel[1] == 150:
speed_dir = 0.5
elif pixel[1] == 100:
speed_dir = 0.25
else:
speed_dir = 0.125
else: # Going in wring direction
if pixel[1] == 200:
speed_dir = -0.75
elif pixel[1] == 150:
speed_dir = -0.5
elif pixel[1] == 100:
speed_dir = -0.25
else:
speed_dir = -0.125
return speed_dir
def is_finished(screen): # Checks reference pixels to make sure episode has ended
screen = np.array(screen)
if is_equal(screen[536][316], [69, 69, 158]) and is_equal(screen[531][197], [255, 250, 80]) and \
is_equal(screen[542][51], [184, 103, 20]):
return True
return False
def actt(DQN_output): # Presses appropriate keys given the action (includes drift)
inc = DQN_output
if inc == 0:
keyboard.release('q')
keyboard.release(Key.right)
keyboard.release(Key.left)
keyboard.release('w')
elif inc == 1:
keyboard.press('q')
keyboard.release(Key.right)
keyboard.release(Key.left)
keyboard.release('w')
elif inc == 2:
keyboard.release('q')
keyboard.press(Key.right)
keyboard.release(Key.left)
keyboard.release('w')
elif inc == 3:
keyboard.press('q')
keyboard.press(Key.right)
keyboard.release(Key.left)
keyboard.release('w')
elif inc == 4:
keyboard.release('q')
keyboard.release(Key.right)
keyboard.press(Key.left)
keyboard.release('w')
elif inc == 5:
keyboard.press('q')
keyboard.release(Key.right)
keyboard.press(Key.left)
keyboard.release('w')
elif inc == 6:
keyboard.release('q')
keyboard.release(Key.right)
keyboard.release(Key.left)
keyboard.press('w')
elif inc == 7:
keyboard.press('q')
keyboard.release(Key.right)
keyboard.release(Key.left)
keyboard.press('w')
elif inc == 8:
keyboard.release('q')
keyboard.press(Key.right)
keyboard.release(Key.left)
keyboard.press('w')
elif inc == 9:
keyboard.press('q')
keyboard.press(Key.right)
keyboard.release(Key.left)
keyboard.press('w')
elif inc == 10:
keyboard.release('q')
keyboard.release(Key.right)
keyboard.press(Key.left)
keyboard.press('w')
else:
keyboard.press('q')
keyboard.release(Key.right)
keyboard.press(Key.left)
keyboard.press('w')
def get_reward(speed): # Reward function
reward = speed * 2 # Reward based on speed
return reward
class DQN():
def __init__(self, ddqn, drift=False, episodes=2000, load=False):
# experience buffer
self.memory = []
# discount rate
self.gamma = 0.9
# initially 90% exploration, 10% exploitation
self.epsilon = 0.9
# iteratively applying decay til 10% exploration/90% exploitation
self.epsilon_min = 0.1
self.epsilon_decay = self.epsilon_min / self.epsilon
self.epsilon_decay = self.epsilon_decay ** (1. / float(episodes))
# Q Network weights filename
self.weights_file = 'ddqn_MKDS.h5' if ddqn else 'dqn_MKDS.h5'
self.n_outputs = 12 if drift else 6
# Q Network for training
self.q_model = self.build_model(self.n_outputs)
self.q_model.compile(loss='mse', optimizer=Adam())
# target Q Network
self.target_q_model = self.build_model(self.n_outputs)
# copy Q Network params to target Q Network
self.update_weights()
self.replay_counter = 0
self.ddqn = True if ddqn else False
if self.ddqn: #Loads in weights file if there is one
print("----------Double DQN--------")
else:
print("-------------DQN------------")
if load: #Can load in weight file to continue training
try:
self.target_q_model.load_weights(self.weights_file)
self.q_model.load_weights(self.weights_file)
except FileNotFoundError:
print("There isn't a file to be loaded")
def build_model(self, n_outputs): # Network architecture
inputs = Input(shape=(5, 84, 140, 3), name='state')
conv = Conv3D(64, (2, 4, 4), activation='relu')(inputs)
conv = Conv3D(64, (1, 4, 4), activation='relu')(conv)
conv = MaxPool3D((1, 2, 2))(conv)
conv = Conv3D(64, (2, 3, 3), activation='relu')(conv)
conv = Conv3D(64, (1, 3, 3), activation='relu')(conv)
conv = MaxPool3D((1, 2, 2))(conv)
conv = Conv3D(64, (2, 2, 2), activation='relu')(conv)
conv = Conv3D(32, (1, 2, 2), activation='relu')(conv)
conv = Conv3D(16, (2, 2, 2), activation='relu')(conv)
x = Flatten()(conv)
x = Dense(16, activation='relu')(x)
x = Dense(256, activation='relu')(x)
x = Dense(256, activation='relu')(x)
x = Dense(256, activation='relu')(x)
x = Dense(256, activation='relu')(x)
x = Dense(n_outputs, activation='linear', name='action')(x)
q_model = Model(inputs, x)
q_model.summary()
return q_model
# save Q Network params to a file
def save_weights(self):
self.q_model.save_weights(self.weights_file)
def update_weights(self):
self.target_q_model.set_weights(self.q_model.get_weights())
# eps-greedy policy
def act(self, state):
if np.random.rand() < self.epsilon:
rand_action = np.random.choice(6)
actt(rand_action)
return rand_action
# exploit
q_values = self.target_q_model.predict(state)
best_action = np.argmax(q_values[0])
actt(best_action)
return best_action
# store experiences in the replay buffer
def remember(self, state, action, reward, next_state, done):
item = (state, action, reward, next_state, done)
self.memory.append(item)
# compute Q_max
# use of target Q Network solves the non-stationarity problem
def forget(self, length):
for i in range(length):
self.memory.pop(0)
def get_target_q_value(self, next_state, reward):
# max Q value among next state's actions
if self.ddqn:
# current Q Network selects the action
# a'_max = argmax_a' Q(s', a')
action = np.argmax(self.q_model.predict(next_state)[0])
# target Q Network evaluates the action
# Q_max = Q_target(s', a'_max)
q_value = self.target_q_model.predict(next_state)[0][action]
else:
q_value = np.amax(self.target_q_model.predict(next_state)[0])
# Q_max = reward + gamma * Q_max
q_value *= self.gamma
q_value += reward
return q_value
# experience replay addresses the correlation issue between samples
def replay(self, batch_size):
# sars = state, action, reward, state' (next_state)
sars_batch = random.sample(self.memory, batch_size)
state_batch, q_values_batch = [], []
for state, action, reward, next_state, done in sars_batch:
# policy prediction for a given state
q_values = self.q_model.predict(state)
# get Q_max
q_value = self.get_target_q_value(next_state, reward)
# correction on the Q value for the action used
q_values[0][action] = reward if done else q_value
# collect batch state-q_value mapping
state_batch.append(state[0])
q_values_batch.append(q_values[0])
# train the Q-network
self.q_model.fit(np.array(state_batch),
np.array(q_values_batch),
batch_size=batch_size,
epochs=1,
verbose=True)
# update exploration-exploitation probability
self.update_epsilon()
# copy new params on old target after every 3 training updates
if self.replay_counter % 5 == 0:
self.update_weights()
self.replay_counter += 1
# decrease the exploration, increase exploitation
def update_epsilon(self):
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
episode_count = 2000
batch_size = 300
scores = []
running = False
race_length = []
agent = DQN(ddqn=True)
episode_count = 2000
batch_size = 300
scores = []
running = False
race_length = []
for episode in range(episode_count): # Main training loop
test_done = False
total_reward = 0
while not test_done:
screen = np.array(get_screen())
# Doesn't start predicting until some time after black screen shows up
if is_equal(screen[234][316], [0, 0, 0]) and is_equal(screen[531][197], [0, 0, 0]):
running = True
while running:
time.sleep(5.25) # So it gets the boost (makes training easier)
keyboard.press('x')
time.sleep(2)
frame = 1
state = []
screen = np.array(get_screen())
same_reward = 0
while not is_finished(screen):
screen = get_screen()
etat = np.array(screen.resize((84, 140))) #One frame
state.append(etat)
if frame % 5 == 0 and frame / 5 != 0:
state = np.array(state).reshape((1, 5, 84, 140, 3))
action = agent.act(state)
screen = get_screen()
speed = get_speed_n_dir(np.array(screen)[515][24]) # [515][24] is reference pixel
if frame == 5: #Does not have enough info
pass
elif frame == 10:
reward = get_reward(speed)
max_reward = reward
agent.remember(prev_state, prev_action, reward, state, is_finished(screen))
total_reward += reward
elif frame > 10:
reward = get_reward(speed)
if reward > max_reward:
max_reward = reward
agent.remember(prev_state, prev_action, reward, state, is_finished(screen))
total_reward += reward
prev_state = state
prev_action = action
state = []
frame += 1
running = False
test_done = True
keyboard.release('x')
keyboard.release('q')
keyboard.release(Key.right)
keyboard.release(Key.left)
keyboard.release('w')
scores.append(total_reward)
mean_score = np.mean(scores)
# Displaying some statistics
print('Episode ' + str(episode + 1) + ':')
print('Highest reward value attained: ' + str(max_reward))
print('Total score: ' + str(scores[-1]))
print('Average score :' + str(mean_score))
# call experience relay
item_lis = []
for i in range(3): # Forgets last two frames when it has finished but computer didn't pick it up
item_lis.append(agent.memory[-1])
agent.memory.pop(-1)
item_lis = list(item_lis)
agent.remember(item_lis[-1][0], item_lis[-1][1], item_lis[-1][2], item_lis[-1][3], True)
race_length.append(len(agent.memory) - sum(race_length))
if episode > 9: #Only remembering last ten races to help with speed of training. Better or worse resources may change this value.
agent.forget(race_length[0])
race_length.pop(0)
print(len(agent.memory))
if len(agent.memory) >= batch_size:
agent.replay(batch_size)
agent.save_weights() #Saves after every training so there will be something to go back to. Overwrite is true by default, so there are minimal problems
for i in range(2): # Sequence of key presses to select Figure-8 Circuit again
keyboard.press('x')
time.sleep(0.5)
keyboard.release('x')
time.sleep(0.5)
time.sleep(3)
keyboard.press('x')
time.sleep(0.5)
keyboard.release('x')
time.sleep(0.5)
for i in range(2):
keyboard.press('x')
time.sleep(0.1)
keyboard.release('x')
time.sleep(0.1)
# Plotting score graphs
episodes = (i for i in range(episode_count))
plt.plot(episodes, scores)
plt.xlabel('Episode')
plt.ylabel('Score')
plt.show()
#Saving weights of completed training
agent.save_weights()