We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9ca89e9 commit 0b1d2dbCopy full SHA for 0b1d2db
transformer_engine/pytorch/optimizers/fused_adam.py
@@ -373,9 +373,9 @@ def _initialize_state(
373
"""
374
dtype = self.name_to_dtype_map[state_name]
375
if store_param_remainders:
376
- data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
+ data = torch.zeros_like(param, dtype=torch.int16, device=param.device)
377
else:
378
- data = torch.empty(param.shape, dtype=dtype, device=param.device)
+ data = torch.empty_like(param, dtype=dtype, device=param.device)
379
if zero_buffer:
380
data.zero_()
381
0 commit comments