Skip to content

Commit

Permalink
Fix MixingDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
dorian-K committed Jan 18, 2025
1 parent e9d7879 commit 623e647
Showing 1 changed file with 80 additions and 55 deletions.
135 changes: 80 additions & 55 deletions users/dorian_koch/datasets/MixingDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from returnn.datasets.cached2 import CachedDataset2
import returnn.util.basic as util
from returnn.util.basic import NumbersDict, load_json, OptionalNotImplementedError
from returnn.log import log
from random import Random
import numpy
import sys
Expand Down Expand Up @@ -88,10 +87,25 @@ def __init__(
self._reset_params()

def _reset_params(self):
# TODO fix this this upper bound
self.total_num_seqs_upper_bound = self.left_dataset.num_seqs + self.right_dataset.num_seqs
assert not (0 < self.right_dataset.num_seqs < 10) and not (0 < self.left_dataset.num_seqs < 10), "mixing can go wrong when one dataset has very few seqs"
# left finishes first
lff = self.right_dataset.num_seqs * (1 + (1-self.mixing_ratio)/(self.mixing_ratio))
# right finishes first
rff = self.left_dataset.num_seqs * (1 + (self.mixing_ratio)/(1-self.mixing_ratio))
if self.how_to_handle_end_of_data_from_one_dataset in ["exception", "early_exit"]:
assert 0.0 < self.mixing_ratio < 1.0, "not implemented"
self.total_num_seqs_upper_bound = math.ceil(min(lff, rff))
else: # wrap_around, so both need to finish
assert 0.0 < self.mixing_ratio < 1.0, "not implemented"
self.total_num_seqs_upper_bound = math.ceil(max(lff, rff))

assert not math.isnan(self.total_num_seqs_upper_bound) and not math.isinf(self.total_num_seqs_upper_bound)
# for good measure
self.total_num_seqs_upper_bound += 10
self.total_num_seqs_upper_bound *= 2

if self.total_num_seqs_upper_bound > 0:
print(f"MixingDataset init: {self.left_dataset.num_seqs} + {self.right_dataset.num_seqs} = {self.total_num_seqs_upper_bound}", file=log.v4)
print(f"MixingDataset init: {self.left_dataset.num_seqs} + {self.right_dataset.num_seqs}, upperbound={self.total_num_seqs_upper_bound}, mixingratio={self.mixing_ratio}", file=log.v4)
else:
print("MixingDataset init: both datasets are empty", file=log.v4)
self._estimated_num_seqs = self.total_num_seqs_upper_bound
Expand All @@ -110,6 +124,7 @@ def _reset_params(self):
self.datasets_loaded_until = [0, 0] # we need to _load_seqs the datasets
# we will get out of balance while choosing, we will correct this by biasing the next choice
self.bias = 0.0
self.datalens = [0, 0]
self._get_childindices_at_seq_idx.cache_clear()

def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
Expand Down Expand Up @@ -141,36 +156,39 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
self._reset_params()
return True

def _load_seqs(self, start, end):
"""
:param int start:
:param int end:
"""
end_indices = self._get_childindices_at_seq_idx(end)
start_indices = self._get_childindices_at_seq_idx(start)

assert end_indices is not None
assert start_indices is not None

if self.datasets_loaded_until[0] <= end_indices[0]:
load_until = min(self.left_dataset.num_seqs, end_indices[0] + 1)
self.left_dataset.load_seqs(start_indices[0], load_until)
self.datasets_loaded_until[0] = load_until
if self.datasets_loaded_until[1] <= end_indices[1]:
load_until = min(self.right_dataset.num_seqs, end_indices[1] + 1)
self.right_dataset.load_seqs(start_indices[1], load_until)
self.datasets_loaded_until[1] = load_until

super()._load_seqs(start=start, end=end)

@staticmethod
def _data_metric(v: numpy.ndarray):
return v.shape[0] if v.ndim >= 1 else 1

def _make_sure_idx_is_loaded_in_child_ds(self, dataset_index, seq_idx):
chosen_dataset = self.left_dataset if dataset_index == 0 else self.right_dataset
child_len = chosen_dataset.num_seqs

# TODO fix this stupid hack
if hasattr(chosen_dataset, "expected_load_seq_start") and seq_idx < chosen_dataset.expected_load_seq_start:
chosen_dataset.init_seq_order(epoch=self.epoch)
self.datasets_loaded_until[dataset_index] = 0

if self.datasets_loaded_until[dataset_index] <= seq_idx:
# 512 is just some arbitrary number TODO maybe decrease this for more intensive workloads?
start = seq_idx
end = (seq_idx + min(child_len - 1, 512)) % child_len

if end < start:
# print(f"({dataset_index}) end < start: loading segs from {start} to {child_len}", file=log.v4)
chosen_dataset.load_seqs(start, child_len)
self.datasets_loaded_until[dataset_index] = child_len
assert self.datasets_loaded_until[dataset_index] >= seq_idx
# not sure if we should also load from 0 to end here, it may erase the data from start to child_lens? idk
else:
# print(f"({dataset_index}) just loading segs from {start} to {end}", file=log.v4)
chosen_dataset.load_seqs(start, end)
self.datasets_loaded_until[dataset_index] = end

def _run_seq_idx(self, seq_idx):
if seq_idx < self.chooser_index:
raise Exception("seq_idx < chooser_index")
assert seq_idx < self.total_num_seqs_upper_bound
assert seq_idx < self.total_num_seqs_upper_bound, "This assert fails only when the two datasets are very unbalanced, in the sense that one dataset has many long sequences while the other mostly has shorter once. Keep them on equal lengths on average please! Otherwise you need to somehow increase this upper bound (which will not cause issues, just eat more ram)"
if self.is_chooser_done:
raise Exception("chooser is done. change attribute 'how_to_handle_end_of_data_from_one_dataset' to 'exception' if you want to know why (probably because early_exit)")
# get old childindices
Expand All @@ -188,8 +206,10 @@ def _run_seq_idx(self, seq_idx):
dataset_index = 1 if chooseRight else 0
chosen_dataset = self.right_dataset if chooseRight else self.left_dataset

if child_indices[dataset_index] >= chosen_dataset.num_seqs:
if child_indices[dataset_index] % child_lens[dataset_index] == 0 and child_indices[dataset_index] > 0:
self.datasets_exhausted[dataset_index] = True
print(f"MixingDataset: ({dataset_index}) exhausted", file=log.v4)
self._print_progress()
if self.how_to_handle_end_of_data_from_one_dataset == "exception":
self.is_chooser_done = True
raise Exception(
Expand All @@ -203,30 +223,20 @@ def _run_seq_idx(self, seq_idx):
elif self.how_to_handle_end_of_data_from_one_dataset == "wrap_around":
# im not sure of the logic inside the datasets and whether it keeps data that has been loaded before indefinitely,
# so just start loading them at the beginning again
self.datasets_loaded_until[dataset_index] = 0
if all(self.datasets_exhausted):
self.is_chooser_done = True
break
# the modulo operator below will wrap around

if self.datasets_loaded_until[dataset_index] <= child_indices[dataset_index] % child_lens[dataset_index]:
# 512 is just some arbitrary number
start = child_indices[dataset_index] % child_lens[dataset_index]
end = (child_indices[dataset_index] + 512) % child_lens[dataset_index]
if end < start:
chosen_dataset.load_seqs(start, child_lens[dataset_index])
self.datasets_loaded_until[dataset_index] = child_lens[dataset_index]
assert self.datasets_loaded_until[dataset_index] >= child_indices[dataset_index]
# not sure if we should also load from 0 to end here, it may erase the data from start to child_lens? idk
else:
chosen_dataset.load_seqs(start, end)
self.datasets_loaded_until[dataset_index] = end
self._make_sure_idx_is_loaded_in_child_ds(dataset_index, child_indices[dataset_index] % child_lens[dataset_index])
datalen = MixingDataset._data_metric(chosen_dataset.get_data(
child_indices[dataset_index] % child_lens[dataset_index], self.data_key
))
#print(f"({dataset_index}) datalen={datalen} shape={data.shape}")
self.bias -= (
(1 - self.mixing_ratio) if chooseRight else self.mixing_ratio
) * datalen
(1 - self.mixing_ratio) if chooseRight else -self.mixing_ratio
) * max(datalen, 1)
self.datalens[dataset_index] += datalen
child_indices[dataset_index] += 1
self.chooser_index += 1

Expand All @@ -248,7 +258,7 @@ def _get_childindices_at_seq_idx(self, seq_idx):
ran_ids = self._run_seq_idx(seq_idx)
if seq_idx >= self.chooser_index:
return None # we could not progress to the desired seq_idx, maybe early exit or exhaustion?
return ran_ids
return (ran_ids[0] % self.left_dataset.num_seqs, ran_ids[1] % self.right_dataset.num_seqs)
# maybe in cache? this should happen often when we go over the dataset sequentially
restore_from_idx = seq_idx - (seq_idx % 1024)
restore_indices = self.index_cache[restore_from_idx // 1024]
Expand All @@ -273,21 +283,33 @@ def _get_dataset_and_childindex_at_seq_idx(self, seq_idx):
indices = self._get_childindices_at_seq_idx(seq_idx)
assert indices is not None
choose_right = self.bitset_chooser.get(seq_idx)
dataset = self.right_dataset if choose_right else self.left_dataset
return dataset, indices[1 if choose_right else 0]
dataset_index = 1 if choose_right else 0
return dataset_index, indices[1 if choose_right else 0]

def _collect_single_seq(self, seq_idx):
"""
:param int seq_idx:
:rtype: DatasetSeq
"""
dataset, dataset_seq_idx = self._get_dataset_and_childindex_at_seq_idx(seq_idx)
#if seq_idx % 100 == 0:
# print(f"Collecting idx {seq_idx}, chooser_index={self.chooser_index}, chooser_indices={self.chooser_childindices}", file=log.v4)
dataset_idx, dataset_seq_idx = self._get_dataset_and_childindex_at_seq_idx(seq_idx)
dataset = self.left_dataset if dataset_idx == 0 else self.right_dataset
self._make_sure_idx_is_loaded_in_child_ds(dataset_idx, dataset_seq_idx)
seq_tag = dataset.get_tag(dataset_seq_idx)
features = {
k: dataset.get_data(dataset_seq_idx, k) for k in dataset.get_data_keys()
}
return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)

def is_less_than_num_seqs(self, seq_idx: int):
if seq_idx < self.chooser_index:
return True
if self.is_chooser_done:
return False
ids = self._get_childindices_at_seq_idx(seq_idx)
return ids is not None

@property
def num_seqs(self):
"""
Expand All @@ -308,18 +330,21 @@ def get_data_keys(self) -> List[str]:
"""data keys"""
return self.left_dataset.get_data_keys()

def finish_epoch(self, *, free_resources: bool = False):
"""finish epoch"""
super().finish_epoch(free_resources=free_resources)
print("MixingDataset: finishing epoch! Datasets:", file=log.v4)
def _print_progress(self):
if self.left_dataset.num_seqs > 0:
print(f"Left dataset: {self.chooser_childindices[0]}/{self.left_dataset.num_seqs} ({self.chooser_childindices[0] / self.left_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[0]}", file=log.v4)
print(f"MixingDataset: Left dataset: {self.chooser_childindices[0]}/{self.left_dataset.num_seqs} ({self.chooser_childindices[0] / self.left_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[0]}, avg_datalen={self.datalens[0]/max(1, self.chooser_childindices[0])}", file=log.v4)
else:
print("Left dataset: empty", file=log.v4)
print("MixingDataset: Left dataset: empty", file=log.v4)
if self.right_dataset.num_seqs > 0:
print(f"Right dataset: {self.chooser_childindices[1]}/{self.right_dataset.num_seqs} ({self.chooser_childindices[1] / self.right_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[1]}", file=log.v4)
print(f"MixingDataset: Right dataset: {self.chooser_childindices[1]}/{self.right_dataset.num_seqs} ({self.chooser_childindices[1] / self.right_dataset.num_seqs * 100}%) exhausted={self.datasets_exhausted[1]}, avg_datalen={self.datalens[1]/max(1, self.chooser_childindices[1])}", file=log.v4)
else:
print("Right dataset: empty", file=log.v4)
print("MixingDataset: Right dataset: empty", file=log.v4)

def finish_epoch(self, *, free_resources: bool = False):
"""finish epoch"""
super().finish_epoch(free_resources=free_resources)
print("MixingDataset: finishing epoch! Datasets:", file=log.v4)
self._print_progress()

self.left_dataset.finish_epoch(free_resources=free_resources)
self.right_dataset.finish_epoch(free_resources=free_resources)
Expand Down

0 comments on commit 623e647

Please sign in to comment.