diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 72758aaf59..f060cc2126 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1652,6 +1652,18 @@ def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike: _inplace_to_out_of_place[relu] = relu, 1 +@torchsymbol(torch.relu_, torch.nn.functional.relu_, id="torch.relu_", is_method=True) +def relu_( + a: TensorLike, + /, +) -> TensorLike: + return prims.copy_(relu(a, False), a) + + +# The default value of `inplace` is False, so no need to tweak args/kwargs +_inplace_to_out_of_place[relu_] = relu, -1 + + # id=torch.relu because we ignore inplace argument in torch.nn.functional.relu @torchsymbol(torch.nn.functional.relu6, id="torch.relu6", is_method=False) def relu6(a: TensorProxy, /, inplace: bool = False) -> TensorLike: