Skip to content

Commit 17d7446

Browse files
committed
Add --disable option to disable some rules
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
1 parent 0d9c3fe commit 17d7446

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

torchfix/__main__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def _parse_args() -> argparse.Namespace:
7070
type=str,
7171
default=None,
7272
)
73+
parser.add_argument(
74+
"--disable",
75+
help="Comma-separated list of rules to disable. Defaults to None.",
76+
type=str,
77+
default=None,
78+
)
7379
parser.add_argument("--version", action="version", version=f"{TorchFixVersion}")
7480

7581
# XXX TODO: Get rid of this!
@@ -101,7 +107,15 @@ def main() -> None:
101107
if not torch_files:
102108
return
103109
config = TorchCodemodConfig()
104-
config.select = list(process_error_code_str(args.select))
110+
selected_rules = process_error_code_str(args.select, True)
111+
if args.disable is not None:
112+
if args.disable == "ALL":
113+
print("No rule to apply", file=sys.stderr)
114+
sys.exit(1)
115+
disabled_rules = process_error_code_str(args.disable, False)
116+
selected_rules = set(selected_rules) - set(disabled_rules)
117+
118+
config.select = list(selected_rules)
105119
command_instance = TorchCodemod(codemod.CodemodContext(), config)
106120
DIFF_CONTEXT = 5
107121
try:

torchfix/torchfix.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,14 @@ def get_visitors_with_error_codes(error_codes):
9595
return [construct_visitor(cls) for cls in visitor_classes]
9696

9797

98-
def process_error_code_str(code_str):
98+
def process_error_code_str(code_str, enabled = True):
9999
# Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001.
100100
# We deduplicate them here.
101101

102102
# Default when --select is not provided.
103103
if code_str is None:
104+
if not enabled:
105+
return set()
104106
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
105107
return set(GET_ALL_ERROR_CODES()) - exclude_set
106108

0 commit comments

Comments
 (0)