Skip to content

Commit

Permalink
resolve backward compatibility issues to run on the cluster docker
Browse files Browse the repository at this point in the history
  • Loading branch information
runjerry committed Jan 17, 2025
1 parent bbe28a1 commit 3b2dcf2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
5 changes: 3 additions & 2 deletions alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,9 @@ def _unroll_iter_off_policy(self):
(config.num_env_steps == 0
or self.get_step_metrics()[1].result() < config.num_env_steps)):
unrolled = True
with (torch.set_grad_enabled(config.unroll_with_grad),
torch.cuda.amp.autocast(config.enable_amp)):
with torch.set_grad_enabled(
config.unroll_with_grad), torch.cuda.amp.autocast(
config.enable_amp):
with record_time("time/unroll"):
self.eval()
# The period of performing unroll may not be an integer
Expand Down
19 changes: 14 additions & 5 deletions alf/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,14 @@ def _update_accumulator_(path, mask, step, acc, val):
"""In-place update of the accumulators and mask."""
val_valid = torch.isfinite(val)
# If at any step the value is valid, then the acc value becomes valid
mask[:] = torch.where(is_first, 0, mask | val_valid)
mask[:] = torch.where(
is_first, torch.tensor(
0, dtype=mask.dtype, device=mask.device), mask | val_valid)
# Only step+1 if the value is valid
step[:] = torch.where(is_first, 0,
step + val_valid.to(self._dtype))
step[:] = torch.where(
is_first, torch.tensor(
0, dtype=step.dtype, device=step.device),
step + val_valid.to(self._dtype))

if path.endswith("@max"):
# Don't max invalid values
Expand All @@ -283,11 +287,16 @@ def _update_accumulator_(path, mask, step, acc, val):
torch.maximum(acc, val.to(self._dtype)))
else:
# Don't sum invalid values
val = torch.where(val_valid, val, 0)
val = torch.where(
val_valid, val,
torch.tensor(0, dtype=val.dtype, device=val.device))
# Zero out batch indices where a new episode is starting.
# Update with new values; Ignores first step whose reward comes from
# the boundary transition of the last step from the previous episode.
acc[:] = torch.where(is_first, 0, acc + val.to(self._dtype))
acc[:] = torch.where(
is_first,
torch.tensor(0, dtype=acc.dtype, device=acc.device),
acc + val.to(self._dtype))

alf.nest.py_map_structure_with_path(_update_accumulator_, self._mask,
self._steps, self._accumulator,
Expand Down

0 comments on commit 3b2dcf2

Please sign in to comment.