Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Number of dims and results of reindexed AffineMap doesn't match on Vectorization #17591

Open
jinchen62 opened this issue Jun 6, 2024 · 12 comments
Assignees
Labels
bug 🐞 Something isn't working

Comments

@jinchen62
Copy link
Contributor

jinchen62 commented Jun 6, 2024

What happened?

dispatch: https://gist.github.com/jinchen62/5e2af98f9b5bfc3b55e949f964459815
error log: https://gist.github.com/jinchen62/df2038b5a43ed4680804a3d7d0647d95

The failing op dumped at https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp#L336 is

%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%arg2 : tensor<1x1xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0], [1, 0], [0, 4], [0, 0]]>} {
^bb0(%in: f32, %out: f32):
%11 = arith.addf %in, %out : f32
linalg.yield %11 : f32
} -> tensor<1x1xf32>

At the assertion failing point https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L474, the map is changed from (d0, d1) -> (0, d0) to (d0) -> (0, d0) so the number of dims and results doesn't match.

Steps to reproduce your issue

Run iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu dispatch_1.mlir -o test.vmfb 2> dump.mlir with TOM iree.

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

@jinchen62 jinchen62 added the bug 🐞 Something isn't working label Jun 6, 2024
@jinchen62 jinchen62 changed the title Number of dims and results of reindexed AffineMap don't match on Vectorization Number of dims and results of reindexed AffineMap doesn't match on Vectorization Jun 6, 2024
@hanhanW hanhanW self-assigned this Jun 6, 2024
@hanhanW
Copy link
Contributor

hanhanW commented Jun 6, 2024

Inlining the mlir input below. In the beginning, I thought that the (d0, d1) -> (0, d0) is generated during codegen, but it is the case. There is (d0, d1) -> (0, d0) affine_map in the codegen's input. @jinchen62 do you know how the input is generated? It would be very helpful if you can track it back to a small set of linalg ops or tosa/torch ops. The 0 should be folded away by FoldUnitExtentDimsPass at global opt level or flow level, i.e., it should be (d0, d1) -> (d0) when it goes to codegen.

hal.executable public @main_graph$async_dispatch_1 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
    hal.executable.export public @main_graph$async_dispatch_1_generic_9x1024_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main_graph$async_dispatch_1_generic_9x1024_f32() {
        %cst = arith.constant 0.000000e+00 : f32
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = arith.index_castui %0 : i32 to index
        %3 = arith.index_castui %1 : i32 to index
        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9x1024xf32>>
        %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%3) : !flow.dispatch.tensor<writeonly:tensor<1x9xf32>>
        %6 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [9, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9x1024xf32>> -> tensor<9x1024xf32>
        %7 = tensor.empty() : tensor<1x9xf32>
        %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x9xf32>) -> tensor<1x9xf32>
        %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0)>], iterator_types = ["parallel", "reduction"]} ins(%6 : tensor<9x1024xf32>) outs(%8 : tensor<1x9xf32>) {
        ^bb0(%in: f32, %out: f32):
          %10 = arith.addf %in, %out : f32
          linalg.yield %10 : f32
        } -> tensor<1x9xf32>
        flow.dispatch.tensor.store %9, %5, offsets = [0, 0], sizes = [1, 9], strides = [1, 1] : tensor<1x9xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x9xf32>>
        return
      }
    }
  }
}

@hanhanW hanhanW assigned jinchen62 and unassigned hanhanW Jun 6, 2024
@hanhanW
Copy link
Contributor

hanhanW commented Jun 7, 2024

I worked with @jinchen62 and we got a smaller repro: https://gist.github.com/hanhanW/b3652f5887b93fb8f0df6c6c39c1ef87

To repro, run iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fold-unit-extent-dims))" ~/repro.mlir.

Then you'll see affine_map<(d0, d1) -> (0, d0)> in the result.

#map2 = affine_map<(d0, d1) -> (d0, d1)>
#map8 = affine_map<(d0, d1) -> (0, d0)>
// ...
    %29 = linalg.generic {indexing_maps = [#map2, #map8], iterator_types = ["parallel", "reduction"]} ins(%collapsed_12 : tensor<9x1024xf32>) outs(%28 : tensor<?x9xf32>) {
    ^bb0(%in: f32, %out: f32):
      %35 = arith.addf %in, %out : f32
      linalg.yield %35 : f32
    } -> tensor<?x9xf32>
// ...

@hanhanW
Copy link
Contributor

hanhanW commented Jun 7, 2024

Actually, the input reduction op looks weird. The size of d0 mismatch. One is 1 and the other is ? It looks like there is a bug in frontend lowering. @jinchen62 you can add -mlir-print-debuginfo to iree-compile, and it will tell you where is the op lowered from. My guess is that there is a bug in XXX->Linalg lowering.

#map5 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map10 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
    %25 = tensor.empty(%12) : tensor<?x9x1xf32>
    %26 = linalg.fill ins(%cst_7 : f32) outs(%25 : tensor<?x9x1xf32>) -> tensor<?x9x1xf32>
    %27 = linalg.generic {indexing_maps = [#map5, #map10], iterator_types = ["parallel", "parallel", "reduction"]} ins(%24 : tensor<1x9x1024xf32>) outs(%26 : tensor<?x9x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %31 = arith.addf %in, %out : f32
      linalg.yield %31 : f32
    } -> tensor<?x9x1xf32>

@jinchen62
Copy link
Contributor Author

smaller repro: https://gist.github.com/jinchen62/91e216fb39abbb9ba4c0461346d2bb5a

command:
iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fold-unit-extent-dims))" repro.mlir
or
iree-compile --iree-hal-target-backends=llvm-cpu repro.mlir -o test.vmfb --mlir-print-ir-after-all 2> dump.mlir

