-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdqn_model.py
39 lines (29 loc) · 1.04 KB
/
dqn_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.optim as optim
class DQN(nn.Module):
def __init__(self, obs_shape, num_actions,lr =1e-4):
super(DQN, self).__init__()
self.obs_shape = obs_shape
self.num_actions = num_actions
self.conv_net = torch.nn.Sequential(
torch.nn.Conv2d(4, 16, (8, 8), stride=(4, 4)),
torch.nn.ReLU(),
torch.nn.Conv2d(16, 32, (4, 4), stride=(2, 2)),
torch.nn.ReLU(),
)
with torch.no_grad():
dummy = torch.zeros((1, *obs_shape))
x = self.conv_net(dummy)
s = x.shape
fc_size = s[1] * s[2] * s[3]
self.fc_net = torch.nn.Sequential(
torch.nn.Linear(fc_size, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, num_actions),
)
self.optimizer = optim.Adam(self.parameters(), lr=lr)
def forward(self, x):
x /= 255
conv_latent = self.conv_net(x)
return self.fc_net(conv_latent.view((conv_latent.shape[0], -1)))