Skip to content

Commit

Permalink
Add TorchExpm1Visitor (#78)
Browse files Browse the repository at this point in the history
Follow-up for #77
  • Loading branch information
kit1980 authored Sep 17, 2024
1 parent 7bf6f67 commit e0988aa
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 2 deletions.
12 changes: 12 additions & 0 deletions tests/fixtures/misc/checker/expm1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
a = torch.randn(5)
b = torch.exp(a) - 1
c = torch.exp(a) - 1.0

ret = (torch.exp(a) - 1) * torch.exp(2 * b)

# False negative: can not detect currently
x = a.exp() - 1

# False negative: should be rare and would complicate implementation
x = -1 + torch.exp(a)
3 changes: 3 additions & 0 deletions tests/fixtures/misc/checker/expm1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
3:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`.
4:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`.
6:7 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`.
10 changes: 9 additions & 1 deletion tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ def pytest_generate_tests(metafunc):
("TOR102,TOR101", {"TOR102", "TOR101"}),
(
"TOR1,TOR102",
{"TOR102", "TOR101", "TOR103", "TOR104", "TOR105", "TOR106"},
{
"TOR101",
"TOR102",
"TOR103",
"TOR104",
"TOR105",
"TOR106",
"TOR107",
},
),
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
]
Expand Down
2 changes: 2 additions & 0 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .visitors import (
TorchDeprecatedSymbolsVisitor,
TorchExpm1Visitor,
TorchLog1pVisitor,
TorchNonPublicAliasVisitor,
TorchReentrantCheckpointVisitor,
Expand All @@ -29,6 +30,7 @@

ALL_VISITOR_CLS = [
TorchDeprecatedSymbolsVisitor,
TorchExpm1Visitor,
TorchLog1pVisitor,
TorchNonPublicAliasVisitor,
TorchRequireGradVisitor,
Expand Down
4 changes: 3 additions & 1 deletion torchfix/visitors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .deprecated_symbols import TorchDeprecatedSymbolsVisitor
from .internal import TorchScopedLibraryVisitor
from .misc import (
TorchExpm1Visitor,
TorchLog1pVisitor,
TorchReentrantCheckpointVisitor,
TorchRequireGradVisitor,
TorchLog1pVisitor,
)
from .nonpublic import TorchNonPublicAliasVisitor
from .performance import TorchSynchronizedDataLoaderVisitor
Expand All @@ -16,6 +17,7 @@

__all__ = [
"TorchDeprecatedSymbolsVisitor",
"TorchExpm1Visitor",
"TorchLog1pVisitor",
"TorchNonPublicAliasVisitor",
"TorchReentrantCheckpointVisitor",
Expand Down
33 changes: 33 additions & 0 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,36 @@ def visit_Call(self, node):
message=self.ERRORS[0].message(),
replacement=None,
)


class TorchExpm1Visitor(TorchVisitor):
"""
Suggest using `torch.special.expm1(x)` instead of `torch.exp(x) - 1`.
"""

ERRORS = [
TorchError(
"TOR107",
(
"Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. "
"It is more accurate for small values of `x`."
),
)
]

def visit_BinaryOperation(self, node):
if m.matches(
node,
m.BinaryOperation(
left=m.Call(),
operator=m.Subtract(),
right=m.Integer(value="1") | m.Float(value="1.0"),
),
):
if self.get_qualified_name_for_call(node.left) == "torch.exp":
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=None,
)

0 comments on commit e0988aa

Please sign in to comment.