diff --git a/augmentation/augmentations_list.py b/augmentation/augmentations_list.py index 61412d7..032617a 100644 --- a/augmentation/augmentations_list.py +++ b/augmentation/augmentations_list.py @@ -2,16 +2,14 @@ import kornia as K import random -from typing import Dict, Tuple +from typing import Dict, Tuple, Union from augmentation.registry import register_augmentation @register_augmentation(name="random_rotate") def random_rotate( - images: torch.Tensor, - degrees: Tuple[float, float] = (-20.0, 20.0), - same_on_batch: bool = False, + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Random rotate a batch of tensor images. @@ -25,8 +23,19 @@ def random_rotate( ------ transformed tensor images of shape [B, C, H, W] """ + degree: Union[Tuple[float, float], float, int] = kwargs.get("parameters", {}).get( + "degree" + ) + # If degrees are given, ensure the same output over the whole batch + if degree is not None: + if isinstance(degree, (float, int)): + degree = (degree, degree) + same_on_batch = True + else: + degree = (-20.0, 20.0) + transform = K.augmentation.RandomRotation( - degrees, same_on_batch=same_on_batch, p=1.0 + degree, same_on_batch=same_on_batch, p=1.0 ) images_out: torch.Tensor = transform(images) return images_out @@ -34,9 +43,7 @@ def random_rotate( @register_augmentation(name="random_scale") def random_scale( - images: torch.Tensor, - scale: Tuple[float, float] = (0.5, 2), - same_on_batch: bool = False, + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Random scale a batch of tensor images. @@ -50,6 +57,17 @@ def random_scale( ------ transformed tensor images of shape [B, C, H, W] """ + scale: Union[Tuple[float, float], float, int] = kwargs.get("parameters", {}).get( + "scale" + ) + # If scale is given, ensure the same output over the whole batch + if scale is not None: + if isinstance(scale, (float, int)): + scale = (scale, scale) + same_on_batch = True + else: + scale: Tuple[float, float] = (0.5, 2) + transform = K.augmentation.RandomAffine( degrees=0.0, scale=scale, same_on_batch=same_on_batch, p=1.0 ) @@ -60,8 +78,8 @@ def random_scale( @register_augmentation(name="random_translate") def random_translate( images: torch.Tensor, - translate: Tuple[float, float] = (0.2, 0.2), same_on_batch: bool = False, + **kwargs, ) -> torch.Tensor: """ Randomly translate a batch of tensor images. @@ -75,6 +93,18 @@ def random_translate( ------ transformed tensor images of shape [B, C, H, W] """ + translate_horizontal: float = kwargs.get("parameters", {}).get( + "translate_horizontal", None + ) + translate_vertical: float = kwargs.get("parameters", {}).get( + "translate_vertical", None + ) + # If translate is given, ensure the same output over the whole batch + if translate_horizontal is not None and translate_vertical is not None: + translate: Tuple[float, float] = (translate_horizontal, translate_vertical) + same_on_batch = True + else: + translate = (0.2, 0.2) transform = K.augmentation.RandomAffine( degrees=0.0, translate=translate, same_on_batch=same_on_batch, p=1.0 ) @@ -84,7 +114,7 @@ def random_translate( @register_augmentation(name="random_horizontal_flip") def random_horizontal_flip( - images: torch.Tensor, same_on_batch: bool = False + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Randomly horizontal flip a batch of tensor images. @@ -98,14 +128,27 @@ def random_horizontal_flip( ------ transformed tensor images of shape [B, C, H, W] """ - transform = K.augmentation.RandomHorizontalFlip(same_on_batch=same_on_batch, p=1.0) - images_out: torch.Tensor = transform(images) - return images_out + flip: bool = kwargs.get("parameters", {}).get("flip") + if flip is None: + # If flip is not given, perform random flip + transform = K.augmentation.RandomHorizontalFlip( + same_on_batch=same_on_batch, p=1.0 + ) + images_out: torch.Tensor = transform(images) + return images_out + elif flip is True: + # If flip is True, ensure the same output over the whole batch + images_out: torch.Tensor = K.geometry.flips.hflip(images) + same_on_batch = True + return images_out + else: + # If flip is False, do nothing + return images @register_augmentation(name="random_vertical_flip") def random_vertical_flip( - images: torch.Tensor, same_on_batch: bool = False + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Randomly vertical flip a batch of tensor images. @@ -119,16 +162,27 @@ def random_vertical_flip( ------ transformed tensor images of shape [B, C, H, W] """ - transform = K.augmentation.RandomVerticalFlip(same_on_batch=same_on_batch, p=1.0) - images_out: torch.Tensor = transform(images) - return images_out + flip: bool = kwargs.get("parameters", {}).get("flip") + if flip is None: + # If flip is not given, perform random flip + transform = K.augmentation.RandomVerticalFlip( + same_on_batch=same_on_batch, p=1.0 + ) + images_out: torch.Tensor = transform(images) + return images_out + elif flip is True: + # If flip is True, ensure the same output over the whole batch + images_out: torch.Tensor = K.geometry.flips.vflip(images) + same_on_batch = True + return images_out + else: + # If flip is False, do nothing + return images @register_augmentation(name="random_crop") def random_crop( - images: torch.Tensor, - size: Tuple[int, int] = (512, 512), - same_on_batch: bool = False, + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Randomly crop a batch of tensor images. @@ -142,6 +196,13 @@ def random_crop( ------ transformed tensor images of shape [B, C, H, W] """ + size: Union[Tuple[int, int], int] = kwargs.get("parameters", {}).get("size") + if size is not None: + if isinstance(size, int): + size: Tuple[int, int] = (size, size) + same_on_batch = True + else: + size: Tuple[int, int] = (512, 512) transform = K.augmentation.RandomCrop(size=size, same_on_batch=same_on_batch, p=1.0) images_out: torch.Tensor = transform(images) @@ -150,12 +211,12 @@ def random_crop( @register_augmentation(name="random_tile") def random_tile( - images: torch.Tensor, window_size: int = 512, same_on_batch: bool = False + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> Dict[str, object]: """ Apply tiling to batch of tensor images and extract a random patch . - Warning: unfolding image into patches consumes A LOT OF memory. + Warning: unfolding image into patches consumes A LOT of memory. https://scikit-image.org/docs/dev/api/skimage.util.html#skimage.util.view_as_windows Parameters: @@ -167,6 +228,10 @@ def random_tile( ------ transformed tensor images of shape [B, C, H, W] """ + window_size: int = kwargs.get("parameters", {}).get("window_size") + if window_size is None: + window_size = 512 + patches: torch.Tensor = K.contrib.extract_tensor_patches( images, window_size, stride=window_size ) @@ -179,10 +244,7 @@ def random_tile( @register_augmentation(name="random_erase") def random_erase( - images: torch.Tensor, - scale: Tuple[float, float] = (0.02, 0.33), - ratio: Tuple[float, float] = (0.3, 3.3), - same_on_batch: bool = False, + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Randomly erase a batch of tensor images. @@ -196,6 +258,22 @@ def random_erase( ------ transformed tensor images of shape [B, C, H, W] """ + scale: Union[Tuple[float, float], float, int] = kwargs.get("parameters", {}).get( + "scale" + ) + ratio: Union[Tuple[float, float], float, int] = kwargs.get("parameters", {}).get( + "ratio" + ) + if scale is not None and ratio is not None: + if isinstance(scale, (float, int)): + scale = (scale, scale) + if isinstance(ratio, (float, int)): + ratio = (ratio, ratio) + same_on_batch = True + else: + scale = (0.02, 0.33) + ratio = (0.3, 3.3) + transform = K.augmentation.RandomErasing( scale, ratio, value=0.0, same_on_batch=same_on_batch, p=1.0 ) @@ -206,9 +284,8 @@ def random_erase( @register_augmentation(name="random_gaussian_noise") def random_gaussian_noise( images: torch.Tensor, - mean: float = 0.0, - std: float = 0.1, same_on_batch: bool = False, + **kwargs, ) -> torch.Tensor: """ Randomly erase a batch of tensor images. @@ -222,6 +299,14 @@ def random_gaussian_noise( ------ transformed tensor images of shape [B, C, H, W] """ + mean: float = kwargs.get("parameters", {}).get("mean") + std: float = kwargs.get("parameters", {}).get("std") + if mean is not None and std is not None: + same_on_batch = True + else: + mean = 0.0 + std = 0.1 + transform = K.augmentation.RandomGaussianNoise( mean, std, same_on_batch=same_on_batch, p=1.0 ) @@ -235,6 +320,7 @@ def random_gaussian_blur( kernel_sizes: Tuple[int, int] = (3, 27), sigmas: Tuple[float, float] = (1.0, 10.0), same_on_batch: bool = False, + **kwargs, ) -> torch.Tensor: """ Randomly blur a batch of tensor images. @@ -254,7 +340,15 @@ def get_random_kernel_size_and_sigma() -> Tuple[int, float]: sigma: float = random.uniform(sigmas[0], sigmas[1]) return kernel_size, sigma - kernel_size, sigma = get_random_kernel_size_and_sigma() + kernel_size: int = kwargs.get("parameters", {}).get("kernel_size") + sigma: float = kwargs.get("parameters", {}).get("sigma") + if kernel_size is not None and sigma is not None: + if kernel_size % 2 == 0: + raise ValueError(f"kernel_size must be an odd number. Got {kernel_size=}") + same_on_batch = True + else: + kernel_size, sigma = get_random_kernel_size_and_sigma() + transform = K.augmentation.RandomGaussianBlur( (kernel_size, kernel_size), (sigma, sigma), same_on_batch=same_on_batch, p=1.0 ) @@ -264,7 +358,7 @@ def get_random_kernel_size_and_sigma() -> Tuple[int, float]: @register_augmentation(name="random_sharpness") def random_sharpness( - images: torch.Tensor, sharpness: float = 0.5, same_on_batch: bool = False + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Randomly enhance sharpness a batch of tensor images. @@ -278,6 +372,16 @@ def random_sharpness( ------ transformed tensor images of shape [B, C, H, W] """ + sharpness: Union[Tuple[float, float], float, int] = kwargs.get( + "parameters", {} + ).get("sharpness") + if sharpness is not None: + if isinstance(sharpness, (float, int)): + sharpness = (sharpness, sharpness) + same_on_batch = True + else: + sharpness = (0.3, 0.7) + transform = K.augmentation.RandomSharpness( sharpness, same_on_batch=same_on_batch, p=1.0 ) @@ -287,9 +391,7 @@ def random_sharpness( @register_augmentation(name="random_brightness") def random_brightness( - images: torch.Tensor, - brightness: Tuple[float, float] = (0.75, 1.5), - same_on_batch: bool = False, + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Adjust brightness of a batch of tensor images randomly. @@ -303,7 +405,16 @@ def random_brightness( ------ transformed tensor images of shape [B, C, H, W] """ - # Random brightness + brightness: Union[Tuple[float, float], float, int] = kwargs.get( + "parameters", {} + ).get("brightness") + if brightness is not None: + if isinstance(brightness, (float, int)): + brightness = (brightness, brightness) + same_on_batch = True + else: + brightness = (0.75, 1.5) + transform = K.augmentation.ColorJitter( brightness=brightness, same_on_batch=same_on_batch, p=1.0 ) @@ -313,9 +424,7 @@ def random_brightness( @register_augmentation(name="random_hue") def random_hue( - images: torch.Tensor, - hue: Tuple[float, float] = (-0.5, 0.5), - same_on_batch: bool = False, + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Adjust hue of a batch of tensor images randomly. @@ -329,6 +438,16 @@ def random_hue( ------ transformed tensor images of shape [B, C, H, W] """ + hue: Union[Tuple[float, float], float, int] = kwargs.get("parameters", {}).get( + "hue" + ) + if hue is not None: + if isinstance(hue, (float, int)): + hue = (hue, hue) + same_on_batch = True + else: + hue = (-0.5, 0.5) + transform = K.augmentation.ColorJitter(hue=hue, same_on_batch=same_on_batch, p=1.0) images_out: torch.Tensor = transform(images) return images_out @@ -337,8 +456,8 @@ def random_hue( @register_augmentation(name="random_saturation") def random_saturation( images: torch.Tensor, - saturation: Tuple[float, float] = (0.5, 1.5), same_on_batch: bool = False, + **kwargs, ) -> torch.Tensor: """ Adjust saturation of a batch of tensor images randomly. @@ -352,7 +471,16 @@ def random_saturation( ------ transformed tensor images of shape [B, C, H, W] """ - # Random brightness + saturation: Union[Tuple[float, float], float, int] = kwargs.get( + "parameters", {} + ).get("saturation") + if saturation is not None: + if isinstance(saturation, (float, int)): + saturation = (saturation, saturation) + same_on_batch = True + else: + saturation = (0.1, 2.0) + transform = K.augmentation.ColorJitter( saturation=saturation, same_on_batch=same_on_batch, p=1.0 ) @@ -362,9 +490,7 @@ def random_saturation( @register_augmentation(name="random_contrast") def random_contrast( - images: torch.Tensor, - contrast: Tuple[float, float] = (0.5, 1.5), - same_on_batch: bool = False, + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Adjust contrast of a batch of tensor images randomly. @@ -378,7 +504,16 @@ def random_contrast( ------ transformed tensor images of shape [B, C, H, W] """ - # Random brightness + contrast: Union[Tuple[float, float], float, int] = kwargs.get("parameters", {}).get( + "constrast" + ) + if contrast is not None: + if isinstance(contrast, (float, int)): + contrast = (contrast, contrast) + same_on_batch = True + else: + contrast = (0.5, 1.5) + transform = K.augmentation.ColorJitter( contrast=contrast, same_on_batch=same_on_batch, p=1.0 ) @@ -388,10 +523,7 @@ def random_contrast( @register_augmentation(name="random_solarize") def random_solarize( - images: torch.Tensor, - thresholds: float = 0.1, - additions: float = 0.1, - same_on_batch: bool = False, + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Adjust solarize of a batch of tensor images randomly. @@ -405,9 +537,24 @@ def random_solarize( ------ transformed tensor images of shape [B, C, H, W] """ - # Random brightness + threshold: Union[Tuple[float, float], float, int] = kwargs.get( + "parameters", {} + ).get("threshold") + addition: Union[Tuple[float, float], float, int] = kwargs.get("parameters", {}).get( + "addition" + ) + if threshold is not None and addition is not None: + if isinstance(threshold, (float, int)): + threshold = (threshold, threshold) + if isinstance(addition, (float, int)): + addition = (addition, addition) + same_on_batch = True + else: + threshold = 0.1 + addition = 0.1 + transform = K.augmentation.RandomSolarize( - thresholds, additions, same_on_batch=same_on_batch, p=1.0 + threshold, addition, same_on_batch=same_on_batch, p=1.0 ) images_out: torch.Tensor = transform(images) return images_out @@ -415,7 +562,7 @@ def random_solarize( @register_augmentation(name="random_posterize") def random_posterize( - images: torch.Tensor, bits: int = 3, same_on_batch: bool = False + images: torch.Tensor, same_on_batch: bool = False, **kwargs ) -> torch.Tensor: """ Adjust posterize of a batch of tensor images randomly. @@ -429,16 +576,21 @@ def random_posterize( ------ transformed tensor images of shape [B, C, H, W] """ - # Random brightness - transform = K.augmentation.RandomPosterize(bits, same_on_batch=same_on_batch, p=1.0) + bit: Union[Tuple[int, int], int] = kwargs.get("parameters", {}).get("bit") + if bit is not None: + if isinstance(bit, int): + bit = (bit, bit) + same_on_batch = True + else: + bit = 3 + + transform = K.augmentation.RandomPosterize(bit, same_on_batch=same_on_batch, p=1.0) images_out: torch.Tensor = transform(images) return images_out @register_augmentation(name="super_resolution") -def super_resolution( - images: torch.Tensor, factor: Tuple[float, float] = (0.25, 4.0) -) -> torch.Tensor: +def super_resolution(images: torch.Tensor, **kwargs) -> torch.Tensor: """ Increase resolution of images randomly @@ -452,8 +604,17 @@ def super_resolution( transformed tensor images of shape [B, C, H, W] ``` """ - B, C, H, W = images.shape + factor: Union[Tuple[float, float], float, int] = kwargs.get("parameters", {}).get( + "factor" + ) + if factor is not None: + if isinstance(factor, (float, int)): + factor = (factor, factor) + else: + factor = (0.25, 4.0) + factor: float = random.uniform(factor[0], factor[1]) + B, C, H, W = images.shape new_height: int = round(H * factor) new_width: int = round(W * factor) diff --git a/augmentation/augmentor.py b/augmentation/augmentor.py index 0852813..443773d 100644 --- a/augmentation/augmentor.py +++ b/augmentation/augmentor.py @@ -6,10 +6,10 @@ import json import uuid from pprint import pformat -import traceback import random import os -from typing import List, Optional, Tuple +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple import augmentation.augmentations_list # Import to register all augmentations from augmentation.registry import AUGMENTATIONS, CodeToAugment @@ -17,6 +17,8 @@ class Augmentor: + SUPPORTED_EXTENSIONS: Tuple = (".png", ".jpg", ".jpeg") + def __init__(self, use_gpu: bool = False): """ Apply random augmentations on batch of images. @@ -34,8 +36,8 @@ def process( input_image_paths: List[str], augment_codes: List[str], num_augments_per_image: int, + parameters: Dict[str, Dict[str, Any]], output_dir: str, - **kwargs, ) -> Tuple[List[str], List[str]]: """ Apply augmentation on a list of images. @@ -102,29 +104,50 @@ def process( ] ``` """ + print("*" * 100) + + pid = os.getpid() + print( + f"[AUGMENTATION][pid {pid}] Found {len(input_image_paths)} images: {input_image_paths}" + ) + + # Skip running for un-supported extensions + for image_path in deepcopy(input_image_paths): + _, extension = os.path.splitext(image_path) + if extension.lower() not in Augmentor.SUPPORTED_EXTENSIONS: + print( + f"[AUGMENTATION][pid {pid}] [WARNING] Only support these extensions: {Augmentor.SUPPORTED_EXTENSIONS}. " + f"But got {extension=} in image {image_path}." + "Skip this image." + ) + input_image_paths.remove(image_path) + + start_augmenting = time.time() if len(augment_codes) > 0: self.__check_valid_augment_codes(augment_codes) else: augment_codes: List[str] = list(CodeToAugment.keys()) + print( + f"[AUGMENTATION][pid {pid}] " + f"{ {augment_code: CodeToAugment[augment_code] for augment_code in augment_codes} }" + ) augment_code: str = random.choice(augment_codes) augment_name: str = CodeToAugment[augment_code] - print(f"{augment_code}: {augment_name}") - - print(f"Found {len(input_image_paths)} images.") - output_image_paths: List[str] = [] - output_json_paths: List[str] = [] - try: - output_image_paths, output_json_paths = self.__process_batch( - input_image_paths, - augment_name, - num_augments_per_image, - output_dir, - **kwargs, - ) - except Exception: - print(f"Error: {traceback.format_exc()}") + output_image_paths, output_json_paths = self.__process_batch( + input_image_paths, + augment_name, + num_augments_per_image, + parameters.get(augment_code, {}), + output_dir, + ) + + end_augmenting = time.time() + print( + f"[AUGMENTATION][pid {pid}] Done augmenting {len(input_image_paths)} images: " + f"{round(end_augmenting - start_augmenting, 4)} seconds" + ) return output_image_paths, output_json_paths def __process_batch( @@ -132,8 +155,8 @@ def __process_batch( image_paths: List[str], augment_name: str, num_augments_per_image: int, + parameters: Dict[str, Any], output_dir: str, - **kwargs, ) -> Tuple[List[str], List[str]]: """ Generate list of augmented images from an image path. @@ -180,25 +203,30 @@ def __process_batch( ] ``` """ + pid = os.getpid() + original_sizes: List[ Tuple[int, int] ] = [] # original height and widths of images images_tensor: List[torch.Tensor] = [] for image_path in image_paths: - start = time.time() + print(f"[AUGMENTATION][pid {pid}] {image_path} | ", end="") + start_read = time.time() image: np.ndarray = read_image(image_path) - end = time.time() - print(f"Read image {image_path}: {round(end - start, 2)} seconds") + end_read = time.time() + print(f"Read image: {round(end_read - start_read, 2)} seconds | ", end="") # Resize tensor images for faster processeing image_tensor: torch.Tensor = image_to_tensor(image).to(self.device) original_sizes.append(image_tensor.shape[-2:]) - start = time.time() + + start_resize = time.time() image_tensor: torch.Tensor = K.geometry.resize( image_tensor, size=(1024, 1024) ) - end = time.time() - print(f"Resize image: {round(end - start, 2)} seconds") + end_resize = time.time() + print(f"Resize image: {round(end_resize - start_resize, 2)} seconds") + images_tensor.append(image_tensor) # Stack multiple same images to form a batch @@ -208,22 +236,31 @@ def __process_batch( output_image_paths: List[str] = [] output_json_paths: List[str] = [] - # Augment batch + print( + f"[AUGMENTATION][pid {pid}] Augmenting batch of {len(images_tensor)} images: {parameters=}" + ) for _ in range(num_augments_per_image): + # Augment a batch of images start = time.time() - images_tensor_out = AUGMENTATIONS[augment_name](images_tensor) + images_tensor_out = AUGMENTATIONS[augment_name]( + images_tensor, parameters=parameters + ) end = time.time() print( - f"Generated {len(images_tensor)} images: {round(end - start, 2)} seconds" + f"[AUGMENTATION][pid {pid}] " + f"{augment_name=}: {round(end - start, 2)} seconds" ) # Save generated images for image_path, image_tensor, original_size in zip( image_paths, images_tensor_out, original_sizes ): - # Resize back to original size - height, width = original_size - image_tensor = K.geometry.resize(image_tensor, size=(height, width)) + print(f"[AUGMENTATION][pid {pid}] {image_path} | ", end="") + + # Resize images back to original size, EXCEPT for super_resolution + if augment_name != "super_resolution": + height, width = original_size + image_tensor = K.geometry.resize(image_tensor, size=(height, width)) image: np.ndarray = tensor_to_image(image_tensor) name_without_ext, ext = os.path.splitext(os.path.basename(image_path)) @@ -245,15 +282,23 @@ def __process_batch( assert len(output_image_paths) == len(output_json_paths) - print("*" * 100) return output_image_paths, output_json_paths def __check_valid_augment_codes(self, augment_codes: List[str]) -> Optional[bool]: + pid = os.getpid() + + # Map from an augment code to its augment name + supported_augment_codes: Dict[str, str] = { + augment_code: augment_name + for augment_code, augment_name in CodeToAugment.items() + if augment_name in AUGMENTATIONS.keys() + } + for augment_code in augment_codes: - augment_name: str = CodeToAugment[augment_code] - if augment_name not in AUGMENTATIONS.keys(): + if augment_code not in supported_augment_codes.keys(): message: str = ( - f"Only support these of augmentations: {pformat(CodeToAugment)}. " + f"[AUGMENTATION][pid {pid}] " + f"Only support these of augmentations: {pformat(supported_augment_codes)}. " f"Got {augment_code=}!" ) print(message) @@ -274,10 +319,15 @@ def __init_device(self, use_gpu: bool) -> torch.device: ------- "cpu" or "cuda" device """ + pid = os.getpid() if use_gpu and torch.cuda.is_available(): device: torch.device = torch.device("cuda:0") - print(f"{use_gpu=} and cuda is available. Initialized {device}") + print( + f"[AUGMENTATION][pid {pid}] {use_gpu=} and cuda is available. Initialized {device}" + ) else: device = torch.device("cpu") - print(f"{use_gpu=} and cuda not found. Initialized {device}") + print( + f"[AUGMENTATION][pid {pid}] {use_gpu=} and cuda not found. Initialized {device}" + ) return device diff --git a/augmentation/deploy.py b/augmentation/deploy.py index 159ceab..3a197e8 100644 --- a/augmentation/deploy.py +++ b/augmentation/deploy.py @@ -1,34 +1,21 @@ import ray from ray import serve from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import JSONResponse -import logging -import sys import traceback -import multiprocessing as mp -from typing import List, Dict +from typing import Any, List, Dict from augmentation.augmentor import Augmentor -from utils import get_current_time - - -CURRENT_TIME: str = get_current_time() -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(name)s %(levelname)s: %(message)s", - datefmt="%y-%b-%d %H:%M:%S", - handlers=[ - logging.StreamHandler(sys.stdout), - logging.FileHandler( - f"logs/augmentaions_{CURRENT_TIME}.txt", mode="w", encoding="utf-8" - ), - ], -) -logger = logging.getLogger(__file__) -class AugmentorDeployment: +@serve.deployment( + route_prefix="/augmentation", + num_replicas=1, + max_concurrent_queries=100, + ray_actor_options={"num_cpus": 1, "num_gpus": 0}, +) +class AugmentationDeployment: def __init__(self, use_gpu: bool = False): """ Deploy and apply random augmentations on batch of images with Ray Serve. @@ -43,7 +30,7 @@ def __init__(self, use_gpu: bool = False): async def __call__(self, request: Request) -> List[Dict[str, object]]: """ - Wrapper of `Augmentor.process` when called with HTTP request. + Wrapper of `Augmentor.process` and `Preprocessor.process` when called with HTTP request. Parameters: ---------- @@ -115,15 +102,21 @@ async def __call__(self, request: Request) -> List[Dict[str, object]]: try: input_image_paths: str = data["images_paths"] output_dir: str = data["output_folder"] - augment_code: str = data["augment_code"] - num_augments_per_image: int = data["num_augments_per_image"] + augment_codes: List[str] = data["codes"] + num_augments_per_image: int = data.get("num_augments_per_image", 1) + parameters: Dict[str, Dict[str, Any]] = data.get("parameters", {}) output_image_paths, output_json_paths = self.augmentor.process( - input_image_paths, augment_code, num_augments_per_image, output_dir + input_image_paths, + augment_codes, + num_augments_per_image, + parameters, + output_dir, ) return {"images_paths": output_image_paths, "json_paths": output_json_paths} + except Exception: - return Response(status_code=500, content=traceback.format_exc()) + return JSONResponse(status_code=500, content=traceback.format_exc()) if __name__ == "__main__": @@ -132,13 +125,4 @@ async def __call__(self, request: Request) -> List[Dict[str, object]]: serve.start(detached=True, http_options={"host": "0.0.0.0", "port": 8000}) # Deploy - num_cpus: int = mp.cpu_count() - serve.deployment(AugmentorDeployment).options( - route_prefix="/augmentation", - num_replicas=2, - max_concurrent_queries=32, - ray_actor_options={"num_cpus": num_cpus, "num_gpus": 0}, - init_kwargs={ - "use_gpu": False, - }, - ).deploy() + AugmentationDeployment.deploy(use_gpu=False) diff --git a/augmentation/registry.py b/augmentation/registry.py index 4300d5d..9e9ad2d 100644 --- a/augmentation/registry.py +++ b/augmentation/registry.py @@ -25,8 +25,12 @@ def register_augmentation(name: str): - def wrapper(augmentation_class): - AUGMENTATIONS[name] = augmentation_class - return augmentation_class + def decorator(augmentation_function): + AUGMENTATIONS[name] = augmentation_function - return wrapper + def wrapper(*args, **kwargs): + return augmentation_function(*args, **kwargs) + + return wrapper + + return decorator diff --git a/deploy.py b/deploy.py index 7a508d2..cb1f701 100644 --- a/deploy.py +++ b/deploy.py @@ -1,145 +1,9 @@ import ray from ray import serve -from starlette.requests import Request -from starlette.responses import Response - -import traceback import multiprocessing as mp -from typing import List, Dict - -from augmentation.augmentor import Augmentor -from preprocessing.preprocessor import Preprocessor - - -class Deployment: - def __init__(self, use_gpu: bool = False): - """ - Deploy and apply random augmentations on batch of images with Ray Serve. - - Parameters: - ----------- - use_gpu: - Whether to perform augmentations on gpu or not. Default: False - """ - self.use_gpu: bool = use_gpu - self.augmentor = Augmentor(use_gpu) - self.preprocessor = Preprocessor(use_gpu) - - async def __call__(self, request: Request) -> List[Dict[str, object]]: - """ - Wrapper of `Augmentor.process` when called with HTTP request. - - Parameters: - ---------- - request: HTTP POST request. The request MUST contain these keys: - - - "images_folder": path to folder containing input images. - - - "output_folder": path to foler containing output images. - - - "augment_code": code of augmentation, must be one of: - - - "AUG-000": "random_rotate" - - "AUG-001": "random_scale" - - "AUG-002": "random_translate" - - "AUG-003": "random_horizontal_flip" - - "AUG-004": "random_vertical_flip" - - "AUG-005": "random_crop" - - "AUG-006": "random_tile" - - "AUG-007": "random_erase" - - "AUG-008": "random_gaussian_noise" - - "AUG-009": "random_gaussian_blur" - - "AUG-010": "random_sharpness" - - "AUG-011": "random_brightness" - - "AUG-012": "random_hue" - - "AUG-013": "random_saturation" - - "AUG-014": "random_contrast" - - "AUG-015": "random_solarize" - - "AUG-016": "random_posterize" - - "AUG-017": "super_resolution" - - - "num_augments_per_image": number of augmentations generated for each image. - - Return: - ------ - List of dictionaries. - Each dict contains output images's paths and corresponding parameters. - Example: - ``` - [ - # Augmentation result of first image - { - "name": "random_scale", # name of augmentation - "images": [ - "/efs/sample/output/image-93daf30e-d68d-414d-af47-c5692190955e.png", - "/efs/sample/output/image-93daf30e-d68d-414d-af47-c5692190955e.png", - ], - "parameters": { - "scale": [1.1, 0.9] # assume num_augment_per_image = 3 - } - }, - - # Augmentation result of second image - { - "name": "random_scale", - "images": [ - "/efs/sample/output/image-dd004211-2c77-4d7a-873f-61fa3a176c87.png", - "/efs/sample/output/image-d59d7e15-63a3-4ef4-a5a9-6f43d69dbd19.png", - ], - "parameters": { - "scale": [0.82, 0.97] - } - }, - ... - ] - ``` - """ - data: Dict[str, object] = await request.json() - - try: - input_image_paths: str = data["images_paths"] - output_dir: str = data["output_folder"] - # Codes can be a list of augment codes or preprocessing code - codes: List[str] = data["codes"] - type_: str = data["type"] - if type_ not in ("augmentation", "preprocessing"): - return Response( - status_code=500, - content=f"Field 'type' must be 'preprocessing' or 'augmentation'. Got type={type_}", - ) - - # Handle augmentations - if type_ == "augmentation": - augment_codes: List[str] = list( - filter(lambda code: "AUG" in code, codes) - ) - if "num_augments_per_image" not in data.keys(): - raise KeyError( - f"Got {augment_codes=}. " - "Missing 'num_augments_per_image' in request body!" - ) - num_augments_per_image: int = data["num_augments_per_image"] - output_image_paths, output_json_paths = self.augmentor.process( - input_image_paths, augment_codes, num_augments_per_image, output_dir - ) - return { - "images_paths": output_image_paths, - "json_paths": output_json_paths, - } - # Handle preprocessing - elif type_ == "preprocessing": - preprocess_codes: List[str] = list( - filter(lambda code: "PRE" in code, codes) - ) - output_image_paths = self.preprocessor.process( - input_image_paths, output_dir, preprocess_codes - ) - return { - "images_paths": output_image_paths, - } - except Exception: - return Response(status_code=500, content=traceback.format_exc()) +from augmentation.deploy import AugmentationDeployment +from preprocessing.deploy import PreprocessingDeployment if __name__ == "__main__": @@ -147,11 +11,21 @@ async def __call__(self, request: Request) -> List[Dict[str, object]]: ray.init(address="auto", namespace="serve") serve.start(detached=True, http_options={"host": "0.0.0.0", "port": 8000}) - # Deploy - num_cpus: int = mp.cpu_count() - serve.deployment(Deployment).options( - route_prefix="/ai", - num_replicas=1, - max_concurrent_queries=32, - ray_actor_options={"num_cpus": num_cpus, "num_gpus": 0}, - ).deploy(use_gpu=False) + deploy_augmentation: bool = True + deploy_preprocessing: bool = True + + if deploy_augmentation: + AugmentationDeployment.options( + route_prefix="/augmentation", + num_replicas=1, + max_concurrent_queries=100, + ray_actor_options={"num_cpus": 1, "num_gpus": 0}, + ).deploy(use_gpu=False) + + if deploy_preprocessing: + PreprocessingDeployment.options( + route_prefix="/preprocessing", + num_replicas=mp.cpu_count() - 1 if mp.cpu_count() > 1 else 1, + max_concurrent_queries=100, + ray_actor_options={"num_cpus": 1, "num_gpus": 0}, + ).deploy(use_gpu=False) diff --git a/healthcheck/__init__.py b/healthcheck/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/healthcheck/healthcheck_list.py b/healthcheck/healthcheck_list.py deleted file mode 100644 index 66854d0..0000000 --- a/healthcheck/healthcheck_list.py +++ /dev/null @@ -1,94 +0,0 @@ -import numpy as np -import cv2 -import boto3 - -import os -from typing import Tuple - -from healthcheck.registry import register_healthcheck -from healthcheck.utils import S3 -from preprocessing.utils import ( - calculate_contrast_score, - calculate_signal_to_noise, - calculate_sharpness_score, - calculate_luminance, -) - - -@register_healthcheck(name="signal_to_noise") -def check_signal_to_noise_RGB( - image: np.ndarray, **kwargs -) -> Tuple[float, float, float]: - R, G, B = cv2.split(image) - snr_R: float = calculate_signal_to_noise(R) - snr_G: float = calculate_signal_to_noise(G) - snr_B: float = calculate_signal_to_noise(B) - return (snr_R, snr_G, snr_B) - - -@register_healthcheck(name="sharpness") -def check_sharpness(image: np.ndarray, **kwargs) -> float: - sharpness: float = calculate_sharpness_score(image) - return sharpness - - -@register_healthcheck(name="contrast") -def check_contrast(image: np.ndarray, **kwargs) -> float: - contrast: float = calculate_contrast_score(image) - return contrast - - -@register_healthcheck(name="luminance") -def check_luminance(image: np.ndarray, **kwargs) -> int: - luminance: float = calculate_luminance(image) - return luminance - - -@register_healthcheck(name="file_size") -def check_file_size(image: np.ndarray, **kwargs) -> int: - image_path: str = kwargs["image_path"] - if "s3://" in image_path: # path is an S3 URI - bucket, key_name = S3.split_s3_path(image_path) - file_size_in_bytes: int = ( - boto3.resource("s3").Bucket(bucket).Object(key_name).content_length - ) - else: # local path - file_size_in_bytes = os.path.getsize(image_path) - - file_size_in_mb: int = file_size_in_bytes / 1024 / 1024 - return file_size_in_mb - - -@register_healthcheck(name="height_width_aspect_ratio") -def check_image_height_and_width(image: np.ndarray, **kwargs) -> Tuple[int, int, float]: - height, width = image.shape[:2] - aspect_ratio: float = height / width - return (height, width, aspect_ratio) - - -@register_healthcheck(name="mean_red_channel") -def check_mean_red_channel(image: np.ndarray, **kwargs) -> int: - R, _, _ = cv2.split(image) - mean = np.mean(R) - return mean - - -@register_healthcheck(name="mean_green_channel") -def check_mean_green_channel(image: np.ndarray, **kwargs) -> int: - _, G, _ = cv2.split(image) - mean = np.mean(G) - return mean - - -@register_healthcheck(name="mean_blue_channel") -def check_mean_blue_channel(image: np.ndarray, **kwargs) -> int: - _, _, B = cv2.split(image) - mean = np.mean(B) - return mean - - -@register_healthcheck(name="extension") -def check_extension(image: np.ndarray, **kwargs) -> str: - image_path: str = kwargs["image_path"] - _, extension = os.path.splitext(image_path) - return extension.lower() diff --git a/healthcheck/registry.py b/healthcheck/registry.py deleted file mode 100644 index c479db3..0000000 --- a/healthcheck/registry.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Callable, Dict - - -HEALTHCHECK: Dict[str, Callable] = {} - - -def register_healthcheck(name: str): - def wrapper(healthcheck_function): - HEALTHCHECK[name] = healthcheck_function - return healthcheck_function - - return wrapper diff --git a/healthcheck/run.py b/healthcheck/run.py deleted file mode 100644 index 3b8c40e..0000000 --- a/healthcheck/run.py +++ /dev/null @@ -1,172 +0,0 @@ -import ray -import pandas as pd - -import os -import time -import datetime -import uuid -import argparse -import traceback -from typing import List, Dict, Union - -import healthcheck_list # Import to register all preprocessing -from healthcheck.registry import HEALTHCHECK -from healthcheck.utils import read_image - - -ray.init() - - -@ray.remote -def healthcheck(image_path: str) -> Dict[str, Union[int, float, str, None]]: - result: Dict[str, Union[int, float, str, None]] = { - "file_name": None, - "signal_to_noise_red_channel": None, - "signal_to_noise_green_channel": None, - "signal_to_noise_blue_channel": None, - "sharpness": None, - "contrast": None, - "luminance": None, - "file_size": None, - "height": None, - "width": None, - "aspect_ratio": None, - "mean_red_channel": None, - "mean_green_channel": None, - "mean_blue_channel": None, - "extension": None, - } - - print(f"Processing {image_path}") - try: - start = time.time() - image = read_image(image_path) - - # Check file name - result["file_name"] = os.path.basename(image_path) - - # Check signal to noise of each channel - snr_R, snr_G, snr_B = HEALTHCHECK["signal_to_noise"]( - image, image_path=image_path - ) - result["signal_to_noise_red_channel"] = snr_R - result["signal_to_noise_green_channel"] = snr_G - result["signal_to_noise_blue_channel"] = snr_B - - # Check sharpeness - sharpness = HEALTHCHECK["sharpness"](image, image_path=image_path) - result["sharpness"] = sharpness - - # Check contrast - contrast = HEALTHCHECK["contrast"](image, image_path=image_path) - result["contrast"] = contrast - - # Check luminance - luminance = HEALTHCHECK["luminance"](image, image_path=image_path) - result["luminance"] = luminance - - # Check file size - file_size_in_mb = HEALTHCHECK["file_size"](image, image_path=image_path) - result["file_size"] = file_size_in_mb - - # Check height, width and aspect ratio - height, width, aspect_ratio = HEALTHCHECK["height_width_aspect_ratio"]( - image, image_path=image_path - ) - result["height"] = height - result["width"] = width - result["aspect_ratio"] = aspect_ratio - - # Check mean of each channel - mean_red_channel = HEALTHCHECK["mean_red_channel"](image, image_path=image_path) - mean_green_channel = HEALTHCHECK["mean_green_channel"]( - image, image_path=image_path - ) - mean_blue_channel = HEALTHCHECK["mean_blue_channel"]( - image, image_path=image_path - ) - result["mean_red_channel"] = mean_red_channel - result["mean_green_channel"] = mean_green_channel - result["mean_blue_channel"] = mean_blue_channel - - # Check extension - extension = HEALTHCHECK["extension"](image, image_path=image_path) - result["extension"] = extension - - end = time.time() - print(result) - print(f"Done processing {image_path}: {round(end - start, 4)} seconds") - return result - - except Exception: - print(result) - print(f"Error processing {image_path}: {traceback.format_exc()} !!!") - return result - - -def main(image_paths: List[str], output_dir: str) -> Union[str, bool]: - output: Dict[str, List[Union[int, float, str, None]]] = { - "file_name": [], - "signal_to_noise_red_channel": [], - "signal_to_noise_green_channel": [], - "signal_to_noise_blue_channel": [], - "sharpness": [], - "contrast": [], - "luminance": [], - "file_size": [], - "height": [], - "width": [], - "aspect_ratio": [], - "mean_red_channel": [], - "mean_green_channel": [], - "mean_blue_channel": [], - "extension": [], - } - - try: - # Multiprocessing heathccheck - results: List[Dict[str, List[Union[int, float, str, None]]]] = [ - healthcheck.remote(image_path) for image_path in image_paths - ] - results = ray.get(results) - - # Combine results into a dict of list - for result in results: - for key in output.keys(): - output[key].append(result[key]) - - # Convert data to csv and save to disk - csv_name: str = ( - "healthcheck_" - + datetime.datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S") - + "_" - + uuid.uuid4().hex - + ".csv" - ) - csv_path: str = os.path.join(output_dir, csv_name) - pd.DataFrame(output).to_csv(csv_path, index=None) - print(f"Output csv path: {csv_path}") - - except Exception: - print("Error !!!") - print(traceback.format_exc()) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Check health of dataset") - parser.add_argument( - "--images", - type=str, - nargs="*", - default=[], - help="List of image path (can be S3 URI or local path)", - ) - parser.add_argument( - "--output_dir", type=str, required=True, help="Directory to save csv output" - ) - args = vars(parser.parse_args()) - - image_paths: List[str] = args["images"] - output_dir: str = args["output_dir"] - - main(image_paths, output_dir) diff --git a/healthcheck/utils.py b/healthcheck/utils.py deleted file mode 100644 index 051d07e..0000000 --- a/healthcheck/utils.py +++ /dev/null @@ -1,67 +0,0 @@ -import cv2 -import boto3 -import numpy as np -from datetime import datetime -from typing import List, Dict, Tuple, Union - - -def is_gray_image(image: np.ndarray) -> bool: - if image.ndim == 2 or (image.ndim == 3 and image.shape[2] == 1): - return True - return False - - -def get_current_time() -> str: - return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - - -def read_image(image_path: str) -> np.ndarray: - if "s3://" in image_path: # image in S3 bucket - image: np.ndarray = S3.read_image(image_path) - else: # image in local machine - image = cv2.imread(image_path, cv2.IMREAD_COLOR) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return np.ascontiguousarray(image) - - -def save_image(image_path: str, image: np.ndarray) -> None: - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - cv2.imwrite(image_path, image) - - -class S3: - s3 = boto3.client("s3") - - @staticmethod - def split_s3_path(path: str) -> Tuple[str, str]: - """ - Split s3 path into bucket and file name - - >>> split_s3_uri('s3://bucket/folder/image.png') - ('bucket', 'folder/image.png') - """ - # s3_path, file_name = os.path.split - bucket, _, file_name = path.replace("s3://", "").partition("/") - return bucket, file_name - - def read_image(uri: str) -> np.ndarray: - try: - bucket, file_name = S3.split_s3_path(uri) - s3_response_object = S3.s3.get_object(Bucket=bucket, Key=file_name) - - array: np.ndarray = np.frombuffer( - s3_response_object["Body"].read(), np.uint8 - ) - image = cv2.imdecode(array, cv2.IMREAD_COLOR) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return image - - except S3.s3.exceptions.NoSuchKey: - message: str = f"File not found. [bucket={bucket},key={file_name}]" - print(message) - raise Exception(message) - - except S3.s3.exceptions.NoSuchBucket: - message: str = f"Bucket not found. [bucket={bucket},key={file_name}]" - print(message) - raise Exception(message) diff --git a/preprocessing/deploy.py b/preprocessing/deploy.py index c178fb1..3c1fccc 100644 --- a/preprocessing/deploy.py +++ b/preprocessing/deploy.py @@ -1,34 +1,21 @@ import ray from ray import serve from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import JSONResponse -import logging -import sys import traceback -import multiprocessing as mp from typing import List, Dict from preprocessing.preprocessor import Preprocessor -from utils import get_current_time - - -CURRENT_TIME: str = get_current_time() -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(name)s %(levelname)s: %(message)s", - datefmt="%y-%b-%d %H:%M:%S", - handlers=[ - logging.StreamHandler(sys.stdout), - logging.FileHandler( - f"AI/logs/preprocessing_{CURRENT_TIME}.txt", mode="w", encoding="utf-8" - ), - ], -) -logger = logging.getLogger(__file__) -class PreprocessorDeployment: +@serve.deployment( + route_prefix="/preprocessing", + num_replicas=1, + max_concurrent_queries=100, + ray_actor_options={"num_cpus": 1, "num_gpus": 0}, +) +class PreprocessingDeployment: def __init__(self, use_gpu: bool = False): """ Deploy and apply random augmentations on batch of images with Ray Serve. @@ -42,23 +29,93 @@ def __init__(self, use_gpu: bool = False): self.preprocessor = Preprocessor(use_gpu) async def __call__(self, request: Request) -> List[Dict[str, object]]: + """ + Wrapper of `Preprocessor.process` when called with HTTP request. + + Parameters: + ---------- + request: HTTP POST request. The request MUST contain these keys: + + - "images_folder": path to folder containing input images. + + - "output_folder": path to foler containing output images. + + - "augment_code": code of augmentation, must be one of: + + - "AUG-000": "random_rotate" + - "AUG-001": "random_scale" + - "AUG-002": "random_translate" + - "AUG-003": "random_horizontal_flip" + - "AUG-004": "random_vertical_flip" + - "AUG-005": "random_crop" + - "AUG-006": "random_tile" + - "AUG-007": "random_erase" + - "AUG-008": "random_gaussian_noise" + - "AUG-009": "random_gaussian_blur" + - "AUG-010": "random_sharpness" + - "AUG-011": "random_brightness" + - "AUG-012": "random_hue" + - "AUG-013": "random_saturation" + - "AUG-014": "random_contrast" + - "AUG-015": "random_solarize" + - "AUG-016": "random_posterize" + - "AUG-017": "super_resolution" + + - "num_augments_per_image": number of augmentations generated for each image. + + Return: + ------ + List of dictionaries. + Each dict contains output images's paths and corresponding parameters. + Example: + ``` + [ + # Augmentation result of first image + { + "name": "random_scale", # name of augmentation + "images": [ + "/efs/sample/output/image-93daf30e-d68d-414d-af47-c5692190955e.png", + "/efs/sample/output/image-93daf30e-d68d-414d-af47-c5692190955e.png", + ], + "parameters": { + "scale": [1.1, 0.9] # assume num_augment_per_image = 3 + } + }, + + # Augmentation result of second image + { + "name": "random_scale", + "images": [ + "/efs/sample/output/image-dd004211-2c77-4d7a-873f-61fa3a176c87.png", + "/efs/sample/output/image-d59d7e15-63a3-4ef4-a5a9-6f43d69dbd19.png", + ], + "parameters": { + "scale": [0.82, 0.97] + } + }, + ... + ] + ``` + """ data: Dict[str, object] = await request.json() try: input_image_paths: str = data["images_paths"] output_dir: str = data["output_folder"] - preprocess_codes: List[str] = data.get("preprocess_code", None) + preprocess_codes: List[str] = data["codes"] + reference_paths_dict: Dict[str, str] = data["reference_images"] - output_image_paths = self.preprocessor.process( - input_image_paths, - output_dir, - preprocess_codes, + output_image_paths: List[str] = self.preprocessor.process( + input_image_paths, output_dir, preprocess_codes, reference_paths_dict ) return { "images_paths": output_image_paths, } except Exception: - return Response(status_code=500, content=traceback.format_exc()) + return JSONResponse( + status_code=500, + content={"images_paths": [], "error": traceback.format_exc()}, + ) if __name__ == "__main__": @@ -67,13 +124,4 @@ async def __call__(self, request: Request) -> List[Dict[str, object]]: serve.start(detached=True, http_options={"host": "0.0.0.0", "port": 8000}) # Deploy - num_cpus: int = mp.cpu_count() - serve.deployment(PreprocessorDeployment).options( - route_prefix="/preprocessing", - num_replicas=2, - max_concurrent_queries=32, - ray_actor_options={"num_cpus": num_cpus, "num_gpus": 0}, - init_kwargs={ - "use_gpu": False, - }, - ).deploy() + PreprocessingDeployment.deploy(use_gpu=False) diff --git a/preprocessing/preprocessing_list.py b/preprocessing/preprocessing_list.py index 14976ad..0f9f09e 100644 --- a/preprocessing/preprocessing_list.py +++ b/preprocessing/preprocessing_list.py @@ -4,7 +4,7 @@ from utils import resize_image from preprocessing.base import BasePreprocessing from preprocessing.registry import register_preprocessing -from preprocessing.utils import ( +from preprocessing.preprocessing_utils import ( calculate_contrast_score, calculate_sharpness_score, ) @@ -104,7 +104,7 @@ def process( from skimage.color import rgb2gray is_normalized: bool = True - image_out: np.ndarray = rgb2gray(image) + image_out: np.ndarray = (rgb2gray(image) * 255).astype(np.uint8) return image_out, is_normalized @@ -136,11 +136,11 @@ def process( is_normalized: bool = False - reference_image_hsv = rgb2hsv(reference_image) - reference_brightness: float = reference_image_hsv[2].var() + reference_image_hsv: np.ndarray = rgb2hsv(reference_image) + reference_brightness: float = reference_image_hsv[:, :, 2].var() - image_hsv = rgb2hsv(image) - brightness: float = image_hsv[2].var() + image_hsv: np.ndarray = rgb2hsv(image) + brightness: float = image_hsv[:, :, 2].var() if abs(brightness - reference_brightness) / reference_brightness > 0.75: matched_hsv = match_histograms( @@ -182,11 +182,11 @@ def process( is_normalized: bool = False - reference_image_hsv = rgb2hsv(reference_image) - reference_hue: float = reference_image_hsv[0].var() + reference_image_hsv: np.ndarray = rgb2hsv(reference_image) + reference_hue: float = reference_image_hsv[:, :, 0].var() - image_hsv = rgb2hsv(image) - hue: float = image_hsv[0].var() + image_hsv: np.ndarray = rgb2hsv(image) + hue: float = image_hsv[:, :, 0].var() if abs(hue - reference_hue) / reference_hue > 0.75: matched_hsv = match_histograms( @@ -229,11 +229,11 @@ def process( is_normalized: bool = False - reference_image_hsv = rgb2hsv(reference_image) - reference_saturation: float = np.mean(reference_image_hsv[1]) + reference_image_hsv: np.ndarray = rgb2hsv(reference_image) + reference_saturation: float = np.mean(reference_image_hsv[:, :, 1]) - image_hsv = rgb2hsv(image) - saturation: float = np.mean(image_hsv[1]) + image_hsv: np.ndarray = rgb2hsv(image) + saturation: float = np.mean(image_hsv[:, :, 1]) if abs(saturation - reference_saturation) / reference_saturation > 0.75: matched_hsv = match_histograms( @@ -364,7 +364,7 @@ def process( @register_preprocessing(name="high_resolution") class IncreaseResolution(BasePreprocessing): def __init__(self): - self.scale_factor = 2.0 + pass def process( self, image: np.ndarray, reference_image: np.ndarray, **kwargs @@ -385,11 +385,8 @@ def process( normalized tensor images of shape [B, C, H, W] """ is_normalized: bool = False - H, W, _ = image.shape - - new_height = int(H * self.scale_factor) - new_width = int(W * self.scale_factor) - image_out: np.ndarray = resize_image(image, (new_height, new_width)) + reference_height, reference_width = reference_image.shape[:2] + image_out: np.ndarray = resize_image(image, (reference_height, reference_width)) if image_out.shape != image.shape: is_normalized = True diff --git a/preprocessing/preprocessing_utils.py b/preprocessing/preprocessing_utils.py new file mode 100644 index 0000000..9214e42 --- /dev/null +++ b/preprocessing/preprocessing_utils.py @@ -0,0 +1,71 @@ +import torch +import numpy as np +from skimage.color import rgb2ycbcr +from skimage.color import rgb2lab +import cv2 +from typing import List, Union + + +def calculate_contrast_score(image: np.ndarray) -> float: + """ + https://en.wikipedia.org/wiki/Contrast_(vision)#Michelson_contrast + """ + YCrCb: torch.Tensor = rgb2ycbcr(image) + Y = YCrCb[0, :, :] + + min = np.min(Y) + max = np.max(Y) + + # compute contrast + contrast = (max - min) / (max + min) + return float(contrast) + + +def calculate_sharpness_score(image: np.ndarray) -> float: + sharpness: float = cv2.Laplacian(image, cv2.CV_16S).std() + return sharpness + + +def calculate_signal_to_noise(image: np.ndarray, axis=None, ddof=0) -> float: + """ + The signal-to-noise ratio of the input data. + Returns the signal-to-noise ratio of an image, here defined as the mean + divided by the standard deviation. + + Parameters + ---------- + a : array_like + An array_like object containing the sample data. + + axis : int or None, optional + Axis along which to operate. If None, compute over + the whole image. + + ddof : int, optional + Degrees of freedom correction for standard deviation. Default is 0. + + Returns + ------- + s2n : ndarray + The mean to standard deviation ratio(s) along `axis`, or 0 where the + standard deviation is 0. + """ + image = np.asanyarray(image) + mean = image.mean(axis) + std = image.std(axis=axis, ddof=ddof) + signal_to_noise: np.ndarray = np.where(std == 0, 0, mean / std) + return float(signal_to_noise) + + +def calculate_luminance(image: np.ndarray) -> float: + lab: float = rgb2lab(image) + luminance: int = cv2.Laplacian(lab, cv2.CV_16S).std() + return luminance + + +def get_index_of_median_value(array: Union[List[float], np.ndarray]) -> int: + """ + Find index of the median value in a list or 1-D arry + """ + index: int = np.argsort(array)[len(array) // 2] + return index diff --git a/preprocessing/preprocessor.py b/preprocessing/preprocessor.py index 29d2e4b..4b56b70 100644 --- a/preprocessing/preprocessor.py +++ b/preprocessing/preprocessor.py @@ -1,20 +1,27 @@ import numpy as np import torch -import multiprocessing as mp import time import os +from copy import deepcopy import traceback -from functools import partial -from typing import Callable, List, Optional +from typing import List, Optional, Dict, Tuple import preprocessing.preprocessing_list # Import to register all preprocessing from preprocessing.registry import PREPROCESSING, CodeToPreprocess from utils import read_image, resize_image, save_image -from preprocessing.utils import calculate_signal_to_noise +from preprocessing.references import ( + find_reference_brightness_image, + find_reference_hue_image, + find_reference_saturation_image, + find_reference_signal_to_noise_image, + find_reference_high_resolution_image, +) class Preprocessor: + SUPPORTED_EXTENSIONS: Tuple = (".png", ".jpg", ".jpeg") + def __init__(self, use_gpu: bool = False): """ Apply random augmentations on batch of images. @@ -32,126 +39,209 @@ def process( input_image_paths: List[str], output_dir: str, preprocess_codes: List[str], + reference_paths_dict: Dict[str, str], **kwargs, ) -> List[str]: + print("*" * 100) - print(f"Found {len(input_image_paths)} images.") - # If preprocess codes are not given, run all preprocessing methods + pid: int = os.getpid() + print(f"[PREPROCESSING][pid {pid}] Found {len(input_image_paths)} images") + + # Skip running for un-supported extensions + for image_path in deepcopy(input_image_paths): + _, extension = os.path.splitext(image_path) + if extension.lower() not in Preprocessor.SUPPORTED_EXTENSIONS: + print( + f"[PREPROCESSING][pid {pid}] [WARNING] Only support these extensions: {Preprocessor.SUPPORTED_EXTENSIONS}. " + f"But got {extension=} in image {image_path}." + "Skip this image." + ) + input_image_paths.remove(image_path) + + start_preprocess = time.time() + + # In mode 'auto' (preprocess codes are not given): + # Run all preprocessing methods, except grayscale (PRE-001) if len(preprocess_codes) == 0: - preprocess_codes = list(CodeToPreprocess.keys()) + preprocess_codes = [ + code for code in CodeToPreprocess.keys() if code != "PRE-001" + ] + else: + # In mode 'expert' (some preprocess codes are given): + # If grayscale (PRE-001) in preprocess_codes, + # remove all other codes, except for auto orientation (PRE-000) and super-resolution (PRE-009) + if "PRE-001" in preprocess_codes: + preprocess_codes = [ + code + for code in preprocess_codes + if code in ("PRE-000", "PRE-001", "PRE-009") + ] + print( + f"[PREPROCESSING][pid {pid}] " f"Preprocess codes: { {code: CodeToPreprocess[code] for code in preprocess_codes} }" ) - # Find reference image - print("Finding reference image...") - start = time.time() - reference_image_path: str = self.__find_reference_image(input_image_paths) - end = time.time() + # If reference images are not already given for each preprocessing method, + # find those reference images + if len(reference_paths_dict.keys()) == 0: + # Find reference image + print( + f"[PREPROCESSING][pid {pid}] Reference images are not given. Finding reference image..." + ) + reference_paths_dict: Dict[str, str] = self.get_reference_image_paths( + input_image_paths, preprocess_codes + ) + else: + print( + f"[PREPROCESSING][pid {pid}] " + f"Reference images are already given: {reference_paths_dict}" + ) + + # Load and resize all reference images beforehand + reference_images_dict: Dict[str, np.ndarray] = { + preprocess_code: resize_image(read_image(reference_image_path), 1024) + for preprocess_code, reference_image_path in reference_paths_dict.items() + } + + # Process each input image + output_image_paths: List[str] = [] + for input_image_path in input_image_paths: + try: + output_image_path: Optional[str] = self.__process_one_image( + input_image_path, + output_dir, + preprocess_codes, + reference_images_dict, + ) + output_image_paths.append(output_image_path) + # If there are some errors (input image not found, weird image...) then skip the image + except Exception: + print(f"[PREPROCESSING][pid {pid}] ERROR: {traceback.format_exc()}") + continue + + end_preprocess = time.time() print( - f"Found reference image {reference_image_path}: {round(end - start, 4)} seconds." + f"[PREPROCESSING][pid {pid}] Done preprocessing {len(input_image_paths)} images: " + f"{round(end_preprocess - start_preprocess, 4)} seconds" ) + return output_image_paths - reference_image: np.ndarray = read_image(reference_image_path) - # Resize reference image for faster processeing - reference_image = resize_image(reference_image, size=1024) + def get_reference_image_paths( + self, + input_image_paths: List[str], + preprocess_codes: List[str], + ) -> Dict[str, str]: - # Multiprocessing pool - num_processes = ( - len(input_image_paths) - if len(input_image_paths) <= mp.cpu_count() - else mp.cpu_count() - ) - pool = mp.Pool(num_processes) - - # Run process_one_image on each image in a separate process - process_one_image: Callable = partial( - self._process_one_image, - reference_image=reference_image, - output_dir=output_dir, - preprocess_codes=preprocess_codes, - ) - output_image_paths: List[str] = pool.map(process_one_image, input_image_paths) + pid: int = os.getpid() - # Remove error image paths - output_image_paths = [ - image_path for image_path in output_image_paths if image_path is not None + # Mapping from a preprocess_code to its corresponding reference image path + reference_paths_dict: Dict[str, str] = {} + # Read all input images beforehand + input_images: List[str] = [ + read_image(input_image_path) for input_image_path in input_image_paths ] - return output_image_paths - def _process_one_image( + # Find reference image for each preprocessing code + for code in preprocess_codes: + preprocess_name: str = CodeToPreprocess[code] + print( + f"[PREPROCESSING][pid {pid}] Reference image of {code} ({preprocess_name}): ", + end="", + ) + start = time.time() + reference_image_path: str = self.__find_reference_image_path( + input_images, input_image_paths, preprocess_name + ) + reference_paths_dict[code] = reference_image_path + end = time.time() + print( + f"{os.path.basename(reference_image_path)} ({round(end - start, 4)} seconds)" + ) + + return reference_paths_dict + + def __process_one_image( self, input_image_path: str, - reference_image: np.ndarray, output_dir: str, preprocess_codes: List[str], + reference_images_dict: Dict[str, np.ndarray], ) -> Optional[str]: """ Apply preprocessing to input image given a reference image. Return saved output image path or None. """ pid: int = os.getpid() - print(f"[pid {pid}] Preprocessing {input_image_path}") - try: - start = time.time() - image: np.ndarray = read_image(input_image_path) - H, W, _ = image.shape - # Resize images for faster processeing - image = resize_image(image, size=1024) - - for preprocess_code in preprocess_codes: - try: - preprocess_name: str = CodeToPreprocess[preprocess_code] - print(f"[pid {pid}] {preprocess_name}:", end=" ") - image, is_normalized = PREPROCESSING[preprocess_name]().process( - image, reference_image, image_path=input_image_path - ) - print(is_normalized) - - except Exception: - print(f"[pid {pid}] ERROR: {preprocess_name}") - print(f"[pid {pid}] {traceback.format_exc()}") - continue - - # Resize back to original size - image = resize_image(image, size=H if H < W else W) - end = time.time() - print(f"[pid {pid}] Done preprocessing: {round(end - start, 4)} seconds") + print(f"[PREPROCESSING][pid {pid}] Preprocessing {input_image_path}") + start = time.time() - # Save output image - image_name: str = os.path.basename(input_image_path) - output_image_path: str = os.path.join( - output_dir, f"preprocessed_{image_name}" - ) - start = time.time() - save_image(output_image_path, image) - end = time.time() - print( - f"[pid {pid}] Save image {output_image_path}: {round(end - start, 2)} seconds" - ) - return output_image_path + image: np.ndarray = read_image(input_image_path) + H, W, _ = image.shape + # Resize images for faster processeing + image = resize_image(image, size=1024) - # If there are some errors (input image not found, weird image...) then return None - except Exception: - print(f"[pid {pid}] ERROR: {traceback.format_exc()}") - return None + for preprocess_code in preprocess_codes: + try: + preprocess_name: str = CodeToPreprocess[preprocess_code] + reference_image: np.ndarray = reference_images_dict[preprocess_code] + image, is_normalized = PREPROCESSING[preprocess_name]().process( + image, reference_image, image_path=input_image_path + ) + print(f"[PREPROCESSING][pid {pid}] {preprocess_name}:", is_normalized) - def __find_reference_image(self, input_image_paths: List[str]) -> str: - images: List[np.ndarray] = [ - read_image(image_path) for image_path in input_image_paths - ] + except Exception: + print(f"[PREPROCESSING][pid {pid}] ERROR: {preprocess_name}") + print(f"[PREPROCESSING][pid {pid}] {traceback.format_exc()}") + continue - signal_to_noise_ratios: List[float] = [ - calculate_signal_to_noise(image) for image in images - ] + # Resize back to original size + image = resize_image(image, size=H if H < W else W) + end = time.time() + print( + f"[PREPROCESSING][pid {pid}] " + f"Done preprocessing: {round(end - start, 4)} seconds" + ) - idxs_sorted: List[int] = sorted( - range(len(signal_to_noise_ratios)), - key=lambda i: signal_to_noise_ratios[i], + # Save output image + image_name: str = os.path.basename(input_image_path) + output_image_path: str = os.path.join(output_dir, f"preprocessed_{image_name}") + start = time.time() + save_image(output_image_path, image) + end = time.time() + print( + f"[PREPROCESSING][pid {pid}] " + f"Save image {output_image_path}: {round(end - start, 2)} seconds" ) - idx: int = idxs_sorted[0] - reference_image_path: str = input_image_paths[idx] + return output_image_path + + def __find_reference_image_path( + self, + input_images: List[np.ndarray], + input_image_paths: List[str], + preprocess_name: str, + ) -> str: + if preprocess_name == "normalize_brightness": + reference_image_path: str = find_reference_brightness_image( + input_images, input_image_paths + ) + elif preprocess_name == "normalize_hue": + reference_image_path = find_reference_hue_image( + input_images, input_image_paths + ) + elif preprocess_name == "normalize_saturation": + reference_image_path = find_reference_saturation_image( + input_images, input_image_paths + ) + elif preprocess_name == "high_resolution": + reference_image_path: str = find_reference_high_resolution_image( + input_images, input_image_paths + ) + else: + reference_image_path = find_reference_signal_to_noise_image( + input_images, input_image_paths + ) return reference_image_path @staticmethod @@ -168,10 +258,15 @@ def __init_device(use_gpu: bool) -> torch.device: ------- "cpu" or "cuda" device """ + pid = os.getpid() if use_gpu and torch.cuda.is_available(): device: torch.device = torch.device("cuda:0") - print(f"{use_gpu=} and cuda is available. Initialized {device}") + print( + f"[PREPROCESSING][pid {pid}] {use_gpu=} and cuda is available. Initialized {device}" + ) else: device = torch.device("cpu") - print(f"{use_gpu=} and cuda not found. Initialized {device}") + print( + f"[PREPROCESSING][pid {pid}] {use_gpu=} and cuda not found. Initialized {device}" + ) return device diff --git a/preprocessing/references.py b/preprocessing/references.py new file mode 100644 index 0000000..246e80e --- /dev/null +++ b/preprocessing/references.py @@ -0,0 +1,104 @@ +import numpy as np +from skimage.color import rgb2hsv +from typing import List + +from preprocessing.preprocessing_utils import ( + calculate_signal_to_noise, + get_index_of_median_value, +) + + +def find_reference_brightness_image( + input_images: List[np.ndarray], input_image_paths: List[str] +) -> str: + hsv_images: List[np.ndarray] = [rgb2hsv(image) for image in input_images] + + # List of images' brightness + brightness_ls: List[float] = [hsv_image[:, :, 2].var() for hsv_image in hsv_images] + median_idx: int = get_index_of_median_value(brightness_ls) + # Reference image is the one that has median brightness + reference_image_path: str = input_image_paths[median_idx] + return reference_image_path + + +def find_reference_hue_image( + input_images: List[np.ndarray], input_image_paths: List[str] +) -> str: + hsv_images: List[np.ndarray] = [rgb2hsv(image) for image in input_images] + # List of images' hue + hue_ls: List[float] = [hsv_image[:, :, 0].var() for hsv_image in hsv_images] + median_idx: int = get_index_of_median_value(hue_ls) + # Reference image is the one that has median hue + reference_image_path: str = input_image_paths[median_idx] + return reference_image_path + + +def find_reference_saturation_image( + input_images: List[np.ndarray], input_image_paths: List[str] +) -> str: + hsv_images: List[np.ndarray] = [rgb2hsv(image) for image in input_images] + # List of images' saturation + saturation_ls: List[float] = [hsv_image[:, :, 1].var() for hsv_image in hsv_images] + median_idx: int = get_index_of_median_value(saturation_ls) + # Reference image is the one that has median saturation + reference_image_path: str = input_image_paths[median_idx] + return reference_image_path + + +def find_reference_signal_to_noise_image( + input_images: List[np.ndarray], input_image_paths: List[str] +) -> str: + signal_to_noise_ratios: List[float] = [ + calculate_signal_to_noise(image) for image in input_images + ] + + idxs_sorted: List[int] = sorted( + range(len(signal_to_noise_ratios)), + key=lambda i: signal_to_noise_ratios[i], + ) + idx: int = idxs_sorted[0] + reference_image_path: str = input_image_paths[idx] + return reference_image_path + + +def find_reference_high_resolution_image( + input_images: List[np.ndarray], input_image_paths: List[str] +) -> str: + + heights: List[float] = [image.shape[0] for image in input_images] + widths: List[float] = [image.shape[1] for image in input_images] + aspect_ratios: List[float] = [ + height / width for height, width in zip(heights, widths) + ] + + # Divide aspect ratios into multiple bins + bins_count, bins_values = np.histogram( + aspect_ratios, np.arange(start=0.1, stop=10, step=0.2) + ) + # Find idx of the bin that occurs most + most_common_bin_idx: int = np.argmax(bins_count) + # Value of most-occur bin + most_common_bin_count: int = bins_count[most_common_bin_idx] + # If there is only 1 bin that occur most + if bins_count.tolist().count(most_common_bin_count) == 1: + most_common_aspect_ratio_idx: int = [ + idx + for idx, aspect_ratio in enumerate(aspect_ratios) + if bins_values[most_common_bin_idx] + <= aspect_ratio + <= bins_values[most_common_bin_idx + 1] + ][0] + # Reference image is the one that has median saturation + reference_image_path: str = input_image_paths[most_common_aspect_ratio_idx] + return reference_image_path + # if there are multiple bin with the same count + else: + max_height: int = max(heights) + max_width: int = max(widths) + if max_height > max_width: + max_height_idx: int = np.argmax(heights) + reference_image_path: str = input_image_paths[max_height_idx] + else: + max_width_idx: int = np.argmax(widths) + reference_image_path: str = input_image_paths[max_width_idx] + return reference_image_path diff --git a/preprocessing/registry.py b/preprocessing/registry.py index ad158d3..d36617d 100644 --- a/preprocessing/registry.py +++ b/preprocessing/registry.py @@ -5,7 +5,7 @@ PREPROCESSING: Dict[str, BasePreprocessing] = {} CodeToPreprocess: Dict[str, str] = { "PRE-000": "auto_orientation", - # "PRE-001": "grayscale", + "PRE-001": "grayscale", "PRE-002": "normalize_brightness", "PRE-003": "normalize_hue", "PRE-004": "normalize_saturation", @@ -13,7 +13,7 @@ "PRE-006": "normalize_contrast", # "PRE-007": "normalize_affine", "PRE-008": "equalize_histogram", - # "PRE-009": "high_resolution", + "PRE-009": "high_resolution", # "PRE-010": "detect_outlier", # "PRE-011": "tilling", # "PRE-012": "cropping", diff --git a/preprocessing/utils.py b/preprocessing/utils.py deleted file mode 100644 index 34c858f..0000000 --- a/preprocessing/utils.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch -import numpy as np -from skimage.color import rgb2ycbcr -from skimage.color import rgb2lab -import kornia as K -import cv2 - - -def calculate_contrast_score(image: np.ndarray) -> float: - """ - https://en.wikipedia.org/wiki/Contrast_(vision)#Michelson_contrast - """ - YCrCb: torch.Tensor = rgb2ycbcr(image) - Y = YCrCb[0, :, :] - - min = np.min(Y) - max = np.max(Y) - - # compute contrast - contrast = (max - min) / (max + min) - return float(contrast) - - -def calculate_sharpness_score(image: np.ndarray) -> float: - sharpness: float = cv2.Laplacian(image, cv2.CV_16S).std() - return sharpness - - -def calculate_signal_to_noise(image: np.ndarray, axis=None, ddof=0) -> float: - """ - The signal-to-noise ratio of the input data. - Returns the signal-to-noise ratio of an image, here defined as the mean - divided by the standard deviation. - - Parameters - ---------- - a : array_like - An array_like object containing the sample data. - - axis : int or None, optional - Axis along which to operate. If None, compute over - the whole image. - - ddof : int, optional - Degrees of freedom correction for standard deviation. Default is 0. - - Returns - ------- - s2n : ndarray - The mean to standard deviation ratio(s) along `axis`, or 0 where the - standard deviation is 0. - """ - image = np.asanyarray(image) - mean = image.mean(axis) - std = image.std(axis=axis, ddof=ddof) - signal_to_noise: np.ndarray = np.where(std == 0, 0, mean / std) - return float(signal_to_noise) - - -def calculate_luminance(image: np.ndarray) -> float: - lab: float = rgb2lab(image) - luminance: int = cv2.Laplacian(lab, cv2.CV_64F).std() - return luminance - - -class AdaptiveGammaCorrection: - def process(image: torch.Tensor) -> torch.Tensor: - """ - Apply adaptive gamma correction on a tensor image. - - Parameters: - ---------- - image: tensor image of shape [C, H, W] - - Return - ------ - normalized tensor images of shape [C, H, W] - """ - C, H, W = image.shape - YCrCb: torch.Tensor = K.color.rgb_to_ycbcr(image.unsqueeze(dim=0))[ - 0 - ] # shape: [C, H, W] - YCrCb = (YCrCb * 255).type(torch.int32) - Y = YCrCb[0, :, :] - - # Threshold to determine whether image is bright or dimmed - threshold: float = 0.3 - expected_global_avg_intensity: int = 112 - mean_intensity: float = torch.sum(Y / (H * W)).item() - t: float = ( - mean_intensity - expected_global_avg_intensity - ) / expected_global_avg_intensity - - if t <= threshold: - result: torch.Tensor = AdaptiveGammaCorrection.__process_dimmed(Y) - else: - result: torch.Tensor = AdaptiveGammaCorrection.__process_bright(Y) - - YCrCb[0, :, :] = result - YCrCb = (YCrCb / 255).astype(float).unqueeze(dim=0) - image_out = K.color.ycbcr_to_rgb(YCrCb) - return image_out.squeeze(dim=0) - - @staticmethod - def __process_bright(image: torch.Tensor): - img_negative = 255 - image - out = AdaptiveGammaCorrection.__correct_gamma( - img_negative, a=0.25, truncated_cdf=False - ) - out = 255 - out - return out - - @staticmethod - def __process_dimmed(image: torch.Tensor): - out = AdaptiveGammaCorrection.__correct_gamma(image, a=0.75, truncated_cdf=True) - return out - - @staticmethod - def __correct_gamma( - self, image: torch.Tensor, a: float = 0.25, truncated_cdf: bool = False - ) -> torch.Tensor: - H, W = image.shape - hist, bins = torch.histogram(image.ravel(), bins=256, range=(0, 256)) - proba_normalized = hist / hist.sum() - - unique_intensity = torch.unique(image) - proba_min = proba_normalized.min() - proba_max = proba_normalized.max() - - pn_temp = (proba_normalized - proba_min) / (proba_max - proba_min) - pn_temp[pn_temp > 0] = proba_max * (pn_temp[pn_temp > 0] ** a) - pn_temp[pn_temp < 0] = proba_max * (-((-pn_temp[pn_temp < 0]) ** a)) - prob_normalized_wd = pn_temp / pn_temp.sum() # normalize to [0,1] - cdf_prob_normalized_wd = prob_normalized_wd.cumsum() - - if truncated_cdf: - inverse_cdf = torch.maximum(0.5, 1 - cdf_prob_normalized_wd) - else: - inverse_cdf = 1 - cdf_prob_normalized_wd - - image_new = image.clone() - for i in unique_intensity: - image_new[image == i] = torch.round(255 * (i / 255) ** inverse_cdf[i]) - return image_new diff --git a/requirements-gpu.txt b/requirements-gpu.txt index dc0bc67..23b6361 100644 --- a/requirements-gpu.txt +++ b/requirements-gpu.txt @@ -8,7 +8,6 @@ opencv-python==4.5.4.58 Pillow==9.0.1 scikit-image==0.18.3 scipy==1.7.2 -pandas==1.4.1 numpy awscli boto3 diff --git a/utils.py b/utils.py index 77abf59..b63f7a9 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ import boto3 from PIL import Image +import os import base64 from io import BytesIO import random @@ -47,7 +48,7 @@ def save_image(image_path: str, image: np.ndarray) -> None: def resize_image(image: np.ndarray, size: Union[Tuple[int, int], int]) -> np.ndarray: - height, width, _ = image.shape + height, width = image.shape[:2] if isinstance(size, int): if height < width: new_height = size @@ -63,7 +64,12 @@ def resize_image(image: np.ndarray, size: Union[Tuple[int, int], int]) -> np.nda class S3Downloader: - s3 = boto3.client("s3") + s3 = boto3.client( + "s3", + aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), + region_name=os.getenv("REGION_NAME"), + ) @staticmethod def split_s3_path(path: str) -> Tuple[str, str]: