Skip to content

Commit 1ba8d12

Browse files
authored
Anomalib CLI Improvements - Update metrics and create post_processing section in the config file (#607)
* Rename image/pixel_metrics_names to image/pixel_metrics * Rename image/pixel_metrics_names to image/pixel_metrics * Modified config files. * Modify get_callbacks function for the old cli * Create pre-processing-configuration callback * Add new CLI configuration for the post-processing configuration * Add options to normalization_method * Address mypy issues * Fix docstring * Address codacy issues
1 parent c1f51a6 commit 1ba8d12

File tree

15 files changed

+208
-113
lines changed

15 files changed

+208
-113
lines changed

anomalib/utils/callbacks/__init__.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,22 @@
1919
from .metrics_configuration import MetricsConfigurationCallback
2020
from .min_max_normalization import MinMaxNormalizationCallback
2121
from .model_loader import LoadModelCallback
22+
from .post_processing_configuration import PostProcessingConfigurationCallback
2223
from .tiler_configuration import TilerConfigurationCallback
2324
from .timer import TimerCallback
2425
from .visualizer import ImageVisualizerCallback, MetricVisualizerCallback
2526

2627
__all__ = [
2728
"CdfNormalizationCallback",
29+
"GraphLogger",
30+
"ImageVisualizerCallback",
2831
"LoadModelCallback",
2932
"MetricsConfigurationCallback",
33+
"MetricVisualizerCallback",
3034
"MinMaxNormalizationCallback",
35+
"PostProcessingConfigurationCallback",
3136
"TilerConfigurationCallback",
3237
"TimerCallback",
33-
"ImageVisualizerCallback",
34-
"MetricVisualizerCallback",
3538
]
3639

3740

@@ -64,20 +67,25 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
6467

6568
callbacks.extend([checkpoint, TimerCallback()])
6669

67-
# Add metric configuration to the model via MetricsConfigurationCallback
68-
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else None
69-
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else None
70+
# Add post-processing configurations to AnomalyModule.
7071
image_threshold = (
7172
config.metrics.threshold.image_default if "image_default" in config.metrics.threshold.keys() else None
7273
)
7374
pixel_threshold = (
7475
config.metrics.threshold.pixel_default if "pixel_default" in config.metrics.threshold.keys() else None
7576
)
77+
post_processing_callback = PostProcessingConfigurationCallback(
78+
adaptive_threshold=config.metrics.threshold.adaptive,
79+
default_image_threshold=image_threshold,
80+
default_pixel_threshold=pixel_threshold,
81+
)
82+
callbacks.append(post_processing_callback)
83+
84+
# Add metric configuration to the model via MetricsConfigurationCallback
85+
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else None
86+
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else None
7687
metrics_callback = MetricsConfigurationCallback(
77-
config.metrics.threshold.adaptive,
7888
config.dataset.task,
79-
image_threshold,
80-
pixel_threshold,
8189
image_metric_names,
8290
pixel_metric_names,
8391
)
@@ -172,7 +180,8 @@ def add_visualizer_callback(callbacks: List[Callback], config: Union[DictConfig,
172180
config.visualization.inputs_are_normalized = not config.model.normalization_method == "none"
173181
else:
174182
config.visualization.task = config.data.init_args.task
175-
config.visualization.inputs_are_normalized = not config.metrics.normalization_method == "none"
183+
config.visualization.inputs_are_normalized = not config.post_processing.normalization_method == "none"
184+
176185
if config.visualization.log_images or config.visualization.save_images or config.visualization.show_images:
177186
image_save_path = (
178187
config.visualization.image_save_path

anomalib/utils/callbacks/metrics_configuration.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import List, Optional
99

1010
import pytorch_lightning as pl
11-
import torch
1211
from pytorch_lightning.callbacks import Callback
1312
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY
1413

@@ -26,13 +25,9 @@ class MetricsConfigurationCallback(Callback):
2625

2726
def __init__(
2827
self,
29-
adaptive_threshold: bool,
3028
task: str = "segmentation",
31-
default_image_threshold: Optional[float] = None,
32-
default_pixel_threshold: Optional[float] = None,
33-
image_metric_names: Optional[List[str]] = None,
34-
pixel_metric_names: Optional[List[str]] = None,
35-
normalization_method: str = "min_max",
29+
image_metrics: Optional[List[str]] = None,
30+
pixel_metrics: Optional[List[str]] = None,
3631
):
3732
"""Create image and pixel-level AnomalibMetricsCollection.
3833
@@ -43,30 +38,12 @@ def __init__(
4338
4439
Args:
4540
task (str): Task type of the current run.
46-
adaptive_threshold (bool): Flag indicating whether threshold should be adaptive.
47-
default_image_threshold (Optional[float]): Default image threshold value.
48-
default_pixel_threshold (Optional[float]): Default pixel threshold value.
49-
image_metric_names (Optional[List[str]]): List of image-level metrics.
50-
pixel_metric_names (Optional[List[str]]): List of pixel-level metrics.
51-
normalization_method(Optional[str]): Normalization method. <None, min_max, cdf>
41+
image_metrics (Optional[List[str]]): List of image-level metrics.
42+
pixel_metrics (Optional[List[str]]): List of pixel-level metrics.
5243
"""
53-
# TODO: https://github.com/openvinotoolkit/anomalib/issues/384
5444
self.task = task
55-
self.image_metric_names = image_metric_names
56-
self.pixel_metric_names = pixel_metric_names
57-
58-
# TODO: https://github.com/openvinotoolkit/anomalib/issues/384
59-
# TODO: This is a workaround. normalization-method is actually not used in metrics.
60-
# It's only accessed from `before_instantiate` method in `AnomalibCLI` to configure
61-
# its callback.
62-
self.normalization_method = normalization_method
63-
64-
assert (
65-
adaptive_threshold or default_image_threshold is not None and default_pixel_threshold is not None
66-
), "Default thresholds must be specified when adaptive threshold is disabled."
67-
self.adaptive_threshold = adaptive_threshold
68-
self.default_image_threshold = default_image_threshold
69-
self.default_pixel_threshold = default_pixel_threshold
45+
self.image_metric_names = image_metrics
46+
self.pixel_metric_names = pixel_metrics
7047

7148
def setup(
7249
self,
@@ -97,12 +74,6 @@ def setup(
9774
pixel_metric_names = self.pixel_metric_names
9875

9976
if isinstance(pl_module, AnomalyModule):
100-
pl_module.adaptive_threshold = self.adaptive_threshold
101-
if not self.adaptive_threshold:
102-
# pylint: disable=not-callable
103-
pl_module.image_threshold.value = torch.tensor(self.default_image_threshold).cpu()
104-
pl_module.pixel_threshold.value = torch.tensor(self.default_pixel_threshold).cpu()
105-
10677
pl_module.image_metrics = metric_collection_from_names(image_metric_names, "image_")
10778
pl_module.pixel_metrics = metric_collection_from_names(pixel_metric_names, "pixel_")
10879

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Post-Processing Configuration Callback."""
2+
3+
# Copyright (C) 2022 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
7+
import logging
8+
from typing import Optional
9+
10+
import torch
11+
from pytorch_lightning import Callback, LightningModule, Trainer
12+
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY
13+
14+
from anomalib.models.components.base.anomaly_module import AnomalyModule
15+
16+
logger = logging.getLogger(__name__)
17+
18+
__all__ = ["PostProcessingConfigurationCallback"]
19+
20+
21+
@CALLBACK_REGISTRY
22+
class PostProcessingConfigurationCallback(Callback):
23+
"""Post-Processing Configuration Callback.
24+
25+
Args:
26+
normalization_method(Optional[str]): Normalization method. <None, min_max, cdf>
27+
adaptive_threshold (bool): Flag indicating whether threshold should be adaptive.
28+
default_image_threshold (Optional[float]): Default image threshold value.
29+
default_pixel_threshold (Optional[float]): Default pixel threshold value.
30+
"""
31+
32+
def __init__(
33+
self,
34+
normalization_method: str = "min_max",
35+
adaptive_threshold: bool = True,
36+
default_image_threshold: Optional[float] = None,
37+
default_pixel_threshold: Optional[float] = None,
38+
) -> None:
39+
super().__init__()
40+
self.normalization_method = normalization_method
41+
42+
assert (
43+
adaptive_threshold or default_image_threshold is not None and default_pixel_threshold is not None
44+
), "Default thresholds must be specified when adaptive threshold is disabled."
45+
46+
self.adaptive_threshold = adaptive_threshold
47+
self.default_image_threshold = default_image_threshold
48+
self.default_pixel_threshold = default_pixel_threshold
49+
50+
# pylint: disable=unused-argument
51+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
52+
"""Setup post-processing configuration within Anomalib Model.
53+
54+
Args:
55+
trainer (Trainer): PyTorch Lightning Trainer
56+
pl_module (LightningModule): Anomalib Model that inherits pl LightningModule.
57+
stage (Optional[str], optional): fit, validate, test or predict. Defaults to None.
58+
"""
59+
if isinstance(pl_module, AnomalyModule):
60+
pl_module.adaptive_threshold = self.adaptive_threshold
61+
if pl_module.adaptive_threshold is False:
62+
pl_module.image_threshold.value = torch.tensor(self.default_image_threshold).cpu()
63+
pl_module.pixel_threshold.value = torch.tensor(self.default_pixel_threshold).cpu()

anomalib/utils/cli/cli.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
MetricsConfigurationCallback,
2727
MinMaxNormalizationCallback,
2828
ModelCheckpoint,
29+
PostProcessingConfigurationCallback,
2930
TilerConfigurationCallback,
3031
TimerCallback,
3132
add_visualizer_callback,
@@ -90,7 +91,6 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
9091
Args:
9192
parser (LightningArgumentParser): Lightning Argument Parser.
9293
"""
93-
# TODO: https://github.com/openvinotoolkit/anomalib/issues/19
9494
# TODO: https://github.com/openvinotoolkit/anomalib/issues/20
9595
parser.add_argument(
9696
"--export_mode", type=str, default="", help="Select export mode to ONNX or OpenVINO IR format."
@@ -105,18 +105,24 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
105105
parser.add_lightning_class_args(TilerConfigurationCallback, "tiling") # type: ignore
106106
parser.set_defaults({"tiling.enable": False})
107107

108+
parser.add_lightning_class_args(PostProcessingConfigurationCallback, "post_processing") # type: ignore
109+
parser.set_defaults(
110+
{
111+
"post_processing.normalization_method": "min_max",
112+
"post_processing.adaptive_threshold": True,
113+
"post_processing.default_image_threshold": None,
114+
"post_processing.default_pixel_threshold": None,
115+
}
116+
)
117+
108118
# TODO: Assign these default values within the MetricsConfigurationCallback
109119
# - https://github.com/openvinotoolkit/anomalib/issues/384
110120
parser.add_lightning_class_args(MetricsConfigurationCallback, "metrics") # type: ignore
111121
parser.set_defaults(
112122
{
113-
"metrics.adaptive_threshold": True,
114123
"metrics.task": "segmentation",
115-
"metrics.default_image_threshold": None,
116-
"metrics.default_pixel_threshold": None,
117-
"metrics.image_metric_names": ["F1Score", "AUROC"],
118-
"metrics.pixel_metric_names": ["F1Score", "AUROC"],
119-
"metrics.normalization_method": "min_max",
124+
"metrics.image_metrics": ["F1Score", "AUROC"],
125+
"metrics.pixel_metrics": ["F1Score", "AUROC"],
120126
}
121127
)
122128

@@ -203,7 +209,7 @@ def __set_callbacks(self) -> None:
203209
# TODO: This could be set in PostProcessingConfiguration callback
204210
# - https://github.com/openvinotoolkit/anomalib/issues/384
205211
# Normalization.
206-
normalization = config.metrics.normalization_method
212+
normalization = config.post_processing.normalization_method
207213
if normalization:
208214
if normalization == "min_max":
209215
callbacks.append(MinMaxNormalizationCallback())

anomalib/utils/sweep/helpers/callbacks.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from omegaconf import DictConfig, ListConfig
1010
from pytorch_lightning import Callback
1111

12-
from anomalib.utils.callbacks import MetricsConfigurationCallback
12+
from anomalib.utils.callbacks import (
13+
MetricsConfigurationCallback,
14+
PostProcessingConfigurationCallback,
15+
)
1316
from anomalib.utils.callbacks.timer import TimerCallback
1417

1518

@@ -24,23 +27,40 @@ def get_sweep_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]
2427
"""
2528
callbacks: List[Callback] = [TimerCallback()]
2629
# Add metric configuration to the model via MetricsConfigurationCallback
27-
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else None
28-
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else None
29-
image_threshold = (
30-
config.metrics.threshold.image_default if "image_default" in config.metrics.threshold.keys() else None
31-
)
32-
pixel_threshold = (
33-
config.metrics.threshold.pixel_default if "pixel_default" in config.metrics.threshold.keys() else None
34-
)
35-
metrics_callback = MetricsConfigurationCallback(
36-
adaptive_threshold=config.metrics.threshold.adaptive,
37-
task=config.dataset.task,
30+
31+
# TODO: Remove this once the old CLI is deprecated.
32+
if isinstance(config, DictConfig):
33+
image_metrics = config.metrics.image if "image" in config.metrics.keys() else None
34+
pixel_metrics = config.metrics.pixel if "pixel" in config.metrics.keys() else None
35+
image_threshold = (
36+
config.metrics.threshold.image_default if "image_default" in config.metrics.threshold.keys() else None
37+
)
38+
pixel_threshold = (
39+
config.metrics.threshold.pixel_default if "pixel_default" in config.metrics.threshold.keys() else None
40+
)
41+
normalization_method = config.model.normalization_method
42+
# NOTE: This is for the new anomalib CLI.
43+
else:
44+
image_metrics = config.metrics.image_metrics if "image_metrics" in config.metrics else None
45+
pixel_metrics = config.metrics.pixel_metrics if "pixel_metrics" in config.metrics else None
46+
image_threshold = (
47+
config.post_processing.default_image_threshold if "image_default" in config.post_processing.keys() else None
48+
)
49+
pixel_threshold = (
50+
config.post_processing.default_pixel_threshold if "pixel_default" in config.post_processing.keys() else None
51+
)
52+
normalization_method = config.post_processing.normalization_method
53+
54+
post_processing_configuration_callback = PostProcessingConfigurationCallback(
55+
normalization_method=normalization_method,
3856
default_image_threshold=image_threshold,
3957
default_pixel_threshold=pixel_threshold,
40-
image_metric_names=image_metric_names,
41-
pixel_metric_names=pixel_metric_names,
42-
normalization_method=config.model.normalization_method,
4358
)
44-
callbacks.append(metrics_callback)
59+
callbacks.append(post_processing_configuration_callback)
60+
61+
metrics_configuration_callback = MetricsConfigurationCallback(
62+
task=config.dataset.task, image_metrics=image_metrics, pixel_metrics=pixel_metrics
63+
)
64+
callbacks.append(metrics_configuration_callback)
4565

4666
return callbacks

configs/model/cflow.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,23 @@ model:
3737
permute_soft: false
3838
lr: 0.0001
3939

40-
metrics:
40+
post_processing:
41+
normalization_method: min_max # <null, min_max, cdf>
4142
adaptive_threshold: true
4243
default_image_threshold: null
4344
default_pixel_threshold: null
44-
image_metric_names:
45+
46+
metrics:
47+
image_metrics:
4548
- F1Score
4649
- AUROC
47-
pixel_metric_names:
50+
pixel_metrics:
4851
- F1Score
4952
- AUROC
50-
normalization_method: min_max
5153

5254
visualization:
5355
show_images: False # show images on the screen
54-
save_images: False # save images to the file system
56+
save_images: True # save images to the file system
5557
log_images: False # log images to the available loggers (if any)
5658
mode: full # options: ["full", "simple"]
5759

configs/model/dfkde.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,20 @@ model:
3232
threshold_steepness: 0.05
3333
threshold_offset: 12
3434

35-
metrics:
35+
post_processing:
36+
normalization_method: min_max # <null, min_max, cdf>
3637
adaptive_threshold: true
3738
default_image_threshold: null
38-
image_metric_names:
39+
default_pixel_threshold: null
40+
41+
metrics:
42+
image_metrics:
3943
- F1Score
4044
- AUROC
4145

4246
visualization:
4347
show_images: False # show images on the screen
44-
save_images: False # save images to the file system
48+
save_images: True # save images to the file system
4549
log_images: False # log images to the available loggers (if any)
4650
mode: full # options: ["full", "simple"]
4751

0 commit comments

Comments
 (0)