Skip to content

Commit

Permalink
Change Torch lowering to use divSI
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar committed Jan 25, 2025
1 parent 66c536a commit b986346
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ class BindSymbolicShapesPass final
case AffineExprKind::Mul:
return builder.create<arith::MulIOp>(loc, lhs, rhs);
case AffineExprKind::Mod:
return builder.create<arith::RemUIOp>(loc, lhs, rhs);
return builder.create<arith::RemSIOp>(loc, lhs, rhs);
case AffineExprKind::FloorDiv:
return builder.create<arith::DivUIOp>(loc, lhs, rhs);
return builder.create<arith::DivSIOp>(loc, lhs, rhs);
case AffineExprKind::CeilDiv:
return builder.create<arith::CeilDivUIOp>(loc, lhs, rhs);
return builder.create<arith::CeilDivSIOp>(loc, lhs, rhs);
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ module @add_expr {
// CHECK-LABEL: @mod_expr
module @mod_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: remui
// CHECK-NOT: udiv
// CHECK: remsi
// CHECK-NOT: sdiv
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
Expand All @@ -127,8 +127,8 @@ module @mod_expr {
// CHECK-LABEL: @floordiv_expr
module @floordiv_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: divui
// CHECK-NOT: udiv
// CHECK: divsi
// CHECK-NOT: sdiv
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ func.func @block_attention_dims() {
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK-DAG: %[[M:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 0 : index
// CHECK-DAG: %[[K2:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 1 : index
// CHECK-DAG: %[[M_DYNAMIC:.+]] = arith.divui %[[M]], %[[C16]]
// CHECK-DAG: %[[M_DYNAMIC:.+]] = arith.divsi %[[M]], %[[C16]]
// CHECK: %[[Q_BINDING:.+]] = hal.interface.binding.subspan
// CHECK-SAME: binding(0)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<4x?x16x32x128xf16>>{%[[M_DYNAMIC]]}
// CHECK: %[[K2_DYNAMIC:.+]] = arith.divui %[[K2]], %[[C32]]
// CHECK: %[[K2_DYNAMIC:.+]] = arith.divsi %[[K2]], %[[C32]]
// CHECK: %[[K_BINDING:.+]] = hal.interface.binding.subspan
// CHECK-SAME: binding(1)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<4x?x32x32x128xf16>>{%[[K2_DYNAMIC]]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func.func @fold_expand_into_loads_dynamic() -> tensor<2x?x16x32xf32> {
// CHECK-LABEL: func @fold_expand_into_loads_dynamic()
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK-DAG: %[[CONST:.+]] = hal.interface.constant.load
// CHECK: %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C16]]
// CHECK: %[[SHAPE:.+]] = arith.divsi %[[CONST]], %[[C16]]
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<2x?x16x32xf32>>{%[[SHAPE]]}
// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SUBSPAN]]
Expand All @@ -81,7 +81,7 @@ func.func @fold_collapse_into_stores_dynamic(%arg0 : tensor<2x?x32xf32>) {
// CHECK-LABEL: func @fold_collapse_into_stores_dynamic(
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[CONST:.+]] = hal.interface.constant.load
// CHECK: %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C2]]
// CHECK: %[[SHAPE:.+]] = arith.divsi %[[CONST]], %[[C2]]
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<2x?x32xf32>>{%[[SHAPE]]}
// CHECK: flow.dispatch.tensor.store %{{.+}}, %[[SUBSPAN]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,17 @@ util.func public @attention_dynamic(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
// CHECK-DAG: %[[SPLIT0:.+]] = arith.divui %[[D0]]
// CHECK-DAG: %[[SPLIT0:.+]] = arith.divsi %[[D0]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SPLIT0]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16>
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]]
// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK-DAG: %[[SPLIT1:.+]] = arith.divui %[[D5]], %[[C2]]
// CHECK-DAG: %[[SPLIT1:.+]] = arith.divsi %[[D5]], %[[C2]]
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]]
// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]]
// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]]
// CHECK-DAG: %[[SPLIT2:.+]] = arith.divui %[[D8]], %[[C2]]
// CHECK-DAG: %[[SPLIT2:.+]] = arith.divsi %[[D8]], %[[C2]]
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]]
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps =
Expand Down Expand Up @@ -262,22 +262,22 @@ util.func public @attention_dynamic_masked(%arg0: tensor<?x?x?xf16>, %arg1: tens
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
// CHECK-DAG: %[[SPLIT0:.+]] = arith.divui %[[D0]]
// CHECK-DAG: %[[SPLIT0:.+]] = arith.divsi %[[D0]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SPLIT0]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16>
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]]
// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK-DAG: %[[SPLIT1:.+]] = arith.divui %[[D5]], %[[C2]]
// CHECK-DAG: %[[SPLIT1:.+]] = arith.divsi %[[D5]], %[[C2]]
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]]
// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]]
// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]]
// CHECK-DAG: %[[SPLIT2:.+]] = arith.divui %[[D8]], %[[C2]]
// CHECK-DAG: %[[SPLIT2:.+]] = arith.divsi %[[D8]], %[[C2]]
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]]
// CHECK-DAG: %[[D10:.+]] = tensor.dim %[[ARG4]], %[[C0]]
// CHECK-DAG: %[[D11:.+]] = tensor.dim %[[ARG4]], %[[C1]]
// CHECK-DAG: %[[D12:.+]] = tensor.dim %[[ARG4]], %[[C2]]
// CHECK-DAG: %[[SPLIT3:.+]] = arith.divui %[[D10]], %[[C2]]
// CHECK-DAG: %[[SPLIT3:.+]] = arith.divsi %[[D10]], %[[C2]]
// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT3]], %[[D11]], %[[D12]]]
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
// CHECK-SAME: indexing_maps =
Expand Down

0 comments on commit b986346

Please sign in to comment.