From 5ffe768d86868a94ee8344807d074b740225f642 Mon Sep 17 00:00:00 2001 From: Affifboudaoud Date: Tue, 21 Oct 2025 17:05:09 +0200 Subject: [PATCH 1/7] Add 3D-4D broadcasting implementation and test --- dace/libraries/blas/nodes/batched_matmul.py | 127 ++++++++++++++------ dace/libraries/blas/nodes/matmul.py | 37 +++++- tests/library/batched_matmul_test.py | 123 +++++++++++++++++++ 3 files changed, 252 insertions(+), 35 deletions(-) diff --git a/dace/libraries/blas/nodes/batched_matmul.py b/dace/libraries/blas/nodes/batched_matmul.py index fc4d9dfdf5..8a29994b91 100644 --- a/dace/libraries/blas/nodes/batched_matmul.py +++ b/dace/libraries/blas/nodes/batched_matmul.py @@ -86,15 +86,22 @@ def make_sdfg(node, parent_state, parent_sdfg): if len(array_a.shape) == 2: memlet_a = '__im, __ik' else: - # Use output batch indices - a_batch_indices = ', '.join(['__i%d' % i for i in range(len(array_a.shape) - 2)]) + # Align input batch dims to output batch dims + num_a_batch = len(array_a.shape) - 2 + # Start from the rightmost batch dimension of output and work backwards + offset = num_batch_dims - num_a_batch + a_batch_indices = ', '.join(['__i%d' % (offset + i) for i in range(num_a_batch)]) memlet_a = f'{a_batch_indices}, __im, __ik' # For B: if 2D, use [K, N]; if 3D+, use [batch_indices..., K, N] if len(array_b.shape) == 2: memlet_b = '__ik, __in' else: - b_batch_indices = ', '.join(['__i%d' % i for i in range(len(array_b.shape) - 2)]) + # Align input batch dims to output batch dims + num_b_batch = len(array_b.shape) - 2 + # Start from the rightmost batch dimension of output and work backwards + offset = num_batch_dims - num_b_batch + b_batch_indices = ', '.join(['__i%d' % (offset + i) for i in range(num_b_batch)]) memlet_b = f'{b_batch_indices}, __ik, __in' # For C: always has batch dimensions @@ -172,8 +179,11 @@ def expansion(node, state, sdfg): const {dtype}** __mkl_BMM_B = new const {dtype}*[{BATCH}]; {dtype}** __mkl_BMM_C = new {dtype}*[{BATCH}]; for (int __ib = 0; __ib < {BATCH}; __ib++) {{ - __mkl_BMM_A[__ib] = (({dtype}*){x}) + __ib*{stride_a}; - __mkl_BMM_B[__ib] = (({dtype}*){y}) + __ib*{stride_b}; + // Handle broadcasting - compute correct index for inputs with fewer batch dimensions + int __a_idx = ({stride_a} > 0) ? (({a_batch_size} < {BATCH}) ? (__ib % {a_batch_size}) : __ib) : 0; + int __b_idx = ({stride_b} > 0) ? (({b_batch_size} < {BATCH}) ? (__ib % {b_batch_size}) : __ib) : 0; + __mkl_BMM_A[__ib] = (({dtype}*){x}) + __a_idx*{stride_a}; + __mkl_BMM_B[__ib] = (({dtype}*){y}) + __b_idx*{stride_b}; __mkl_BMM_C[__ib] = (({dtype}*)_c) + __ib*{stride_c}; }} @@ -227,9 +237,12 @@ def expansion(node, state, sdfg): code = ''' for (int __ib = 0; __ib < {BATCH}; ++__ib) {{ + // Handle broadcasting - compute correct index for inputs with fewer batch dimensions + int __a_idx = ({stride_a} > 0) ? (({a_batch_size} < {BATCH}) ? (__ib % {a_batch_size}) : __ib) : 0; + int __b_idx = ({stride_b} > 0) ? (({b_batch_size} < {BATCH}) ? (__ib % {b_batch_size}) : __ib) : 0; cblas_{func}(CblasColMajor, {ta}, {tb}, {M}, {N}, {K}, {alpha}, - (({dtype}*){x}) + __ib*{stride_a}, {lda}, - (({dtype}*){y}) + __ib*{stride_b}, {ldb}, + (({dtype}*){x}) + __a_idx*{stride_a}, {lda}, + (({dtype}*){y}) + __b_idx*{stride_b}, {ldb}, {beta}, (({dtype}*)_c) + __ib*{stride_c}, {ldc}); }}'''.format_map(opt) @@ -325,17 +338,38 @@ def expansion(node, state, sdfg): opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) opt['array_prefix'] = '_' if needs_copy else '' + # Check if we need broadcasting (non-uniform strides) + needs_broadcasting = (opt.get('a_batch_size') and opt.get('b_batch_size') + and (opt['a_batch_size'] != opt['BATCH'] or opt['b_batch_size'] != opt['BATCH'])) + # Matrix multiplication if (node.compute_type is None and node.accumulator_type is None and node.algorithm is None): - call = '''cublas{func}StridedBatched(__dace_cublas_handle, - CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, - {M}, {N}, {K}, - {alpha}, - ({dtype}*){array_prefix}{x}, {lda}, {stride_a}, - ({dtype}*){array_prefix}{y}, {ldb}, {stride_b}, - {beta}, - ({dtype}*){array_prefix}_c, {ldc}, {stride_c}, - {BATCH});'''.format_map(opt) + if needs_broadcasting: + # Use manual loop for broadcasting cases + call = ''' + for (int __ib = 0; __ib < {BATCH}; ++__ib) {{ + int __a_idx = ({stride_a} > 0) ? (({a_batch_size} < {BATCH}) ? (__ib % {a_batch_size}) : __ib) : 0; + int __b_idx = ({stride_b} > 0) ? (({b_batch_size} < {BATCH}) ? (__ib % {b_batch_size}) : __ib) : 0; + cublas{func}(__dace_cublas_handle, + CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, + {M}, {N}, {K}, + {alpha}, + ({dtype}*){array_prefix}{x} + __a_idx*{stride_a}, {lda}, + ({dtype}*){array_prefix}{y} + __b_idx*{stride_b}, {ldb}, + {beta}, + ({dtype}*){array_prefix}_c + __ib*{stride_c}, {ldc}); + }}'''.format_map(opt) + else: + # Use StridedBatched for uniform case + call = '''cublas{func}StridedBatched(__dace_cublas_handle, + CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, + {M}, {N}, {K}, + {alpha}, + ({dtype}*){array_prefix}{x}, {lda}, {stride_a}, + ({dtype}*){array_prefix}{y}, {ldb}, {stride_b}, + {beta}, + ({dtype}*){array_prefix}_c, {ldc}, {stride_c}, + {BATCH});'''.format_map(opt) else: if node.compute_type is not None: acctype = node.compute_type @@ -349,24 +383,49 @@ def expansion(node, state, sdfg): if node.algorithm is not None: algorithm = node.algorithm - call = f''' - cublasGemmStridedBatchedEx(__dace_cublas_handle, - CUBLAS_OP_{opt['ta']}, CUBLAS_OP_{opt['tb']}, - {opt['M']}, {opt['N']}, {opt['K']}, - {alpha}, - {opt['array_prefix']}{opt['x']}, - {dtype_to_cudadatatype(opt['xdtype'])}, - {opt['lda']}, {opt['stride_a']}, - {opt['array_prefix']}{opt['y']}, - {dtype_to_cudadatatype(opt['ydtype'])}, - {opt['ldb']}, {opt['stride_b']}, - {beta}, - {opt['array_prefix']}_c, - {dtype_to_cudadatatype(opt['cdtype'])}, - {opt['ldc']}, {opt['stride_c']}, - {opt['BATCH']}, - {acctype}, {algorithm}); - ''' + if needs_broadcasting: + # Use manual loop for broadcasting cases with GemmEx + call = f''' + for (int __ib = 0; __ib < {opt['BATCH']}; ++__ib) {{{{ + int __a_idx = ({opt['stride_a']} > 0) ? (({opt['a_batch_size']} < {opt['BATCH']}) ? (__ib % {opt['a_batch_size']}) : __ib) : 0; + int __b_idx = ({opt['stride_b']} > 0) ? (({opt['b_batch_size']} < {opt['BATCH']}) ? (__ib % {opt['b_batch_size']}) : __ib) : 0; + cublasGemmEx(__dace_cublas_handle, + CUBLAS_OP_{opt['ta']}, CUBLAS_OP_{opt['tb']}, + {opt['M']}, {opt['N']}, {opt['K']}, + {alpha}, + {opt['array_prefix']}{opt['x']} + __a_idx*{opt['stride_a']}, + {dtype_to_cudadatatype(opt['xdtype'])}, + {opt['lda']}, + {opt['array_prefix']}{opt['y']} + __b_idx*{opt['stride_b']}, + {dtype_to_cudadatatype(opt['ydtype'])}, + {opt['ldb']}, + {beta}, + {opt['array_prefix']}_c + __ib*{opt['stride_c']}, + {dtype_to_cudadatatype(opt['cdtype'])}, + {opt['ldc']}, + {acctype}, {algorithm}); + }}}} + ''' + else: + # Use StridedBatchedEx for uniform case + call = f''' + cublasGemmStridedBatchedEx(__dace_cublas_handle, + CUBLAS_OP_{opt['ta']}, CUBLAS_OP_{opt['tb']}, + {opt['M']}, {opt['N']}, {opt['K']}, + {alpha}, + {opt['array_prefix']}{opt['x']}, + {dtype_to_cudadatatype(opt['xdtype'])}, + {opt['lda']}, {opt['stride_a']}, + {opt['array_prefix']}{opt['y']}, + {dtype_to_cudadatatype(opt['ydtype'])}, + {opt['ldb']}, {opt['stride_b']}, + {beta}, + {opt['array_prefix']}_c, + {dtype_to_cudadatatype(opt['cdtype'])}, + {opt['ldc']}, {opt['stride_c']}, + {opt['BATCH']}, + {acctype}, {algorithm}); + ''' code = call_prefix + call + call_suffix tasklet = dace.sdfg.nodes.Tasklet(node.name, diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index f5f308b2c8..471e8deb0e 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -104,6 +104,13 @@ def _get_batchmm_opts(a_shape, a_strides, b_shape, b_strides, c_shape, c_strides # Calculate strides for batched operations # For a tensor with shape [B1, B2, ..., M, K], the stride for batched operations # should be M*K (the size of each matrix) to iterate through all matrices in the flattened batch + # + # For broadcasting cases (e.g., A - [b1, b2, m, k] @ B - [b2, k, n]): + # - The flattened batch is b1*b2 + # - B needs special handling: we need to compute which of the b2 matrices to use + # For batch index i in [0, b1*b2), the B matrix index is (i % b2) + # This can be expressed as: if A has more batch dims than B, use modulo arithmetic + stride_a = 0 stride_b = 0 stride_c = 0 @@ -125,10 +132,35 @@ def _get_batchmm_opts(a_shape, a_strides, b_shape, b_strides, c_shape, c_strides if res is False: raise ValueError(f'Output batch dimension mismatch: {c_dim} vs {r_dim} at position {i}') + # For partial broadcasting (3D-4D cases), we need to track additional information + # to properly index into the smaller batch dimension tensor + a_batch_multiplier = 1 # How many times to cycle through A's batch + b_batch_multiplier = 1 # How many times to cycle through B's batch + + if len(a_batch_dims) < len(result_batch_dims): + # A has fewer batch dimensions, so it will be broadcast + # Calculate the size of the leading dimensions that A doesn't have + a_batch_multiplier = prod(result_batch_dims[:len(result_batch_dims) - len(a_batch_dims)]) + + if len(b_batch_dims) < len(result_batch_dims): + # B has fewer batch dimensions, so it will be broadcast + # Calculate the size of the leading dimensions that B doesn't have + b_batch_multiplier = prod(result_batch_dims[:len(result_batch_dims) - len(b_batch_dims)]) + if batch_size == 1 and not result_batch_dims: return {} - return {'sa': stride_a, 'sb': stride_b, 'sc': stride_c, 'b': batch_size, 'batch_dims': result_batch_dims} + return { + 'sa': stride_a, + 'sb': stride_b, + 'sc': stride_c, + 'b': batch_size, + 'batch_dims': result_batch_dims, + 'a_batch_size': prod(a_batch_dims) if a_batch_dims else 1, + 'b_batch_size': prod(b_batch_dims) if b_batch_dims else 1, + 'a_batch_multiplier': a_batch_multiplier, + 'b_batch_multiplier': b_batch_multiplier + } def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) -> Dict[str, Any]: @@ -165,6 +197,7 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, if opt['swap']: if bopt: bopt['sa'], bopt['sb'] = bopt['sb'], bopt['sa'] + bopt['a_batch_size'], bopt['b_batch_size'] = bopt['b_batch_size'], bopt['a_batch_size'] opt['lda'], opt['ldb'] = opt['ldb'], opt['lda'] opt['x'], opt['y'] = opt['y'], opt['x'] opt['xdtype'], opt['ydtype'] = opt['ydtype'], opt['xdtype'] @@ -180,6 +213,8 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, opt['stride_b'] = sym2cpp(bopt['sb']) opt['stride_c'] = sym2cpp(bopt['sc']) opt['BATCH'] = sym2cpp(bopt['b']) + opt['a_batch_size'] = sym2cpp(bopt['a_batch_size']) + opt['b_batch_size'] = sym2cpp(bopt['b_batch_size']) else: opt['BATCH'] = None diff --git a/tests/library/batched_matmul_test.py b/tests/library/batched_matmul_test.py index 0583ca793c..1504724381 100644 --- a/tests/library/batched_matmul_test.py +++ b/tests/library/batched_matmul_test.py @@ -224,6 +224,111 @@ def bmm_4d_broadcast(A: dtype[m, k], B: dtype[b1, b2, k, n], C: dtype[b1, b2, m, assert np.allclose(ref, z) +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu), + pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_batchmm_3d_4d_broadcast(implementation: str, dtype): + """Test 4D batched matmul with broadcast on LHS: [b2, m, k] @ [b1, b2, k, n]""" + b1, b2, m, n, k = 4, 2, 64, 128, 64 + + @dace.program + def bmm_3d_4d_broadcast(A: dtype[b2, m, k], B: dtype[b1, b2, k, n], C: dtype[b1, b2, m, n]): + C[:] = A @ B + + with change_default(blas, implementation): + sdfg = bmm_3d_4d_broadcast.to_sdfg() + sdfg.simplify() + sdfg.expand_library_nodes() + + x = np.random.rand(b2, m, k).astype(dtype.as_numpy_dtype()) + y = np.random.rand(b1, b2, k, n).astype(dtype.as_numpy_dtype()) + z = np.zeros([b1, b2, m, n]).astype(dtype.as_numpy_dtype()) + + csdfg = sdfg.compile() + csdfg(A=x, B=y, C=z) + + ref = x @ y + + assert np.allclose(ref, z) + + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu), + pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_batchmm_4d_3d_broadcast(implementation: str, dtype): + """Test 4D batched matmul with broadcast on RHS: [b1, b2, m, k] @ [b2, k, n]""" + b1, b2, m, n, k = 4, 2, 64, 128, 64 + + @dace.program + def bmm_4d_3d_broadcast(A: dtype[b1, b2, m, k], B: dtype[b2, k, n], C: dtype[b1, b2, m, n]): + C[:] = A @ B + + with change_default(blas, implementation): + sdfg = bmm_4d_3d_broadcast.to_sdfg() + sdfg.simplify() + sdfg.expand_library_nodes() + + x = np.random.rand(b1, b2, m, k).astype(dtype.as_numpy_dtype()) + y = np.random.rand(b2, k, n).astype(dtype.as_numpy_dtype()) + z = np.zeros([b1, b2, m, n]).astype(dtype.as_numpy_dtype()) + + csdfg = sdfg.compile() + csdfg(A=x, B=y, C=z) + + ref = x @ y + + assert np.allclose(ref, z) + + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu), + pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_batchmm_5d_3d_broadcast(implementation: str, dtype): + """Test 5D batched matmul with broadcast on RHS: [b1, b2, b3, m, k] @ [b3, k, n]""" + b1, b2, b3, m, n, k = 4, 2, 3, 64, 128, 64 + + @dace.program + def bmm_5d_3d_broadcast(A: dtype[b1, b2, b3, m, k], B: dtype[b3, k, n], C: dtype[b1, b2, b3, m, n]): + C[:] = A @ B + + with change_default(blas, implementation): + sdfg = bmm_5d_3d_broadcast.to_sdfg() + sdfg.simplify() + sdfg.expand_library_nodes() + + x = np.random.rand(b1, b2, b3, m, k).astype(dtype.as_numpy_dtype()) + y = np.random.rand(b3, k, n).astype(dtype.as_numpy_dtype()) + z = np.zeros([b1, b2, b3, m, n]).astype(dtype.as_numpy_dtype()) + + csdfg = sdfg.compile() + csdfg(A=x, B=y, C=z) + + ref = x @ y + + assert np.allclose(ref, z) + + if __name__ == "__main__": test_batchmm("pure", dace.float32) test_batchmm("pure", dace.float64) @@ -261,3 +366,21 @@ def bmm_4d_broadcast(A: dtype[m, k], B: dtype[b1, b2, k, n], C: dtype[b1, b2, m, test_batchmm_4d_broadcast_lhs("MKL", dace.float64) test_batchmm_4d_broadcast_lhs("cuBLAS", dace.float32) test_batchmm_4d_broadcast_lhs("cuBLAS", dace.float64) + test_batchmm_3d_4d_broadcast("pure", dace.float32) + test_batchmm_3d_4d_broadcast("pure", dace.float64) + test_batchmm_3d_4d_broadcast("MKL", dace.float32) + test_batchmm_3d_4d_broadcast("MKL", dace.float64) + test_batchmm_3d_4d_broadcast("cuBLAS", dace.float32) + test_batchmm_3d_4d_broadcast("cuBLAS", dace.float64) + test_batchmm_4d_3d_broadcast("pure", dace.float32) + test_batchmm_4d_3d_broadcast("pure", dace.float64) + test_batchmm_4d_3d_broadcast("MKL", dace.float32) + test_batchmm_4d_3d_broadcast("MKL", dace.float64) + test_batchmm_4d_3d_broadcast("cuBLAS", dace.float32) + test_batchmm_4d_3d_broadcast("cuBLAS", dace.float64) + test_batchmm_5d_3d_broadcast("pure", dace.float32) + test_batchmm_5d_3d_broadcast("pure", dace.float64) + test_batchmm_5d_3d_broadcast("MKL", dace.float32) + test_batchmm_5d_3d_broadcast("MKL", dace.float64) + test_batchmm_5d_3d_broadcast("cuBLAS", dace.float32) + test_batchmm_5d_3d_broadcast("cuBLAS", dace.float64) From 3a88e777e2ec5518f8ec61419735582687698215 Mon Sep 17 00:00:00 2001 From: Affifboudaoud Date: Wed, 22 Oct 2025 00:01:12 +0200 Subject: [PATCH 2/7] Add support for 1D batched-GEMV broadcasting --- dace/frontend/python/replacements/linalg.py | 28 ++ dace/libraries/blas/blas_helpers.py | 20 +- dace/libraries/blas/nodes/batched_matmul.py | 377 ++++++++++++++++++-- dace/libraries/blas/nodes/matmul.py | 31 +- tests/library/batched_matmul_test.py | 99 +++++ 5 files changed, 511 insertions(+), 44 deletions(-) diff --git a/dace/frontend/python/replacements/linalg.py b/dace/frontend/python/replacements/linalg.py index 3c10d95d23..12e49a6e4e 100644 --- a/dace/frontend/python/replacements/linalg.py +++ b/dace/frontend/python/replacements/linalg.py @@ -86,6 +86,34 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op output_shape = (1, ) + elif len(arr1.shape) == 1 and len(arr2.shape) > 2: # vector @ batched matrix (e.g., [k] @ [b, k, n]) + + res = symbolic.equal(arr1.shape[0], arr2.shape[-2]) + if res is None: + warnings.warn( + f'Length of vector {arr1.shape[0]} and second-last dimension of tensor {arr2.shape[-2]} ' + f'may not match', UserWarning) + elif not res: + raise SyntaxError(f"Length of vector {arr1.shape[0]} must match " + f"second-last dimension of tensor {arr2.shape[-2]}") + + # Output has all batch dimensions plus the last dimension of arr2 + output_shape = arr2.shape[:-2] + (arr2.shape[-1], ) + + elif len(arr1.shape) > 2 and len(arr2.shape) == 1: # batched matrix @ vector (e.g., [b, m, k] @ [k]) + + res = symbolic.equal(arr1.shape[-1], arr2.shape[0]) + if res is None: + warnings.warn( + f'Last dimension of tensor {arr1.shape[-1]} and length of vector {arr2.shape[0]} ' + f'may not match', UserWarning) + elif not res: + raise SyntaxError(f"Last dimension of tensor {arr1.shape[-1]} must match " + f"length of vector {arr2.shape[0]}") + + # Output has all batch dimensions plus the second-last dimension of arr1 + output_shape = arr1.shape[:-1] + else: # Dunno what this is, bail raise SyntaxError("Cannot multiply arrays with shapes: {} and {}".format(arr1.shape, arr2.shape)) diff --git a/dace/libraries/blas/blas_helpers.py b/dace/libraries/blas/blas_helpers.py index 6a568f6e4a..b2d8766530 100644 --- a/dace/libraries/blas/blas_helpers.py +++ b/dace/libraries/blas/blas_helpers.py @@ -110,9 +110,23 @@ def get_gemm_opts(a_strides, b_strides, c_strides) -> Dict[str, Any]: # | | | # use these 3 to detect correct option - sAM, sAK = a_strides[-2:] - sBK, sBN = b_strides[-2:] - sCM, sCN = c_strides[-2:] + # Handle 1D inputs by treating them as column/row vectors + # [k] -> treat as [k, 1] with stride [1, k] for column vector + if len(a_strides) == 1: + sAM, sAK = a_strides[0], 1 # Treat as column vector [k, 1] + else: + sAM, sAK = a_strides[-2:] + + # Treat as row vector [1, k] -> transposed to [k, 1] + if len(b_strides) == 1: + sBK, sBN = 1, b_strides[0] + else: + sBK, sBN = b_strides[-2:] + + if len(c_strides) == 1: + sCM, sCN = c_strides[0], 1 + else: + sCM, sCN = c_strides[-2:] opts = { 'mkm': { diff --git a/dace/libraries/blas/nodes/batched_matmul.py b/dace/libraries/blas/nodes/batched_matmul.py index 8a29994b91..3e51a70501 100644 --- a/dace/libraries/blas/nodes/batched_matmul.py +++ b/dace/libraries/blas/nodes/batched_matmul.py @@ -1,5 +1,6 @@ # Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. from copy import deepcopy as dc +from math import prod from dace import dtypes, memlet as mm, properties, data as dt from dace.symbolic import symstr, equal import dace.library @@ -26,20 +27,34 @@ def make_sdfg(node, parent_state, parent_sdfg): cdesc = parent_sdfg.arrays[outedge.data.data] bopt = _get_batchmm_opts(shape_a, strides_a, shape_b, strides_b, cdesc.shape, cdesc.strides) - res = equal(shape_a[-1], shape_b[-2]) + # Handle 1D inputs - determine dimensions + is_a_1d = len(shape_a) == 1 + is_b_1d = len(shape_b) == 1 + + if is_a_1d: + # [k] treated as row vector for matmul + m_dim = 1 + k_dim = shape_a[0] + else: + m_dim = shape_a[-2] + k_dim = shape_a[-1] + + if is_b_1d: + # [k] treated as column vector for matmul + k_dim_b = shape_b[0] + n_dim = 1 + else: + k_dim_b = shape_b[-2] + n_dim = shape_b[-1] + + res = equal(k_dim, k_dim_b) if res is None: - warnings.warn(f"First matrix columns {shape_a[-1]} may not match second matrix rows {shape_b[-2]}", - UserWarning) + warnings.warn(f"K-dimensions {k_dim} may not match {k_dim_b}", UserWarning) elif not res: - raise SyntaxError("Matrix sizes must match") + raise SyntaxError(f"K-dimensions must match: {k_dim} vs {k_dim_b}") - # Determine output shape based on batch options - if bopt: - # Use batch dimensions from bopt (may be multi-dimensional) - batch_dims = bopt.get('batch_dims', [bopt['b']]) - shape_c = tuple(batch_dims) + (shape_a[-2], shape_b[-1]) - else: - shape_c = (shape_a[-2], shape_b[-1]) + # Determine output shape - the actual output shape from cdesc + shape_c = cdesc.shape dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type @@ -69,23 +84,41 @@ def make_sdfg(node, parent_state, parent_sdfg): state = sdfg.add_state_after(init_state, node.label + "_state") # Calculate number of batch dimensions in output - num_batch_dims = len(shape_c) - 2 + # For 1D cases, output may have fewer dimensions + # e.g., [3, 32, 64] @ [64] = [3, 32] + if is_a_1d and is_b_1d: + # [k] @ [k] = scalar, this shouldn't happen in batched context + num_batch_dims = len(shape_c) + elif is_a_1d: + # [k] @ [batch..., k, n] = [batch..., n] + num_batch_dims = len(shape_c) - 1 # All dims except N + elif is_b_1d: + # [batch..., m, k] @ [k] = [batch..., m] + num_batch_dims = len(shape_c) - 1 # All dims except M + else: + # Regular case: [batch..., m, k] @ [batch..., k, n] = [batch..., m, n] + num_batch_dims = len(shape_c) - 2 # Build map parameters: batch dimensions + M, N, K map_params = {} for i in range(num_batch_dims): map_params['__i%d' % i] = '0:%s' % symstr(shape_c[i]) - # M, N, K dimensions - map_params['__im'] = '0:%s' % symstr(shape_a[-2]) - map_params['__in'] = '0:%s' % symstr(shape_b[-1]) - map_params['__ik'] = '0:%s' % symstr(shape_a[-1]) + # M, N, K dimensions - always create map parameters + map_params['__im'] = '0:%s' % symstr(m_dim) + map_params['__in'] = '0:%s' % symstr(n_dim) + map_params['__ik'] = '0:%s' % symstr(k_dim) # Build memlet access patterns - # For A: if 2D, use [M, K]; if 3D+, use [batch_indices..., M, K] - if len(array_a.shape) == 2: + # Handle 1D inputs specially - they only have __ik dimension + if is_a_1d: + # [k] input - just use __ik + memlet_a = '__ik' + elif len(array_a.shape) == 2: + # 2D input [M, K] memlet_a = '__im, __ik' else: + # 3D+ input [batch..., M, K] # Align input batch dims to output batch dims num_a_batch = len(array_a.shape) - 2 # Start from the rightmost batch dimension of output and work backwards @@ -93,10 +126,15 @@ def make_sdfg(node, parent_state, parent_sdfg): a_batch_indices = ', '.join(['__i%d' % (offset + i) for i in range(num_a_batch)]) memlet_a = f'{a_batch_indices}, __im, __ik' - # For B: if 2D, use [K, N]; if 3D+, use [batch_indices..., K, N] - if len(array_b.shape) == 2: + # For B: if 1D, use [K]; if 2D, use [K, N]; if 3D+, use [batch_indices..., K, N] + if is_b_1d: + # [k] input - just use __ik + memlet_b = '__ik' + elif len(array_b.shape) == 2: + # 2D input [K, N] memlet_b = '__ik, __in' else: + # 3D+ input [batch..., K, N] # Align input batch dims to output batch dims num_b_batch = len(array_b.shape) - 2 # Start from the rightmost batch dimension of output and work backwards @@ -104,8 +142,31 @@ def make_sdfg(node, parent_state, parent_sdfg): b_batch_indices = ', '.join(['__i%d' % (offset + i) for i in range(num_b_batch)]) memlet_b = f'{b_batch_indices}, __ik, __in' - # For C: always has batch dimensions - c_indices = ', '.join(['__i%d' % i for i in range(num_batch_dims)]) + ', __im, __in' + # For C: build indices matching the output shape + c_indices_parts = [] + if is_a_1d and is_b_1d: + # Scalar output - all batch dims + for i in range(num_batch_dims): + c_indices_parts.append(f'__i{i}') + elif is_a_1d: + # [k] @ [batch..., k, n] = [batch..., n] + # Output has batch dims + n + for i in range(num_batch_dims): + c_indices_parts.append(f'__i{i}') + c_indices_parts.append('__in') + elif is_b_1d: + # [batch..., m, k] @ [k] = [batch..., m] + # Output has batch dims + m + for i in range(num_batch_dims): + c_indices_parts.append(f'__i{i}') + c_indices_parts.append('__im') + else: + # Regular: [batch..., m, k] @ [batch..., k, n] = [batch..., m, n] + for i in range(num_batch_dims): + c_indices_parts.append(f'__i{i}') + c_indices_parts.append('__im') + c_indices_parts.append('__in') + c_indices = ', '.join(c_indices_parts) state.add_mapped_tasklet('_BatchedMatMult_', map_params, { @@ -129,6 +190,115 @@ class ExpandBatchedMatMulMKL(ExpandTransformation): environments = [environments.intel_mkl.IntelMKL] + @staticmethod + def _expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, astrides, bstrides, dtype, is_a_1d, + is_b_1d): + """Expand batched matrix-vector or vector-matrix multiplication using GEMV loops.""" + from dace.codegen.common import sym2cpp + + prefix = to_blastype(dtype.type).lower() + if dtype == dace.float32: + alpha = "1.0f" + beta = "0.0f" + elif dtype == dace.float64: + alpha = "1.0" + beta = "0.0" + elif dtype == dace.complex64: + alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" + beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + elif dtype == dace.complex128: + alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" + beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + else: + raise ValueError("Unsupported type for BLAS: " + str(dtype)) + + # Determine batch size and strides + cshape = cdesc.shape + + if is_a_1d and is_b_1d: + # Both 1D - shouldn't happen in batched context + raise ValueError("Both inputs are 1D - use dot product instead") + elif is_a_1d: + # [k] @ [batch..., k, n] = [batch..., n] + batch_size = prod(cshape[:-1]) + k = ashape[0] + n = bshape[-1] + + # Detect storage order from B's strides + # B has shape [batch..., k, n] with strides [..., stride_k, stride_n] + # Row-major: stride_k > stride_n (elements in same row are contiguous) + # Column-major: stride_k < stride_n (elements in same column are contiguous) + stride_k = bstrides[-2] + stride_n = bstrides[-1] + + if stride_n == 1: # Row-major: rightmost dimension has stride 1 + layout = 'CblasRowMajor' + trans = 'CblasTrans' # For a @ B[i], compute B[i]^T @ a + ldb = n # Leading dimension in row-major + stride_b = k * n # Stride between batches + else: # Column-major: leftmost matrix dimension has stride 1 + layout = 'CblasColMajor' + trans = 'CblasTrans' + ldb = k + stride_b = k * n + + stride_c = n # Output stride + + elif is_b_1d: + # [batch..., m, k] @ [k] = [batch..., m] + batch_size = prod(cshape[:-1]) + m = ashape[-2] + k = ashape[-1] + + # Detect storage order from A's strides + stride_k = astrides[-1] + + if stride_k == 1: # Row-major + layout = 'CblasRowMajor' + trans = 'CblasNoTrans' # For A[i] @ b, no transpose needed + lda = k # Leading dimension in row-major + stride_a = m * k + else: # Column-major + layout = 'CblasColMajor' + trans = 'CblasNoTrans' + lda = m + stride_a = m * k + + stride_c = m # Output stride + else: + raise ValueError("Unexpected case - neither input is 1D") + + # Generate code + if is_a_1d: + # [k] @ [batch..., k, n]: loop over batch, each time: c[i] = B[i]^T @ a + code = f''' + for (int __ib = 0; __ib < {sym2cpp(batch_size)}; ++__ib) {{ + cblas_{prefix}gemv({layout}, {trans}, {sym2cpp(k)}, {sym2cpp(n)}, + {alpha}, + (({dtype.ctype}*)_b) + __ib*{sym2cpp(stride_b)}, {sym2cpp(ldb)}, + ({dtype.ctype}*)_a, 1, + {beta}, + (({dtype.ctype}*)_c) + __ib*{sym2cpp(stride_c)}, 1); + }}''' + else: # is_b_1d + # [batch..., m, k] @ [k]: loop over batch, each time: c[i] = A[i] @ b + code = f''' + for (int __ib = 0; __ib < {sym2cpp(batch_size)}; ++__ib) {{ + cblas_{prefix}gemv({layout}, {trans}, {sym2cpp(m)}, {sym2cpp(k)}, + {alpha}, + (({dtype.ctype}*)_a) + __ib*{sym2cpp(stride_a)}, {sym2cpp(lda)}, + ({dtype.ctype}*)_b, 1, + {beta}, + (({dtype.ctype}*)_c) + __ib*{sym2cpp(stride_c)}, 1); + }}''' + + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP) + return tasklet + @staticmethod def expansion(node, state, sdfg): node.validate(sdfg, state) @@ -137,6 +307,16 @@ def expansion(node, state, sdfg): cdesc: dt.Array = sdfg.arrays[state.out_edges(node)[0].data.data] check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc) dtype = cdesc.dtype.base_type + + # Check if we have 1D inputs (vector operations) + is_a_1d = len(ashape) == 1 + is_b_1d = len(bshape) == 1 + + # For 1D cases, use GEMV instead of batched GEMM + if is_a_1d or is_b_1d: + return ExpandBatchedMatMulMKL._expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, + astrides, bstrides, dtype, is_a_1d, is_b_1d) + func = to_blastype(dtype.type).lower() + 'gemm' if dtype == dace.float32: alpha = "1.0f" @@ -206,6 +386,115 @@ def expansion(node, state, sdfg): class ExpandBatchedMatMulOpenBLAS(ExpandTransformation): environments = [environments.openblas.OpenBLAS] + @staticmethod + def _expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, astrides, bstrides, dtype, is_a_1d, + is_b_1d): + """Expand batched matrix-vector or vector-matrix multiplication using GEMV loops.""" + from dace.codegen.common import sym2cpp + + prefix = to_blastype(dtype.type).lower() + if dtype == dace.float32: + alpha = "1.0f" + beta = "0.0f" + elif dtype == dace.float64: + alpha = "1.0" + beta = "0.0" + elif dtype == dace.complex64: + alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" + beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + elif dtype == dace.complex128: + alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" + beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + else: + raise ValueError("Unsupported type for BLAS: " + str(dtype)) + + # Determine batch size and strides + cshape = cdesc.shape + + if is_a_1d and is_b_1d: + # Both 1D - shouldn't happen in batched context + raise ValueError("Both inputs are 1D - use dot product instead") + elif is_a_1d: + # [k] @ [batch..., k, n] = [batch..., n] + batch_size = prod(cshape[:-1]) + k = ashape[0] + n = bshape[-1] + + # Detect storage order from B's strides + # B has shape [batch..., k, n] with strides [..., stride_k, stride_n] + # Row-major: stride_k > stride_n (elements in same row are contiguous) + # Column-major: stride_k < stride_n (elements in same column are contiguous) + stride_k = bstrides[-2] + stride_n = bstrides[-1] + + if stride_n == 1: # Row-major: rightmost dimension has stride 1 + layout = 'CblasRowMajor' + trans = 'CblasTrans' # For a @ B[i], compute B[i]^T @ a + ldb = n # Leading dimension in row-major + stride_b = k * n # Stride between batches + else: # Column-major: leftmost matrix dimension has stride 1 + layout = 'CblasColMajor' + trans = 'CblasTrans' + ldb = k + stride_b = k * n + + stride_c = n # Output stride + + elif is_b_1d: + # [batch..., m, k] @ [k] = [batch..., m] + batch_size = prod(cshape[:-1]) + m = ashape[-2] + k = ashape[-1] + + # Detect storage order from A's strides + stride_k = astrides[-1] + + if stride_k == 1: # Row-major + layout = 'CblasRowMajor' + trans = 'CblasNoTrans' # For A[i] @ b, no transpose needed + lda = k # Leading dimension in row-major + stride_a = m * k + else: # Column-major + layout = 'CblasColMajor' + trans = 'CblasNoTrans' + lda = m + stride_a = m * k + + stride_c = m # Output stride + else: + raise ValueError("Unexpected case - neither input is 1D") + + # Generate code + if is_a_1d: + # [k] @ [batch..., k, n]: loop over batch, each time: c[i] = B[i]^T @ a + code = f''' + for (int __ib = 0; __ib < {sym2cpp(batch_size)}; ++__ib) {{ + cblas_{prefix}gemv({layout}, {trans}, {sym2cpp(k)}, {sym2cpp(n)}, + {alpha}, + (({dtype.ctype}*)_b) + __ib*{sym2cpp(stride_b)}, {sym2cpp(ldb)}, + ({dtype.ctype}*)_a, 1, + {beta}, + (({dtype.ctype}*)_c) + __ib*{sym2cpp(stride_c)}, 1); + }}''' + else: + # [batch..., m, k] @ [k]: loop over batch, each time: c[i] = A[i] @ b + code = f''' + for (int __ib = 0; __ib < {sym2cpp(batch_size)}; ++__ib) {{ + cblas_{prefix}gemv({layout}, {trans}, {sym2cpp(m)}, {sym2cpp(k)}, + {alpha}, + (({dtype.ctype}*)_a) + __ib*{sym2cpp(stride_a)}, {sym2cpp(lda)}, + ({dtype.ctype}*)_b, 1, + {beta}, + (({dtype.ctype}*)_c) + __ib*{sym2cpp(stride_c)}, 1); + }}''' + + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP) + return tasklet + @staticmethod def expansion(node, state, sdfg): node.validate(sdfg, state) @@ -214,6 +503,16 @@ def expansion(node, state, sdfg): cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc) dtype = cdesc.dtype.base_type + + # Check if we have 1D inputs (vector operations) + is_a_1d = len(ashape) == 1 + is_b_1d = len(bshape) == 1 + + # For 1D cases, use GEMV instead of batched GEMM + if is_a_1d or is_b_1d: + return ExpandBatchedMatMulOpenBLAS._expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, + astrides, bstrides, dtype, is_a_1d, is_b_1d) + func = to_blastype(dtype.type).lower() + 'gemm' if dtype == dace.float32: alpha = "1.0f" @@ -535,30 +834,32 @@ def validate(self, sdfg, state): "batched matrix-matrix product") out_memlet = out_edges[0].data - # Both inputs must be at least 2D - if len(size0) < 2: - raise ValueError(f"First input must be at least 2D, got shape with {len(size0)} dimensions") - if len(size1) < 2: - raise ValueError(f"Second input must be at least 2D, got shape with {len(size1)} dimensions") + # Valid cases: 1D@ND (N>=3), ND@1D (N>=3), ND@MD (N or M >=3) + if len(size0) < 1 or len(size1) < 1: + raise ValueError("Inputs must be at least 1D") - # At least one input must have batch dimensions (3D or higher) for batched operation - if len(size0) <= 2 and len(size1) <= 2: + # For batched operations, we need at least one operand to be 3D+ or one to be 1D + has_1d = (len(size0) == 1 or len(size1) == 1) + has_batch = (len(size0) >= 3 or len(size1) >= 3) + + if not has_1d and not has_batch and not (len(size0) == 2 and len(size1) == 2): + # This would be just regular 2D@2D which isn't batched raise ValueError( - "Batched matrix-matrix product requires at least one input to have batch dimensions (3D or higher)") + "Batched operation requires at least one input to be 1D or have batch dimensions (3D or higher)") # Validate K-dimension compatibility - res = equal(size0[-1], size1[-2]) + # For 1D inputs, the single dimension is the k-dimension + # For 2D+ inputs with matrix structure [..., M, K] or [..., K, N] + k_dim_a = size0[0] if len(size0) == 1 else size0[-1] + k_dim_b = size1[0] if len(size1) == 1 else size1[-2] + + res = equal(k_dim_a, k_dim_b) if res is None: warnings.warn( - f'First tensor\'s last mode {size0[-1]} and second tensor\'s second-last mode {size1[-2]} ' + f'K-dimension of first operand {k_dim_a} and k-dimension of second operand {k_dim_b} ' f'may not match', UserWarning) elif not res: - raise ValueError("Inputs to matrix-matrix product must agree in the k-dimension") - - # Output must have batch dimensions - if len(out_memlet.subset) < 3: - raise ValueError( - f"Batched matrix-matrix product output must be at least 3D, got {len(out_memlet.subset)} dimensions") + raise ValueError(f"Inputs must agree in the k-dimension: {k_dim_a} vs {k_dim_b}") # Numpy replacement diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index 471e8deb0e..a8824539e3 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -187,9 +187,25 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, opt['xdtype'] = adesc.dtype opt['ydtype'] = bdesc.dtype opt['cdtype'] = cdesc.dtype - opt['M'] = sym2cpp(ashape[-2]) - opt['N'] = sym2cpp(bshape[-1]) - opt['K'] = sym2cpp(ashape[-1]) + + # Handle 1D inputs: [k] is treated as [1, k] or [k, 1] depending on context + if len(ashape) == 1: + # [k] @ [..., k, n] -> M=1, K=ashape[0] + opt['M'] = sym2cpp(1) + opt['K'] = sym2cpp(ashape[0]) + else: + opt['M'] = sym2cpp(ashape[-2]) + opt['K'] = sym2cpp(ashape[-1]) + + if len(bshape) == 1: + # [..., m, k] @ [k] -> N=1, K=bshape[0] + opt['N'] = sym2cpp(1) + if 'K' not in opt: + opt['K'] = sym2cpp(bshape[0]) + else: + opt['N'] = sym2cpp(bshape[-1]) + if 'K' not in opt: + opt['K'] = sym2cpp(bshape[-2]) opt['lda'] = sym2cpp(opt['lda']) opt['ldb'] = sym2cpp(opt['ldb']) opt['ldc'] = sym2cpp(opt['ldc']) @@ -283,6 +299,15 @@ def expansion(node, state, sdfg): b[0].dst_conn = "_y" c[0].src_conn = "_result" result = Dot(node.name + 'dot', location=node.location) + elif len(size_a) == 1 and len(size_b) > 2: + # Vector @ batched matrix -> batched GEMV (e.g., [k] @ [b, k, n] = [b, n]) + # This can be viewed as batched y = A^T @ x where x is the vector + from dace.libraries.blas.nodes.batched_matmul import BatchedMatMul + result = BatchedMatMul(node.name + 'bmm', location=node.location) + elif len(size_a) > 2 and len(size_b) == 1: + # Batched matrix @ vector -> batched GEMV (e.g., [b, m, k] @ [k] = [b, m]) + from dace.libraries.blas.nodes.batched_matmul import BatchedMatMul + result = BatchedMatMul(node.name + 'bmm', location=node.location) else: raise NotImplementedError("Matrix multiplication not implemented " "for shapes: {} and {}".format(size_a, size_b)) diff --git a/tests/library/batched_matmul_test.py b/tests/library/batched_matmul_test.py index 1504724381..890936f41a 100644 --- a/tests/library/batched_matmul_test.py +++ b/tests/library/batched_matmul_test.py @@ -329,6 +329,105 @@ def bmm_5d_3d_broadcast(A: dtype[b1, b2, b3, m, k], B: dtype[b3, k, n], C: dtype assert np.allclose(ref, z) +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_batchmm_1d_3d_broadcast(implementation: str, dtype): + """Test 1D-3D batched matmul with broadcast: [k] @ [b, k, n]""" + b, n, k = 3, 32, 64 + + @dace.program + def bmm_1d_3d_broadcast(A: dtype[k], B: dtype[b, k, n], C: dtype[b, n]): + C[:] = A @ B + + with change_default(blas, implementation): + sdfg = bmm_1d_3d_broadcast.to_sdfg() + sdfg.simplify() + sdfg.expand_library_nodes() + + x = np.random.rand(k).astype(dtype.as_numpy_dtype()) + y = np.random.rand(b, k, n).astype(dtype.as_numpy_dtype()) + z = np.zeros([b, n]).astype(dtype.as_numpy_dtype()) + + csdfg = sdfg.compile() + csdfg(A=x, B=y, C=z) + + ref = x @ y + + assert np.allclose(ref, z) + + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_batchmm_3d_1d_broadcast(implementation: str, dtype): + """Test 3D-1D batched matmul with broadcast: [b, m, k] @ [k]""" + b, m, k = 3, 32, 64 + + @dace.program + def bmm_3d_1d_broadcast(A: dtype[b, m, k], B: dtype[k], C: dtype[b, m]): + C[:] = A @ B + + with change_default(blas, implementation): + sdfg = bmm_3d_1d_broadcast.to_sdfg() + sdfg.simplify() + sdfg.expand_library_nodes() + + x = np.random.rand(b, m, k).astype(dtype.as_numpy_dtype()) + y = np.random.rand(k).astype(dtype.as_numpy_dtype()) + z = np.zeros([b, m]).astype(dtype.as_numpy_dtype()) + + csdfg = sdfg.compile() + csdfg(A=x, B=y, C=z) + + ref = x @ y + + assert np.allclose(ref, z) + + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_batchmm_4d_1d_broadcast(implementation: str, dtype): + """Test 4D-1D batched matmul with broadcast: [b1, b2, m, k] @ [k]""" + b1, b2, m, k = 4, 2, 32, 64 + + @dace.program + def bmm_4d_1d_broadcast(A: dtype[b1, b2, m, k], B: dtype[k], C: dtype[b1, b2, m]): + C[:] = A @ B + + with change_default(blas, implementation): + sdfg = bmm_4d_1d_broadcast.to_sdfg() + sdfg.simplify() + sdfg.expand_library_nodes() + + x = np.random.rand(b1, b2, m, k).astype(dtype.as_numpy_dtype()) + y = np.random.rand(k).astype(dtype.as_numpy_dtype()) + z = np.zeros([b1, b2, m]).astype(dtype.as_numpy_dtype()) + + csdfg = sdfg.compile() + csdfg(A=x, B=y, C=z) + + ref = x @ y + + assert np.allclose(ref, z) + + if __name__ == "__main__": test_batchmm("pure", dace.float32) test_batchmm("pure", dace.float64) From bf975026c479e7cfabdbe7db98e3e9434ebcd2a2 Mon Sep 17 00:00:00 2001 From: Affifboudaoud Date: Wed, 22 Oct 2025 14:54:20 +0200 Subject: [PATCH 3/7] Fix and test MatMul accumulation --- dace/libraries/blas/nodes/batched_matmul.py | 207 +++++++++++----- dace/libraries/blas/nodes/gemm.py | 21 +- dace/libraries/blas/nodes/gemv.py | 25 +- tests/library/test_matmul_accumulate.py | 260 ++++++++++++++++++++ 4 files changed, 455 insertions(+), 58 deletions(-) create mode 100644 tests/library/test_matmul_accumulate.py diff --git a/dace/libraries/blas/nodes/batched_matmul.py b/dace/libraries/blas/nodes/batched_matmul.py index 3e51a70501..73a753d8b9 100644 --- a/dace/libraries/blas/nodes/batched_matmul.py +++ b/dace/libraries/blas/nodes/batched_matmul.py @@ -71,17 +71,35 @@ def make_sdfg(node, parent_state, parent_sdfg): _, array_b = sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=storage) _, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-3], storage=storage) - # Add an initialization state - init_state = sdfg.add_state() - init_state.add_mapped_tasklet( - 'batched_matmul_init', { - '_o%d' % i: '0:%s' % symstr(d) - for i, d in enumerate(shape_c) - }, {}, - 'out = 0', {'out': dace.Memlet.simple('_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))}, - external_edges=True) - - state = sdfg.add_state_after(init_state, node.label + "_state") + # Handle beta factor for C + # C_new = alpha * A @ B + beta * C_old + if node.beta == 0: + # Initialize C to 0 + init_state = sdfg.add_state() + init_state.add_mapped_tasklet( + 'batched_matmul_init', { + '_o%d' % i: '0:%s' % symstr(d) + for i, d in enumerate(shape_c) + }, {}, + 'out = 0', {'out': dace.Memlet.simple('_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))}, + external_edges=True) + state = sdfg.add_state_after(init_state, node.label + "_state") + elif node.beta != 1: + # Scale C by beta before accumulation + init_state = sdfg.add_state() + beta_value = node.beta + init_state.add_mapped_tasklet( + 'batched_matmul_scale_c', { + '_o%d' % i: '0:%s' % symstr(d) + for i, d in enumerate(shape_c) + }, {'_in': dace.Memlet.simple('_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))}, + f'_out = {beta_value} * _in', + {'_out': dace.Memlet.simple('_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))}, + external_edges=True) + state = sdfg.add_state_after(init_state, node.label + "_state") + else: + # beta == 1: Just accumulate into existing C values + state = sdfg.add_state(node.label + "_state") # Calculate number of batch dimensions in output # For 1D cases, output may have fewer dimensions @@ -168,12 +186,19 @@ def make_sdfg(node, parent_state, parent_sdfg): c_indices_parts.append('__in') c_indices = ', '.join(c_indices_parts) + # Handle alpha factor in the multiplication + alpha_value = node.alpha + if alpha_value == 1: + tasklet_code = '__c = __a * __b' + else: + tasklet_code = f'__c = {alpha_value} * __a * __b' + state.add_mapped_tasklet('_BatchedMatMult_', map_params, { '__a': dace.Memlet.simple("_a", memlet_a), '__b': dace.Memlet.simple("_b", memlet_b) }, - '__c = __a * __b', + tasklet_code, {'__c': dace.Memlet.simple("_c", c_indices, wcr_str='lambda x, y: x + y')}, external_edges=True) @@ -197,18 +222,31 @@ def _expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, as from dace.codegen.common import sym2cpp prefix = to_blastype(dtype.type).lower() + # Use node's alpha and beta values if dtype == dace.float32: - alpha = "1.0f" - beta = "0.0f" + alpha = f"{float(node.alpha)}f" + beta = f"{float(node.beta)}f" elif dtype == dace.float64: - alpha = "1.0" - beta = "0.0" + alpha = f"{float(node.alpha)}" + beta = f"{float(node.beta)}" elif dtype == dace.complex64: - alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" - beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + if node.alpha == 1: + alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" + else: + alpha = f"dace::blas::make_cuComplex({node.alpha}, 0)" + if node.beta == 0: + beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + else: + beta = f"dace::blas::make_cuComplex({node.beta}, 0)" elif dtype == dace.complex128: - alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" - beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + if node.alpha == 1: + alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" + else: + alpha = f"dace::blas::make_cuDoubleComplex({node.alpha}, 0)" + if node.beta == 0: + beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + else: + beta = f"dace::blas::make_cuDoubleComplex({node.beta}, 0)" else: raise ValueError("Unsupported type for BLAS: " + str(dtype)) @@ -318,21 +356,35 @@ def expansion(node, state, sdfg): astrides, bstrides, dtype, is_a_1d, is_b_1d) func = to_blastype(dtype.type).lower() + 'gemm' + + # Use node's alpha and beta values if dtype == dace.float32: - alpha = "1.0f" - beta = "0.0f" + alpha = f"{float(node.alpha)}f" + beta = f"{float(node.beta)}f" prefix = "s" elif dtype == dace.float64: - alpha = "1.0" - beta = "0.0" + alpha = f"{float(node.alpha)}" + beta = f"{float(node.beta)}" prefix = "d" elif dtype == dace.complex64: - alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" - beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + if node.alpha == 1: + alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" + else: + alpha = f"dace::blas::make_cuComplex({node.alpha}, 0)" + if node.beta == 0: + beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + else: + beta = f"dace::blas::make_cuComplex({node.beta}, 0)" prefix = "c" elif dtype == dace.complex128: - alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" - beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + if node.alpha == 1: + alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" + else: + alpha = f"dace::blas::make_cuDoubleComplex({node.alpha}, 0)" + if node.beta == 0: + beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + else: + beta = f"dace::blas::make_cuDoubleComplex({node.beta}, 0)" prefix = "z" else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) @@ -393,18 +445,31 @@ def _expand_gemv_loop(node, state, sdfg, adesc, bdesc, cdesc, ashape, bshape, as from dace.codegen.common import sym2cpp prefix = to_blastype(dtype.type).lower() + # Use node's alpha and beta values if dtype == dace.float32: - alpha = "1.0f" - beta = "0.0f" + alpha = f"{float(node.alpha)}f" + beta = f"{float(node.beta)}f" elif dtype == dace.float64: - alpha = "1.0" - beta = "0.0" + alpha = f"{float(node.alpha)}" + beta = f"{float(node.beta)}" elif dtype == dace.complex64: - alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" - beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + if node.alpha == 1: + alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" + else: + alpha = f"dace::blas::make_cuComplex({node.alpha}, 0)" + if node.beta == 0: + beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + else: + beta = f"dace::blas::make_cuComplex({node.beta}, 0)" elif dtype == dace.complex128: - alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" - beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + if node.alpha == 1: + alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" + else: + alpha = f"dace::blas::make_cuDoubleComplex({node.alpha}, 0)" + if node.beta == 0: + beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + else: + beta = f"dace::blas::make_cuDoubleComplex({node.beta}, 0)" else: raise ValueError("Unsupported type for BLAS: " + str(dtype)) @@ -514,18 +579,31 @@ def expansion(node, state, sdfg): astrides, bstrides, dtype, is_a_1d, is_b_1d) func = to_blastype(dtype.type).lower() + 'gemm' + # Use node's alpha and beta values if dtype == dace.float32: - alpha = "1.0f" - beta = "0.0f" + alpha = f"{float(node.alpha)}f" + beta = f"{float(node.beta)}f" elif dtype == dace.float64: - alpha = "1.0" - beta = "0.0" + alpha = f"{float(node.alpha)}" + beta = f"{float(node.beta)}" elif dtype == dace.complex64: - alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" - beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + if node.alpha == 1: + alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" + else: + alpha = f"dace::blas::make_cuComplex({node.alpha}, 0)" + if node.beta == 0: + beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + else: + beta = f"dace::blas::make_cuComplex({node.beta}, 0)" elif dtype == dace.complex128: - alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" - beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + if node.alpha == 1: + alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" + else: + alpha = f"dace::blas::make_cuDoubleComplex({node.alpha}, 0)" + if node.beta == 0: + beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + else: + beta = f"dace::blas::make_cuDoubleComplex({node.beta}, 0)" else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) @@ -612,26 +690,43 @@ def expansion(node, state, sdfg): 1.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Pone()", 0.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Zero()", } + + # Handle alpha if node.alpha not in constants: # Deal with complex input constants if isinstance(node.alpha, complex): - alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' + alpha_val = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' else: - alpha = f'{dtype.ctype}({node.alpha})' + alpha_val = f'{dtype.ctype}({node.alpha})' + use_host_mode_alpha = True + else: + alpha = constants[node.alpha] + use_host_mode_alpha = False - # Set pointer mode to host - call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST); - {dtype.ctype} alpha = {alpha}; - {dtype.ctype} beta = 0; - ''' + # Handle beta + if node.beta not in constants: + # Deal with complex input constants + if isinstance(node.beta, complex): + beta_val = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})' + else: + beta_val = f'{dtype.ctype}({node.beta})' + use_host_mode_beta = True + else: + beta = constants[node.beta] + use_host_mode_beta = False + + # Set pointer mode to host if needed + if use_host_mode_alpha or use_host_mode_beta: + call_prefix += 'cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST);\n' + if use_host_mode_alpha: + call_prefix += f' {dtype.ctype} alpha = {alpha_val};\n' + alpha = f'({cdtype} *)&alpha' + if use_host_mode_beta: + call_prefix += f' {dtype.ctype} beta = {beta_val};\n' + beta = f'({cdtype} *)&beta' call_suffix += ''' cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE); ''' - beta = f'({cdtype} *)&beta' - alpha = f'({cdtype} *)&alpha' - else: - alpha = constants[node.alpha] - beta = "__state->cublas_handle.Constants(__dace_cuda_device).%sZero()" % factort # Set up options for code formatting opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index 003ab45bba..4aa5c7f52f 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -192,13 +192,32 @@ def expansion(node, state, sdfg): opt['alpha'] = '&__alpha' opt['beta'] = '&__beta' + # Handle the case when cin=True and beta != 0 (node has _c as both input and output) + # Since BLAS GEMM does in-place read-modify-write on C, and tasklets cannot have + # duplicate connectors, we remove the input _c connector. The BLAS call will read + # from and write to the same memory location (_c output). + # + # We also remove the incoming edge and orphaned access node to maintain graph validity. + in_connectors = {} + for k, v in node.in_connectors.items(): + if k == '_c': + # Remove the incoming edge to _c and the source access node if it becomes isolated + for edge in list(state.in_edges_by_connector(node, '_c')): + src_node = edge.src + state.remove_edge(edge) + # Remove the access node if it has no other edges + if state.degree(src_node) == 0: + state.remove_node(src_node) + else: + in_connectors[k] = v + code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " "_c, {ldc});").format_map(opt) tasklet = dace.sdfg.nodes.Tasklet( node.name, - node.in_connectors, + in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP, diff --git a/dace/libraries/blas/nodes/gemv.py b/dace/libraries/blas/nodes/gemv.py index a791faba10..e3ad66642b 100644 --- a/dace/libraries/blas/nodes/gemv.py +++ b/dace/libraries/blas/nodes/gemv.py @@ -786,8 +786,31 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs): code += f"""cblas_{func}({layout}, {trans}, {m}, {n}, {alpha}, _A, {lda}, _x, {strides_x[0]}, {beta}, _y, {strides_y[0]});""" + # Handle the case when beta != 0 (node has _y as both input and output) + # NOTE: This happens when the Gemv node is created with beta != 0 (see __init__ line 915). + # The pure implementation needs y as input to explicitly scale it, but BLAS implementations + # handle this internally. + # + # Since BLAS GEMV does in-place read-modify-write on y, and tasklets cannot have + # duplicate connectors, we remove the input _y connector. The BLAS call will read + # from and write to the same memory location (_y output). + # + # We also remove the incoming edge and orphaned access node to maintain graph validity. + in_connectors = {} + for k, v in node.in_connectors.items(): + if k == '_y': + # Remove the incoming edge to _y and the source access node if it becomes isolated + for edge in list(state.in_edges_by_connector(node, '_y')): + src_node = edge.src + state.remove_edge(edge) + # Remove the access node if it has no other edges + if state.degree(src_node) == 0: + state.remove_node(src_node) + else: + in_connectors[k] = v + tasklet = dace.sdfg.nodes.Tasklet(node.name, - node.in_connectors, + in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) diff --git a/tests/library/test_matmul_accumulate.py b/tests/library/test_matmul_accumulate.py new file mode 100644 index 0000000000..27ebd75b93 --- /dev/null +++ b/tests/library/test_matmul_accumulate.py @@ -0,0 +1,260 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +"""Tests for matmul accumulation (beta factor) across all matmul implementations.""" +import pytest +import numpy as np +import dace +import dace.libraries.blas as blas +from dace.library import change_default + + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu), + pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_batched_matmul_accumulate(implementation: str, dtype): + """Test batched matmul with non-zero beta (accumulation into existing output)""" + b, m, n, k = 3, 8, 8, 8 + + sdfg = dace.SDFG('batched_matmul_accumulate') + + # Add arrays + sdfg.add_array("A", [b, m, k], dtype) + sdfg.add_array("B", [b, k, n], dtype) + sdfg.add_array("C", [b, m, n], dtype) + + state = sdfg.add_state() + + a_in = state.add_read("A") + b_in = state.add_read("B") + c_out = state.add_write("C") + + bmm_node = blas.nodes.batched_matmul.BatchedMatMul('bmm', ) + bmm_node.alpha = 1.0 + bmm_node.beta = 2.0 + + state.add_node(bmm_node) + state.add_edge(a_in, None, bmm_node, '_a', dace.Memlet.from_array("A", sdfg.arrays["A"])) + state.add_edge(b_in, None, bmm_node, '_b', dace.Memlet.from_array("B", sdfg.arrays["B"])) + state.add_edge(bmm_node, '_c', c_out, None, dace.Memlet.from_array("C", sdfg.arrays["C"])) + + with change_default(blas, implementation): + sdfg.expand_library_nodes() + sdfg.validate() + + A = np.random.rand(b, m, k).astype(dtype.as_numpy_dtype()) + B = np.random.rand(b, k, n).astype(dtype.as_numpy_dtype()) + C_initial = np.random.rand(b, m, n).astype(dtype.as_numpy_dtype()) + C = C_initial.copy() + + csdfg = sdfg.compile() + csdfg(A=A, B=B, C=C) + + # C = alpha * A @ B + beta * C_initial = 1.0 * A @ B + 2.0 * C_initial + ref = A @ B + 2.0 * C_initial + + assert np.allclose(ref, C), f"Test failed for {implementation} with dtype {dtype}" + + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu), + pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_gemm_accumulate(implementation: str, dtype): + """Test GEMM with non-zero beta (accumulation into existing output)""" + m, n, k = 16, 16, 16 + + sdfg = dace.SDFG('gemm_accumulate') + + # Add arrays + sdfg.add_array("A", [m, k], dtype) + sdfg.add_array("B", [k, n], dtype) + sdfg.add_array("C", [m, n], dtype) + + state = sdfg.add_state() + + a_in = state.add_read("A") + b_in = state.add_read("B") + c_out = state.add_write("C") + + # Create GEMM node with alpha=1.5 and beta=2.0 + # For BLAS implementations, cin=False even when beta != 0, because BLAS handles reading C + gemm_node = blas.nodes.gemm.Gemm('gemm', alpha=1.5, beta=2.0, cin=False) + + state.add_node(gemm_node) + state.add_edge(a_in, None, gemm_node, '_a', dace.Memlet.from_array("A", sdfg.arrays["A"])) + state.add_edge(b_in, None, gemm_node, '_b', dace.Memlet.from_array("B", sdfg.arrays["B"])) + state.add_edge(gemm_node, '_c', c_out, None, dace.Memlet.from_array("C", sdfg.arrays["C"])) + + with change_default(blas, implementation): + sdfg.expand_library_nodes() + sdfg.validate() + + A = np.random.rand(m, k).astype(dtype.as_numpy_dtype()) + B = np.random.rand(k, n).astype(dtype.as_numpy_dtype()) + C_initial = np.random.rand(m, n).astype(dtype.as_numpy_dtype()) + C = C_initial.copy() + + csdfg = sdfg.compile() + csdfg(A=A, B=B, C=C) + + # C = alpha * A @ B + beta * C_initial = 1.5 * A @ B + 2.0 * C_initial + ref = 1.5 * (A @ B) + 2.0 * C_initial + + assert np.allclose(ref, C), f"Test failed for {implementation} with dtype {dtype}" + + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu), + pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_gemv_accumulate(implementation: str, dtype): + """Test GEMV with non-zero beta (accumulation into existing output)""" + m, n = 32, 32 + + sdfg = dace.SDFG('gemv_accumulate') + + # Add arrays + sdfg.add_array("A", [m, n], dtype) + sdfg.add_array("x", [n], dtype) + sdfg.add_array("y", [m], dtype) + + state = sdfg.add_state() + + a_in = state.add_read("A") + x_in = state.add_read("x") + y_in = state.add_read("y") + y_out = state.add_write("y") + + gemv_node = blas.nodes.gemv.Gemv('gemv', alpha=1.5, beta=2.0) + + state.add_node(gemv_node) + state.add_edge(a_in, None, gemv_node, '_A', dace.Memlet.from_array("A", sdfg.arrays["A"])) + state.add_edge(x_in, None, gemv_node, '_x', dace.Memlet.from_array("x", sdfg.arrays["x"])) + # For GEMV, when beta != 0, _y is both an input and output + state.add_edge(y_in, None, gemv_node, '_y', dace.Memlet.from_array("y", sdfg.arrays["y"])) + state.add_edge(gemv_node, '_y', y_out, None, dace.Memlet.from_array("y", sdfg.arrays["y"])) + + with change_default(blas, implementation): + sdfg.expand_library_nodes() + sdfg.validate() + + A = np.random.rand(m, n).astype(dtype.as_numpy_dtype()) + x = np.random.rand(n).astype(dtype.as_numpy_dtype()) + y_initial = np.random.rand(m).astype(dtype.as_numpy_dtype()) + y = y_initial.copy() + + csdfg = sdfg.compile() + csdfg(A=A, x=x, y=y) + + # y = alpha * A @ x + beta * y_initial = 1.5 * A @ x + 2.0 * y_initial + ref = 1.5 * (A @ x) + 2.0 * y_initial + + assert np.allclose(ref, y), f"Test failed for {implementation} with dtype {dtype}" + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("pure", dace.float32), + pytest.param("pure", dace.float64), + pytest.param("MKL", dace.float32, marks=pytest.mark.mkl), + pytest.param("MKL", dace.float64, marks=pytest.mark.mkl), + pytest.param("cuBLAS", dace.float32, marks=pytest.mark.gpu), + pytest.param("cuBLAS", dace.float64, marks=pytest.mark.gpu), + pytest.param("OpenBLAS", dace.float32, marks=pytest.mark.lapack), + pytest.param("OpenBLAS", dace.float64, marks=pytest.mark.lapack) +]) +def test_batched_matmul_accumulate(implementation: str, dtype): + """Test batched matmul with non-zero beta (accumulation into existing output)""" + b, m, n, k = 3, 8, 8, 8 + + sdfg = dace.SDFG('batched_matmul_accumulate') + + # Add arrays + sdfg.add_array("A", [b, m, k], dtype) + sdfg.add_array("B", [b, k, n], dtype) + sdfg.add_array("C", [b, m, n], dtype) + + # Create state and add BatchedMatMul node with beta=2.0 + state = sdfg.add_state() + + a_in = state.add_read("A") + b_in = state.add_read("B") + c_out = state.add_write("C") + + # Create BatchedMatMul node with alpha=1.0 and beta=2.0 + bmm_node = blas.nodes.batched_matmul.BatchedMatMul('bmm') + bmm_node.alpha = 1.0 + bmm_node.beta = 2.0 + + state.add_node(bmm_node) + state.add_edge(a_in, None, bmm_node, '_a', dace.Memlet.from_array("A", sdfg.arrays["A"])) + state.add_edge(b_in, None, bmm_node, '_b', dace.Memlet.from_array("B", sdfg.arrays["B"])) + state.add_edge(bmm_node, '_c', c_out, None, dace.Memlet.from_array("C", sdfg.arrays["C"])) + + # Set the implementation + if implementation != "pure": + # Expand library node with specific implementation + bmm_node.implementation = implementation + + with change_default(blas, implementation): + sdfg.expand_library_nodes() + sdfg.validate() + + # Create test data + A = np.random.rand(b, m, k).astype(dtype.as_numpy_dtype()) + B = np.random.rand(b, k, n).astype(dtype.as_numpy_dtype()) + C_initial = np.random.rand(b, m, n).astype(dtype.as_numpy_dtype()) + C = C_initial.copy() + + # Compile and run + csdfg = sdfg.compile() + csdfg(A=A, B=B, C=C) + + # Expected result: C = alpha * A @ B + beta * C_initial = 1.0 * A @ B + 2.0 * C_initial + ref = A @ B + 2.0 * C_initial + + assert np.allclose(ref, C), f"Test failed for {implementation} with dtype {dtype}" + +if __name__ == "__main__": + test_batched_matmul_accumulate("pure", dace.float32) + test_batched_matmul_accumulate("pure", dace.float64) + + test_batched_matmul_accumulate("MKL", dace.float32) + test_batched_matmul_accumulate("MKL", dace.float64) + + test_batched_matmul_accumulate("OpenBLAS", dace.float32) + test_batched_matmul_accumulate("OpenBLAS", dace.float64) + + test_gemm_accumulate("pure", dace.float32) + test_gemm_accumulate("pure", dace.float64) + + test_gemm_accumulate("MKL", dace.float32) + test_gemm_accumulate("MKL", dace.float64) + + test_gemm_accumulate("OpenBLAS", dace.float32) + test_gemm_accumulate("OpenBLAS", dace.float64) + + test_gemv_accumulate("pure", dace.float32) + test_gemv_accumulate("pure", dace.float64) + + test_gemv_accumulate("MKL", dace.float32) + test_gemv_accumulate("MKL", dace.float64) + + test_gemv_accumulate("OpenBLAS", dace.float32) + test_gemv_accumulate("OpenBLAS", dace.float64) From 2657c72b86ca0443b0fd9e1c295effdb85582b1b Mon Sep 17 00:00:00 2001 From: Affifboudaoud Date: Wed, 29 Oct 2025 17:46:47 +0100 Subject: [PATCH 4/7] Fix formatting --- tests/library/test_matmul_accumulate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/library/test_matmul_accumulate.py b/tests/library/test_matmul_accumulate.py index 27ebd75b93..12126e023a 100644 --- a/tests/library/test_matmul_accumulate.py +++ b/tests/library/test_matmul_accumulate.py @@ -169,6 +169,7 @@ def test_gemv_accumulate(implementation: str, dtype): assert np.allclose(ref, y), f"Test failed for {implementation} with dtype {dtype}" + @pytest.mark.parametrize("implementation, dtype", [ pytest.param("pure", dace.float32), pytest.param("pure", dace.float64), @@ -230,7 +231,8 @@ def test_batched_matmul_accumulate(implementation: str, dtype): ref = A @ B + 2.0 * C_initial assert np.allclose(ref, C), f"Test failed for {implementation} with dtype {dtype}" - + + if __name__ == "__main__": test_batched_matmul_accumulate("pure", dace.float32) test_batched_matmul_accumulate("pure", dace.float64) From 22895d731447b8e28c7865d7160d67ed870a71d6 Mon Sep 17 00:00:00 2001 From: Affifboudaoud Date: Sat, 8 Nov 2025 14:38:07 +0100 Subject: [PATCH 5/7] Fix for 1d broadcasting --- dace/libraries/blas/nodes/batched_matmul.py | 22 +++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/dace/libraries/blas/nodes/batched_matmul.py b/dace/libraries/blas/nodes/batched_matmul.py index 73a753d8b9..9a77cebee5 100644 --- a/dace/libraries/blas/nodes/batched_matmul.py +++ b/dace/libraries/blas/nodes/batched_matmul.py @@ -141,7 +141,16 @@ def make_sdfg(node, parent_state, parent_sdfg): num_a_batch = len(array_a.shape) - 2 # Start from the rightmost batch dimension of output and work backwards offset = num_batch_dims - num_a_batch - a_batch_indices = ', '.join(['__i%d' % (offset + i) for i in range(num_a_batch)]) + # Handle broadcasting: if dimension is 1, use index 0 instead of loop variable + a_batch_indices_parts = [] + for i in range(num_a_batch): + if array_a.shape[i] == 1: + # Broadcast dimension: always access index 0 + a_batch_indices_parts.append('0') + else: + # Regular dimension: use loop variable + a_batch_indices_parts.append('__i%d' % (offset + i)) + a_batch_indices = ', '.join(a_batch_indices_parts) memlet_a = f'{a_batch_indices}, __im, __ik' # For B: if 1D, use [K]; if 2D, use [K, N]; if 3D+, use [batch_indices..., K, N] @@ -157,7 +166,16 @@ def make_sdfg(node, parent_state, parent_sdfg): num_b_batch = len(array_b.shape) - 2 # Start from the rightmost batch dimension of output and work backwards offset = num_batch_dims - num_b_batch - b_batch_indices = ', '.join(['__i%d' % (offset + i) for i in range(num_b_batch)]) + # Handle broadcasting: if dimension is 1, use index 0 instead of loop variable + b_batch_indices_parts = [] + for i in range(num_b_batch): + if array_b.shape[i] == 1: + # Broadcast dimension: always access index 0 + b_batch_indices_parts.append('0') + else: + # Regular dimension: use loop variable + b_batch_indices_parts.append('__i%d' % (offset + i)) + b_batch_indices = ', '.join(b_batch_indices_parts) memlet_b = f'{b_batch_indices}, __ik, __in' # For C: build indices matching the output shape From 5edead30918f23fe66db3a3a35f3de77b9d07fa6 Mon Sep 17 00:00:00 2001 From: affifboudaoud Date: Sun, 9 Nov 2025 15:38:47 +0100 Subject: [PATCH 6/7] Remove Memlet squeezing from BLAS library expansions --- dace/libraries/blas/nodes/batched_matmul.py | 8 ++----- dace/libraries/blas/nodes/dot.py | 19 +++++++--------- dace/libraries/blas/nodes/einsum.py | 18 ++++------------ dace/libraries/blas/nodes/gemv.py | 24 ++++++++------------- dace/libraries/blas/nodes/ger.py | 18 +++++----------- dace/libraries/blas/nodes/matmul.py | 12 ++--------- 6 files changed, 30 insertions(+), 69 deletions(-) diff --git a/dace/libraries/blas/nodes/batched_matmul.py b/dace/libraries/blas/nodes/batched_matmul.py index 9a77cebee5..3fff44291e 100644 --- a/dace/libraries/blas/nodes/batched_matmul.py +++ b/dace/libraries/blas/nodes/batched_matmul.py @@ -934,13 +934,9 @@ def validate(self, sdfg, state): raise ValueError("Expected exactly two inputs to batched matrix-matrix product") for _, _, _, dst_conn, memlet in state.in_edges(self): if dst_conn == '_a': - subset = dc(memlet.subset) - subset.squeeze() - size0 = subset.size() + size0 = memlet.subset.size() if dst_conn == '_b': - subset = dc(memlet.subset) - subset.squeeze() - size1 = subset.size() + size1 = memlet.subset.size() out_edges = state.out_edges(self) if len(out_edges) != 1: raise ValueError("Expected exactly one output from " diff --git a/dace/libraries/blas/nodes/dot.py b/dace/libraries/blas/nodes/dot.py index c994504048..b9e562db09 100644 --- a/dace/libraries/blas/nodes/dot.py +++ b/dace/libraries/blas/nodes/dot.py @@ -547,22 +547,19 @@ def validate(self, sdfg, state): if desc_x.dtype.base_type != desc_res.dtype.base_type: raise TypeError(f"Data types of input and output must be equal: {desc_x.dtype}, {desc_res.dtype}") - # Squeeze input memlets - squeezed1 = copy.deepcopy(in_memlets[0].subset) - squeezed2 = copy.deepcopy(in_memlets[1].subset) - sqdims1 = squeezed1.squeeze() - sqdims2 = squeezed2.squeeze() + # Get input sizes + size1 = in_memlets[0].subset.size() + size2 = in_memlets[1].subset.size() - if len(squeezed1.size()) != 1 or len(squeezed2.size()) != 1: + if len(size1) != 1 or len(size2) != 1: raise ValueError("dot product only supported on 1-dimensional arrays") if out_memlet.subset.num_elements() != 1: raise ValueError("Output of dot product must be a single element") - # We are guaranteed that there is only one non-squeezed dimension - stride_x = desc_x.strides[sqdims1[0]] - stride_y = desc_y.strides[sqdims2[0]] - n = squeezed1.num_elements() - if squeezed1.num_elements() != squeezed2.num_elements(): + stride_x = desc_x.strides[0] + stride_y = desc_y.strides[0] + n = size1[0] + if size1[0] != size2[0]: raise ValueError('Size mismatch in inputs') return (desc_x, stride_x), (desc_y, stride_y), desc_res, n diff --git a/dace/libraries/blas/nodes/einsum.py b/dace/libraries/blas/nodes/einsum.py index 0c4fbe96b0..fe2d301f10 100644 --- a/dace/libraries/blas/nodes/einsum.py +++ b/dace/libraries/blas/nodes/einsum.py @@ -49,24 +49,14 @@ def expansion(node: Einsum, parent_state: SDFGState, parent_sdfg: SDFG) -> SDFG: for e in parent_state.in_edges(node): inputs.append(e.dst_conn) desc = parent_sdfg.arrays[e.data.data] - insubset = deepcopy(e.data.src_subset) - isqdim = insubset.squeeze() - sdfg.add_array(e.dst_conn, - insubset.size(), - desc.dtype, - strides=[s for i, s in enumerate(desc.strides) if i in isqdim], - storage=desc.storage) + insubset_size = e.data.src_subset.size() + sdfg.add_array(e.dst_conn, insubset_size, desc.dtype, strides=desc.strides, storage=desc.storage) for e in parent_state.out_edges(node): output = e.src_conn desc = parent_sdfg.arrays[e.data.data] - outsubset = deepcopy(e.data.dst_subset) - osqdim = outsubset.squeeze() - sdfg.add_array(output, - outsubset.size(), - desc.dtype, - strides=[s for i, s in enumerate(desc.strides) if i in osqdim], - storage=desc.storage) + outsubset_size = e.data.dst_subset.size() + sdfg.add_array(output, outsubset_size, desc.dtype, strides=desc.strides, storage=desc.storage) ####################################### # Fill SDFG with einsum contents diff --git a/dace/libraries/blas/nodes/gemv.py b/dace/libraries/blas/nodes/gemv.py index e3ad66642b..cda06a466f 100644 --- a/dace/libraries/blas/nodes/gemv.py +++ b/dace/libraries/blas/nodes/gemv.py @@ -931,20 +931,14 @@ def validate(self, sdfg, state): size_y_in = None for _, _, _, dst_conn, memlet in state.in_edges(self): if dst_conn == "_A": - subset = copy.deepcopy(memlet.subset) - subset.squeeze() - size_a = subset.size() + size_a = memlet.subset.size() if dst_conn == "_x": - subset = copy.deepcopy(memlet.subset) - subset.squeeze() - size_x = subset.size() + size_x = memlet.subset.size() if dst_conn == "_y": - subset = copy.deepcopy(memlet.subset) - subset.squeeze() - size_y_in = subset.size() + size_y_in = memlet.subset.size() if len(size_a) != 2 or len(size_x) != 1: - raise ValueError("Matrix-vector product only supported on matrix-vector input") + raise ValueError("Matrix-vector product only supported on 2D matrix and 1D vector") a_cols = size_a[1] if not self.transA else size_a[0] a_rows = size_a[0] if not self.transA else size_a[1] @@ -958,11 +952,11 @@ def validate(self, sdfg, state): raise ValueError("Expected exactly one output from matrix-vector product") out_memlet = out_edges[0].data - out_subset = copy.deepcopy(out_memlet.subset) - out_subset.squeeze() - size_y_out = out_subset.size() - if size_y_in is not None and size_y_in != size_y_out: - raise ValueError("Input y-vector must match output y-vector.") + size_y_out = out_memlet.subset.size() + + if size_y_in is not None: + if size_y_in != size_y_out: + raise ValueError("Input y-vector must match output y-vector.") if (len(size_y_out) != 1 or size_y_out[0] != a_rows): raise ValueError("Vector input to GEMV must match matrix rows.") diff --git a/dace/libraries/blas/nodes/ger.py b/dace/libraries/blas/nodes/ger.py index c22f8f7010..17ce639b4a 100644 --- a/dace/libraries/blas/nodes/ger.py +++ b/dace/libraries/blas/nodes/ger.py @@ -266,19 +266,13 @@ def validate(self, sdfg, state): size_y = None for _, _, _, dst_conn, memlet in state.in_edges(self): if dst_conn == "_A": - subset = copy.deepcopy(memlet.subset) - subset.squeeze() - size_a = subset.size() + size_a = memlet.subset.size() desc_a = sdfg.arrays[memlet.data] if dst_conn == "_x": - subset = copy.deepcopy(memlet.subset) - subset.squeeze() - size_x = subset.size() + size_x = memlet.subset.size() desc_x = sdfg.arrays[memlet.data] if dst_conn == "_y": - subset = copy.deepcopy(memlet.subset) - subset.squeeze() - size_y = subset.size() + size_y = memlet.subset.size() desc_y = sdfg.arrays[memlet.data] if size_a is None or size_x is None: @@ -292,7 +286,7 @@ def validate(self, sdfg, state): return if len(size_a) != 2: - raise ValueError("A must be a matrix") + raise ValueError("A must be a 2-dimensional matrix") if len(size_x) != 1: raise ValueError("x must be a vector") if len(size_y) != 1: @@ -306,9 +300,7 @@ def validate(self, sdfg, state): raise ValueError("Expected exactly one output from ger rank 1 operation.") out_memlet = out_edges[0].data - out_subset = copy.deepcopy(out_memlet.subset) - out_subset.squeeze() - size_out = out_subset.size() + size_out = out_memlet.subset.size() if (len(size_out) != 2 or size_out[0] != size_a[0] or size_out[1] != size_a[1]): raise ValueError("Output matrix must match input matrix a and outer product x*yT.") diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index a8824539e3..ab5a33b9c2 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -14,26 +14,18 @@ def _get_matmul_operands(node, state, sdfg, name_lhs="_a", name_rhs="_b", name_o for edge in state.all_edges(node): if edge.dst_conn in [name_lhs, name_rhs]: size = edge.data.subset.size() - squeezed = dc(edge.data.subset) - squeezed_dims = squeezed.squeeze() - squeezed_size = squeezed.size() outer_array = sdfg.data(dace.sdfg.find_input_arraynode(state, edge).data) strides = list(outer_array.strides) - squeezed_strides = [s for i, s in enumerate(outer_array.strides) if i in squeezed_dims] - res = edge, outer_array, size, strides, squeezed_size, squeezed_strides + res = edge, outer_array, size, strides, size, strides if edge.dst_conn == name_lhs: res_lhs = res else: res_rhs = res elif edge.src_conn == name_out: size = edge.data.subset.size() - squeezed = dc(edge.data.subset) - squeezed_dims = squeezed.squeeze() - squeezed_size = squeezed.size() outer_array = sdfg.data(dace.sdfg.find_output_arraynode(state, edge).data) strides = list(outer_array.strides) - squeezed_strides = [s for i, s in enumerate(outer_array.strides) if i in squeezed_dims] - res_out = edge, outer_array, size, strides, squeezed_size, squeezed_strides + res_out = edge, outer_array, size, strides, size, strides for res, name in ((res_lhs, name_lhs), (res_rhs, name_rhs), (res_out, name_out)): if res is None: raise ValueError("Matrix multiplication connector " From 6518e60ff5840b2e6f915beece7747fe7728fba5 Mon Sep 17 00:00:00 2001 From: affifboudaoud Date: Sun, 9 Nov 2025 18:07:11 +0100 Subject: [PATCH 7/7] Fix Scalar bug in GEMM BLAS expansion --- dace/codegen/codegen.py | 1 - dace/frontend/common/einsum.py | 4 +++- dace/libraries/blas/nodes/gemm.py | 20 +++++++++++++++++--- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index 3ccbb56dc6..555dd9a111 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -176,7 +176,6 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: sdfg.save(f'{tmp_dir}/test.sdfg', hash=False) sdfg2 = SDFG.from_file(f'{tmp_dir}/test.sdfg') sdfg2.save(f'{tmp_dir}/test2.sdfg', hash=False) - print('Testing SDFG serialization...') if not filecmp.cmp(f'{tmp_dir}/test.sdfg', f'{tmp_dir}/test2.sdfg'): with open(f'{tmp_dir}/test.sdfg', 'r') as f1: with open(f'{tmp_dir}/test2.sdfg', 'r') as f2: diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index df1c8de34e..8876b05a41 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -226,7 +226,9 @@ def _create_einsum_internal(sdfg: SDFG, for inp, inpname in zip(einsum.inputs, arrays): inparr = sdfg.arrays[inpname] if len(inp) != len(inparr.shape): - raise ValueError('Dimensionality mismatch in input "%s"' % inpname) + raise ValueError(f'Dimensionality mismatch in input "{inpname}": ' + f'einsum subscript has {len(inp)} dimensions but array has ' + f'{len(inparr.shape)} dimensions') for char, shp in zip(inp, inparr.shape): if char in chardict and shp != chardict[char]: raise ValueError('Dimension mismatch in einsum expression') diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index 4aa5c7f52f..e1f260c212 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -12,6 +12,7 @@ from .. import environments import numpy as np import warnings +from dace.codegen.common import sym2cpp def _is_complex(dtype): @@ -211,9 +212,22 @@ def expansion(node, state, sdfg): else: in_connectors[k] = v - code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, " - "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " - "_c, {ldc});").format_map(opt) + # Check if output is scalar-sized (1x1) + is_scalar_output = (sym2cpp(opt['M']) == '1' and sym2cpp(opt['N']) == '1') + + if is_scalar_output: + # For scalar outputs, we need to use a local array and copy + code += f''' + {dtype.ctype} _c_array[1]; + ''' + code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, " + "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " + "_c_array, {ldc});\n").format_map(opt) + code += "_c = _c_array[0];" + else: + code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, " + "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " + "_c, {ldc});").format_map(opt) tasklet = dace.sdfg.nodes.Tasklet( node.name,