Skip to content

Commit

Permalink
return the pytorch file out of extracted zip file
Browse files Browse the repository at this point in the history
Former-commit-id: 20e5c5b [formerly f4319d7]
Former-commit-id: cb71c74
  • Loading branch information
mese79 committed Jun 4, 2024
1 parent 6433d43 commit e8961f4
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/featureforest/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def download_model(
known_hash=None,
processor=pooch.Unzip() if is_archived else None
)
# for zip files, get the file ending with "pt" or "pth" as model weights file.
if is_archived:
pytorch_files = [
f for f in downloaded_file
if Path(f).suffix in ["pt", "pth"]
]
downloaded_file = pytorch_files[0] if len(pytorch_files) > 0 else None

return downloaded_file

except Exception as err:
Expand Down

0 comments on commit e8961f4

Please sign in to comment.