diff --git a/alf/utils/data_buffer.py b/alf/utils/data_buffer.py index d1caa6680..31f9f5e2c 100644 --- a/alf/utils/data_buffer.py +++ b/alf/utils/data_buffer.py @@ -146,8 +146,6 @@ def _create_buffer(spec_path, tensor_spec): self._buffer = alf.nest.py_map_structure_with_path( _create_buffer, data_spec) - self._flattened_buffer = alf.nest.map_structure( - lambda x: x.view(-1, *x.shape[2:]), self._buffer) if allow_multiprocess: self.share_memory() @@ -272,8 +270,8 @@ def _set(buf, bat): indices = env_ids * self._max_length + self.circular( current_pos) alf.nest.map_structure( - lambda buf, bat: buf.__setitem__(indices, bat.detach()), - self._flattened_buffer, batch) + lambda buf, bat: buf.view(-1, *buf.shape[2:]).__setitem__( + indices, bat.detach()), self._buffer, batch) self._current_pos[env_ids] += 1 current_size = self._current_size[env_ids]