From ea0159ca20b4109a84f572020b73220fb3c85328 Mon Sep 17 00:00:00 2001 From: beverlylytle <57254617+beverlylytle@users.noreply.github.com> Date: Wed, 27 Nov 2024 11:05:52 +0100 Subject: [PATCH] add leaky_relu op (#1459) Co-authored-by: Thomas Viehmann --- thunder/executors/torchex.py | 2 ++ thunder/tests/opinfos.py | 36 +++++++++++++++++++++--------- thunder/torch/__init__.py | 11 +++++++++ thunder/torch/default_torch_ops.py | 1 - 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index f354e9abf0..9385eb3970 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -836,6 +836,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor: celu = _register_torch_operation("celu", module=torch.nn.functional) elu = _register_torch_operation("elu", module=torch.nn.functional) gelu = _register_torch_operation("gelu", module=torch.nn.functional) +leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional) relu = _register_torch_operation("relu", module=torch.nn.functional) relu6 = _register_torch_operation("relu6", module=torch.nn.functional) hardswish = _register_torch_operation("hardswish", module=torch.nn.functional) @@ -850,6 +851,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F _register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable) +_register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index b8daca7de2..5cc5c89077 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1633,20 +1633,24 @@ def _abs_torch(x: torch.Tensor | Number): elementwise_unary_ops.append(reciprocal_opinfo) -def elementwise_unary_with_alpha_generator(op, device, dtype, requires_grad): - alphas = (None, -1.0, 0.5) - samples = elementwise_unary_generator(op, device, dtype, requires_grad) - for alpha, sample in itertools.product(alphas, samples): - if alpha is None: - yield sample - else: - yield SampleInput(*sample.args, alpha=alpha, **sample.kwargs) +def get_elementwise_unary_with_alpha_generator(): + kwargs_list = [{}, {"alpha": -1.0}, {"alpha": 0.5}] + return get_elementwise_unary_with_kwargs_generator(kwargs_list) + + +def get_elementwise_unary_with_kwargs_generator(kwargs_list): + def gen(op, device, dtype, requires_grad): + samples = elementwise_unary_generator(op, device, dtype, requires_grad) + for kwargs, sample in itertools.product(kwargs_list, samples): + yield SampleInput(*sample.args, **kwargs, **sample.kwargs) + + return gen celu_opinfo = OpInfo( ltorch.celu, dtypes=(datatypes.floating,), - sample_input_generator=elementwise_unary_with_alpha_generator, + sample_input_generator=get_elementwise_unary_with_alpha_generator(), torch_reference=_elementwise_unary_torch(torch.celu), test_directives=(), ) @@ -1656,7 +1660,7 @@ def elementwise_unary_with_alpha_generator(op, device, dtype, requires_grad): elu_opinfo = OpInfo( ltorch.elu, dtypes=(datatypes.floating,), - sample_input_generator=elementwise_unary_with_alpha_generator, + sample_input_generator=get_elementwise_unary_with_alpha_generator(), torch_reference=torch.nn.functional.elu, # fdm.jvp, which is used in test_vjp_correctness, behaves badly on (-1e-6, 1e-6) for this function singularity_fn=lambda x: x, @@ -1665,6 +1669,18 @@ def elementwise_unary_with_alpha_generator(op, device, dtype, requires_grad): elementwise_unary_ops.append(elu_opinfo) +leaky_relu_opinfo = OpInfo( + ltorch.leaky_relu, + dtypes=(datatypes.floating,), + sample_input_generator=get_elementwise_unary_with_kwargs_generator([{}, {"negative_slope": 0.5}]), + torch_reference=torch.nn.functional.leaky_relu, + # fdm.jvp, which is used in test_vjp_correctness, behaves badly on (-1e-6, 1e-6) for this function + singularity_fn=lambda x: x, + test_directives=(), +) +elementwise_unary_ops.append(leaky_relu_opinfo) + + relu_opinfo = OpInfo( ltorch.relu, sample_input_generator=elementwise_unary_generator, diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a94ada1cc8..5ec568e024 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1801,6 +1801,17 @@ def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike: raise ValueError(f"gelu does not support the approximate={approximate} argument") +@torchsymbol(torch.nn.functional.leaky_relu, is_method=False) +def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool = False) -> TensorLike: + out = where(a > 0, a, a * negative_slope) + if inplace: + return prims.copy_(out, a) + return out + + +_inplace_to_out_of_place[leaky_relu] = leaky_relu, 2 + + # TODO Should this use clamp? -- Would that propagate NaNs properly? @torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True) def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike: diff --git a/thunder/torch/default_torch_ops.py b/thunder/torch/default_torch_ops.py index 84e0ae0f90..91ea98adf0 100644 --- a/thunder/torch/default_torch_ops.py +++ b/thunder/torch/default_torch_ops.py @@ -356,7 +356,6 @@ torch.nn.functional.instance_norm, torch.nn.functional.kl_div, torch.nn.functional.l1_loss, - torch.nn.functional.leaky_relu, torch.nn.functional.local_response_norm, torch.nn.functional.logsigmoid, torch.nn.functional.lp_pool1d,