Skip to content

Commit

Permalink
setup_model: raising err on failed download; cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mese79 committed Jun 3, 2024
1 parent 2ea98c9 commit 65deee0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
6 changes: 5 additions & 1 deletion src/featureforest/SAM/setup_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ def setup_mobile_sam_model():
# sam model (light hq sam)
model = MobileSAM.setup_model().to(device)
# download model's weights
model_url = "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt"
model_file = download_model(
model_url="https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt",
model_url=model_url,
model_name="mobile_sam.pt"
)
if model_file is None:
raise ValueError(f"Could not download the model from {model_url}.")

# load weights
weights = torch.load(model_file, map_location=device)
model.load_state_dict(weights, strict=True)
Expand Down
8 changes: 0 additions & 8 deletions src/featureforest/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,3 @@ def download_model(
except Exception as err:
print(f"\nError while downloading the model:\n{err}")
return None


if __name__ == "__main__":
model_file = download_model(
model_url="https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt",
model_name="mobile_sam.pt"
)
print(model_file)

0 comments on commit 65deee0

Please sign in to comment.