Skip to content

Commit

Permalink
Add custom, file-based HRTF dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jpauwels committed Mar 3, 2024
1 parent 2877681 commit a7f1f68
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 3 deletions.
1 change: 1 addition & 0 deletions hartufo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Scut,
Sonicom,
MitKemar,
CustomSphericalDataset,
)

from .planar import (
Expand Down
18 changes: 16 additions & 2 deletions hartufo/datareader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .query import DataQuery, CipicDataQuery, AriDataQuery, ListenDataQuery, BiLiDataQuery, CrossModDataQuery, ItaDataQuery, HutubsDataQuery, RiecDataQuery, ChedarDataQuery, WidespreadDataQuery, Sadie2DataQuery, Princeton3D3ADataQuery, ScutDataQuery, SonicomDataQuery, MitKemarDataQuery
from .query import DataQuery, CipicDataQuery, AriDataQuery, ListenDataQuery, BiLiDataQuery, CrossModDataQuery, ItaDataQuery, HutubsDataQuery, RiecDataQuery, ChedarDataQuery, WidespreadDataQuery, Sadie2DataQuery, Princeton3D3ADataQuery, ScutDataQuery, SonicomDataQuery, MitKemarDataQuery, CustomDataQuery
from .util import wrap_closed_open_interval, wrap_closed_interval, spherical2cartesian, spherical2interaural, cartesian2spherical, cartesian2interaural, interaural2spherical, interaural2cartesian, quantise
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Iterable, Optional, Union
import numpy as np
import netCDF4 as ncdf
from PIL import Image
Expand Down Expand Up @@ -669,3 +669,17 @@ def __init__(self,

def _sofa_path(self, subject_id):
return str(self.query.sofa_directory_path / f'mit_kemar_{subject_id}_pinna.sofa')


class CustomSphericalDataReader(SofaSphericalDataReader):

def __init__(self,
collection_id: str,
file_paths: Iterable[Union[str, Path]],
):
query = CustomDataQuery(collection_id, file_paths)
super().__init__(query)


def _sofa_path(self, subject_id):
return str(self.query.file_paths[subject_id])
26 changes: 25 additions & 1 deletion hartufo/full.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .datareader import DataReader, CipicDataReader, AriDataReader, ListenDataReader, BiLiDataReader, CrossModDataReader, ItaDataReader, HutubsDataReader, RiecDataReader, ChedarDataReader, WidespreadDataReader, Sadie2DataReader, Princeton3D3ADataReader, ScutDataReader, SonicomDataReader, MitKemarDataReader
from .datareader import DataReader, CipicDataReader, AriDataReader, ListenDataReader, BiLiDataReader, CrossModDataReader, ItaDataReader, HutubsDataReader, RiecDataReader, ChedarDataReader, WidespreadDataReader, Sadie2DataReader, Princeton3D3ADataReader, ScutDataReader, SonicomDataReader, MitKemarDataReader, CustomSphericalDataReader
from .specifications import Spec, HrirSpec, sanitise_specs, sanitise_multiple_specs
from .transforms.hrir import BatchTransform, ScaleTransform, MinPhaseTransform, ResampleTransform, TruncateTransform, DomainTransform, SelectValueRangeTransform, PlaneTransform, InterauralPlaneTransform, SphericalPlaneTransform
from collections import defaultdict
Expand Down Expand Up @@ -723,3 +723,27 @@ def __init__(
verify=verify,
)
super().__init__(datareader, features_spec, target_spec, group_spec, subject_ids, subject_requirements, exclude_ids, dtype)


class CustomSphericalDataset(Dataset):
"""Custom HRTF Dataset
"""
PlaneTransform = SphericalPlaneTransform


def __init__(
self,
collection_id: str,
file_paths: Iterable[Union[str, Path]],
features_spec: Union[Spec, Iterable[Spec]],
target_spec: Optional[Union[Spec, Iterable[Spec]]] = None,
group_spec: Optional[Union[Spec, Iterable[Spec]]] = None,
subject_ids: Optional[Iterable[int]] = None,
subject_requirements: Optional[Dict] = None,
exclude_ids: Optional[Iterable[int]] = None,
dtype: type = np.float32,
download: bool = False,
verify: bool = False,
) -> None:
datareader = CustomSphericalDataReader(collection_id, file_paths)
super().__init__(datareader, features_spec, target_spec, group_spec, subject_ids, subject_requirements, exclude_ids, dtype)
11 changes: 11 additions & 0 deletions hartufo/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,3 +1056,14 @@ def __init__(self, sofa_directory_path='', download=False, verify=False):

def _all_hrir_ids(self, side):
return sorted([x.stem.split('_')[2] for x in self.sofa_directory_path.glob('mit_kemar_*_pinna.sofa')])


class CustomDataQuery(HrirDataQuery):

def __init__(self, collection_id, file_paths):
self.file_paths = file_paths
super().__init__(collection_id=collection_id, sofa_directory_path='.', download=False, verify=False)


def _all_hrir_ids(self, side):
return list(range(len(self.file_paths)))

0 comments on commit a7f1f68

Please sign in to comment.