diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b5c87b4815..935af8ee0e 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -373,9 +373,9 @@ def _initialize_state( """ dtype = self.name_to_dtype_map[state_name] if store_param_remainders: - data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) + data = torch.zeros_like(param, dtype=torch.int16, device=param.device) else: - data = torch.empty(param.shape, dtype=dtype, device=param.device) + data = torch.empty_like(param, dtype=dtype, device=param.device) if zero_buffer: data.zero_()