Skip to content

Commit

Permalink
Fix mixing cache
Browse files Browse the repository at this point in the history
  • Loading branch information
dorian-K committed Jan 19, 2025
1 parent ef12762 commit f79fc48
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions users/dorian_koch/datasets/MixingDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ def _reset_params(self):
rff = self.left_dataset.num_seqs * (1 + (self.mixing_ratio)/(1-self.mixing_ratio))
if self.how_to_handle_end_of_data_from_one_dataset in ["exception", "early_exit"]:
assert 0.0 < self.mixing_ratio < 1.0, "not implemented"
self.total_num_seqs_upper_bound = math.ceil(min(lff, rff))
else: # wrap_around, so both need to finish
self.total_num_seqs_upper_bound = math.ceil(min(lff, rff)) # only one needs to finish
elif self.how_to_handle_end_of_data_from_one_dataset == "wrap_around":
assert 0.0 < self.mixing_ratio < 1.0, "not implemented"
self.total_num_seqs_upper_bound = math.ceil(max(lff, rff))
self.total_num_seqs_upper_bound = math.ceil(max(lff, rff)) # both need to finish
else:
assert False

assert not math.isnan(self.total_num_seqs_upper_bound) and not math.isinf(self.total_num_seqs_upper_bound)
# for good measure
Expand Down Expand Up @@ -197,8 +199,7 @@ def _run_seq_idx(self, seq_idx):
assert seq_idx < self.total_num_seqs_upper_bound, "This assert fails only when the two datasets are very unbalanced, in the sense that one dataset has many long sequences while the other mostly has shorter once. Keep them on equal lengths on average please! Otherwise you need to somehow increase this upper bound (which will not cause issues, just eat more ram)"
if self.is_chooser_done:
raise Exception("chooser is done. change attribute 'how_to_handle_end_of_data_from_one_dataset' to 'exception' if you want to know why (probably because early_exit)")
# get old childindices
child_indices = self.chooser_childindices

child_lens = [
self.left_dataset.num_seqs,
self.right_dataset.num_seqs
Expand All @@ -208,11 +209,12 @@ def _run_seq_idx(self, seq_idx):
chooseRight = self.bias >= 0 and self.mixing_ratio > 0
self.bitset_chooser.set(self.chooser_index, chooseRight)
if self.chooser_index % 1024 == 0:
self.index_cache[self.chooser_index // 1024] = child_indices
# this works, because index_cache is a numpy array, otherwise we would need to explictly copy
self.index_cache[self.chooser_index // 1024] = self.chooser_childindices
dataset_index = 1 if chooseRight else 0
chosen_dataset = self.right_dataset if chooseRight else self.left_dataset

if child_indices[dataset_index] % child_lens[dataset_index] == 0 and child_indices[dataset_index] > 0:
if self.chooser_childindices[dataset_index] % child_lens[dataset_index] == 0 and self.chooser_childindices[dataset_index] > 0:
self.datasets_exhausted[dataset_index] = True
print(f"MixingDataset: ({dataset_index}) exhausted", file=log.v4)
self._print_progress()
Expand All @@ -231,30 +233,33 @@ def _run_seq_idx(self, seq_idx):
# so just start loading them at the beginning again
if all(self.datasets_exhausted):
self.is_chooser_done = True
c0 = child_indices[0] / max(1, child_lens[0])
c1 = child_indices[1] / max(1, child_lens[1])
c0 = self.chooser_childindices[0] / max(1, child_lens[0])
c1 = self.chooser_childindices[1] / max(1, child_lens[1])
print(f"MixingDataset: optimal mixing ratio = {(self.datalens[1] / c1) / max(1, self.datalens[0]/c0 + self.datalens[1]/c1)}", file=log.v4)
break
# the modulo operator below will wrap around
else:
assert False, f"{self.how_to_handle_end_of_data_from_one_dataset} not implemented"

self._make_sure_idx_is_loaded_in_child_ds(dataset_index, child_indices[dataset_index] % child_lens[dataset_index])
self._make_sure_idx_is_loaded_in_child_ds(dataset_index, self.chooser_childindices[dataset_index] % child_lens[dataset_index])
datalen = MixingDataset._data_metric(chosen_dataset.get_data(
child_indices[dataset_index] % child_lens[dataset_index], self.data_key
self.chooser_childindices[dataset_index] % child_lens[dataset_index], self.data_key
))
#print(f"({dataset_index}) datalen={datalen} shape={data.shape}")
self.bias -= (
(1 - self.mixing_ratio) if chooseRight else -self.mixing_ratio
) * max(datalen, 1)
self.datalens[dataset_index] += datalen
child_indices[dataset_index] += 1
self.chooser_childindices[dataset_index] += 1
self.chooser_index += 1

assert not math.isnan(self.bias) and not math.isinf(
self.bias
) # this should never ever happen

self.chooser_childindices = child_indices
return child_indices
if self.is_chooser_done:
return None
return (self.chooser_childindices[0], self.chooser_childindices[1])

@lru_cache(maxsize=500)
def _get_childindices_at_seq_idx(self, seq_idx):
Expand All @@ -277,6 +282,8 @@ def _get_childindices_at_seq_idx(self, seq_idx):
restore_from_idx = try_seq
restore_indices = result
break
# convert to list to avoid changin the index cache elements
restore_indices = list(restore_indices)

# replay the steps
while restore_from_idx < seq_idx:
Expand All @@ -300,8 +307,6 @@ def _collect_single_seq(self, seq_idx):
:param int seq_idx:
:rtype: DatasetSeq
"""
#if seq_idx % 100 == 0:
# print(f"Collecting idx {seq_idx}, chooser_index={self.chooser_index}, chooser_indices={self.chooser_childindices}", file=log.v4)
dataset_idx, dataset_seq_idx = self._get_dataset_and_childindex_at_seq_idx(seq_idx)
dataset = self.left_dataset if dataset_idx == 0 else self.right_dataset
self._make_sure_idx_is_loaded_in_child_ds(dataset_idx, dataset_seq_idx)
Expand Down

0 comments on commit f79fc48

Please sign in to comment.