From 908d5625811efc946270e27106192c605ce51199 Mon Sep 17 00:00:00 2001 From: Zim Gong Date: Fri, 13 Jun 2025 09:18:50 +0800 Subject: [PATCH 1/2] support load stats from hf hub --- experiments/robot/openvla_utils.py | 44 ++++++++++++++++++++++++++++++ verl/workers/fsdp_workers.py | 14 ++-------- 2 files changed, 47 insertions(+), 11 deletions(-) create mode 100644 experiments/robot/openvla_utils.py diff --git a/experiments/robot/openvla_utils.py b/experiments/robot/openvla_utils.py new file mode 100644 index 0000000..b0e3c01 --- /dev/null +++ b/experiments/robot/openvla_utils.py @@ -0,0 +1,44 @@ +"""Utils for evaluating OpenVLA or fine-tuned OpenVLA policies.""" + +import json +import os + +import torch +from huggingface_hub import HfApi, hf_hub_download + + +def model_is_on_hf_hub(model_path: str) -> bool: + """Checks whether a model path points to a model on Hugging Face Hub.""" + # If the API call below runs without error, the model is on the hub + try: + HfApi().model_info(model_path) + return True + except Exception: + return False + +def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None: + """ + Load dataset statistics used during training for action normalization. + + Args: + vla: The VLA model + checkpoint_path: Path to the checkpoint directory + """ + if model_is_on_hf_hub(checkpoint_path): + # Download dataset stats directly from HF Hub + dataset_statistics_path = hf_hub_download( + repo_id=checkpoint_path, + filename="dataset_statistics.json", + ) + else: + dataset_statistics_path = os.path.join(checkpoint_path, "dataset_statistics.json") + if os.path.isfile(dataset_statistics_path): + with open(dataset_statistics_path, "r") as f: + norm_stats = json.load(f) + vla.norm_stats = norm_stats + else: + print( + "WARNING: No local dataset_statistics.json file found for current checkpoint.\n" + "You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint." + "Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`." + ) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 7463b96..9e4c853 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -45,6 +45,8 @@ from peft import LoraConfig, PeftModel, get_peft_model, TaskType import json +from experiments.robot.openvla_utils import _load_dataset_stats + logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) @@ -196,17 +198,7 @@ def _build_model_optimizer(self, #oft add actor_module.vision_backbone.set_num_images_in_input(self.config.actor.num_images_in_input) - dataset_statistics_path = os.path.join(local_path, "dataset_statistics.json") - if os.path.isfile(dataset_statistics_path): - with open(dataset_statistics_path, "r") as f: - norm_stats = json.load(f) - actor_module.norm_stats = norm_stats - else: - print( - "WARNING: No local dataset_statistics.json file found for current checkpoint.\n" - "You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint." - "Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`." - ) + _load_dataset_stats(actor_module, local_path) elif self.config.model.vla == "openvla": actor_module = AutoModelForVision2Seq.from_pretrained( pretrained_model_name_or_path=local_path, From b2c6a11c95a79da8b0df4d35ff0e410d71dc8763 Mon Sep 17 00:00:00 2001 From: Haozhan Li <92256546+zhan72@users.noreply.github.com> Date: Fri, 13 Jun 2025 11:12:27 +0800 Subject: [PATCH 2/2] Update openvla_utils.py --- experiments/robot/openvla_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/experiments/robot/openvla_utils.py b/experiments/robot/openvla_utils.py index b0e3c01..ec9d757 100644 --- a/experiments/robot/openvla_utils.py +++ b/experiments/robot/openvla_utils.py @@ -26,10 +26,14 @@ def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None: """ if model_is_on_hf_hub(checkpoint_path): # Download dataset stats directly from HF Hub - dataset_statistics_path = hf_hub_download( - repo_id=checkpoint_path, - filename="dataset_statistics.json", - ) + try: + dataset_statistics_path = hf_hub_download( + repo_id=checkpoint_path, + filename="dataset_statistics.json", + ) + except Exception as e: + print(f"Failed to download dataset_statistics.json from HF Hub: {e}") + dataset_statistics_path = "" else: dataset_statistics_path = os.path.join(checkpoint_path, "dataset_statistics.json") if os.path.isfile(dataset_statistics_path):