Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jan 18, 2026

Stacked PRs:


[mxfp8 training] add new configurable params now exposed by torchao

Summary

Expose job config param to control mxfp8 dim0 quantization kernel used in torchao.

  • dim0 mxfp8 quantization kernel options
    • torch: ~90% memory bandwidth utilization (requires torch.compile for perf)
    • triton: ~82% memory bandwidth utilization (does not require torch.compile)

Tests

(test depends on full PR stack)

CONFIG_FILE=/home/dev/torchtitan/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ./run_train.sh --metrics.log_freq=10 \
--training.steps=1500  \
--parallelism.data_parallel_shard_degree=4 \
--parallelism.expert_parallel_degree=4 \
--parallelism.tensor_parallel_degree=2 \
--parallelism.expert_tensor_parallel_degree=1 \
--training.seq_len=8192 \
--activation_checkpoint.mode=full \
--model.print_after_conversion \
--training.local_batch_size=16 \
--quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="triton" --quantize.linear.mx.mxfp8_dim1_cast_kernel_choice="cuda" \
--quantize.grouped_mm.mx.fqns="experts" --quantize.grouped_mm.mx.recipe_name="mxfp8_wgrad_with_hp" \
--compile.enable --compile.components="model,loss" --debug.moe_force_load_balance \
--model.converters="quantize.grouped_mm.mx"

danielvegamyhre added a commit that referenced this pull request Jan 18, 2026
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/5 branch from cdce43e to a071d1c Compare January 18, 2026 23:21
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 18, 2026
@danielvegamyhre danielvegamyhre changed the title [mxfp8 training] add new configurable params now exposed by torchao [mxfp8 training] add mxfp8 dim0 cast kernel choice to MXLinear Jan 18, 2026
Example: --quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="torch"
"""

mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have asked before but forgot -- if CUDA is recommended, why we put triton as default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had triton as the default because previously using the CUDA kernels required building torchao from source. Now the CUDA kernels are shipped in the wheels, we could update the default to CUDA, actually.

@danielvegamyhre danielvegamyhre marked this pull request as draft January 20, 2026 20:40
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/4 to main January 20, 2026 20:40
danielvegamyhre added a commit that referenced this pull request Jan 20, 2026
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/5 branch from a071d1c to 84cf472 Compare January 20, 2026 20:40
@danielvegamyhre danielvegamyhre changed the title [mxfp8 training] add mxfp8 dim0 cast kernel choice to MXLinear [mxfp8 training] add new configurable params now exposed by torchao Jan 20, 2026
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/4 January 20, 2026 20:40
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 20, 2026 20:40

@dataclass
class MXLinear:
mxfp8_dim0_cast_kernel_choice: Literal["triton", "torch"] = "torch"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

describe why we need multiple options and when is each applicable, e.g. need to enable compile for torch, etc.

Example: --quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="torch"
"""

mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar.
Also, would you like to change default to cuda?

@danielvegamyhre danielvegamyhre marked this pull request as draft January 20, 2026 22:49
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/4 to main January 20, 2026 22:49
danielvegamyhre added a commit that referenced this pull request Jan 20, 2026
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/5 branch from 84cf472 to 67f4155 Compare January 20, 2026 22:49
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/4 January 20, 2026 22:50
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 20, 2026 22:50
@danielvegamyhre danielvegamyhre marked this pull request as draft January 21, 2026 04:22
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/4 to main January 21, 2026 04:22
danielvegamyhre added a commit that referenced this pull request Jan 21, 2026
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/5 branch from 67f4155 to 3289e3c Compare January 21, 2026 04:22
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/4 to main January 21, 2026 21:51
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/4 January 21, 2026 21:51
@danielvegamyhre danielvegamyhre marked this pull request as ready for review January 21, 2026 21:52
@dataclass
class MXLinear:
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
mxfp8_dim0_cast_kernel_choice: Literal["triton", "torch"] = "triton"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default recipe is mxfp8_cublas for which torch kernel is recommended -- can we make the default consistent?

Example: --quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="torch"
"""

mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "cuda"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to keep triton and torch? If so, when are they recommended?

# by wrapping each submod's forward instead of their __call__
moe = submod
for attr_name, submod in moe.named_children():
if attr_name == "experts":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this issue gone?

@danielvegamyhre danielvegamyhre marked this pull request as draft February 3, 2026 20:03
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/4 to main February 3, 2026 20:03
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/5 branch from 3289e3c to 9045425 Compare February 3, 2026 20:03
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/4 February 3, 2026 20:03
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 3, 2026 20:03
@danielvegamyhre danielvegamyhre marked this pull request as draft February 4, 2026 18:49
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/4 to main February 4, 2026 18:49
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/5 branch from 9045425 to 57646ab Compare February 4, 2026 18:49
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/4 February 4, 2026 18:49
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 4, 2026 18:50
@danielvegamyhre danielvegamyhre marked this pull request as draft February 4, 2026 18:56
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/4 to main February 4, 2026 18:56
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/5 branch from 57646ab to d6c15de Compare February 4, 2026 18:56
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/4 February 4, 2026 18:56
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 4, 2026 18:56
@danielvegamyhre danielvegamyhre marked this pull request as draft February 4, 2026 19:16
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/4 to main February 4, 2026 19:16
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/4 February 4, 2026 19:16
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 4, 2026 19:16
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
@danielvegamyhre danielvegamyhre marked this pull request as draft February 6, 2026 22:59
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/4 to main February 6, 2026 22:59
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/5 branch from d6c15de to e985503 Compare February 6, 2026 22:59
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/4 February 6, 2026 23:00
@danielvegamyhre danielvegamyhre marked this pull request as ready for review February 6, 2026 23:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants