From 6bdcc03deabe387a43c14b2610ef3d7f57513084 Mon Sep 17 00:00:00 2001 From: beverlylytle Date: Fri, 13 Dec 2024 20:04:40 +0200 Subject: [PATCH] revert move to grad_transform --- thunder/core/transforms.py | 36 ++++++++++++++++++------------------ thunder/executors/torchex.py | 20 +------------------- 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index cef37b8e69..3cb34a623e 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1437,29 +1437,29 @@ def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy): register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad) -# def _log_sigmoid_grad( -# a: TensorProxy, -# ) -> TensorProxy: -# from thunder.torch import abs, exp, log_sigmoid_backward, logsigmoid +def _log_sigmoid_grad( + a: TensorProxy, +) -> TensorProxy: + from thunder.torch import abs, exp, log_sigmoid_backward, logsigmoid -# fwd = logsigmoid(a) + fwd = logsigmoid(a) -# g = get_grad(fwd) -# if a.device.type == "cpu": -# # NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see -# # https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332 -# buffer = exp(-abs(a)) -# a_grad = log_sigmoid_backward(g, a, buffer) -# else: -# # Here a placeholder tensor is provided. -# placeholder_buffer = empty((0,), device=a.device, dtype=a.dtype) -# a_grad = log_sigmoid_backward(g, a, placeholder_buffer) -# put_grad(a, a_grad) + g = get_grad(fwd) + if a.device.type == "cpu": + # NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see + # https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332 + buffer = exp(-abs(a)) + a_grad = log_sigmoid_backward(g, a, buffer) + else: + # Here a placeholder tensor is provided. + placeholder_buffer = empty((0,), device=a.device, dtype=a.dtype) + a_grad = log_sigmoid_backward(g, a, placeholder_buffer) + put_grad(a, a_grad) -# return fwd + return fwd -# register_grad("torch.nn.functional.logsigmoid", _log_sigmoid_grad) +register_grad("torch.nn.functional.logsigmoid", _log_sigmoid_grad) # diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index ad23934dd9..df10fd9a4b 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -861,25 +861,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F _register_elementwise_unary_implementation( ltorch.log_sigmoid_backward, log_sigmoid_backward, checker=_always_executable ) -# _register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid) - - -def log_sigmoid_grad_transform(a): - fwd = logsigmoid(a) - - g = get_grad(fwd) - # NOTE PyTorch's CPU computation for logsigmoid's grad uses an additional "buffer" tensor, see - # https://github.com/pytorch/pytorch/blob/7667235a23e2ffca4d32e6e16aa60a683418e159/torch/_decomp/decompositions.py#L332 - buffer = exp(-abs(a)) - a_grad = log_sigmoid_backward(g, a, buffer) - - put_grad(a, a_grad) - return fwd - - -ex.register_implementation( - ltorch.logsigmoid, logsigmoid, checker=_elementwise_unary_checker, grad_transform=log_sigmoid_grad_transform -) +_register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid) _register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker) _register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)