Skip to content

Commit

Permalink
Improve mixingdataset
Browse files Browse the repository at this point in the history
  • Loading branch information
dorian-K committed Jan 18, 2025
1 parent ce520c5 commit e9d7879
Showing 1 changed file with 60 additions and 9 deletions.
69 changes: 60 additions & 9 deletions users/dorian_koch/datasets/MixingDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import typing
from i6_experiments.users.zeyer.utils.lru_cache import lru_cache
import math
from returnn.log import log


class Bitarray:
Expand Down Expand Up @@ -45,6 +46,10 @@ class MixingDataset(CachedDataset2):
Both datasets work in steplock, meaning that they are at the same epoch at all times.
This means that, under some configurations, an epoch of one dataset may be seen many times.
If this is problematic maybe wrap it in a MultiEpochDataset? (does it support num_seqs? idk)
# TODO i overcomplicated some things in the design of this,
1. I hyper optimized for memory usage, which makes the code very messy
2. Because of 1, this doesnt scale well at all inside a MultiProcDataset
"""

def __init__(
Expand Down Expand Up @@ -85,6 +90,11 @@ def __init__(
def _reset_params(self):
# TODO fix this this upper bound
self.total_num_seqs_upper_bound = self.left_dataset.num_seqs + self.right_dataset.num_seqs
if self.total_num_seqs_upper_bound > 0:
print(f"MixingDataset init: {self.left_dataset.num_seqs} + {self.right_dataset.num_seqs} = {self.total_num_seqs_upper_bound}", file=log.v4)
else:
print("MixingDataset init: both datasets are empty", file=log.v4)
self._estimated_num_seqs = self.total_num_seqs_upper_bound
assert self.total_num_seqs_upper_bound < 2**31, "sequences do not fit into int32"
# 0 means left, 1 means right
self.bitset_chooser = Bitarray(self.total_num_seqs_upper_bound)
Expand All @@ -100,6 +110,7 @@ def _reset_params(self):
self.datasets_loaded_until = [0, 0] # we need to _load_seqs the datasets
# we will get out of balance while choosing, we will correct this by biasing the next choice
self.bias = 0.0
self._get_childindices_at_seq_idx.cache_clear()

def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""
Expand Down Expand Up @@ -152,6 +163,10 @@ def _load_seqs(self, start, end):

super()._load_seqs(start=start, end=end)

@staticmethod
def _data_metric(v: numpy.ndarray):
return v.shape[0] if v.ndim >= 1 else 1

def _run_seq_idx(self, seq_idx):
if seq_idx < self.chooser_index:
raise Exception("seq_idx < chooser_index")
Expand Down Expand Up @@ -206,9 +221,9 @@ def _run_seq_idx(self, seq_idx):
else:
chosen_dataset.load_seqs(start, end)
self.datasets_loaded_until[dataset_index] = end
datalen = chosen_dataset.get_data(
datalen = MixingDataset._data_metric(chosen_dataset.get_data(
child_indices[dataset_index] % child_lens[dataset_index], self.data_key
).shape[0]
))
self.bias -= (
(1 - self.mixing_ratio) if chooseRight else self.mixing_ratio
) * datalen
Expand Down Expand Up @@ -283,22 +298,58 @@ def num_seqs(self):
# we can calculate this, but its very expensive! TODO what do?
raise Exception("num_seqs not known yet")

@property
def _estimated_num_seqs(self):
return self.total_num_seqs_upper_bound

def get_target_list(self):
"""
:rtype: list[str]
"""
return self.left_dataset.get_target_list()

def get_data_keys(self) -> List[str]:
"""data keys"""
return self.left_dataset.get_data_keys()

def finish_epoch(self, *, free_resources: bool = False):
"""finish epoch"""
super().finish_epoch(free_resources=free_resources)
print("MixingDataset: finishing epoch! Datasets:")
print(f"Left dataset: {self.chooser_childindices[0]}/{self.left_dataset.num_seqs} ({self.chooser_childindices[0] / self.left_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[0]}")
print(f"Right dataset: {self.chooser_childindices[1]}/{self.right_dataset.num_seqs} ({self.chooser_childindices[1] / self.right_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[1]}")
print("MixingDataset: finishing epoch! Datasets:", file=log.v4)
if self.left_dataset.num_seqs > 0:
print(f"Left dataset: {self.chooser_childindices[0]}/{self.left_dataset.num_seqs} ({self.chooser_childindices[0] / self.left_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[0]}", file=log.v4)
else:
print("Left dataset: empty", file=log.v4)
if self.right_dataset.num_seqs > 0:
print(f"Right dataset: {self.chooser_childindices[1]}/{self.right_dataset.num_seqs} ({self.chooser_childindices[1] / self.right_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[1]}", file=log.v4)
else:
print("Right dataset: empty", file=log.v4)

self.left_dataset.finish_epoch(free_resources=free_resources)
self.right_dataset.finish_epoch(free_resources=free_resources)

def get_data_dim(self, key: str) -> int:
"""data dim"""
return self.left_dataset.get_data_dim(key)

def get_data_shape(self, data_key: str) -> List[int]:
"""data shape"""
return self.left_dataset.get_data_shape(data_key)

def get_data_dtype(self, key: str) -> str:
"""data dtype"""
return self.left_dataset.get_data_dtype(key)

def is_data_sparse(self, key: str) -> bool:
"""is data sparse"""
return self.left_dataset.is_data_sparse(key)

def get_epoch_continuous(self, sorted_seq_idx: int) -> float:
assert self.left_dataset.num_seqs > 0 and self.right_dataset.num_seqs > 0
indices = self._get_childindices_at_seq_idx(sorted_seq_idx)
if indices is None:
return 1.0 # we are done
frac_left = indices[0] / self.left_dataset.num_seqs
frac_right = indices[1] / self.right_dataset.num_seqs
if self.how_to_handle_end_of_data_from_one_dataset == "wrap_around":
return min(frac_left, frac_right)
# "early_exit" or "exception"
return max(frac_left, frac_right)

# TODO implement is_less_than_num_seqs

0 comments on commit e9d7879

Please sign in to comment.