Skip to content

Commit

Permalink
Finalize Filtering Scripts; Refactor as Utils
Browse files Browse the repository at this point in the history
  • Loading branch information
w11wo committed Apr 5, 2024
1 parent a9eee58 commit c4244f9
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 122 deletions.
44 changes: 42 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Examples of these JSON files can be found in the [data](./data/openbible_swahili
Then to align and segment the chapter-level audio to verse-level audio, run the following script:

```sh
python segment_audio.py \
python scripts/segment_audio.py \
--audio_path downloads/wavs_44/PSA/PSA_119.wav \ # path to the audio file
--json_path data/openbible_swahili/PSA.json \ # path to the JSON file
--output_dir outputs/openbible_swahili/ \ # output directory
Expand Down Expand Up @@ -126,10 +126,50 @@ python scripts/run_segmentation.py \

Finally, we also provided a [bash script](./run_segmentation.sh) to segment all downloaded books.

### Probability-based Alignment Score Filtering

As proposed in §3.1.5 of [MMS](https://arxiv.org/abs/2305.13516), we implemented a length-normalized probability difference filtering to remove noisy alignments based on the following equation:

$$\frac{1}{T} \left[\log P\left(Y^{\text {aligned}} \mid X\right)-\log P\left(Y^{\text {greedy}} \mid X\right)\right]$$

where $T$ is the length of the audio, $P\left(Y^{\text{aligned}} \mid X\right)$ is the probability of the forced-alignment path, and $P\left(Y^{\text{greedy}} \mid X\right)$ is the probability of the greedy sequence.

Like MMS, we select `−0.2` as the default threshold and choose samples with scores greater than this threshold.

The filtering script can be run as follows:

```sh
# score: -0.005685280751179646 (good alignment; accept)
python scripts/filter_audio.py \
--audio_path outputs/openbible_swahili/EPH/EPH_003/EPH_003_001.wav \
--ground_truth "kwa sababu hii mimi paulo mfungwa wa kristo yesu kwa ajili yenu ninyi watu wa mataifa" \
--chunk_size_s 15

# score: -0.5496844846810868 (bad alignment; reject)
python scripts/filter_audio.py \
--audio_path outputs/openbible_swahili/EPH/EPH_001/EPH_001_020.wav \
--ground_truth "aliyoitumia katika kristo alipomfufua kutoka kwa wafu na akamketisha mkono wake wa kuume huko mbinguni" \
--chunk_size_s 15
```

Likewise, we also provided a [runner script](./scripts/run_filter.py) that can be used to segment all the audio files in a directory, typically for each book in the Bible. You can run it like follows:

```sh
python scripts/run_filter.py \
--audio_dir outputs/openbible_swahili/PSA/ \
--output_dir outputs/openbible_swahili_filtered/ \
--chunk_size_s 15 \
--probability_difference_threshold -0.2
```

It will then generate a new directory with the filtered audio segments, retaining the same directory structure.

Finally, we also provided a [bash script](./run_filter.sh) to filter generated segments for all books.

## Future Improvements

- [ ] Support chunk batching
- [ ] Probability-based alignment filtering
- [x] Probability-based alignment filtering
- [ ] CER-based filtering

## License
Expand Down
9 changes: 9 additions & 0 deletions run_filter.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/sh
audio_dir="outputs/openbible_swahili"
output_dir="outputs/openbible_swahili_filtered"

books="GEN EXO LEV NUM DEU JOS JDG RUT 1SA 2SA 1KI 2KI 1CH 2CH EZR NEH EST JOB PSA PRO ECC SNG ISA JER LAM EZK DAN HOS JOL AMO OBA JON MIC NAM HAB ZEP HAG ZEC MAL MAT MRK LUK JHN ACT ROM 1CO 2CO GAL EPH PHP COL 1TH 2TH 1TI 2TI TIT PHM HEB JAS 1PE 2PE 1JN 2JN 3JN JUD REV"

for book in $books; do
python scripts/run_filter.py --audio_dir $audio_dir/$book --output_dir $output_dir
done
78 changes: 8 additions & 70 deletions scripts/filter_audio.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,24 @@
from pathlib import Path
import argparse
import json
import re
import string

import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

from unidecode import unidecode
from num2words import num2words
import unicodedata
from utils import MMS_SUBSAMPLING_RATIO, preprocess_verse, compute_alignment_scores

parser = argparse.ArgumentParser()
parser.add_argument(
"--audio_path",
required=True,
help="Path to the audio file. Example: outputs/openbible_swahili/EPH/EPH_003/EPH_003_001.wav",
)
parser.add_argument(
"--json_path", required=True, help="Path to the JSON file. Example: data/openbible_swahili/EPH.json"
)
parser.add_argument("--output_dir", default="outputs/openbible_swahili/", help="Path to the output directory")
parser.add_argument("--ground_truth", required=True, help="Ground truth text to forced-align with.")
parser.add_argument("--chunk_size_s", type=int, default=15, help="Chunk size in seconds")

# MMS feature extractor minimum input frame size (25ms)
# also the same value as `ratio`
# `ratio = input_waveform.size(1) / num_frames`
SUBSAMPLING_RATIO = 400


def preprocess_verse(text: str) -> str:
text = unidecode(text)
text = unicodedata.normalize("NFKC", text)
text = text.lower()
text = text.translate(str.maketrans("", "", string.punctuation))
text = re.sub(r"\d+", lambda x: num2words(int(x.group(0)), lang="sw"), text)
text = re.sub("\s+", " ", text)
return text


def load_transcript(json_path: Path, verse: str) -> str:
with open(json_path, "r") as f:
data = json.load(f)

# convert PSA 19:1 -> PSA_019_001
get_verse = lambda x: x.split()[0] + "_" + x.split(":")[0].split()[1].zfill(3) + "_" + x.split(":")[1].zfill(3)
# filter by verse
transcript = [d["verseText"] for d in data if get_verse(d["verseNumber"]) == verse][0]
return transcript


###############################################################################################################
# functions modified from https://pytorch.org/audio/main/tutorials/ctc_forced_alignment_api_tutorial.html
###############################################################################################################


def align(emission, tokens, device):
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
alignments, scores = F.forced_align(emission, targets, blank=0)

alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
scores = scores.exp() # convert back to probability
return alignments, scores


def compute_alignment_scores(emission, transcript, dictionary, device):
tokens = [dictionary[char] for word in transcript for char in word]
_, scores = align(emission, tokens, device)
return scores


def compute_probability_difference(audio_path: str, json_path: str, chunk_size_s: int = 15) -> float:
def compute_probability_difference(audio_path: str, ground_truth: str, chunk_size_s: int = 15) -> float:
audio_path = Path(audio_path)
json_path = Path(json_path)

# verse_id = "MAT_019_001"
verse_id = audio_path.stem

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -86,10 +27,8 @@ def compute_probability_difference(audio_path: str, json_path: str, chunk_size_s
model = bundle.get_model(with_star=False).to(device)
DICTIONARY = bundle.get_dict(star=None)

# load transcript
transcript = load_transcript(json_path, verse_id)
# apply preprocessing
verse = preprocess_verse(transcript)
verse = preprocess_verse(ground_truth)
words = verse.split()

# load audio
Expand All @@ -109,7 +48,7 @@ def compute_probability_difference(audio_path: str, json_path: str, chunk_size_s
for chunk in chunks:
# NOTE: we could pad here, but it'll need to be removed later
# skipping for simplicity, since it's at most 25ms
if chunk.size(1) >= SUBSAMPLING_RATIO:
if chunk.size(1) >= MMS_SUBSAMPLING_RATIO:
emission, _ = model(chunk.to(device))
emissions.append(emission)

Expand All @@ -126,7 +65,7 @@ def compute_probability_difference(audio_path: str, json_path: str, chunk_size_s
greedy_log_probs = torch.sum(torch.log(greedy_probs)).cpu().numpy().item()

# compute forced-alignment score
aligned_probs = compute_alignment_scores(emission, words, DICTIONARY)
aligned_probs = compute_alignment_scores(emission, words, DICTIONARY, device)
aligned_log_probs = torch.sum(torch.log(aligned_probs)).cpu().numpy().item()

# compute length-normalized probability difference
Expand All @@ -137,6 +76,5 @@ def compute_probability_difference(audio_path: str, json_path: str, chunk_size_s

if __name__ == "__main__":
args = parser.parse_args()
probability_diff = compute_probability_difference(
args.audio_path, args.json_path, args.output_dir, args.chunk_size_s
)
probability_difference = compute_probability_difference(args.audio_path, args.ground_truth, args.chunk_size_s)
print(probability_difference)
56 changes: 56 additions & 0 deletions scripts/run_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from pathlib import Path
import argparse
import shutil

from tqdm.auto import tqdm

from filter_audio import compute_probability_difference

parser = argparse.ArgumentParser()
parser.add_argument(
"--audio_dir",
required=True,
help="Path to the audio directory, must contain txt files with the same name. Example: outputs/openbible_swahili/PSA/",
)
parser.add_argument("--output_dir", default="outputs/openbible_swahili_filtered/", help="Path to the output directory")
parser.add_argument("--chunk_size_s", type=int, default=15, help="Chunk size in seconds")
parser.add_argument(
"--probability_difference_threshold",
type=float,
default=-0.2,
help="Probability difference threshold for filtering. Default: -0.2 from MMS.",
)


def main(args):
audio_dir = Path(args.audio_dir)
audios = sorted(audio_dir.rglob("*/*.wav"))
for audio_path in tqdm(audios, desc=f"Filtering {audio_dir.stem}"):
transcript_path = audio_path.with_suffix(".txt")
# create output directory `output_dir/{book}/{chapter}/`
output_path = Path(args.output_dir) / audio_dir.stem / audio_path.parent.stem
output_audio_path = output_path / audio_path.name
output_transcript_path = output_path / transcript_path.name
output_audio_path.parent.mkdir(parents=True, exist_ok=True)

# skip if already filtered
if output_audio_path.exists() and output_transcript_path.exists():
print(f"Skipping {audio_path.stem}")
continue

# read ground truth
with open(transcript_path) as f:
ground_truth = f.read()

# compute probability difference
probability_difference = compute_probability_difference(audio_path, ground_truth, args.chunk_size_s)

# copy audio and transcript if probability_difference is greater than threshold
if probability_difference > args.probability_difference_threshold:
shutil.copy(audio_path, output_audio_path)
shutil.copy(transcript_path, output_transcript_path)


if __name__ == "__main__":
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion scripts/run_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
parser.add_argument(
"--audio_dir",
required=True,
help="Path to the audio file, must be 16kHz. Example: downloads/wavs_16/PSA/",
help="Path to the audio directory. Example: downloads/wavs_16/PSA/",
)
parser.add_argument("--output_dir", default="outputs/openbible_swahili/", help="Path to the output directory")
parser.add_argument("--chunk_size_s", type=int, default=15, help="Chunk size in seconds")
Expand Down
53 changes: 4 additions & 49 deletions scripts/segment_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@
from pathlib import Path
import argparse
import json
import re
import string

import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

from unidecode import unidecode
from num2words import num2words
from scipy.io.wavfile import write
import unicodedata

from utils import MMS_SUBSAMPLING_RATIO, preprocess_verse, compute_alignments


parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -33,16 +30,6 @@
SUBSAMPLING_RATIO = 400


def preprocess_verse(text: str) -> str:
text = unidecode(text)
text = unicodedata.normalize("NFKC", text)
text = text.lower()
text = text.translate(str.maketrans("", "", string.punctuation))
text = re.sub(r"\d+", lambda x: num2words(int(x.group(0)), lang="sw"), text)
text = re.sub("\s+", " ", text)
return text


def load_transcripts(json_path: Path, chapter: str) -> Tuple[List[str], List[str]]:
with open(json_path, "r") as f:
data = json.load(f)
Expand All @@ -55,38 +42,6 @@ def load_transcripts(json_path: Path, chapter: str) -> Tuple[List[str], List[str
return verse_ids, transcripts


###############################################################################################################
# functions taken from https://pytorch.org/audio/main/tutorials/ctc_forced_alignment_api_tutorial.html
###############################################################################################################


def align(emission, tokens, device):
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
alignments, scores = F.forced_align(emission, targets, blank=0)

alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
scores = scores.exp() # convert back to probability
return alignments, scores


def unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret


def compute_alignments(emission, transcript, dictionary, device):
tokens = [dictionary[char] for word in transcript for char in word]
alignment, scores = align(emission, tokens, device)
token_spans = F.merge_tokens(alignment, scores)
word_spans = unflatten(token_spans, [len(word) for word in transcript])
return word_spans


def segment(audio_path: str, json_path: str, output_dir: str, chunk_size_s: int = 15):
audio_path = Path(audio_path)
json_path = Path(json_path)
Expand Down Expand Up @@ -141,7 +96,7 @@ def segment(audio_path: str, json_path: str, output_dir: str, chunk_size_s: int
for chunk in chunks:
# NOTE: we could pad here, but it'll need to be removed later
# skipping for simplicity, since it's at most 25ms
if chunk.size(1) >= SUBSAMPLING_RATIO:
if chunk.size(1) >= MMS_SUBSAMPLING_RATIO:
emission, _ = model(chunk.to(device))
emissions.append(emission)

Expand Down
Loading

0 comments on commit c4244f9

Please sign in to comment.