Skip to content

Commit

Permalink
move to grad_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle committed Dec 13, 2024
1 parent d9c8766 commit f4cacf8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
36 changes: 18 additions & 18 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,29 +1437,29 @@ def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy):
register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad)


def _log_sigmoid_grad(
a: TensorProxy,
) -> TensorProxy:
from thunder.torch import abs, exp, log_sigmoid_backward, logsigmoid
# def _log_sigmoid_grad(
# a: TensorProxy,
# ) -> TensorProxy:
# from thunder.torch import abs, exp, log_sigmoid_backward, logsigmoid

fwd = logsigmoid(a)
# fwd = logsigmoid(a)

g = get_grad(fwd)
if a.device.type == "cpu":
# NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see
# https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332
buffer = exp(-abs(a))
a_grad = log_sigmoid_backward(g, a, buffer)
else:
# Here a placeholder tensor is provided.
placeholder_buffer = empty((0,), device=a.device, dtype=a.dtype)
a_grad = log_sigmoid_backward(g, a, placeholder_buffer)
put_grad(a, a_grad)
# g = get_grad(fwd)
# if a.device.type == "cpu":
# # NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see
# # https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332
# buffer = exp(-abs(a))
# a_grad = log_sigmoid_backward(g, a, buffer)
# else:
# # Here a placeholder tensor is provided.
# placeholder_buffer = empty((0,), device=a.device, dtype=a.dtype)
# a_grad = log_sigmoid_backward(g, a, placeholder_buffer)
# put_grad(a, a_grad)

return fwd
# return fwd


register_grad("torch.nn.functional.logsigmoid", _log_sigmoid_grad)
# register_grad("torch.nn.functional.logsigmoid", _log_sigmoid_grad)


#
Expand Down
20 changes: 19 additions & 1 deletion thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,25 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
_register_elementwise_unary_implementation(
ltorch.log_sigmoid_backward, log_sigmoid_backward, checker=_always_executable
)
_register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid)
# _register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid)


def log_sigmoid_grad_transform(a):
fwd = logsigmoid(a)

g = get_grad(fwd)
# NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see
# https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332
buffer = exp(-abs(a))
a_grad = log_sigmoid_backward(g, a, buffer)

put_grad(a, a_grad)
return fwd


ex.register_implementation(
ltorch.logsigmoid, logsigmoid, checker=_elementwise_unary_checker, grad_transform=log_sigmoid_grad_transform
)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
Expand Down

0 comments on commit f4cacf8

Please sign in to comment.