From ac14c05ff92d72627e1387b38e8b162998b3c099 Mon Sep 17 00:00:00 2001 From: leeyngdo Date: Mon, 17 Mar 2025 05:09:31 +0000 Subject: [PATCH] fix errata in simba architecture --- scale_rl/agents/simba/simba_network.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scale_rl/agents/simba/simba_network.py b/scale_rl/agents/simba/simba_network.py index 0b084f7..adc541d 100755 --- a/scale_rl/agents/simba/simba_network.py +++ b/scale_rl/agents/simba/simba_network.py @@ -23,8 +23,9 @@ def setup(self): ) self.encoder = nn.Sequential( [ - PreLNResidualBlock(hidden_dim=self.hidden_dim) - for _ in range(self.num_blocks) + *[PreLNResidualBlock(hidden_dim=self.hidden_dim) + for _ in range(self.num_blocks)], + nn.LayerNorm(), ] ) self.predictor = NormalTanhPolicy(self.action_dim) @@ -51,8 +52,9 @@ def setup(self): ) self.encoder = nn.Sequential( [ - PreLNResidualBlock(hidden_dim=self.hidden_dim) - for _ in range(self.num_blocks) + *[PreLNResidualBlock(hidden_dim=self.hidden_dim) + for _ in range(self.num_blocks)], + nn.LayerNorm(), ] ) self.predictor = LinearCritic()