Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchtitan/components/quantization/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,18 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

# Configure MXFP8
from torchao.prototype.mx_formats.config import (
MXFP8Dim0CastKernelChoice,
MXFP8Dim1CastKernelChoice,
MXLinearConfig as TorchAOMXLinearConfig,
)

mx_job_config = job_config.quantize.linear.mx
config = TorchAOMXLinearConfig.from_recipe_name(mx_job_config.recipe_name)
# pyrefly: ignore [missing-attribute]
config.mxfp8_dim0_cast_kernel_choice = MXFP8Dim0CastKernelChoice[
mx_job_config.mxfp8_dim0_cast_kernel_choice.upper()
]
# pyrefly: ignore [missing-attribute]
config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[
mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
]
Expand Down
12 changes: 11 additions & 1 deletion torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,17 @@ class Float8GroupedMM:

@dataclass
class MXLinear:
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
mxfp8_dim0_cast_kernel_choice: Literal["triton", "torch"] = "triton"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default recipe is mxfp8_cublas for which torch kernel is recommended -- can we make the default consistent?

"""
Temp work around for inductor performance gap.

* triton is recommended for best performance for recipe "mxfp8_cublas_rceil" (rceil scale rounding mode)
* torch is recommended for best performance for recipe "mxfp8_cublas" (floor scale rounding mode)

Example: --quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="torch"
"""

mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "cuda"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to keep triton and torch? If so, when are they recommended?

"""
Temp work around for inductor performance gap.

Expand Down
4 changes: 0 additions & 4 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,10 +639,6 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b
# by wrapping each submod's forward instead of their __call__
moe = submod
for attr_name, submod in moe.named_children():
if attr_name == "experts":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this issue gone?

# NOTE: We don't compile token dispatch and token combine due to an issue on B200:
# https://github.com/pytorch/torchtitan/issues/1940
continue
setattr(
moe,
attr_name,
Expand Down
Loading