From e3f12f54d84aa2c246029cec91c2cf386d0a4486 Mon Sep 17 00:00:00 2001 From: kwangneuraco Date: Tue, 3 Feb 2026 15:34:55 +0000 Subject: [PATCH] feat: add tfds dataformat import and combine with rlds, add bridge v2 dataset upload and all relevant tests --- neuracore-dictionary.txt | 5 +- neuracore/importer/config/bridge_v2.yaml | 78 +++ neuracore/importer/core/validation.py | 18 +- neuracore/importer/importer.py | 23 +- neuracore/importer/rlds_tfds_importer.py | 657 ++++++++++++++++++ tests/unit/importer/test_base_validation.py | 12 + .../test_rlds_tfds_importer_behavior.py | 214 ++++++ 7 files changed, 994 insertions(+), 13 deletions(-) create mode 100644 neuracore/importer/config/bridge_v2.yaml create mode 100644 neuracore/importer/rlds_tfds_importer.py create mode 100644 tests/unit/importer/test_rlds_tfds_importer_behavior.py diff --git a/neuracore-dictionary.txt b/neuracore-dictionary.txt index a4f11230..299ba5b1 100644 --- a/neuracore-dictionary.txt +++ b/neuracore-dictionary.txt @@ -251,8 +251,11 @@ nans posinf neginf rotvec +roundoff dataconfig pbtxt +tfrecord +widowx calcsize EADDRINUSE metas @@ -272,4 +275,4 @@ getpid WRONLY rels huggingface -itemsize \ No newline at end of file +itemsize diff --git a/neuracore/importer/config/bridge_v2.yaml b/neuracore/importer/config/bridge_v2.yaml new file mode 100644 index 00000000..2410c33a --- /dev/null +++ b/neuracore/importer/config/bridge_v2.yaml @@ -0,0 +1,78 @@ +input_dataset_name: bridge_v2_dataset_converted_externally_to_tfds + +output_dataset: + name: bridge_v2 + tags: [widowx, bridge_v2, manipulation] + description: "BridgeData V2 is a large and diverse dataset of robotic manipulation + behaviors designed to facilitate research in scalable robot learning. The dataset is + compatible with open-vocabulary, multi-task learning methods conditioned on goal + images or natural language instructions. Skills learned from the data generalize to + novel objects and environments, as well as across institutions. + +. Dataset source: https://rail-berkeley.github.io/bridgedata/" + +robot: + name: widowx + urdf_path: "neuracore/neuracore/importer/config/widowx.urdf" + override_existing: true + +frequency: 5.0 + +data_import_config: + RGB_IMAGES: + source: observation + image_convention: CHANNELS_LAST + order_of_channels: RGB + mapping: + - name: image_0 + source_name: image_0 + - name: image_1 + source_name: image_1 + - name: image_2 + source_name: image_2 + - name: image_3 + source_name: image_3 + + JOINT_POSITIONS: + source: observation.state + units: RADIANS + mapping: + - name: joint_1 + index: 0 + - name: joint_2 + index: 1 + - name: joint_3 + index: 2 + - name: joint_4 + index: 3 + - name: joint_5 + index: 4 + - name: gripper_revolute_joint + index: 5 + + + PARALLEL_GRIPPER_OPEN_AMOUNTS: + source: observation.state + inverted: false + mapping: + - name: gripper_open_amount + index: 6 + + VISUAL_JOINT_POSITIONS: + source: observation.state # path to the data in the input dataset + format: + visual_joint_type: GRIPPER + mapping: + - name: gripper_prismatic_joint_1 + index: 6 + inverted: false + offset: 0.0 + - name: gripper_prismatic_joint_2 + index: 6 + + LANGUAGE: + source: language_instruction + format: + language_type: BYTES + mapping: + - name: instruction diff --git a/neuracore/importer/core/validation.py b/neuracore/importer/core/validation.py index f9aebf52..8ab467d0 100644 --- a/neuracore/importer/core/validation.py +++ b/neuracore/importer/core/validation.py @@ -26,6 +26,10 @@ if DATA_TYPE_TO_NC_DATA_CLASS[dt] == JointData ] +# Small absolute tolerance for floating point comparisons against joint limits. +# This avoids false-positive warnings caused by float32/float64 roundoff. +JOINT_LIMIT_EPSILON = 1e-6 + def validate_rgb_images(data: Any, format: DataFormat) -> None: """Validate RGB image data. @@ -136,17 +140,13 @@ def validate_joint_positions( if name not in joint_info: raise DataValidationError(f"Joint {name} not found in robot model.") - if ( - joint_info[name].limits.lower is not None - and data < joint_info[name].limits.lower - ): + lower_limit = joint_info[name].limits.lower + if lower_limit is not None and data < lower_limit - JOINT_LIMIT_EPSILON: raise DataValidationWarning( - f"Position {data} is below the lower limit {joint_info[name].limits.lower}." + f"Position {data} is below the lower limit {lower_limit}." ) - if ( - joint_info[name].limits.upper is not None - and data > joint_info[name].limits.upper - ): + upper_limit = joint_info[name].limits.upper + if upper_limit is not None and data > upper_limit + JOINT_LIMIT_EPSILON: raise DataValidationWarning( f"Position {data} is above the upper limit {joint_info[name].limits.upper}." ) diff --git a/neuracore/importer/importer.py b/neuracore/importer/importer.py index b5632209..14178b2a 100644 --- a/neuracore/importer/importer.py +++ b/neuracore/importer/importer.py @@ -31,7 +31,10 @@ validate_dataset_config_against_robot_model, ) from neuracore.importer.lerobot_importer import LeRobotDatasetImporter -from neuracore.importer.rlds_importer import RLDSDatasetImporter +from neuracore.importer.rlds_tfds_importer import ( + RLDSDatasetImporter, + TFDSDatasetImporter, +) logger = logging.getLogger(__name__) @@ -251,10 +254,22 @@ def _run_import( logger.info("Setup complete; beginning import.") - importer: RLDSDatasetImporter | LeRobotDatasetImporter | None = None skip_on_error = args.skip_on_error + importer: TFDSDatasetImporter | RLDSDatasetImporter | LeRobotDatasetImporter if dataset_type == DatasetTypeConfig.TFDS: - raise NotImplementedError("TFDS import not yet implemented.") + logger.info("Starting TFDS dataset import from %s", args.dataset_dir) + importer = TFDSDatasetImporter( + input_dataset_name=dataconfig.input_dataset_name, + output_dataset_name=dataconfig.output_dataset.name, + dataset_dir=args.dataset_dir, + dataset_config=dataconfig, + joint_info=robot.joint_info, + ik_urdf_path=ik_urdf_path, + ik_init_config=ik_init_config, + dry_run=args.dry_run, + suppress_warnings=args.no_validation_warnings, + ) + importer.import_all() elif dataset_type == DatasetTypeConfig.RLDS: logger.info("Starting RLDS dataset import from %s", args.dataset_dir) importer = RLDSDatasetImporter( @@ -285,6 +300,8 @@ def _run_import( skip_on_error=skip_on_error, ) importer.import_all() + else: + raise DatasetOperationError(f"Unsupported dataset type: {dataset_type}") logger.info("Finished importing dataset.") diff --git a/neuracore/importer/rlds_tfds_importer.py b/neuracore/importer/rlds_tfds_importer.py new file mode 100644 index 00000000..8ba2296e --- /dev/null +++ b/neuracore/importer/rlds_tfds_importer.py @@ -0,0 +1,657 @@ +"""Shared importer for RLDS/TFDS-style datasets.""" + +from __future__ import annotations + +import os +import time +import traceback +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import tensorflow_datasets as tfds +from neuracore_types import DataType +from neuracore_types.importer.config import LanguageConfig +from neuracore_types.nc_data import DatasetImportConfig + +import neuracore as nc +from neuracore.core.robot import JointInfo +from neuracore.importer.core.base import ( + ImportItem, + NeuracoreDatasetImporter, + WorkerError, +) +from neuracore.importer.core.exceptions import ImportError + +# Suppress TensorFlow informational messages (e.g., "End of sequence") +# 0 = all logs, 1 = no INFO, 2 = no WARNING, 3 = no ERROR +os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "1") + + +class RLDSAndTFDSDatasetImporterBase(NeuracoreDatasetImporter): + """Base class for RLDS/TFDS dataset importers.""" + + dataset_label = "DATASET" + allow_mapping_name_fallback = False + + def __init__( + self, + input_dataset_name: str, + output_dataset_name: str, + dataset_dir: Path, + dataset_config: DatasetImportConfig, + joint_info: dict[str, JointInfo] = {}, + ik_urdf_path: str | None = None, + ik_init_config: list[float] | None = None, + dry_run: bool = False, + suppress_warnings: bool = False, + max_workers: int | None = 1, + skip_on_error: str = "episode", + ): + """Initialize the RLDS/TFDS dataset importer. + + Args: + input_dataset_name: Name of the dataset to import. + output_dataset_name: Name of the dataset to create. + dataset_dir: Directory containing the dataset. + dataset_config: Dataset configuration. + joint_info: Joint info to use for validation. + ik_urdf_path: URDF path for IK (used to recreate IK in worker processes). + ik_init_config: Initial joint configuration for IK. + dry_run: If True, skip actual logging (validation only). + suppress_warnings: If True, suppress warning messages. + max_workers: Maximum number of worker processes. + skip_on_error: "episode" to skip a failed episode; "step" to skip only + failing steps; "all" to abort on the first error. + """ + super().__init__( + dataset_dir=dataset_dir, + dataset_config=dataset_config, + output_dataset_name=output_dataset_name, + max_workers=max_workers, + joint_info=joint_info, + ik_urdf_path=ik_urdf_path, + ik_init_config=ik_init_config, + dry_run=dry_run, + suppress_warnings=suppress_warnings, + skip_on_error=skip_on_error, + ) + self.dataset_name = input_dataset_name + self.builder_dir = self._resolve_builder_dir() + if self.frequency is None: + raise ImportError( + f"Dataset frequency is required for {self.dataset_label} imports." + ) + + builder = self._load_builder() + self.split = self._pick_split(builder) + self.num_episodes = self._count_episodes(builder, self.split) + + self._builder: tfds.core.DatasetBuilder | None = None + self._episode_iter = None + + self.logger.info( + "Initialized %s importer for '%s' " + "(split=%s, episodes=%s, freq=%s, dir=%s)", + self.dataset_label, + self.dataset_name, + self.split, + self.num_episodes, + self.frequency, + self.builder_dir, + ) + + def __getstate__(self) -> dict: + """Drop worker-local handles when pickling for multiprocessing.""" + state = self.__dict__.copy() + state.pop("_builder", None) + state.pop("_episode_iter", None) + return state + + def build_work_items(self) -> Sequence[ImportItem]: + """Build work items for the dataset importer.""" + return [ + ImportItem(index=i, split=str(self.split)) for i in range(self.num_episodes) + ] + + def prepare_worker( + self, worker_id: int, chunk: Sequence[ImportItem] | None = None + ) -> None: + """Prepare the worker for the dataset importer.""" + super().prepare_worker(worker_id, chunk) + self._builder = self._load_builder() + chunk_start = chunk[0].index if chunk else 0 + chunk_length = len(chunk) if chunk else None + + self.logger.info( + "[worker %s] Loading split=%s (start=%s count=%s) from %s", + worker_id, + self.split, + chunk_start, + chunk_length if chunk_length is not None else "remainder", + self.builder_dir, + ) + + dataset = self._load_dataset(self._builder, self.split) + if chunk_start: + dataset = dataset.skip(chunk_start) + if chunk_length is not None: + dataset = dataset.take(chunk_length) + self._episode_iter = iter(dataset) + + def import_item(self, item: ImportItem) -> None: + """Import a single episode to the dataset importer.""" + if self._episode_iter is None: + raise ImportError("Worker dataset iterator was not initialized.") + + try: + episode = next(self._episode_iter) + except Exception as exc: # delegate to subclass + self._handle_episode_load_error(exc, item) + + steps = episode["steps"] + if self.frequency is None: + raise ImportError("Frequency is required for importing episodes.") + total_steps = self._infer_total_steps(steps) + base_time = time.time() + nc.start_recording() + episode_label = ( + f"{item.split or 'episode'} #{item.index}" + if item.split is not None + else str(item.index) + ) + worker_label = ( + f"worker {self._worker_id}" if self._worker_id is not None else "worker 0" + ) + self.logger.info( + "[%s] Importing %s (%s/%s, steps=%s)", + worker_label, + episode_label, + item.index + 1, + self.num_episodes, + total_steps if total_steps is not None else "unknown", + ) + self._emit_progress( + item.index, step=0, total_steps=total_steps, episode_label=episode_label + ) + for idx, step in enumerate(steps, start=1): + timestamp = base_time + (idx / self.frequency) + try: + self._record_step(step, timestamp) + except Exception as exc: # importer-specific policy hook + if self._handle_step_error(exc, item, idx): + continue + raise + self._emit_progress( + item.index, + step=idx, + total_steps=total_steps, + episode_label=episode_label, + ) + nc.stop_recording(wait=True) + self.logger.info("[%s] Completed %s", worker_label, episode_label) + + def _handle_step_error( + self, exc: Exception, item: ImportItem, step_index: int + ) -> bool: + """Skip failing steps when configured with skip_on_error='step'. + + Returns: + True if the error is handled and import should continue with next step. + False to re-raise and fail the current item. + """ + if getattr(self, "skip_on_error", "episode") != "step": + return False + + worker_id_attr = getattr(self, "_worker_id", None) + worker_id = worker_id_attr if worker_id_attr is not None else 0 + + error_queue = getattr(self, "_error_queue", None) + if error_queue is not None: + error_queue.put( + WorkerError( + worker_id=worker_id, + item_index=item.index, + message=f"Step {step_index}: {exc}", + traceback=traceback.format_exc(), + ) + ) + self._log_worker_error(worker_id, item.index, f"Step {step_index}: {exc}") + return True + + def _handle_episode_load_error(self, exc: Exception, item: ImportItem) -> None: + """Map dataset iteration exceptions into importer-level errors.""" + if isinstance(exc, StopIteration): + raise ImportError( + f"No episode available for index {item.index} " + f"(dataset has {self.num_episodes} episodes)." + ) from exc + raise exc + + def _resolve_builder_dir(self) -> Path: + """Find the dataset version directory that contains dataset_info.json.""" + if (self.dataset_dir / "dataset_info.json").exists(): + return self.dataset_dir + + version_dirs = [ + path + for path in self.dataset_dir.iterdir() + if path.is_dir() and (path / "dataset_info.json").exists() + ] + if version_dirs: + return sorted(version_dirs)[-1] + + name_dir = self.dataset_dir / self.dataset_name + if (name_dir / "dataset_info.json").exists(): + return name_dir + nested_versions = ( + [ + path + for path in name_dir.iterdir() + if path.is_dir() and (path / "dataset_info.json").exists() + ] + if name_dir.exists() + else [] + ) + if nested_versions: + return sorted(nested_versions)[-1] + + raise ImportError( + f"Could not find dataset_info.json under {self.dataset_dir}. " + "Pass either the dataset version directory or its parent." + ) + + def _load_builder(self) -> tfds.core.DatasetBuilder: + """Load a TFDS builder directly from the local dataset directory.""" + self.logger.info( + "Loading %s builder from %s", self.dataset_label, self.builder_dir + ) + try: + builder = tfds.builder_from_directory(str(self.builder_dir)) + self._on_builder_loaded(builder) + return builder + except Exception as exc: + raise ImportError( + f"Failed to load {self.dataset_label} builder from " + f"'{self.builder_dir}': {exc}" + ) from exc + + def _on_builder_loaded(self, builder: tfds.core.DatasetBuilder) -> None: + """Hook for subclass-specific builder checks.""" + return None + + def _pick_split(self, builder: tfds.core.DatasetBuilder) -> tfds.typing.SplitArg: + """Select split to inspect; default to all splits.""" + splits = list(builder.info.splits.keys()) + if not splits: + raise ImportError( + f"No splits found in {self.dataset_label} dataset at " + f"'{self.builder_dir}'." + ) + return tfds.Split.ALL + + def _count_episodes( + self, builder: tfds.core.DatasetBuilder, split: tfds.typing.SplitArg + ) -> int: + """Count the number of episodes in the chosen split.""" + if split == tfds.Split.ALL or str(split).lower() == "all": + return int(builder.info.splits.total_num_examples) + try: + split_info = builder.info.splits[split] + except KeyError: + split_info = builder.info.splits[str(split)] + return int(split_info.num_examples) + + def _build_read_config(self) -> tfds.ReadConfig: + """Build read config for normal dataset loading.""" + return tfds.ReadConfig(try_autocache=False) + + def _build_retry_read_config(self) -> tfds.ReadConfig | None: + """Build read config for retry after missing-file errors.""" + return None + + @staticmethod + def _is_missing_file_error(exc: Exception) -> bool: + error_msg = str(exc).lower() + return "no such file or directory" in error_msg or "not found" in error_msg + + def _load_dataset( + self, builder: tfds.core.DatasetBuilder, split: tfds.typing.SplitArg + ) -> tfds.core.dataset_builder.DatasetBuilder: + """Load the TFDS dataset from the local builder.""" + self.logger.info("Opening dataset split '%s' for import.", split) + try: + return builder.as_dataset( + split=split, + shuffle_files=False, + read_config=self._build_read_config(), + ) + except Exception as exc: + retry_config = self._build_retry_read_config() + if self._is_missing_file_error(exc) and retry_config is not None: + self.logger.warning( + "Some dataset shard files appear to be missing. " + "This may indicate an incomplete dataset. " + "Attempting to continue with available shards. Error: %s", + exc, + ) + try: + return builder.as_dataset( + split=split, + shuffle_files=False, + read_config=retry_config, + ) + except Exception as retry_exc: + raise ImportError( + f"Failed to load {self.dataset_label} dataset split " + f"'{split}' even with lenient configuration. " + f"Original error: {exc}. Retry error: {retry_exc}. " + "Please ensure all dataset shard files are present." + ) from retry_exc + raise ImportError( + f"Failed to load {self.dataset_label} dataset split '{split}': {exc}" + ) from exc + + def _infer_total_steps(self, steps: Any) -> int | None: + """Best-effort step count extraction without materializing the dataset.""" + try: + if not isinstance(steps, dict): + length = len(steps) + if isinstance(length, int): + return length + except Exception: + pass + + for attr in ("shape", "shapes"): + try: + shape = getattr(steps, attr) + first_dim = shape[0] if shape else None + if isinstance(first_dim, int): + return first_dim + except Exception: + continue + + if isinstance(steps, dict): + for value in steps.values(): + try: + length = len(value) + except Exception: + continue + if isinstance(length, int): + return length + return None + + def _record_step(self, step_data: dict, timestamp: float) -> None: + """Record a single step to Neuracore.""" + import tensorflow as tf + + for data_type, import_config in self.dataset_config.data_import_config.items(): + source = step_data + for path in import_config.source.split("."): + source = source[path] + + for item in import_config.mapping: + try: + source_data = self._extract_source_data( + source=source, + item=item, + import_source_path=import_config.source, + data_type=data_type, + ) + + source_data = self._convert_source_data( + source_data=source_data, + data_type=data_type, + language_type=import_config.format.language_type, + item_name=item.name, + import_source_path=import_config.source, + tf_module=tf, + ) + except (KeyError, IndexError, TypeError) as exc: + raise ImportError( + f"Failed to extract data for {data_type.value}." + f"{item.name if hasattr(item, 'name') else 'unknown'} " + f"from source '{import_config.source}': {exc}" + ) from exc + + self._log_data( + data_type, source_data, item, import_config.format, timestamp + ) + + def _extract_source_data( + self, + source: Any, + item: Any, + import_source_path: str, + data_type: DataType, + ) -> Any: + if item.source_name is not None: + source_data = source[item.source_name] + elif ( + self.allow_mapping_name_fallback + and isinstance(source, dict) + and item.name in source + ): + source_data = source[item.name] + else: + source_data = source + + if item.index is not None: + if isinstance(source_data, dict): + raise ImportError( + "Cannot index dict with integer index for " + f"{data_type.value}." + f"{item.name if hasattr(item, 'name') else 'unknown'}. " + f"Source path '{import_source_path}' resolved to a dict, " + "not a tensor. Check your dataset config." + ) + source_data = source_data[item.index] + elif item.index_range is not None: + if isinstance(source_data, dict): + raise ImportError( + "Cannot slice dict with index_range for " + f"{data_type.value}." + f"{item.name if hasattr(item, 'name') else 'unknown'}. " + f"Source path '{import_source_path}' resolved to a dict, " + "not a tensor." + ) + source_data = source_data[item.index_range.start : item.index_range.end] + + return source_data + + def _convert_source_data( + self, + source_data: Any, + data_type: DataType, + language_type: LanguageConfig, + item_name: str | None, + import_source_path: str, + tf_module: Any, + ) -> Any: + if data_type == DataType.LANGUAGE and language_type == LanguageConfig.STRING: + return source_data + + if isinstance(source_data, tf_module.Tensor) or ( + hasattr(source_data, "numpy") + and not isinstance(source_data, (dict, list, str, bytes)) + ): + try: + return source_data.numpy() + except Exception as exc: + self.logger.warning( + "Failed to convert tensor to numpy for %s.%s: %s. Using raw value.", + data_type.value, + item_name if item_name else "unknown", + exc, + ) + return source_data + + if isinstance(source_data, dict): + raise ImportError( + "Expected tensor but got dict for " + f"{data_type.value}.{item_name if item_name else 'unknown'}. " + f"Source path '{import_source_path}' may be incorrect." + ) + + if isinstance(source_data, list) and source_data: + if isinstance(source_data[0], tf_module.Tensor) or hasattr( + source_data[0], "numpy" + ): + try: + return [ + value.numpy() if hasattr(value, "numpy") else value + for value in source_data + ] + except Exception: + return source_data + + return source_data + + +class RLDSDatasetImporter(RLDSAndTFDSDatasetImporterBase): + """Importer for RLDS datasets.""" + + dataset_label = "RLDS" + + def __init__( + self, + input_dataset_name: str, + output_dataset_name: str, + dataset_dir: Path, + dataset_config: DatasetImportConfig, + joint_info: dict[str, JointInfo] = {}, + ik_urdf_path: str | None = None, + ik_init_config: list[float] | None = None, + dry_run: bool = False, + suppress_warnings: bool = False, + skip_on_error: str = "episode", + ): + """Initialize the RLDS/TFDS dataset importer. + + Args: + input_dataset_name: Name of the dataset to import. + output_dataset_name: Name of the dataset to create. + dataset_dir: Directory containing the dataset. + dataset_config: Dataset configuration. + joint_info: Joint info to use for validation. + ik_urdf_path: URDF path for IK (used to recreate IK in worker processes). + ik_init_config: Initial joint configuration for IK. + dry_run: If True, skip actual logging (validation only). + suppress_warnings: If True, suppress warning messages. + skip_on_error: "episode" to skip a failed episode; "step" to skip only + failing steps; "all" to abort on the first error. + """ + super().__init__( + input_dataset_name=input_dataset_name, + output_dataset_name=output_dataset_name, + dataset_dir=dataset_dir, + dataset_config=dataset_config, + joint_info=joint_info, + ik_urdf_path=ik_urdf_path, + ik_init_config=ik_init_config, + dry_run=dry_run, + suppress_warnings=suppress_warnings, + max_workers=1, + skip_on_error=skip_on_error, + ) + + def _record_step(self, step_data: dict, timestamp: float) -> None: + """Record a single step to Neuracore. + + RLDS configs may omit source paths, so we preserve that behavior here. + """ + for data_type, import_config in self.dataset_config.data_import_config.items(): + if not import_config.source: + source: Any = step_data + else: + source_path = import_config.source.split(".") + source = step_data + for path in source_path: + source = source[path] + + for item in import_config.mapping: + source_data = ( + source[item.source_name] if item.source_name is not None else source + ) + + if item.index is not None: + source_data = source_data[item.index] + elif item.index_range is not None: + source_data = source_data[ + item.index_range.start : item.index_range.end + ] + + if not ( + data_type == DataType.LANGUAGE + and import_config.format.language_type == LanguageConfig.STRING + ): + source_data = source_data.numpy() + + self._log_data( + data_type, source_data, item, import_config.format, timestamp + ) + + +class TFDSDatasetImporter(RLDSAndTFDSDatasetImporterBase): + """Importer for TFDS (TensorFlow Datasets) datasets.""" + + dataset_label = "TFDS" + allow_mapping_name_fallback = True + + def _on_builder_loaded(self, builder: tfds.core.DatasetBuilder) -> None: + """Check for missing shards after builder load.""" + self._check_missing_shards() + + def _build_read_config(self) -> tfds.ReadConfig: + """Build read config for TFDS datasets.""" + return tfds.ReadConfig( + try_autocache=False, + skip_prefetch=True, + ) + + def _build_retry_read_config(self) -> tfds.ReadConfig | None: + """Build lenient retry read config for TFDS shard issues.""" + return tfds.ReadConfig( + try_autocache=False, + skip_prefetch=True, + interleave_cycle_length=1, + ) + + def _handle_episode_load_error(self, exc: Exception, item: ImportItem) -> None: + """Treat missing-file episode failures as skippable import errors.""" + if isinstance(exc, StopIteration): + super()._handle_episode_load_error(exc, item) + return + if self._is_missing_file_error(exc): + self.logger.warning( + "[worker %s item %s] Skipping episode due to missing shard file: %s", + self._worker_id if self._worker_id is not None else 0, + item.index, + exc, + ) + raise ImportError( + f"Episode {item.index} cannot be loaded due to missing shard " + f"file: {exc}" + ) from exc + raise exc + + def _check_missing_shards(self) -> None: + """Check for TFRecord shards and warn when none are found.""" + try: + tfrecord_files = list(self.builder_dir.glob("*.tfrecord*")) + if not tfrecord_files: + for subdir in self.builder_dir.iterdir(): + if subdir.is_dir(): + tfrecord_files.extend(subdir.glob("*.tfrecord*")) + if tfrecord_files: + self.logger.info( + "Found %d TFRecord shard files in dataset directory", + len(tfrecord_files), + ) + else: + self.logger.warning( + "No TFRecord files found in dataset directory. " + "Dataset may be incomplete or use a different format." + ) + except Exception as exc: # noqa: BLE001 - informational check only + self.logger.debug("Could not check for missing shards: %s", exc) diff --git a/tests/unit/importer/test_base_validation.py b/tests/unit/importer/test_base_validation.py index 26888904..bfaa2f0f 100644 --- a/tests/unit/importer/test_base_validation.py +++ b/tests/unit/importer/test_base_validation.py @@ -293,6 +293,18 @@ def test_validate_joint_positions_above_limit(self, importer): with pytest.raises(DataValidationWarning): importer._validate_joint_data(DataType.JOINT_POSITIONS, data, "joint1") + def test_validate_joint_positions_within_upper_tolerance(self, importer): + """Do not warn for tiny floating-point overshoot near upper limit.""" + data = 1.0000005 # Within epsilon above upper limit 1.0 + + importer._validate_joint_data(DataType.JOINT_POSITIONS, data, "joint1") + + def test_validate_joint_positions_within_lower_tolerance(self, importer): + """Do not warn for tiny floating-point undershoot near lower limit.""" + data = -1.0000005 # Within epsilon below lower limit -1.0 + + importer._validate_joint_data(DataType.JOINT_POSITIONS, data, "joint1") + def test_validate_joint_positions_joint_not_found(self, importer): """Test joint position validation with joint not in joint_info.""" data = 0.5 diff --git a/tests/unit/importer/test_rlds_tfds_importer_behavior.py b/tests/unit/importer/test_rlds_tfds_importer_behavior.py new file mode 100644 index 00000000..ffb8b21f --- /dev/null +++ b/tests/unit/importer/test_rlds_tfds_importer_behavior.py @@ -0,0 +1,214 @@ +"""Tests for shared RLDS/TFDS importer behavior and RLDS overrides.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from neuracore_types import DataType +from neuracore_types.importer.config import LanguageConfig + +from neuracore.importer.core.base import ImportItem, WorkerError +from neuracore.importer.core.exceptions import ImportError +from neuracore.importer.rlds_tfds_importer import ( + RLDSAndTFDSDatasetImporterBase, + RLDSDatasetImporter, + TFDSDatasetImporter, +) + + +class _FakeTensor: + """Simple test tensor-like object exposing a numpy() method.""" + + def __init__(self, value): + self._value = value + + def numpy(self): + return self._value + + +def test_base_handle_step_error_default_returns_false(): + """Base importer should re-raise step errors unless subclass handles them.""" + importer = object.__new__(RLDSAndTFDSDatasetImporterBase) + item = ImportItem(index=1) + + assert importer._handle_step_error(RuntimeError("boom"), item, 2) is False + + +def test_rlds_handle_step_error_step_mode_enqueues_and_logs(): + """RLDS importer should skip step failures when configured with step mode.""" + importer = object.__new__(RLDSDatasetImporter) + importer.skip_on_error = "step" + importer._worker_id = 3 + importer._error_queue = MagicMock() + importer._log_worker_error = MagicMock() + + try: + raise ValueError("bad step") + except ValueError as exc: + handled = importer._handle_step_error(exc, ImportItem(index=7), 4) + + assert handled is True + importer._error_queue.put.assert_called_once() + queued_error = importer._error_queue.put.call_args.args[0] + assert isinstance(queued_error, WorkerError) + assert queued_error.worker_id == 3 + assert queued_error.item_index == 7 + assert queued_error.message == "Step 4: bad step" + assert queued_error.traceback is not None + assert "ValueError: bad step" in queued_error.traceback + importer._log_worker_error.assert_called_once_with(3, 7, "Step 4: bad step") + + +def test_rlds_handle_step_error_non_step_mode_returns_false(): + """RLDS importer should not handle step errors unless step mode is enabled.""" + importer = object.__new__(RLDSDatasetImporter) + importer.skip_on_error = "episode" + importer._worker_id = 1 + importer._error_queue = MagicMock() + importer._log_worker_error = MagicMock() + + try: + raise RuntimeError("boom") + except RuntimeError as exc: + handled = importer._handle_step_error(exc, ImportItem(index=2), 1) + + assert handled is False + importer._error_queue.put.assert_not_called() + importer._log_worker_error.assert_not_called() + + +def test_rlds_record_step_supports_empty_source_path_for_language(): + """RLDS _record_step should allow empty source path and string language values.""" + importer = object.__new__(RLDSDatasetImporter) + mapping_item = SimpleNamespace( + source_name="instruction", + index=None, + index_range=None, + name="instruction", + ) + import_format = SimpleNamespace(language_type=LanguageConfig.STRING) + import_config = SimpleNamespace( + source="", + mapping=[mapping_item], + format=import_format, + ) + importer.dataset_config = SimpleNamespace( + data_import_config={DataType.LANGUAGE: import_config} + ) + importer._log_data = MagicMock() + + importer._record_step({"instruction": "pick up block"}, timestamp=12.5) + + importer._log_data.assert_called_once_with( + DataType.LANGUAGE, + "pick up block", + mapping_item, + import_format, + 12.5, + ) + + +def test_rlds_record_step_converts_tensor_to_numpy_for_non_language(): + """RLDS _record_step should call numpy() for non-language data.""" + importer = object.__new__(RLDSDatasetImporter) + mapping_item = SimpleNamespace( + source_name="joint_positions", + index=None, + index_range=None, + name="joint_positions", + ) + import_format = SimpleNamespace(language_type=LanguageConfig.STRING) + import_config = SimpleNamespace( + source="", + mapping=[mapping_item], + format=import_format, + ) + importer.dataset_config = SimpleNamespace( + data_import_config={DataType.JOINT_POSITIONS: import_config} + ) + importer._log_data = MagicMock() + + importer._record_step({"joint_positions": _FakeTensor([1.0, 2.0])}, timestamp=3.0) + + importer._log_data.assert_called_once_with( + DataType.JOINT_POSITIONS, + [1.0, 2.0], + mapping_item, + import_format, + 3.0, + ) + + +def test_rlds_init_forwards_ik_args_to_base(monkeypatch): + """RLDS importer should forward IK initialization args to base class.""" + captured = {} + + def fake_base_init(self, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr(RLDSAndTFDSDatasetImporterBase, "__init__", fake_base_init) + + RLDSDatasetImporter( + input_dataset_name="in", + output_dataset_name="out", + dataset_dir=SimpleNamespace(), + dataset_config=SimpleNamespace(), + joint_info={}, + ik_urdf_path="/tmp/robot.urdf", + ik_init_config=[0.0, 1.0], + dry_run=True, + suppress_warnings=True, + skip_on_error="step", + ) + + assert captured["ik_urdf_path"] == "/tmp/robot.urdf" + assert captured["ik_init_config"] == [0.0, 1.0] + assert captured["skip_on_error"] == "step" + + +def test_tfds_extract_source_data_uses_mapping_name_fallback(): + """TFDS importer should fallback to mapping name when source_name is missing.""" + importer = object.__new__(TFDSDatasetImporter) + item = SimpleNamespace( + source_name=None, + index=None, + index_range=None, + name="joint_positions", + ) + + source_data = importer._extract_source_data( + source={"joint_positions": [0.1, 0.2]}, + item=item, + import_source_path="steps.observation", + data_type=DataType.JOINT_POSITIONS, + ) + + assert source_data == [0.1, 0.2] + + +def test_extract_then_convert_dict_raises_import_error_without_nested_search(): + """Dict outputs are now rejected directly instead of nested tensor search.""" + importer = object.__new__(RLDSAndTFDSDatasetImporterBase) + item = SimpleNamespace( + source_name="observation", + index=None, + index_range=None, + name="ee_pose", + ) + extracted = importer._extract_source_data( + source={"observation": {"state": _FakeTensor([1.0, 2.0])}}, + item=item, + import_source_path="steps.observation", + data_type=DataType.POSES, + ) + tf_module = SimpleNamespace(Tensor=_FakeTensor) + + with pytest.raises(ImportError, match="Expected tensor but got dict"): + importer._convert_source_data( + source_data=extracted, + data_type=DataType.POSES, + language_type=LanguageConfig.STRING, + item_name="ee_pose", + import_source_path="steps.observation", + tf_module=tf_module, + )