Skip to content

Commit 0b1d2db

Browse files
committed
FusedAdam: replace zeros(param.shape)/empty(param.shape) with zeros_like(param)/empty_like(param) to support DTensor
1 parent 9ca89e9 commit 0b1d2db

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

transformer_engine/pytorch/optimizers/fused_adam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,9 @@ def _initialize_state(
373373
"""
374374
dtype = self.name_to_dtype_map[state_name]
375375
if store_param_remainders:
376-
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
376+
data = torch.zeros_like(param, dtype=torch.int16, device=param.device)
377377
else:
378-
data = torch.empty(param.shape, dtype=dtype, device=param.device)
378+
data = torch.empty_like(param, dtype=dtype, device=param.device)
379379
if zero_buffer:
380380
data.zero_()
381381

0 commit comments

Comments
 (0)