Skip to content

Commit

Permalink
fix scaling issue in validation
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed Dec 12, 2023
1 parent 985274b commit e974a82
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
4 changes: 2 additions & 2 deletions core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def build_train_loader(cls, cfg):
mapper = DatasetMapper(
is_train=True,
recompute_boxes=cfg.MODEL.MASK_ON,
augmentations=build_augmentation(cfg, is_train=True),
augmentations=build_augmentation(cfg, mode="train"),
image_format=cfg.INPUT.FORMAT,
use_instance_mask=cfg.MODEL.MASK_ON,
instance_mask_format=cfg.INPUT.MASK_FORMAT,
Expand All @@ -308,7 +308,7 @@ def build_test_loader(cls, cfg, dataset_name):
mapper = DatasetMapper(
is_train=False,
recompute_boxes=cfg.MODEL.MASK_ON,
augmentations=build_augmentation(cfg, is_train=False),
augmentations=build_augmentation(cfg, mode="val"),
image_format=cfg.INPUT.FORMAT,
use_instance_mask=cfg.MODEL.MASK_ON,
instance_mask_format=cfg.INPUT.MASK_FORMAT,
Expand Down
25 changes: 18 additions & 7 deletions datasets/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,13 +700,13 @@ def get_transform(self, image):
return BlendTransform(src_image=np.asarray(0).astype(np.float32), src_weight=1 - w, dst_weight=w)


def build_augmentation(cfg: CfgNode, is_train: bool) -> list[T.Augmentation | T.Transform]:
def build_augmentation(cfg: CfgNode, mode: str = "train") -> list[T.Augmentation | T.Transform]:
"""
Function to generate all the augmentations used in the inference and training process
Args:
cfg (CfgNode): config node
is_train (bool): flag if the augmentation are used for inference or training
mode (str): flag if the augmentation are used for inference or training
Returns:
list[T.Augmentation | T.Transform]: list of augmentations to apply to an image
Expand All @@ -716,30 +716,41 @@ def build_augmentation(cfg: CfgNode, is_train: bool) -> list[T.Augmentation | T.
if cfg.INPUT.RESIZE_MODE == "none":
pass
elif cfg.INPUT.RESIZE_MODE in ["shortest_edge", "longest_edge"]:
if is_train:
if mode == "train":
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
else:
elif mode == "val":
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
sample_style = "choice"
elif mode == "test":
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
sample_style = "choice"
else:
raise NotImplementedError(f"Unknown mode: {mode}")
if cfg.INPUT.RESIZE_MODE == "shortest_edge":
augmentation.append(ResizeShortestEdge(min_size, max_size, sample_style))
elif cfg.INPUT.RESIZE_MODE == "longest_edge":
augmentation.append(ResizeLongestEdge(min_size, max_size, sample_style))
elif cfg.INPUT.RESIZE_MODE == "scaling":
if is_train:
if mode == "train":
max_size = cfg.INPUT.MAX_SIZE_TRAIN
scaling = cfg.INPUT.SCALING_TRAIN
else:
elif mode == "val":
max_size = cfg.INPUT.MAX_SIZE_TRAIN
scaling = cfg.INPUT.SCALING_TRAIN
elif mode == "test":
max_size = cfg.INPUT.MAX_SIZE_TEST
scaling = cfg.INPUT.SCALING_TEST
else:
raise NotImplementedError(f"Unknown mode: {mode}")
augmentation.append(ResizeScaling(scaling, max_size))
else:
raise NotImplementedError(f"{cfg.INPUT.RESIZE_MODE} is not a known resize mode")

if not is_train:
if not mode == "train":
return augmentation

# TODO Add random crop
Expand Down

0 comments on commit e974a82

Please sign in to comment.