Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding stft-based specaugment for CTC #261

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,15 @@ def get_returnn_config(
if audio_perturbation:
prolog += get_code_for_perturbation()
for layer in list(network.keys()):
if layer in ("stft", "istft", "wave_input"):
continue
if network[layer]["from"] == "data":
network[layer]["from"] = "features"
elif isinstance(network[layer]["from"], list) and "data" in network[layer]["from"]:
assert len(network[layer]["from"]) == 1
network[layer]["from"] = "features"
network["features"] = feature_net
feature_net["from"] = "wave_input"
if recognition:
for layer in list(network.keys()):
if "aux" in layer:
Expand Down
120 changes: 120 additions & 0 deletions users/vieting/experiments/switchboard/ctc/feat/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,124 @@ def run_mel_audio_perturbation_from_checkpoint():
return report


def run_stft_experiments():
gs.ALIAS_AND_OUTPUT_SUBDIR = "experiments/switchboard/ctc/feat/"

(
returnn_datasets,
rasr_loss_corpus_path,
rasr_loss_corpus_segments,
rasr_loss_lexicon_path,
dev_corpora,
) = get_datasets()
returnn_args = {
"batch_size": 5000,
"rasr_binary_path": RASR_BINARY_PATH,
"rasr_loss_corpus_path": rasr_loss_corpus_path,
"rasr_loss_corpus_segments": rasr_loss_corpus_segments,
"rasr_loss_lexicon_path": rasr_loss_lexicon_path,
"datasets": returnn_datasets,
"extra_args": {
"accum_grad_multiple_step": 2,
"conv_pad_seq_len_to_power": 1.5,
},
"conformer_type": "wei",
}
feature_args = {"class": "ScfNetwork", "size_tf": 256 // 2, "stride_tf": 10 // 2, "preemphasis": 0.97}
feature_args_lgm = {
"class": "LogMelNetwork",
"wave_norm": True,
"frame_size": 200,
"frame_shift": 80,
"fft_size": 256,
}
lr_args = {
"peak_lr": 4e-4,
"start_lr": 1.325e-05,
"end_lr": 1e-5,
"increase_epochs": 180,
"decrease_epochs": 180,
"final_epochs": 0,
}

nn_args, report_args_collection = get_nn_args_baseline(
nn_base_args={
"bs2x5k_scf_stft_time_only": dict(
returnn_args={
**returnn_args,
"specaug_old": {"max_feature": 0, "max_feature_num": 0, "stft": True},
},
feature_args=feature_args,
lr_args=lr_args,
report_args={"batch_size": "2x5k", "stft": True},
),
"bs2x5k_scf_stft_mask_1_1": dict(
returnn_args={
**returnn_args,
"specaug_old": {"max_feature": 1, "max_feature_num": 1, "stft": True},
},
feature_args=feature_args,
lr_args=lr_args,
report_args={"batch_size": "2x5k", "stft": True},
),
"bs2x5k_scf_stft_mask_2_4": dict(
returnn_args={
**returnn_args,
"specaug_old": {"max_feature": 4, "max_feature_num": 2, "stft": True},
},
feature_args=feature_args,
lr_args=lr_args,
report_args={"batch_size": "2x5k", "stft": True},
),
"bs2x5k_scf_stft_mask_5_8": dict(
returnn_args={
**returnn_args,
"specaug_old": {"max_feature": 8, "stft": True},
},
feature_args=feature_args,
lr_args=lr_args,
report_args={"batch_size": "2x5k", "stft": True},
),
"bs2x5k_scf_stft_mask_5_15": dict(
returnn_args={
**returnn_args,
"specaug_old": {"max_feature": 15, "stft": True},
},
feature_args=feature_args,
lr_args=lr_args,
report_args={"batch_size": "2x5k", "stft": True},
),
"bs2x5k_mel_stft_mask_5_8": dict(
returnn_args={
**returnn_args,
"specaug_old": {"max_feature": 8, "stft": True},
},
feature_args=feature_args_lgm,
lr_args=lr_args,
report_args={"batch_size": "2x5k", "stft": True},
),
},
num_epochs=450,
evaluation_epochs=[350, 390, 400, 410, 450],
prefix="conformer_",
)

returnn_root = CloneGitRepositoryJob(
"https://github.com/rwth-i6/returnn",
commit="c4d36d06f6465e82a50d400d114259e07b8b0709",
).out_repository
returnn_root.hash_overwrite = "returnn_conv_padding"
report, ctc_nn_system = run_nn_args(
nn_args,
report_args_collection,
dev_corpora,
"report_stft",
returnn_root=returnn_root,
recog_args={"epochs": [350, 390, 400, 410, 450]},
)
return report, ctc_nn_system


def py():
"""
called if the file is passed to sis manager, used to run all experiments (replacement for main)
Expand All @@ -1131,6 +1249,7 @@ def py():
report_scf_specaug_sort = run_scf_specaug_sort()
report_scf_audio_perturbation_from_checkpoint = run_scf_audio_perturbation_from_checkpoint()
report_mel_audio_perturbation_from_checkpoint = run_mel_audio_perturbation_from_checkpoint()
report_stft = run_stft_experiments()

report_base = Report(
columns_start=["train_name", "batch_size"],
Expand All @@ -1144,6 +1263,7 @@ def py():
report_scf_specaug_sort,
report_scf_audio_perturbation_from_checkpoint,
report_mel_audio_perturbation_from_checkpoint,
report_stft,
]
)
tk.register_report(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,57 @@ def make_conformer_fullsum_ctc_model(

if recognition:
python_code = []
network["wave_input"] = {"class": "copy", "from": "data"}
else:
if specaug_old is not None:
assert specaug_config is None
sort_layer2 = specaug_old.pop("sort_layer2", False)
specaug_func = add_specaug_layer_sort_layer2 if sort_layer2 else add_specaug_layer
specaug_old_args = {
"max_time_num": 1,
"max_time": 15,
"max_feature_num": 5,
"max_feature": 4,
**specaug_old,
}
from_list, python_code = specaug_func(network, from_list=from_list, **specaug_old_args)
if specaug_old.get("stft", False):
specaug_old_args = {
"max_time_num": 1,
"max_time": 15,
"max_feature_num": 5,
"max_feature": 4,
**{k: v for k, v in specaug_old.items() if k != "stft"},
}
# Add STFT layer
network["stft"] = {
"class": "stft",
"from": ["data"],
"frame_size": 400,
"frame_shift": 160,
"fft_size": 512,
}
from_list = ["stft"]

specaug_func = add_specaug_layer
from_list, python_code = specaug_func(network, from_list=from_list, **specaug_old_args)

# Add iSTFT layer
network["istft"] = {
"class": "istft",
"from": from_list,
"frame_size": 400,
"frame_shift": 160,
"fft_size": 512,
}
network["wave_input"] = {"class": "copy", "from": "istft"}
else:
assert specaug_config is None
sort_layer2 = specaug_old.pop("sort_layer2", False)
specaug_func = add_specaug_layer_sort_layer2 if sort_layer2 else add_specaug_layer
specaug_old_args = {
"max_time_num": 1,
"max_time": 15,
"max_feature_num": 5,
"max_feature": 4,
**specaug_old,
}
from_list, python_code = specaug_func(network, from_list=from_list, **specaug_old_args)
network["wave_input"] = {"class": "copy", "from": "data"}
elif specaug_config is not None:
assert specaug_old is None
from_list, python_code = add_specaug_layer_configurable(network, from_list=from_list, num_epochs=num_epochs, config=specaug_config)
from_list, python_code = add_specaug_layer_configurable(
network, from_list=from_list, num_epochs=num_epochs, config=specaug_config
)
else:
from_list, python_code = add_specaug_layer_v2(network, from_list=from_list)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _mask(x, batch_axis, axis, pos, max_amount):
)
from TFUtil import where_bc

x = where_bc(cond, 0.0, x)
x = where_bc(cond, tf.constant(0.0, dtype=x.dtype), x)
return x


Expand Down
Loading