Skip to content

Commit

Permalink
Cleanup plots and split bar plot for py benches
Browse files Browse the repository at this point in the history
  • Loading branch information
smu160 committed Feb 10, 2024
1 parent 648d4d2 commit 5203d4e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 47 deletions.
40 changes: 22 additions & 18 deletions benches/benchmark_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def build_and_clean_data(

def plot(data: dict[str, list], n_range: range) -> None:
index = [bytes2human(2**n * (128 / 8)) for n in n_range]
plt.figure()

y0 = np.asarray(data["fftw3"])
y1 = np.asarray(data["phastft"])
Expand All @@ -59,34 +58,39 @@ def plot(data: dict[str, list], n_range: range) -> None:
index=index,
)

df.plot(kind="bar", linewidth=3, rot=0)
plt.xticks(fontsize=9, rotation=-45)
plt.yticks(fontsize=9)
title = "PhastFT vs. FFTW3 vs. RustFFT"
df.plot(kind="bar", linewidth=2, rot=0, title=title)
plt.xticks(fontsize=8, rotation=-45)
plt.xlabel("size of input")
plt.ylabel("time taken (relative to RustFFT)")
plt.ylabel("Execution Time Ratio\n(relative to RustFFT)")
plt.legend(loc="best")
plt.tight_layout()
plt.savefig("benchmarks_bar_plot.png", dpi=600)
plt.savefig(f"benchmarks_bar_plot_{n_range.start}_{n_range.stop -1}.png", dpi=600)
plt.show()


def main():
"""Entry point... yay"""
lib_names = ("rustfft", "phastft", "fftw3")
n_range = range(4, 30)
all_data = {}
ranges = (range(4, 13), range(13, 30))

for lib in lib_names:
root_folder = find_directory()
if root_folder is None:
raise FileNotFoundError("unable to find the benchmark data directory")
for n_range in ranges:
all_data = {}

data = build_and_clean_data(root_folder, n_range, lib)
all_data[lib] = data
for lib in lib_names:
root_folder = find_directory()
if root_folder is None:
raise FileNotFoundError("unable to find the benchmark data directory")

assert (
len(all_data["rustfft"]) == len(all_data["fftw3"]) == len(all_data["phastft"])
)
plot(all_data, n_range)
data = build_and_clean_data(root_folder, n_range, lib)
all_data[lib] = data

assert (
len(all_data["rustfft"])
== len(all_data["fftw3"])
== len(all_data["phastft"])
)
plot(all_data, n_range)


if __name__ == "__main__":
Expand Down
53 changes: 25 additions & 28 deletions benches/py_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np
import pandas as pd
import pyfftw

pyfftw.interfaces.cache.enable()

from pybindings import fft

from utils import bytes2human
Expand All @@ -13,36 +16,30 @@


def gen_random_signal(dim: int) -> np.ndarray:
return np.asarray(
"""Generate a random, complex 1D signal"""
return np.ascontiguousarray(
np.random.randn(dim) + 1j * np.random.randn(dim),
dtype="complex128",
)


def main() -> None:
with open("elapsed_times.csv", "w", newline="") as csvfile:
with open("elapsed_times.csv", "w", newline="", encoding="utf-8") as csvfile:
fieldnames = ["n", "phastft_time", "numpy_fft_time", "pyfftw_fft_time"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

writer.writeheader()

for n in range(12, 29):
for n in range(4, 29):
print(f"n = {n}")
big_n = 1 << n
s = gen_random_signal(big_n)

a_re = [None] * len(s)
a_im = [None] * len(s)

for i, val in enumerate(s):
a_re[i] = val.real
a_im[i] = val.imag

a_re = np.asarray(a_re, dtype=np.float64)
a_im = np.asarray(a_im, dtype=np.float64)
a_re = np.ascontiguousarray(s.real, dtype=np.float64)
a_im = np.ascontiguousarray(s.imag, dtype=np.float64)

start = time.time()
fft(a_re, a_im)
fft(a_re, a_im, "f")
phastft_elapsed = round((time.time() - start) * 10**6)
print(f"PhastFT completed in {phastft_elapsed} us")

Expand All @@ -68,11 +65,11 @@ def main() -> None:
a = pyfftw.empty_aligned(big_n, dtype="complex128")
a[:] = arr
start = time.time()
b = pyfftw.interfaces.numpy_fft.fft(a)
a = pyfftw.interfaces.numpy_fft.fft(a)
pyfftw_elapsed = round((time.time() - start) * 10**6)
print(f"pyFFTW completed in {pyfftw_elapsed} us")

np.testing.assert_allclose(b, actual, rtol=1e-3, atol=0)
np.testing.assert_allclose(a, actual, rtol=1e-3, atol=0)

writer.writerow(
{
Expand All @@ -85,18 +82,19 @@ def main() -> None:

file_path = "elapsed_times.csv"
loaded_data = read_csv_to_dict(file_path)
plot_elapsed_times(loaded_data)
grouped_bar_plot(loaded_data)
grouped_bar_plot(loaded_data, start=0, end=9)
grouped_bar_plot(loaded_data, start=9, end=29)


def read_csv_to_dict(file_path: str) -> dict:
"""Read the benchmark results from the csv file and convert it to a dict"""
data: dict[str, list] = {
"n": [],
"phastft_time": [],
"numpy_fft_time": [],
"pyfftw_fft_time": [],
}
with open(file_path, newline="") as csvfile:
with open(file_path, newline="", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
data["n"].append(int(row["n"]))
Expand All @@ -113,6 +111,7 @@ def read_csv_to_dict(file_path: str) -> dict:


def plot_elapsed_times(data: dict) -> None:
"""Plot the timings for all libs using line plots"""
index = [bytes2human(2**n * (128 / 8)) for n in data["n"]]
np_fft_timings = np.asarray(data["numpy_fft_time"])
pyfftw_timings = np.asarray(data["pyfftw_fft_time"])
Expand All @@ -122,9 +121,8 @@ def plot_elapsed_times(data: dict) -> None:
plt.plot(index, pyfftw_timings, label="PyFFTW FFT", lw=0.8)
plt.plot(index, phastft_timings, label="PhastFT", lw=0.8)

plt.title("FFT Elapsed Times Comparison")
plt.xticks(fontsize=9, rotation=-45)
plt.yticks(fontsize=9)
plt.title("PhastFT vs. pyFFTW vs. NumPy FFT")
plt.xticks(fontsize=8, rotation=-45)
plt.xlabel("size of input")
plt.ylabel("time (us)")
plt.yscale("log")
Expand All @@ -133,14 +131,14 @@ def plot_elapsed_times(data: dict) -> None:
plt.savefig("py_benchmarks.png", dpi=600)


def grouped_bar_plot(data: dict):
def grouped_bar_plot(data: dict, start=0, end=1):
"""Plot the timings for all libs using a grouped bar chart"""
index = data["n"]
index = [bytes2human(2**n * (128 / 8)) for n in index]
np_fft_timings = np.asarray(data["numpy_fft_time"])
pyfftw_timings = np.asarray(data["pyfftw_fft_time"]) # / np_fft_timings
phastft_timings = np.asarray(data["phastft_time"]) # / np_fft_timings

plt.figure()
df = pd.DataFrame(
{
"NumPy fft": np.ones(len(index)),
Expand All @@ -150,12 +148,11 @@ def grouped_bar_plot(data: dict):
index=index,
)

_ax = df.plot(kind="bar", linewidth=3, rot=0)
plt.title("FFT Elapsed Times Comparison")
plt.xticks(fontsize=9, rotation=-45)
plt.yticks(fontsize=9)
title = "PhastFT vs. pyFFTW vs. NumPy FFT"
df[start:end].plot(kind="bar", linewidth=2, rot=0, title=title)
plt.xticks(fontsize=8, rotation=-45)
plt.xlabel("size of input")
plt.ylabel("time taken (relative to NumPy FFT)")
plt.ylabel("Execution Time Ratio\n(relative to NumPy FFT)")
plt.legend(loc="best")
plt.tight_layout()
plt.savefig("py_benchmarks_bar_plot.png", dpi=600)
Expand Down
1 change: 0 additions & 1 deletion benches/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import re
from datetime import datetime


SYMBOLS = {
"customary": ("B", "K", "M", "G", "T", "P", "E", "Z", "Y"),
"customary_ext": (
Expand Down

0 comments on commit 5203d4e

Please sign in to comment.