-
Notifications
You must be signed in to change notification settings - Fork 9
Add torch-like addmm and matmul support #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Most of torch's APIs allocate new tensor objects when calling high level functions and provide the option to perform in-place operations. To keep things consistent with argument names and better match what torch does, this commit also moves the project in that direction by changing the output matrix to be called `out` inside the matmul() and addmm() functions and, unless that argument is set, does output tensor allocations before calling the triton kernels within.
Despite this addition, backwards pass doesn't currently work because autograd is incompatible with mutated input args (`out=`).
The version of triton that torch.compile uses has an issue with the way tl.multiple_of generates load operations - in some situations it is possible to generate an illegal vectorized load setup which will result in incorrect operation results but without any warning/error. This does not happen outside of torch.compile where a maineline version of triton is used instead of torch's specific version. This commit adds a workaround only for the torch.compile case which will be removed once the torch.compile triton version is updated to fix this issue.
The following changes needed to be made for this merge to be successful: - Remove previous commit's _vectorized_load_fix as it's no longer an issue - Perform type promotion on several more values in triton as torch.compile often passes around SymInts during Dynamo passes - Create a factory for schedules to handle type promotion there - Swap the ordering of slicing used on the bias vector (implemented in this branch previously but location for the fix changed in main) - Change bias input args to be N-centric instead of M-centric
There is a bug in testing at the moment which causes CUDA RuntimeErrors due to muliple forked processes trying to initialize the GPU runtime during torch.compile runs. The temporary fix is to force Inductor to run single-threaded after which the error is gone and all tests pass, but changing the multiprocessing methodology to 'spawn' over 'fork' may be a better long-term solution (or submitting a bug with torch because this wasn't happening before).
|
CI failed but it's the known-problematic StreamK tests as mentioned in the OP: The rest of the tests are passing including all of the persistent kernels. |
| @triton.jit | ||
| def make_schedule_context(M, N, K, ctx: GemmContext, streamk_tiles=0): | ||
| """ | ||
| Create a ScheduleContext from a GemmContext. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API is required because the M,N,K can be const or non const depending on their value (const if it's one), so we type promote here to match strong typing in aggregates.
This reverts commit cd11927.
|
Alright, things aren't totally solved - the bug that c8731b3 provided a workaround for in the tests is showing up in real-world problems. Unfortunately that may be a torch-triton-runtime issue requiring upstream fixing... |
Motivation
TritonBLAS aims to have a torch-like API to aid in switching between PyTorch's own functions and tritonBLAS's with as little impact on the user as possible. Ideally this would include support for various parts of the torch ecosystem including full autograd support for training, compatibility with torch.compile, and the ability to pick between in-place and out-of-place operations all from one, unified API.
Technical Details
This PR adds torch.compile support, autograd support, and in-place/out-of-place support to the
matmul()function.The PR also adds the
addm()function to support full GEMM operations (D = A @ B + C) with the same torch.compile, autograd, and in-place/out-of-place support.Test Plan
2 new batches of tests are added:
matmul()correctness tests which use the front-endmatmul()API on a variety of matrix sizes including several which were major pain-points during development along with permutations of with/without torch.compile and with/without autograd. Error checks are also done for invalid situations such as in-place autograd.addmm()correctness tests which perform the same gauntlet of tests as above but for the front-endaddmm()API.Test Result
Failed test are explained below, the warnings are about compiler caches being disabled (to force each test run to recompile every time) and a deprecated way of calling torch.compile which seems to happen when pytest calls the not-deprecated APIs.
Known Issues
This work necessitated a change in the way the bias vector is handled in the triton code (it seems it was incorrect before given that the bias is often provided as a 1D tensor). I was able to fix it in the persistent_matmul path but do not have the fix for the streamk_matmul path. Thus, streamk does not currently work with these changes (and the test results show this).
When a fix is implemented I will change the tests to permute streamk=True/False to improve test coverage further but for now it's easier to debug with the known-issue isolated to a small test at the end of the file.
a8w8 tests are also failing and that appears to be unrelated to the changes made in this PR as far as I can tell.
Submission Checklist