From e974a820c7e9c76d85ecf7b289097e6a1457a3aa Mon Sep 17 00:00:00 2001 From: Stefan Klut Date: Tue, 12 Dec 2023 13:08:30 +0100 Subject: [PATCH] fix scaling issue in validation --- core/trainer.py | 4 ++-- datasets/augmentations.py | 25 ++++++++++++++++++------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/core/trainer.py b/core/trainer.py index 29ad244..4784945 100644 --- a/core/trainer.py +++ b/core/trainer.py @@ -291,7 +291,7 @@ def build_train_loader(cls, cfg): mapper = DatasetMapper( is_train=True, recompute_boxes=cfg.MODEL.MASK_ON, - augmentations=build_augmentation(cfg, is_train=True), + augmentations=build_augmentation(cfg, mode="train"), image_format=cfg.INPUT.FORMAT, use_instance_mask=cfg.MODEL.MASK_ON, instance_mask_format=cfg.INPUT.MASK_FORMAT, @@ -308,7 +308,7 @@ def build_test_loader(cls, cfg, dataset_name): mapper = DatasetMapper( is_train=False, recompute_boxes=cfg.MODEL.MASK_ON, - augmentations=build_augmentation(cfg, is_train=False), + augmentations=build_augmentation(cfg, mode="val"), image_format=cfg.INPUT.FORMAT, use_instance_mask=cfg.MODEL.MASK_ON, instance_mask_format=cfg.INPUT.MASK_FORMAT, diff --git a/datasets/augmentations.py b/datasets/augmentations.py index abd10f9..bebe5ce 100644 --- a/datasets/augmentations.py +++ b/datasets/augmentations.py @@ -700,13 +700,13 @@ def get_transform(self, image): return BlendTransform(src_image=np.asarray(0).astype(np.float32), src_weight=1 - w, dst_weight=w) -def build_augmentation(cfg: CfgNode, is_train: bool) -> list[T.Augmentation | T.Transform]: +def build_augmentation(cfg: CfgNode, mode: str = "train") -> list[T.Augmentation | T.Transform]: """ Function to generate all the augmentations used in the inference and training process Args: cfg (CfgNode): config node - is_train (bool): flag if the augmentation are used for inference or training + mode (str): flag if the augmentation are used for inference or training Returns: list[T.Augmentation | T.Transform]: list of augmentations to apply to an image @@ -716,30 +716,41 @@ def build_augmentation(cfg: CfgNode, is_train: bool) -> list[T.Augmentation | T. if cfg.INPUT.RESIZE_MODE == "none": pass elif cfg.INPUT.RESIZE_MODE in ["shortest_edge", "longest_edge"]: - if is_train: + if mode == "train": min_size = cfg.INPUT.MIN_SIZE_TRAIN max_size = cfg.INPUT.MAX_SIZE_TRAIN sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING - else: + elif mode == "val": + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + sample_style = "choice" + elif mode == "test": min_size = cfg.INPUT.MIN_SIZE_TEST max_size = cfg.INPUT.MAX_SIZE_TEST sample_style = "choice" + else: + raise NotImplementedError(f"Unknown mode: {mode}") if cfg.INPUT.RESIZE_MODE == "shortest_edge": augmentation.append(ResizeShortestEdge(min_size, max_size, sample_style)) elif cfg.INPUT.RESIZE_MODE == "longest_edge": augmentation.append(ResizeLongestEdge(min_size, max_size, sample_style)) elif cfg.INPUT.RESIZE_MODE == "scaling": - if is_train: + if mode == "train": max_size = cfg.INPUT.MAX_SIZE_TRAIN scaling = cfg.INPUT.SCALING_TRAIN - else: + elif mode == "val": + max_size = cfg.INPUT.MAX_SIZE_TRAIN + scaling = cfg.INPUT.SCALING_TRAIN + elif mode == "test": max_size = cfg.INPUT.MAX_SIZE_TEST scaling = cfg.INPUT.SCALING_TEST + else: + raise NotImplementedError(f"Unknown mode: {mode}") augmentation.append(ResizeScaling(scaling, max_size)) else: raise NotImplementedError(f"{cfg.INPUT.RESIZE_MODE} is not a known resize mode") - if not is_train: + if not mode == "train": return augmentation # TODO Add random crop