Skip to content

Commit

Permalink
new WriteSeqListFromShuffledJob, WriteSeqListInOrigOrderFromShuffledJob
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 3, 2024
1 parent 33f59f5 commit 4d51834
Showing 1 changed file with 85 additions and 1 deletion.
86 changes: 85 additions & 1 deletion users/zeyer/datasets/utils/extract_seq_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Extract/generate seq lists.
"""

from typing import Optional, Union, Any, Dict
from typing import Optional, Union, Any, Sequence, Dict
from sisyphus import tk, Task, Job
from i6_core.util import uopen

Expand Down Expand Up @@ -239,6 +239,90 @@ class WriteSeqListFromShuffledJob(Job):
The code is based on :class:`i6_core.corpus.segments.ShuffleAndSplitSegmentsJob`.
"""

def __init__(
self,
*,
seq_tag_template: str,
num_seqs: Union[int, tk.Variable],
split: Dict[str, float],
selected_splits: Sequence[str],
shuffle=True,
shuffle_seed=0x3C5EA3E47D4E0077,
):
"""
:param seq_tag_template: e.g. `"librispeech-lm-part{split_key}/recording_{split_seq_idx}/line_{split_seq_idx}"`
or `"line-{orig_seq_idx}"`.
:param num_seqs: total number of sequences
:param split: dict of split keys to split ratio
:param selected_splits: list of split keys to use. also specifies the order.
:param shuffle: whether to shuffle
:param shuffle_seed: seed for the shuffle
"""
assert isinstance(split, dict)
assert all(s > 0 for s in split.values())
assert abs(sum(split.values()) - 1.0) < 1e-10

self.seq_tag_template = seq_tag_template
self.num_seqs = num_seqs
self.split = split
self.selected_splits = selected_splits
for key in selected_splits:
assert key in split
self.shuffle = shuffle
self.shuffle_seed = shuffle_seed

self.out_segments = self.output_path("out.segments")

def tasks(self):
yield Task("run", mini_task=True)

def run(self):
import random
import itertools as it

n = self.num_seqs
if isinstance(n, tk.Variable):
n = n.get()
assert isinstance(n, int)
segments = list(range(n))

if self.shuffle:
rng = random.Random(self.shuffle_seed)
rng.shuffle(segments)

ordered_keys = sorted(self.split.keys())
split_indices = [0] + [int(n * c) for c in it.accumulate(self.split[k] for k in ordered_keys)]
split_indices[-1] = n # just in case we get numeric errors that drop the last element
ordered_keys_indices = {k: i for i, k in enumerate(ordered_keys)}

with uopen(self.out_segments.get_path(), "wt", encoding="utf-8") as f:
for split_key in self.selected_splits:
split_idx = ordered_keys_indices[split_key]
for split_seq_idx, orig_seq_idx in enumerate(
segments[split_indices[split_idx] : split_indices[split_idx + 1]]
):
f.write(
self.seq_tag_template.format(
split_key=split_key, split_seq_idx=split_seq_idx, orig_seq_idx=orig_seq_idx
)
+ "\n"
)


class WriteSeqListInOrigOrderFromShuffledJob(Job):
"""
Assuming that some dataset was split and shuffled using
:class:`i6_core.corpus.segments.ShuffleAndSplitSegmentsJob`,
and then merged again (via some more complex pipeline,
e.g. involving :class:`i6_core.returnn.oggzip.BlissToOggZipJob`),
we might get sequence tags like `"librispeech-lm-part{split_key}/recording_{split_seq_idx}/line_{split_seq_idx}"`.
We write those seq tags out to a seq list file.
This file could be used for :class:`returnn.datasets.meta.MetaDataset` ``seq_list_file``.
The code is based on :class:`i6_core.corpus.segments.ShuffleAndSplitSegmentsJob`.
"""

def __init__(
self,
*,
Expand Down

0 comments on commit 4d51834

Please sign in to comment.