forked from thu-ml/tianshou
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathatari_network.py
More file actions
185 lines (161 loc) · 5.68 KB
/
atari_network.py
File metadata and controls
185 lines (161 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import nn
from tianshou.utils.net.discrete import NoisyLinear
class DQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
device: Union[str, int, torch.device] = "cpu",
features_only: bool = False,
output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.device = device
self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
nn.Flatten()
)
with torch.no_grad():
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
if not features_only:
self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
nn.Linear(512, np.prod(action_shape))
)
self.output_dim = np.prod(action_shape)
elif output_dim is not None:
self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, output_dim),
nn.ReLU(inplace=True)
)
self.output_dim = output_dim
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*)."""
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
return self.net(obs), state
class C51(DQN):
"""Reference: A distributional perspective on reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
device: Union[str, int, torch.device] = "cpu",
) -> None:
self.action_num = np.prod(action_shape)
super().__init__(c, h, w, [self.action_num * num_atoms], device)
self.num_atoms = num_atoms
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
obs, state = super().forward(obs)
obs = obs.view(-1, self.num_atoms).softmax(dim=-1)
obs = obs.view(-1, self.action_num, self.num_atoms)
return obs, state
class Rainbow(DQN):
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
noisy_std: float = 0.5,
device: Union[str, int, torch.device] = "cpu",
is_dueling: bool = True,
is_noisy: bool = True,
) -> None:
super().__init__(c, h, w, action_shape, device, features_only=True)
self.action_num = np.prod(action_shape)
self.num_atoms = num_atoms
def linear(x, y):
if is_noisy:
return NoisyLinear(x, y, noisy_std)
else:
return nn.Linear(x, y)
self.Q = nn.Sequential(
linear(self.output_dim, 512), nn.ReLU(inplace=True),
linear(512, self.action_num * self.num_atoms)
)
self._is_dueling = is_dueling
if self._is_dueling:
self.V = nn.Sequential(
linear(self.output_dim, 512), nn.ReLU(inplace=True),
linear(512, self.num_atoms)
)
self.output_dim = self.action_num * self.num_atoms
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
obs, state = super().forward(obs)
q = self.Q(obs)
q = q.view(-1, self.action_num, self.num_atoms)
if self._is_dueling:
v = self.V(obs)
v = v.view(-1, 1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
else:
logits = q
probs = logits.softmax(dim=2)
return probs, state
class QRDQN(DQN):
"""Reference: Distributional Reinforcement Learning with Quantile \
Regression.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_quantiles: int = 200,
device: Union[str, int, torch.device] = "cpu",
) -> None:
self.action_num = np.prod(action_shape)
super().__init__(c, h, w, [self.action_num * num_quantiles], device)
self.num_quantiles = num_quantiles
def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
obs, state = super().forward(obs)
obs = obs.view(-1, self.action_num, self.num_quantiles)
return obs, state