Skip to content

Commit 16f45ea

Browse files
beverlylytlericcardofelluga
authored andcommitted
Add hardshrink (#1505)
1 parent 3c9a23f commit 16f45ea

File tree

4 files changed

+47
-5
lines changed

4 files changed

+47
-5
lines changed

thunder/executors/torchex.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
838838
leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional)
839839
relu = _register_torch_operation("relu", module=torch.nn.functional)
840840
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
841+
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
841842
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
842843
selu = _register_torch_operation("selu", module=torch.nn.functional)
843844
silu = _register_torch_operation("silu", module=torch.nn.functional)
@@ -853,6 +854,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
853854
_register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable)
854855
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
855856
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
857+
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
856858
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
857859
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
858860
_register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable)

thunder/tests/opinfos.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def push_away_from_singularities(x, singularity_fn, eps):
6666
`eps` away from them. The `singularity_fn` returns the (signed)
6767
distance from `x` to the nearest singularity."""
6868
x_dist = singularity_fn(x)
69-
x_ = torch.where((x_dist > 0) & (x_dist < eps), x + eps, x)
70-
return torch.where((x_dist < 0) & (x_dist > -eps), x - eps, x_)
69+
x_ = torch.where((x_dist >= 0) & (x_dist < eps), x + eps, x)
70+
return torch.where((x_dist <= 0) & (x_dist > -eps), x_ - eps, x_)
7171

7272

7373
# Randomly select a fraction of the elements in a tensor and set them to specified value
@@ -1727,6 +1727,37 @@ def gen(op, device, dtype, requires_grad):
17271727
elementwise_unary_ops.append(relu6_opinfo)
17281728

17291729

1730+
# For positive lambd, hardshrink's singularities occur at lambd and -lambd, the locations of jump discontinuties
1731+
# of its partial derivatives. Since lambd is passed as an input kwarg, the singularity_fn depends upon the input
1732+
# sample. Therefore, mutliple opinfos with varying sample generator and singularity_fn pairs are added.
1733+
def get_hardshrink_singularity_fn(lambd):
1734+
if lambd is None:
1735+
lambd = 0.5
1736+
return lambda a: torch.where(a >= 0, a - lambd, a + lambd)
1737+
1738+
1739+
def hardshrink_opinfo_factory(lambds):
1740+
for lambd in lambds:
1741+
kwargs = {} if lambd is None else {"lambd": lambd}
1742+
name = "hardshrink_" + str(lambd)
1743+
singularity_fn = get_hardshrink_singularity_fn(lambd)
1744+
1745+
hardshrink_opinfo = OpInfo(
1746+
ltorch.hardshrink,
1747+
name=name,
1748+
dtypes=(datatypes.floating,),
1749+
sample_input_generator=get_elementwise_unary_with_kwargs_generator([kwargs]),
1750+
torch_reference=_elementwise_unary_torch(torch.nn.functional.hardshrink),
1751+
# fdm.jvp, which is used in test_vjp_correctness, behaves badly at jump discontinuties of the partial derviatives
1752+
singularity_fn=singularity_fn,
1753+
test_directives=(),
1754+
)
1755+
elementwise_unary_ops.append(hardshrink_opinfo)
1756+
1757+
1758+
hardshrink_opinfo_factory([None, 0.25, -0.1])
1759+
1760+
17301761
hardswish_opinfo = OpInfo(
17311762
ltorch.hardswish,
17321763
sample_input_generator=elementwise_unary_generator,

thunder/torch/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,6 +1849,18 @@ def relu6(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
18491849
_inplace_to_out_of_place[relu6] = relu6, 1
18501850

18511851

1852+
@torchsymbol(torch.nn.functional.hardshrink, is_method=False)
1853+
def hardshrink(a: TensorProxy, /, lambd: float = 0.5) -> TensorLike:
1854+
utils.check(
1855+
not dtypes.is_complex_dtype(a.dtype),
1856+
lambda: f"hardshrink not implemented for '{a.dtype}'",
1857+
)
1858+
return where(abs(a) <= lambd, 0, a)
1859+
1860+
1861+
_inplace_to_out_of_place[hardshrink] = hardshrink, -1
1862+
1863+
18521864
@torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False)
18531865
def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
18541866
utils.check(

thunder/torch/default_torch_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@
133133
torch.grid_sampler_3d,
134134
torch.gru,
135135
torch.gru_cell,
136-
torch.hardshrink,
137136
torch.heaviside,
138137
torch.hinge_embedding_loss,
139138
torch.histc,
@@ -348,7 +347,6 @@
348347
torch.nn.functional.gaussian_nll_loss,
349348
torch.nn.functional.grid_sample,
350349
torch.nn.functional.gumbel_softmax,
351-
torch.nn.functional.hardshrink,
352350
torch.nn.functional.hardtanh,
353351
torch.nn.functional.hinge_embedding_loss,
354352
torch.nn.functional.huber_loss,
@@ -478,7 +476,6 @@
478476
torch.Tensor.greater,
479477
torch.Tensor.greater_equal,
480478
torch.Tensor.half,
481-
torch.Tensor.hardshrink,
482479
torch.Tensor.has_names,
483480
torch.Tensor.heaviside,
484481
torch.Tensor.histc,

0 commit comments

Comments
 (0)