diff --git a/GANDLF/utils/modelio.py b/GANDLF/utils/modelio.py index 23556b206..65152449c 100644 --- a/GANDLF/utils/modelio.py +++ b/GANDLF/utils/modelio.py @@ -1,7 +1,7 @@ import hashlib import os import subprocess -from typing import Any, Dict +from typing import Any, Dict, Optional import torch @@ -136,6 +136,7 @@ def save_model( params: Dict[str, Any], path: str, onnx_export: bool = True, + hf_hub_repo_id: Optional[str] = None, ): """ Save the model dictionary to a file. @@ -146,6 +147,7 @@ def save_model( params (dict): The parameter dictionary. path (str): The path to save the model dictionary to. onnx_export (bool): Whether to export to ONNX and OpenVINO. + hf_hub_repo_id (str): The Hugging Face Hub repo id to push to. Defaults to None (will not push to HF Hub). """ model_dict["timestamp"] = get_unique_timestamp() model_dict["timestamp_hash"] = hashlib.sha256( @@ -169,6 +171,9 @@ def save_model( # post-training optimization optimize_and_save_model(model, params, path, onnx_export=onnx_export) + if hf_hub_repo_id is not None: + push_model_to_hf_hub(model_path=path, repo_path=path, repo_id=hf_hub_repo_id) + #TODO: also push optimized models? def load_model( path: str, device: torch.device, full_sanity_check: bool = True @@ -242,3 +247,115 @@ def load_ov_model(path: str, device: str = "CPU"): output_layer = compiled_model.outputs return compiled_model, input_layer, output_layer + + +def load_model_from_hf_hub( + repo_id: str, + model_filename: str, + revision: str = None, + local_dir: str = None, + device: str = "CPU", + ) -> Dict[str, Any]: + """ + Download and load model from Hugging Face Hub. If the repo is private, credentials must be set beforehand. + + Args: + repo_id (str): The Hugging Face Hub repo id. + model_file_path (str): The model filename in the repo. + revision (Optional[str]): The revision of the model to load. Defaults to + the latest revision. + local_dir (Optional[str]): The local directory to download the model to. Defaults to None. + device (str): The device to run inference, can be "CPU", "GPU" or "MULTI:CPU,GPU". Default is "CPU". + + Returns: + path: Path to model file. Can be used in `load_model`. + """ + try: + import huggingface_hub + except ImportError: + print("WARNING: huggingface_hub is not present.") + + local_model_path = huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=model_filename, + revision=revision, + local_dir=local_dir + ) + + return load_model(local_model_path, device) + + +def upload_model_repo_to_hf_hub( + repo_id: str, + local_path: str, + path_in_repo: str, + model_card_path: Optional[str] = None, + private: Optional[bool] = False, + upload_onnx: Optional[bool] = False, + upload_ov: Optional[bool] = False + ) -> str: + """ + Upload model to repo on Hugging Face Hub. Must be logged in to Hugging Face Hub. + + Args: + repo_id (str): The Hugging Face Hub repo id to upload to. Will create if does not already exist. + local_path (str): The path to the model to upload. + path_in_repo (str): The path to the model in the repo. + model_card_path (Optional[str]): The path to the model card to upload. Defaults to None. + private (Optional[bool]): Whether to make the repo private. Defaults to False. + upload_onnx (Optional[bool]): Whether to upload the ONNX model. Defaults to False. + upload_ov (Optional[bool]): Whether to upload the OpenVINO model. Defaults to False. + + Returns: + str: The revision of the model. + """ + try: + import huggingface_hub + except ImportError: + print("WARNING: huggingface_hub is not present.") + + huggingface_hub.create_repo(repo_id=repo_id, private=private) + + api = huggingface_hub.HfApi() + + api.upload_file(path_or_file_obj=local_path, path_in_repo=path_in_repo, repo_id=repo_id) + + if model_card_path is not None: + #TODO: upload model card + pass + else: + #TODO: create new model card + pass + + #TODO: upload optimized models? + + +def push_model_to_hf_hub( + repo_id: str, + local_path: str, + path_in_repo: str +) -> str: + """ + Push model to repo on Hugging Face Hub. Must be logged in to Hugging Face Hub. + + Args: + repo_id (str): The Hugging Face Hub repo id to push to. + model_path (str): The local path to the model to push. + path_in_repo (str): The path to the model in the repo. + + Returns: + str: The URL to visualize the uploaded file on the hub. + """ + + try: + import huggingface_hub + except ImportError: + print("WARNING: huggingface_hub is not present.") + + api = huggingface_hub.HfApi() + + return api.upload_file( + path_or_file_obj=local_path, + path_in_repo=path_in_repo, + repo_id=repo_id + ) \ No newline at end of file