diff --git a/conda/meta.yaml b/conda/meta.yaml index d32baf5..0f00b15 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -32,7 +32,6 @@ requirements: test: imports: - kelp_o_matic - - kelp_o_matic.data - kelp_o_matic.geotiff_io commands: - pip check diff --git a/kelp_o_matic/data/__init__.py b/kelp_o_matic/data/__init__.py deleted file mode 100644 index 71d66c8..0000000 --- a/kelp_o_matic/data/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -import importlib.resources as importlib_resources - -_base = importlib_resources.files(__name__) - -rgb_kelp_presence_torchscript_path = ( - _base / "LRASPP_MobileNetV3_kelp_presence_rgb_jit_miou=0.8023.pt" -) -rgbi_kelp_presence_torchscript_path = ( - _base / "UNetPlusPlus_EfficientNetB4_kelp_presence_rgbi_jit_miou=0.8785.pt" -) -rgb_kelp_species_torchscript_path = ( - _base / "LRASPP_MobileNetV3_kelp_species_rgb_jit_miou=0.9634.pt" -) -rgbi_kelp_species_torchscript_path = ( - _base / "UNetPlusPlus_EfficientNetB4_kelp_species_rgbi_jit_miou=0.8432.pt" -) -rgb_mussel_presence_torchscript_path = ( - _base / "LRASPP_MobileNetV3_mussel_presence_rgb_jit_miou=0.8745.pt" -) diff --git a/kelp_o_matic/models.py b/kelp_o_matic/models.py index ddd79ca..df2efe1 100644 --- a/kelp_o_matic/models.py +++ b/kelp_o_matic/models.py @@ -1,5 +1,4 @@ import gc -import importlib.resources as importlib_resources from abc import ABC, abstractmethod, ABCMeta from typing import Union @@ -8,13 +7,7 @@ import torchvision.transforms.functional as f from PIL.Image import Image -from kelp_o_matic.data import ( - rgb_kelp_presence_torchscript_path, - rgb_kelp_species_torchscript_path, - rgb_mussel_presence_torchscript_path, - rgbi_kelp_presence_torchscript_path, - rgbi_kelp_species_torchscript_path, -) +from kelp_o_matic.utils import lazy_load_params class _Model(ABC): @@ -37,8 +30,8 @@ def torchscript_path(self): raise NotImplementedError def load_model(self) -> "torch.nn.Module": - with importlib_resources.as_file(self.torchscript_path) as ts: - model = torch.jit.load(ts, map_location=self.device) + params_file = lazy_load_params(self.torchscript_path) + model = torch.jit.load(params_file, map_location=self.device) model.eval() return model @@ -95,16 +88,16 @@ def post_process(self, x: "torch.Tensor") -> "np.ndarray": class KelpRGBPresenceSegmentationModel(_Model): - torchscript_path = rgb_kelp_presence_torchscript_path + torchscript_path = "LRASPP_MobileNetV3_kelp_presence_rgb_jit_miou=0.8023.pt" class KelpRGBSpeciesSegmentationModel(_SpeciesSegmentationModel): - torchscript_path = rgb_kelp_species_torchscript_path + torchscript_path = "LRASPP_MobileNetV3_kelp_species_rgb_jit_miou=0.9634.pt" presence_model_class = KelpRGBPresenceSegmentationModel class MusselRGBPresenceSegmentationModel(_Model): - torchscript_path = rgb_mussel_presence_torchscript_path + torchscript_path = "LRASPP_MobileNetV3_mussel_presence_rgb_jit_miou=0.8745.pt" def _unet_efficientnet_b4_transform(x: Union[np.ndarray, Image]) -> torch.Tensor: @@ -117,7 +110,9 @@ def _unet_efficientnet_b4_transform(x: Union[np.ndarray, Image]) -> torch.Tensor class KelpRGBIPresenceSegmentationModel(_Model): - torchscript_path = rgbi_kelp_presence_torchscript_path + torchscript_path = ( + "UNetPlusPlus_EfficientNetB4_kelp_presence_rgbi_jit_miou=0.8785.pt" + ) @staticmethod def transform(x: Union[np.ndarray, Image]) -> torch.Tensor: @@ -125,7 +120,9 @@ def transform(x: Union[np.ndarray, Image]) -> torch.Tensor: class KelpRGBISpeciesSegmentationModel(_SpeciesSegmentationModel): - torchscript_path = rgbi_kelp_species_torchscript_path + torchscript_path = ( + "UNetPlusPlus_EfficientNetB4_kelp_species_rgbi_jit_miou=0.8432.pt" + ) presence_model_class = KelpRGBIPresenceSegmentationModel @staticmethod diff --git a/kelp_o_matic/utils.py b/kelp_o_matic/utils.py index 444c9e0..34a249e 100644 --- a/kelp_o_matic/utils.py +++ b/kelp_o_matic/utils.py @@ -1,2 +1,58 @@ +import os +import tempfile +import urllib.request +from pathlib import Path + +from rich.progress import Progress + +S3_BUCKET = "https://kelp-o-matic.s3.amazonaws.com/pt_jit" +CACHE_DIR = Path("~/.cache/kelp_o_matic").expanduser() + + +def lazy_load_params(object_name: str): + object_name = object_name + remote_url = f"{S3_BUCKET}/{object_name}" + local_file = CACHE_DIR / object_name + + # Create cache directory if it doesn't exist + if not CACHE_DIR.is_dir(): + CACHE_DIR.mkdir(parents=True) + + # Download file if it doesn't exist + if not local_file.is_file(): + download_file(remote_url, local_file) + + return local_file + + +def download_file(url: str, filename: Path): + # Make a request to the URL + response = urllib.request.urlopen(url) + + # Get the total size of the file + file_size = int(response.getheader("Content-Length")) + + # Create a task with the total file size + with Progress(transient=True) as progress: + task = progress.add_task(f"Downloading {filename.name}...", total=file_size) + + # Download the file + with tempfile.NamedTemporaryFile("wb") as f: + # Read data in chunks (e.g., 1024 bytes) + while True: + chunk = response.read(1024) + if not chunk: + break + f.write(chunk) + + # Update progress bar + progress.update(task, advance=len(chunk)) + + # Move the file to the cache directory once downloaded + f.flush() + os.fsync(f.fileno()) + filename.hardlink_to(f.name) + + def all_same(items): return all(x == items[0] for x in items) diff --git a/kelp_o_matic/data/LRASPP_MobileNetV3_kelp_presence_rgb_jit_miou=0.8023.pt b/pt_jit/LRASPP_MobileNetV3_kelp_presence_rgb_jit_miou=0.8023.pt similarity index 100% rename from kelp_o_matic/data/LRASPP_MobileNetV3_kelp_presence_rgb_jit_miou=0.8023.pt rename to pt_jit/LRASPP_MobileNetV3_kelp_presence_rgb_jit_miou=0.8023.pt diff --git a/kelp_o_matic/data/LRASPP_MobileNetV3_kelp_species_rgb_jit_miou=0.9634.pt b/pt_jit/LRASPP_MobileNetV3_kelp_species_rgb_jit_miou=0.9634.pt similarity index 100% rename from kelp_o_matic/data/LRASPP_MobileNetV3_kelp_species_rgb_jit_miou=0.9634.pt rename to pt_jit/LRASPP_MobileNetV3_kelp_species_rgb_jit_miou=0.9634.pt diff --git a/kelp_o_matic/data/LRASPP_MobileNetV3_mussel_presence_rgb_jit_miou=0.8745.pt b/pt_jit/LRASPP_MobileNetV3_mussel_presence_rgb_jit_miou=0.8745.pt similarity index 100% rename from kelp_o_matic/data/LRASPP_MobileNetV3_mussel_presence_rgb_jit_miou=0.8745.pt rename to pt_jit/LRASPP_MobileNetV3_mussel_presence_rgb_jit_miou=0.8745.pt diff --git a/kelp_o_matic/data/UNetPlusPlus_EfficientNetB4_kelp_presence_rgbi_jit_miou=0.8785.pt b/pt_jit/UNetPlusPlus_EfficientNetB4_kelp_presence_rgbi_jit_miou=0.8785.pt similarity index 100% rename from kelp_o_matic/data/UNetPlusPlus_EfficientNetB4_kelp_presence_rgbi_jit_miou=0.8785.pt rename to pt_jit/UNetPlusPlus_EfficientNetB4_kelp_presence_rgbi_jit_miou=0.8785.pt diff --git a/kelp_o_matic/data/UNetPlusPlus_EfficientNetB4_kelp_species_rgbi_jit_miou=0.8432.pt b/pt_jit/UNetPlusPlus_EfficientNetB4_kelp_species_rgbi_jit_miou=0.8432.pt similarity index 100% rename from kelp_o_matic/data/UNetPlusPlus_EfficientNetB4_kelp_species_rgbi_jit_miou=0.8432.pt rename to pt_jit/UNetPlusPlus_EfficientNetB4_kelp_species_rgbi_jit_miou=0.8432.pt