diff --git a/Cargo.toml b/Cargo.toml index 98e51eb..12d50ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "phastft" -version = "0.2.0" +version = "0.2.1" edition = "2021" authors = ["Saveliy Yusufov", "Shnatsel"] license = "MIT OR Apache-2.0" diff --git a/README.md b/README.md index e885f3b..54c505a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Transform (FFT) library written in pure Rust. - Performance on par with other Rust FFT implementations - Zero `unsafe` code - Takes advantage of latest CPU features up to and including `AVX-512`, but performs well even without them +- Selects the fastest implementation at runtime. No need for `-C target-cpu=native`! - Optional parallelization of some steps to 2 threads (with even more planned) - 2x lower memory usage than [RustFFT](https://crates.io/crates/rustfft/) - Python bindings (via [PyO3](https://github.com/PyO3/pyo3)) diff --git a/pyphastft/Cargo.toml b/pyphastft/Cargo.toml index eaf4221..a1d234c 100644 --- a/pyphastft/Cargo.toml +++ b/pyphastft/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pyphastft" -version = "0.2.0" +version = "0.2.1" edition = "2021" authors = ["Saveliy Yusufov", "Shnatsel"] license = "MIT OR Apache-2.0" diff --git a/src/lib.rs b/src/lib.rs index 235279c..5756aba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,14 +48,35 @@ macro_rules! impl_fft_for { imags.len() ); - let mut planner = <$planner>::new(reals.len(), direction); + let mut planner = <$planner>::new(reals.len(), Direction::Forward); assert!( planner.num_twiddles().is_power_of_two() && planner.num_twiddles() == reals.len() / 2 ); let opts = Options::guess_options(reals.len()); + + match direction { + Direction::Reverse => { + for z_im in imags.iter_mut() { + *z_im = -*z_im; + } + } + _ => (), + } + $opts_and_plan(reals, imags, &opts, &mut planner); + + match direction { + Direction::Reverse => { + let scaling_factor = (reals.len() as $precision).recip(); + for (z_re, z_im) in reals.iter_mut().zip(imags.iter_mut()) { + *z_re *= scaling_factor; + *z_im *= -scaling_factor; + } + } + _ => (), + } } }; } @@ -156,7 +177,7 @@ impl_fft_with_opts_and_plan_for!( mod tests { use std::ops::Range; - use utilities::assert_float_closeness; + use utilities::{assert_float_closeness, gen_random_signal}; use utilities::rustfft::FftPlanner; use utilities::rustfft::num_complex::Complex; @@ -257,4 +278,34 @@ mod tests { test_fft_correctness!(fft_correctness_32, f32, fft_32, 4, 9); test_fft_correctness!(fft_correctness_64, f64, fft_64, 4, 17); + + #[test] + fn fft_round_trip() { + for i in 4..23 { + let big_n = 1 << i; + let mut reals = vec![0.0; big_n]; + let mut imags = vec![0.0; big_n]; + + gen_random_signal(&mut reals, &mut imags); + + let original_reals = reals.clone(); + let original_imags = imags.clone(); + + // Forward FFT + fft_64(&mut reals, &mut imags, Direction::Forward); + + // Inverse FFT + fft_64(&mut reals, &mut imags, Direction::Reverse); + + // Ensure we get back the original signal within some tolerance + for ((orig_re, orig_im), (res_re, res_im)) in original_reals + .into_iter() + .zip(original_imags.into_iter()) + .zip(reals.into_iter().zip(imags.into_iter())) + { + assert_float_closeness(res_re, orig_re, 1e-6); + assert_float_closeness(res_im, orig_im, 1e-6); + } + } + } } diff --git a/src/planner.rs b/src/planner.rs index 548e972..b174567 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -6,6 +6,7 @@ use crate::twiddles::{generate_twiddles, generate_twiddles_simd_32, generate_twi /// Reverse is for running the Inverse Fast Fourier Transform (IFFT) /// Forward is for running the regular FFT +#[derive(Copy, Clone)] pub enum Direction { /// Leave the exponent term in the twiddle factor alone Forward = 1, @@ -34,7 +35,7 @@ macro_rules! impl_planner_for { /// # Panics /// /// Panics if `num_points < 1` or if `num_points` is __not__ a power of 2. - pub fn new(num_points: usize, direction: Direction) -> Self { + pub fn new(num_points: usize, _direction: Direction) -> Self { assert!(num_points > 0 && num_points.is_power_of_two()); if num_points <= 4 { return Self { @@ -46,9 +47,9 @@ macro_rules! impl_planner_for { let dist = num_points >> 1; let (twiddles_re, twiddles_im) = if dist >= 8 * 2 { - $generate_twiddles_simd_fn(dist, direction) + $generate_twiddles_simd_fn(dist, Direction::Forward) } else { - generate_twiddles(dist, direction) + generate_twiddles(dist, Direction::Forward) }; assert_eq!(twiddles_re.len(), twiddles_im.len()); @@ -72,8 +73,6 @@ impl_planner_for!(Planner32, f32, generate_twiddles_simd_32); #[cfg(test)] mod tests { - use utilities::assert_float_closeness; - use super::*; macro_rules! test_no_twiddles { @@ -90,40 +89,4 @@ mod tests { test_no_twiddles!(no_twiddles_64, Planner64); test_no_twiddles!(no_twiddles_32, Planner32); - - macro_rules! forward_mul_inverse_eq_identity { - ($test_name:ident, $planner:ty) => { - #[test] - fn $test_name() { - for i in 3..25 { - let num_points = 1 << i; - let planner_forward = <$planner>::new(num_points, Direction::Forward); - let planner_reverse = <$planner>::new(num_points, Direction::Reverse); - - assert_eq!( - planner_reverse.num_twiddles(), - planner_forward.num_twiddles() - ); - - // (a + ib) (c + id) = ac + iad + ibc - bd - // = ac - bd + i(ad + bc) - planner_forward - .twiddles_re - .iter() - .zip(planner_forward.twiddles_im.iter()) - .zip(planner_reverse.twiddles_re.iter()) - .zip(planner_reverse.twiddles_im) - .for_each(|(((a, b), c), d)| { - let temp_re = a * c - b * d; - let temp_im = a * d + b * c; - assert_float_closeness(temp_re, 1.0, 1e-2); - assert_float_closeness(temp_im, 0.0, 1e-2); - }); - } - } - }; - } - - forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, Planner64); - forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, Planner32); } diff --git a/src/twiddles.rs b/src/twiddles.rs index b754a11..13dedba 100644 --- a/src/twiddles.rs +++ b/src/twiddles.rs @@ -300,4 +300,49 @@ mod tests { } } } + + macro_rules! forward_mul_inverse_eq_identity { + ($test_name:ident, $generate_twiddles_simd_fn:ident) => { + #[test] + fn $test_name() { + for i in 3..25 { + let num_points = 1 << i; + let dist = num_points >> 1; + + let (fwd_twiddles_re, fwd_twiddles_im) = if dist >= 8 * 2 { + $generate_twiddles_simd_fn(dist, Direction::Forward) + } else { + generate_twiddles(dist, Direction::Forward) + }; + + assert_eq!(fwd_twiddles_re.len(), fwd_twiddles_im.len()); + + let (rev_twiddles_re, rev_twiddles_im) = if dist >= 8 * 2 { + $generate_twiddles_simd_fn(dist, Direction::Reverse) + } else { + generate_twiddles(dist, Direction::Reverse) + }; + + assert_eq!(rev_twiddles_re.len(), rev_twiddles_im.len()); + + // (a + ib) (c + id) = ac + iad + ibc - bd + // = ac - bd + i(ad + bc) + fwd_twiddles_re + .iter() + .zip(fwd_twiddles_im.iter()) + .zip(rev_twiddles_re.iter()) + .zip(rev_twiddles_im.iter()) + .for_each(|(((a, b), c), d)| { + let temp_re = a * c - b * d; + let temp_im = a * d + b * c; + assert_float_closeness(temp_re, 1.0, 1e-2); + assert_float_closeness(temp_im, 0.0, 1e-2); + }); + } + } + }; + } + + forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, generate_twiddles_simd_64); + forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, generate_twiddles_simd_32); } diff --git a/utilities/Cargo.toml b/utilities/Cargo.toml index e4faada..4bd855b 100644 --- a/utilities/Cargo.toml +++ b/utilities/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "utilities" -version = "0.1.0" +version = "0.2.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html