-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Hello author, thanks for the great work!
I've been reading the paper and exploring the code, and I noticed that in CtxEncoder, there are options for different network types, including mlp, gru and so on. However, I didn't see any experimenatal results or comparisons for these variants in the paper.
I was wondering if you performed any ablation studies comparing these different encoder architectures? I'm curious about the performance of CtxEncoder with mlp or gru backbones. Do they generalize better than Dreamerv3, or are they comparable to the Transformer version?
If you have any insights or experimental notes on this, that would be bery helpful for my understanding.
Thanks again!
DALI/dreamerv3_compat/dreamerv3/nets.py
Lines 52 to 101 in 34374fb
| def __call__(self, inputs): | |
| # Extract inputs for the context encoding | |
| feat = self._inputs(inputs) | |
| if self._symlog_inputs: | |
| feat = jaxutils.symlog(feat) | |
| x = jaxutils.cast_to_compute(feat) | |
| batch_size, batch_len, batch_dim = x.shape[0], x.shape[1], x.shape[2] | |
| if self._input_permutation: | |
| x = jax.random.permutation(nj.rng(), x, axis=1) | |
| if self.network_type == "mlp": | |
| x = x.reshape([batch_size, -1]) | |
| for i in range(self.mlp_opts["layers"]): | |
| mlp_opts = {k:v for k,v in self.mlp_opts.items() if k not in ["layers"]} | |
| x = self.get(f'ctx_linear{i}', Linear, **mlp_opts)(x) | |
| elif self.network_type == "gru": | |
| x = self.get('gru_lin_in', Linear, **self.gru_opts["linear"])(x) | |
| current_state = jnp.zeros([batch_size, self.gru_opts["units"]], f32) | |
| hidden_states = [] | |
| for t in range(batch_len): | |
| current_state, _ = self._gru(x[:, t], current_state) | |
| hidden_states.append(current_state) | |
| hidden_history = jnp.stack(hidden_states, axis=0) | |
| final_state = hidden_history[-1] | |
| x = self.get('gru_lin_out', Linear, **self.gru_opts["linear"])(final_state) | |
| elif self.network_type == "attention": | |
| # x = x.reshape([batch_size, -1]) | |
| x = self.get('proj', Linear, **self.attn_opts["linear"])(x) | |
| x = self.get("norm", Norm, "layer")(x) | |
| x = self.get('attn', Attention, self.attn_opts["heads"], self.attn_opts["units"])(x, x, x) | |
| x = x.reshape([batch_size, -1]) | |
| elif self.network_type == "transformer": | |
| # Following encoder block | |
| # https://github.com/jlin816/dynalang/blob/0da77173ee4aeb975bd8a65c76ddb187fde8de81/dynalang/nets.py#L917 | |
| # x = x.reshape([batch_size, -1]) | |
| x = self.get('proj', Linear, **self.attn_opts["linear"])(x) | |
| skip = x | |
| x = self.get("norm1", Norm, "layer")(x) | |
| x = self.get('attn', Attention, self.attn_opts["heads"], self.attn_opts["units"])(x, x, x) | |
| x += skip | |
| skip = x | |
| x = self.get("norm2", Norm, "layer")(x) | |
| x = self.get('ff1', Linear, **self.attn_opts["linear"])(x) | |
| x = self.get('ff2', Linear, **self.attn_opts["linear"])(x) | |
| x += skip | |
| x = x.reshape([batch_size, -1]) | |
| ctx_out = self.get('ctx_out', Linear, **self._kw["linear_ctx_out"])(x) | |
| return jnp.broadcast_to(ctx_out[:, None, :], (batch_size, batch_len, ctx_out.shape[-1])) |