From 5120cf28aeac3e4fdce01a9f36ca4ec8fe450e7a Mon Sep 17 00:00:00 2001 From: Taylor Denouden Date: Tue, 23 Jan 2024 13:15:59 -0800 Subject: [PATCH] Add flag to do test time augmentation to improve prediction robustness (#98) * Add rotate and flip test time augmentation Accumulate logits, with maximum confidence output being the one used to predict the output * Add test time augmentation interfaces to CLI and lib functions * Update docs --- docs/cli.md | 2 ++ kelp_o_matic/cli.py | 16 ++++++++++++++-- kelp_o_matic/lib.py | 18 ++++++++++++++++-- kelp_o_matic/managers.py | 15 +++++++++++++++ 4 files changed, 47 insertions(+), 4 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index dab2155..de5beea 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -39,6 +39,7 @@ $ kom find-kelp --help --rgbi --rgb Use RGB and NIR bands for classification. Assumes RGBI ordering. [default: rgb] -b INTEGER GDAL-style band re-ordering flag. Defaults to RGB or RGBI order. To e.g., reorder a BGRI image at runtime, pass flags `-b 3 -b 2 -b 1 -b 4`. [default: None] --gpu --no-gpu Enable or disable GPU, if available. [default: gpu] + --tta --no-tta Use test time augmentation to improve accuracy at the cost of processing time. [default: no-tta] --help -h Show this message and exit. ``` @@ -110,6 +111,7 @@ $ kom find-mussels --help --crop-size INTEGER The data window size to run through the segmentation model. [default: 1024] -b INTEGER GDAL-style band re-ordering flag. Defaults to RGB order. To e.g., reorder a BGR image at runtime, pass flags `-b 3 -b 2 -b 1`. [default: None] --gpu --no-gpu Enable or disable GPU, if available. [default: gpu] + --tta --no-tta Use test time augmentation to improve accuracy at the cost of processing time. [default: no-tta] --help -h Show this message and exit. ``` diff --git a/kelp_o_matic/cli.py b/kelp_o_matic/cli.py index 0638906..a020c01 100644 --- a/kelp_o_matic/cli.py +++ b/kelp_o_matic/cli.py @@ -39,12 +39,18 @@ def find_kelp( use_gpu: bool = typer.Option( True, "--gpu/--no-gpu", help="Enable or disable GPU, if available." ), + use_tta: bool = typer.Option( + False, + "--tta/--no-tta", + help="Use test time augmentation to improve accuracy at the cost of " + "processing time.", + ), ): """ Detect kelp in image at path SOURCE and output the resulting classification raster to file at path DEST. """ - find_kelp_(source, dest, species, crop_size, use_nir, band_order, use_gpu) + find_kelp_(source, dest, species, crop_size, use_nir, band_order, use_gpu, use_tta) @cli.command() @@ -64,12 +70,18 @@ def find_mussels( use_gpu: bool = typer.Option( True, "--gpu/--no-gpu", help="Enable or disable GPU, if available." ), + use_tta: bool = typer.Option( + False, + "--tta/--no-tta", + help="Use test time augmentation to improve accuracy at the cost of " + "processing time.", + ), ): """ Detect mussels in image at path SOURCE and output the resulting classification raster to file at path DEST. """ - find_mussels_(source, dest, crop_size, band_order, use_gpu) + find_mussels_(source, dest, crop_size, band_order, use_gpu, use_tta) def version_callback(value: bool) -> None: diff --git a/kelp_o_matic/lib.py b/kelp_o_matic/lib.py index 98d1950..8050584 100644 --- a/kelp_o_matic/lib.py +++ b/kelp_o_matic/lib.py @@ -72,6 +72,7 @@ def find_kelp( use_nir: bool = False, band_order: Optional[list[int]] = None, use_gpu: bool = True, + test_time_augmentation: bool = False, ): """ Detect kelp in image at path `source` and output the resulting classification raster @@ -86,6 +87,7 @@ def find_kelp( band_order: GDAL-style band re-ordering. Defaults to RGB or RGBI order. e.g. to reorder a BGRI image at runtime, pass `[3,2,1,4]`. use_gpu: Disable Cuda GPU usage and run on CPU only. + test_time_augmentation: Use test time augmentation to improve model accuracy. """ if not band_order: band_order = [1, 2, 3] @@ -104,7 +106,12 @@ def find_kelp( else: model = KelpRGBPresenceSegmentationModel(use_gpu=use_gpu) RichSegmentationManager( - model, Path(source), Path(dest), band_order=band_order, crop_size=crop_size + model, + Path(source), + Path(dest), + band_order=band_order, + crop_size=crop_size, + test_time_augmentation=test_time_augmentation, )() @@ -114,6 +121,7 @@ def find_mussels( crop_size: int = 1024, band_order: Optional[list[int]] = None, use_gpu: bool = True, + test_time_augmentation: bool = False, ): """ Detect mussels in image at path `source` and output the resulting classification @@ -126,6 +134,7 @@ def find_mussels( band_order: GDAL-style band re-ordering flag. Defaults to RGB order. e.g. to reorder a BGR image at runtime, pass `[3,2,1]`. use_gpu: Disable Cuda GPU usage and run on CPU only. + test_time_augmentation: Use test time augmentation to improve model accuracy. """ if not band_order: band_order = [1, 2, 3] @@ -134,5 +143,10 @@ def find_mussels( _validate_paths(Path(source), Path(dest)) model = MusselRGBPresenceSegmentationModel(use_gpu=use_gpu) RichSegmentationManager( - model, Path(source), Path(dest), band_order=band_order, crop_size=crop_size + model, + Path(source), + Path(dest), + band_order=band_order, + crop_size=crop_size, + test_time_augmentation=test_time_augmentation, )() diff --git a/kelp_o_matic/managers.py b/kelp_o_matic/managers.py index ddd24b7..5f19d98 100644 --- a/kelp_o_matic/managers.py +++ b/kelp_o_matic/managers.py @@ -23,6 +23,7 @@ def __init__( output_path: Union[str, Path], band_order: tuple[int] = (1, 2, 3), crop_size: int = 1024, + test_time_augmentation: bool = False, ): """Create the segmentation object. @@ -35,10 +36,12 @@ def __init__( Defaults to [1, 2, 3] for RGB ordering. crop_size: The size of image crop to classify iteratively until the entire image is classified. + test_time_augmentation: Use test time augmentation to improve accuracy. """ self.model = model self.band_order = band_order self.crop_size = crop_size + self.tta = test_time_augmentation self.input_path = str(Path(input_path).expanduser().resolve()) self.output_path = str(Path(output_path).expanduser().resolve()) @@ -84,6 +87,18 @@ def __call__(self): ) logits = self.model(crop.unsqueeze(0))[0] + if self.tta: + for k in range(1, 4): + aug_crop = torch.rot90(crop, k=k, dims=(1, 2)) + aug_logits = self.model(aug_crop.unsqueeze(0))[0] + unaug_logits = torch.rot90(aug_logits, k=-k, dims=(1, 2)) + logits = torch.maximum(logits, unaug_logits) + for d in [1, 2]: + aug_crop = torch.flip(crop, dims=(d,)) + aug_logits = self.model(aug_crop.unsqueeze(0))[0] + unaug_logits = torch.flip(aug_logits, dims=(d,)) + logits = torch.maximum(logits, unaug_logits) + logits = self.kernel( logits, top=self.reader.is_top_window(read_window),