Skip to content

Commit ca57b78

Browse files
committed
FSDP work, TP doesn't work
1 parent 4fc78a3 commit ca57b78

File tree

6 files changed

+51
-87
lines changed

6 files changed

+51
-87
lines changed

torchtitan/distributed/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def create_context_parallel_ctx(
186186

187187

188188
def get_train_context(
189-
enable_loss_parallel: bool, enable_compiled_autograd: bool, use_sdpa: bool = True
189+
enable_loss_parallel: bool, enable_compiled_autograd: bool
190190
) -> Generator[None, None, None]:
191191
@contextlib.contextmanager
192192
def context(cp_context: Generator[None, None, None] | None = None):

torchtitan/experiments/gpt_oss/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
route_scale=1.0,
3535
score_before_experts=False,
3636
top_k=4,
37-
use_grouped_mm=False,
37+
use_grouped_mm=True,
3838
load_balance_coeff=1e-3,
3939
),
4040
use_flex_attn=True,

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 38 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,26 @@
1717

1818
from torchtitan.config.job_config import JobConfig
1919
from torchtitan.config import TORCH_DTYPE_MAP
20-
from torchtitan.distributed import ParallelDims
20+
from torchtitan.distributed import ParallelDims, NoParallel
21+
from torchtitan.distributed.expert_parallel import ExpertParallel, ReordererSequenceParallel
2122
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
23+
from torchtitan.experiments.llama4.infra.parallelize import (
24+
apply_fsdp,
25+
apply_moe_ep_tp,
26+
)
27+
2228
from torchtitan.tools.logging import logger
2329

24-
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
2530
from torch.distributed.tensor import Partial, Replicate, Shard
2631

2732
from .expert_parallel import (
28-
ExpertParallel,
2933
ExpertTensorParallel,
3034
TensorParallel,
3135
)
32-
from torchtitan.distributed import NoParallel
3336

3437

3538
# for selective op activation checkpointing
36-
_save_list = {
39+
_op_sac_save_list = {
3740
torch.ops.aten.mm.default,
3841
torch.ops.aten._scaled_dot_product_efficient_attention.default,
3942
torch.ops.aten._scaled_dot_product_flash_attention.default,
@@ -87,7 +90,7 @@ def parallelize_gptoss(
8790
apply_non_moe_tp(
8891
model,
8992
world_mesh["tp"],
90-
loss_parallel=parallel_dims.loss_parallel_enabled,
93+
loss_parallel=not job_config.parallelism.disable_loss_parallel,
9194
enable_float8_tensorwise_tp=False,
9295
enable_async_tp=False,
9396
)
@@ -99,9 +102,10 @@ def parallelize_gptoss(
99102
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
100103
ep_tp_mesh=(
101104
world_mesh["ep", "tp"]
102-
if parallel_dims.tp_enabled and parallel_dims.ep_enabled
105+
if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled
103106
else None
104107
),
108+
etp_enabled=parallel_dims.etp_enabled,
105109
)
106110

107111
model_compile_enabled = (
@@ -114,7 +118,7 @@ def parallelize_gptoss(
114118
job_config.activation_checkpoint,
115119
model_compile_enabled=model_compile_enabled,
116120
use_flex_attn=use_flex_attn,
117-
save_list=_save_list,
121+
save_list=_op_sac_save_list,
118122
)
119123

120124
dp_mesh: DeviceMesh | None = None
@@ -263,83 +267,18 @@ def apply_non_moe_tp(
263267
)
264268

265269

266-
def apply_fsdp(
267-
model: nn.Module,
268-
dp_mesh: DeviceMesh,
269-
param_dtype: torch.dtype,
270-
reduce_dtype: torch.dtype,
271-
pp_enabled: bool,
272-
cpu_offload: bool = False,
273-
reshard_after_forward_policy: str = "default",
274-
dp_mod_ep_mesh: DeviceMesh | None = None,
275-
):
276-
"""
277-
Apply data parallelism (via FSDP2) to the model.
278-
279-
Args:
280-
model (nn.Module): The model to apply data parallelism to.
281-
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
282-
param_dtype (torch.dtype): The data type to use for model parameters.
283-
reduce_dtype (torch.dtype): The data type to use for reduction operations.
284-
pp_enabled (bool): Whether pipeline parallelism is enabled.
285-
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
286-
reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
287-
Other options: "never", "always".
288-
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
289-
- "always" will enable `reshard_after_forward` for all forward passes.
290-
- "never" will disable `reshard_after_forward` for all forward passes.
291-
292-
"""
293-
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
294-
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
295-
if cpu_offload:
296-
fsdp_config["offload_policy"] = CPUOffloadPolicy()
297-
298-
for layer_id, transformer_block in model.layers.items():
299-
if reshard_after_forward_policy == "always":
300-
reshard_after_forward = True
301-
elif reshard_after_forward_policy == "never":
302-
reshard_after_forward = False
303-
elif reshard_after_forward_policy == "default":
304-
if pp_enabled:
305-
# For PP, do not reshard after forward to avoid per-microbatch
306-
# all-gathers, which can be expensive and non-overlapped
307-
reshard_after_forward = False
308-
else:
309-
# As an optimization, do not reshard after forward for the last
310-
# transformer block since FSDP would prefetch it immediately
311-
reshard_after_forward = int(layer_id) < len(model.layers) - 1
312-
else:
313-
raise ValueError(
314-
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
315-
)
316-
317-
# NOTE: in an MoE layer, the router and the shared experts
318-
# are sharded together with the TransformerBlock
319-
if dp_mod_ep_mesh:
320-
fsdp_mod_ep_config = fsdp_config.copy()
321-
fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh
322-
fully_shard(
323-
transformer_block.moe.experts,
324-
**fsdp_mod_ep_config,
325-
reshard_after_forward=reshard_after_forward,
326-
)
327-
328-
fully_shard(
329-
transformer_block,
330-
**fsdp_config,
331-
reshard_after_forward=reshard_after_forward,
332-
)
333-
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
334-
335-
270+
# NOTE(jianiw): The function can not be reused now because reimplemented ExpertTensorParallel
336271
def apply_moe_ep_tp(
337272
model: nn.Module,
338273
tp_mesh: DeviceMesh | None,
339274
ep_mesh: DeviceMesh | None,
340275
ep_tp_mesh: DeviceMesh | None,
276+
etp_enabled: bool,
341277
):
342278
for transformer_block in model.layers.values():
279+
if not transformer_block.moe_enabled:
280+
continue
281+
343282
if tp_mesh is not None:
344283
moe_layer_plan = {
345284
# input / output sharding on the seqlen dim
@@ -354,13 +293,28 @@ def apply_moe_ep_tp(
354293
# replicate computation for the router
355294
"moe.router.gate": NoParallel(),
356295
}
296+
if ep_mesh is not None and not etp_enabled:
297+
# If TP is borrowed for EP, then split the tokens across TP ranks so that
298+
# the reorderer, the all-to-all comms, and routed experts computation
299+
# are effectively running Sequence Parallel (split along the folded bs*slen dim)
300+
moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()})
301+
if transformer_block.moe.shared_experts is not None:
302+
# input Replicate, output Partial
303+
moe_layer_plan.update(
304+
{
305+
"moe.shared_experts.w1": ColwiseParallel(),
306+
"moe.shared_experts.w2": RowwiseParallel(
307+
output_layouts=Partial()
308+
),
309+
"moe.shared_experts.w3": ColwiseParallel(),
310+
}
311+
)
357312
parallelize_module(
358313
module=transformer_block,
359314
device_mesh=tp_mesh,
360315
parallelize_plan=moe_layer_plan,
361316
)
362317

363-
# if ep_mesh is not None:
364318
experts_mesh, experts_plan = None, None
365319
if ep_mesh is None:
366320
experts_mesh = tp_mesh
@@ -370,9 +324,13 @@ def apply_moe_ep_tp(
370324
experts_mesh = ep_mesh
371325
# input / output sharding on the batch / tokens dim
372326
experts_plan = ExpertParallel()
373-
else:
327+
elif etp_enabled:
374328
experts_mesh = ep_tp_mesh
375329
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
330+
else:
331+
experts_mesh = ep_mesh
332+
experts_plan = ExpertParallel()
333+
376334
parallelize_module(
377335
module=transformer_block.moe.experts,
378336
device_mesh=experts_mesh,

torchtitan/experiments/gpt_oss/model/model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
9999
)
100100

101101

102-
# TODO(jianw): This is eager version from HuggingFace
102+
# TODO(jianw): This is eager version from HuggingFace. Remove it once FlexAttention is ready.
103103
def eager_attention_forward(
104104
query: torch.Tensor,
105105
key: torch.Tensor,
@@ -109,8 +109,15 @@ def eager_attention_forward(
109109
scaling: float,
110110
dropout: float = 0.0,
111111
**kwargs,
112-
):
113-
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
112+
):
113+
key_values = key.transpose(2, 3) # When TP is enabled, key should be shard()
114+
print(f"key_values : {key_values.placements} {key_values.shape}")
115+
print(f"query : {query.placements} {query.shape}")
116+
117+
# [rank0]:key_values : (Shard(dim=1),) torch.Size([8, 64, 64, 2048])
118+
# [rank0]:query : (Shard(dim=1),) torch.Size([8, 64, 2048, 64])
119+
120+
attn_weights = query @ key_values * scaling
114121
if attention_mask is not None:
115122
# attention_mask can be [Tq, Tk] or [B, H, Tq, Tk]
116123
# Convert boolean "allowed" -> additive mask

torchtitan/experiments/gpt_oss/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4646
data_parallel_replicate_degree = 1
4747
data_parallel_shard_degree = -1
4848
fsdp_reshard_after_forward = "default" # default / never / always
49-
tensor_parallel_degree = 1
49+
tensor_parallel_degree = 2
5050
enable_async_tensor_parallel = false
5151
pipeline_parallel_degree = 1
5252
context_parallel_degree = 1

torchtitan/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@ def __init__(self, job_config: JobConfig):
329329
self.train_context = dist_utils.get_train_context(
330330
loss_parallel_enabled,
331331
parallelism_config.enable_compiled_autograd,
332-
use_sdpa=not getattr(model_args, "use_flex_attn", False),
333332
)
334333
self.maybe_enable_amp = dist_utils.maybe_enable_amp(
335334
parallel_dims,

0 commit comments

Comments
 (0)