Skip to content

Conversation

@shjwudp
Copy link
Contributor

@shjwudp shjwudp commented Nov 26, 2025

Description

Recent modifications to FusedAdam have made it incompatible with DTensor. Specifically, in the optimizer state initialization section, the optimizer state is now created according to the global shape of the DTensor instead of creating a DTensor optimizer state with the same shape as the parameters.

To maintain compatibility with DTensor, the state tensors should be initialized using zeros_like(param) or empty_like(param) instead of zeros(param.shape) or empty(param.shape).

Fixes #2424

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…ike(param)/empty_like(param) to support DTensor

Signed-off-by: jianbinc <shjwudp@gmail.com>
@shjwudp shjwudp force-pushed the fused_adam_dtensor_issue branch from 0b1d2db to 629c786 Compare November 26, 2025 03:11
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 26, 2025

Greptile Overview

Greptile Summary

Fixed DTensor compatibility issue in FusedAdam optimizer by replacing torch.zeros(param.shape) and torch.empty(param.shape) with torch.zeros_like(param) and torch.empty_like(param) in the _initialize_state method.

Key Changes:

  • When using param.shape with DTensor parameters, the optimizer was creating state tensors with the global shape instead of preserving the DTensor's distributed structure
  • Using *_like() functions ensures that optimizer states (exp_avg, exp_avg_sq, master_param) maintain the same tensor type and distribution as the parameters
  • This fix enables FusedAdam to work correctly with FSDP2 and other distributed training frameworks that use DTensor

Note: Line 388 still uses param.shape in quantizer.make_empty(param.shape) for FP8 state initialization, which may need similar attention in future if FP8 DTensor support is required.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The fix is minimal, well-targeted, and addresses the root cause correctly by using *_like() functions instead of manually constructing tensors with param.shape, which preserves DTensor structure
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/optimizers/fused_adam.py 5/5 replaced torch.zeros(param.shape) and torch.empty(param.shape) with torch.zeros_like(param) and torch.empty_like(param) to properly handle DTensor parameters

Sequence Diagram

sequenceDiagram
    participant Optimizer as FusedAdam
    participant Param as Parameter/DTensor
    participant State as Optimizer State
    
    Note over Optimizer,State: State Initialization Flow
    
    Optimizer->>Optimizer: initialize_state(param)
    Optimizer->>Optimizer: _initialize_state(param, "exp_avg")
    
    alt Before this PR (broken for DTensor)
        Note over Optimizer,Param: torch.zeros(param.shape)<br/>returns global shape tensor
        Optimizer->>Param: param.shape
        Param-->>Optimizer: global_shape (e.g., [1024, 512])
        Optimizer->>State: torch.zeros(global_shape)
        Note over State: Creates tensor with global shape<br/>instead of local DTensor shape
    end
    
    alt After this PR (DTensor compatible)
        Note over Optimizer,Param: torch.zeros_like(param)<br/>preserves DTensor structure
        Optimizer->>Param: torch.zeros_like(param)
        Param-->>Optimizer: DTensor with correct local shape
        Optimizer->>State: Store DTensor state
        Note over State: Correctly creates DTensor state<br/>matching parameter structure
    end
    
    Note over Optimizer,State: State is now compatible with<br/>distributed tensor parallelism
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

"""
dtype = self.name_to_dtype_map[state_name]
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
Copy link
Collaborator

@vthumbe1503 vthumbe1503 Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also change run_fsdp2_model.py to use te FusedAdam optimizer instead of torch Adam so we dont break this again in the future?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Compatibility issues between FusedAdam and DTensor

2 participants