Skip to content

Commit 00954c9

Browse files
authored
Don't suggest logsumexp if sum's dim is None (#91)
See pytorch/pytorch#144339
1 parent 28f1a5f commit 00954c9

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
import torch
2-
a = torch.randn(5)
3-
b = torch.randn(5)
2+
3+
x = torch.randn(5)
44

55
# logsumexp
66
y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True))
7+
y = torch.log(torch.sum(torch.exp(x), dim=1, keepdim=True))
78
y = torch.log(torch.sum(torch.exp(2.5 + x), 1))
9+
y = torch.log(torch.sum(torch.exp(2.5 + x), dim=1))
810

911
# not logsumexp
1012
y = torch.log(torch.sum(torch.exp(x), 1, keepdim=True) + 2.5)
1113
y = torch.log(torch.sum(torch.exp(x) + 2.5, 1))
1214
y = torch.log(2 + x)
1315
y = torch.sum(torch.log(torch.exp(x)), 1)
1416
y = torch.exp(torch.sum(torch.log(x), 1, keepdim=True))
17+
18+
# not logsumexp because of https://github.com/pytorch/pytorch/issues/144339
19+
y = torch.log(torch.sum(torch.exp(x), None, keepdim=True))
20+
y = torch.log(torch.sum(torch.exp(x), dim=None, keepdim=True))
21+
y = torch.log(torch.sum(torch.exp(x), keepdim=True))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
6:5 TOR108 Use numerically stabilized `torch.logsumexp`.
22
7:5 TOR108 Use numerically stabilized `torch.logsumexp`.
3+
8:5 TOR108 Use numerically stabilized `torch.logsumexp`.
4+
9:5 TOR108 Use numerically stabilized `torch.logsumexp`.

torchfix/visitors/misc/__init__.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,20 @@ def visit_Call(self, node):
184184
)
185185
== "torch.exp"
186186
):
187-
self.add_violation(
188-
node,
189-
error_code=self.ERRORS[0].error_code,
190-
message=self.ERRORS[0].message(),
191-
replacement=None,
187+
188+
# if `dim` is not provided or None for sum, skip:
189+
# https://github.com/pytorch/pytorch/issues/144339
190+
dim_arg = self.get_specific_arg(
191+
node.args[0].value, arg_name="dim", arg_pos=1
192192
)
193+
if dim_arg is not None:
194+
if not (
195+
isinstance(dim_arg.value, cst.Name)
196+
and dim_arg.value.value == "None"
197+
):
198+
self.add_violation(
199+
node,
200+
error_code=self.ERRORS[0].error_code,
201+
message=self.ERRORS[0].message(),
202+
replacement=None,
203+
)

0 commit comments

Comments
 (0)