Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,12 @@ def apply_fsdp(
transformer_block.moe.experts.set_gradient_divide_factor(
gradient_divide_factor,
)
else:
fully_shard(
transformer_block._checkpoint_wrapped_module.feed_forward if hasattr(transformer_block, "_checkpoint_wrapped_module") else transformer_block.feed_forward,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)

fully_shard(
transformer_block,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@
dim=2048,
inter_dim=10944,
moe_inter_dim=1408,
n_layers=27,
n_dense_layers=1,
n_layers=5,
n_dense_layers=5,
n_heads=16,
moe_args=MoEArgs(
num_experts=64,
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def parallelize_deepseekv3(
else:
logger.info("Applied FSDP to the model")

# import fbvscode
# fbvscode.set_trace()

if parallel_dims.cp_enabled:
logger.info("Applied Context Parallel to the model")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ description = "DeepSeek-V3 16B model training"
print_args = false

[profiling]
enable_profiling = false
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
enable_memory_snapshot = true
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
Expand Down Expand Up @@ -35,10 +35,10 @@ decay_type = "cosine"
min_lr_factor = 0.1

[training]
local_batch_size = 8
seq_len = 4096
local_batch_size = 1
seq_len = 4
max_norm = 1.0 # grad norm clipping
steps = 1000
steps = 20
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
Expand All @@ -49,7 +49,7 @@ tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
pipeline_parallel_schedule = "Interleaved1F1B"
expert_parallel_degree = 8
expert_parallel_degree = 1
expert_tensor_parallel_degree = 1

[checkpoint]
Expand All @@ -61,11 +61,11 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]"

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
mode = "full" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=true
enable=false
components = ["loss"] # ["model", "loss"]

[float8]
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
),
"8B": TransformerModelArgs(
dim=4096,
n_layers=32,
n_layers=3,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
Expand Down
5 changes: 5 additions & 0 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ def apply_fsdp(
reshard_after_forward=reshard_after_forward,
)
for layer_id, transformer_block in model.layers.items():
fully_shard(
transformer_block.feed_forward if not hasattr(transformer_block, "_checkpoint_wrapped_module") else transformer_block._checkpoint_wrapped_module.feed_forward,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(
transformer_block,
**fsdp_config,
Expand Down
10 changes: 6 additions & 4 deletions torchtitan/models/llama3/train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ description = "Llama 3 8B training"
[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
profile_freq = 10
enable_memory_snapshot = true
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 10
Expand All @@ -30,9 +32,9 @@ warmup_steps = 200 # lr scheduler warm up

[training]
local_batch_size = 1
seq_len = 8192
seq_len = 4
max_norm = 1.0 # grad norm clipping
steps = 1000
steps = 100
dataset = "c4"

[parallelism]
Expand All @@ -55,7 +57,7 @@ enable=false
components = ["model", "loss"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
mode = "full" # ["none", "selective", "full"]
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down
Loading