Skip to content

Commit

Permalink
add logsigmoid op (#1520)
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle authored Dec 14, 2024
1 parent 84ba13c commit 6c4e020
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 8 deletions.
27 changes: 27 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
)
import thunder.clang as clang
from thunder.clang import (
empty,
full,
full_like,
unsqueeze,
Expand Down Expand Up @@ -1435,6 +1436,32 @@ def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy):

register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad)


def _log_sigmoid_grad(
a: TensorProxy,
) -> TensorProxy:
from thunder.torch import abs, exp, log_sigmoid_backward, logsigmoid

fwd = logsigmoid(a)

g = get_grad(fwd)
if a.device.type == "cpu":
# NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see
# https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332
buffer = exp(-abs(a))
a_grad = log_sigmoid_backward(g, a, buffer)
else:
# Here a placeholder tensor is provided.
placeholder_buffer = empty((0,), device=a.device, dtype=a.dtype)
a_grad = log_sigmoid_backward(g, a, placeholder_buffer)
put_grad(a, a_grad)

return fwd


register_grad("torch.nn.functional.logsigmoid", _log_sigmoid_grad)


#
# Phantom grad transform helpers
#
Expand Down
16 changes: 12 additions & 4 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,11 +835,15 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
celu = _register_torch_operation("celu", module=torch.nn.functional)
elu = _register_torch_operation("elu", module=torch.nn.functional)
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional)
logsigmoid = _register_torch_operation("logsigmoid", module=torch.nn.functional)
log_sigmoid_backward = _register_torch_operation(
"torch.ops.aten.log_sigmoid_backward", like=ltorch.log_sigmoid_backward
)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
selu = _register_torch_operation("selu", module=torch.nn.functional)
silu = _register_torch_operation("silu", module=torch.nn.functional)
tanhshrink = _register_torch_operation("tanhshrink", module=torch.nn.functional)
Expand All @@ -852,11 +856,15 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
_register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable)
_register_elementwise_unary_implementation(
ltorch.log_sigmoid_backward, log_sigmoid_backward, checker=_always_executable
)
_register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.tanhshrink, tanhshrink, checker=_always_executable)
Expand Down
11 changes: 11 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,17 @@ def gen(op, device, dtype, requires_grad):
elementwise_unary_ops.append(leaky_relu_opinfo)


logsigmoid_opinfo = OpInfo(
ltorch.logsigmoid,
dtypes=(datatypes.floating,),
sample_input_generator=elementwise_unary_generator,
torch_reference=torch.nn.functional.logsigmoid,
domain=(-1, 1),
test_directives=(),
)
elementwise_unary_ops.append(logsigmoid_opinfo)


relu_opinfo = OpInfo(
ltorch.relu,
sample_input_generator=elementwise_unary_generator,
Expand Down
16 changes: 13 additions & 3 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,6 +1812,19 @@ def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool =
_inplace_to_out_of_place[leaky_relu] = leaky_relu, 2


@torchsymbol(torch.nn.functional.logsigmoid, is_method=False)
def logsigmoid(a: TensorProxy, /) -> TensorLike:
return where(a > 0, -log1p(exp(-a)), a - log1p(exp(a)))


@torchsymbol("log_sigmoid_backward", id="log_sigmoid_backward")
def log_sigmoid_backward(g: TensorProxy, a: TensorProxy, buffer: TensorProxy) -> TensorLike:
# buffer is used by PyTorch in cpu-based calculations. See
# https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332
# This is addressed in the custom grad fn thunder.core.transforms._log_sigmoid_grad.
return g * where(a > 0, exp(-a) / (1 + exp(-a)), 1 - exp(a) / (1 + exp(a)))


# TODO Should this use clamp? -- Would that propagate NaNs properly?
@torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True)
def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
Expand Down Expand Up @@ -1858,9 +1871,6 @@ def hardshrink(a: TensorProxy, /, lambd: float = 0.5) -> TensorLike:
return where(abs(a) <= lambd, 0, a)


_inplace_to_out_of_place[hardshrink] = hardshrink, -1


@torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False)
def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
utils.check(
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@
torch.nn.functional.kl_div,
torch.nn.functional.l1_loss,
torch.nn.functional.local_response_norm,
torch.nn.functional.logsigmoid,
torch.nn.functional.lp_pool1d,
torch.nn.functional.lp_pool2d,
torch.nn.functional.lp_pool3d,
Expand Down

0 comments on commit 6c4e020

Please sign in to comment.