Skip to content

Commit

Permalink
Add eelgrass detection model
Browse files Browse the repository at this point in the history
Fix eelgrass outputs
  • Loading branch information
tayden committed Feb 9, 2024
1 parent 044a9c3 commit 0b60a07
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 151 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,5 @@ conda/index.html
.idea/

*.profile

*.tif
3 changes: 2 additions & 1 deletion kelp_o_matic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from kelp_o_matic.lib import find_kelp, find_mussels
from kelp_o_matic.lib import find_kelp, find_mussels, find_eelgrass

__all__ = [
"find_kelp",
"find_mussels",
"find_eelgrass",
]
__version__ = "0.0.0"
32 changes: 32 additions & 0 deletions kelp_o_matic/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
__version__,
find_kelp as find_kelp_,
find_mussels as find_mussels_,
find_eelgrass as find_eelgrass_,
)

cli = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]})
Expand Down Expand Up @@ -84,6 +85,37 @@ def find_mussels(
find_mussels_(source, dest, crop_size, band_order, use_gpu, use_tta)


@cli.command()
def find_eelgrass(
source: Path = typer.Argument(..., help="Input image with Byte data type."),
dest: Path = typer.Argument(..., help="File path location to save output to."),
crop_size: int = typer.Option(
1024,
help="The data window size to run through the segmentation model.",
),
band_order: Optional[list[int]] = typer.Option(
None,
"-b",
help="GDAL-style band re-ordering flag. Defaults to RGB order. "
"To e.g., reorder a BGR image at runtime, pass flags `-b 3 -b 2 -b 1`.",
),
use_gpu: bool = typer.Option(
True, "--gpu/--no-gpu", help="Enable or disable GPU, if available."
),
use_tta: bool = typer.Option(
False,
"--tta/--no-tta",
help="Use test time augmentation to improve accuracy at the cost of "
"processing time.",
),
):
"""
Detect eelgrass in image at path SOURCE and output the resulting classification
raster to file at path DEST.
"""
find_eelgrass_(source, dest, crop_size, band_order, use_gpu, use_tta)


def version_callback(value: bool) -> None:
if value:
typer.echo(f"Kelp-O-Matic {__version__}")
Expand Down
2 changes: 1 addition & 1 deletion kelp_o_matic/hann.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,6 @@ def step(self, new_logits: torch.Tensor, img_window: Window):
height=min(self.hws, img_window.height),
width=min(self.hws, img_window.width),
)
preds = logits_a[:, : img_window.height, : img_window.width].softmax(axis=0)
preds = logits_a[:, : img_window.height, : img_window.width]

return preds, preds_win
38 changes: 38 additions & 0 deletions kelp_o_matic/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
MusselRGBPresenceSegmentationModel,
KelpRGBIPresenceSegmentationModel,
KelpRGBISpeciesSegmentationModel,
SeagrassPresenceSegmentationModel,
)


Expand Down Expand Up @@ -150,3 +151,40 @@ def find_mussels(
crop_size=crop_size,
test_time_augmentation=test_time_augmentation,
)()


def find_eelgrass(
source: Union[str, Path],
dest: Union[str, Path],
crop_size: int = 1024,
band_order: Optional[list[int]] = None,
use_gpu: bool = True,
test_time_augmentation: bool = False,
):
"""
Detect eelgrass in image at path `source` and output the resulting classification
raster to file at path `dest`.
Args:
source: Input image with Byte data type.
dest: File path location to save output to.
crop_size: The size of cropped image square run through the segmentation model.
band_order: GDAL-style band re-ordering flag. Defaults to RGB order.
e.g. to reorder a BGR image at runtime, pass `[3,2,1]`.
use_gpu: Disable Cuda GPU usage and run on CPU only.
test_time_augmentation: Use test time augmentation to improve model accuracy.
"""
if not band_order:
band_order = [1, 2, 3]

_validate_band_order(band_order)
_validate_paths(Path(source), Path(dest))
model = SeagrassPresenceSegmentationModel(use_gpu=use_gpu)
RichSegmentationManager(
model,
Path(source),
Path(dest),
band_order=band_order,
crop_size=crop_size,
test_time_augmentation=test_time_augmentation,
)()
13 changes: 13 additions & 0 deletions kelp_o_matic/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,16 @@ class KelpRGBISpeciesSegmentationModel(_SpeciesSegmentationModel):
@staticmethod
def transform(x: Union[np.ndarray, Image]) -> torch.Tensor:
return _unet_efficientnet_b4_transform(x)


class SeagrassPresenceSegmentationModel(_Model):
register_depth = 1

torchscript_path = (
"UNetPlusPlus_EfficientNetB5_seagrass_presence_rgb_jit_dice=0.8498.pt"
)

def post_process(self, x: "torch.Tensor") -> "np.ndarray":
with torch.no_grad():
label = torch.sigmoid(x) > 0.5 # 0: bg, 1: eelgrass
return label.squeeze(0).detach().cpu().numpy().astype(np.uint8)
Loading

0 comments on commit 0b60a07

Please sign in to comment.