@hanhanW
Copy link
Contributor

hanhanW commented Jun 11, 2024

@jinchen62 did you get a chance to see which op is generating the IR? The generic op looks invalid to me, like I explained in the above comment.

@jinchen62
Copy link
Contributor Author

I think it's

%237 = torch.aten.sum.dim_IntList %235, %236, %true, %none : !torch.vtensor<[?,9,1024],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?,9,1],f32>

@hanhanW
Copy link
Contributor

hanhanW commented Jun 11, 2024

I'd suggest to check if there are bugs in torch -> linalg lowering, or other high level dialects -> torch lowering.

@jinchen62
Copy link
Contributor Author

torch level repro: https://gist.github.com/jinchen62/601cfce290b81e037383fc49b604a68a

iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu --iree-util-zero-fill-elided-attrs repro_torch.mlir -o test.vmfb

@jinchen62
Copy link
Contributor Author

jinchen62 commented Jun 12, 2024

part of dump torch repro:
After ExpandOps (memref-expand) -> After Canonicalizer (canonicalize)
https://gist.github.com/jinchen62/ae856e42b0660d0b41426e910039fb9a

@hanhanW I think with a tensor.cast op, the reduction op that you found weird should be good to compile like line381. But after Canonicalizer pass, it looks missing it like line817. The following is a compiled repro, it would fail on the same error that we are facing without the cast op at the end. Does it make sense?

#map = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
module {
  func.func @repro2(%arg0: tensor<1x9x1024xf32>) -> tensor<1x9x1xf32> {
    %cst = arith.constant dense<[false, true]> : tensor<2xi1>
    %cst_0 = arith.constant dense<1> : tensor<2xi32>
    %cst_1 = arith.constant dense<[1, -1]> : tensor<2xi32>
    %cst_2 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<2xi32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%cst, %cst_0, %cst_1 : tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) outs(%0 : tensor<2xi32>) {
    ^bb0(%in: i1, %in_3: i32, %in_4: i32, %out: i32):
      %6 = arith.select %in, %in_3, %in_4 : i32
      linalg.yield %6 : i32
    } -> tensor<2xi32>
    %extracted_slice = tensor.extract_slice %1[0] [1] [1] : tensor<2xi32> to tensor<1xi32>
    %collapsed = tensor.collapse_shape %extracted_slice [] : tensor<1xi32> into tensor<i32>
    %extracted = tensor.extract %collapsed[] : tensor<i32>
    %2 = arith.index_cast %extracted : i32 to index
    %3 = tensor.empty(%2) : tensor<?x9x1xf32>
    %4 = linalg.fill ins(%cst_2 : f32) outs(%3 : tensor<?x9x1xf32>) -> tensor<?x9x1xf32>
    %5 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x9x1024xf32>) outs(%4 : tensor<?x9x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %6 = arith.addf %in, %out : f32
      linalg.yield %6 : f32
    } -> tensor<?x9x1xf32>
    %cast = tensor.cast %5 : tensor<?x9x1xf32> to tensor<1x9x1xf32>
    return %cast : tensor<1x9x1xf32>
  }
}

@hanhanW
Copy link
Contributor

hanhanW commented Jun 12, 2024

I'm not convinced that the issue is tensor.cast. There are some shape inference passes/patterns in MLIR dialect, and they create tensor.cast op to spell out some static shapes. With the hint, the compiler is smart to fold the shape information into linalg op, which is reasonable to me. The patterns and passes are working at Linalg level, what I can think of is that the frontend is generating invalid ops.

I don't know why we're still triaging the issue at model level, perhaps I did not make it clear. Let me put it this way -- Instead of compiling the whole model, are you able to compile a single %237 = torch.aten.sum.dim_IntList %235, %236, %true, %none : !torch.vtensor<[?,9,1024],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?,9,1],f32> op?

@jinchen62
Copy link
Contributor Author

I don't think it's a lowering issue. The torch.aten.sum.dim_IntList op compiles, and I traced up to the onnx->torch and didn't find a lowering bug of any onnx op.

@raikonenfnu and I think there might be an optimization bug in canonicalize pass after memref-expand. We saw the generic op with reduction dim changing from ins(%146 : tensor<?x9x1024xf32>) outs(%149 : tensor<?x?x1xf32>) to ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) with folding the tensor.cast op. We might want to see it changes to ins(%55 : tensor<?x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) or ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<1x9x1xf32>). The dump ir is here.

@AmosLewis
Copy link
Contributor

AmosLewis commented Jun 21, 2024

I don't think it's a lowering issue. The torch.aten.sum.dim_IntList op compiles, and I traced up to the onnx->torch and didn't find a lowering bug of any onnx op.

@raikonenfnu and I think there might be an optimization bug in canonicalize pass after memref-expand. We saw the generic op with reduction dim changing from ins(%146 : tensor<?x9x1024xf32>) outs(%149 : tensor<?x?x1xf32>) to ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) with folding the tensor.cast op. We might want to see it changes to ins(%55 : tensor<?x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) or ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<1x9x1xf32>). The dump ir is here.

@jinchen62 So what's the plan to fix this issue? The bart-large model need this anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants