diff --git a/benches/benchmark_plots.py b/benches/benchmark_plots.py index 68fcb8d..8d416b4 100644 --- a/benches/benchmark_plots.py +++ b/benches/benchmark_plots.py @@ -1,8 +1,6 @@ """ Plot benchmark results for FFTW3, RustFFT, and PhastFT """ -from collections import defaultdict - import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -20,13 +18,13 @@ } -def read_file(filepath: str) -> list[int]: +def read_file(filepath: str) -> list[float]: y = [] with open(filepath) as f: for line in f: line = line.strip() - y.append(int(line)) + y.append(float(line)) return y @@ -35,31 +33,32 @@ def get_figure_of_interest(vals: list[int]) -> float: return np.median(vals) -def build_and_clean_data(root_benchmark_dir: str, *names) -> defaultdict[str, list]: - libs = ("rustfft", "phastft") - n_range = range(12, 30) - - data = defaultdict(list) +def build_and_clean_data(root_benchmark_dir: str, n_range: range, lib_name: str) -> list[float]: + data = [] - for lib in libs: - for n in n_range: - y = read_file(f"{root_benchmark_dir}/{lib}/size_{n}") - y_k = get_figure_of_interest(y) - data[lib].append(y_k) + for n in n_range: + y = read_file(f"{root_benchmark_dir}/{lib_name}/size_{n}") + y_k = get_figure_of_interest(y) + data.append(y_k) return data -def plot_lines(data: defaultdict[str, list]) -> None: - index = [bytes2human(2**n * (128 / 8)) for n in range(12, 30)] + +def plot_lines(data: dict[str, list], n_range: range) -> None: + index = [bytes2human(2**n * (128 / 8)) for n in n_range] plt.figure() - y0 = np.asarray(data["phastft"]) - y1 = np.asarray(data["rustfft"]) - y0 = y1/y0 + y0 = np.asarray(data["fftw3"]) + y1 = np.asarray(data["phastft"]) + y2 = np.asarray(data["rustfft"]) + + y0 /= y2 + y1 /= y2 df = pd.DataFrame( { - "PhastFT": y0, + "FFTW3": y0, + "PhastFT": y1, "RustFFT": np.ones(len(index)), }, index=index, @@ -70,10 +69,11 @@ def plot_lines(data: defaultdict[str, list]) -> None: plt.xticks(fontsize=9, rotation=-45) plt.yticks(fontsize=9) plt.xlabel("size of input") - plt.ylabel("speedup (relative to RustFFT)") + plt.ylabel("time taken (relative to RustFFT)") plt.legend(loc='best') plt.tight_layout() plt.savefig("benchmarks_bar_plot.png", dpi=600) + # plt.show() # Source: https://stackoverflow.com/a/1094933 @@ -97,9 +97,19 @@ def bytes2human(n, format='%(value).1f %(symbol)s', symbols='customary'): return format % dict(symbol=symbols[0], value=n) +def main(): + lib_names = ("rustfft", "phastft", "fftw3") + n_range = range(12, 30) + + all_data = {} + + for lib in lib_names: + data = build_and_clean_data("benchmark-data.2024.02.02.16-45-50", n_range, lib) + all_data[lib] = data + + assert len(all_data["rustfft"]) == len(all_data["fftw3"]) == len(all_data["phastft"]) + plot_lines(all_data, n_range) + if __name__ == "__main__": - # y = read_file("benchmark-data.2024.02.02.11-02-07/phastft/size_16") - data = build_and_clean_data("benchmark-data.2024.02.02.11-43-10") - # print(data) - plot_lines(data) \ No newline at end of file + main() \ No newline at end of file