diff --git a/examples/benchmark.rs b/examples/benchmark.rs index e00bb78..bc4844f 100644 --- a/examples/benchmark.rs +++ b/examples/benchmark.rs @@ -1,7 +1,8 @@ use std::env; use std::str::FromStr; -use phastft::fft_dif; +use phastft::fft; +use phastft::planner::Planner; use utilities::gen_random_signal; fn benchmark_fft(n: usize) { @@ -11,7 +12,8 @@ fn benchmark_fft(n: usize) { gen_random_signal(&mut reals, &mut imags); let now = std::time::Instant::now(); - fft_dif(&mut reals, &mut imags); + let mut planner = Planner::new(n); + fft(&mut reals, &mut imags, &mut planner); let elapsed = now.elapsed().as_micros(); println!("{elapsed}"); } diff --git a/examples/profile.rs b/examples/profile.rs index 4cfae4a..1f62a6b 100644 --- a/examples/profile.rs +++ b/examples/profile.rs @@ -1,13 +1,15 @@ use std::env; use std::str::FromStr; -use phastft::fft_dif; +use phastft::fft; +use phastft::planner::Planner; fn benchmark_fft(num_qubits: usize) { let n = 1 << num_qubits; - let mut reals: Vec = (1..=n).map(f64::from).collect(); - let mut imags: Vec = (1..=n).map(f64::from).collect(); - fft_dif(&mut reals, &mut imags); + let mut reals: Vec = (1..=n).map(|i| i as f64).collect(); + let mut imags: Vec = (1..=n).map(|i| i as f64).collect(); + let mut planner = Planner::new(n); + fft(&mut reals, &mut imags, &mut planner); } fn main() { diff --git a/src/lib.rs b/src/lib.rs index 3c1cdc9..441ea1d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,68 +3,62 @@ use crate::cobra::cobra_apply; use crate::kernels::{fft_chunk_2, fft_chunk_4, fft_chunk_n, fft_chunk_n_simd, Float}; use crate::options::Options; -use crate::twiddles::{filter_twiddles, generate_twiddles, generate_twiddles_simd}; +use crate::planner::Planner; +use crate::twiddles::filter_twiddles; mod cobra; mod kernels; pub mod options; +pub mod planner; mod twiddles; /// FFT -- Decimation in Frequency /// /// This is just the decimation-in-time algorithm, reversed. -/// The inputs are in normal order, and the outputs are bit reversed. +/// The inputs are in normal order, and the outputs are then bit reversed. /// /// # Panics /// /// Panics if `reals.len() != imags.len()` /// /// [1] https://inst.eecs.berkeley.edu/~ee123/sp15/Notes/Lecture08_FFT_and_SpectAnalysis.key.pdf -pub fn fft_dif(reals: &mut [Float], imags: &mut [Float]) { +pub fn fft(reals: &mut [Float], imags: &mut [Float], planner: &mut Planner) { let opts = Options::guess_options(reals.len()); - fft_dif_with_opts(reals, imags, &opts) + fft_with_opts(reals, imags, &opts, planner); } -/// Same as [fft_dif], but also accepts [`Options`] that control optimization strategies. +/// Same as [fft], but also accepts [`Options`] that control optimization strategies. /// /// `fft_dif` automatically guesses the best strategy for a given input, /// so you only need to call this if you are tuning performance for a specific hardware platform. -pub fn fft_dif_with_opts(reals: &mut [Float], imags: &mut [Float], opts: &Options) { +/// +/// # Panics +/// +/// Panics if `reals.len() != imags.len()` +pub fn fft_with_opts( + reals: &mut [Float], + imags: &mut [Float], + opts: &Options, + planner: &mut Planner, +) { assert_eq!(reals.len(), imags.len()); let n: usize = reals.len().ilog2() as usize; - let dist = 1 << (n - 1); - let chunk_size = dist << 1; - let (mut twiddles_re, mut twiddles_im) = if dist >= 8 * 2 { - generate_twiddles_simd(dist) - } else { - generate_twiddles(dist) - }; + let twiddles_re = &mut planner.twiddles_re; + let twiddles_im = &mut planner.twiddles_im; - assert_eq!(twiddles_re.len(), twiddles_im.len()); - - if chunk_size > 4 { - if chunk_size >= 16 { - fft_chunk_n_simd(reals, imags, &twiddles_re, &twiddles_im, dist); - } else { - fft_chunk_n(reals, imags, &twiddles_re, &twiddles_im, dist); - } - } else if chunk_size == 2 { - fft_chunk_2(reals, imags); - } else if chunk_size == 4 { - fft_chunk_4(reals, imags); - } - - for t in (0..n - 1).rev() { + for t in (0..n).rev() { let dist = 1 << t; let chunk_size = dist << 1; if chunk_size > 4 { - filter_twiddles(&mut twiddles_re, &mut twiddles_im); + if t < n - 1 { + filter_twiddles(twiddles_re, twiddles_im); + } if chunk_size >= 16 { - fft_chunk_n_simd(reals, imags, &twiddles_re, &twiddles_im, dist); + fft_chunk_n_simd(reals, imags, twiddles_re, twiddles_im, dist); } else { - fft_chunk_n(reals, imags, &twiddles_re, &twiddles_im, dist); + fft_chunk_n(reals, imags, twiddles_re, twiddles_im, dist); } } else if chunk_size == 2 { fft_chunk_2(reals, imags); @@ -96,18 +90,19 @@ mod tests { use super::*; #[test] - fn fft() { + fn fft_correctness() { let range = Range { start: 4, end: 17 }; for k in range { - let n = 1 << k; + let n: usize = 1 << k; - let mut reals: Vec = (1..=n).map(f64::from).collect(); - let mut imags: Vec = (1..=n).map(f64::from).collect(); - fft_dif(&mut reals, &mut imags); + let mut reals: Vec = (1..=n).map(|i| i as f64).collect(); + let mut imags: Vec = (1..=n).map(|i| i as f64).collect(); + let mut planner = Planner::new(n); + fft(&mut reals, &mut imags, &mut planner); let mut buffer: Vec = (1..=n) - .map(|i| Complex64::new(f64::from(i), f64::from(i))) + .map(|i| Complex64::new(i as f64, i as f64)) .collect(); let mut planner = FftPlanner::new();