17
17
18
18
from torchtitan .config .job_config import JobConfig
19
19
from torchtitan .config import TORCH_DTYPE_MAP
20
- from torchtitan .distributed import ParallelDims
20
+ from torchtitan .distributed import ParallelDims , NoParallel
21
+ from torchtitan .distributed .expert_parallel import ExpertParallel , ReordererSequenceParallel
21
22
from torchtitan .models .llama3 .infra .parallelize import apply_ac , apply_ddp
23
+ from torchtitan .experiments .llama4 .infra .parallelize import (
24
+ apply_fsdp ,
25
+ apply_moe_ep_tp ,
26
+ )
27
+
22
28
from torchtitan .tools .logging import logger
23
29
24
- from torch .distributed .fsdp import CPUOffloadPolicy , fully_shard , MixedPrecisionPolicy
25
30
from torch .distributed .tensor import Partial , Replicate , Shard
26
31
27
32
from .expert_parallel import (
28
- ExpertParallel ,
29
33
ExpertTensorParallel ,
30
34
TensorParallel ,
31
35
)
32
- from torchtitan .distributed import NoParallel
33
36
34
37
35
38
# for selective op activation checkpointing
36
- _save_list = {
39
+ _op_sac_save_list = {
37
40
torch .ops .aten .mm .default ,
38
41
torch .ops .aten ._scaled_dot_product_efficient_attention .default ,
39
42
torch .ops .aten ._scaled_dot_product_flash_attention .default ,
@@ -87,7 +90,7 @@ def parallelize_gptoss(
87
90
apply_non_moe_tp (
88
91
model ,
89
92
world_mesh ["tp" ],
90
- loss_parallel = parallel_dims . loss_parallel_enabled ,
93
+ loss_parallel = not job_config . parallelism . disable_loss_parallel ,
91
94
enable_float8_tensorwise_tp = False ,
92
95
enable_async_tp = False ,
93
96
)
@@ -99,9 +102,10 @@ def parallelize_gptoss(
99
102
ep_mesh = world_mesh ["ep" ] if parallel_dims .ep_enabled else None ,
100
103
ep_tp_mesh = (
101
104
world_mesh ["ep" , "tp" ]
102
- if parallel_dims .tp_enabled and parallel_dims .ep_enabled
105
+ if parallel_dims .tp_enabled and parallel_dims .ep_enabled and parallel_dims . etp_enabled
103
106
else None
104
107
),
108
+ etp_enabled = parallel_dims .etp_enabled ,
105
109
)
106
110
107
111
model_compile_enabled = (
@@ -114,7 +118,7 @@ def parallelize_gptoss(
114
118
job_config .activation_checkpoint ,
115
119
model_compile_enabled = model_compile_enabled ,
116
120
use_flex_attn = use_flex_attn ,
117
- save_list = _save_list ,
121
+ save_list = _op_sac_save_list ,
118
122
)
119
123
120
124
dp_mesh : DeviceMesh | None = None
@@ -263,83 +267,18 @@ def apply_non_moe_tp(
263
267
)
264
268
265
269
266
- def apply_fsdp (
267
- model : nn .Module ,
268
- dp_mesh : DeviceMesh ,
269
- param_dtype : torch .dtype ,
270
- reduce_dtype : torch .dtype ,
271
- pp_enabled : bool ,
272
- cpu_offload : bool = False ,
273
- reshard_after_forward_policy : str = "default" ,
274
- dp_mod_ep_mesh : DeviceMesh | None = None ,
275
- ):
276
- """
277
- Apply data parallelism (via FSDP2) to the model.
278
-
279
- Args:
280
- model (nn.Module): The model to apply data parallelism to.
281
- dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
282
- param_dtype (torch.dtype): The data type to use for model parameters.
283
- reduce_dtype (torch.dtype): The data type to use for reduction operations.
284
- pp_enabled (bool): Whether pipeline parallelism is enabled.
285
- cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
286
- reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
287
- Other options: "never", "always".
288
- - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
289
- - "always" will enable `reshard_after_forward` for all forward passes.
290
- - "never" will disable `reshard_after_forward` for all forward passes.
291
-
292
- """
293
- mp_policy = MixedPrecisionPolicy (param_dtype = param_dtype , reduce_dtype = reduce_dtype )
294
- fsdp_config = {"mesh" : dp_mesh , "mp_policy" : mp_policy }
295
- if cpu_offload :
296
- fsdp_config ["offload_policy" ] = CPUOffloadPolicy ()
297
-
298
- for layer_id , transformer_block in model .layers .items ():
299
- if reshard_after_forward_policy == "always" :
300
- reshard_after_forward = True
301
- elif reshard_after_forward_policy == "never" :
302
- reshard_after_forward = False
303
- elif reshard_after_forward_policy == "default" :
304
- if pp_enabled :
305
- # For PP, do not reshard after forward to avoid per-microbatch
306
- # all-gathers, which can be expensive and non-overlapped
307
- reshard_after_forward = False
308
- else :
309
- # As an optimization, do not reshard after forward for the last
310
- # transformer block since FSDP would prefetch it immediately
311
- reshard_after_forward = int (layer_id ) < len (model .layers ) - 1
312
- else :
313
- raise ValueError (
314
- f"Invalid reshard_after_forward_policy: { reshard_after_forward_policy } ."
315
- )
316
-
317
- # NOTE: in an MoE layer, the router and the shared experts
318
- # are sharded together with the TransformerBlock
319
- if dp_mod_ep_mesh :
320
- fsdp_mod_ep_config = fsdp_config .copy ()
321
- fsdp_mod_ep_config ["mesh" ] = dp_mod_ep_mesh
322
- fully_shard (
323
- transformer_block .moe .experts ,
324
- ** fsdp_mod_ep_config ,
325
- reshard_after_forward = reshard_after_forward ,
326
- )
327
-
328
- fully_shard (
329
- transformer_block ,
330
- ** fsdp_config ,
331
- reshard_after_forward = reshard_after_forward ,
332
- )
333
- fully_shard (model , ** fsdp_config , reshard_after_forward = not pp_enabled )
334
-
335
-
270
+ # NOTE(jianiw): The function can not be reused now because reimplemented ExpertTensorParallel
336
271
def apply_moe_ep_tp (
337
272
model : nn .Module ,
338
273
tp_mesh : DeviceMesh | None ,
339
274
ep_mesh : DeviceMesh | None ,
340
275
ep_tp_mesh : DeviceMesh | None ,
276
+ etp_enabled : bool ,
341
277
):
342
278
for transformer_block in model .layers .values ():
279
+ if not transformer_block .moe_enabled :
280
+ continue
281
+
343
282
if tp_mesh is not None :
344
283
moe_layer_plan = {
345
284
# input / output sharding on the seqlen dim
@@ -354,13 +293,28 @@ def apply_moe_ep_tp(
354
293
# replicate computation for the router
355
294
"moe.router.gate" : NoParallel (),
356
295
}
296
+ if ep_mesh is not None and not etp_enabled :
297
+ # If TP is borrowed for EP, then split the tokens across TP ranks so that
298
+ # the reorderer, the all-to-all comms, and routed experts computation
299
+ # are effectively running Sequence Parallel (split along the folded bs*slen dim)
300
+ moe_layer_plan .update ({"moe.reorderer" : ReordererSequenceParallel ()})
301
+ if transformer_block .moe .shared_experts is not None :
302
+ # input Replicate, output Partial
303
+ moe_layer_plan .update (
304
+ {
305
+ "moe.shared_experts.w1" : ColwiseParallel (),
306
+ "moe.shared_experts.w2" : RowwiseParallel (
307
+ output_layouts = Partial ()
308
+ ),
309
+ "moe.shared_experts.w3" : ColwiseParallel (),
310
+ }
311
+ )
357
312
parallelize_module (
358
313
module = transformer_block ,
359
314
device_mesh = tp_mesh ,
360
315
parallelize_plan = moe_layer_plan ,
361
316
)
362
317
363
- # if ep_mesh is not None:
364
318
experts_mesh , experts_plan = None , None
365
319
if ep_mesh is None :
366
320
experts_mesh = tp_mesh
@@ -370,9 +324,13 @@ def apply_moe_ep_tp(
370
324
experts_mesh = ep_mesh
371
325
# input / output sharding on the batch / tokens dim
372
326
experts_plan = ExpertParallel ()
373
- else :
327
+ elif etp_enabled :
374
328
experts_mesh = ep_tp_mesh
375
329
experts_plan = ExpertTensorParallel (tp_mesh = tp_mesh , ep_mesh = ep_mesh )
330
+ else :
331
+ experts_mesh = ep_mesh
332
+ experts_plan = ExpertParallel ()
333
+
376
334
parallelize_module (
377
335
module = transformer_block .moe .experts ,
378
336
device_mesh = experts_mesh ,
0 commit comments