-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path5dtensor.cpp
99 lines (80 loc) · 4.24 KB
/
5dtensor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
/*
This implementation focuses specifically on lowering the ONNX Col2Im operation to its Torch equivalent, with explicit support for 5D tensors.
Here are the key aspects of this implementation:
1. Input Handling:
- Expects a 3D input tensor (C * D * H * W, N * oD * oH, oW) as per ONNX specification.
- The output shape is provided as a separate tensor operand, which is expected to be 5D.
2. Attribute Validation:
- Validates that blockShape, dilations, and strides have 3 elements each (for D, H, W).
- Validates that pads has 6 elements (start and end padding for D, H, W).
3. Tensor Reshaping:
- Reshapes the input from (C * D * H * W, N * oD * oH, oW) to (N, C * D * H * W, oD * oH, oW).
- This step is necessary to match the expected input format of Torch's Col2Im operation.
4. Torch Operation Creation:
- Creates a constant tensor for blockShape using Torch::ConstantOp.
- Uses Torch::AtenCol2ImOp for the actual Col2Im operation, which supports 5D tensors.
5. Flexible Output Shape:
- The output shape is set to (-1, -1, -1, -1, -1), allowing for dynamic shape inference in Torch.
6. ONNX to Torch Mapping:
- Maps the ONNX Col2Im operation directly to its Torch equivalent, maintaining the semantics of the operation while adapting to Torch's specific requirements.
This implementation provides a clear lowering strategy from the ONNX Col2Im operation to its Torch equivalent, with explicit support for 5D tensors.
It handles the necessary reshaping and attribute conversions to ensure compatibility between the ONNX and Torch versions of the operation.
*/
LogicalResult col2imONNXToTorchLowering(OpBinder binder, ConversionPatternRewriter &rewriter) {
// Tensor operands
Value input, outputShape;
// Attributes
SmallVector<int64_t> blockShape, dilations, pads, strides;
// Bind operands and attributes
if (binder.tensorOperandAtIndex(input, 0) ||
binder.tensorOperandAtIndex(outputShape, 1) ||
binder.s64IntegerArrayAttr(blockShape, "blockShape") ||
binder.s64IntegerArrayAttr(dilations, "dilations") ||
binder.s64IntegerArrayAttr(pads, "pads") ||
binder.s64IntegerArrayAttr(strides, "strides"))
return failure();
// Validate input tensor
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
auto inputShape = inputTy.getSizes();
if (inputShape.size() != 3) {
return rewriter.notifyMatchFailure(binder.op, "Expected input to be a 3D tensor");
}
// Validate output shape
auto outputShapeTy = cast<Torch::ValueTensorType>(outputShape.getType());
auto outputShapeSizes = outputShapeTy.getSizes();
if (outputShapeSizes[0] != 5) {
return rewriter.notifyMatchFailure(binder.op, "Expected output shape to be 5D");
}
// Validate attribute sizes
if (blockShape.size() != 3 || dilations.size() != 3 || strides.size() != 3 || pads.size() != 6) {
return rewriter.notifyMatchFailure(binder.op, "Attribute sizes don't match 5D tensor requirements");
}
// Create Torch tensor for blockShape
auto blockShapeTensor = rewriter.create<Torch::ConstantOp>(
binder.op->getLoc(),
Torch::ValueTensorType::get(rewriter.getContext(), {3}, rewriter.getI64Type()),
rewriter.getDenseI64ArrayAttr(blockShape));
// Reshape input from (C * D * H * W, N * oD * oH, oW) to (N, C * D * H * W, oD * oH, oW)
auto reshapedInput = rewriter.create<Torch::AtenReshapeOp>(
binder.op->getLoc(),
Torch::ValueTensorType::get(rewriter.getContext(), {-1, inputShape[0], -1, inputShape[2]}, inputTy.getDtype()),
input,
rewriter.getDenseI64ArrayAttr({-1, inputShape[0], -1, inputShape[2]}));
// Create the Torch Col2Im operation
auto col2imOp = rewriter.create<Torch::AtenCol2ImOp>(
binder.op->getLoc(),
Torch::ValueTensorType::get(rewriter.getContext(), {-1, -1, -1, -1, -1}, inputTy.getDtype()),
reshapedInput,
outputShape,
blockShapeTensor.getResult(),
rewriter.getI64ArrayAttr(dilations),
rewriter.getI64ArrayAttr(pads),
rewriter.getI64ArrayAttr(strides));
// Replace the original op with the new col2im op
rewriter.replaceOp(binder.op, col2imOp.getResult());
return success();
}
// Register the pattern
patterns.onOp("ONNXCol2ImOp", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
return col2imONNXToTorchLowering(binder, rewriter);
});