From 4d51834398bad308acec4835a908adba94dac07d Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 3 Dec 2024 20:57:16 +0100 Subject: [PATCH] new WriteSeqListFromShuffledJob, WriteSeqListInOrigOrderFromShuffledJob --- .../zeyer/datasets/utils/extract_seq_list.py | 86 ++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/users/zeyer/datasets/utils/extract_seq_list.py b/users/zeyer/datasets/utils/extract_seq_list.py index d1bed9778..fb58d874f 100644 --- a/users/zeyer/datasets/utils/extract_seq_list.py +++ b/users/zeyer/datasets/utils/extract_seq_list.py @@ -2,7 +2,7 @@ Extract/generate seq lists. """ -from typing import Optional, Union, Any, Dict +from typing import Optional, Union, Any, Sequence, Dict from sisyphus import tk, Task, Job from i6_core.util import uopen @@ -239,6 +239,90 @@ class WriteSeqListFromShuffledJob(Job): The code is based on :class:`i6_core.corpus.segments.ShuffleAndSplitSegmentsJob`. """ + def __init__( + self, + *, + seq_tag_template: str, + num_seqs: Union[int, tk.Variable], + split: Dict[str, float], + selected_splits: Sequence[str], + shuffle=True, + shuffle_seed=0x3C5EA3E47D4E0077, + ): + """ + :param seq_tag_template: e.g. `"librispeech-lm-part{split_key}/recording_{split_seq_idx}/line_{split_seq_idx}"` + or `"line-{orig_seq_idx}"`. + :param num_seqs: total number of sequences + :param split: dict of split keys to split ratio + :param selected_splits: list of split keys to use. also specifies the order. + :param shuffle: whether to shuffle + :param shuffle_seed: seed for the shuffle + """ + assert isinstance(split, dict) + assert all(s > 0 for s in split.values()) + assert abs(sum(split.values()) - 1.0) < 1e-10 + + self.seq_tag_template = seq_tag_template + self.num_seqs = num_seqs + self.split = split + self.selected_splits = selected_splits + for key in selected_splits: + assert key in split + self.shuffle = shuffle + self.shuffle_seed = shuffle_seed + + self.out_segments = self.output_path("out.segments") + + def tasks(self): + yield Task("run", mini_task=True) + + def run(self): + import random + import itertools as it + + n = self.num_seqs + if isinstance(n, tk.Variable): + n = n.get() + assert isinstance(n, int) + segments = list(range(n)) + + if self.shuffle: + rng = random.Random(self.shuffle_seed) + rng.shuffle(segments) + + ordered_keys = sorted(self.split.keys()) + split_indices = [0] + [int(n * c) for c in it.accumulate(self.split[k] for k in ordered_keys)] + split_indices[-1] = n # just in case we get numeric errors that drop the last element + ordered_keys_indices = {k: i for i, k in enumerate(ordered_keys)} + + with uopen(self.out_segments.get_path(), "wt", encoding="utf-8") as f: + for split_key in self.selected_splits: + split_idx = ordered_keys_indices[split_key] + for split_seq_idx, orig_seq_idx in enumerate( + segments[split_indices[split_idx] : split_indices[split_idx + 1]] + ): + f.write( + self.seq_tag_template.format( + split_key=split_key, split_seq_idx=split_seq_idx, orig_seq_idx=orig_seq_idx + ) + + "\n" + ) + + +class WriteSeqListInOrigOrderFromShuffledJob(Job): + """ + Assuming that some dataset was split and shuffled using + :class:`i6_core.corpus.segments.ShuffleAndSplitSegmentsJob`, + and then merged again (via some more complex pipeline, + e.g. involving :class:`i6_core.returnn.oggzip.BlissToOggZipJob`), + we might get sequence tags like `"librispeech-lm-part{split_key}/recording_{split_seq_idx}/line_{split_seq_idx}"`. + + We write those seq tags out to a seq list file. + This file could be used for :class:`returnn.datasets.meta.MetaDataset` ``seq_list_file``. + + The code is based on :class:`i6_core.corpus.segments.ShuffleAndSplitSegmentsJob`. + """ + def __init__( self, *,