Skip to content

Commit

Permalink
Saving as npy or pt works, getting further with the transforms and mo…
Browse files Browse the repository at this point in the history
…dels (still needs mapper)
  • Loading branch information
stefanklut committed Sep 6, 2024
1 parent 073deb3 commit c46ab23
Show file tree
Hide file tree
Showing 9 changed files with 487 additions and 123 deletions.
5 changes: 5 additions & 0 deletions configs/extra_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@
_C.PREPROCESS.RESIZE.MIN_SIZE = [1024]
_C.PREPROCESS.RESIZE.MAX_SIZE = 2048

_C.PREPROCESS.SAVE_METHOD_IMAGE = "pt"
_C.PREPROCESS.SAVE_METHOD_SEM_SEG = "pt"
_C.PREPROCESS.SAVE_METHOD_INSTANCES = "json"
_C.PREPROCESS.SAVE_METHOD_PANOS = "png"

# DPI correction in resizing
_C.PREPROCESS.DPI = CN()
_C.PREPROCESS.DPI.AUTO_DETECT = False
Expand Down
8 changes: 4 additions & 4 deletions data/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def wrapper(*args, **kwargs):
return wrapper


T.AugmentationList = TimedAugmentationList
# T.AugmentationList = TimedAugmentationList


class RandomApply(T.RandomApply):
Expand Down Expand Up @@ -1880,7 +1880,7 @@ def test(args) -> None:
from core.setup import setup_cfg
from data import preprocess
from data.mapper import AugInput
from utils.image_torch_utils import load_image_tensor_from_path
from utils.image_torch_utils import load_image_tensor_from_path_gpu_decode
from utils.image_utils import load_image_array_from_path
from utils.tempdir import OptionalTemporaryDirectory

Expand All @@ -1898,8 +1898,8 @@ def test(args) -> None:
# image = load_image_array_from_path(Path(tmp_dir).joinpath(output["image_paths"]))["image"] # type: ignore
# sem_seg = load_image_array_from_path(Path(tmp_dir).joinpath(output["sem_seg_paths"]), mode="grayscale")["image"] # type: ignore

image = load_image_tensor_from_path(Path(tmp_dir).joinpath(output["image_paths"]), device="cpu")["image"] # type: ignore
sem_seg = load_image_tensor_from_path(Path(tmp_dir).joinpath(output["sem_seg_paths"]), mode="grayscale", device="cpu")["image"] # type: ignore
image = load_image_tensor_from_path_gpu_decode(Path(tmp_dir).joinpath(output["image_paths"]), device="cpu")["image"] # type: ignore
sem_seg = load_image_tensor_from_path_gpu_decode(Path(tmp_dir).joinpath(output["sem_seg_paths"]), mode="grayscale", device="cpu")["image"] # type: ignore

augs = build_augmentation(cfg, mode="train")
aug = T.AugmentationList(augs)
Expand Down
93 changes: 69 additions & 24 deletions data/mapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import logging
from pathlib import Path
from typing import Any, Optional

import detectron2.data.transforms as T
Expand All @@ -16,7 +17,7 @@
from detectron2.data.transforms.augmentation import _check_img_dtype

from data.augmentations import build_augmentation
from utils.image_torch_utils import load_image_tensor_from_path
from utils.image_torch_utils import load_image_tensor_from_path_gpu_decode
from utils.image_utils import load_image_array_from_path
from utils.logging_utils import get_logger_name

