diff --git a/hartufo/__init__.py b/hartufo/__init__.py index 9d40c32..a2b1b27 100644 --- a/hartufo/__init__.py +++ b/hartufo/__init__.py @@ -19,6 +19,7 @@ Scut, Sonicom, MitKemar, + CustomSphericalDataset, ) from .planar import ( diff --git a/hartufo/datareader.py b/hartufo/datareader.py index da50088..d8bceae 100644 --- a/hartufo/datareader.py +++ b/hartufo/datareader.py @@ -1,8 +1,8 @@ -from .query import DataQuery, CipicDataQuery, AriDataQuery, ListenDataQuery, BiLiDataQuery, CrossModDataQuery, ItaDataQuery, HutubsDataQuery, RiecDataQuery, ChedarDataQuery, WidespreadDataQuery, Sadie2DataQuery, Princeton3D3ADataQuery, ScutDataQuery, SonicomDataQuery, MitKemarDataQuery +from .query import DataQuery, CipicDataQuery, AriDataQuery, ListenDataQuery, BiLiDataQuery, CrossModDataQuery, ItaDataQuery, HutubsDataQuery, RiecDataQuery, ChedarDataQuery, WidespreadDataQuery, Sadie2DataQuery, Princeton3D3ADataQuery, ScutDataQuery, SonicomDataQuery, MitKemarDataQuery, CustomDataQuery from .util import wrap_closed_open_interval, wrap_closed_interval, spherical2cartesian, spherical2interaural, cartesian2spherical, cartesian2interaural, interaural2spherical, interaural2cartesian, quantise from abc import abstractmethod from pathlib import Path -from typing import Optional, Union +from typing import Iterable, Optional, Union import numpy as np import netCDF4 as ncdf from PIL import Image @@ -669,3 +669,17 @@ def __init__(self, def _sofa_path(self, subject_id): return str(self.query.sofa_directory_path / f'mit_kemar_{subject_id}_pinna.sofa') + + +class CustomSphericalDataReader(SofaSphericalDataReader): + + def __init__(self, + collection_id: str, + file_paths: Iterable[Union[str, Path]], + ): + query = CustomDataQuery(collection_id, file_paths) + super().__init__(query) + + + def _sofa_path(self, subject_id): + return str(self.query.file_paths[subject_id]) diff --git a/hartufo/full.py b/hartufo/full.py index 1b496b3..2a04b01 100644 --- a/hartufo/full.py +++ b/hartufo/full.py @@ -1,4 +1,4 @@ -from .datareader import DataReader, CipicDataReader, AriDataReader, ListenDataReader, BiLiDataReader, CrossModDataReader, ItaDataReader, HutubsDataReader, RiecDataReader, ChedarDataReader, WidespreadDataReader, Sadie2DataReader, Princeton3D3ADataReader, ScutDataReader, SonicomDataReader, MitKemarDataReader +from .datareader import DataReader, CipicDataReader, AriDataReader, ListenDataReader, BiLiDataReader, CrossModDataReader, ItaDataReader, HutubsDataReader, RiecDataReader, ChedarDataReader, WidespreadDataReader, Sadie2DataReader, Princeton3D3ADataReader, ScutDataReader, SonicomDataReader, MitKemarDataReader, CustomSphericalDataReader from .specifications import Spec, HrirSpec, sanitise_specs, sanitise_multiple_specs from .transforms.hrir import BatchTransform, ScaleTransform, MinPhaseTransform, ResampleTransform, TruncateTransform, DomainTransform, SelectValueRangeTransform, PlaneTransform, InterauralPlaneTransform, SphericalPlaneTransform from collections import defaultdict @@ -723,3 +723,27 @@ def __init__( verify=verify, ) super().__init__(datareader, features_spec, target_spec, group_spec, subject_ids, subject_requirements, exclude_ids, dtype) + + +class CustomSphericalDataset(Dataset): + """Custom HRTF Dataset + """ + PlaneTransform = SphericalPlaneTransform + + + def __init__( + self, + collection_id: str, + file_paths: Iterable[Union[str, Path]], + features_spec: Union[Spec, Iterable[Spec]], + target_spec: Optional[Union[Spec, Iterable[Spec]]] = None, + group_spec: Optional[Union[Spec, Iterable[Spec]]] = None, + subject_ids: Optional[Iterable[int]] = None, + subject_requirements: Optional[Dict] = None, + exclude_ids: Optional[Iterable[int]] = None, + dtype: type = np.float32, + download: bool = False, + verify: bool = False, + ) -> None: + datareader = CustomSphericalDataReader(collection_id, file_paths) + super().__init__(datareader, features_spec, target_spec, group_spec, subject_ids, subject_requirements, exclude_ids, dtype) diff --git a/hartufo/query.py b/hartufo/query.py index 8810955..e9e4a0b 100644 --- a/hartufo/query.py +++ b/hartufo/query.py @@ -1056,3 +1056,14 @@ def __init__(self, sofa_directory_path='', download=False, verify=False): def _all_hrir_ids(self, side): return sorted([x.stem.split('_')[2] for x in self.sofa_directory_path.glob('mit_kemar_*_pinna.sofa')]) + + +class CustomDataQuery(HrirDataQuery): + + def __init__(self, collection_id, file_paths): + self.file_paths = file_paths + super().__init__(collection_id=collection_id, sofa_directory_path='.', download=False, verify=False) + + + def _all_hrir_ids(self, side): + return list(range(len(self.file_paths)))