diff --git a/users/dorian_koch/datasets/MixingDataset.py b/users/dorian_koch/datasets/MixingDataset.py index 6460c3cc2..47a44e9bf 100644 --- a/users/dorian_koch/datasets/MixingDataset.py +++ b/users/dorian_koch/datasets/MixingDataset.py @@ -12,6 +12,7 @@ import typing from i6_experiments.users.zeyer.utils.lru_cache import lru_cache import math +from returnn.log import log class Bitarray: @@ -45,6 +46,10 @@ class MixingDataset(CachedDataset2): Both datasets work in steplock, meaning that they are at the same epoch at all times. This means that, under some configurations, an epoch of one dataset may be seen many times. If this is problematic maybe wrap it in a MultiEpochDataset? (does it support num_seqs? idk) + + # TODO i overcomplicated some things in the design of this, + 1. I hyper optimized for memory usage, which makes the code very messy + 2. Because of 1, this doesnt scale well at all inside a MultiProcDataset """ def __init__( @@ -85,6 +90,11 @@ def __init__( def _reset_params(self): # TODO fix this this upper bound self.total_num_seqs_upper_bound = self.left_dataset.num_seqs + self.right_dataset.num_seqs + if self.total_num_seqs_upper_bound > 0: + print(f"MixingDataset init: {self.left_dataset.num_seqs} + {self.right_dataset.num_seqs} = {self.total_num_seqs_upper_bound}", file=log.v4) + else: + print("MixingDataset init: both datasets are empty", file=log.v4) + self._estimated_num_seqs = self.total_num_seqs_upper_bound assert self.total_num_seqs_upper_bound < 2**31, "sequences do not fit into int32" # 0 means left, 1 means right self.bitset_chooser = Bitarray(self.total_num_seqs_upper_bound) @@ -100,6 +110,7 @@ def _reset_params(self): self.datasets_loaded_until = [0, 0] # we need to _load_seqs the datasets # we will get out of balance while choosing, we will correct this by biasing the next choice self.bias = 0.0 + self._get_childindices_at_seq_idx.cache_clear() def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): """ @@ -152,6 +163,10 @@ def _load_seqs(self, start, end): super()._load_seqs(start=start, end=end) + @staticmethod + def _data_metric(v: numpy.ndarray): + return v.shape[0] if v.ndim >= 1 else 1 + def _run_seq_idx(self, seq_idx): if seq_idx < self.chooser_index: raise Exception("seq_idx < chooser_index") @@ -206,9 +221,9 @@ def _run_seq_idx(self, seq_idx): else: chosen_dataset.load_seqs(start, end) self.datasets_loaded_until[dataset_index] = end - datalen = chosen_dataset.get_data( + datalen = MixingDataset._data_metric(chosen_dataset.get_data( child_indices[dataset_index] % child_lens[dataset_index], self.data_key - ).shape[0] + )) self.bias -= ( (1 - self.mixing_ratio) if chooseRight else self.mixing_ratio ) * datalen @@ -283,22 +298,58 @@ def num_seqs(self): # we can calculate this, but its very expensive! TODO what do? raise Exception("num_seqs not known yet") - @property - def _estimated_num_seqs(self): - return self.total_num_seqs_upper_bound - def get_target_list(self): """ :rtype: list[str] """ return self.left_dataset.get_target_list() + + def get_data_keys(self) -> List[str]: + """data keys""" + return self.left_dataset.get_data_keys() def finish_epoch(self, *, free_resources: bool = False): """finish epoch""" super().finish_epoch(free_resources=free_resources) - print("MixingDataset: finishing epoch! Datasets:") - print(f"Left dataset: {self.chooser_childindices[0]}/{self.left_dataset.num_seqs} ({self.chooser_childindices[0] / self.left_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[0]}") - print(f"Right dataset: {self.chooser_childindices[1]}/{self.right_dataset.num_seqs} ({self.chooser_childindices[1] / self.right_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[1]}") + print("MixingDataset: finishing epoch! Datasets:", file=log.v4) + if self.left_dataset.num_seqs > 0: + print(f"Left dataset: {self.chooser_childindices[0]}/{self.left_dataset.num_seqs} ({self.chooser_childindices[0] / self.left_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[0]}", file=log.v4) + else: + print("Left dataset: empty", file=log.v4) + if self.right_dataset.num_seqs > 0: + print(f"Right dataset: {self.chooser_childindices[1]}/{self.right_dataset.num_seqs} ({self.chooser_childindices[1] / self.right_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[1]}", file=log.v4) + else: + print("Right dataset: empty", file=log.v4) self.left_dataset.finish_epoch(free_resources=free_resources) self.right_dataset.finish_epoch(free_resources=free_resources) + + def get_data_dim(self, key: str) -> int: + """data dim""" + return self.left_dataset.get_data_dim(key) + + def get_data_shape(self, data_key: str) -> List[int]: + """data shape""" + return self.left_dataset.get_data_shape(data_key) + + def get_data_dtype(self, key: str) -> str: + """data dtype""" + return self.left_dataset.get_data_dtype(key) + + def is_data_sparse(self, key: str) -> bool: + """is data sparse""" + return self.left_dataset.is_data_sparse(key) + + def get_epoch_continuous(self, sorted_seq_idx: int) -> float: + assert self.left_dataset.num_seqs > 0 and self.right_dataset.num_seqs > 0 + indices = self._get_childindices_at_seq_idx(sorted_seq_idx) + if indices is None: + return 1.0 # we are done + frac_left = indices[0] / self.left_dataset.num_seqs + frac_right = indices[1] / self.right_dataset.num_seqs + if self.how_to_handle_end_of_data_from_one_dataset == "wrap_around": + return min(frac_left, frac_right) + # "early_exit" or "exception" + return max(frac_left, frac_right) + + # TODO implement is_less_than_num_seqs