From 850d85518271b5cf4fddb37b37c84df26726832e Mon Sep 17 00:00:00 2001 From: xiaozhah <1298856981@qq.com> Date: Wed, 17 Jul 2024 14:16:48 +0800 Subject: [PATCH] add refined beta_binomial_prior_distribution func and test it --- prior.py | 139 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 84 insertions(+), 55 deletions(-) diff --git a/prior.py b/prior.py index e1f1d4d..552be9a 100644 --- a/prior.py +++ b/prior.py @@ -1,21 +1,9 @@ -import os import numpy as np from scipy.stats import betabinom +import time -# from https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/helpers.py - -def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): - """ - Calculate the Beta-Binomial prior distribution for alignment. - - Args: - phoneme_count (int): Number of phonemes. - mel_count (int): Number of mel spectrogram frames. - scaling_factor (float): Scaling factor for the distribution, default is 1.0. - - Returns: - np.array: 2D array of prior probabilities [mel_count, phoneme_count]. - """ +# Original implementation +def original_beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): P, M = phoneme_count, mel_count x = np.arange(0, P) mel_text_probs = [] @@ -26,51 +14,88 @@ def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1. mel_text_probs.append(mel_i_prob) return np.array(mel_text_probs) -def compute_attn_prior(x_len, y_len, scaling_factor=1.0): - """ - Compute attention priors for the alignment network. +# Refined NumPy implementation +def refined_beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): + P, M = phoneme_count, mel_count + m = np.arange(1, M + 1)[:, np.newaxis] + x = np.arange(P) + a = scaling_factor * m + b = scaling_factor * (M + 1 - m) + probs = betabinom.pmf(x, P, a, b) + return probs + +# Cython implementation (assuming it's compiled and imported) +# from beta_binomial_cython import cython_beta_binomial_prior_distribution + +# Unified interface +def compute_attn_prior(phoneme_count, mel_count, scaling_factor=1.0, method='refined'): + methods = { + 'original': original_beta_binomial_prior_distribution, + 'refined': refined_beta_binomial_prior_distribution, + # 'cython': cython_beta_binomial_prior_distribution # Uncomment if Cython version is available + } - Args: - x_len (int): Length of input sequence (e.g., number of phonemes). - y_len (int): Length of output sequence (e.g., number of mel frames). - scaling_factor (float): Scaling factor for the distribution. + if method not in methods: + raise ValueError(f"Unknown method: {method}. Available methods are: {', '.join(methods.keys())}") - Returns: - np.array: Attention prior matrix [y_len, x_len]. - """ - attn_prior = beta_binomial_prior_distribution( - x_len, - y_len, - scaling_factor, - ) - return attn_prior # [y_len, x_len] + return methods[method](phoneme_count, mel_count, scaling_factor) -def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path): - """ - Load or compute and save the attention prior. - - Args: - token_ids (list): Input token IDs. - wav (np.array): Waveform data. - rel_wav_path (str): Relative path to the wav file. +# Benchmark function +def benchmark(method, phoneme_count, mel_count, scaling_factor=1.0, runs=5): + times = [] + for _ in range(runs): + start = time.time() + result = compute_attn_prior(phoneme_count, mel_count, scaling_factor, method=method) + end = time.time() + times.append(end - start) + return np.mean(times), result + +# Function to compare outputs +def compare_outputs(original_output, test_output, tolerance=1e-10): + return np.allclose(original_output, test_output, rtol=tolerance, atol=tolerance) + +# Main benchmark script +def run_benchmark(): + test_cases = [ + (10, 20), # Small input + (50, 100), # Medium input + (200, 400), # Large input + (500, 1000) # Very large input + ] + + methods = ['original', 'refined'] # Add 'cython' if available - Returns: - np.array: Attention prior matrix. - """ - attn_prior_file = os.path.join(self.attn_prior_cache_path, f"{rel_wav_path}.npy") + print("Phonemes | Mels | Method | Time (s) | Speedup | Outputs Match") + print("---------|------|---------|----------|---------|---------------") + + for P, M in test_cases: + original_time, original_output = benchmark('original', P, M) + + for method in methods: + time_taken, output = benchmark(method, P, M) + speedup = original_time / time_taken if method != 'original' else 1.0 + outputs_match = compare_outputs(original_output, output) + + print(f"{P:8d} | {M:4d} | {method:7s} | {time_taken:.6f} | {speedup:.2f}x | {outputs_match}") + + print("---------|------|---------|----------|---------|---------------") + + print("\nVerifying output consistency across all methods...") + all_match = all( + compare_outputs( + compute_attn_prior(P, M, method='original'), + compute_attn_prior(P, M, method=method) + ) + for P, M in test_cases + for method in methods + ) - if os.path.exists(attn_prior_file): - # If cached prior exists, load and return it - return np.load(attn_prior_file) + if all_match: + print("All outputs match within the specified tolerance.") else: - # Compute the prior, save it, and return - token_len = len(token_ids) - mel_len = wav.shape[1] // self.ap.hop_length - attn_prior = compute_attn_prior(token_len, mel_len) - np.save(attn_prior_file, attn_prior) - return attn_prior - -if __name__ == "__main__": + print("WARNING: Not all outputs match. Please review the implementations for potential discrepancies.") + +def plot(): import matplotlib.pyplot as plt # Test parameters @@ -115,4 +140,8 @@ def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path): print(f"\nWith scaling factor {scaling_factor}:") print(f"Min value: {attn_prior_scaled.min():.6f}") print(f"Max value: {attn_prior_scaled.max():.6f}") - print(f"Mean value: {attn_prior_scaled.mean():.6f}") \ No newline at end of file + print(f"Mean value: {attn_prior_scaled.mean():.6f}") + +if __name__ == "__main__": + run_benchmark() + plot() \ No newline at end of file