diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index cbf948be52..3cb34a623e 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -64,6 +64,7 @@ ) import thunder.clang as clang from thunder.clang import ( + empty, full, full_like, unsqueeze, @@ -1435,6 +1436,32 @@ def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy): register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad) + +def _log_sigmoid_grad( + a: TensorProxy, +) -> TensorProxy: + from thunder.torch import abs, exp, log_sigmoid_backward, logsigmoid + + fwd = logsigmoid(a) + + g = get_grad(fwd) + if a.device.type == "cpu": + # NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see + # https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332 + buffer = exp(-abs(a)) + a_grad = log_sigmoid_backward(g, a, buffer) + else: + # Here a placeholder tensor is provided. + placeholder_buffer = empty((0,), device=a.device, dtype=a.dtype) + a_grad = log_sigmoid_backward(g, a, placeholder_buffer) + put_grad(a, a_grad) + + return fwd + + +register_grad("torch.nn.functional.logsigmoid", _log_sigmoid_grad) + + # # Phantom grad transform helpers # diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index faeb74bf98..6de644204d 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -835,11 +835,15 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor: celu = _register_torch_operation("celu", module=torch.nn.functional) elu = _register_torch_operation("elu", module=torch.nn.functional) gelu = _register_torch_operation("gelu", module=torch.nn.functional) +hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional) +hardswish = _register_torch_operation("hardswish", module=torch.nn.functional) leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional) +logsigmoid = _register_torch_operation("logsigmoid", module=torch.nn.functional) +log_sigmoid_backward = _register_torch_operation( + "torch.ops.aten.log_sigmoid_backward", like=ltorch.log_sigmoid_backward +) relu = _register_torch_operation("relu", module=torch.nn.functional) relu6 = _register_torch_operation("relu6", module=torch.nn.functional) -hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional) -hardswish = _register_torch_operation("hardswish", module=torch.nn.functional) selu = _register_torch_operation("selu", module=torch.nn.functional) silu = _register_torch_operation("silu", module=torch.nn.functional) tanhshrink = _register_torch_operation("tanhshrink", module=torch.nn.functional) @@ -852,11 +856,15 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F _register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable) +_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable) +_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable) +_register_elementwise_unary_implementation( + ltorch.log_sigmoid_backward, log_sigmoid_backward, checker=_always_executable +) +_register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid) _register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker) -_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable) -_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable) _register_elementwise_unary_implementation(ltorch.tanhshrink, tanhshrink, checker=_always_executable) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index a142ab8a94..c63c4667b3 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1685,6 +1685,17 @@ def gen(op, device, dtype, requires_grad): elementwise_unary_ops.append(leaky_relu_opinfo) +logsigmoid_opinfo = OpInfo( + ltorch.logsigmoid, + dtypes=(datatypes.floating,), + sample_input_generator=elementwise_unary_generator, + torch_reference=torch.nn.functional.logsigmoid, + domain=(-1, 1), + test_directives=(), +) +elementwise_unary_ops.append(logsigmoid_opinfo) + + relu_opinfo = OpInfo( ltorch.relu, sample_input_generator=elementwise_unary_generator, diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 66a4047287..57b64bfee0 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1812,6 +1812,19 @@ def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool = _inplace_to_out_of_place[leaky_relu] = leaky_relu, 2 +@torchsymbol(torch.nn.functional.logsigmoid, is_method=False) +def logsigmoid(a: TensorProxy, /) -> TensorLike: + return where(a > 0, -log1p(exp(-a)), a - log1p(exp(a))) + + +@torchsymbol("log_sigmoid_backward", id="log_sigmoid_backward") +def log_sigmoid_backward(g: TensorProxy, a: TensorProxy, buffer: TensorProxy) -> TensorLike: + # buffer is used by PyTorch in cpu-based calculations. See + # https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332 + # This is addressed in the custom grad fn thunder.core.transforms._log_sigmoid_grad. + return g * where(a > 0, exp(-a) / (1 + exp(-a)), 1 - exp(a) / (1 + exp(a))) + + # TODO Should this use clamp? -- Would that propagate NaNs properly? @torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True) def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike: @@ -1858,9 +1871,6 @@ def hardshrink(a: TensorProxy, /, lambd: float = 0.5) -> TensorLike: return where(abs(a) <= lambd, 0, a) -_inplace_to_out_of_place[hardshrink] = hardshrink, -1 - - @torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False) def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike: utils.check( diff --git a/thunder/torch/default_torch_ops.py b/thunder/torch/default_torch_ops.py index 3be2f163d3..e6b56ece4b 100644 --- a/thunder/torch/default_torch_ops.py +++ b/thunder/torch/default_torch_ops.py @@ -354,7 +354,6 @@ torch.nn.functional.kl_div, torch.nn.functional.l1_loss, torch.nn.functional.local_response_norm, - torch.nn.functional.logsigmoid, torch.nn.functional.lp_pool1d, torch.nn.functional.lp_pool2d, torch.nn.functional.lp_pool3d,