Skip to content

Conversation

@weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Jan 28, 2026

command: CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh

fsdp2 support per-param mesh: pytorch/pytorch#173509

this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

def _shard_placement_fn(param: nn.Parameter) -> ShardPlacementResult:
    if param in expert_params:
        # Expert parameters: use Shard(1) on edp_mesh
        return ShardPlacementResult(
            placement=Shard(1), mesh_info=edp_mesh_info
        )
    else:
        # Non-expert parameters: use Shard(0) on dp_mesh
        return ShardPlacementResult(
            placement=Shard(0), mesh_info=dp_mesh_info
        )

this make it possible for apply torch.compile on each transformer_block. I didn't enable compile per block yet becuase there is still a gap in torch.compile + ac + MoE: #2341

AG order in forward are exactly the same before and after this change
Screenshot 2026-02-06 at 14 57 04

AG order in backward are different but is better
Screenshot 2026-02-06 at 14 59 44

Explicit Backward AllGather Order                                                                                                                                                                                         
  layers.7       @ 118.83ms   (attention/ffn params)                                                                 
  layers.6       @ 121.52ms   (attention/ffn params)                                                                 
  layers.6.moe   @ 122.04ms   (MoE expert params)                                                                    
  layers.7.moe   @ 125.81ms   (MoE expert params)  ← delayed!                                                        
                                                                                                                     
  Per-param Backward AllGather Order                                                                                 
  layers.7       @ 114.30ms   (first FSDP unit)                                                                      
  layers.7       @ 115.14ms   (second FSDP unit, includes MoE)                                                       
  layers.6       @ 117.42ms   (first FSDP unit)                                                                      
  layers.6       @ 117.89ms   (second FSDP unit, includes MoE)   

Numerics remains bitwise equal with/without this change

 Loss Comparison                                                                                                                                                                                                                           
  ┌──────┬───────────────┬───────────────┬───────┐                                                                                                                                                                                          
  │ Step │ Old (0d93c63) │ New (e1c47c8) │ Match │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 1    │ 8.01151657    │ 8.01151657    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 5    │ 3.85572004    │ 3.85572004    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 10   │ 3.15517211    │ 3.15517211    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 15   │ 3.07873583    │ 3.07873583    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 20   │ 2.92206621    │ 2.92206621    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 25   │ 2.89102936    │ 2.89102936    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 30   │ 2.81201696    │ 2.81201696    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 35   │ 2.84123349    │ 2.84123349    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 40   │ 2.76206398    │ 2.76206398    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 45   │ 2.82969308    │ 2.82969308    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 50   │ 2.77560568    │ 2.77560568    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 55   │ 2.75578761    │ 2.75578761    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 60   │ 2.75143075    │ 2.75143075    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 65   │ 2.74203372    │ 2.74203372    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 70   │ 2.71638918    │ 2.71638918    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 75   │ 2.74999237    │ 2.74999237    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 80   │ 2.75584078    │ 2.75584078    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 85   │ 2.74837303    │ 2.74837303    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 90   │ 2.72101045    │ 2.72101045    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 95   │ 2.73645735    │ 2.73645735    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 100  │ 2.70604038    │ 2.70604038    │ ✓     │                                                                                                                                                                                          
  └──────┴───────────────┴───────────────┴───────┘                                                                                                                                                                                          

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 28, 2026
@weifengpy weifengpy marked this pull request as draft January 28, 2026 20:44
@weifengpy weifengpy changed the title [FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile [WIP][FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile Jan 28, 2026
@weifengpy weifengpy force-pushed the per-param-mesh branch 5 times, most recently from bd87f68 to 3c36e53 Compare February 7, 2026 01:45
@weifengpy weifengpy changed the title [WIP][FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile [FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile Feb 7, 2026
@weifengpy weifengpy changed the title [FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile [FSDP2] enable per-param mesh FSDP2 for MoE Feb 7, 2026
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant