Skip to content

Commit

Permalink
fix: tfmm backward deduction (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
ganler authored Mar 14, 2023
1 parent d70fb1f commit e4ba12e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
1 change: 0 additions & 1 deletion nnsmith/abstract/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2024,7 +2024,6 @@ def __init__(self):
super().__init__(DType.bool)


@mark_materialize("core")
class MatMul(BinaryOpBase):
in_dtypes = [(i, i) for i in DTYPE_GEN_NON_BOOL if i not in DTYPE_GEN_COMPLEX]
out_dtypes = [(i,) for i in DTYPE_GEN_NON_BOOL if i not in DTYPE_GEN_COMPLEX]
Expand Down
16 changes: 16 additions & 0 deletions nnsmith/materialize/tensorflow/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,19 @@ def __init__(self):
super().__init__()
self.inp_ranks = [(2, 3), (2, 3)]
self.out_ranks = [(2, 3)]

def deduct_inp_ranks_and_dtype(
self, out_abs_tensor: List[AbsTensor]
) -> List[Tuple[int, DType]]:
if out_abs_tensor[0].ndims == 2:
return [
(2, out_abs_tensor[0].dtype),
(2, out_abs_tensor[0].dtype),
]
# at least one of them is 3
ranks = [3, random.choice([2, 3])]
random.shuffle(ranks)
return [
(ranks[0], out_abs_tensor[0].dtype),
(ranks[1], out_abs_tensor[0].dtype),
]
13 changes: 12 additions & 1 deletion nnsmith/materialize/torch/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
DTYPE_GEN_NON_BOOL,
DType,
)
from nnsmith.abstract.op import ReduceBase, UnaryOpBase, mark_materialize, rank_from
from nnsmith.abstract.op import (
MatMul,
ReduceBase,
UnaryOpBase,
mark_materialize,
rank_from,
)
from nnsmith.abstract.tensor import AbsTensor
from nnsmith.error import ConstraintCheck

Expand Down Expand Up @@ -90,3 +96,8 @@ def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
if input_shapes[0].dtype in DTYPE_GEN_INTS: # This is a PyTorch trick...
output[0].dtype = DType.int64
return output


@mark_materialize("torch")
class PTMatMul(MatMul):
pass

0 comments on commit e4ba12e

Please sign in to comment.