Skip to content

Commit e659c35

Browse files
committed
Added support for HardSwish function
1 parent 94c9494 commit e659c35

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

thunder/executors/torchex.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
743743
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
744744
relu = _register_torch_operation("relu", module=torch.nn.functional)
745745
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
746+
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
746747
selu = _register_torch_operation("selu", module=torch.nn.functional)
747748
silu = _register_torch_operation("silu", module=torch.nn.functional)
748749

@@ -754,6 +755,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
754755
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
755756
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
756757
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
758+
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
757759
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
758760
_register_elementwise_unary_implementation(ltorch.silu, silu)
759761

thunder/tests/opinfos.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,46 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs):
15981598
elementwise_unary_ops.append(relu6_opinfo)
15991599

16001600

1601+
def hardswish_error_generator(op, device, dtype=torch.float32, **kwargs):
1602+
a = make_tensor((), dtype=dtype, device=device)
1603+
yield (SampleInput(a, inplace=True), NotImplementedError, "hardswish only supports inplace=False")
1604+
1605+
1606+
hardswish_opinfo = OpInfo(
1607+
ltorch.hardswish,
1608+
sample_input_generator=elementwise_unary_generator,
1609+
error_input_generator=hardswish_error_generator,
1610+
torch_reference=_elementwise_unary_torch(torch.nn.functional.hardswish),
1611+
test_directives=(
1612+
# PyTorch does not support bool for both CPU and CUDA hardswish
1613+
DecorateInfo(
1614+
pytest.mark.xfail,
1615+
"test_core_vs_torch_consistency",
1616+
dtypes=(datatypes.bool8,),
1617+
),
1618+
# PyTorch does not support complex types for both the CPU and CUDA hardswish
1619+
DecorateInfo(
1620+
pytest.mark.xfail,
1621+
"test_core_vs_torch_consistency",
1622+
dtypes=(datatypes.complexfloating,),
1623+
),
1624+
# PyTorch does not support CPU Half hardswish
1625+
DecorateInfo(
1626+
pytest.mark.xfail,
1627+
"test_core_vs_torch_consistency",
1628+
dtypes=(datatypes.float16,),
1629+
devicetypes=(devices.DeviceType.CPU,),
1630+
),
1631+
# TODO: we might have a tolerance issue here with hardsiwsh, a function of relu6
1632+
DecorateInfo(
1633+
pytest.mark.xfail(strict=False),
1634+
"test_vjp_correctness",
1635+
),
1636+
),
1637+
)
1638+
elementwise_unary_ops.append(hardswish_opinfo)
1639+
1640+
16011641
def selu_error_generator(op, device, dtype=torch.float32, **kwargs):
16021642
a = make_tensor((), dtype=dtype, device=device)
16031643
yield (SampleInput(a, inplace=True), NotImplementedError, "selu only supports inplace=False")

thunder/torch/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,12 @@ def relu6(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
12081208
return clamp(a, 0, 6)
12091209

12101210

1211+
@torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False)
1212+
def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
1213+
utils.check(not inplace, lambda: f"hardswish only supports inplace=False", exception_type=NotImplementedError)
1214+
return a * relu6(a + 3) / 6
1215+
1216+
12111217
# id=torch.selu because we ignore inplace argument in torch.nn.functional.selu
12121218
@torchsymbol(torch.selu, torch.nn.functional.selu, id="torch.selu", is_method=False)
12131219
def selu(a: TensorProxy, /, inplace: bool = False) -> TensorLike:

0 commit comments

Comments
 (0)