Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions experiments/robot/openvla_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Utils for evaluating OpenVLA or fine-tuned OpenVLA policies."""

import json
import os

import torch
from huggingface_hub import HfApi, hf_hub_download


def model_is_on_hf_hub(model_path: str) -> bool:
"""Checks whether a model path points to a model on Hugging Face Hub."""
# If the API call below runs without error, the model is on the hub
try:
HfApi().model_info(model_path)
return True
except Exception:
return False

def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None:
"""
Load dataset statistics used during training for action normalization.

Args:
vla: The VLA model
checkpoint_path: Path to the checkpoint directory
"""
if model_is_on_hf_hub(checkpoint_path):
# Download dataset stats directly from HF Hub
try:
dataset_statistics_path = hf_hub_download(
repo_id=checkpoint_path,
filename="dataset_statistics.json",
)
except Exception as e:
print(f"Failed to download dataset_statistics.json from HF Hub: {e}")
dataset_statistics_path = ""
else:
dataset_statistics_path = os.path.join(checkpoint_path, "dataset_statistics.json")
if os.path.isfile(dataset_statistics_path):
with open(dataset_statistics_path, "r") as f:
norm_stats = json.load(f)
vla.norm_stats = norm_stats
else:
print(
"WARNING: No local dataset_statistics.json file found for current checkpoint.\n"
"You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint."
"Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`."
)
14 changes: 3 additions & 11 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from peft import LoraConfig, PeftModel, get_peft_model, TaskType
import json

from experiments.robot.openvla_utils import _load_dataset_stats


logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
Expand Down Expand Up @@ -196,17 +198,7 @@ def _build_model_optimizer(self,
#oft add
actor_module.vision_backbone.set_num_images_in_input(self.config.actor.num_images_in_input)

dataset_statistics_path = os.path.join(local_path, "dataset_statistics.json")
if os.path.isfile(dataset_statistics_path):
with open(dataset_statistics_path, "r") as f:
norm_stats = json.load(f)
actor_module.norm_stats = norm_stats
else:
print(
"WARNING: No local dataset_statistics.json file found for current checkpoint.\n"
"You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint."
"Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`."
)
_load_dataset_stats(actor_module, local_path)
elif self.config.model.vla == "openvla":
actor_module = AutoModelForVision2Seq.from_pretrained(
pretrained_model_name_or_path=local_path,
Expand Down