diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 06fa56bb085..33b5ce0b345 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -4399,6 +4399,39 @@ std::vector CatOp::evaluate( return {at::cat(unpadded_inputs, concat_dim)}; } +namespace { + +// Given a tensorview, compute the strides according to the allocation domain +// for re-striding the corresponding ATen tensor. +std::vector computeStrides( + TensorView* tv, + const c10::IntArrayRef sizes) { + const auto& logical_domain = tv->getLogicalDomain(); + const auto& allocation_domain = tv->getMaybeAllocationDomain(); + + std::optional> out_order = ir_utils::computePermutation( + TensorDomain::noReductions(logical_domain), + TensorDomain::noReductions(allocation_domain)); + NVF_CHECK( + out_order.has_value(), + "Valid permute from logical to allocation domain was not found."); + + auto rank = sizes.size(); + std::vector sorted_strides(rank); + auto permuted_sizes = ir_utils::applyPermutation(sizes.vec(), *out_order); + sorted_strides[rank - 1] = 1; + for (int64_t idx = (int64_t)rank - 2; idx >= 0; idx--) { + sorted_strides[idx] = permuted_sizes[idx + 1] * sorted_strides[idx + 1]; + } + // Rearrange the strides in correct order of allocation + std::vector strides(rank); + for (auto idx : c10::irange(rank)) { + strides[out_order.value()[idx]] = sorted_strides[idx]; + } + return strides; +} +} // namespace + MatmulOp::MatmulOp(IrBuilderPasskey passkey, Val* out, Val* in_a, Val* in_b) : Expr(passkey) { addOutput(out); @@ -4425,7 +4458,17 @@ std::vector MatmulOp::evaluate( const std::vector& inputs) const { const auto a = inputs.at(0).as(); const auto b = inputs.at(1).as(); - return {at::matmul(a, b)}; + + auto matmul_out = at::matmul(a, b); + if (ir_utils::hasTrivialAllocationDomain(out())) { + return {matmul_out}; + } + auto matmul_sizes = matmul_out.sizes(); + auto strides = computeStrides(out(), matmul_sizes); + auto strided_matmul_out = + at::empty_strided(matmul_sizes, strides, a.options()); + strided_matmul_out = strided_matmul_out.copy_(matmul_out); + return {strided_matmul_out}; } LinearOp::LinearOp( diff --git a/tests/python/test_matmul.py b/tests/python/test_matmul.py index 14176154e1f..ed1b12c2063 100644 --- a/tests/python/test_matmul.py +++ b/tests/python/test_matmul.py @@ -4,7 +4,7 @@ # Owner(s): ["module: nvfuser"] import torch -from utils import NVFuserTest, is_pre_volta +from utils import NVFuserTest, is_pre_volta, verify_stride_order from nvfuser import FusionDefinition, DataType import pytest from functools import partial @@ -201,3 +201,29 @@ def fusion_func(fd: FusionDefinition) -> None: ] outputs, _ = self.exec_nvfuser(fusion_func, inputs) assert outputs[0].ndim == 3 + + def test_matmul_stride(self): + n, h, l, s, e = 4, 8, 16, 16, 8 + inputs = [ + torch.randn( + n, h, l, e, device="cuda", dtype=torch.float16, requires_grad=True + ), + torch.randn( + n, h, s, e, device="cuda", dtype=torch.float16, requires_grad=True + ), + ] + for perm in itertools.permutations(range(4), 4): + + def fusion_func(fd: FusionDefinition) -> None: + q = fd.from_pytorch(inputs[0]) + k = fd.from_pytorch(inputs[1]) + k_t = fd.ops.permute(k, [0, 1, 3, 2]) + out = fd.ops.matmul(q, k_t) + fd.add_output(out, stride_order=perm) + + with FusionDefinition() as fd: + fusion_func(fd) + nvf_out = fd.execute(inputs) + eager_out = torch.matmul(inputs[0], torch.transpose(inputs[1], -2, -1)) + verify_stride_order(nvf_out[0].stride(), perm) + torch.testing.assert_close(nvf_out[0], eager_out)