diff --git a/experiments/robot/openvla_utils.py b/experiments/robot/openvla_utils.py new file mode 100644 index 0000000..ec9d757 --- /dev/null +++ b/experiments/robot/openvla_utils.py @@ -0,0 +1,48 @@ +"""Utils for evaluating OpenVLA or fine-tuned OpenVLA policies.""" + +import json +import os + +import torch +from huggingface_hub import HfApi, hf_hub_download + + +def model_is_on_hf_hub(model_path: str) -> bool: + """Checks whether a model path points to a model on Hugging Face Hub.""" + # If the API call below runs without error, the model is on the hub + try: + HfApi().model_info(model_path) + return True + except Exception: + return False + +def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None: + """ + Load dataset statistics used during training for action normalization. + + Args: + vla: The VLA model + checkpoint_path: Path to the checkpoint directory + """ + if model_is_on_hf_hub(checkpoint_path): + # Download dataset stats directly from HF Hub + try: + dataset_statistics_path = hf_hub_download( + repo_id=checkpoint_path, + filename="dataset_statistics.json", + ) + except Exception as e: + print(f"Failed to download dataset_statistics.json from HF Hub: {e}") + dataset_statistics_path = "" + else: + dataset_statistics_path = os.path.join(checkpoint_path, "dataset_statistics.json") + if os.path.isfile(dataset_statistics_path): + with open(dataset_statistics_path, "r") as f: + norm_stats = json.load(f) + vla.norm_stats = norm_stats + else: + print( + "WARNING: No local dataset_statistics.json file found for current checkpoint.\n" + "You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint." + "Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`." + ) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 7463b96..9e4c853 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -45,6 +45,8 @@ from peft import LoraConfig, PeftModel, get_peft_model, TaskType import json +from experiments.robot.openvla_utils import _load_dataset_stats + logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) @@ -196,17 +198,7 @@ def _build_model_optimizer(self, #oft add actor_module.vision_backbone.set_num_images_in_input(self.config.actor.num_images_in_input) - dataset_statistics_path = os.path.join(local_path, "dataset_statistics.json") - if os.path.isfile(dataset_statistics_path): - with open(dataset_statistics_path, "r") as f: - norm_stats = json.load(f) - actor_module.norm_stats = norm_stats - else: - print( - "WARNING: No local dataset_statistics.json file found for current checkpoint.\n" - "You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint." - "Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`." - ) + _load_dataset_stats(actor_module, local_path) elif self.config.model.vla == "openvla": actor_module = AutoModelForVision2Seq.from_pretrained( pretrained_model_name_or_path=local_path,