Skip to content

Commit

Permalink
Add flag to do test time augmentation to improve prediction robustness (
Browse files Browse the repository at this point in the history
#98)

* Add rotate and flip test time augmentation

Accumulate logits, with maximum confidence output being the one used to predict the output

* Add test time augmentation interfaces to CLI and lib functions

* Update docs
  • Loading branch information
tayden authored Jan 23, 2024
1 parent 49e67e9 commit 5120cf2
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ $ kom find-kelp --help
--rgbi --rgb Use RGB and NIR bands for classification. Assumes RGBI ordering. [default: rgb]
-b INTEGER GDAL-style band re-ordering flag. Defaults to RGB or RGBI order. To e.g., reorder a BGRI image at runtime, pass flags `-b 3 -b 2 -b 1 -b 4`. [default: None]
--gpu --no-gpu Enable or disable GPU, if available. [default: gpu]
--tta --no-tta Use test time augmentation to improve accuracy at the cost of processing time. [default: no-tta]
--help -h Show this message and exit.
```

Expand Down Expand Up @@ -110,6 +111,7 @@ $ kom find-mussels --help
--crop-size INTEGER The data window size to run through the segmentation model. [default: 1024]
-b INTEGER GDAL-style band re-ordering flag. Defaults to RGB order. To e.g., reorder a BGR image at runtime, pass flags `-b 3 -b 2 -b 1`. [default: None]
--gpu --no-gpu Enable or disable GPU, if available. [default: gpu]
--tta --no-tta Use test time augmentation to improve accuracy at the cost of processing time. [default: no-tta]
--help -h Show this message and exit.
```

Expand Down
16 changes: 14 additions & 2 deletions kelp_o_matic/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ def find_kelp(
use_gpu: bool = typer.Option(
True, "--gpu/--no-gpu", help="Enable or disable GPU, if available."
),
use_tta: bool = typer.Option(
False,
"--tta/--no-tta",
help="Use test time augmentation to improve accuracy at the cost of "
"processing time.",
),
):
"""
Detect kelp in image at path SOURCE and output the resulting classification raster
to file at path DEST.
"""
find_kelp_(source, dest, species, crop_size, use_nir, band_order, use_gpu)
find_kelp_(source, dest, species, crop_size, use_nir, band_order, use_gpu, use_tta)


@cli.command()
Expand All @@ -64,12 +70,18 @@ def find_mussels(
use_gpu: bool = typer.Option(
True, "--gpu/--no-gpu", help="Enable or disable GPU, if available."
),
use_tta: bool = typer.Option(
False,
"--tta/--no-tta",
help="Use test time augmentation to improve accuracy at the cost of "
"processing time.",
),
):
"""
Detect mussels in image at path SOURCE and output the resulting classification
raster to file at path DEST.
"""
find_mussels_(source, dest, crop_size, band_order, use_gpu)
find_mussels_(source, dest, crop_size, band_order, use_gpu, use_tta)


def version_callback(value: bool) -> None:
Expand Down
18 changes: 16 additions & 2 deletions kelp_o_matic/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def find_kelp(
use_nir: bool = False,
band_order: Optional[list[int]] = None,
use_gpu: bool = True,
test_time_augmentation: bool = False,
):
"""
Detect kelp in image at path `source` and output the resulting classification raster
Expand All @@ -86,6 +87,7 @@ def find_kelp(
band_order: GDAL-style band re-ordering. Defaults to RGB or RGBI order.
e.g. to reorder a BGRI image at runtime, pass `[3,2,1,4]`.
use_gpu: Disable Cuda GPU usage and run on CPU only.
test_time_augmentation: Use test time augmentation to improve model accuracy.
"""
if not band_order:
band_order = [1, 2, 3]
Expand All @@ -104,7 +106,12 @@ def find_kelp(
else:
model = KelpRGBPresenceSegmentationModel(use_gpu=use_gpu)
RichSegmentationManager(
model, Path(source), Path(dest), band_order=band_order, crop_size=crop_size
model,
Path(source),
Path(dest),
band_order=band_order,
crop_size=crop_size,
test_time_augmentation=test_time_augmentation,
)()


Expand All @@ -114,6 +121,7 @@ def find_mussels(
crop_size: int = 1024,
band_order: Optional[list[int]] = None,
use_gpu: bool = True,
test_time_augmentation: bool = False,
):
"""
Detect mussels in image at path `source` and output the resulting classification
Expand All @@ -126,6 +134,7 @@ def find_mussels(
band_order: GDAL-style band re-ordering flag. Defaults to RGB order.
e.g. to reorder a BGR image at runtime, pass `[3,2,1]`.
use_gpu: Disable Cuda GPU usage and run on CPU only.
test_time_augmentation: Use test time augmentation to improve model accuracy.
"""
if not band_order:
band_order = [1, 2, 3]
Expand All @@ -134,5 +143,10 @@ def find_mussels(
_validate_paths(Path(source), Path(dest))
model = MusselRGBPresenceSegmentationModel(use_gpu=use_gpu)
RichSegmentationManager(
model, Path(source), Path(dest), band_order=band_order, crop_size=crop_size
model,
Path(source),
Path(dest),
band_order=band_order,
crop_size=crop_size,
test_time_augmentation=test_time_augmentation,
)()
15 changes: 15 additions & 0 deletions kelp_o_matic/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
output_path: Union[str, Path],
band_order: tuple[int] = (1, 2, 3),
crop_size: int = 1024,
test_time_augmentation: bool = False,
):
"""Create the segmentation object.
Expand All @@ -35,10 +36,12 @@ def __init__(
Defaults to [1, 2, 3] for RGB ordering.
crop_size: The size of image crop to classify iteratively until the entire
image is classified.
test_time_augmentation: Use test time augmentation to improve accuracy.
"""
self.model = model
self.band_order = band_order
self.crop_size = crop_size
self.tta = test_time_augmentation
self.input_path = str(Path(input_path).expanduser().resolve())
self.output_path = str(Path(output_path).expanduser().resolve())

Expand Down Expand Up @@ -84,6 +87,18 @@ def __call__(self):
)
logits = self.model(crop.unsqueeze(0))[0]

if self.tta:
for k in range(1, 4):
aug_crop = torch.rot90(crop, k=k, dims=(1, 2))
aug_logits = self.model(aug_crop.unsqueeze(0))[0]
unaug_logits = torch.rot90(aug_logits, k=-k, dims=(1, 2))
logits = torch.maximum(logits, unaug_logits)
for d in [1, 2]:
aug_crop = torch.flip(crop, dims=(d,))
aug_logits = self.model(aug_crop.unsqueeze(0))[0]
unaug_logits = torch.flip(aug_logits, dims=(d,))
logits = torch.maximum(logits, unaug_logits)

logits = self.kernel(
logits,
top=self.reader.is_top_window(read_window),
Expand Down

0 comments on commit 5120cf2

Please sign in to comment.