diff --git a/env.yml b/env.yml index de72ce7..d3578ec 100644 --- a/env.yml +++ b/env.yml @@ -19,6 +19,7 @@ dependencies: - h5py - pytorch=2.1.2 - torchvision + - pooch - pip - pip: - numpy diff --git a/setup.cfg b/setup.cfg index 754f62c..f384d79 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ install_requires = magicgui qtpy h5py + pooch ; pytorch ; torchvision timm diff --git a/src/featureforest/SAM/setup_model.py b/src/featureforest/SAM/setup_model.py index 172a505..1e33d9a 100644 --- a/src/featureforest/SAM/setup_model.py +++ b/src/featureforest/SAM/setup_model.py @@ -2,6 +2,7 @@ import torch +from ..utils.downloader import download_model from .models import MobileSAM @@ -10,13 +11,17 @@ def setup_mobile_sam_model(): print(f"running on {device}") # sam model (light hq sam) model = MobileSAM.setup_model().to(device) - # load weights - weights = torch.load( - Path(__file__).parent.joinpath( - "./models/weights/mobile_sam.pt" - ), - map_location=device + # download model's weights + model_url = "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt" + model_file = download_model( + model_url=model_url, + model_name="mobile_sam.pt" ) + if model_file is None: + raise ValueError(f"Could not download the model from {model_url}.") + + # load weights + weights = torch.load(model_file, map_location=device) model.load_state_dict(weights, strict=True) model.eval() diff --git a/src/featureforest/utils/downloader.py b/src/featureforest/utils/downloader.py new file mode 100644 index 0000000..8388cbe --- /dev/null +++ b/src/featureforest/utils/downloader.py @@ -0,0 +1,44 @@ +from pathlib import Path +import pooch + + +MODELS_CACHE_DIR = Path.home().joinpath(".featureforest", "models") + + +def download_model( + model_url: str, model_name: str, + cache_dir: Path = MODELS_CACHE_DIR, is_archived: bool = False +) -> str: + """Download a model weights from a given url. + + Args: + model_url (str): the model weights' url. + model_name (str): model's name that will be saved in cache. + cache_dir (Path, optional): download directory. Defaults to CACHE_DIR. + is_archived (bool, optional): set to True to unzip the downloaded file. + Defaults to False. + + Returns: + str: full path of the downloaded file. + """ + try: + downloaded_file = pooch.retrieve( + url=model_url, + fname=model_name, + path=cache_dir, + known_hash=None, + processor=pooch.Unzip() if is_archived else None + ) + # for zip files, get the file ending with "pt" or "pth" as model weights file. + if is_archived: + pytorch_files = [ + f for f in downloaded_file + if Path(f).suffix in ["pt", "pth"] + ] + downloaded_file = pytorch_files[0] if len(pytorch_files) > 0 else None + + return downloaded_file + + except Exception as err: + print(f"\nError while downloading the model:\n{err}") + return None