diff --git a/hartufo/full.py b/hartufo/full.py index 48e3ce2..1b496b3 100644 --- a/hartufo/full.py +++ b/hartufo/full.py @@ -1,6 +1,6 @@ from .datareader import DataReader, CipicDataReader, AriDataReader, ListenDataReader, BiLiDataReader, CrossModDataReader, ItaDataReader, HutubsDataReader, RiecDataReader, ChedarDataReader, WidespreadDataReader, Sadie2DataReader, Princeton3D3ADataReader, ScutDataReader, SonicomDataReader, MitKemarDataReader from .specifications import Spec, HrirSpec, sanitise_specs, sanitise_multiple_specs -from .transforms.hrir import BatchTransform, ScaleTransform, MinPhaseTransform, ResampleTransform, TruncateTransform, DomainTransform, SelectValueRangeTransform, InterauralPlaneTransform, SphericalPlaneTransform +from .transforms.hrir import BatchTransform, ScaleTransform, MinPhaseTransform, ResampleTransform, TruncateTransform, DomainTransform, SelectValueRangeTransform, PlaneTransform, InterauralPlaneTransform, SphericalPlaneTransform from collections import defaultdict from copy import deepcopy from itertools import chain @@ -65,6 +65,7 @@ def __init__( self.orthogonal_angles = np.array([]) self.radii = np.array([]) self._selection_mask = np.array([]) + self._directions = np.array([]) self._data = {} return @@ -81,6 +82,10 @@ def __init__( self.fundamental_angles, self.orthogonal_angles, self.radii, self._selection_mask, *_ = datareader._map_sofa_position_order_to_matrix( self.subject_ids[0], requested_fundamental_angles, requested_orthogonal_angles, hrir_spec.get('distance'), ) + self._directions = np.ma.masked_where( + np.tile(self._selection_mask[..., np.newaxis], (1, 1, 1, 3)), + np.stack(np.meshgrid(self.fundamental_angles, self.orthogonal_angles, self.radii, indexing='ij'), axis=-1), + ) if hrir_spec.get('side', '').startswith('both-'): datareader._verify_angle_symmetry(self.fundamental_angles, self.orthogonal_angles) # Create plane transform from file angles and mask @@ -257,6 +262,14 @@ def hrtf_frequencies(self): return self._hrtf_frequencies[region_selector[0]._selection] + @property + def directions(self): + plane_transforms = [t for t in self.full_chain if isinstance(t, PlaneTransform)] + if plane_transforms: + return plane_transforms[0](self._directions) + return self._directions + + def split_by_angles(dataset: Dataset): angle_datasets = [] for row_idx, fundamental_angle in enumerate(dataset.fundamental_angles):