-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
75 lines (67 loc) · 1.97 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from importlib import import_module
import click
import torch
from interpretation import run_interpretation_method, run_pxpermute
from util import load_model, prepare_test_data, read_config_file
@click.group()
def cli():
pass
@cli.command()
@click.option("--output-path", default="../output", type=click.Path(exists=True))
@click.option(
"--method",
default="PxPermute",
type=click.Choice(
[
"PxPermute",
"DeepLift",
"GuidedGradCam",
"Saliency",
"DeepLiftShap",
"GradientShap",
"InputXGradient",
"IntegratedGradients",
"GuidedBackprop",
"Deconvolution",
"Occlusion",
"FeaturePermutation",
"ShapleyValueSampling",
"Lime",
"KernelShap",
"LRP",
]
),
)
def run_channel_importance(output_path, method="PxPermute"):
# load parameters from config file
parameters = read_config_file("config.cfg")
if parameters["device"] == "cuda" and not torch.cuda.is_available():
parameters["device"] = "cpu"
click.echo("GPU is not available, using CPU instead")
# load data and model
metadata, loader, test_index, label_map, test_transform = prepare_test_data(
**parameters
)
parameters["num_classes"] = len(label_map.keys())
model = load_model(**parameters)
# run interpretation method
if method == "PxPermute":
run_pxpermute(
metadata,
loader,
model,
output_path,
test_index,
test_transform,
label_map,
**parameters
)
else:
mod = import_module("captum.attr")
method = getattr(mod, method)
ablator = method(model)
run_interpretation_method(
test_loader=loader, ablator=ablator, output_path=output_path, **parameters
)
if __name__ == "__main__":
run_channel_importance()