Skip to content

Commit

Permalink
register relu_
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Jun 18, 2024
1 parent e4cf487 commit d6fe101
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,6 +1652,18 @@ def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
_inplace_to_out_of_place[relu] = relu, 1


@torchsymbol(torch.relu_, torch.nn.functional.relu_, id="torch.relu_", is_method=True)
def relu_(
a: TensorLike,
/,
) -> TensorLike:
return prims.copy_(relu(a, False), a)


# The default value of `inplace` is False, so no need to tweak args/kwargs
_inplace_to_out_of_place[relu_] = relu, -1


# id=torch.relu because we ignore inplace argument in torch.nn.functional.relu
@torchsymbol(torch.nn.functional.relu6, id="torch.relu6", is_method=False)
def relu6(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
Expand Down

0 comments on commit d6fe101

Please sign in to comment.