Skip to content

Commit d283d65

Browse files
committed
Add tests and make API breaking changes
- Add a test to make sure inverse twiddle factors multiplied by forward twiddle factors always gives 1.0 - `fft` now takes the real/imaginary input and the `Direction` of the `fft` - Pre-built planners can be used with the `fft_with_opts_and_plan` function
1 parent c365d00 commit d283d65

File tree

6 files changed

+92
-35
lines changed

6 files changed

+92
-35
lines changed

examples/benchmark.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
use std::env;
22
use std::str::FromStr;
33

4-
use phastft::fft;
5-
use phastft::planner::{Direction, Planner};
64
use utilities::gen_random_signal;
75

6+
use phastft::fft;
7+
use phastft::planner::Direction;
8+
89
fn benchmark_fft(n: usize) {
910
let big_n = 1 << n;
1011
let mut reals = vec![0.0; big_n];
1112
let mut imags = vec![0.0; big_n];
1213
gen_random_signal(&mut reals, &mut imags);
1314

1415
let now = std::time::Instant::now();
15-
let mut planner = Planner::new(big_n, Direction::Forward);
16-
fft(&mut reals, &mut imags, &mut planner);
16+
fft(&mut reals, &mut imags, Direction::Forward);
1717
let elapsed = now.elapsed().as_micros();
1818
println!("{elapsed}");
1919
}

examples/profile.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ use std::env;
22
use std::str::FromStr;
33

44
use phastft::fft;
5-
use phastft::planner::Planner;
5+
use phastft::planner::Direction;
66

77
fn benchmark_fft(num_qubits: usize) {
88
let n = 1 << num_qubits;
99
let mut reals: Vec<f64> = (1..=n).map(|i| i as f64).collect();
1010
let mut imags: Vec<f64> = (1..=n).map(|i| i as f64).collect();
11-
let mut planner = Planner::new(n);
12-
fft(&mut reals, &mut imags, &mut planner);
11+
fft(&mut reals, &mut imags, Direction::Forward);
1312
}
1413

1514
fn main() {

pybindings/src/lib.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
use numpy::PyReadwriteArray1;
2-
use phastft::{fft as fft_rs, planner::Planner};
2+
use phastft::{fft as fft_rs, planner::Direction};
33
use pyo3::prelude::*;
44

55
#[pyfunction]
6-
fn fft(mut reals: PyReadwriteArray1<f64>, mut imags: PyReadwriteArray1<f64>) {
7-
let mut planner = Planner::new(reals.len());
6+
fn fft(mut reals: PyReadwriteArray1<f64>, mut imags: PyReadwriteArray1<f64>, direction: char) {
7+
assert!(direction == 'f' || direction == 'r');
8+
let dir = if direction == 'f' {
9+
Direction::Forward
10+
} else {
11+
Direction::Reverse
12+
};
13+
814
fft_rs(
915
reals.as_slice_mut().unwrap(),
1016
imags.as_slice_mut().unwrap(),
11-
&mut planner,
17+
dir,
1218
);
1319
}
1420

src/lib.rs

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::cobra::cobra_apply;
44
use crate::kernels::{fft_chunk_2, fft_chunk_4, fft_chunk_n, fft_chunk_n_simd, Float};
55
use crate::options::Options;
6-
use crate::planner::Planner;
6+
use crate::planner::{Direction, Planner};
77
use crate::twiddles::filter_twiddles;
88

99
mod cobra;
@@ -22,20 +22,36 @@ mod twiddles;
2222
/// Panics if `reals.len() != imags.len()`
2323
///
2424
/// [1] https://inst.eecs.berkeley.edu/~ee123/sp15/Notes/Lecture08_FFT_and_SpectAnalysis.key.pdf
25-
pub fn fft(reals: &mut [Float], imags: &mut [Float], planner: &mut Planner) {
25+
pub fn fft(reals: &mut [Float], imags: &mut [Float], direction: Direction) {
26+
assert_eq!(
27+
reals.len(),
28+
imags.len(),
29+
"real and imaginary inputs must be of equal size, but got: {} {}",
30+
reals.len(),
31+
imags.len()
32+
);
33+
34+
let mut planner = Planner::new(reals.len(), direction);
35+
assert!(planner.num_twiddles().is_power_of_two() && planner.num_twiddles() == reals.len() / 2);
36+
2637
let opts = Options::guess_options(reals.len());
27-
fft_with_opts(reals, imags, &opts, planner);
38+
fft_with_opts_and_plan(reals, imags, &opts, &mut planner);
2839
}
2940

30-
/// Same as [fft], but also accepts [`Options`] that control optimization strategies.
41+
/// Same as [fft], but also accepts [`Options`] that control optimization strategies, as well as
42+
/// a [`Planner`] in the case that this FFT will need to be run multiple times.
3143
///
32-
/// `fft_dif` automatically guesses the best strategy for a given input,
44+
/// `fft` automatically guesses the best strategy for a given input,
3345
/// so you only need to call this if you are tuning performance for a specific hardware platform.
3446
///
47+
/// In addition, `fft` automatically creates a planner to be used. In the case that you plan
48+
/// on running an FFT many times on inputs of the same size, use this function with the pre-built
49+
/// [`Planner`].
50+
///
3551
/// # Panics
3652
///
37-
/// Panics if `reals.len() != imags.len()`
38-
pub fn fft_with_opts(
53+
/// Panics if `reals.len() != imags.len()`, or if the input length is *not* a power of two.
54+
pub fn fft_with_opts_and_plan(
3955
reals: &mut [Float],
4056
imags: &mut [Float],
4157
opts: &Options,
@@ -87,6 +103,8 @@ mod tests {
87103
rustfft::{num_complex::Complex64, FftPlanner},
88104
};
89105

106+
use crate::planner::Direction;
107+
90108
use super::*;
91109

92110
#[should_panic]
@@ -95,11 +113,14 @@ mod tests {
95113
let num_points = 5;
96114

97115
// this test will actually always fail at this stage
98-
let mut planner = Planner::new(num_points);
116+
let mut planner = Planner::new(num_points, Direction::Forward);
99117

100118
let mut reals = vec![0.0; num_points];
101119
let mut imags = vec![0.0; num_points];
102-
fft(&mut reals, &mut imags, &mut planner);
120+
let opts = Options::guess_options(reals.len());
121+
122+
// but this call should panic as well
123+
fft_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner);
103124
}
104125

105126
#[test]
@@ -111,8 +132,7 @@ mod tests {
111132

112133
let mut reals: Vec<Float> = (1..=n).map(|i| i as f64).collect();
113134
let mut imags: Vec<Float> = (1..=n).map(|i| i as f64).collect();
114-
let mut planner = Planner::new(n);
115-
fft(&mut reals, &mut imags, &mut planner);
135+
fft(&mut reals, &mut imags, Direction::Forward);
116136

117137
let mut buffer: Vec<Complex64> = (1..=n)
118138
.map(|i| Complex64::new(i as f64, i as f64))

src/planner.rs

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,17 @@ pub struct Planner {
1010
pub twiddles_im: Vec<f64>,
1111
}
1212

13-
// TODO(saveliy yusufov): Add a parameter to `new` that will take into consideration whether we do inverse FFT (IFFT)
14-
// In this case, the twiddle factors should be pre-computed as follows:
15-
//
16-
// FFT Twiddle Factor: e^{i2π*k/N}
17-
// IFFT Twiddle Factor: e^{-i2π*k/N}
18-
//
19-
// source: https://dsp.stackexchange.com/q/73367
2013
impl Planner {
21-
/// Create a `Planner` for an FFT of size `num_points`
14+
/// Create a `Planner` for an FFT of size `num_points`.
15+
/// The twiddle factors are pre-computed based on the provided [`Direction`].
16+
/// For `Forward`, use [`Direction::Forward`].
17+
/// For `Reverse`, use [`Direction::Reverse`].
2218
///
2319
/// # Panics
2420
///
2521
/// Panics if `num_points` is less than 1
2622
pub fn new(num_points: usize, direction: Direction) -> Self {
27-
assert!(num_points > 0);
23+
assert!(num_points > 0 && num_points.is_power_of_two());
2824
if num_points <= 4 {
2925
return Self {
3026
twiddles_re: vec![],
@@ -46,17 +42,53 @@ impl Planner {
4642
twiddles_im,
4743
}
4844
}
45+
46+
pub(crate) fn num_twiddles(&self) -> usize {
47+
assert_eq!(self.twiddles_re.len(), self.twiddles_im.len());
48+
self.twiddles_re.len()
49+
}
4950
}
5051

5152
#[cfg(test)]
5253
mod tests {
54+
use utilities::assert_f64_closeness;
55+
5356
use crate::planner::{Direction, Planner};
5457

5558
#[test]
5659
fn no_twiddles() {
57-
for num_points in 2..=4 {
60+
for num_points in [2, 4] {
5861
let planner = Planner::new(num_points, Direction::Forward);
5962
assert!(planner.twiddles_im.is_empty() && planner.twiddles_re.is_empty());
6063
}
6164
}
65+
66+
#[test]
67+
fn forward_mul_inverse_eq_identity() {
68+
for i in 3..25 {
69+
let num_points = 1 << i;
70+
let planner_forward = Planner::new(num_points, Direction::Forward);
71+
let planner_reverse = Planner::new(num_points, Direction::Reverse);
72+
73+
assert_eq!(
74+
planner_reverse.num_twiddles(),
75+
planner_forward.num_twiddles()
76+
);
77+
78+
// (a + ib) (c + id) = ac + iad + ibc - bd
79+
// = ac - bd + i(ad + bc)
80+
planner_forward
81+
.twiddles_re
82+
.iter()
83+
.zip(planner_forward.twiddles_im.iter())
84+
.zip(planner_reverse.twiddles_re.iter())
85+
.zip(planner_reverse.twiddles_im)
86+
.for_each(|(((a, b), c), d)| {
87+
let temp_re = a * c - b * d;
88+
let temp_im = a * d + b * c;
89+
assert_f64_closeness(temp_re, 1.0, 1e-6);
90+
assert_f64_closeness(temp_im, 0.0, 1e-6);
91+
});
92+
}
93+
}
6294
}

src/twiddles.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ mod tests {
219219
for n in 4..28 {
220220
let dist = 1 << n;
221221

222-
let (twiddles_re_ref, twiddles_im_ref) = generate_twiddles(dist);
223-
let (twiddles_re, twiddles_im) = generate_twiddles_simd(dist);
222+
let (twiddles_re_ref, twiddles_im_ref) = generate_twiddles(dist, Direction::Forward);
223+
let (twiddles_re, twiddles_im) = generate_twiddles_simd(dist, Direction::Forward);
224224

225225
twiddles_re
226226
.iter()
@@ -245,7 +245,7 @@ mod tests {
245245
let dist = 1 << (n - 1);
246246
let mut twiddles_iter = Twiddles::new(dist);
247247

248-
let (mut twiddles_re, mut twiddles_im) = generate_twiddles(dist);
248+
let (mut twiddles_re, mut twiddles_im) = generate_twiddles(dist, Direction::Forward);
249249

250250
for i in 0..dist {
251251
let (tw_re, tw_im) = twiddles_iter.next().unwrap();

0 commit comments

Comments
 (0)