From e0988aa1b10ca7d319c752a2a095f00cd676637a Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 17 Sep 2024 14:52:56 -0700 Subject: [PATCH] Add TorchExpm1Visitor (#78) Follow-up for https://github.com/pytorch-labs/torchfix/pull/77 --- tests/fixtures/misc/checker/expm1.py | 12 ++++++++++ tests/fixtures/misc/checker/expm1.txt | 3 +++ tests/test_torchfix.py | 10 +++++++- torchfix/torchfix.py | 2 ++ torchfix/visitors/__init__.py | 4 +++- torchfix/visitors/misc/__init__.py | 33 +++++++++++++++++++++++++++ 6 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 tests/fixtures/misc/checker/expm1.py create mode 100644 tests/fixtures/misc/checker/expm1.txt diff --git a/tests/fixtures/misc/checker/expm1.py b/tests/fixtures/misc/checker/expm1.py new file mode 100644 index 0000000..4a7646d --- /dev/null +++ b/tests/fixtures/misc/checker/expm1.py @@ -0,0 +1,12 @@ +import torch +a = torch.randn(5) +b = torch.exp(a) - 1 +c = torch.exp(a) - 1.0 + +ret = (torch.exp(a) - 1) * torch.exp(2 * b) + +# False negative: can not detect currently +x = a.exp() - 1 + +# False negative: should be rare and would complicate implementation +x = -1 + torch.exp(a) diff --git a/tests/fixtures/misc/checker/expm1.txt b/tests/fixtures/misc/checker/expm1.txt new file mode 100644 index 0000000..ed24905 --- /dev/null +++ b/tests/fixtures/misc/checker/expm1.txt @@ -0,0 +1,3 @@ +3:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`. +4:5 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`. +6:7 TOR107 Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. It is more accurate for small values of `x`. diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 29f6cf9..7b79fb7 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -37,7 +37,15 @@ def pytest_generate_tests(metafunc): ("TOR102,TOR101", {"TOR102", "TOR101"}), ( "TOR1,TOR102", - {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105", "TOR106"}, + { + "TOR101", + "TOR102", + "TOR103", + "TOR104", + "TOR105", + "TOR106", + "TOR107", + }, ), (None, set(GET_ALL_ERROR_CODES()) - exclude_set), ] diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 80acb1c..0798a5e 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -9,6 +9,7 @@ from .visitors import ( TorchDeprecatedSymbolsVisitor, + TorchExpm1Visitor, TorchLog1pVisitor, TorchNonPublicAliasVisitor, TorchReentrantCheckpointVisitor, @@ -29,6 +30,7 @@ ALL_VISITOR_CLS = [ TorchDeprecatedSymbolsVisitor, + TorchExpm1Visitor, TorchLog1pVisitor, TorchNonPublicAliasVisitor, TorchRequireGradVisitor, diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index f63e405..8e56b4a 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -1,9 +1,10 @@ from .deprecated_symbols import TorchDeprecatedSymbolsVisitor from .internal import TorchScopedLibraryVisitor from .misc import ( + TorchExpm1Visitor, + TorchLog1pVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, - TorchLog1pVisitor, ) from .nonpublic import TorchNonPublicAliasVisitor from .performance import TorchSynchronizedDataLoaderVisitor @@ -16,6 +17,7 @@ __all__ = [ "TorchDeprecatedSymbolsVisitor", + "TorchExpm1Visitor", "TorchLog1pVisitor", "TorchNonPublicAliasVisitor", "TorchReentrantCheckpointVisitor", diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index edc6809..348612c 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -121,3 +121,36 @@ def visit_Call(self, node): message=self.ERRORS[0].message(), replacement=None, ) + + +class TorchExpm1Visitor(TorchVisitor): + """ + Suggest using `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. + """ + + ERRORS = [ + TorchError( + "TOR107", + ( + "Use `torch.special.expm1(x)` instead of `torch.exp(x) - 1`. " + "It is more accurate for small values of `x`." + ), + ) + ] + + def visit_BinaryOperation(self, node): + if m.matches( + node, + m.BinaryOperation( + left=m.Call(), + operator=m.Subtract(), + right=m.Integer(value="1") | m.Float(value="1.0"), + ), + ): + if self.get_qualified_name_for_call(node.left) == "torch.exp": + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + )