diff --git a/scale_rl/agents/simba/simba_network.py b/scale_rl/agents/simba/simba_network.py index 0b084f7..adc541d 100755 --- a/scale_rl/agents/simba/simba_network.py +++ b/scale_rl/agents/simba/simba_network.py @@ -23,8 +23,9 @@ def setup(self): ) self.encoder = nn.Sequential( [ - PreLNResidualBlock(hidden_dim=self.hidden_dim) - for _ in range(self.num_blocks) + *[PreLNResidualBlock(hidden_dim=self.hidden_dim) + for _ in range(self.num_blocks)], + nn.LayerNorm(), ] ) self.predictor = NormalTanhPolicy(self.action_dim) @@ -51,8 +52,9 @@ def setup(self): ) self.encoder = nn.Sequential( [ - PreLNResidualBlock(hidden_dim=self.hidden_dim) - for _ in range(self.num_blocks) + *[PreLNResidualBlock(hidden_dim=self.hidden_dim) + for _ in range(self.num_blocks)], + nn.LayerNorm(), ] ) self.predictor = LinearCritic()