You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
PR tensorflow#21886: [ROCM][NFC] BlasLt interface refactoring & simplifying: part I
Imported from GitHub PR openxla/xla#21886
After this PR tensorflow#73926 is merged, we can remove unnecessary low-level DoMatmul functions from GpuBlasLt interface (which otherwise looks scary and unnecessarily complicated).
Furthermore, we can also remove **ValidateInputs** function from the interface and derived classes since a high-level **ExecuteOnStream** function already handles data-types correctly. This also greatly simplifies the code.
Also, I have packed the input arguments of ExecuteOnStream calls to a struct **MemoryArgs** to simplify arguments passing in derived classes and improve code readability.
Finally, in the original GpuBlasLt PR: openxla/xla#5911, I made a sort of mistake by adding a reference to **blas_lt** to the MatmulPlan class [here](https://github.com/openxla/xla/blob/main/xla/stream_executor/rocm/hip_blas_lt.h#L135), thereby making MatmulPlans bound to a **particular BlasLt instance**. This resulted in some further bugfixes and, most importantly, complicated GpuBlasLt cache design in gpublas_lt_matmul_thunk.cc/.h. In this PR, I remove this reference again from MatmulPlan class and in the next NFC PR the cache mechanics can also be simplified.
Unfortunately, this change also requires a tandem PR for Tensorflow: tensorflow#85835@xla-rotation Would you please have a look
Copybara import of the project:
--
e96bb2fbedab3f53b31ef0e1748582c76e9fb105 by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>:
blaslt interface refactoring: removing blas_lt_ref
added cuda adaptions
cuda-side adaptions
cuda side adaptions
fix
fixing pointers
Merging this change closestensorflow#21886
PiperOrigin-RevId: 727898957
0 commit comments