Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor wrapper subclass] Add support for torchao.float8 mlp #1585

Draft
wants to merge 1 commit into
base: tensor_subclass_2
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Dec 23, 2024

What does this PR do?

Multiple changes for thunder.jit to support a torchao.float8 MLP (see the test):

  • Add support of torch._scaled_mm
  • Update _general_jit_torch_autograd_function_apply_lookaside

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant