diff --git a/scripts/py_benchmarks.py b/scripts/py_benchmarks.py index b8bddfd..0217e74 100644 --- a/scripts/py_benchmarks.py +++ b/scripts/py_benchmarks.py @@ -2,6 +2,7 @@ import pyfftw import time import csv +from scipy.stats import unitary_group import matplotlib.pyplot as plt from pybindings import fft @@ -17,14 +18,16 @@ def main() -> None: for n in range(4, 29): print(f"n = {n}") big_n = 1 << n - a_re = np.asarray([float(i) for i in range(big_n)]) - a_im = np.asarray([float(i) for i in range(big_n)]) + x = unitary_group.rvs(big_n) + + a_re = x[:, 0].copy().real # np.asarray([float(i) for i in range(big_n)]) + a_im = x[:, 0].copy().imag # np.asarray([float(i) for i in range(big_n)]) start = time.time() fft(a_re, a_im) phastft_elapsed = round((time.time() - start) * 10**6) - a = [complex(i, i) for i in range(big_n)] + a = x[:, 0].copy() start = time.time() expected = np.fft.fft(a) @@ -41,7 +44,7 @@ def main() -> None: ) np.testing.assert_allclose(actual, expected) - arr = np.asarray([complex(i, i) for i in range(big_n)]) + arr = x[:, 0].copy() a = pyfftw.empty_aligned(big_n, dtype="complex128") a[:] = arr start = time.time() @@ -94,7 +97,6 @@ def plot_elapsed_times(data: dict) -> None: plt.ylabel("Elapsed Time (microseconds)") plt.yscale("log") plt.legend() - plt.savefig("py_benchmarks.png", dpi=600)