Skip to content

Commit

Permalink
Merge pull request #3 from juglab/ms/feat/model_downloader
Browse files Browse the repository at this point in the history
added model_downloader using pooch

Former-commit-id: ff41bbe
Former-commit-id: 8d6e853
  • Loading branch information
mese79 authored Jun 4, 2024
2 parents 2685e12 + e8961f4 commit 254716f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 6 deletions.
1 change: 1 addition & 0 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- h5py
- pytorch=2.1.2
- torchvision
- pooch
- pip
- pip:
- numpy
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ install_requires =
magicgui
qtpy
h5py
pooch
; pytorch
; torchvision
timm
Expand Down
17 changes: 11 additions & 6 deletions src/featureforest/SAM/setup_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

from ..utils.downloader import download_model
from .models import MobileSAM


Expand All @@ -10,13 +11,17 @@ def setup_mobile_sam_model():
print(f"running on {device}")
# sam model (light hq sam)
model = MobileSAM.setup_model().to(device)
# load weights
weights = torch.load(
Path(__file__).parent.joinpath(
"./models/weights/mobile_sam.pt"
),
map_location=device
# download model's weights
model_url = "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt"
model_file = download_model(
model_url=model_url,
model_name="mobile_sam.pt"
)
if model_file is None:
raise ValueError(f"Could not download the model from {model_url}.")

# load weights
weights = torch.load(model_file, map_location=device)
model.load_state_dict(weights, strict=True)
model.eval()

Expand Down
44 changes: 44 additions & 0 deletions src/featureforest/utils/downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from pathlib import Path
import pooch


MODELS_CACHE_DIR = Path.home().joinpath(".featureforest", "models")


def download_model(
model_url: str, model_name: str,
cache_dir: Path = MODELS_CACHE_DIR, is_archived: bool = False
) -> str:
"""Download a model weights from a given url.
Args:
model_url (str): the model weights' url.
model_name (str): model's name that will be saved in cache.
cache_dir (Path, optional): download directory. Defaults to CACHE_DIR.
is_archived (bool, optional): set to True to unzip the downloaded file.
Defaults to False.
Returns:
str: full path of the downloaded file.
"""
try:
downloaded_file = pooch.retrieve(
url=model_url,
fname=model_name,
path=cache_dir,
known_hash=None,
processor=pooch.Unzip() if is_archived else None
)
# for zip files, get the file ending with "pt" or "pth" as model weights file.
if is_archived:
pytorch_files = [
f for f in downloaded_file
if Path(f).suffix in ["pt", "pth"]
]
downloaded_file = pytorch_files[0] if len(pytorch_files) > 0 else None

return downloaded_file

except Exception as err:
print(f"\nError while downloading the model:\n{err}")
return None

0 comments on commit 254716f

Please sign in to comment.