From e4ba12ea4bcd073b653b36a366891c3fe3bbde60 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Tue, 14 Mar 2023 11:26:11 -0500 Subject: [PATCH] fix: tfmm backward deduction (#95) --- nnsmith/abstract/op.py | 1 - nnsmith/materialize/tensorflow/dialect.py | 16 ++++++++++++++++ nnsmith/materialize/torch/dialect.py | 13 ++++++++++++- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/nnsmith/abstract/op.py b/nnsmith/abstract/op.py index 628aacc..a3df0ea 100644 --- a/nnsmith/abstract/op.py +++ b/nnsmith/abstract/op.py @@ -2024,7 +2024,6 @@ def __init__(self): super().__init__(DType.bool) -@mark_materialize("core") class MatMul(BinaryOpBase): in_dtypes = [(i, i) for i in DTYPE_GEN_NON_BOOL if i not in DTYPE_GEN_COMPLEX] out_dtypes = [(i,) for i in DTYPE_GEN_NON_BOOL if i not in DTYPE_GEN_COMPLEX] diff --git a/nnsmith/materialize/tensorflow/dialect.py b/nnsmith/materialize/tensorflow/dialect.py index 84ed7ec..c7a5965 100644 --- a/nnsmith/materialize/tensorflow/dialect.py +++ b/nnsmith/materialize/tensorflow/dialect.py @@ -273,3 +273,19 @@ def __init__(self): super().__init__() self.inp_ranks = [(2, 3), (2, 3)] self.out_ranks = [(2, 3)] + + def deduct_inp_ranks_and_dtype( + self, out_abs_tensor: List[AbsTensor] + ) -> List[Tuple[int, DType]]: + if out_abs_tensor[0].ndims == 2: + return [ + (2, out_abs_tensor[0].dtype), + (2, out_abs_tensor[0].dtype), + ] + # at least one of them is 3 + ranks = [3, random.choice([2, 3])] + random.shuffle(ranks) + return [ + (ranks[0], out_abs_tensor[0].dtype), + (ranks[1], out_abs_tensor[0].dtype), + ] diff --git a/nnsmith/materialize/torch/dialect.py b/nnsmith/materialize/torch/dialect.py index c0dff75..5c2a6ee 100644 --- a/nnsmith/materialize/torch/dialect.py +++ b/nnsmith/materialize/torch/dialect.py @@ -11,7 +11,13 @@ DTYPE_GEN_NON_BOOL, DType, ) -from nnsmith.abstract.op import ReduceBase, UnaryOpBase, mark_materialize, rank_from +from nnsmith.abstract.op import ( + MatMul, + ReduceBase, + UnaryOpBase, + mark_materialize, + rank_from, +) from nnsmith.abstract.tensor import AbsTensor from nnsmith.error import ConstraintCheck @@ -90,3 +96,8 @@ def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]: if input_shapes[0].dtype in DTYPE_GEN_INTS: # This is a PyTorch trick... output[0].dtype = DType.int64 return output + + +@mark_materialize("torch") +class PTMatMul(MatMul): + pass