diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py index c83286d8e..4d97f281b 100644 --- a/alf/algorithms/sac_algorithm.py +++ b/alf/algorithms/sac_algorithm.py @@ -287,11 +287,13 @@ def __init__(self, observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, + debug_summaries=debug_summaries, config=config) target_repr_alg = repr_alg_ctor( observation_spec=observation_spec, action_spec=action_spec, reward_spec=reward_spec, + debug_summaries=debug_summaries, config=config) assert hasattr(repr_alg, 'output_spec'), "repr_alg must have output_spec" @@ -884,7 +886,7 @@ def train_step(self, inputs: TimeStep, state: SacState, # usage can be reduced because its computation graph will not be kept. with torch.no_grad(): tgt_repr_step = self._target_repr_alg.predict_step( - inputs, rollout_info.repr) + inputs, state.target_repr) target_observation = tgt_repr_step.output target_repr_state = tgt_repr_step.state else: