Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flag to do test time augmentation to improve prediction robustness #98

Merged
merged 3 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ $ kom find-kelp --help
--rgbi --rgb Use RGB and NIR bands for classification. Assumes RGBI ordering. [default: rgb]
-b INTEGER GDAL-style band re-ordering flag. Defaults to RGB or RGBI order. To e.g., reorder a BGRI image at runtime, pass flags `-b 3 -b 2 -b 1 -b 4`. [default: None]
--gpu --no-gpu Enable or disable GPU, if available. [default: gpu]
--tta --no-tta Use test time augmentation to improve accuracy at the cost of processing time. [default: no-tta]
--help -h Show this message and exit.
```

Expand Down Expand Up @@ -110,6 +111,7 @@ $ kom find-mussels --help
--crop-size INTEGER The data window size to run through the segmentation model. [default: 1024]
-b INTEGER GDAL-style band re-ordering flag. Defaults to RGB order. To e.g., reorder a BGR image at runtime, pass flags `-b 3 -b 2 -b 1`. [default: None]
--gpu --no-gpu Enable or disable GPU, if available. [default: gpu]
--tta --no-tta Use test time augmentation to improve accuracy at the cost of processing time. [default: no-tta]
--help -h Show this message and exit.
```

Expand Down
16 changes: 14 additions & 2 deletions kelp_o_matic/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ def find_kelp(
use_gpu: bool = typer.Option(
True, "--gpu/--no-gpu", help="Enable or disable GPU, if available."
),
use_tta: bool = typer.Option(
False,
"--tta/--no-tta",
help="Use test time augmentation to improve accuracy at the cost of "
"processing time.",
),
):
"""
Detect kelp in image at path SOURCE and output the resulting classification raster
to file at path DEST.
"""
find_kelp_(source, dest, species, crop_size, use_nir, band_order, use_gpu)
find_kelp_(source, dest, species, crop_size, use_nir, band_order, use_gpu, use_tta)


@cli.command()
Expand All @@ -64,12 +70,18 @@ def find_mussels(
use_gpu: bool = typer.Option(
True, "--gpu/--no-gpu", help="Enable or disable GPU, if available."
),
use_tta: bool = typer.Option(
False,
"--tta/--no-tta",
help="Use test time augmentation to improve accuracy at the cost of "
"processing time.",
),
):
"""
Detect mussels in image at path SOURCE and output the resulting classification
raster to file at path DEST.
"""
find_mussels_(source, dest, crop_size, band_order, use_gpu)
find_mussels_(source, dest, crop_size, band_order, use_gpu, use_tta)


def version_callback(value: bool) -> None:
Expand Down
18 changes: 16 additions & 2 deletions kelp_o_matic/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def find_kelp(
use_nir: bool = False,
band_order: Optional[list[int]] = None,
use_gpu: bool = True,
test_time_augmentation: bool = False,
):
"""
Detect kelp in image at path `source` and output the resulting classification raster
Expand All @@ -86,6 +87,7 @@ def find_kelp(
band_order: GDAL-style band re-ordering. Defaults to RGB or RGBI order.
e.g. to reorder a BGRI image at runtime, pass `[3,2,1,4]`.
use_gpu: Disable Cuda GPU usage and run on CPU only.
test_time_augmentation: Use test time augmentation to improve model accuracy.
"""
if not band_order:
band_order = [1, 2, 3]
Expand All @@ -104,7 +106,12 @@ def find_kelp(
else:
model = KelpRGBPresenceSegmentationModel(use_gpu=use_gpu)
RichSegmentationManager(
model, Path(source), Path(dest), band_order=band_order, crop_size=crop_size
model,
Path(source),
Path(dest),
band_order=band_order,
crop_size=crop_size,
test_time_augmentation=test_time_augmentation,
)()


Expand All @@ -114,6 +121,7 @@ def find_mussels(
crop_size: int = 1024,
band_order: Optional[list[int]] = None,
use_gpu: bool = True,
test_time_augmentation: bool = False,
):
"""
Detect mussels in image at path `source` and output the resulting classification
Expand All @@ -126,6 +134,7 @@ def find_mussels(
band_order: GDAL-style band re-ordering flag. Defaults to RGB order.
e.g. to reorder a BGR image at runtime, pass `[3,2,1]`.
use_gpu: Disable Cuda GPU usage and run on CPU only.
test_time_augmentation: Use test time augmentation to improve model accuracy.
"""
if not band_order:
band_order = [1, 2, 3]
Expand All @@ -134,5 +143,10 @@ def find_mussels(
_validate_paths(Path(source), Path(dest))
model = MusselRGBPresenceSegmentationModel(use_gpu=use_gpu)
RichSegmentationManager(
model, Path(source), Path(dest), band_order=band_order, crop_size=crop_size
model,
Path(source),
Path(dest),
band_order=band_order,
crop_size=crop_size,
test_time_augmentation=test_time_augmentation,
)()
15 changes: 15 additions & 0 deletions kelp_o_matic/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
output_path: Union[str, Path],
band_order: tuple[int] = (1, 2, 3),
crop_size: int = 1024,
test_time_augmentation: bool = False,
):
"""Create the segmentation object.

Expand All @@ -35,10 +36,12 @@ def __init__(
Defaults to [1, 2, 3] for RGB ordering.
crop_size: The size of image crop to classify iteratively until the entire
image is classified.
test_time_augmentation: Use test time augmentation to improve accuracy.
"""
self.model = model
self.band_order = band_order
self.crop_size = crop_size
self.tta = test_time_augmentation
self.input_path = str(Path(input_path).expanduser().resolve())
self.output_path = str(Path(output_path).expanduser().resolve())

Expand Down Expand Up @@ -84,6 +87,18 @@ def __call__(self):
)
logits = self.model(crop.unsqueeze(0))[0]

if self.tta:
for k in range(1, 4):
aug_crop = torch.rot90(crop, k=k, dims=(1, 2))
aug_logits = self.model(aug_crop.unsqueeze(0))[0]
unaug_logits = torch.rot90(aug_logits, k=-k, dims=(1, 2))
logits = torch.maximum(logits, unaug_logits)
for d in [1, 2]:
aug_crop = torch.flip(crop, dims=(d,))
aug_logits = self.model(aug_crop.unsqueeze(0))[0]
unaug_logits = torch.flip(aug_logits, dims=(d,))
logits = torch.maximum(logits, unaug_logits)

logits = self.kernel(
logits,
top=self.reader.is_top_window(read_window),
Expand Down
Loading