@@ -1598,6 +1598,46 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs):
1598
1598
elementwise_unary_ops .append (relu6_opinfo )
1599
1599
1600
1600
1601
+ def hardswish_error_generator (op , device , dtype = torch .float32 , ** kwargs ):
1602
+ a = make_tensor ((), dtype = dtype , device = device )
1603
+ yield (SampleInput (a , inplace = True ), NotImplementedError , "hardswish only supports inplace=False" )
1604
+
1605
+
1606
+ hardswish_opinfo = OpInfo (
1607
+ ltorch .hardswish ,
1608
+ sample_input_generator = elementwise_unary_generator ,
1609
+ error_input_generator = hardswish_error_generator ,
1610
+ torch_reference = _elementwise_unary_torch (torch .nn .functional .hardswish ),
1611
+ test_directives = (
1612
+ # PyTorch does not support bool for both CPU and CUDA hardswish
1613
+ DecorateInfo (
1614
+ pytest .mark .xfail ,
1615
+ "test_core_vs_torch_consistency" ,
1616
+ dtypes = (datatypes .bool8 ,),
1617
+ ),
1618
+ # PyTorch does not support complex types for both the CPU and CUDA hardswish
1619
+ DecorateInfo (
1620
+ pytest .mark .xfail ,
1621
+ "test_core_vs_torch_consistency" ,
1622
+ dtypes = (datatypes .complexfloating ,),
1623
+ ),
1624
+ # PyTorch does not support CPU Half hardswish
1625
+ DecorateInfo (
1626
+ pytest .mark .xfail ,
1627
+ "test_core_vs_torch_consistency" ,
1628
+ dtypes = (datatypes .float16 ,),
1629
+ devicetypes = (devices .DeviceType .CPU ,),
1630
+ ),
1631
+ # TODO: we might have a tolerance issue here with hardsiwsh, a function of relu6
1632
+ DecorateInfo (
1633
+ pytest .mark .xfail (strict = False ),
1634
+ "test_vjp_correctness" ,
1635
+ ),
1636
+ ),
1637
+ )
1638
+ elementwise_unary_ops .append (hardswish_opinfo )
1639
+
1640
+
1601
1641
def selu_error_generator (op , device , dtype = torch .float32 , ** kwargs ):
1602
1642
a = make_tensor ((), dtype = dtype , device = device )
1603
1643
yield (SampleInput (a , inplace = True ), NotImplementedError , "selu only supports inplace=False" )
0 commit comments