-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdqn.py
36 lines (31 loc) · 1.54 KB
/
dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Implementing Deep Q-Learning with Experience Replay
import numpy as np
class DQN(object):
# INTRODUCING AND INITIALIZING ALL THE PARAMETERS AND VARIABLES OF THE DQN
def __init__(self, max_memory = 100, discount = 0.9):
self.memory = list()
self.max_memory = max_memory
self.discount = discount
# MAKING A METHOD THAT BUILDS THE MEMORY IN EXPERIENCE REPLAY
def remember(self, transition, game_over):
self.memory.append([transition, game_over])
if len(self.memory) > self.max_memory:
del self.memory[0]
# MAKING A METHOD THAT BUILDS TWO BATCHES OF INPUTS AND TARGETS BY EXTRACTING TRANSITIONS FROM THE MEMORY
def get_batch(self, model, batch_size = 10):
len_memory = len(self.memory)
num_inputs = self.memory[0][0][0].shape[1]
num_outputs = model.output_shape[-1]
inputs = np.zeros((min(len_memory, batch_size), num_inputs))
targets = np.zeros((min(len_memory, batch_size), num_outputs))
for i, idx in enumerate(np.random.randint(0, len_memory, size = min(len_memory, batch_size))):
current_state, action, reward, next_state = self.memory[idx][0]
game_over = self.memory[idx][1]
inputs[i] = current_state
targets[i] = model.predict(current_state)[0]
Q_sa = np.max(model.predict(next_state)[0])
if game_over:
targets[i, action] = reward
else:
targets[i, action] = reward + self.discount * Q_sa
return inputs, targets