Skip to content

Commit

Permalink
Add hardshrink (#1505)
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle authored Dec 3, 2024
1 parent 2228304 commit a3cdbc4
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 5 deletions.
2 changes: 2 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
selu = _register_torch_operation("selu", module=torch.nn.functional)
silu = _register_torch_operation("silu", module=torch.nn.functional)
Expand All @@ -853,6 +854,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
_register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable)
Expand Down
35 changes: 33 additions & 2 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def push_away_from_singularities(x, singularity_fn, eps):
`eps` away from them. The `singularity_fn` returns the (signed)
distance from `x` to the nearest singularity."""
x_dist = singularity_fn(x)
x_ = torch.where((x_dist > 0) & (x_dist < eps), x + eps, x)
return torch.where((x_dist < 0) & (x_dist > -eps), x - eps, x_)
x_ = torch.where((x_dist >= 0) & (x_dist < eps), x + eps, x)
return torch.where((x_dist <= 0) & (x_dist > -eps), x_ - eps, x_)


# Randomly select a fraction of the elements in a tensor and set them to specified value
Expand Down Expand Up @@ -1727,6 +1727,37 @@ def gen(op, device, dtype, requires_grad):
elementwise_unary_ops.append(relu6_opinfo)


# For positive lambd, hardshrink's singularities occur at lambd and -lambd, the locations of jump discontinuties
# of its partial derivatives. Since lambd is passed as an input kwarg, the singularity_fn depends upon the input
# sample. Therefore, mutliple opinfos with varying sample generator and singularity_fn pairs are added.
def get_hardshrink_singularity_fn(lambd):
if lambd is None:
lambd = 0.5
return lambda a: torch.where(a >= 0, a - lambd, a + lambd)


def hardshrink_opinfo_factory(lambds):
for lambd in lambds:
kwargs = {} if lambd is None else {"lambd": lambd}
name = "hardshrink_" + str(lambd)
singularity_fn = get_hardshrink_singularity_fn(lambd)

hardshrink_opinfo = OpInfo(
ltorch.hardshrink,
name=name,
dtypes=(datatypes.floating,),
sample_input_generator=get_elementwise_unary_with_kwargs_generator([kwargs]),
torch_reference=_elementwise_unary_torch(torch.nn.functional.hardshrink),
# fdm.jvp, which is used in test_vjp_correctness, behaves badly at jump discontinuties of the partial derviatives
singularity_fn=singularity_fn,
test_directives=(),
)
elementwise_unary_ops.append(hardshrink_opinfo)


hardshrink_opinfo_factory([None, 0.25, -0.1])


hardswish_opinfo = OpInfo(
ltorch.hardswish,
sample_input_generator=elementwise_unary_generator,
Expand Down
12 changes: 12 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,18 @@ def relu6(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
_inplace_to_out_of_place[relu6] = relu6, 1


@torchsymbol(torch.nn.functional.hardshrink, is_method=False)
def hardshrink(a: TensorProxy, /, lambd: float = 0.5) -> TensorLike:
utils.check(
not dtypes.is_complex_dtype(a.dtype),
lambda: f"hardshrink not implemented for '{a.dtype}'",
)
return where(abs(a) <= lambd, 0, a)


_inplace_to_out_of_place[hardshrink] = hardshrink, -1


@torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False)
def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
utils.check(
Expand Down
3 changes: 0 additions & 3 deletions thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@
torch.grid_sampler_3d,
torch.gru,
torch.gru_cell,
torch.hardshrink,
torch.heaviside,
torch.hinge_embedding_loss,
torch.histc,
Expand Down Expand Up @@ -348,7 +347,6 @@
torch.nn.functional.gaussian_nll_loss,
torch.nn.functional.grid_sample,
torch.nn.functional.gumbel_softmax,
torch.nn.functional.hardshrink,
torch.nn.functional.hardtanh,
torch.nn.functional.hinge_embedding_loss,
torch.nn.functional.huber_loss,
Expand Down Expand Up @@ -478,7 +476,6 @@
torch.Tensor.greater,
torch.Tensor.greater_equal,
torch.Tensor.half,
torch.Tensor.hardshrink,
torch.Tensor.has_names,
torch.Tensor.heaviside,
torch.Tensor.histc,
Expand Down

0 comments on commit a3cdbc4

Please sign in to comment.