diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index cda64a281d..db8df37fae 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -53,12 +53,18 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # Configure MXFP8 from torchao.prototype.mx_formats.config import ( + MXFP8Dim0CastKernelChoice, MXFP8Dim1CastKernelChoice, MXLinearConfig as TorchAOMXLinearConfig, ) mx_job_config = job_config.quantize.linear.mx config = TorchAOMXLinearConfig.from_recipe_name(mx_job_config.recipe_name) + # pyrefly: ignore [missing-attribute] + config.mxfp8_dim0_cast_kernel_choice = MXFP8Dim0CastKernelChoice[ + mx_job_config.mxfp8_dim0_cast_kernel_choice.upper() + ] + # pyrefly: ignore [missing-attribute] config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[ mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() ] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 53bda77754..461f0da86f 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -757,7 +757,17 @@ class Float8GroupedMM: @dataclass class MXLinear: - mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton" + mxfp8_dim0_cast_kernel_choice: Literal["triton", "torch"] = "triton" + """ + Temp work around for inductor performance gap. + + * triton is recommended for best performance for recipe "mxfp8_cublas_rceil" (rceil scale rounding mode) + * torch is recommended for best performance for recipe "mxfp8_cublas" (floor scale rounding mode) + + Example: --quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="torch" + """ + + mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "cuda" """ Temp work around for inductor performance gap. diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 8324080389..9325a3cf4e 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -639,10 +639,6 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b # by wrapping each submod's forward instead of their __call__ moe = submod for attr_name, submod in moe.named_children(): - if attr_name == "experts": - # NOTE: We don't compile token dispatch and token combine due to an issue on B200: - # https://github.com/pytorch/torchtitan/issues/1940 - continue setattr( moe, attr_name,