Skip to content

Commit e07a0bf

Browse files
authored
onnx.resize: Add support for coordTfMode "half_pixel" (#3441)
half_pixel is also the default mode used by ONNX, see https://onnx.ai/onnx/operators/onnx__Resize.html
1 parent d77bab3 commit e07a0bf

File tree

4 files changed

+70
-3
lines changed

4 files changed

+70
-3
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,10 +2823,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
28232823
binder.op, "unimplemented: coordinate transformation mode: "
28242824
"tf_crop_and_resize");
28252825

2826-
if (mode == "nearest" && coordTfMode != "asymmetric") {
2826+
if (mode == "nearest" && coordTfMode != "asymmetric" &&
2827+
coordTfMode != "half_pixel") {
28272828
return rewriter.notifyMatchFailure(
28282829
binder.op, "unimplemented: support not present for coord tf mode "
2829-
"except asymmetric");
2830+
"except asymmetric and half_pixel");
28302831
}
28312832

28322833
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2631,7 +2631,17 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
26312631

26322632
Value outInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), outIndex);
26332633
Value outFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), outInt);
2634-
Value proj = b.create<arith::DivFOp>(loc, outFP, scale);
2634+
Value proj;
2635+
if (coordStr.empty() || coordStr == "_asymmetric") {
2636+
proj = b.create<arith::DivFOp>(loc, outFP, scale);
2637+
} else if (coordStr == "_half_pixel") {
2638+
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
2639+
Value add = b.create<arith::AddFOp>(loc, outFP, cstHalf);
2640+
Value div = b.create<arith::DivFOp>(loc, add, scale);
2641+
proj = b.create<arith::SubFOp>(loc, div, cstHalf);
2642+
} else {
2643+
llvm_unreachable("Unsupported coordination transformation mode");
2644+
}
26352645

26362646
Value nearestFP;
26372647
// get nearest pixel using floor
@@ -2655,6 +2665,8 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
26552665
nearestFP = b.create<arith::SelectOp>(loc, cmp, ceil, floor);
26562666
} else if (nearestMode == "ceil") {
26572667
nearestFP = b.create<math::CeilOp>(loc, proj);
2668+
} else {
2669+
llvm_unreachable("Unsupported nearest mode");
26582670
}
26592671
Value nearestInt =
26602672
b.create<arith::FPToSIOp>(loc, b.getI64Type(), nearestFP);

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,19 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
21832183

21842184
// -----
21852185

2186+
// CHECK-LABEL: func.func @test_resize_sizes_nearest
2187+
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2188+
%none = torch.constant.none
2189+
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
2190+
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2191+
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
2192+
torch.onnx.coordinate_transformation_mode = "half_pixel",
2193+
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
2194+
return %0 : !torch.vtensor<[?,?,?,?],f32>
2195+
}
2196+
2197+
// -----
2198+
21862199
// CHECK-LABEL: func.func @test_resize_sizes_linear
21872200
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
21882201
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {

test/Conversion/TorchToLinalg/resize.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,44 @@ func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1:
155155
%7 = torch.aten.__interpolate.size_list_scale_list %arg0, %6, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32>
156156
return %7 : !torch.vtensor<[?,?,?,?,?],f32>
157157
}
158+
159+
// CHECK-LABEL: func.func @test_resize_nearest_half_pixel
160+
func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> {
161+
// CHECK: %[[GENERIC:.*]] = linalg.generic
162+
// CHECK: %[[x11:.*]] = linalg.index 0 : index
163+
// CHECK: %[[x12:.*]] = linalg.index 1 : index
164+
// CHECK: %[[x13:.*]] = linalg.index 2 : index
165+
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
166+
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
167+
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
168+
// CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64
169+
// CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32
170+
// CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32
171+
// CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32
172+
// CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32
173+
// CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32
174+
// CHECK: %[[cst3:.*]] = arith.constant 5.000000e-01 : f32
175+
// CHECK: %[[floor:.*]] = math.floor %[[sub]] : f32
176+
// CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32
177+
// CHECK: %[[sub2:.*]] = arith.subf %[[sub]], %[[floor]] : f32
178+
// CHECK: %[[cmpf:.*]] = arith.cmpf ule, %[[sub2]], %[[cst3]] : f32
179+
// CHECK: %[[select:.*]] = arith.select %[[cmpf]], %[[floor]], %[[ceil]] : f32
180+
// CHECK: %[[x31:.*]] = arith.fptosi %[[select]] : f32 to i64
181+
// CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index
182+
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor<?x?x?xf32>
183+
// CHECK: linalg.yield %[[extracted]] : f32
184+
%none = torch.constant.none
185+
%none_0 = torch.constant.none
186+
%int0 = torch.constant.int 0
187+
%false = torch.constant.bool false
188+
%true = torch.constant.bool true
189+
%str = torch.constant.str "nearest_half_pixel,round_prefer_floor"
190+
%int2 = torch.constant.int 2
191+
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
192+
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
193+
%4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list<int>
194+
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32>
195+
return %5 : !torch.vtensor<[?,?,?],f32>
196+
}
197+
198+
// -----

0 commit comments

Comments
 (0)