Skip to content

Commit 1cb14f6

Browse files
daveliddellDave Liddell
and
Dave Liddell
authored
Rob's atenTensor folder (llvm#2867)
If a tensor is initialized by a list with a single constant integer, this folder turns it into a torch.vtensor.literal --------- Co-authored-by: Dave Liddell <dliddell@xilinx.com>
1 parent 041a54a commit 1cb14f6

File tree

4 files changed

+34
-1
lines changed

4 files changed

+34
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8582,6 +8582,7 @@ def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [
85828582
printDefaultTorchOp(printer, *this, 4, 1);
85838583
}
85848584
}];
8585+
let hasFolder = 1;
85858586
}
85868587

85878588
def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2758,6 +2758,27 @@ void AtenDeviceWithIndexOp::getCanonicalizationPatterns(
27582758
});
27592759
}
27602760

2761+
//===----------------------------------------------------------------------===//
2762+
// AtenTensorOp
2763+
//===----------------------------------------------------------------------===//
2764+
2765+
OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
2766+
// If a torch.aten.tensor op is initialized by a list with a constant, single
2767+
// element, fold it into a torch.vtensor.literal
2768+
auto resultTy = dyn_cast<ValueTensorType>(getType());
2769+
Type eTy = resultTy.getDtype();
2770+
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
2771+
2772+
SmallVector<int64_t> data;
2773+
if (matchPattern(getData(), m_TorchListOfConstantInts(data)) &&
2774+
data.size() == 1) {
2775+
Attribute attribute = IntegerAttr::get(eTy, data[0]);
2776+
return DenseElementsAttr::get(shapedTy, attribute);
2777+
}
2778+
2779+
return nullptr;
2780+
}
2781+
27612782
//===----------------------------------------------------------------------===//
27622783
// AtenIntTensorOp
27632784
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def emit_with_mutating_variants(key, **kwargs):
570570
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
571571
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
572572
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
573-
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
573+
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)", has_folder=True)
574574
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
575575
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
576576
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")

test/Dialect/Torch/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,17 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to
14611461
return %0 : !torch.tensor<[],f32>
14621462
}
14631463

1464+
// CHECK-LABEL: func.func @torch.aten.tensor$one_elem(
1465+
// CHECK-NEXT: torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
1466+
func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) {
1467+
%none = torch.constant.none
1468+
%false = torch.constant.bool false
1469+
%int42 = torch.constant.int 42
1470+
%66 = torch.prim.ListConstruct %int42 : (!torch.int) -> !torch.list<int>
1471+
%67 = torch.aten.tensor %66, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
1472+
return %67 : !torch.vtensor<[1],si64>
1473+
}
1474+
14641475
// CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype(
14651476
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
14661477
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32>

0 commit comments

Comments
 (0)