Skip to content

Conversation

anshul-si
Copy link

@anshul-si anshul-si commented Sep 15, 2025

Summary: During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched.

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp.

Below is a link comparing the loss curves for Llama3.1-8B models: one configured with dimension sharding (2) and tensor parallelism (4), and the other with dimension replication (2) and sharding (4).

image

https://fburl.com/mlhub/btkos8ok

Test Case

  1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)

Stack from ghstack (oldest at bottom):

anshul-si added a commit that referenced this pull request Sep 15, 2025
…torchtitan

ghstack-source-id: ea5b964
Pull Request resolved: #1714
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 15, 2025
@anshul-si anshul-si marked this pull request as draft September 15, 2025 23:12
@anshul-si anshul-si requested review from mori360 and removed request for fegin September 15, 2025 23:12
…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 16, 2025
…torchtitan

ghstack-source-id: 19a48b7
Pull Request resolved: #1714
…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 16, 2025
…torchtitan

ghstack-source-id: 7bba1f6
Pull Request resolved: #1714
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
replicate(model, device_mesh=dp_mesh)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replicate is meant to be equivalent to FSDP1 NO_SHARD, instead of DDP

the key difference is DDP has bucket_cap_mb to overlap all_reduce with compute. replicate does not have such overlapping

the big assumption is user uses fsdp2, but want to avoid all-gathers for small modules, that's the time we use replicate

cc @tianyu-l in case we are counting replicate for DDP + EP

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replicate is meant to be equivalent to FSDP1 NO_SHARD, instead of DDP

What does this mean? I thought we are doing replicate only because we can replace DDP with replicate so that we can do DDP + TP or other parallelisms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for overlappings in replicate, we should use nested wrapping

for layer in model.layers:
    replicate(layer)
    
replicate(model)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replicate is meant to be equivalent to FSDP1 NO_SHARD, instead of DDP

What does this mean? I thought we are doing replicate only because we can replace DDP with replicate so that we can do DDP + TP or other parallelisms.

synced in chat that we can replace DDP with replicate in titan, just need the per-layer wrapping for replicate

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the big assumption is user uses fsdp2, but want to avoid all-gathers for small modules, that's the time we use replicate

Does this means we could wrap some of small modules in the model with Replicate(), while other modules with fully_shard()? That's very flexible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's right

…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 23, 2025
…torchtitan

ghstack-source-id: bfe9ee3
Pull Request resolved: #1714
@anshul-si anshul-si marked this pull request as ready for review September 23, 2025 20:16
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Had some comments

elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
# if world_mesh.ndim > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always have >= 2D mesh as you would enable dp_shard==1 anyway?
We should:

  1. verify DDP+TP and DDP+PP work with correct numerics (see instructions at https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#seed-checkpoint-based-reproducibility)
  2. remove this comment
  3. change all other occurrences as they depend on this function https://github.com/search?q=repo%3Apytorch%2Ftorchtitan%20apply_ddp&type=code

…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

Below is a link comparing the loss curves for Llama3.1-8B models: one configured with dimension sharding (2) and tensor parallelism (4), and the other with dimension replication (2) and sharding (4).

<img width="1266" height="483" alt="image" src="https://github.com/user-attachments/assets/40198bc5-5e3f-486b-be56-12111e010e0c" />

https://fburl.com/mlhub/btkos8ok

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 23, 2025
…torchtitan

ghstack-source-id: 1ab4103
Pull Request resolved: #1714
@EquationWalker
Copy link

EquationWalker commented Sep 24, 2025

In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp .
By the way, DistributedDataParallel has experimentally supported native mixed precision, similar to MixedPrecisionPolicy of FSDP2. This means that perhaps torchtitan can remove torchtitan/distributed/utils.py/maybe_enable_amp() completely. See at DDP native mixed precision #92882.
cc @weifengpy @tianyu-l

@tianyu-l
Copy link
Contributor

@EquationWalker

In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp .

Great point!
@anshul-si Let's accommodate.

apply_ddp(
model,
world_mesh,
world_mesh[tuple(dp_mesh_dim_names)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this dp_mesh 2d? for user facing apis, it should be a 1d mesh or default 1d world mesh

inside replicate api, we can do 2d

Copy link
Author

@anshul-si anshul-si Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently replicate api doesn't have a method to convert the 1d mesh to 2d. I can leave a TODO to change this when the unflatten method for device mesh is complete, but until then I think it makes more sense to leave it like this. Also the user technically believes they are creating a 1d mesh because they are only changing replicate dim. I think from their perspective, they still are only creating a 1d mesh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember @fduwjj mentioned a mesh api to convert 1d to 2d. this needs to be done before landing. it's a user contract. if we change 2d back to 1d later, that becomes bc-breaking

@weifengpy
Copy link
Contributor

In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp .
By the way, DistributedDataParallel has experimentally supported native mixed precision, similar to MixedPrecisionPolicy of FSDP2.

@EquationWalker good catch!

Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my request changes is mainly on 2d mesh. we should target 1d mesh for landing. it's a user contract in public facing api

@EquationWalker
Copy link

EquationWalker commented Sep 25, 2025

my request changes is mainly on 2d mesh. we should target 1d mesh for landing. it's a user contract in public facing api

I think the use of 2D mesh has something to do with the FSDPParamGroup user contract. When passing a 2D mesh, FSDPParamGroup treats it as an HSDP and then shard parameters in the second dimension and replicate parameters in the first dimension. If you pass a 1D Mesh, FSDPParamGroup will shard parameters on this mesh instead of replicating them.

pytorch/issues#159013 mentioned adding reshape operation to mesh, but it seems that pytorch has not implemented it.

One solution would be to recreate a new 2D mesh (N,1) using the 1D mesh shape information (N,) inside apply_ddp, but this would create new communication groups, and I'm not sure if this is an expensive operation.

If replicate_with_fsdp could internally convert any 2D mesh shape (N,M) to (N*M,1), and any 1D mesh shape (N,) to (N,1), perhaps we can use fully_shard and replicate_with_fsdp in combination.

cc @anshul-si @tianyu-l

@fegin
Copy link
Contributor

fegin commented Sep 25, 2025

The 1D mesh as the input has been in the roadmap but there are two issues.

  1. We need DeviceMesh unflatten support support which the PR is being reviewed but is not landed yet. But this is just a WIP.

  2. iirc, there is a concern about the input mesh being 1D but named_parameters() and state_dict() will actually return 2D DeviceMesh, which can cause inconsistent views for users.

@anshul-si @mori360 Please correct me if I'm wrong for the second issue.

@weifengpy
Copy link
Contributor

weifengpy commented Sep 25, 2025

2. iirc, there is a concern about the input mesh being 1D but named_parameters() and state_dict() will actually return 2D DeviceMesh, which can cause inconsistent views for users.

I was thinking named_parameters() and state_dict() returning 1D

We need DeviceMesh unflatten support support which the PR is being reviewed but is not landed yet

which PR? I don't mind creating a new 2D mesh inside replicate to unblock for now, but the user contract has to be right on day 1

  • input: 1D mesh
  • output: model.parameters() and model.state_dict() being 1D

…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

Below is a link comparing the loss curves for Llama3.1-8B models: one configured with dimension sharding (2) and tensor parallelism (4), and the other with dimension replication (2) and sharding (4).

<img width="1266" height="483" alt="image" src="https://github.com/user-attachments/assets/40198bc5-5e3f-486b-be56-12111e010e0c" />

https://fburl.com/mlhub/btkos8ok

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 29, 2025
…torchtitan

ghstack-source-id: 88bd5ed
Pull Request resolved: #1714
…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

Below is a link comparing the loss curves for Llama3.1-8B models: one configured with dimension sharding (2) and tensor parallelism (4), and the other with dimension replication (2) and sharding (4).

<img width="1266" height="483" alt="image" src="https://github.com/user-attachments/assets/40198bc5-5e3f-486b-be56-12111e010e0c" />

https://fburl.com/mlhub/btkos8ok

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants