@@ -15,11 +15,11 @@ def __init__(
15
15
self ,
16
16
state_dim ,
17
17
action_space ,
18
- n_atoms ,
18
+ n_atoms = 51 ,
19
19
seed = 0 ,
20
- hidden_size = None ,
21
- init_weight_gain = np . sqrt ( 2 ) ,
22
- init_bias = 0
20
+ fc1_unit = 128 ,
21
+ fc2_unit = 128 ,
22
+ fc3_unit = 128 ,
23
23
):
24
24
"""
25
25
Initialize parameters and build model.
@@ -31,58 +31,26 @@ def __init__(
31
31
fc1_unit (int): Number of nodes in first hidden layer
32
32
fc2_unit (int): Number of nodes in second hidden layer
33
33
"""
34
- super ().__init__ ()
34
+ super ().__init__ () ## calls __init__ method of nn.Module class
35
+ self .seed = torch .manual_seed (seed )
35
36
self .action_space = action_space
36
37
self .n_atoms = n_atoms
37
- self .seed = torch .manual_seed (seed )
38
- self .hidden_size = (100 , 100 , 100 ) if not hidden_size else hidden_size
39
- self .bn = nn .BatchNorm1d (state_dim )
40
-
41
- def init_weights (m ):
42
- if isinstance (m , nn .Linear ):
43
- nn .init .orthogonal_ (m .weight , gain = init_weight_gain )
44
- nn .init .constant_ (m .bias , init_bias )
45
-
46
- # note: The self.hidden_layers attribute is defined as a list of lists,
47
- # note: but it should be a list of `nn.Sequential` objects.
48
- # note: You can fix this by using `nn.Sequential` to define each layer.
49
- # note: After using `nn.Sequential`, you need to define a list with
50
- # note: `nn.ModuleList` to construct the model graph.
51
- self .hidden_layers = nn .ModuleList ([
52
- nn .Sequential (nn .Linear (in_size , out_size ), nn .LeakyReLU ())
53
- for in_size , out_size in zip ((state_dim , ) +
54
- self .hidden_size , self .hidden_size )
55
- ])
56
- self .hidden_layers .apply (init_weights )
57
-
58
- def init_output_weights (m ):
59
- if isinstance (m , nn .Linear ):
38
+ self .fc1 = nn .Linear (state_dim , fc1_unit )
39
+ self .fc2 = nn .Linear (fc1_unit , fc2_unit )
40
+ self .fc3 = nn .Linear (fc2_unit , fc3_unit )
41
+ self .fc4 = nn .Linear (fc3_unit , action_space * n_atoms )
60
42
61
- nn .init .orthogonal_ (m .weight , gain = init_weight_gain )
62
- nn .init .constant_ (m .bias , init_bias )
63
-
64
- self .output_layers = nn .ModuleList ([
65
- nn .Sequential (
66
- nn .Linear (self .hidden_size [- 1 ], n_atoms ), nn .LeakyReLU (),
67
- nn .Softmax (dim = - 1 )
68
- ) for _ in range (action_space )
69
- ])
70
-
71
- self .output_layers .apply (init_output_weights )
72
-
73
- def forward (self , state ):
74
- x = self .bn (state )
75
- for hidden_layer in self .hidden_layers :
76
- x = hidden_layer (x )
77
- out = torch .concat ([
78
- torch .unsqueeze (output_layer (x ), dim = 1 )
79
- for output_layer in self .output_layers
80
- ],
81
- dim = 1 )
82
- # x = self.output_layer(x)
83
- # x = torch.reshape(x, (-1, self.action_space, self.n_atoms))
84
- # x = F.softmax(x, dim=-1)
85
- return out
43
+ def forward (self , x ):
44
+ """
45
+ Build a network that maps state -> action values.
46
+ """
47
+ x = F .leaky_relu (self .fc1 (x ))
48
+ x = F .leaky_relu (self .fc2 (x ))
49
+ x = F .leaky_relu (self .fc3 (x ))
50
+ x = self .fc4 (x )
51
+ x = torch .reshape (x , (- 1 , self .action_space , self .n_atoms ))
52
+ x = F .softmax (x , dim = - 1 )
53
+ return x
86
54
87
55
88
56
# device = torch.device("cpu")
0 commit comments