-
Notifications
You must be signed in to change notification settings - Fork 700
[mxfp8 training] add new configurable params now exposed by torchao #2251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: danielvegamyhre/stack/4
Are you sure you want to change the base?
Conversation
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
cdce43e to
a071d1c
Compare
torchtitan/config/job_config.py
Outdated
| Example: --quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="torch" | ||
| """ | ||
|
|
||
| mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might have asked before but forgot -- if CUDA is recommended, why we put triton as default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We had triton as the default because previously using the CUDA kernels required building torchao from source. Now the CUDA kernels are shipped in the wheels, we could update the default to CUDA, actually.
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
a071d1c to
84cf472
Compare
torchtitan/config/job_config.py
Outdated
|
|
||
| @dataclass | ||
| class MXLinear: | ||
| mxfp8_dim0_cast_kernel_choice: Literal["triton", "torch"] = "torch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
describe why we need multiple options and when is each applicable, e.g. need to enable compile for torch, etc.
torchtitan/config/job_config.py
Outdated
| Example: --quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="torch" | ||
| """ | ||
|
|
||
| mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar.
Also, would you like to change default to cuda?
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
84cf472 to
67f4155
Compare
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
67f4155 to
3289e3c
Compare
| @dataclass | ||
| class MXLinear: | ||
| mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton" | ||
| mxfp8_dim0_cast_kernel_choice: Literal["triton", "torch"] = "triton" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default recipe is mxfp8_cublas for which torch kernel is recommended -- can we make the default consistent?
| Example: --quantize.linear.mx.mxfp8_dim0_cast_kernel_choice="torch" | ||
| """ | ||
|
|
||
| mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "cuda" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to keep triton and torch? If so, when are they recommended?
| # by wrapping each submod's forward instead of their __call__ | ||
| moe = submod | ||
| for attr_name, submod in moe.named_children(): | ||
| if attr_name == "experts": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this issue gone?
3289e3c to
9045425
Compare
9045425 to
57646ab
Compare
57646ab to
d6c15de
Compare
stack-info: PR: #2251, branch: danielvegamyhre/stack/5
d6c15de to
e985503
Compare
Stacked PRs:
[mxfp8 training] add new configurable params now exposed by torchao
Summary
Expose job config param to control mxfp8 dim0 quantization kernel used in torchao.
torch: ~90% memory bandwidth utilization (requires torch.compile for perf)triton: ~82% memory bandwidth utilization (does not require torch.compile)Tests
(test depends on full PR stack)