update use of optimizer to avoid runtime errors on sharded gradients #54
+234
−61
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Applied the
FSDP2 fully_shardwrapper so the runtime uses the new API and the toggles now take effect.Changes:
Switched the global wrapper to
FSDP2 fully_shardwith aDeviceMeshbuilt from the active process group ranks indistributed.py:14-399.Updated
MultiTaskModelMPto use FSDP2 on encoder/decoder with per-group meshes in[MultiTaskModelMP.py:8-276].Added FSDP2 detection for save/load paths so checkpointing doesn’t call FSDP1-only APIs in
[model.py:20-325].Added a per-rank runtime log for FSDP2 activation in
[distributed.py:392-399]. It will print FSDP2 active on rank<rank>: <ModelClass> when HYDRAGNN_USE_FSDP=1.Notes:
FSDP2 currently supports FULL_SHARD only; the code now warns and ignores other
HYDRAGNN_FSDP_STRATEGYvalues.[set_reshard_after_backward(False/True)]toggles in[train_validate_test.py:70-958]will now apply to FSDP2 and should address the[autograd.grad()]storage error.