File tree Expand file tree Collapse file tree 3 files changed +27
-7
lines changed
tests/fixtures/misc/checker Expand file tree Collapse file tree 3 files changed +27
-7
lines changed Original file line number Diff line number Diff line change 11import torch
2- a = torch . randn ( 5 )
3- b = torch .randn (5 )
2+
3+ x = torch .randn (5 )
44
55# logsumexp
66y = torch .log (torch .sum (torch .exp (x ), 1 , keepdim = True ))
7+ y = torch .log (torch .sum (torch .exp (x ), dim = 1 , keepdim = True ))
78y = torch .log (torch .sum (torch .exp (2.5 + x ), 1 ))
9+ y = torch .log (torch .sum (torch .exp (2.5 + x ), dim = 1 ))
810
911# not logsumexp
1012y = torch .log (torch .sum (torch .exp (x ), 1 , keepdim = True ) + 2.5 )
1113y = torch .log (torch .sum (torch .exp (x ) + 2.5 , 1 ))
1214y = torch .log (2 + x )
1315y = torch .sum (torch .log (torch .exp (x )), 1 )
1416y = torch .exp (torch .sum (torch .log (x ), 1 , keepdim = True ))
17+
18+ # not logsumexp because of https://github.com/pytorch/pytorch/issues/144339
19+ y = torch .log (torch .sum (torch .exp (x ), None , keepdim = True ))
20+ y = torch .log (torch .sum (torch .exp (x ), dim = None , keepdim = True ))
21+ y = torch .log (torch .sum (torch .exp (x ), keepdim = True ))
Original file line number Diff line number Diff line change 116:5 TOR108 Use numerically stabilized `torch.logsumexp`.
227:5 TOR108 Use numerically stabilized `torch.logsumexp`.
3+ 8:5 TOR108 Use numerically stabilized `torch.logsumexp`.
4+ 9:5 TOR108 Use numerically stabilized `torch.logsumexp`.
Original file line number Diff line number Diff line change @@ -184,9 +184,20 @@ def visit_Call(self, node):
184184 )
185185 == "torch.exp"
186186 ):
187- self . add_violation (
188- node ,
189- error_code = self . ERRORS [ 0 ]. error_code ,
190- message = self .ERRORS [ 0 ]. message (),
191- replacement = None ,
187+
188+ # if `dim` is not provided or None for sum, skip:
189+ # https://github.com/pytorch/pytorch/issues/144339
190+ dim_arg = self .get_specific_arg (
191+ node . args [ 0 ]. value , arg_name = "dim" , arg_pos = 1
192192 )
193+ if dim_arg is not None :
194+ if not (
195+ isinstance (dim_arg .value , cst .Name )
196+ and dim_arg .value .value == "None"
197+ ):
198+ self .add_violation (
199+ node ,
200+ error_code = self .ERRORS [0 ].error_code ,
201+ message = self .ERRORS [0 ].message (),
202+ replacement = None ,
203+ )
You can’t perform that action at this time.
0 commit comments