Skip to content

Commit 114b2b1

Browse files
Implementation of matmul, tensordot, and vecdot per array API (#1490)
* Adds ThreeOffsets_CombinedIndexer This enables strided data processing by gemm kernels * Remove unused `elementwise_functions.cpp` * Implements `matmul`, `vecdot`, and `tensordot` These three functions are implemented through a common `py_dot` binding, which is also part of a new tensor submodule `_tensor_linalg_impl` * Tweaks to `matmul` and gemm kernels Fixes a missing indexer in gemm functor with threading along `nm` dimensions Fixes `matmul` broadcasting, which was broadcasting in some unnecessary cases * Remove double-counting of batch offset in gemm batch tree reduction * Fixes missing dependency in vecdot When the first argument would not be cast and the second argument would be, the copy dependency was not appended to the list of dependencies, creating race conditions * Run test_matmul_simple2 in Windows before full test suite Part of triaging crashes on Windows * Test removing test_matmul_simple leaving only test_matmul_simple2 * Fix incorrect comments throughtout gemm kernels Comments incorrectly stated that the third argument to `scale_gemm_k_parameters` is modified by reference * Drastically reduced parameters used for gemm kernels which thread over k Experimental change to see if this stabilizes CI * Test removal of k-threading gemm kernel which writes to multiple outputs atomically * Refactors `gemm_tree_impl` Now uses two smaller functions, `gemm_tree_k_impl` and `gemm_tree_nm_impl` for greater readability * Reverse order of numeric types passed to test_matmul_simple2 May improve stability on CPU * Refactors `gemm_contig_tree_impl` `gemm_contig_tree_impl` now calls new functions `gemm_contig_tree_k_impl` and `gemm_contig_tree_nm_impl` * Refactoring `gemm_batch_tree` functions Adds new functions for calling `nm` threading and `k` threading kernels to improve readability * Test reversing data types for `test_matmul_strided` * pre-commit fixes in `gemm.hpp` * Check if malloc_device return nullptr (#1493) * Add step to Linux conda package workflow to run `test_matmul_strided` under gdb Part of triaging CPU crashes * Remove unnecessary comments * Adds a fast-path for empty (k = 0) gemm kernels * Adds logic that avoids certain kernels on CPU that are known to be problematic Specifically uses logic to always avoid paths which would call k threaded functors on CPU with m_groups > 1 * Also access memory if indices are in range This prevents out-of-bound access that was responsible for crashes observed in CI. * Simplified computation of m_id/gr_id in kernels No need to use both it.get_global_linear_id() and it.get_group_linear_id() to compute batch id and group id. * Change generic kernels to work for any value of m_groups, not just m_groups=2 * Remove work-arounds/special-casing for CPUs * Extended test_matmul_strided, reverted work-arounds * Revert remaining gemm work-arounds This commit removes remaining checks for if a kernel is called on CPU as well as reverting hyperparameters for gemm kernels to their original values * Revert tuning down of `gemm` kernel parameters * Removed logically dead code from _linear_algebra_functions.py * Added more tests to improve coverage of _linear_algebra_functions * Fixed "UnboundLocalError: local variable 'buf1_dt' referenced before assignment" Initialized buf1_dt and buf2_dt to None * More tests to improve coverage * Removed more dead branches in _linear_algebra_functions.py * `tensordot` now properly handles negative `axes` As per array API, negative axes are not permitted * Adds `test_tensordot_type_matrix` to `test_usm_ndarray_linalg.py` * Addresses flaws in gemm tree kernel logic Previously, assertions for calling a full tree reduction with only a single work-group of elements could be tripped The kernel logic has been changed such that this is no longer possible * Implements `__matmul__`, `__imatmul__`, and `__rmatmul__` operators for usm_ndarray * Makes usm_ndarray operator argument names consistent * Test changes for `tensordot` Adds a test for axes errors in `tensordot` for negative axes Incorporates test for `tensordot` promotion of both inputs into `test_tensordot_type_promotion` * Reverts running certain `matmul` tests under gdb * Fix to typo in `test_tensordot_promotion` * Removes unnecessary input type checks in `matmul` * More tests added to `test_usm_linalg.py` Adds several tests for `matmul` and expands some `tensordot` and `vecdot` tests to improve coverage * Use result_type with tensors to take device capability into account * Use order keyword in test of type promotion for matmul * Make generic k-threaded kernels handle arbitrary m_groups Also increases hyper-parameters for k-threaded kernels to improve performance * Adjusted dispatch logic for gemm kernels Now uses m_groups = 4 when m > 4, and otherwise, m_groups = 1 to improve performance --------- Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
1 parent 9a80b47 commit 114b2b1

File tree

14 files changed

+11194
-5193
lines changed

14 files changed

+11194
-5193
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ set(_tensor_sorting_impl_sources
156156
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
157157
${_sorting_sources}
158158
)
159+
set(_linalg_sources
160+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
161+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp
162+
)
163+
set(_tensor_linalg_impl_sources
164+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp
165+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
166+
${_linalg_sources}
167+
)
159168

160169
set(_py_trgts)
161170

@@ -179,6 +188,11 @@ pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources}
179188
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources})
180189
list(APPEND _py_trgts ${python_module_name})
181190

191+
set(python_module_name _tensor_linalg_impl)
192+
pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
193+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})
194+
list(APPEND _py_trgts ${python_module_name})
195+
182196
set(_clang_prefix "")
183197
if (WIN32)
184198
set(_clang_prefix "/clang:")
@@ -193,6 +207,7 @@ list(APPEND _no_fast_math_sources
193207
${_elementwise_sources}
194208
${_reduction_sources}
195209
${_sorting_sources}
210+
${_linalg_sources}
196211
)
197212

198213
foreach(_src_fn ${_no_fast_math_sources})

dpctl/tensor/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@
6060
from dpctl.tensor._device import Device
6161
from dpctl.tensor._dlpack import from_dlpack
6262
from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take
63-
from dpctl.tensor._linear_algebra_functions import matrix_transpose
63+
from dpctl.tensor._linear_algebra_functions import (
64+
matmul,
65+
matrix_transpose,
66+
tensordot,
67+
vecdot,
68+
)
6469
from dpctl.tensor._manipulation_functions import (
6570
broadcast_arrays,
6671
broadcast_to,
@@ -356,4 +361,7 @@
356361
"unique_counts",
357362
"unique_inverse",
358363
"unique_values",
364+
"matmul",
365+
"tensordot",
366+
"vecdot",
359367
]

0 commit comments

Comments
 (0)