From e8961f40a339a293b9f1278b664d6d90abf0d365 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Tue, 4 Jun 2024 11:22:02 +0200 Subject: [PATCH] return the pytorch file out of extracted zip file Former-commit-id: 20e5c5b768cb30920ba0a4d8c00885a575a1c910 [formerly f4319d71e7d1599557401530717c1b78dd1b6116] Former-commit-id: cb71c7499b08b6eee33fc965cedce6af66fa315a --- src/featureforest/utils/downloader.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/featureforest/utils/downloader.py b/src/featureforest/utils/downloader.py index 6191743..8388cbe 100644 --- a/src/featureforest/utils/downloader.py +++ b/src/featureforest/utils/downloader.py @@ -29,6 +29,14 @@ def download_model( known_hash=None, processor=pooch.Unzip() if is_archived else None ) + # for zip files, get the file ending with "pt" or "pth" as model weights file. + if is_archived: + pytorch_files = [ + f for f in downloaded_file + if Path(f).suffix in ["pt", "pth"] + ] + downloaded_file = pytorch_files[0] if len(pytorch_files) > 0 else None + return downloaded_file except Exception as err: