Skip to content

Commit

Permalink
Merge branch 'main' of github.com:QuState/PhastFT
Browse files Browse the repository at this point in the history
  • Loading branch information
smu160 committed Feb 3, 2024
2 parents 880e05a + fdddc26 commit 1a37997
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions benches/benchmark_plots.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
Plot benchmark results for FFTW3, RustFFT, and PhastFT
"""
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand All @@ -20,13 +18,13 @@
}


def read_file(filepath: str) -> list[int]:
def read_file(filepath: str) -> list[float]:
y = []

with open(filepath) as f:
for line in f:
line = line.strip()
y.append(int(line))
y.append(float(line))

return y

Expand All @@ -35,31 +33,32 @@ def get_figure_of_interest(vals: list[int]) -> float:
return np.median(vals)


def build_and_clean_data(root_benchmark_dir: str, *names) -> defaultdict[str, list]:
libs = ("rustfft", "phastft")
n_range = range(12, 30)

data = defaultdict(list)
def build_and_clean_data(root_benchmark_dir: str, n_range: range, lib_name: str) -> list[float]:
data = []

for lib in libs:
for n in n_range:
y = read_file(f"{root_benchmark_dir}/{lib}/size_{n}")
y_k = get_figure_of_interest(y)
data[lib].append(y_k)
for n in n_range:
y = read_file(f"{root_benchmark_dir}/{lib_name}/size_{n}")
y_k = get_figure_of_interest(y)
data.append(y_k)

return data

def plot_lines(data: defaultdict[str, list]) -> None:
index = [bytes2human(2**n * (128 / 8)) for n in range(12, 30)]

def plot_lines(data: dict[str, list], n_range: range) -> None:
index = [bytes2human(2**n * (128 / 8)) for n in n_range]
plt.figure()

y0 = np.asarray(data["phastft"])
y1 = np.asarray(data["rustfft"])
y0 = y1/y0
y0 = np.asarray(data["fftw3"])
y1 = np.asarray(data["phastft"])
y2 = np.asarray(data["rustfft"])

y0 /= y2
y1 /= y2

df = pd.DataFrame(
{
"PhastFT": y0,
"FFTW3": y0,
"PhastFT": y1,
"RustFFT": np.ones(len(index)),
},
index=index,
Expand All @@ -70,10 +69,11 @@ def plot_lines(data: defaultdict[str, list]) -> None:
plt.xticks(fontsize=9, rotation=-45)
plt.yticks(fontsize=9)
plt.xlabel("size of input")
plt.ylabel("speedup (relative to RustFFT)")
plt.ylabel("time taken (relative to RustFFT)")
plt.legend(loc='best')
plt.tight_layout()
plt.savefig("benchmarks_bar_plot.png", dpi=600)
# plt.show()


# Source: https://stackoverflow.com/a/1094933
Expand All @@ -97,9 +97,19 @@ def bytes2human(n, format='%(value).1f %(symbol)s', symbols='customary'):
return format % dict(symbol=symbols[0], value=n)


def main():
lib_names = ("rustfft", "phastft", "fftw3")
n_range = range(12, 30)

all_data = {}

for lib in lib_names:
data = build_and_clean_data("benchmark-data.2024.02.02.16-45-50", n_range, lib)
all_data[lib] = data

assert len(all_data["rustfft"]) == len(all_data["fftw3"]) == len(all_data["phastft"])
plot_lines(all_data, n_range)


if __name__ == "__main__":
# y = read_file("benchmark-data.2024.02.02.11-02-07/phastft/size_16")
data = build_and_clean_data("benchmark-data.2024.02.02.11-43-10")
# print(data)
plot_lines(data)
main()

0 comments on commit 1a37997

Please sign in to comment.