diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 0ee80f9cf3..8484d389db 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -838,6 +838,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor: leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional) relu = _register_torch_operation("relu", module=torch.nn.functional) relu6 = _register_torch_operation("relu6", module=torch.nn.functional) +hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional) hardswish = _register_torch_operation("hardswish", module=torch.nn.functional) selu = _register_torch_operation("selu", module=torch.nn.functional) silu = _register_torch_operation("silu", module=torch.nn.functional) @@ -853,6 +854,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F _register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker) +_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 159e1ce3d9..d41a9ae433 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -66,8 +66,8 @@ def push_away_from_singularities(x, singularity_fn, eps): `eps` away from them. The `singularity_fn` returns the (signed) distance from `x` to the nearest singularity.""" x_dist = singularity_fn(x) - x_ = torch.where((x_dist > 0) & (x_dist < eps), x + eps, x) - return torch.where((x_dist < 0) & (x_dist > -eps), x - eps, x_) + x_ = torch.where((x_dist >= 0) & (x_dist < eps), x + eps, x) + return torch.where((x_dist <= 0) & (x_dist > -eps), x_ - eps, x_) # Randomly select a fraction of the elements in a tensor and set them to specified value @@ -1727,6 +1727,37 @@ def gen(op, device, dtype, requires_grad): elementwise_unary_ops.append(relu6_opinfo) +# For positive lambd, hardshrink's singularities occur at lambd and -lambd, the locations of jump discontinuties +# of its partial derivatives. Since lambd is passed as an input kwarg, the singularity_fn depends upon the input +# sample. Therefore, mutliple opinfos with varying sample generator and singularity_fn pairs are added. +def get_hardshrink_singularity_fn(lambd): + if lambd is None: + lambd = 0.5 + return lambda a: torch.where(a >= 0, a - lambd, a + lambd) + + +def hardshrink_opinfo_factory(lambds): + for lambd in lambds: + kwargs = {} if lambd is None else {"lambd": lambd} + name = "hardshrink_" + str(lambd) + singularity_fn = get_hardshrink_singularity_fn(lambd) + + hardshrink_opinfo = OpInfo( + ltorch.hardshrink, + name=name, + dtypes=(datatypes.floating,), + sample_input_generator=get_elementwise_unary_with_kwargs_generator([kwargs]), + torch_reference=_elementwise_unary_torch(torch.nn.functional.hardshrink), + # fdm.jvp, which is used in test_vjp_correctness, behaves badly at jump discontinuties of the partial derviatives + singularity_fn=singularity_fn, + test_directives=(), + ) + elementwise_unary_ops.append(hardshrink_opinfo) + + +hardshrink_opinfo_factory([None, 0.25, -0.1]) + + hardswish_opinfo = OpInfo( ltorch.hardswish, sample_input_generator=elementwise_unary_generator, diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 1dd272ee9f..4998a6d29c 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1849,6 +1849,18 @@ def relu6(a: TensorProxy, /, inplace: bool = False) -> TensorLike: _inplace_to_out_of_place[relu6] = relu6, 1 +@torchsymbol(torch.nn.functional.hardshrink, is_method=False) +def hardshrink(a: TensorProxy, /, lambd: float = 0.5) -> TensorLike: + utils.check( + not dtypes.is_complex_dtype(a.dtype), + lambda: f"hardshrink not implemented for '{a.dtype}'", + ) + return where(abs(a) <= lambd, 0, a) + + +_inplace_to_out_of_place[hardshrink] = hardshrink, -1 + + @torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False) def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike: utils.check( diff --git a/thunder/torch/default_torch_ops.py b/thunder/torch/default_torch_ops.py index a43b6c3131..800c548ba6 100644 --- a/thunder/torch/default_torch_ops.py +++ b/thunder/torch/default_torch_ops.py @@ -133,7 +133,6 @@ torch.grid_sampler_3d, torch.gru, torch.gru_cell, - torch.hardshrink, torch.heaviside, torch.hinge_embedding_loss, torch.histc, @@ -348,7 +347,6 @@ torch.nn.functional.gaussian_nll_loss, torch.nn.functional.grid_sample, torch.nn.functional.gumbel_softmax, - torch.nn.functional.hardshrink, torch.nn.functional.hardtanh, torch.nn.functional.hinge_embedding_loss, torch.nn.functional.huber_loss, @@ -478,7 +476,6 @@ torch.Tensor.greater, torch.Tensor.greater_equal, torch.Tensor.half, - torch.Tensor.hardshrink, torch.Tensor.has_names, torch.Tensor.heaviside, torch.Tensor.histc,