Skip to content

Commit

Permalink
Add RGBI kelp species segmentation model
Browse files Browse the repository at this point in the history
  • Loading branch information
tayden committed Jan 15, 2024
1 parent d3121dd commit 9e57df3
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 60 deletions.
Binary file not shown.
3 changes: 3 additions & 0 deletions kelp_o_matic/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
rgb_kelp_species_torchscript_path = (
_base / "LRASPP_MobileNetV3_kelp_species_rgb_jit_miou=0.9634.pt"
)
rgbi_kelp_species_torchscript_path = (
_base / "UNetPlusPlus_EfficientNetB4_kelp_species_rgbi_jit_miou=0.8432.pt"
)
rgb_mussel_presence_torchscript_path = (
_base / "LRASPP_MobileNetV3_mussel_presence_rgb_jit_miou=0.8745.pt"
)
4 changes: 2 additions & 2 deletions kelp_o_matic/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
KelpRGBSpeciesSegmentationModel,
MusselRGBPresenceSegmentationModel,
KelpRGBIPresenceSegmentationModel,
KelpRGBISpeciesSegmentationModel,
)


Expand Down Expand Up @@ -93,10 +94,9 @@ def find_kelp(

_validate_band_order(band_order, use_nir)
_validate_paths(Path(source), Path(dest))
use_nir = len(band_order) == 4

if use_nir and species:
raise NotImplementedError("RGBI species classification not yet available.")
model = KelpRGBISpeciesSegmentationModel(use_gpu=use_gpu)
elif use_nir:
model = KelpRGBIPresenceSegmentationModel(use_gpu=use_gpu)
elif species:
Expand Down
53 changes: 38 additions & 15 deletions kelp_o_matic/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
import importlib.resources as importlib_resources
from abc import ABC, abstractmethod
from abc import ABC, abstractmethod, ABCMeta
from typing import Union

import numpy as np
Expand All @@ -13,6 +13,7 @@
rgb_kelp_species_torchscript_path,
rgb_mussel_presence_torchscript_path,
rgbi_kelp_presence_torchscript_path,
rgbi_kelp_species_torchscript_path,
)


Expand Down Expand Up @@ -63,17 +64,17 @@ def shortcut(self, crop_size: int):
return logits


class KelpRGBPresenceSegmentationModel(_Model):
torchscript_path = rgb_kelp_presence_torchscript_path


class KelpRGBSpeciesSegmentationModel(_Model):
torchscript_path = rgb_kelp_species_torchscript_path
class _SpeciesSegmentationModel(_Model, metaclass=ABCMeta):
register_depth = 4

@property
@abstractmethod
def presence_model_class(self):
raise NotImplementedError

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.presence_model = KelpRGBPresenceSegmentationModel(*args, **kwargs)
self.presence_model = self.presence_model_class(*args, **kwargs)

def __call__(self, x: "torch.Tensor") -> "torch.Tensor":
with torch.no_grad():
Expand All @@ -82,7 +83,7 @@ def __call__(self, x: "torch.Tensor") -> "torch.Tensor":
species_logits = self.model.forward(x) # 0: macro, 1: nerea
logits = torch.concat((presence_logits, species_logits), dim=1)

return logits # 0: bg, 1: kelp, 2: macro, 3: nereo
return logits # [[0: bg, 1: kelp], [0: macro, 1: nereo]]

def post_process(self, x: "torch.Tensor") -> "np.ndarray":
with torch.no_grad():
Expand All @@ -93,18 +94,40 @@ def post_process(self, x: "torch.Tensor") -> "np.ndarray":
return label.detach().cpu().numpy()


class KelpRGBPresenceSegmentationModel(_Model):
torchscript_path = rgb_kelp_presence_torchscript_path


class KelpRGBSpeciesSegmentationModel(_SpeciesSegmentationModel):
torchscript_path = rgb_kelp_species_torchscript_path
presence_model_class = KelpRGBPresenceSegmentationModel


class MusselRGBPresenceSegmentationModel(_Model):
torchscript_path = rgb_mussel_presence_torchscript_path


def _unet_efficientnet_b4_transform(x: Union[np.ndarray, Image]) -> torch.Tensor:
# to float
x = f.to_tensor(x)[:4, :, :].to(torch.float)
# min-max scale
min_, _ = torch.kthvalue(x.flatten().unique(), 2)
max_ = x.flatten().max()
return torch.clamp((x - min_) / (max_ - min_ + 1e-8), 0, 1)


class KelpRGBIPresenceSegmentationModel(_Model):
torchscript_path = rgbi_kelp_presence_torchscript_path

@staticmethod
def transform(x: Union[np.ndarray, Image]) -> torch.Tensor:
# to float
x = f.to_tensor(x)[:4, :, :].to(torch.float)
# min-max scale
min_, _ = torch.kthvalue(x.flatten().unique(), 2)
max_ = x.flatten().max()
return torch.clamp((x - min_) / (max_ - min_ + 1e-8), 0, 1)
return _unet_efficientnet_b4_transform(x)


class KelpRGBISpeciesSegmentationModel(_SpeciesSegmentationModel):
torchscript_path = rgbi_kelp_species_torchscript_path
presence_model_class = KelpRGBIPresenceSegmentationModel

@staticmethod
def transform(x: Union[np.ndarray, Image]) -> torch.Tensor:
return _unet_efficientnet_b4_transform(x)
Binary file not shown.
86 changes: 43 additions & 43 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9e57df3

Please sign in to comment.