Skip to content

Conversation

@asunderwood
Copy link
Collaborator

Motivation

TritonBLAS aims to have a torch-like API to aid in switching between PyTorch's own functions and tritonBLAS's with as little impact on the user as possible. Ideally this would include support for various parts of the torch ecosystem including full autograd support for training, compatibility with torch.compile, and the ability to pick between in-place and out-of-place operations all from one, unified API.

Technical Details

This PR adds torch.compile support, autograd support, and in-place/out-of-place support to the matmul() function.

The PR also adds the addm() function to support full GEMM operations (D = A @ B + C) with the same torch.compile, autograd, and in-place/out-of-place support.

Test Plan

2 new batches of tests are added:

  • matmul() correctness tests which use the front-end matmul() API on a variety of matrix sizes including several which were major pain-points during development along with permutations of with/without torch.compile and with/without autograd. Error checks are also done for invalid situations such as in-place autograd.
  • addmm() correctness tests which perform the same gauntlet of tests as above but for the front-end addmm() API.

Test Result

$ pytest tests/test_matmul_correctness.py tests/test_addmm_correctness.py

~ snip ~

===== 3 failed, 265 passed, 15 warnings in 180.48s (0:03:00) =====

Failed test are explained below, the warnings are about compiler caches being disabled (to force each test run to recompile every time) and a deprecated way of calling torch.compile which seems to happen when pytest calls the not-deprecated APIs.

Known Issues

This work necessitated a change in the way the bias vector is handled in the triton code (it seems it was incorrect before given that the bias is often provided as a 1D tensor). I was able to fix it in the persistent_matmul path but do not have the fix for the streamk_matmul path. Thus, streamk does not currently work with these changes (and the test results show this).

When a fix is implemented I will change the tests to permute streamk=True/False to improve test coverage further but for now it's easier to debug with the known-issue isolated to a small test at the end of the file.

a8w8 tests are also failing and that appears to be unrelated to the changes made in this PR as far as I can tell.

Submission Checklist

Most of torch's APIs allocate new tensor objects when calling high level
functions and provide the option to perform in-place operations.  To
keep things consistent with argument names and better match what torch
does, this commit also moves the project in that direction by changing
the output matrix to be called `out` inside the matmul() and addmm()
functions and, unless that argument is set, does output tensor
allocations before calling the triton kernels within.
Despite this addition, backwards pass doesn't currently work because
autograd is incompatible with mutated input args (`out=`).
The version of triton that torch.compile uses has an issue with the way
tl.multiple_of generates load operations - in some situations it is
possible to generate an illegal vectorized load setup which will result
in incorrect operation results but without any warning/error.  This does
not happen outside of torch.compile where a maineline version of triton
is used instead of torch's specific version.  This commit adds a
workaround only for the torch.compile case which will be removed once
the torch.compile triton version is updated to fix this issue.
The following changes needed to be made for this merge to be successful:
- Remove previous commit's _vectorized_load_fix as it's no longer an
  issue
- Perform type promotion on several more values in triton as
  torch.compile often passes around SymInts during Dynamo passes
- Create a factory for schedules to handle type promotion there
- Swap the ordering of slicing used on the bias vector (implemented in
  this branch previously but location for the fix changed in main)
- Change bias input args to be N-centric instead of M-centric
There is a bug in testing at the moment which causes CUDA RuntimeErrors
due to muliple forked processes trying to initialize the GPU runtime
during torch.compile runs.  The temporary fix is to force Inductor to
run single-threaded after which the error is gone and all tests pass,
but changing the multiprocessing methodology to 'spawn' over 'fork' may
be a better long-term solution (or submitting a bug with torch because
this wasn't happening before).
@asunderwood
Copy link
Collaborator Author

asunderwood commented Feb 6, 2026

CI failed but it's the known-problematic StreamK tests as mentioned in the OP:

============ 3 failed, 476 passed, 15 warnings in 407.00s (0:06:46) ============

The rest of the tests are passing including all of the persistent kernels.

@triton.jit
def make_schedule_context(M, N, K, ctx: GemmContext, streamk_tiles=0):
"""
Create a ScheduleContext from a GemmContext.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is required because the M,N,K can be const or non const depending on their value (const if it's one), so we type promote here to match strong typing in aggregates.

@asunderwood asunderwood merged commit cd11927 into main Feb 6, 2026
3 of 4 checks passed
asunderwood added a commit that referenced this pull request Feb 6, 2026
@asunderwood
Copy link
Collaborator Author

Alright, things aren't totally solved - the bug that c8731b3 provided a workaround for in the tests is showing up in real-world problems.

Unfortunately that may be a torch-triton-runtime issue requiring upstream fixing...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants