Skip to content

Commit

Permalink
Remove rayon dep, add bar plot, rename to PhastFT
Browse files Browse the repository at this point in the history
  • Loading branch information
smu160 committed Jan 31, 2024
1 parent f07c3b7 commit daee2bb
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 19 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ keywords = ["quantum", "fft", "discrete", "fourier", "transform"]
categories = ["algorithms", "compression", "science"]

[dependencies]
rayon = "1.8.1"
rustfft = "6.1.0"

[profile.release]
Expand Down
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# PHFT
# PhastFT

**PH**ast**FT** (PHFT) is a high-performance, "quantum-inspired" Fast Fourier Transform (FFT) library written in pure
and safe Rust. It is the fastest pure-Rust FFT library according to our benchmarks.
PhastFT is a high-performance, "quantum-inspired" Fast Fourier
Transform (FFT) library written in pure and safe Rust. It is the fastest
pure-Rust FFT library according to our benchmarks.

## Features

Expand All @@ -18,7 +19,7 @@ and safe Rust. It is the fastest pure-Rust FFT library according to our benchmar

## How is it so fast?

PHFT is designed around the capabilities and limitations of modern hardware (that is, anything made in the last 10 years or so).
PhastFT is designed around the capabilities and limitations of modern hardware (that is, anything made in the last 10 years or so).

The two major bottlenecks in FFT are the **CPU cycles** and **memory accesses.**

Expand All @@ -35,11 +36,11 @@ including [RustFFT](https://crates.io/crates/rustfft/), on both large and small

## Getting Started

To integrate PHFT into your Rust project:
To integrate PhastFT into your Rust project:

...

To use PHFT with Python:
To use PhastFT with Python:

```bash
pip install ...
Expand All @@ -51,7 +52,7 @@ pip install ...

### Rust

PHFT provides a straightforward API for performing FFT computations. Here's an example of using PHFT for a basic FFT
PhastFT provides a straightforward API for performing FFT computations. Here's an example of using PhastFT for a basic FFT
operation:

...
Expand All @@ -62,12 +63,12 @@ operation:

## Benchmarks

PHFT is benchmarked against other FFT libraries. Detailed benchmarking results and instructions are available in the
PhastFT is benchmarked against other FFT libraries. Detailed benchmarking results and instructions are available in the
benchmarks directory.

## Contributing

Contributions to PHFT are welcome! If you find any issues or have improvements to suggest, please open an issue or
Contributions to PhastFT are welcome! If you find any issues or have improvements to suggest, please open an issue or
submit a pull request. Follow the contribution guidelines outlined in the CONTRIBUTING.md file.

## License
Expand Down Expand Up @@ -103,7 +104,7 @@ Finally, run:

## What's with the name?

The name, **PHFT**, is derived from the implementation of the
The name, **PhastFT**, is derived from the implementation of the
[Quantum Fourier Transform](https://en.wikipedia.org/wiki/Quantum_Fourier_transform) (QFT). Namely, the
[quantum circuit implementation of QFT](https://en.wikipedia.org/wiki/Quantum_Fourier_transform#Circuit_implementation)
consists of the **P**hase gates and **H**adamard gates. Hence, **PH**ast**FT**.
consists of the **P**hase gates and **H**adamard gates. Hence, **Ph**astFT.
66 changes: 59 additions & 7 deletions scripts/py_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import numpy as np
import pandas as pd
import pyfftw
import time
import csv
from scipy.stats import unitary_group
import matplotlib.pyplot as plt

from pybindings import fft

plt.style.use('seaborn-v0_8-poster')


def gen_random_signal(dim: int) -> np.ndarray:
x = np.asarray((1 / np.sqrt(2)) * (np.random.randn(dim) + 1j * np.random.randn(dim)), dtype="complex128")
return np.ascontiguousarray(x)
return np.asarray(
np.random.randn(dim) + 1j * np.random.randn(dim),
dtype="complex128",
)


def main() -> None:
Expand All @@ -24,18 +28,28 @@ def main() -> None:
print(f"n = {n}")
big_n = 1 << n
s = gen_random_signal(big_n)
a_re = np.ascontiguousarray(s.real) # np.asarray([float(i) for i in range(big_n)])
a_im = np.ascontiguousarray(s.imag) # np.asarray([float(i) for i in range(big_n)])

a_re = [None] * len(s)
a_im = [None] * len(s)

for i, val in enumerate(s):
a_re[i] = val.real
a_im[i] = val.imag

a_re = np.asarray(a_re, dtype=np.float64)
a_im = np.asarray(a_im, dtype=np.float64)

start = time.time()
fft(a_re, a_im)
phastft_elapsed = round((time.time() - start) * 10**6)
print(f"PhastFT completed in {phastft_elapsed} us")

a = s.copy()

start = time.time()
expected = np.fft.fft(a)
numpy_elapsed = round((time.time() - start) * 10**6)
print(f"NumPy fft completed in {numpy_elapsed} us")

actual = np.asarray(
[
Expand All @@ -54,6 +68,7 @@ def main() -> None:
start = time.time()
b = pyfftw.interfaces.numpy_fft.fft(a)
pyfftw_elapsed = round((time.time() - start) * 10**6)
print(f"pyFFTW completed in {pyfftw_elapsed} us")

np.testing.assert_allclose(b, actual, rtol=1e-3, atol=0)

Expand All @@ -68,7 +83,6 @@ def main() -> None:

file_path = "elapsed_times.csv"
loaded_data = read_csv_to_dict(file_path)
print(loaded_data)
plot_elapsed_times(loaded_data)


Expand Down Expand Up @@ -104,5 +118,43 @@ def plot_elapsed_times(data: dict) -> None:
plt.savefig("py_benchmarks.png", dpi=600)


def grouped_bar_plot(data: dict):
index = data["n"]
np_fft_timings = data["numpy_fft_time"]
pyfftw_timings = data["pyfftw_fft_time"] # / np_fft_timings
phastft_timings = data["phastft_time"] # / np_fft_timings

ratio_np_to_pyfftw = []
ratio_np_to_phastft = []
ratio_np_to_np = []
for (s1, s2, s3) in zip(np_fft_timings, pyfftw_timings, phastft_timings):
if s1 == 0 or s2 == 0 or s3 == 0:
continue

ratio_np_to_pyfftw.append(s2 / s1)
ratio_np_to_phastft.append(s3 / s1)
ratio_np_to_np.append(s1 / s1)

plt.figure()
df = pd.DataFrame(
{
"NumPy fft": ratio_np_to_np,
"pyFFTW": ratio_np_to_pyfftw,
"PhastFT": ratio_np_to_phastft,
},
index=index,
)

ax = df.plot(kind='bar', linewidth=3, rot=0)
plt.xticks(fontsize=8)
plt.xlabel("# of bits in index")
plt.ylabel("Time (relative to NumPy FFT)")
# plt.tight_layout(pad=0.0)
plt.savefig("py_benchmarks_bar_plot.png", dpi=600)


if __name__ == "__main__":
main()
# main()
file_path = "elapsed_times.csv"
loaded_data = read_csv_to_dict(file_path)
grouped_bar_plot(loaded_data)

0 comments on commit daee2bb

Please sign in to comment.