diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index cc6b7195fe2..58189f1fd29 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -90,6 +90,8 @@ title: Create a document dataset - local: nifti_dataset title: Create a medical imaging dataset + - local: bids_dataset + title: Load a BIDS dataset title: "Vision" - sections: - local: nlp_load diff --git a/docs/source/bids_dataset.mdx b/docs/source/bids_dataset.mdx new file mode 100644 index 00000000000..89ca31e566e --- /dev/null +++ b/docs/source/bids_dataset.mdx @@ -0,0 +1,63 @@ +# BIDS Dataset + +[BIDS (Brain Imaging Data Structure)](https://bids.neuroimaging.io/) is a standard for organizing and describing neuroimaging and behavioral data. The `datasets` library supports loading BIDS datasets directly, leveraging `pybids` for parsing and `nibabel` for handling NIfTI files. + + + +To use the BIDS loader, you need to install the `bids` extra (which installs `pybids` and `nibabel`): + +```bash +pip install datasets[bids] +``` + + + +## Loading a BIDS Dataset + +You can load a BIDS dataset by pointing to its root directory (containing `dataset_description.json`): + +```python +from datasets import load_dataset + +# Load a local BIDS dataset +ds = load_dataset("bids", data_dir="/path/to/bids/dataset") + +# Access the first example +print(ds["train"][0]) +# { +# 'subject': '01', +# 'session': 'baseline', +# 'datatype': 'anat', +# 'suffix': 'T1w', +# 'nifti': , +# ... +# } +``` + +The `nifti` column contains `nibabel` image objects, which can be visualized interactively in Jupyter notebooks. + +## Filtering + +You can filter the dataset by BIDS entities like `subject`, `session`, and `datatype` when loading: + +```python +# Load only specific subjects and datatypes +ds = load_dataset( + "bids", + data_dir="/path/to/bids/dataset", + subjects=["01", "05", "10"], + sessions=["pre", "post"], + datatypes=["func"], +) +``` + +## Metadata + +BIDS datasets often include JSON sidecar files with metadata (e.g., scanner parameters). This metadata is loaded into the `metadata` column as a JSON string. + +```python +import json + +metadata = json.loads(ds["train"][0]["metadata"]) +print(metadata["RepetitionTime"]) +``` diff --git a/setup.py b/setup.py index 30d66fc54db..42dd5c101b8 100644 --- a/setup.py +++ b/setup.py @@ -210,6 +210,8 @@ NIBABEL_REQUIRE = ["nibabel>=5.3.2", "ipyniivue==2.4.2"] +PYBIDS_REQUIRE = ["pybids>=0.21.0"] + NIBABEL_REQUIRE + EXTRAS_REQUIRE = { "audio": AUDIO_REQUIRE, "vision": VISION_REQUIRE, @@ -228,6 +230,7 @@ "docs": DOCS_REQUIRE, "pdfs": PDFS_REQUIRE, "nibabel": NIBABEL_REQUIRE, + "bids": PYBIDS_REQUIRE, } setup( diff --git a/src/datasets/config.py b/src/datasets/config.py index b6412682727..2df571e4b8f 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -140,6 +140,7 @@ TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None PDFPLUMBER_AVAILABLE = importlib.util.find_spec("pdfplumber") is not None NIBABEL_AVAILABLE = importlib.util.find_spec("nibabel") is not None +PYBIDS_AVAILABLE = importlib.util.find_spec("bids") is not None # Optional compression tools RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index c9a32ff71f0..9655dffdc10 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -6,6 +6,7 @@ from .arrow import arrow from .audiofolder import audiofolder +from .bids import bids from .cache import cache from .csv import csv from .eval import eval @@ -49,6 +50,7 @@ def _hash_python_lines(lines: list[str]) -> str: "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())), "pdffolder": (pdffolder.__name__, _hash_python_lines(inspect.getsource(pdffolder).splitlines())), "niftifolder": (niftifolder.__name__, _hash_python_lines(inspect.getsource(niftifolder).splitlines())), + "bids": (bids.__name__, _hash_python_lines(inspect.getsource(bids).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), "hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())), diff --git a/src/datasets/packaged_modules/bids/__init__.py b/src/datasets/packaged_modules/bids/__init__.py new file mode 100644 index 00000000000..1d167b51030 --- /dev/null +++ b/src/datasets/packaged_modules/bids/__init__.py @@ -0,0 +1 @@ +from .bids import Bids, BidsConfig diff --git a/src/datasets/packaged_modules/bids/bids.py b/src/datasets/packaged_modules/bids/bids.py new file mode 100644 index 00000000000..d165218de4d --- /dev/null +++ b/src/datasets/packaged_modules/bids/bids.py @@ -0,0 +1,116 @@ +import json +import os +from dataclasses import dataclass +from typing import Optional + +import datasets +from datasets import config + + +logger = datasets.utils.logging.get_logger(__name__) + + +@dataclass +class BidsConfig(datasets.BuilderConfig): + """BuilderConfig for BIDS datasets.""" + + data_dir: Optional[str] = None + database_path: Optional[str] = None # For pybids caching + subjects: Optional[list[str]] = None # Filter by subject + sessions: Optional[list[str]] = None # Filter by session + datatypes: Optional[list[str]] = None # Filter by datatype + + +class Bids(datasets.GeneratorBasedBuilder): + """BIDS dataset loader using pybids.""" + + BUILDER_CONFIG_CLASS = BidsConfig + + def _info(self): + if not config.PYBIDS_AVAILABLE: + raise ImportError("To load BIDS datasets, please install pybids: pip install pybids") + if not config.NIBABEL_AVAILABLE: + raise ImportError("To load BIDS datasets, please install nibabel: pip install nibabel") + + return datasets.DatasetInfo( + features=datasets.Features( + { + "subject": datasets.Value("string"), + "session": datasets.Value("string"), + "datatype": datasets.Value("string"), + "suffix": datasets.Value("string"), + "task": datasets.Value("string"), + "run": datasets.Value("string"), + "path": datasets.Value("string"), + "nifti": datasets.Nifti(), + "metadata": datasets.Value("string"), + } + ) + ) + + def _split_generators(self, dl_manager): + from bids import BIDSLayout + + if not self.config.data_dir: + raise ValueError("data_dir is required for BIDS datasets") + + if not os.path.isdir(self.config.data_dir): + raise ValueError(f"data_dir does not exist: {self.config.data_dir}") + + desc_file = os.path.join(self.config.data_dir, "dataset_description.json") + if not os.path.exists(desc_file): + raise ValueError(f"Not a valid BIDS dataset: missing dataset_description.json in {self.config.data_dir}") + + layout = BIDSLayout( + self.config.data_dir, + database_path=self.config.database_path, + validate=False, # Don't fail on minor validation issues + ) + + # Build query kwargs + query = {"extension": [".nii", ".nii.gz"]} + if self.config.subjects: + query["subject"] = self.config.subjects + if self.config.sessions: + query["session"] = self.config.sessions + if self.config.datatypes: + query["datatype"] = self.config.datatypes + + # Get all NIfTI files + nifti_files = layout.get(**query) + + if not nifti_files: + logger.warning( + f"No NIfTI files found in {self.config.data_dir} with filters: {query}. " + "Check that the dataset is valid BIDS and filters match existing data." + ) + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={"layout": layout, "files": nifti_files}, + ) + ] + + def _generate_examples(self, layout, files): + for idx, bids_file in enumerate(files): + entities = bids_file.get_entities() + + # Get JSON sidecar metadata + metadata = layout.get_metadata(bids_file.path) + metadata_str = json.dumps(metadata) if metadata else "{}" + + yield ( + idx, + { + "subject": entities.get("subject"), + "session": entities.get("session"), + "datatype": entities.get("datatype"), + "suffix": entities.get("suffix"), + "task": entities.get("task"), + "run": str(entities.get("run")) if entities.get("run") else None, + "path": bids_file.path, + "nifti": bids_file.path, + "metadata": metadata_str, + }, + ) diff --git a/tests/packaged_modules/test_bids.py b/tests/packaged_modules/test_bids.py new file mode 100644 index 00000000000..8ce2be9b72b --- /dev/null +++ b/tests/packaged_modules/test_bids.py @@ -0,0 +1,120 @@ +import json + +import numpy as np +import pytest + +import datasets.config + + +@pytest.fixture +def minimal_bids_dataset(tmp_path): + """Minimal valid BIDS dataset with one subject, one T1w scan.""" + # dataset_description.json (required) + (tmp_path / "dataset_description.json").write_text( + json.dumps({"Name": "Test BIDS Dataset", "BIDSVersion": "1.10.1"}) + ) + + # Create subject/anat folder + anat_dir = tmp_path / "sub-01" / "anat" + anat_dir.mkdir(parents=True) + + # Create dummy NIfTI + if datasets.config.NIBABEL_AVAILABLE: + import nibabel as nib + + data = np.zeros((4, 4, 4), dtype=np.float32) + img = nib.Nifti1Image(data, np.eye(4)) + nib.save(img, str(anat_dir / "sub-01_T1w.nii.gz")) + else: + # Fallback if nibabel not available (shouldn't happen in test env ideally) + (anat_dir / "sub-01_T1w.nii.gz").write_bytes(b"DUMMY NIFTI CONTENT") + + # JSON sidecar + (anat_dir / "sub-01_T1w.json").write_text(json.dumps({"RepetitionTime": 2.0})) + + return str(tmp_path) + + +@pytest.fixture +def multi_subject_bids(tmp_path): + """BIDS dataset with multiple subjects and sessions.""" + (tmp_path / "dataset_description.json").write_text( + json.dumps({"Name": "Multi-Subject Test", "BIDSVersion": "1.10.1"}) + ) + + data = np.zeros((4, 4, 4), dtype=np.float32) + + if datasets.config.NIBABEL_AVAILABLE: + import nibabel as nib + else: + nib = None + + for sub in ["01", "02"]: + for ses in ["baseline", "followup"]: + anat_dir = tmp_path / f"sub-{sub}" / f"ses-{ses}" / "anat" + anat_dir.mkdir(parents=True) + + file_path = anat_dir / f"sub-{sub}_ses-{ses}_T1w.nii.gz" + if nib: + img = nib.Nifti1Image(data, np.eye(4)) + nib.save(img, str(file_path)) + else: + file_path.write_bytes(b"DUMMY NIFTI CONTENT") + + (anat_dir / f"sub-{sub}_ses-{ses}_T1w.json").write_text(json.dumps({"RepetitionTime": 2.0})) + + return str(tmp_path) + + +def test_bids_module_imports(): + from datasets.packaged_modules.bids import Bids, BidsConfig + + assert Bids is not None + assert BidsConfig is not None + + +def test_bids_requires_pybids(monkeypatch): + """Test helpful error when pybids not installed.""" + from datasets.packaged_modules.bids.bids import Bids + + monkeypatch.setattr(datasets.config, "PYBIDS_AVAILABLE", False) + + with pytest.raises(ImportError, match="pybids"): + Bids() + + +@pytest.mark.skipif( + not datasets.config.PYBIDS_AVAILABLE or not datasets.config.NIBABEL_AVAILABLE, + reason="pybids or nibabel not installed", +) +def test_bids_loads_single_subject(minimal_bids_dataset): + from datasets import load_dataset + + ds = load_dataset("bids", data_dir=minimal_bids_dataset) + + assert "train" in ds + assert len(ds["train"]) == 1 + + sample = ds["train"][0] + assert sample["subject"] == "01" + assert sample["suffix"] == "T1w" + assert sample["datatype"] == "anat" + assert sample["session"] is None + + +@pytest.mark.skipif( + not datasets.config.PYBIDS_AVAILABLE or not datasets.config.NIBABEL_AVAILABLE, + reason="pybids or nibabel not installed", +) +def test_bids_multi_subject(multi_subject_bids): + from datasets import load_dataset + + ds = load_dataset("bids", data_dir=multi_subject_bids) + + assert len(ds["train"]) == 4 # 2 subjects × 2 sessions + + subjects = {sample["subject"] for sample in ds["train"]} + assert subjects == {"01", "02"} + + sessions = {sample["session"] for sample in ds["train"]} + assert sessions == {"baseline", "followup"}