diff --git a/.cargo/config b/.cargo/config.toml similarity index 100% rename from .cargo/config rename to .cargo/config.toml diff --git a/Cargo.toml b/Cargo.toml index 410e232..dd141f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,8 @@ keywords = ["quantum", "fft", "discrete", "fourier", "transform"] categories = ["algorithms", "compression", "science"] exclude = ["assets", "scripts", "benches"] +[dependencies] +num-traits = "0.2.18" [dev-dependencies] utilities = { path = "utilities" } @@ -23,3 +25,4 @@ panic = "abort" [profile.profiling] inherits = "release" debug = true + diff --git a/README.md b/README.md index 42e481a..bfab241 100644 --- a/README.md +++ b/README.md @@ -58,14 +58,16 @@ including [RustFFT](https://crates.io/crates/rustfft/), while using significantl ### Rust ```rust -use phastft::planner::Direction; -use phastft::fft; +use phastft::{ + planner::Direction, + fft_64 +}; fn main() { let big_n = 1 << 10; let mut reals: Vec = (1..=big_n).map(|i| i as f64).collect(); let mut imags: Vec = (1..=big_n).map(|i| i as f64).collect(); - fft(&mut reals, &mut imags, Direction::Forward); + fft_64(&mut reals, &mut imags, Direction::Forward); } ``` diff --git a/examples/benchmark.rs b/examples/benchmark.rs index 952d197..46d940e 100644 --- a/examples/benchmark.rs +++ b/examples/benchmark.rs @@ -3,17 +3,17 @@ use std::str::FromStr; use utilities::gen_random_signal; -use phastft::fft; +use phastft::fft_64; use phastft::planner::Direction; -fn benchmark_fft(n: usize) { +fn benchmark_fft_64(n: usize) { let big_n = 1 << n; let mut reals = vec![0.0; big_n]; let mut imags = vec![0.0; big_n]; gen_random_signal(&mut reals, &mut imags); let now = std::time::Instant::now(); - fft(&mut reals, &mut imags, Direction::Forward); + fft_64(&mut reals, &mut imags, Direction::Forward); let elapsed = now.elapsed().as_micros(); println!("{elapsed}"); } @@ -24,5 +24,5 @@ fn main() { let n = usize::from_str(&args[1]).unwrap(); - benchmark_fft(n); + benchmark_fft_64(n); } diff --git a/examples/profile.rs b/examples/profile.rs index 112e17e..6725046 100644 --- a/examples/profile.rs +++ b/examples/profile.rs @@ -1,14 +1,14 @@ use std::env; use std::str::FromStr; -use phastft::fft; +use phastft::fft_64; use phastft::planner::Direction; fn benchmark_fft(num_qubits: usize) { let n = 1 << num_qubits; let mut reals: Vec = (1..=n).map(|i| i as f64).collect(); let mut imags: Vec = (1..=n).map(|i| i as f64).collect(); - fft(&mut reals, &mut imags, Direction::Forward); + fft_64(&mut reals, &mut imags, Direction::Forward); } fn main() { diff --git a/src/cobra.rs b/src/cobra.rs index c9809a9..56a42e3 100644 --- a/src/cobra.rs +++ b/src/cobra.rs @@ -13,7 +13,8 @@ //! Symposium on Foundations of Computer Science (Cat. No.98CB36280), Palo Alto, CA, USA, 1998, pp. 544-553, doi: //! 10.1109/SFCS.1998.743505. //! keywords: {Read-write memory;Costs;Computer science;Drives;Random access memory;Argon;Registers;Read only memory;Computational modeling;Libraries} -use crate::kernels::Float; + +use num_traits::Float; const BLOCK_WIDTH: usize = 128; // size of the cacheline @@ -78,7 +79,7 @@ pub(crate) fn bit_rev(buf: &mut [T], log_n: usize) { since = "0.1.0", note = "Please use COBRA for a cache-optimal bit reverse permutation." )] -fn complex_bit_rev(reals: &mut [Float], imags: &mut [Float], log_n: usize) { +fn complex_bit_rev(reals: &mut [T], imags: &mut [T], log_n: usize) { let mut nodd: usize; let mut noddrev; // to hold bitwise negated or odd values @@ -129,25 +130,6 @@ fn complex_bit_rev(reals: &mut [Float], imags: &mut [Float], log_n: usize) { } } -/// Run in-place bit reversal on the entire state, in parallel. -/// This function uses 2 threads to run a bit reverse on the reals buffer on one thread, and the other thread handles -/// the bit reversal of the imaginaries buffer -#[allow(dead_code)] -#[deprecated( - since = "0.1.0", - note = "Naive bit reverse permutation is slow and not cache friendly. COBRA should be used instead." -)] -pub(crate) fn bit_reverse_permute_state_par( - reals: &mut [Float], - imags: &mut [Float], - log_n: usize, -) { - std::thread::scope(|s| { - s.spawn(|| bit_rev(reals, log_n)); - s.spawn(|| bit_rev(imags, log_n)); - }); -} - #[allow(dead_code)] #[deprecated( since = "0.1.0", @@ -333,26 +315,6 @@ mod tests { } } - #[test] - fn jennifer_method_parallel() { - for n in 2..24 { - let big_n = 1 << n; - let mut actual_re: Vec = (0..big_n).map(f64::from).collect(); - let mut actual_im: Vec = (0..big_n).map(f64::from).collect(); - - #[allow(deprecated)] - bit_reverse_permute_state_par(&mut actual_re, &mut actual_im, n); - - let input_re: Vec = (0..big_n).map(f64::from).collect(); - let expected_re = top_down_bit_reverse_permutation(&input_re); - assert_eq!(actual_re, expected_re); - - let input_im: Vec = (0..big_n).map(f64::from).collect(); - let expected_im = top_down_bit_reverse_permutation(&input_im); - assert_eq!(actual_im, expected_im); - } - } - #[test] fn naive_bit_reverse_permutation() { for n in 2..24 { diff --git a/src/kernels.rs b/src/kernels.rs index 67991da..a02f77a 100644 --- a/src/kernels.rs +++ b/src/kernels.rs @@ -1,56 +1,61 @@ -use std::simd::f64x8; - -pub type Float = f64; - -pub(crate) fn fft_chunk_n_simd( - reals: &mut [Float], - imags: &mut [Float], - twiddles_re: &[Float], - twiddles_im: &[Float], - dist: usize, -) { - let chunk_size = dist << 1; - assert!(chunk_size >= 16); - - reals - .chunks_exact_mut(chunk_size) - .zip(imags.chunks_exact_mut(chunk_size)) - .for_each(|(reals_chunk, imags_chunk)| { - let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist); - let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist); - - reals_s0 - .chunks_exact_mut(8) - .zip(reals_s1.chunks_exact_mut(8)) - .zip(imags_s0.chunks_exact_mut(8)) - .zip(imags_s1.chunks_exact_mut(8)) - .zip(twiddles_re.chunks_exact(8)) - .zip(twiddles_im.chunks_exact(8)) - .for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| { - let real_c0 = f64x8::from_slice(re_s0); - let real_c1 = f64x8::from_slice(re_s1); - let imag_c0 = f64x8::from_slice(im_s0); - let imag_c1 = f64x8::from_slice(im_s1); - - let tw_re = f64x8::from_slice(w_re); - let tw_im = f64x8::from_slice(w_im); - - re_s0.copy_from_slice((real_c0 + real_c1).as_array()); - im_s0.copy_from_slice((imag_c0 + imag_c1).as_array()); - let v_re = real_c0 - real_c1; - let v_im = imag_c0 - imag_c1; - re_s1.copy_from_slice((v_re * tw_re - v_im * tw_im).as_array()); - im_s1.copy_from_slice((v_re * tw_im + v_im * tw_re).as_array()); +use std::simd::{f32x16, f64x8}; + +use num_traits::Float; + +macro_rules! fft_butterfly_n_simd { + ($func_name:ident, $precision:ty, $lanes:literal, $simd_vector:ty) => { + pub fn $func_name( + reals: &mut [$precision], + imags: &mut [$precision], + twiddles_re: &[$precision], + twiddles_im: &[$precision], + dist: usize, + ) { + let chunk_size = dist << 1; + assert!(chunk_size >= $lanes * 2); + reals + .chunks_exact_mut(chunk_size) + .zip(imags.chunks_exact_mut(chunk_size)) + .for_each(|(reals_chunk, imags_chunk)| { + let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist); + let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist); + + reals_s0 + .chunks_exact_mut($lanes) + .zip(reals_s1.chunks_exact_mut($lanes)) + .zip(imags_s0.chunks_exact_mut($lanes)) + .zip(imags_s1.chunks_exact_mut($lanes)) + .zip(twiddles_re.chunks_exact($lanes)) + .zip(twiddles_im.chunks_exact($lanes)) + .for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| { + let real_c0 = <$simd_vector>::from_slice(re_s0); + let real_c1 = <$simd_vector>::from_slice(re_s1); + let imag_c0 = <$simd_vector>::from_slice(im_s0); + let imag_c1 = <$simd_vector>::from_slice(im_s1); + + let tw_re = <$simd_vector>::from_slice(w_re); + let tw_im = <$simd_vector>::from_slice(w_im); + + re_s0.copy_from_slice((real_c0 + real_c1).as_array()); + im_s0.copy_from_slice((imag_c0 + imag_c1).as_array()); + let v_re = real_c0 - real_c1; + let v_im = imag_c0 - imag_c1; + re_s1.copy_from_slice((v_re * tw_re - v_im * tw_im).as_array()); + im_s1.copy_from_slice((v_re * tw_im + v_im * tw_re).as_array()); + }); }); - }); + } + }; } -// TODO(saveliy): parallelize -pub(crate) fn fft_chunk_n( - reals: &mut [Float], - imags: &mut [Float], - twiddles_re: &[Float], - twiddles_im: &[Float], +fft_butterfly_n_simd!(fft_64_chunk_n_simd, f64, 8, f64x8); +fft_butterfly_n_simd!(fft_32_chunk_n_simd, f32, 16, f32x16); + +pub(crate) fn fft_chunk_n( + reals: &mut [T], + imags: &mut [T], + twiddles_re: &[T], + twiddles_im: &[T], dist: usize, ) { let chunk_size = dist << 1; @@ -79,14 +84,14 @@ pub(crate) fn fft_chunk_n( *im_s0 = imag_c0 + imag_c1; let v_re = real_c0 - real_c1; let v_im = imag_c0 - imag_c1; - *re_s1 = v_re * w_re - v_im * w_im; - *im_s1 = v_re * w_im + v_im * w_re; + *re_s1 = v_re * *w_re - v_im * *w_im; + *im_s1 = v_re * *w_im + v_im * *w_re; }); }); } /// `chunk_size == 4`, so hard code twiddle factors -pub(crate) fn fft_chunk_4(reals: &mut [Float], imags: &mut [Float]) { +pub(crate) fn fft_chunk_4(reals: &mut [T], imags: &mut [T]) { let dist = 2; let chunk_size = dist << 1; @@ -120,7 +125,7 @@ pub(crate) fn fft_chunk_4(reals: &mut [Float], imags: &mut [Float]) { } /// `chunk_size == 2`, so skip phase -pub(crate) fn fft_chunk_2(reals: &mut [Float], imags: &mut [Float]) { +pub(crate) fn fft_chunk_2(reals: &mut [T], imags: &mut [T]) { reals .chunks_exact_mut(2) .zip(imags.chunks_exact_mut(2)) diff --git a/src/lib.rs b/src/lib.rs index 4d33350..7c7e4b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,9 +9,11 @@ #![feature(portable_simd)] use crate::cobra::cobra_apply; -use crate::kernels::{fft_chunk_2, fft_chunk_4, fft_chunk_n, fft_chunk_n_simd, Float}; +use crate::kernels::{ + fft_32_chunk_n_simd, fft_64_chunk_n_simd, fft_chunk_2, fft_chunk_4, fft_chunk_n, +}; use crate::options::Options; -use crate::planner::{Direction, Planner}; +use crate::planner::{Direction, Planner32, Planner64}; use crate::twiddles::filter_twiddles; pub mod cobra; @@ -20,173 +22,230 @@ pub mod options; pub mod planner; mod twiddles; -/// FFT -- Decimation in Frequency. This is just the decimation-in-time algorithm, reversed. -/// This call to FFT is run, in-place. -/// The input should be provided in normal order, and then the modified input is bit-reversed. -/// -/// # Panics -/// -/// Panics if `reals.len() != imags.len()` -/// -/// ## References -/// -pub fn fft(reals: &mut [Float], imags: &mut [Float], direction: Direction) { - assert_eq!( - reals.len(), - imags.len(), - "real and imaginary inputs must be of equal size, but got: {} {}", - reals.len(), - imags.len() - ); - - let mut planner = Planner::new(reals.len(), direction); - assert!(planner.num_twiddles().is_power_of_two() && planner.num_twiddles() == reals.len() / 2); - - let opts = Options::guess_options(reals.len()); - fft_with_opts_and_plan(reals, imags, &opts, &mut planner); +macro_rules! impl_fft_for { + ($func_name:ident, $precision:ty, $planner:ty, $opts_and_plan:ident) => { + /// FFT -- Decimation in Frequency. This is just the decimation-in-time algorithm, reversed. + /// This call to FFT is run, in-place. + /// The input should be provided in normal order, and then the modified input is bit-reversed. + /// + /// # Panics + /// + /// Panics if `reals.len() != imags.len()` or if `reals.len()` and `imags.len()` are not a power of + /// 2 + /// + /// ## References + /// + pub fn $func_name( + reals: &mut [$precision], + imags: &mut [$precision], + direction: Direction, + ) { + assert_eq!( + reals.len(), + imags.len(), + "real and imaginary inputs must be of equal size, but got: {} {}", + reals.len(), + imags.len() + ); + + let mut planner = <$planner>::new(reals.len(), direction); + assert!( + planner.num_twiddles().is_power_of_two() + && planner.num_twiddles() == reals.len() / 2 + ); + + let opts = Options::guess_options(reals.len()); + $opts_and_plan(reals, imags, &opts, &mut planner); + } + }; } -/// Same as [fft], but also accepts [`Options`] that control optimization strategies, as well as -/// a [`Planner`] in the case that this FFT will need to be run multiple times. -/// -/// `fft` automatically guesses the best strategy for a given input, -/// so you only need to call this if you are tuning performance for a specific hardware platform. -/// -/// In addition, `fft` automatically creates a planner to be used. In the case that you plan -/// on running an FFT many times on inputs of the same size, use this function with the pre-built -/// [`Planner`]. -/// -/// # Panics -/// -/// Panics if `reals.len() != imags.len()`, or if the input length is *not* a power of two. -pub fn fft_with_opts_and_plan( - reals: &mut [Float], - imags: &mut [Float], - opts: &Options, - planner: &mut Planner, -) { - assert!(reals.len() == imags.len() && reals.len().is_power_of_two()); - let n: usize = reals.len().ilog2() as usize; - - let twiddles_re = &mut planner.twiddles_re; - let twiddles_im = &mut planner.twiddles_im; - - // We shouldn't be able to execute FFT if the # of twiddles isn't equal to the distance - // between pairs - assert!(twiddles_re.len() == reals.len() / 2 && twiddles_im.len() == imags.len() / 2); - - for t in (0..n).rev() { - let dist = 1 << t; - let chunk_size = dist << 1; - - if chunk_size > 4 { - if t < n - 1 { - filter_twiddles(twiddles_re, twiddles_im); +impl_fft_for!(fft_64, f64, Planner64, fft_64_with_opts_and_plan); +impl_fft_for!(fft_32, f32, Planner32, fft_32_with_opts_and_plan); + +macro_rules! impl_fft_with_opts_and_plan_for { + ($func_name:ident, $precision:ty, $planner:ty, $simd_butterfly_kernel:ident, $lanes:literal) => { + /// Same as [fft], but also accepts [`Options`] that control optimization strategies, as well as + /// a [`Planner`] in the case that this FFT will need to be run multiple times. + /// + /// `fft` automatically guesses the best strategy for a given input, + /// so you only need to call this if you are tuning performance for a specific hardware platform. + /// + /// In addition, `fft` automatically creates a planner to be used. In the case that you plan + /// on running an FFT many times on inputs of the same size, use this function with the pre-built + /// [`Planner`]. + /// + /// # Panics + /// + /// Panics if `reals.len() != imags.len()`, or if the input length is *not* a power of 2. + pub fn $func_name( + reals: &mut [$precision], + imags: &mut [$precision], + opts: &Options, + planner: &mut $planner, + ) { + assert!(reals.len() == imags.len() && reals.len().is_power_of_two()); + let n: usize = reals.len().ilog2() as usize; + + let twiddles_re = &mut planner.twiddles_re; + let twiddles_im = &mut planner.twiddles_im; + + // We shouldn't be able to execute FFT if the # of twiddles isn't equal to the distance + // between pairs + assert!(twiddles_re.len() == reals.len() / 2 && twiddles_im.len() == imags.len() / 2); + + for t in (0..n).rev() { + let dist = 1 << t; + let chunk_size = dist << 1; + + if chunk_size > 4 { + if t < n - 1 { + filter_twiddles(twiddles_re, twiddles_im); + } + if chunk_size >= $lanes * 2 { + $simd_butterfly_kernel(reals, imags, twiddles_re, twiddles_im, dist); + } else { + fft_chunk_n(reals, imags, twiddles_re, twiddles_im, dist); + } + } else if chunk_size == 2 { + fft_chunk_2(reals, imags); + } else if chunk_size == 4 { + fft_chunk_4(reals, imags); + } } - if chunk_size >= 16 { - fft_chunk_n_simd(reals, imags, twiddles_re, twiddles_im, dist); + + if opts.multithreaded_bit_reversal { + std::thread::scope(|s| { + s.spawn(|| cobra_apply(reals, n)); + s.spawn(|| cobra_apply(imags, n)); + }); } else { - fft_chunk_n(reals, imags, twiddles_re, twiddles_im, dist); + cobra_apply(reals, n); + cobra_apply(imags, n); } - } else if chunk_size == 2 { - fft_chunk_2(reals, imags); - } else if chunk_size == 4 { - fft_chunk_4(reals, imags); } - } - - if opts.multithreaded_bit_reversal { - std::thread::scope(|s| { - s.spawn(|| cobra_apply(reals, n)); - s.spawn(|| cobra_apply(imags, n)); - }); - } else { - cobra_apply(reals, n); - cobra_apply(imags, n); - } + }; } +impl_fft_with_opts_and_plan_for!( + fft_64_with_opts_and_plan, + f64, + Planner64, + fft_64_chunk_n_simd, + 8 +); + +impl_fft_with_opts_and_plan_for!( + fft_32_with_opts_and_plan, + f32, + Planner32, + fft_32_chunk_n_simd, + 16 +); + #[cfg(test)] mod tests { use std::ops::Range; - use utilities::{ - assert_f64_closeness, - rustfft::{num_complex::Complex64, FftPlanner}, - }; - - use crate::planner::Direction; + use utilities::assert_float_closeness; + use utilities::rustfft::num_complex::Complex; + use utilities::rustfft::FftPlanner; use super::*; - #[should_panic] - #[test] - fn non_power_of_two_fft() { - let num_points = 5; - - // this test will actually always fail at this stage - let mut planner = Planner::new(num_points, Direction::Forward); - - let mut reals = vec![0.0; num_points]; - let mut imags = vec![0.0; num_points]; - let opts = Options::guess_options(reals.len()); + macro_rules! non_power_of_2_planner { + ($test_name:ident, $planner:ty) => { + #[should_panic] + #[test] + fn $test_name() { + let num_points = 5; - // but this call should, in principle, panic as well - fft_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner); + // this test _should_ always fail at this stage + let _ = <$planner>::new(num_points, Direction::Forward); + } + }; } - // A regression test to make sure the `Planner` is compatible with fft execution. - #[should_panic] - #[test] - fn wrong_num_points_in_planner() { - let n = 16; - let num_points = 1 << n; - - // We purposely set n = 16 and pass it to the planner. - // n = 16 == 2^{4} is clearly a power of two, so the planner won't throw it out. - // However, the call to `fft_with_opts_and_plan` should panic since it tests that the - // size of the generated twiddle factors is half the size of the input. - // In this case, we have an input of size 1024 (used for mp3), but we tell the planner the - // input size is 16. - let mut planner = Planner::new(n, Direction::Forward); - - let mut reals = vec![0.0; num_points]; - let mut imags = vec![0.0; num_points]; - let opts = Options::guess_options(reals.len()); - - // but this call should panic as well - fft_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner); + non_power_of_2_planner!(non_power_of_2_planner_32, Planner32); + non_power_of_2_planner!(non_power_of_2_planner_64, Planner64); + + macro_rules! wrong_num_points_in_planner { + ($test_name:ident, $planner:ty, $fft_with_opts_and_plan:ident) => { + // A regression test to make sure the `Planner` is compatible with fft execution. + #[should_panic] + #[test] + fn $test_name() { + let n = 16; + let num_points = 1 << n; + + // We purposely set n = 16 and pass it to the planner. + // n == 16 == 2^{4} is clearly a power of two, so the planner won't throw it out. + // However, the call to `fft_with_opts_and_plan` should panic since it tests that the + // size of the generated twiddle factors is half the size of the input. + // In this case, we have an input of size 1024 (used for mp3), but we tell the planner the + // input size is 16. + let mut planner = <$planner>::new(n, Direction::Forward); + + let mut reals = vec![0.0; num_points]; + let mut imags = vec![0.0; num_points]; + let opts = Options::guess_options(reals.len()); + + // this call should panic + $fft_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner); + } + }; } - #[test] - fn fft_correctness() { - let range = Range { start: 4, end: 17 }; - - for k in range { - let n: usize = 1 << k; - - let mut reals: Vec = (1..=n).map(|i| i as f64).collect(); - let mut imags: Vec = (1..=n).map(|i| i as f64).collect(); - fft(&mut reals, &mut imags, Direction::Forward); - - let mut buffer: Vec = (1..=n) - .map(|i| Complex64::new(i as f64, i as f64)) - .collect(); - - let mut planner = FftPlanner::new(); - let fft = planner.plan_fft_forward(buffer.len()); - fft.process(&mut buffer); - - reals - .iter() - .zip(imags.iter()) - .enumerate() - .for_each(|(i, (z_re, z_im))| { - let expect_re = buffer[i].re; - let expect_im = buffer[i].im; - assert_f64_closeness(*z_re, expect_re, 0.01); - assert_f64_closeness(*z_im, expect_im, 0.01); - }); - } + wrong_num_points_in_planner!( + wrong_num_points_in_planner_32, + Planner32, + fft_32_with_opts_and_plan + ); + wrong_num_points_in_planner!( + wrong_num_points_in_planner_64, + Planner64, + fft_64_with_opts_and_plan + ); + + macro_rules! test_fft_correctness { + ($test_name:ident, $precision:ty, $fft_type:ident, $range_start:literal, $range_end:literal) => { + #[test] + fn $test_name() { + let range = Range { + start: $range_start, + end: $range_end, + }; + + for k in range { + let n: usize = 1 << k; + + let mut reals: Vec<$precision> = (1..=n).map(|i| i as $precision).collect(); + let mut imags: Vec<$precision> = (1..=n).map(|i| i as $precision).collect(); + $fft_type(&mut reals, &mut imags, Direction::Forward); + + let mut buffer: Vec> = (1..=n) + .map(|i| Complex::new(i as $precision, i as $precision)) + .collect(); + + let mut planner = FftPlanner::new(); + let fft = planner.plan_fft_forward(buffer.len()); + fft.process(&mut buffer); + + reals + .iter() + .zip(imags.iter()) + .enumerate() + .for_each(|(i, (z_re, z_im))| { + let expect_re = buffer[i].re; + let expect_im = buffer[i].im; + assert_float_closeness(*z_re, expect_re, 0.01); + assert_float_closeness(*z_im, expect_im, 0.01); + }); + } + } + }; } + + test_fft_correctness!(fft_correctness_32, f32, fft_32, 4, 9); + test_fft_correctness!(fft_correctness_64, f64, fft_64, 4, 17); } diff --git a/src/planner.rs b/src/planner.rs index cec8fe4..548e972 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -2,8 +2,7 @@ //! a Fast Fourier Transform (FFT). Currently, the planner is responsible for //! pre-computing twiddle factors based on the input signal length, as well as the //! direction of the FFT. - -use crate::twiddles::{generate_twiddles, generate_twiddles_simd}; +use crate::twiddles::{generate_twiddles, generate_twiddles_simd_32, generate_twiddles_simd_64}; /// Reverse is for running the Inverse Fast Fourier Transform (IFFT) /// Forward is for running the regular FFT @@ -14,96 +13,117 @@ pub enum Direction { Reverse = -1, } -/// The planner is responsible for pre-computing and storing twiddle factors for all the -/// `log_2(N)` stages of the FFT. -/// The amount of twiddle factors should always be a power of 2. In addition, -/// the amount of twiddle factors should always be `(1/2) * N` -pub struct Planner { - /// The real components of the twiddle factors - pub twiddles_re: Vec, - /// The imaginary components of the twiddle factors - pub twiddles_im: Vec, -} - -impl Planner { - /// Create a `Planner` for an FFT of size `num_points`. - /// The twiddle factors are pre-computed based on the provided [`Direction`]. - /// For `Forward`, use [`Direction::Forward`]. - /// For `Reverse`, use [`Direction::Reverse`]. - /// - /// # Panics - /// - /// Panics if `num_points < 1` - pub fn new(num_points: usize, direction: Direction) -> Self { - assert!(num_points > 0 && num_points.is_power_of_two()); - if num_points <= 4 { - return Self { - twiddles_re: vec![], - twiddles_im: vec![], - }; +macro_rules! impl_planner_for { + ($struct_name:ident, $precision:ident, $generate_twiddles_simd_fn:ident) => { + /// The planner is responsible for pre-computing and storing twiddle factors for all the + /// `log_2(N)` stages of the FFT. + /// The amount of twiddle factors should always be a power of 2. In addition, + /// the amount of twiddle factors should always be `(1/2) * N` + pub struct $struct_name { + /// The real components of the twiddle factors + pub twiddles_re: Vec<$precision>, + /// The imaginary components of the twiddle factors + pub twiddles_im: Vec<$precision>, } + impl $struct_name { + /// Create a `Planner` for an FFT of size `num_points`. + /// The twiddle factors are pre-computed based on the provided [`Direction`]. + /// For `Forward`, use [`Direction::Forward`]. + /// For `Reverse`, use [`Direction::Reverse`]. + /// + /// # Panics + /// + /// Panics if `num_points < 1` or if `num_points` is __not__ a power of 2. + pub fn new(num_points: usize, direction: Direction) -> Self { + assert!(num_points > 0 && num_points.is_power_of_two()); + if num_points <= 4 { + return Self { + twiddles_re: vec![], + twiddles_im: vec![], + }; + } - let dist = num_points >> 1; - let (twiddles_re, twiddles_im) = if dist >= 8 * 2 { - generate_twiddles_simd(dist, direction) - } else { - generate_twiddles(dist, direction) - }; + let dist = num_points >> 1; - assert_eq!(twiddles_re.len(), twiddles_im.len()); + let (twiddles_re, twiddles_im) = if dist >= 8 * 2 { + $generate_twiddles_simd_fn(dist, direction) + } else { + generate_twiddles(dist, direction) + }; - Self { - twiddles_re, - twiddles_im, - } - } + assert_eq!(twiddles_re.len(), twiddles_im.len()); - pub(crate) fn num_twiddles(&self) -> usize { - assert_eq!(self.twiddles_re.len(), self.twiddles_im.len()); - self.twiddles_re.len() - } + Self { + twiddles_re, + twiddles_im, + } + } + + pub(crate) fn num_twiddles(&self) -> usize { + assert_eq!(self.twiddles_re.len(), self.twiddles_im.len()); + self.twiddles_re.len() + } + } + }; } +impl_planner_for!(Planner64, f64, generate_twiddles_simd_64); +impl_planner_for!(Planner32, f32, generate_twiddles_simd_32); + #[cfg(test)] mod tests { - use utilities::assert_f64_closeness; + use utilities::assert_float_closeness; - use crate::planner::{Direction, Planner}; + use super::*; - #[test] - fn no_twiddles() { - for num_points in [2, 4] { - let planner = Planner::new(num_points, Direction::Forward); - assert!(planner.twiddles_im.is_empty() && planner.twiddles_re.is_empty()); - } + macro_rules! test_no_twiddles { + ($test_name:ident, $planner:ty) => { + #[test] + fn $test_name() { + for num_points in [2, 4] { + let planner = <$planner>::new(num_points, Direction::Forward); + assert!(planner.twiddles_im.is_empty() && planner.twiddles_re.is_empty()); + } + } + }; } - #[test] - fn forward_mul_inverse_eq_identity() { - for i in 3..25 { - let num_points = 1 << i; - let planner_forward = Planner::new(num_points, Direction::Forward); - let planner_reverse = Planner::new(num_points, Direction::Reverse); - - assert_eq!( - planner_reverse.num_twiddles(), - planner_forward.num_twiddles() - ); - - // (a + ib) (c + id) = ac + iad + ibc - bd - // = ac - bd + i(ad + bc) - planner_forward - .twiddles_re - .iter() - .zip(planner_forward.twiddles_im.iter()) - .zip(planner_reverse.twiddles_re.iter()) - .zip(planner_reverse.twiddles_im) - .for_each(|(((a, b), c), d)| { - let temp_re = a * c - b * d; - let temp_im = a * d + b * c; - assert_f64_closeness(temp_re, 1.0, 1e-6); - assert_f64_closeness(temp_im, 0.0, 1e-6); - }); - } + test_no_twiddles!(no_twiddles_64, Planner64); + test_no_twiddles!(no_twiddles_32, Planner32); + + macro_rules! forward_mul_inverse_eq_identity { + ($test_name:ident, $planner:ty) => { + #[test] + fn $test_name() { + for i in 3..25 { + let num_points = 1 << i; + let planner_forward = <$planner>::new(num_points, Direction::Forward); + let planner_reverse = <$planner>::new(num_points, Direction::Reverse); + + assert_eq!( + planner_reverse.num_twiddles(), + planner_forward.num_twiddles() + ); + + // (a + ib) (c + id) = ac + iad + ibc - bd + // = ac - bd + i(ad + bc) + planner_forward + .twiddles_re + .iter() + .zip(planner_forward.twiddles_im.iter()) + .zip(planner_reverse.twiddles_re.iter()) + .zip(planner_reverse.twiddles_im) + .for_each(|(((a, b), c), d)| { + let temp_re = a * c - b * d; + let temp_im = a * d + b * c; + assert_float_closeness(temp_re, 1.0, 1e-2); + assert_float_closeness(temp_im, 0.0, 1e-2); + }); + } + } + }; } + + forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, Planner64); + forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, Planner32); } diff --git a/src/twiddles.rs b/src/twiddles.rs index 48b91fa..97f31a0 100644 --- a/src/twiddles.rs +++ b/src/twiddles.rs @@ -1,37 +1,43 @@ -use std::{f64::consts::PI, simd::f64x8}; +use std::simd::{f32x8, f64x8}; + +use num_traits::{Float, FloatConst, One, Zero}; -use crate::kernels::Float; use crate::planner::Direction; -pub(crate) struct Twiddles { - st: Float, - ct: Float, - w_re_prev: Float, - w_im_prev: Float, +pub(crate) struct Twiddles { + st: T, + ct: T, + w_re_prev: T, + w_im_prev: T, } -impl Twiddles { +impl Twiddles { /// `cache_size` is the amount of roots of unity kept pre-built at any point in time. /// `num_roots` is the total number of roots of unity that will need to be computed. /// `cache_size` can be thought of as the length of a chunk of roots of unity from /// out of the total amount (i.e., `num_roots`) #[allow(dead_code)] pub fn new(num_roots: usize) -> Self { - let theta = -PI / (num_roots as Float); + let theta = -T::PI() / (T::from(num_roots).unwrap()); let (st, ct) = theta.sin_cos(); Self { st, ct, - w_re_prev: 1.0, - w_im_prev: 0.0, + w_re_prev: T::one(), + w_im_prev: T::zero(), } } } -impl Iterator for Twiddles { - type Item = (Float, Float); +// TODO: generate twiddles using the first quarter chunk of twiddle factors +// 1st chunk: old fashioned multiplication of complex nums +// 2nd chunk: reverse the 1st chunk, swap components, and negate both components +// 3rd chunk: No reversal. Swap the components and negate the *new* imaginary components +// 4th chunk: reverse the 1st chunk, and negate the real component +impl Iterator for Twiddles { + type Item = (T, T); - fn next(&mut self) -> Option<(f64, f64)> { + fn next(&mut self) -> Option<(T, T)> { let w_re = self.w_re_prev; let w_im = self.w_im_prev; @@ -43,19 +49,22 @@ impl Iterator for Twiddles { } } -pub fn generate_twiddles(dist: usize, direction: Direction) -> (Vec, Vec) { - let mut twiddles_re = vec![0.0; dist]; - let mut twiddles_im = vec![0.0; dist]; - twiddles_re[0] = 1.0; +pub fn generate_twiddles( + dist: usize, + direction: Direction, +) -> (Vec, Vec) { + let mut twiddles_re = vec![T::zero(); dist]; + let mut twiddles_im = vec![T::zero(); dist]; + twiddles_re[0] = T::one(); let sign = match direction { - Direction::Forward => 1.0, - Direction::Reverse => -1.0, + Direction::Forward => T::one(), + Direction::Reverse => -T::one(), }; - let angle: Float = sign * -PI / (dist as f64); + let angle = sign * -T::PI() / T::from(dist).unwrap(); let (st, ct) = angle.sin_cos(); - let (mut w_re, mut w_im) = (1.0, 0.0); + let (mut w_re, mut w_im) = (T::one(), T::zero()); let mut i = 1; while i < (dist / 2) + 1 { @@ -76,106 +85,108 @@ pub fn generate_twiddles(dist: usize, direction: Direction) -> (Vec, Vec (Vec, Vec) { - const CHUNK_SIZE: usize = 8; // TODO: make this a const generic? - assert!(dist >= CHUNK_SIZE * 2); - assert_eq!(dist % CHUNK_SIZE, 0); - let mut twiddles_re = vec![0.0; dist]; - let mut twiddles_im = vec![0.0; dist]; - twiddles_re[0] = 1.0; - - let sign = match direction { - Direction::Forward => 1.0, - Direction::Reverse => -1.0, - }; - - let angle: Float = sign * -PI / (dist as f64); - let (st, ct) = angle.sin_cos(); - let (mut w_re, mut w_im) = (1.0, 0.0); - - let mut next_twiddle = || { - let temp = w_re; - w_re = w_re * ct - w_im * st; - w_im = temp * st + w_im * ct; - (w_re, w_im) - }; - - let apply_symmetry_re = |input: &[Float], output: &mut [Float]| { - let first_re = f64x8::from_slice(input); - let minus_one = f64x8::splat(-1.0); - let negated = (first_re * minus_one).reverse(); - output.copy_from_slice(negated.as_array()); - }; - - let apply_symmetry_im = |input: &[Float], output: &mut [Float]| { - let mut buf: [Float; CHUNK_SIZE] = [Float::default(); 8]; - buf.copy_from_slice(input); - buf.reverse(); - output.copy_from_slice(&buf); - }; - - // TODO: generate twiddles using the first quarter chunk of twiddle factors - // 1st chunk: old fashioned multiplication of complex nums - // 2nd chunk: reverse the 1st chunk, swap components, and negate both components - // 3rd chunk: No reversal. Swap the components and negate the *new* imaginary components - // 4th chunk: reverse the 1st chunk, and negate the real component - - // Split the twiddles into two halves. There is a cheaper way to calculate the second half - let (first_half_re, second_half_re) = twiddles_re[1..].split_at_mut(dist / 2); - assert_eq!(first_half_re.len(), second_half_re.len() + 1); - let (first_half_im, second_half_im) = twiddles_im[1..].split_at_mut(dist / 2); - assert_eq!(first_half_im.len(), second_half_im.len() + 1); - - first_half_re - .chunks_exact_mut(CHUNK_SIZE) - .zip(first_half_im.chunks_exact_mut(CHUNK_SIZE)) - .zip( - second_half_re[CHUNK_SIZE - 1..] - .chunks_exact_mut(CHUNK_SIZE) - .rev(), - ) - .zip( - second_half_im[CHUNK_SIZE - 1..] +macro_rules! generate_twiddles_simd { + ($func_name:ident, $precision:ty, $lanes:literal, $simd_vector:ty) => { + pub(crate) fn $func_name( + dist: usize, + direction: Direction, + ) -> (Vec<$precision>, Vec<$precision>) { + const CHUNK_SIZE: usize = 8; // TODO: make this a const generic? + assert!(dist >= CHUNK_SIZE * 2); + assert_eq!(dist % CHUNK_SIZE, 0); + let mut twiddles_re = vec![0.0; dist]; + let mut twiddles_im = vec![0.0; dist]; + twiddles_re[0] = 1.0; + + let sign = match direction { + Direction::Forward => 1.0, + Direction::Reverse => -1.0, + }; + + let angle = sign * -<$precision>::PI() / dist as $precision; + let (st, ct) = angle.sin_cos(); + let (mut w_re, mut w_im) = (<$precision>::one(), <$precision>::zero()); + + let mut next_twiddle = || { + let temp = w_re; + w_re = w_re * ct - w_im * st; + w_im = temp * st + w_im * ct; + (w_re, w_im) + }; + + let apply_symmetry_re = |input: &[$precision], output: &mut [$precision]| { + let first_re = <$simd_vector>::from_slice(input); + let minus_one = <$simd_vector>::splat(-1.0); + let negated = (first_re * minus_one).reverse(); + output.copy_from_slice(negated.as_array()); + }; + + let apply_symmetry_im = |input: &[$precision], output: &mut [$precision]| { + let mut buf: [$precision; CHUNK_SIZE] = [0.0; CHUNK_SIZE]; + buf.copy_from_slice(input); + buf.reverse(); + output.copy_from_slice(&buf); + }; + + // Split the twiddles into two halves. There is a cheaper way to calculate the second half + let (first_half_re, second_half_re) = twiddles_re[1..].split_at_mut(dist / 2); + assert_eq!(first_half_re.len(), second_half_re.len() + 1); + let (first_half_im, second_half_im) = twiddles_im[1..].split_at_mut(dist / 2); + assert_eq!(first_half_im.len(), second_half_im.len() + 1); + + first_half_re .chunks_exact_mut(CHUNK_SIZE) - .rev(), - ) - .for_each( - |(((first_ch_re, first_ch_im), second_ch_re), second_ch_im)| { - // Calculate a chunk of the first half in a plain old scalar way - first_ch_re - .iter_mut() - .zip(first_ch_im.iter_mut()) - .for_each(|(re, im)| { - (*re, *im) = next_twiddle(); - }); - // Calculate a chunk of the second half in a clever way by copying the first chunk - // This avoids data dependencies of the regular calculation and gets vectorized. - // We do it up front while the values we just calculated are still in the cache, - // so we don't have to re-load them from memory later, which would be slow. - apply_symmetry_re(first_ch_re, second_ch_re); - apply_symmetry_im(first_ch_im, second_ch_im); - }, - ); - - // Fill in the middle that the SIMD loop didn't - twiddles_re[dist / 2 - CHUNK_SIZE + 1..][..(CHUNK_SIZE * 2) - 1] - .iter_mut() - .zip(twiddles_im[dist / 2 - CHUNK_SIZE + 1..][..(CHUNK_SIZE * 2) - 1].iter_mut()) - .for_each(|(re, im)| { - (*re, *im) = next_twiddle(); - }); + .zip(first_half_im.chunks_exact_mut(CHUNK_SIZE)) + .zip( + second_half_re[CHUNK_SIZE - 1..] + .chunks_exact_mut(CHUNK_SIZE) + .rev(), + ) + .zip( + second_half_im[CHUNK_SIZE - 1..] + .chunks_exact_mut(CHUNK_SIZE) + .rev(), + ) + .for_each( + |(((first_ch_re, first_ch_im), second_ch_re), second_ch_im)| { + // Calculate a chunk of the first half in a plain old scalar way + first_ch_re + .iter_mut() + .zip(first_ch_im.iter_mut()) + .for_each(|(re, im)| { + (*re, *im) = next_twiddle(); + }); + // Calculate a chunk of the second half in a clever way by copying the first chunk + // This avoids data dependencies of the regular calculation and gets vectorized. + // We do it up front while the values we just calculated are still in the cache, + // so we don't have to re-load them from memory later, which would be slow. + apply_symmetry_re(first_ch_re, second_ch_re); + apply_symmetry_im(first_ch_im, second_ch_im); + }, + ); + + // Fill in the middle that the SIMD loop didn't + twiddles_re[dist / 2 - CHUNK_SIZE + 1..][..(CHUNK_SIZE * 2) - 1] + .iter_mut() + .zip(twiddles_im[dist / 2 - CHUNK_SIZE + 1..][..(CHUNK_SIZE * 2) - 1].iter_mut()) + .for_each(|(re, im)| { + (*re, *im) = next_twiddle(); + }); - (twiddles_re, twiddles_im) + (twiddles_re, twiddles_im) + } + }; } -pub(crate) fn filter_twiddles(twiddles_re: &mut Vec, twiddles_im: &mut Vec) { +generate_twiddles_simd!(generate_twiddles_simd_64, f64, 8, f64x8); +generate_twiddles_simd!(generate_twiddles_simd_32, f32, 8, f32x8); + +pub(crate) fn filter_twiddles(twiddles_re: &mut Vec, twiddles_im: &mut Vec) { assert_eq!(twiddles_re.len(), twiddles_im.len()); let dist = twiddles_re.len(); - let filtered_twiddles_re: Vec = - twiddles_re.chunks_exact(2).map(|chunk| chunk[0]).collect(); - let filtered_twiddles_im: Vec = - twiddles_im.chunks_exact(2).map(|chunk| chunk[0]).collect(); + let filtered_twiddles_re: Vec = twiddles_re.chunks_exact(2).map(|chunk| chunk[0]).collect(); + let filtered_twiddles_im: Vec = twiddles_im.chunks_exact(2).map(|chunk| chunk[0]).collect(); assert!( filtered_twiddles_re.len() == filtered_twiddles_im.len() @@ -190,7 +201,7 @@ pub(crate) fn filter_twiddles(twiddles_re: &mut Vec, twiddles_im: &mut Vec< mod tests { use std::f64::consts::FRAC_1_SQRT_2; - use utilities::assert_f64_closeness; + use utilities::assert_float_closeness; use super::*; @@ -201,62 +212,74 @@ mod tests { let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, 1.0, 1e-10); - assert_f64_closeness(w_im, 0.0, 1e-10); + assert_float_closeness(w_re, 1.0, 1e-10); + assert_float_closeness(w_im, 0.0, 1e-10); let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, FRAC_1_SQRT_2, 1e-10); - assert_f64_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_re, FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, 0.0, 1e-10); - assert_f64_closeness(w_im, -1.0, 1e-10); + assert_float_closeness(w_re, 0.0, 1e-10); + assert_float_closeness(w_im, -1.0, 1e-10); let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, -FRAC_1_SQRT_2, 1e-10); - assert_f64_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_re, -FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); } - #[test] - fn twiddles_simd() { - for n in 4..28 { - let dist = 1 << n; - - let (twiddles_re_ref, twiddles_im_ref) = generate_twiddles(dist, Direction::Forward); - let (twiddles_re, twiddles_im) = generate_twiddles_simd(dist, Direction::Forward); - - twiddles_re - .iter() - .zip(twiddles_re_ref.iter()) - .for_each(|(simd, reference)| { - assert_f64_closeness(*simd, *reference, 1e-10); - }); - - twiddles_im - .iter() - .zip(twiddles_im_ref.iter()) - .for_each(|(simd, reference)| { - assert_f64_closeness(*simd, *reference, 1e-10); - }); - } + macro_rules! test_twiddles_simd { + ($test_name:ident, $generate_twiddles_simd:ident, $epsilon:literal) => { + #[test] + fn $test_name() { + for n in 4..25 { + let dist = 1 << n; + + let (twiddles_re_ref, twiddles_im_ref) = + generate_twiddles(dist, Direction::Forward); + let (twiddles_re, twiddles_im) = + $generate_twiddles_simd(dist, Direction::Forward); + + twiddles_re + .iter() + .zip(twiddles_re_ref.iter()) + .for_each(|(simd, reference)| { + assert_float_closeness(*simd, *reference, $epsilon); + }); + + twiddles_im + .iter() + .zip(twiddles_im_ref.iter()) + .for_each(|(simd, reference)| { + assert_float_closeness(*simd, *reference, $epsilon); + }); + } + } + }; } + test_twiddles_simd!(twiddles_simd_32, generate_twiddles_simd_32, 1e-1); + test_twiddles_simd!(twiddles_simd_64, generate_twiddles_simd_64, 1e-10); + #[test] fn twiddles_filter() { + // Assume n = 28 let n = 28; + // distance := 2^{n} / 2 == 2^{n-1} let dist = 1 << (n - 1); + let mut twiddles_iter = Twiddles::new(dist); let (mut twiddles_re, mut twiddles_im) = generate_twiddles(dist, Direction::Forward); for i in 0..dist { let (tw_re, tw_im) = twiddles_iter.next().unwrap(); - assert_f64_closeness(twiddles_re[i], tw_re, 1e-6); - assert_f64_closeness(twiddles_im[i], tw_im, 1e-6); + assert_float_closeness(twiddles_re[i], tw_re, 1e-6); + assert_float_closeness(twiddles_im[i], tw_im, 1e-6); } for t in (0..n - 1).rev() { @@ -271,8 +294,8 @@ mod tests { for i in 0..dist { let (tw_re, tw_im) = twiddles_iter.next().unwrap(); - assert_f64_closeness(twiddles_re[i], tw_re, 1e-6); - assert_f64_closeness(twiddles_im[i], tw_im, 1e-6); + assert_float_closeness(twiddles_re[i], tw_re, 1e-6); + assert_float_closeness(twiddles_im[i], tw_im, 1e-6); } } } diff --git a/utilities/src/lib.rs b/utilities/src/lib.rs index 878048f..f7eff6f 100644 --- a/utilities/src/lib.rs +++ b/utilities/src/lib.rs @@ -1,33 +1,20 @@ pub extern crate rustfft; use std::f64::consts::PI; +use std::fmt::Display; use rand::distributions::Uniform; use rand::prelude::*; +use rustfft::num_traits::Float; -/// Asserts that two f64 numbers are approximately equal. +/// Asserts that two floating-point numbers are approximately equal. /// /// # Panics /// /// Panics if `actual` and `expected` are too far from each other #[allow(dead_code)] #[track_caller] -pub fn assert_f64_closeness(actual: f64, expected: f64, epsilon: f64) { - if (actual - expected).abs() >= epsilon { - panic!( - "Assertion failed: {actual} too far from expected value {expected} (with epsilon {epsilon})", - ); - } -} - -/// Asserts that two f32 numbers are approximately equal. -/// -/// # Panics -/// -/// Panics if `actual` and `expected` are too far from each other -#[allow(dead_code)] -#[track_caller] -pub fn assert_f32_closeness(actual: f32, expected: f32, epsilon: f32) { +pub fn assert_float_closeness(actual: T, expected: T, epsilon: T) { if (actual - expected).abs() >= epsilon { panic!( "Assertion failed: {actual} too far from expected value {expected} (with epsilon {epsilon})",