File tree Expand file tree Collapse file tree 5 files changed +58
-1
lines changed
tests/fixtures/performance/checker Expand file tree Collapse file tree 5 files changed +58
-1
lines changed Original file line number Diff line number Diff line change 1+ import torch
2+ import torch .nn as nn
3+
4+ x = torch .ones ((100 , 100 ))
5+ model = nn .Sequential ()
6+ optimizer = torch .optim .Adam (model .parameters ())
7+
8+ # This should raise flags
9+ optimizer .zero_grad (set_to_none = False )
10+ model .zero_grad (set_to_none = False )
11+
12+ # This should not raise flags
13+ optimizer .zero_grad ()
14+ model .zero_grad ()
15+
16+
Original file line number Diff line number Diff line change 1+ 9:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad().
2+ 10:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad().
Original file line number Diff line number Diff line change 2121 TorchVisionDeprecatedPretrainedVisitor ,
2222 TorchVisionDeprecatedToTensorVisitor ,
2323 TorchVisionSingletonImportVisitor ,
24+ TorchGradNotSetToNonePatternVisitor ,
2425)
2526
2627__version__ = "0.7.0"
4344 TorchVisionDeprecatedPretrainedVisitor ,
4445 TorchVisionDeprecatedToTensorVisitor ,
4546 TorchVisionSingletonImportVisitor ,
47+ TorchGradNotSetToNonePatternVisitor ,
4648]
4749
4850
Original file line number Diff line number Diff line change 88 TorchRequireGradVisitor ,
99)
1010from .nonpublic import TorchNonPublicAliasVisitor
11- from .performance import TorchSynchronizedDataLoaderVisitor
11+ from .performance import (
12+ TorchSynchronizedDataLoaderVisitor ,
13+ TorchGradNotSetToNonePatternVisitor ,
14+ )
1215from .security import TorchUnsafeLoadVisitor
1316from .vision import (
1417 TorchVisionDeprecatedPretrainedVisitor ,
3033 "TorchVisionDeprecatedPretrainedVisitor" ,
3134 "TorchVisionDeprecatedToTensorVisitor" ,
3235 "TorchVisionSingletonImportVisitor" ,
36+ "TorchGradNotSetToNonePatternVisitor" ,
3337]
Original file line number Diff line number Diff line change @@ -32,3 +32,36 @@ def visit_Call(self, node):
3232 error_code = self .ERRORS [0 ].error_code ,
3333 message = self .ERRORS [0 ].message (),
3434 )
35+
36+
37+ class TorchGradNotSetToNonePatternVisitor (TorchVisitor ):
38+ """
39+ Reimplementation of GradNotSetToNonePattern from
40+ https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py
41+ """
42+
43+ ERRORS = [
44+ TorchError (
45+ "TOR402" ,
46+ (
47+ "Detected gradient set to zero instead of None. "
48+ "Please add 'set_to_none=True' when calling zero_grad()."
49+ ),
50+ )
51+ ]
52+
53+ def visit_Call (self , node ):
54+ qualified_name = self .get_qualified_name_for_call (node )
55+
56+ if qualified_name and qualified_name .endswith ("zero_grad" ):
57+
58+ set_to_none_arg = self .get_specific_arg (node , "set_to_none" , 0 )
59+
60+ # hasattr check to handle mypy error
61+ if set_to_none_arg and hasattr (set_to_none_arg .value , "value" ):
62+ if set_to_none_arg .value .value == "False" :
63+ self .add_violation (
64+ node ,
65+ error_code = self .ERRORS [0 ].error_code ,
66+ message = self .ERRORS [0 ].message (),
67+ )
You can’t perform that action at this time.
0 commit comments