@@ -605,51 +605,16 @@ def doprint(self, expr, *, simplify: bool = True, p=True):
605
605
return super ().doprint (expr )
606
606
607
607
608
- class OpOverrides :
609
- def __init__ (self , parent ):
610
- super ().__init__ ()
611
- self ._parent = parent
612
-
613
- @staticmethod
614
- def paren (string : str ) -> str :
615
- def all_in_parens (string : str ) -> bool :
616
- if string [0 ] != "(" or len (string ) < 2 :
617
- return False
618
- count = 1
619
- for i , char in enumerate (string [1 :]):
620
- if char == "(" :
621
- count += 1
622
- elif char == ")" :
623
- count -= 1
624
- if count == 0 and i != len (string ) - 2 :
625
- return False
626
- assert count == 0
627
- return True
628
-
629
- if (
630
- isinstance (string , CSEVariable )
631
- or re .match (r"^[a-z0-9_.]+$" , string , re .IGNORECASE )
632
- or re .match (r"^\([^)]*\)$" , string , re .IGNORECASE )
633
- or string == ""
634
- ):
635
- return string
636
- # don't put extra parens for strings that are already wrapped in parens
637
- if all_in_parens (string ):
638
- return string
639
- return f"({ string } )"
640
-
641
- def __getattr__ (self , item ):
642
- return getattr (self ._parent , item )
608
+ class OpDecompositions :
609
+ """
610
+ Decomposes inductor ops
611
+ """
643
612
644
613
@staticmethod
645
614
def identity (value ):
646
615
# used to trigger cse
647
616
return value
648
617
649
- @staticmethod
650
- def constant (value , dtype ):
651
- return repr (value )
652
-
653
618
@staticmethod
654
619
def reciprocal (x ):
655
620
return ops .truediv (ops .constant (1 , torch .int32 ), x )
@@ -691,15 +656,86 @@ def sigmoid(x):
691
656
one = ops .constant (1 , torch .int32 )
692
657
return ops .truediv (one , ops .add (one , ops .exp (ops .neg (x ))))
693
658
659
+ @staticmethod
660
+ def relu (x ):
661
+ return ops .maximum (x , ops .constant (0 , torch .int32 ))
662
+
663
+ @staticmethod
664
+ def fma (x , y , z ):
665
+ # for backends that don't override this (halide)
666
+ return ops .add (ops .mul (x , y ), z )
667
+
668
+ @staticmethod
669
+ def floor_to_int (a , dtype ):
670
+ return ops .to_dtype (ops .floor (a ), dtype )
671
+
672
+ @staticmethod
673
+ def ceil_to_int (a , dtype ):
674
+ return ops .to_dtype (ops .ceil (a ), dtype )
675
+
676
+ @staticmethod
677
+ def trunc_to_int (a , dtype ):
678
+ return ops .to_dtype (ops .trunc (a ), dtype )
679
+
680
+ @staticmethod
681
+ def remainder (a , b ):
682
+ r = ops .mod (a , b )
683
+ cond = ops .and_ (
684
+ ops .ne (r , ops .constant (0 , torch .int32 )),
685
+ ops .ne (ops .signbit (r ), ops .signbit (b )),
686
+ )
687
+ return ops .where (cond , ops .add (r , b ), r )
688
+
689
+ @staticmethod
690
+ def round_to_int (a , dtype ):
691
+ return ops .to_dtype (ops .round (a ), dtype )
692
+
693
+
694
+ class OpOverrides (OpDecompositions ):
695
+ def __init__ (self , parent ):
696
+ super ().__init__ ()
697
+ self ._parent = parent
698
+
699
+ @staticmethod
700
+ def paren (string : str ) -> str :
701
+ def all_in_parens (string : str ) -> bool :
702
+ if string [0 ] != "(" or len (string ) < 2 :
703
+ return False
704
+ count = 1
705
+ for i , char in enumerate (string [1 :]):
706
+ if char == "(" :
707
+ count += 1
708
+ elif char == ")" :
709
+ count -= 1
710
+ if count == 0 and i != len (string ) - 2 :
711
+ return False
712
+ assert count == 0
713
+ return True
714
+
715
+ if (
716
+ isinstance (string , CSEVariable )
717
+ or re .match (r"^[a-z0-9_.]+$" , string , re .IGNORECASE )
718
+ or re .match (r"^\([^)]*\)$" , string , re .IGNORECASE )
719
+ or string == ""
720
+ ):
721
+ return string
722
+ # don't put extra parens for strings that are already wrapped in parens
723
+ if all_in_parens (string ):
724
+ return string
725
+ return f"({ string } )"
726
+
727
+ def __getattr__ (self , item ):
728
+ return getattr (self ._parent , item )
729
+
730
+ @staticmethod
731
+ def constant (value , dtype ):
732
+ return repr (value )
733
+
694
734
@staticmethod
695
735
def libdevice_sigmoid (x ):
696
736
one = ops .constant (1 , torch .int32 )
697
737
return ops .truediv (one , ops .add (one , ops .libdevice_exp (ops .neg (x ))))
698
738
699
- @staticmethod
700
- def relu (x ):
701
- return ops .maximum (x , ops .constant (0 , torch .int32 ))
702
-
703
739
@staticmethod
704
740
def libdevice_abs (x ):
705
741
return ops .abs (x )
@@ -752,36 +788,6 @@ def bitwise_left_shift(x, y):
752
788
def bitwise_right_shift (x , y ):
753
789
return f"{ OpOverrides .paren (x )} >> { OpOverrides .paren (y )} "
754
790
755
- @staticmethod
756
- def remainder (a , b ):
757
- r = ops .mod (a , b )
758
- cond = ops .and_ (
759
- ops .ne (r , ops .constant (0 , torch .int32 )),
760
- ops .ne (ops .signbit (r ), ops .signbit (b )),
761
- )
762
- return ops .where (cond , ops .add (r , b ), r )
763
-
764
- @staticmethod
765
- def fma (x , y , z ):
766
- # for backends that don't override this (halide)
767
- return ops .add (ops .mul (x , y ), z )
768
-
769
- @staticmethod
770
- def trunc_to_int (a , dtype ):
771
- return ops .to_dtype (ops .trunc (a ), dtype )
772
-
773
- @staticmethod
774
- def floor_to_int (a , dtype ):
775
- return ops .to_dtype (ops .floor (a ), dtype )
776
-
777
- @staticmethod
778
- def ceil_to_int (a , dtype ):
779
- return ops .to_dtype (ops .ceil (a ), dtype )
780
-
781
- @staticmethod
782
- def round_to_int (a , dtype ):
783
- return ops .to_dtype (ops .round (a ), dtype )
784
-
785
791
@staticmethod
786
792
def int_truediv (a , b ):
787
793
# TODO: this is wrong
0 commit comments