Skip to content

Commit

Permalink
Stride MatmulOp according to set allocation domain (#3447)
Browse files Browse the repository at this point in the history
Resolves Issue #2427.
If the `MatmulOp` has a stride order set from python frontend
(`fd.ops.add_output/fd.ops.stride_order`), returns a copy of the output
with the specified memory_layout.

`at::matmul_out` is not used since it does not allow inputs/outputs
which require gradients.

https://github.com/pytorch/pytorch/blob/1f3d8896bc9cea7f46c50ff92b69c6aa139defcb/aten/src/ATen/native/LinearAlgebra.cpp#L2018-L2025

---------

Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
  • Loading branch information
Priya2698 and wujingyue authored Dec 17, 2024
1 parent 1136753 commit d623221
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
45 changes: 44 additions & 1 deletion csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4399,6 +4399,39 @@ std::vector<PolymorphicValue> CatOp::evaluate(
return {at::cat(unpadded_inputs, concat_dim)};
}

namespace {

// Given a tensorview, compute the strides according to the allocation domain
// for re-striding the corresponding ATen tensor.
std::vector<int64_t> computeStrides(
TensorView* tv,
const c10::IntArrayRef sizes) {
const auto& logical_domain = tv->getLogicalDomain();
const auto& allocation_domain = tv->getMaybeAllocationDomain();

std::optional<std::vector<int64_t>> out_order = ir_utils::computePermutation(
TensorDomain::noReductions(logical_domain),
TensorDomain::noReductions(allocation_domain));
NVF_CHECK(
out_order.has_value(),
"Valid permute from logical to allocation domain was not found.");

auto rank = sizes.size();
std::vector<int64_t> sorted_strides(rank);
auto permuted_sizes = ir_utils::applyPermutation(sizes.vec(), *out_order);
sorted_strides[rank - 1] = 1;
for (int64_t idx = (int64_t)rank - 2; idx >= 0; idx--) {
sorted_strides[idx] = permuted_sizes[idx + 1] * sorted_strides[idx + 1];
}
// Rearrange the strides in correct order of allocation
std::vector<int64_t> strides(rank);
for (auto idx : c10::irange(rank)) {
strides[out_order.value()[idx]] = sorted_strides[idx];
}
return strides;
}
} // namespace

MatmulOp::MatmulOp(IrBuilderPasskey passkey, Val* out, Val* in_a, Val* in_b)
: Expr(passkey) {
addOutput(out);
Expand All @@ -4425,7 +4458,17 @@ std::vector<PolymorphicValue> MatmulOp::evaluate(
const std::vector<PolymorphicValue>& inputs) const {
const auto a = inputs.at(0).as<at::Tensor>();
const auto b = inputs.at(1).as<at::Tensor>();
return {at::matmul(a, b)};

auto matmul_out = at::matmul(a, b);
if (ir_utils::hasTrivialAllocationDomain(out())) {
return {matmul_out};
}
auto matmul_sizes = matmul_out.sizes();
auto strides = computeStrides(out(), matmul_sizes);
auto strided_matmul_out =
at::empty_strided(matmul_sizes, strides, a.options());
strided_matmul_out = strided_matmul_out.copy_(matmul_out);
return {strided_matmul_out};
}

LinearOp::LinearOp(
Expand Down
28 changes: 27 additions & 1 deletion tests/python/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Owner(s): ["module: nvfuser"]

import torch
from utils import NVFuserTest, is_pre_volta
from utils import NVFuserTest, is_pre_volta, verify_stride_order
from nvfuser import FusionDefinition, DataType
import pytest
from functools import partial
Expand Down Expand Up @@ -201,3 +201,29 @@ def fusion_func(fd: FusionDefinition) -> None:
]
outputs, _ = self.exec_nvfuser(fusion_func, inputs)
assert outputs[0].ndim == 3

def test_matmul_stride(self):
n, h, l, s, e = 4, 8, 16, 16, 8
inputs = [
torch.randn(
n, h, l, e, device="cuda", dtype=torch.float16, requires_grad=True
),
torch.randn(
n, h, s, e, device="cuda", dtype=torch.float16, requires_grad=True
),
]
for perm in itertools.permutations(range(4), 4):

def fusion_func(fd: FusionDefinition) -> None:
q = fd.from_pytorch(inputs[0])
k = fd.from_pytorch(inputs[1])
k_t = fd.ops.permute(k, [0, 1, 3, 2])
out = fd.ops.matmul(q, k_t)
fd.add_output(out, stride_order=perm)

with FusionDefinition() as fd:
fusion_func(fd)
nvf_out = fd.execute(inputs)
eager_out = torch.matmul(inputs[0], torch.transpose(inputs[1], -2, -1))
verify_stride_order(nvf_out[0].stride(), perm)
torch.testing.assert_close(nvf_out[0], eager_out)

0 comments on commit d623221

Please sign in to comment.