Skip to content

Commit

Permalink
add refined beta_binomial_prior_distribution func and test it
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaozhah committed Jul 17, 2024
1 parent 21f0047 commit 850d855
Showing 1 changed file with 84 additions and 55 deletions.
139 changes: 84 additions & 55 deletions prior.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
import os
import numpy as np
from scipy.stats import betabinom
import time

# from https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/helpers.py

def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
"""
Calculate the Beta-Binomial prior distribution for alignment.
Args:
phoneme_count (int): Number of phonemes.
mel_count (int): Number of mel spectrogram frames.
scaling_factor (float): Scaling factor for the distribution, default is 1.0.
Returns:
np.array: 2D array of prior probabilities [mel_count, phoneme_count].
"""
# Original implementation
def original_beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
P, M = phoneme_count, mel_count
x = np.arange(0, P)
mel_text_probs = []
Expand All @@ -26,51 +14,88 @@ def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.
mel_text_probs.append(mel_i_prob)
return np.array(mel_text_probs)

def compute_attn_prior(x_len, y_len, scaling_factor=1.0):
"""
Compute attention priors for the alignment network.
# Refined NumPy implementation
def refined_beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
P, M = phoneme_count, mel_count
m = np.arange(1, M + 1)[:, np.newaxis]
x = np.arange(P)
a = scaling_factor * m
b = scaling_factor * (M + 1 - m)
probs = betabinom.pmf(x, P, a, b)
return probs

# Cython implementation (assuming it's compiled and imported)
# from beta_binomial_cython import cython_beta_binomial_prior_distribution

# Unified interface
def compute_attn_prior(phoneme_count, mel_count, scaling_factor=1.0, method='refined'):
methods = {
'original': original_beta_binomial_prior_distribution,
'refined': refined_beta_binomial_prior_distribution,
# 'cython': cython_beta_binomial_prior_distribution # Uncomment if Cython version is available
}

Args:
x_len (int): Length of input sequence (e.g., number of phonemes).
y_len (int): Length of output sequence (e.g., number of mel frames).
scaling_factor (float): Scaling factor for the distribution.
if method not in methods:
raise ValueError(f"Unknown method: {method}. Available methods are: {', '.join(methods.keys())}")

Returns:
np.array: Attention prior matrix [y_len, x_len].
"""
attn_prior = beta_binomial_prior_distribution(
x_len,
y_len,
scaling_factor,
)
return attn_prior # [y_len, x_len]
return methods[method](phoneme_count, mel_count, scaling_factor)

def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path):
"""
Load or compute and save the attention prior.
Args:
token_ids (list): Input token IDs.
wav (np.array): Waveform data.
rel_wav_path (str): Relative path to the wav file.
# Benchmark function
def benchmark(method, phoneme_count, mel_count, scaling_factor=1.0, runs=5):
times = []
for _ in range(runs):
start = time.time()
result = compute_attn_prior(phoneme_count, mel_count, scaling_factor, method=method)
end = time.time()
times.append(end - start)
return np.mean(times), result

# Function to compare outputs
def compare_outputs(original_output, test_output, tolerance=1e-10):
return np.allclose(original_output, test_output, rtol=tolerance, atol=tolerance)

# Main benchmark script
def run_benchmark():
test_cases = [
(10, 20), # Small input
(50, 100), # Medium input
(200, 400), # Large input
(500, 1000) # Very large input
]

methods = ['original', 'refined'] # Add 'cython' if available

Returns:
np.array: Attention prior matrix.
"""
attn_prior_file = os.path.join(self.attn_prior_cache_path, f"{rel_wav_path}.npy")
print("Phonemes | Mels | Method | Time (s) | Speedup | Outputs Match")
print("---------|------|---------|----------|---------|---------------")

for P, M in test_cases:
original_time, original_output = benchmark('original', P, M)

for method in methods:
time_taken, output = benchmark(method, P, M)
speedup = original_time / time_taken if method != 'original' else 1.0
outputs_match = compare_outputs(original_output, output)

print(f"{P:8d} | {M:4d} | {method:7s} | {time_taken:.6f} | {speedup:.2f}x | {outputs_match}")

print("---------|------|---------|----------|---------|---------------")

print("\nVerifying output consistency across all methods...")
all_match = all(
compare_outputs(
compute_attn_prior(P, M, method='original'),
compute_attn_prior(P, M, method=method)
)
for P, M in test_cases
for method in methods
)

if os.path.exists(attn_prior_file):
# If cached prior exists, load and return it
return np.load(attn_prior_file)
if all_match:
print("All outputs match within the specified tolerance.")
else:
# Compute the prior, save it, and return
token_len = len(token_ids)
mel_len = wav.shape[1] // self.ap.hop_length
attn_prior = compute_attn_prior(token_len, mel_len)
np.save(attn_prior_file, attn_prior)
return attn_prior

if __name__ == "__main__":
print("WARNING: Not all outputs match. Please review the implementations for potential discrepancies.")

def plot():
import matplotlib.pyplot as plt

# Test parameters
Expand Down Expand Up @@ -115,4 +140,8 @@ def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path):
print(f"\nWith scaling factor {scaling_factor}:")
print(f"Min value: {attn_prior_scaled.min():.6f}")
print(f"Max value: {attn_prior_scaled.max():.6f}")
print(f"Mean value: {attn_prior_scaled.mean():.6f}")
print(f"Mean value: {attn_prior_scaled.mean():.6f}")

if __name__ == "__main__":
run_benchmark()
plot()

0 comments on commit 850d855

Please sign in to comment.