From 9b61174f1c78b8040ab21e64813009cb71413798 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Tue, 3 Mar 2026 08:39:44 +0000 Subject: [PATCH] [FLYDSL]: Bug fixes for algebra not being the simplest --- .../flydsl/Dialect/Fly/Utils/IntTupleUtils.h | 5 ++ .../flydsl/Dialect/Fly/Utils/LayoutUtils.h | 17 +++- test/LayoutAlgebra/divide.mlir | 55 ++++++++++++- test/LayoutAlgebra/product.mlir | 79 ++++++++++++++++++- 4 files changed, 145 insertions(+), 11 deletions(-) diff --git a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h index e3cc50e9..3191c66c 100644 --- a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h @@ -845,6 +845,11 @@ std::pair intTupleZip2ByImpl(const IntTupleBuilder assert(t.rank() == 2 && "intTupleZip2By expects rank-2 tuple at terminal"); return {builder.at(t, 0), builder.at(t, 1)}; } + // Canonicalize singleton guide wrappers so 1D profiles behave as leaf guides. + // This keeps zip2By robust after singleton unwrapping in product/divide type canonicalization. + if (guide.rank() == 1) { + return intTupleZip2ByImpl(builder, t, guide.at(0)); + } Collector firsts; Collector seconds; diff --git a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h index d2635a77..b52e7a01 100644 --- a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h @@ -27,6 +27,11 @@ std::pair canonicalizeStridePair(const IntTupleBuilder (4, 4):(1, 4). + if (shape.rank() == 1) { + return canonicalizeStridePair(builder, builder.at(shape, 0), builder.at(stride, 0)); + } typename IntTupleBuilder::ElemCollector shapeElems; typename IntTupleBuilder::ElemCollector strideElems; for (int i = 0; i < shape.rank(); ++i) { @@ -769,7 +774,9 @@ Layout layoutComposition(LayoutBuilder &builder, Layout outerLayout, Lay auto [retShape, retStride] = detail::compositionImpl(builder, coalShape, coalStride, builder.getShape(innerLayout), builder.getStride(innerLayout)); - return builder.makeLayout(retShape, retStride); + auto [canonicalShape, canonicalStride] = + detail::canonicalizeStridePair(builder, retShape, retStride); + return builder.makeLayout(canonicalShape, canonicalStride); } template Layout layoutComposition(LayoutBuilder &builder, Layout outerLayout, @@ -815,7 +822,9 @@ Layout layoutComposition(LayoutBuilder &builder, Layout outerLayout, retStride.push_back(builder.at(lhsStride, i)); } } - return builder.makeLayout(builder.makeTuple(retShape), builder.makeTuple(retStride)); + auto [canonicalShape, canonicalStride] = + detail::canonicalizeStridePair(builder, builder.makeTuple(retShape), builder.makeTuple(retStride)); + return builder.makeLayout(canonicalShape, canonicalStride); } template @@ -1214,7 +1223,9 @@ Layout layoutLogicalProduct(LayoutBuilder &builder, Layout blockLayout, retStrideElems.push_back(builder.getStride(blockLayout)); retStrideElems.push_back(builder.getStride(composed)); - return builder.makeLayout(builder.makeTuple(retShapeElems), builder.makeTuple(retStrideElems)); + auto [canonicalShape, canonicalStride] = detail::canonicalizeStridePair( + builder, builder.makeTuple(retShapeElems), builder.makeTuple(retStrideElems)); + return builder.makeLayout(canonicalShape, canonicalStride); } template diff --git a/test/LayoutAlgebra/divide.mlir b/test/LayoutAlgebra/divide.mlir index db73ec89..802b769a 100644 --- a/test/LayoutAlgebra/divide.mlir +++ b/test/LayoutAlgebra/divide.mlir @@ -50,14 +50,61 @@ func.func @test_flat_divide(%layout: !fly.layout<(4, 8) : (1, 4)>, } // CHECK-LABEL: @test_logical_divide_1d -func.func @test_logical_divide_1d() -> !fly.layout<((4), 4) : ((1), 4)> { - // Divide a 1D contiguous layout: (16):(1) / (4):(1) -> ((4),4):((1),4) +func.func @test_logical_divide_1d() -> !fly.layout<(4, 4) : (1, 4)> { + // Divide a 1D contiguous layout: (16):(1) / (4):(1) -> (4,4):(1,4) %s = fly.static {elems = [16 : i32]} : () -> !fly.int_tuple<(16)> %d = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> %layout = fly.make_layout(%s, %d) : (!fly.int_tuple<(16)>, !fly.int_tuple<(1)>) -> !fly.layout<(16) : (1)> %ds = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> %dd = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> %divisor = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> - %result = fly.logical_divide(%layout, %divisor) : (!fly.layout<(16) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<((4), 4) : ((1), 4)> - return %result : !fly.layout<((4), 4) : ((1), 4)> + %result = fly.logical_divide(%layout, %divisor) : (!fly.layout<(16) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(4, 4) : (1, 4)> + return %result : !fly.layout<(4, 4) : (1, 4)> +} + +// CHECK-LABEL: @test_zipped_divide_1d +func.func @test_zipped_divide_1d() -> !fly.layout<(4, 4) : (1, 4)> { + %s = fly.static {elems = [16 : i32]} : () -> !fly.int_tuple<(16)> + %d = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %layout = fly.make_layout(%s, %d) : (!fly.int_tuple<(16)>, !fly.int_tuple<(1)>) -> !fly.layout<(16) : (1)> + %ds = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> + %dd = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %divisor = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> + %result = fly.zipped_divide(%layout, %divisor) : (!fly.layout<(16) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(4, 4) : (1, 4)> + return %result : !fly.layout<(4, 4) : (1, 4)> +} + +// CHECK-LABEL: @test_tiled_divide_1d +func.func @test_tiled_divide_1d() -> !fly.layout<(4, 4) : (1, 4)> { + %s = fly.static {elems = [16 : i32]} : () -> !fly.int_tuple<(16)> + %d = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %layout = fly.make_layout(%s, %d) : (!fly.int_tuple<(16)>, !fly.int_tuple<(1)>) -> !fly.layout<(16) : (1)> + %ds = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> + %dd = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %divisor = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> + %result = fly.tiled_divide(%layout, %divisor) : (!fly.layout<(16) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(4, 4) : (1, 4)> + return %result : !fly.layout<(4, 4) : (1, 4)> +} + +// CHECK-LABEL: @test_flat_divide_1d +func.func @test_flat_divide_1d() -> !fly.layout<(4, 4) : (1, 4)> { + %s = fly.static {elems = [16 : i32]} : () -> !fly.int_tuple<(16)> + %d = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %layout = fly.make_layout(%s, %d) : (!fly.int_tuple<(16)>, !fly.int_tuple<(1)>) -> !fly.layout<(16) : (1)> + %ds = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> + %dd = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %divisor = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> + %result = fly.flat_divide(%layout, %divisor) : (!fly.layout<(16) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(4, 4) : (1, 4)> + return %result : !fly.layout<(4, 4) : (1, 4)> +} + +// CHECK-LABEL: @test_logical_divide_wrapped_tuple_1d +func.func @test_logical_divide_wrapped_tuple_1d( + %layout: !fly.layout<((16, 1)) : ((1, 16))>, + %divisor: !fly.layout<((4, 1)) : ((1, 4))>) -> !fly.layout<((4, 1), 4) : ((1, 0), 4)> { + // Outer singleton wrappers are accepted and handled in inference. + %result = fly.logical_divide(%layout, %divisor) + : (!fly.layout<((16, 1)) : ((1, 16))>, !fly.layout<((4, 1)) : ((1, 4))>) + -> !fly.layout<((4, 1), 4) : ((1, 0), 4)> + return %result : !fly.layout<((4, 1), 4) : ((1, 0), 4)> } diff --git a/test/LayoutAlgebra/product.mlir b/test/LayoutAlgebra/product.mlir index 832688c7..742210b2 100644 --- a/test/LayoutAlgebra/product.mlir +++ b/test/LayoutAlgebra/product.mlir @@ -68,14 +68,85 @@ func.func @test_raked_product(%base: !fly.layout<(4, 8) : (1, 4)>, } // CHECK-LABEL: @test_logical_product_1d -func.func @test_logical_product_1d() -> !fly.layout<((8), (4)) : ((1), (8))> { - // 1D base with 1D tile preserves nesting structure +func.func @test_logical_product_1d() -> !fly.layout<(8, 4) : (1, 8)> { + // 1D base with 1D tile canonicalizes singleton tuple wrappers %s1 = fly.static {elems = [8 : i32]} : () -> !fly.int_tuple<(8)> %d1 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> %base = fly.make_layout(%s1, %d1) : (!fly.int_tuple<(8)>, !fly.int_tuple<(1)>) -> !fly.layout<(8) : (1)> %s2 = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> %d2 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> %tile = fly.make_layout(%s2, %d2) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> - %result = fly.logical_product(%base, %tile) : (!fly.layout<(8) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<((8), (4)) : ((1), (8))> - return %result : !fly.layout<((8), (4)) : ((1), (8))> + %result = fly.logical_product(%base, %tile) : (!fly.layout<(8) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(8, 4) : (1, 8)> + return %result : !fly.layout<(8, 4) : (1, 8)> +} + +// CHECK-LABEL: @test_zipped_product_1d +func.func @test_zipped_product_1d() -> !fly.layout<(8, 4) : (1, 8)> { + %s1 = fly.static {elems = [8 : i32]} : () -> !fly.int_tuple<(8)> + %d1 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %base = fly.make_layout(%s1, %d1) : (!fly.int_tuple<(8)>, !fly.int_tuple<(1)>) -> !fly.layout<(8) : (1)> + %s2 = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> + %d2 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %tile = fly.make_layout(%s2, %d2) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> + %result = fly.zipped_product(%base, %tile) : (!fly.layout<(8) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(8, 4) : (1, 8)> + return %result : !fly.layout<(8, 4) : (1, 8)> +} + +// CHECK-LABEL: @test_tiled_product_1d +func.func @test_tiled_product_1d() -> !fly.layout<(8, 4) : (1, 8)> { + %s1 = fly.static {elems = [8 : i32]} : () -> !fly.int_tuple<(8)> + %d1 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %base = fly.make_layout(%s1, %d1) : (!fly.int_tuple<(8)>, !fly.int_tuple<(1)>) -> !fly.layout<(8) : (1)> + %s2 = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> + %d2 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %tile = fly.make_layout(%s2, %d2) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> + %result = fly.tiled_product(%base, %tile) : (!fly.layout<(8) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(8, 4) : (1, 8)> + return %result : !fly.layout<(8, 4) : (1, 8)> +} + +// CHECK-LABEL: @test_flat_product_1d +func.func @test_flat_product_1d() -> !fly.layout<(8, 4) : (1, 8)> { + %s1 = fly.static {elems = [8 : i32]} : () -> !fly.int_tuple<(8)> + %d1 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %base = fly.make_layout(%s1, %d1) : (!fly.int_tuple<(8)>, !fly.int_tuple<(1)>) -> !fly.layout<(8) : (1)> + %s2 = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> + %d2 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %tile = fly.make_layout(%s2, %d2) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> + %result = fly.flat_product(%base, %tile) : (!fly.layout<(8) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(8, 4) : (1, 8)> + return %result : !fly.layout<(8, 4) : (1, 8)> +} + +// CHECK-LABEL: @test_blocked_product_1d +func.func @test_blocked_product_1d() -> !fly.layout<(8, 4) : (1, 8)> { + %s1 = fly.static {elems = [8 : i32]} : () -> !fly.int_tuple<(8)> + %d1 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %base = fly.make_layout(%s1, %d1) : (!fly.int_tuple<(8)>, !fly.int_tuple<(1)>) -> !fly.layout<(8) : (1)> + %s2 = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> + %d2 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %tile = fly.make_layout(%s2, %d2) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> + %result = fly.blocked_product(%base, %tile) : (!fly.layout<(8) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(8, 4) : (1, 8)> + return %result : !fly.layout<(8, 4) : (1, 8)> +} + +// CHECK-LABEL: @test_raked_product_1d +func.func @test_raked_product_1d() -> !fly.layout<(4, 8) : (8, 1)> { + %s1 = fly.static {elems = [8 : i32]} : () -> !fly.int_tuple<(8)> + %d1 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %base = fly.make_layout(%s1, %d1) : (!fly.int_tuple<(8)>, !fly.int_tuple<(1)>) -> !fly.layout<(8) : (1)> + %s2 = fly.static {elems = [4 : i32]} : () -> !fly.int_tuple<(4)> + %d2 = fly.static {elems = [1 : i32]} : () -> !fly.int_tuple<(1)> + %tile = fly.make_layout(%s2, %d2) : (!fly.int_tuple<(4)>, !fly.int_tuple<(1)>) -> !fly.layout<(4) : (1)> + %result = fly.raked_product(%base, %tile) : (!fly.layout<(8) : (1)>, !fly.layout<(4) : (1)>) -> !fly.layout<(4, 8) : (8, 1)> + return %result : !fly.layout<(4, 8) : (8, 1)> +} + +// CHECK-LABEL: @test_logical_product_wrapped_tuple_1d +func.func @test_logical_product_wrapped_tuple_1d( + %base: !fly.layout<((8, 1)) : ((1, 8))>, + %tile: !fly.layout<((4, 1)) : ((1, 4))>) -> !fly.layout<((8, 1), (4, 1)) : ((1, 0), (8, 0))> { + // Outer singleton wrappers are accepted and handled in inference. + %result = fly.logical_product(%base, %tile) + : (!fly.layout<((8, 1)) : ((1, 8))>, !fly.layout<((4, 1)) : ((1, 4))>) + -> !fly.layout<((8, 1), (4, 1)) : ((1, 0), (8, 0))> + return %result : !fly.layout<((8, 1), (4, 1)) : ((1, 0), (8, 0))> }