-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsegment_audio.py
143 lines (113 loc) · 5.17 KB
/
segment_audio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from typing import List, Tuple
from pathlib import Path
import argparse
import json
import torch
import torchaudio
import torchaudio.transforms as T
from scipy.io.wavfile import write
from utils import MMS_SUBSAMPLING_RATIO, preprocess_verse, compute_alignments
parser = argparse.ArgumentParser()
parser.add_argument(
"--audio_path",
required=True,
help="Path to the audio file. Example: downloads/wavs_44/PSA/PSA_119.wav",
)
parser.add_argument(
"--json_path", required=True, help="Path to the JSON file. Example: data/openbible_swahili/PSA.json"
)
parser.add_argument("--output_dir", default="outputs/openbible_swahili/", help="Path to the output directory")
parser.add_argument("--chunk_size_s", type=int, default=15, help="Chunk size in seconds")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load MMS aligner model
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model(with_star=True).to(device)
DICTIONARY = bundle.get_dict()
def load_transcripts(json_path: Path, chapter: str) -> Tuple[List[str], List[str]]:
with open(json_path, "r") as f:
data = json.load(f)
# convert PSA 19:1 -> PSA_019
get_chapter = lambda x: x.split()[0] + "_" + x.split(":")[0].split()[1].zfill(3)
# filter by book and chapter
transcripts = [d["verseText"] for d in data if get_chapter(d["verseNumber"]) == chapter]
verse_ids = [d["verseNumber"] for d in data if get_chapter(d["verseNumber"]) == chapter]
return verse_ids, transcripts
def segment(audio_path: str, json_path: str, output_dir: str, chunk_size_s: int = 15):
audio_path = Path(audio_path)
json_path = Path(json_path)
# book = "MAT"; chapter = "MAT_019"
book, chapter = json_path.stem, audio_path.stem
# prepare output directories
output_dir = Path(output_dir) / book / chapter
output_dir.mkdir(parents=True, exist_ok=True)
# skip if already segmented
if any(output_dir.iterdir()):
print(f"Skipping {chapter}")
return
# load transcripts
verse_ids, transcripts = load_transcripts(json_path, chapter)
# apply preprocessing
verses = [preprocess_verse(v) for v in transcripts]
# insert "*" before every verse for chapter intro or verse number
# see MMS robust noisy audio alignment
# https://pytorch.org/audio/main/tutorials/ctc_forced_alignment_api_tutorial.html
augmented_verses = ["*"] * len(verses) * 2
augmented_verses[1::2] = verses
words = [verse.split() for verse in verses]
augmented_words = [word for verse in augmented_verses for word in verse.split()]
# load audio
input_waveform, input_sample_rate = torchaudio.load(audio_path)
resampler = T.Resample(input_sample_rate, bundle.sample_rate, dtype=input_waveform.dtype)
resampled_waveform = resampler(input_waveform)
# split audio into chunks to avoid OOM and faster inference
chunk_size_frames = chunk_size_s * bundle.sample_rate
chunks = [
resampled_waveform[:, i : i + chunk_size_frames]
for i in range(0, resampled_waveform.shape[1], chunk_size_frames)
]
# collect per-chunk emissions, rejoin
emissions = []
with torch.inference_mode():
for chunk in chunks:
# NOTE: we could pad here, but it'll need to be removed later
# skipping for simplicity, since it's at most 25ms
if chunk.size(1) >= MMS_SUBSAMPLING_RATIO:
emission, _ = model(chunk.to(device))
emissions.append(emission)
emission = torch.cat(emissions, dim=1)
num_frames = emission.size(1)
assert len(DICTIONARY) == emission.shape[2]
# perform forced-alignment
word_spans = compute_alignments(emission, augmented_words, DICTIONARY, device)
# remove "*" from alignment
word_only_spans = [spans for spans, word in zip(word_spans, augmented_words) if word != "*"]
assert len(word_only_spans) == sum(len(word) for word in words)
# collect verse-level segments
segments, labels, start = [], [], 0
for verse_words in words:
end = start + len(verse_words)
verse_spans = word_only_spans[start:end]
ratio = input_waveform.size(1) / num_frames
x0 = int(ratio * verse_spans[0][0].start)
x1 = int(ratio * verse_spans[-1][-1].end)
transcript = " ".join(verse_words)
segment = input_waveform[:, x0:x1]
start = end
segments.append(segment)
labels.append(transcript)
assert len(segments) == len(verse_ids) == len(labels)
# export segments and forced-aligned transcripts
for verse_id, segment, label in zip(verse_ids, segments, labels):
# PSA 19:1 -> PSA_019_001
verse_number = verse_id.split(":")[-1].zfill(3)
verse_file_name = chapter + "_" + verse_number
# write audio
audio_path = (output_dir / verse_file_name).with_suffix(".wav")
write(audio_path, input_sample_rate, segment.squeeze().numpy())
# write transcript
transcript_path = (output_dir / verse_file_name).with_suffix(".txt")
with open(transcript_path, "w") as f:
f.write(label)
if __name__ == "__main__":
args = parser.parse_args()
segment(args.audio_path, args.json_path, args.output_dir, args.chunk_size_s)