Skip to content

Commit ecca8ee

Browse files
author
Maximilian Azevedo
committed
max thesis code
1 parent 8ed59ff commit ecca8ee

File tree

165 files changed

+36002
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

165 files changed

+36002
-0
lines changed

users/azevedo/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
**/.DS_Store
2+
**/__pycache__

users/azevedo/__init__.py

Whitespace-only changes.

users/azevedo/experiments/__init__.py

Whitespace-only changes.

users/azevedo/experiments/librispeech/__init__.py

Whitespace-only changes.

users/azevedo/experiments/librispeech/ctc_rnnt_standalone_2024/README.md

Whitespace-only changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
Uses as "root" point for hashing related to everything in "pytorch_networks"
3+
"""
4+
PACKAGE = __package__
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""
2+
Universal helpers to create configuration objects (i6_core ReturnnConfig) for RETURNN training/forwarding
3+
"""
4+
import copy
5+
from typing import Any, Dict, Optional
6+
7+
from i6_core.returnn.config import ReturnnConfig, CodeWrapper
8+
9+
from i6_experiments.common.setups.returnn_pytorch.serialization import (
10+
Collection as TorchCollection,
11+
)
12+
from i6_experiments.common.setups.serialization import Import
13+
from .data.common import TrainingDatasets
14+
from .serializer import serialize_training, serialize_forward, PACKAGE
15+
16+
17+
def get_training_config(
18+
training_datasets: TrainingDatasets,
19+
network_module: str,
20+
config: Dict[str, Any],
21+
net_args: Dict[str, Any],
22+
unhashed_net_args: Optional[Dict[str, Any]] = None,
23+
include_native_ops=False,
24+
debug: bool = False,
25+
use_speed_perturbation: bool = False,
26+
post_config: Optional[Dict[str, Any]] = None,
27+
) -> ReturnnConfig:
28+
"""
29+
Get a generic config for training a model
30+
31+
:param training_datasets: datasets for training
32+
:param network_module: path to the pytorch config file containing Model
33+
:param net_args: extra arguments for constructing the PyTorch model
34+
:param unhashed_net_args: unhashed extra arguments for constructing the PyTorch model
35+
:param config: config arguments for RETURNN
36+
:param debug: run training in debug mode (linking from recipe instead of copy)
37+
:param use_speed_perturbation: Use speedperturbation in the training
38+
:param post_config: Add non-hashed arguments for RETURNN
39+
"""
40+
41+
# changing these does not change the hash
42+
base_post_config = {"stop_on_nonfinite_train_score": True, "num_workers_per_gpu": 2, "backend": "torch"}
43+
44+
# TODO: test
45+
base_config = {
46+
"cleanup_old_models": {
47+
"keep_last_n": 4,
48+
"keep_best_n": 4,
49+
"keep": [1, 10, 125]
50+
},
51+
#############
52+
"train": copy.deepcopy(training_datasets.train.as_returnn_opts()),
53+
"dev": training_datasets.cv.as_returnn_opts(),
54+
"eval_datasets": {"devtrain": training_datasets.devtrain.as_returnn_opts()},
55+
}
56+
config = {**base_config, **copy.deepcopy(config)}
57+
post_config = {**base_post_config, **copy.deepcopy(post_config or {})}
58+
59+
serializer = serialize_training(
60+
network_module=network_module,
61+
net_args=net_args,
62+
unhashed_net_args=unhashed_net_args,
63+
include_native_ops=include_native_ops,
64+
debug=debug,
65+
)
66+
python_prolog = None
67+
68+
# TODO: maybe make nice (if capability added to RETURNN itself)
69+
if use_speed_perturbation:
70+
prolog_serializer = TorchCollection(
71+
serializer_objects=[
72+
Import(
73+
code_object_path=PACKAGE + ".extra_code.speed_perturbation.legacy_speed_perturbation",
74+
unhashed_package_root=PACKAGE,
75+
)
76+
]
77+
)
78+
python_prolog = [prolog_serializer]
79+
config["train"]["datasets"]["zip_dataset"]["audio"]["pre_process"] = CodeWrapper("legacy_speed_perturbation")
80+
81+
returnn_config = ReturnnConfig(
82+
config=config, post_config=post_config, python_prolog=python_prolog, python_epilog=[serializer]
83+
)
84+
return returnn_config
85+
86+
87+
def get_prior_config(
88+
training_datasets: TrainingDatasets, # TODO: replace by single dataset
89+
network_module: str,
90+
config: Dict[str, Any],
91+
net_args: Dict[str, Any],
92+
unhashed_net_args: Optional[Dict[str, Any]] = None,
93+
debug: bool = False,
94+
):
95+
"""
96+
Get a generic config for extracting output label priors
97+
98+
:param training_datasets: datasets for training
99+
:param network_module: path to the pytorch config file containing Model
100+
:param config: config arguments for RETURNN
101+
:param net_args: extra arguments for constructing the PyTorch model
102+
:param unhashed_net_args: unhashed extra arguments for constructing the PyTorch model
103+
:param debug: run training in debug mode (linking from recipe instead of copy)
104+
"""
105+
106+
# changing these does not change the hash
107+
post_config = {
108+
"num_workers_per_gpu": 2,
109+
}
110+
111+
base_config = {
112+
#############
113+
"batch_size": 500 * 16000,
114+
"max_seqs": 240,
115+
#############
116+
"forward": copy.deepcopy(training_datasets.prior.as_returnn_opts()),
117+
}
118+
config = {**base_config, **copy.deepcopy(config)}
119+
post_config["backend"] = "torch"
120+
121+
serializer = serialize_forward(
122+
network_module=network_module,
123+
net_args=net_args,
124+
unhashed_net_args=unhashed_net_args,
125+
forward_module=None, # same as network
126+
forward_step_name="prior",
127+
forward_init_args=None,
128+
unhashed_forward_init_args=None,
129+
debug=debug,
130+
)
131+
returnn_config = ReturnnConfig(config=config, post_config=post_config, python_epilog=[serializer])
132+
return returnn_config
133+
134+
135+
def get_forward_config(
136+
network_module: str,
137+
config: Dict[str, Any],
138+
net_args: Dict[str, Any],
139+
decoder: str,
140+
decoder_args: Dict[str, Any],
141+
unhashed_decoder_args: Optional[Dict[str, Any]] = None,
142+
unhashed_net_args: Optional[Dict[str, Any]] = None,
143+
debug: bool = False,
144+
) -> ReturnnConfig:
145+
"""
146+
Get a generic config for forwarding
147+
148+
:param network_module: path to the pytorch config file containing Model
149+
:param net_args: extra arguments for constructing the PyTorch model
150+
:param decoder: which (python) file to load which defines the forward, forward_init and forward_finish functions
151+
:param decoder_args: extra arguments to pass to forward_init
152+
:param config: config arguments for RETURNN
153+
:param unhashed_decoder_args: unhashed extra arguments for the forward init
154+
:param unhashed_net_args: unhashed extra arguments for constructing the PyTorch model
155+
:param debug: run training in debug mode (linking from recipe instead of copy)
156+
"""
157+
158+
# changing these does not change the hash
159+
post_config = {}
160+
161+
# changeing these does change the hash
162+
base_config = {
163+
"batch_size": 1000 * 16000,
164+
"max_seqs": 240,
165+
}
166+
config = {**base_config, **copy.deepcopy(config)}
167+
post_config["backend"] = "torch"
168+
169+
serializer = serialize_forward(
170+
network_module=network_module,
171+
net_args=net_args,
172+
unhashed_net_args=unhashed_net_args,
173+
forward_module=decoder,
174+
forward_init_args=decoder_args,
175+
unhashed_forward_init_args=unhashed_decoder_args,
176+
debug=debug,
177+
)
178+
returnn_config = ReturnnConfig(config=config, post_config=post_config, python_epilog=[serializer])
179+
return returnn_config

users/azevedo/experiments/librispeech/ctc_rnnt_standalone_2024/data/__init__.py

Whitespace-only changes.
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""
2+
Dataset helpers for the BPE-based training
3+
"""
4+
from sisyphus import tk
5+
6+
from i6_core.g2p.convert import BlissLexiconToG2PLexiconJob
7+
from i6_core.lexicon.bpe import CreateBPELexiconJob
8+
9+
from i6_experiments.common.datasets.librispeech import get_ogg_zip_dict, get_bliss_lexicon
10+
from i6_experiments.common.datasets.librispeech.vocab import get_subword_nmt_bpe_v2
11+
from i6_experiments.common.setups.returnn.datastreams.vocabulary import BpeDatastream
12+
13+
from .common import DatasetSettings, TrainingDatasets, build_training_datasets
14+
from ..default_tools import MINI_RETURNN_ROOT, RETURNN_EXE, SUBWORD_NMT_REPO
15+
16+
17+
def get_bpe_datastream(librispeech_key: str, bpe_size: int, is_recog: bool, use_postfix: bool) -> BpeDatastream:
18+
"""
19+
Returns the datastream for the bpe labels
20+
21+
Uses the legacy BPE setup that is compatible with old LM models
22+
23+
:param librispeech_key: which librispeech corpus to use for bpe training
24+
:param bpe_size: size for the bpe labels
25+
:param is_recog: removes the UNK label when not in training
26+
:param use_postfix: True for RNN-T or Attention, False for CTC
27+
"""
28+
bpe_settings = get_subword_nmt_bpe_v2(corpus_key=librispeech_key, bpe_size=bpe_size, unk_label="<unk>")
29+
30+
bpe_targets = BpeDatastream(
31+
available_for_inference=False,
32+
bpe_settings=bpe_settings,
33+
use_unk_label=is_recog,
34+
seq_postfix=0 if use_postfix else None,
35+
)
36+
return bpe_targets
37+
38+
39+
def get_bpe_lexicon(librispeech_key: str, bpe_size: int) -> tk.Path:
40+
"""
41+
Create BPE lexicon without unknown and silence
42+
43+
:param librispeech_key: which librispeech corpus to use for bpe training
44+
:param bpe_size: number of BPE splits
45+
:return: path to a lexicon bliss xml file
46+
"""
47+
bpe_settings = get_subword_nmt_bpe_v2(corpus_key=librispeech_key, bpe_size=bpe_size, unk_label="<unk>")
48+
bpe_lexicon = CreateBPELexiconJob(
49+
base_lexicon_path=get_bliss_lexicon(add_unknown_phoneme_and_mapping=False, add_silence=False),
50+
bpe_codes=bpe_settings.bpe_codes,
51+
bpe_vocab=bpe_settings.bpe_vocab,
52+
subword_nmt_repo=SUBWORD_NMT_REPO,
53+
unk_label="<unk>",
54+
).out_lexicon
55+
56+
return bpe_lexicon
57+
58+
59+
def get_text_lexicon(prefix: str, librispeech_key: str, bpe_size: int) -> tk.Path:
60+
"""
61+
Get a bpe lexicon in line-based text format to be used for torchaudio/Flashlight decoding
62+
63+
:param prefix:
64+
:param librispeech_key: which librispeech corpus to use for bpe training
65+
:param bpe_size: number of BPE splits
66+
:return: path to a lexicon text file
67+
"""
68+
bliss_lex = get_bpe_lexicon(librispeech_key=librispeech_key, bpe_size=bpe_size)
69+
word_lexicon = BlissLexiconToG2PLexiconJob(
70+
bliss_lex,
71+
include_pronunciation_variants=True,
72+
include_orthography_variants=True,
73+
).out_g2p_lexicon
74+
return word_lexicon
75+
76+
77+
def build_bpe_training_datasets(
78+
prefix: str,
79+
librispeech_key: str,
80+
bpe_size: int,
81+
settings: DatasetSettings,
82+
use_postfix: bool,
83+
) -> TrainingDatasets:
84+
"""
85+
86+
:param librispeech_key: which librispeech corpus to use for bpe training
87+
:param bpe_size: number of BPE splits
88+
:param settings: configuration object for the dataset pipeline
89+
:param use_postfix: True for RNN-T or Attention, False for CTC
90+
"""
91+
label_datastream = get_bpe_datastream(
92+
librispeech_key=librispeech_key, bpe_size=bpe_size, is_recog=False, use_postfix=use_postfix
93+
)
94+
95+
ogg_zip_dict = get_ogg_zip_dict(prefix, returnn_root=MINI_RETURNN_ROOT, returnn_python_exe=RETURNN_EXE)
96+
train_ogg = ogg_zip_dict[librispeech_key]
97+
dev_clean_ogg = ogg_zip_dict["dev-clean"]
98+
dev_other_ogg = ogg_zip_dict["dev-other"]
99+
100+
return build_training_datasets(
101+
train_ogg=train_ogg,
102+
dev_clean_ogg=dev_clean_ogg,
103+
dev_other_ogg=dev_other_ogg,
104+
settings=settings,
105+
label_datastream=label_datastream,
106+
)

0 commit comments

Comments
 (0)