Expand Down Expand Up @@ -49,7 +50,7 @@ def _transform_to_aug(tfm_or_aug):
def _check_img_dtype(img):
if isinstance(img, torch.Tensor):
assert img.dtype == torch.uint8 or img.dtype == torch.float32, f"[Augmentation] Got image of type {img.dtype}!"
assert img.dim() in [2, 3], img.dim()
assert img.dim() == 3, img.dim()
elif isinstance(img, np.ndarray):
assert img.dtype == np.uint8 or img.dtype == np.float32, f"[Augmentation] Got image of type {img.dtype}!"
assert img.ndim in [2, 3], img.ndim
Expand Down Expand Up @@ -221,6 +222,51 @@ def from_config(cls, cfg: CfgNode, mode: str = "train", device=torch.device("cpu
)
return ret

def load_array(self, path: Path | str, mode: str = "color") -> dict[str, Any]:
"""
Load an image from a file path.
Args:
path (str): The path to the image file.
mode (str): The mode to use when loading the image.
Returns:
dict: The loaded image and its DPI.
"""
path = Path(path)
if path.suffix == ".npy":
array = np.load(path)
if array is None:
raise ValueError(f"Array {path} cannot be loaded")
assert array.ndim == 3 or array.ndim == 2, f"Invalid array shape: {array.shape}"
if array.ndim == 2:
array = array[:, :, None]
return {"image": array, "dpi": None}
elif path.suffix == ".pt":
tensor = torch.load(path, weights_only=True)
if tensor is None:
raise ValueError(f"Tensor {path} cannot be loaded")

if tensor.dim() == 2:
tensor = tensor.unsqueeze(0)
elif tensor.dim() == 3:
pass
else:
raise ValueError(f"Invalid tensor shape: {tensor.shape}")

return {"image": tensor, "dpi": None}
else:
if self.on_gpu:
data = load_image_tensor_from_path_gpu_decode(path, mode=mode, device=self.device)
if data is None:
raise ValueError(f"Image {path} cannot be loaded")
return data
else:
data = load_image_array_from_path(path, mode=mode)
if data is None:
raise ValueError(f"Image {path} cannot be loaded")
return data

def __call__(self, dataset_dict):
"""
Args:
Expand All @@ -231,27 +277,20 @@ def __call__(self, dataset_dict):
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below

if self.on_gpu:
image = load_image_tensor_from_path(dataset_dict["file_name"], mode="color", device=self.device)
else:
image = load_image_array_from_path(dataset_dict["file_name"], mode="color")
# Load image.
image = self.load_array(dataset_dict["file_name"], mode="color")

if image is None:
raise ValueError(f"Image {dataset_dict['file_name']} cannot be loaded")
check_image_size(dataset_dict, image["image"])

# USER: Remove if you don't do semantic/panoptic segmentation.
if "sem_seg_file_name" in dataset_dict:
if self.on_gpu:
sem_seg_gt = load_image_tensor_from_path(
dataset_dict["sem_seg_file_name"], mode="grayscale", device=self.device
)
else:
sem_seg_gt = load_image_array_from_path(dataset_dict["sem_seg_file_name"], mode="grayscale")
if sem_seg_gt is None:
raise ValueError(f"Sem-seg {dataset_dict['sem_seg_file_name']} cannot be loaded")
sem_seg_gt = self.load_array(dataset_dict["sem_seg_file_name"], mode="grayscale")
else:
sem_seg_gt = {"image": None}
sem_seg_gt = {"image": None, "dpi": None}

assert type(image["image"]) == type(
sem_seg_gt["image"]
), f"Image and sem_seg_gt have different types: {type(image['image'])} and {type(sem_seg_gt['image'])}"

aug_input = AugInput(
image["image"],
Expand All @@ -266,21 +305,27 @@ def __call__(self, dataset_dict):
image, sem_seg_gt = aug_input.image, aug_input.sem_seg

if image is None:
raise ValueError(f"Image {dataset_dict['file_name']} cannot be loaded")
raise ValueError(f"Image {dataset_dict['file_name']} has become None after augmentation")

# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
if self.on_gpu:
if isinstance(image, torch.Tensor):
image_shape = image.shape[-2:] # h, w
dataset_dict["image"] = image.squeeze(0)
if sem_seg_gt is not None:
dataset_dict["sem_seg"] = sem_seg_gt.squeeze(0).to(dtype=torch.long)
else:
dataset_dict["image"] = image.clone()
elif isinstance(image, np.ndarray):
image_shape = image.shape[:2]
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
if sem_seg_gt is not None:
else:
raise ValueError(f"image is not a numpy array or torch tensor: {type(image)}")

if sem_seg_gt is not None:
if isinstance(sem_seg_gt, torch.Tensor):
dataset_dict["sem_seg"] = sem_seg_gt.to(dtype=torch.long).squeeze(0).clone()
elif isinstance(sem_seg_gt, np.ndarray):
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
else:
raise ValueError(f"sem_seg_gt is not a numpy array or torch tensor: {type(sem_seg_gt)}")

# USER: Remove if you don't use pre-computed proposals.
# Most users would not need this feature.
Expand Down
2 changes: 1 addition & 1 deletion data/numpy_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def wrapper(*args, **kwargs):
return wrapper


T.Transform = TimedTransform
# T.Transform = TimedTransform


class ResizeTransform(T.Transform):
Expand Down
Loading

0 comments on commit c46ab23

Please sign in to comment.