Skip to content

Commit 0b75b7f

Browse files
eellisonpytorchmergebot
authored andcommitted
[Easy] factor out inductor ophandler decompositions (pytorch#142400)
Factor out inductor operator decompositions Pull Request resolved: pytorch#142400 Approved by: https://github.com/Chillee, https://github.com/jansel
1 parent c170248 commit 0b75b7f

File tree

2 files changed

+80
-75
lines changed

2 files changed

+80
-75
lines changed

torch/_inductor/codegen/common.py

Lines changed: 79 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -605,51 +605,16 @@ def doprint(self, expr, *, simplify: bool = True, p=True):
605605
return super().doprint(expr)
606606

607607

608-
class OpOverrides:
609-
def __init__(self, parent):
610-
super().__init__()
611-
self._parent = parent
612-
613-
@staticmethod
614-
def paren(string: str) -> str:
615-
def all_in_parens(string: str) -> bool:
616-
if string[0] != "(" or len(string) < 2:
617-
return False
618-
count = 1
619-
for i, char in enumerate(string[1:]):
620-
if char == "(":
621-
count += 1
622-
elif char == ")":
623-
count -= 1
624-
if count == 0 and i != len(string) - 2:
625-
return False
626-
assert count == 0
627-
return True
628-
629-
if (
630-
isinstance(string, CSEVariable)
631-
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
632-
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
633-
or string == ""
634-
):
635-
return string
636-
# don't put extra parens for strings that are already wrapped in parens
637-
if all_in_parens(string):
638-
return string
639-
return f"({string})"
640-
641-
def __getattr__(self, item):
642-
return getattr(self._parent, item)
608+
class OpDecompositions:
609+
"""
610+
Decomposes inductor ops
611+
"""
643612

644613
@staticmethod
645614
def identity(value):
646615
# used to trigger cse
647616
return value
648617

649-
@staticmethod
650-
def constant(value, dtype):
651-
return repr(value)
652-
653618
@staticmethod
654619
def reciprocal(x):
655620
return ops.truediv(ops.constant(1, torch.int32), x)
@@ -691,15 +656,86 @@ def sigmoid(x):
691656
one = ops.constant(1, torch.int32)
692657
return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
693658

659+
@staticmethod
660+
def relu(x):
661+
return ops.maximum(x, ops.constant(0, torch.int32))
662+
663+
@staticmethod
664+
def fma(x, y, z):
665+
# for backends that don't override this (halide)
666+
return ops.add(ops.mul(x, y), z)
667+
668+
@staticmethod
669+
def floor_to_int(a, dtype):
670+
return ops.to_dtype(ops.floor(a), dtype)
671+
672+
@staticmethod
673+
def ceil_to_int(a, dtype):
674+
return ops.to_dtype(ops.ceil(a), dtype)
675+
676+
@staticmethod
677+
def trunc_to_int(a, dtype):
678+
return ops.to_dtype(ops.trunc(a), dtype)
679+
680+
@staticmethod
681+
def remainder(a, b):
682+
r = ops.mod(a, b)
683+
cond = ops.and_(
684+
ops.ne(r, ops.constant(0, torch.int32)),
685+
ops.ne(ops.signbit(r), ops.signbit(b)),
686+
)
687+
return ops.where(cond, ops.add(r, b), r)
688+
689+
@staticmethod
690+
def round_to_int(a, dtype):
691+
return ops.to_dtype(ops.round(a), dtype)
692+
693+
694+
class OpOverrides(OpDecompositions):
695+
def __init__(self, parent):
696+
super().__init__()
697+
self._parent = parent
698+
699+
@staticmethod
700+
def paren(string: str) -> str:
701+
def all_in_parens(string: str) -> bool:
702+
if string[0] != "(" or len(string) < 2:
703+
return False
704+
count = 1
705+
for i, char in enumerate(string[1:]):
706+
if char == "(":
707+
count += 1
708+
elif char == ")":
709+
count -= 1
710+
if count == 0 and i != len(string) - 2:
711+
return False
712+
assert count == 0
713+
return True
714+
715+
if (
716+
isinstance(string, CSEVariable)
717+
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
718+
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
719+
or string == ""
720+
):
721+
return string
722+
# don't put extra parens for strings that are already wrapped in parens
723+
if all_in_parens(string):
724+
return string
725+
return f"({string})"
726+
727+
def __getattr__(self, item):
728+
return getattr(self._parent, item)
729+
730+
@staticmethod
731+
def constant(value, dtype):
732+
return repr(value)
733+
694734
@staticmethod
695735
def libdevice_sigmoid(x):
696736
one = ops.constant(1, torch.int32)
697737
return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
698738

699-
@staticmethod
700-
def relu(x):
701-
return ops.maximum(x, ops.constant(0, torch.int32))
702-
703739
@staticmethod
704740
def libdevice_abs(x):
705741
return ops.abs(x)
@@ -752,36 +788,6 @@ def bitwise_left_shift(x, y):
752788
def bitwise_right_shift(x, y):
753789
return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
754790

755-
@staticmethod
756-
def remainder(a, b):
757-
r = ops.mod(a, b)
758-
cond = ops.and_(
759-
ops.ne(r, ops.constant(0, torch.int32)),
760-
ops.ne(ops.signbit(r), ops.signbit(b)),
761-
)
762-
return ops.where(cond, ops.add(r, b), r)
763-
764-
@staticmethod
765-
def fma(x, y, z):
766-
# for backends that don't override this (halide)
767-
return ops.add(ops.mul(x, y), z)
768-
769-
@staticmethod
770-
def trunc_to_int(a, dtype):
771-
return ops.to_dtype(ops.trunc(a), dtype)
772-
773-
@staticmethod
774-
def floor_to_int(a, dtype):
775-
return ops.to_dtype(ops.floor(a), dtype)
776-
777-
@staticmethod
778-
def ceil_to_int(a, dtype):
779-
return ops.to_dtype(ops.ceil(a), dtype)
780-
781-
@staticmethod
782-
def round_to_int(a, dtype):
783-
return ops.to_dtype(ops.round(a), dtype)
784-
785791
@staticmethod
786792
def int_truediv(a, b):
787793
# TODO: this is wrong

torch/_inductor/ops_handler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def _arg_str(a) -> str:
5151
# implementations make heavy use of __getattr__ magic, and pre-existing
5252
# stubs for methods would interfere with this mechanism.
5353
#
54-
# TODO: A superclass that does desugaring for operations like
55-
# reciprocal/square might be useful.
54+
# See OpDecompositions for superclass that desugars operations like reciprocal/square.
5655
class OpsHandler(Protocol[T]):
5756
"""
5857
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,

0 commit comments

Comments
 (0)