From 5806b504296e5d035e2c2a0b544eca26c782434e Mon Sep 17 00:00:00 2001 From: Saveliy Yusufov Date: Thu, 28 Mar 2024 16:04:45 -0400 Subject: [PATCH] Finish adding test macros for f32/f64 - Make float comparison generic for f32/f64 - Add f32/f64 tests for twiddles using macros --- src/lib.rs | 160 +++++++++++++++++++++++++------------------ src/planner.rs | 20 ++---- src/twiddles.rs | 79 +++++++++++---------- utilities/src/lib.rs | 21 ++---- 4 files changed, 146 insertions(+), 134 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1ac6edf..4a5b6fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -145,81 +145,105 @@ impl_fft_with_opts_and_plan_for!( mod tests { use std::ops::Range; - use utilities::{ - assert_f64_closeness, - rustfft::{num_complex::Complex64, FftPlanner}, - }; + use utilities::assert_float_closeness; + use utilities::rustfft::num_complex::Complex; + use utilities::rustfft::FftPlanner; use super::*; - #[should_panic] - #[test] - fn non_power_of_two_fft() { - let num_points = 5; - - // this test will actually always fail at this stage - let mut planner = Planner64::new(num_points, Direction::Forward); - - let mut reals = vec![0.0; num_points]; - let mut imags = vec![0.0; num_points]; - let opts = Options::guess_options(reals.len()); + macro_rules! non_power_of_2_planner { + ($test_name:ident, $planner:ty) => { + #[should_panic] + #[test] + fn $test_name() { + let num_points = 5; - // but this call should, in principle, panic as well - fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner); + // this test _should_ always fail at this stage + let _ = <$planner>::new(num_points, Direction::Forward); + } + }; } - // A regression test to make sure the `Planner` is compatible with fft execution. - #[should_panic] - #[test] - fn wrong_num_points_in_planner() { - let n = 16; - let num_points = 1 << n; - - // We purposely set n = 16 and pass it to the planner. - // n = 16 == 2^{4} is clearly a power of two, so the planner won't throw it out. - // However, the call to `fft_with_opts_and_plan` should panic since it tests that the - // size of the generated twiddle factors is half the size of the input. - // In this case, we have an input of size 1024 (used for mp3), but we tell the planner the - // input size is 16. - let mut planner = Planner64::new(n, Direction::Forward); - - let mut reals = vec![0.0; num_points]; - let mut imags = vec![0.0; num_points]; - let opts = Options::guess_options(reals.len()); - - // this call should panic - fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner); + non_power_of_2_planner!(non_power_of_2_planner_32, Planner32); + non_power_of_2_planner!(non_power_of_2_planner_64, Planner64); + + macro_rules! wrong_num_points_in_planner { + ($test_name:ident, $planner:ty, $fft_with_opts_and_plan:ident) => { + // A regression test to make sure the `Planner` is compatible with fft execution. + #[should_panic] + #[test] + fn $test_name() { + let n = 16; + let num_points = 1 << n; + + // We purposely set n = 16 and pass it to the planner. + // n == 16 == 2^{4} is clearly a power of two, so the planner won't throw it out. + // However, the call to `fft_with_opts_and_plan` should panic since it tests that the + // size of the generated twiddle factors is half the size of the input. + // In this case, we have an input of size 1024 (used for mp3), but we tell the planner the + // input size is 16. + let mut planner = <$planner>::new(n, Direction::Forward); + + let mut reals = vec![0.0; num_points]; + let mut imags = vec![0.0; num_points]; + let opts = Options::guess_options(reals.len()); + + // this call should panic + $fft_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner); + } + }; } - #[test] - fn fft_correctness() { - let range = Range { start: 4, end: 17 }; - - for k in range { - let n: usize = 1 << k; - - let mut reals: Vec<_> = (1..=n).map(|i| i as f64).collect(); - let mut imags: Vec<_> = (1..=n).map(|i| i as f64).collect(); - fft_64(&mut reals, &mut imags, Direction::Forward); - - let mut buffer: Vec = (1..=n) - .map(|i| Complex64::new(i as f64, i as f64)) - .collect(); - - let mut planner = FftPlanner::new(); - let fft = planner.plan_fft_forward(buffer.len()); - fft.process(&mut buffer); - - reals - .iter() - .zip(imags.iter()) - .enumerate() - .for_each(|(i, (z_re, z_im))| { - let expect_re = buffer[i].re; - let expect_im = buffer[i].im; - assert_f64_closeness(*z_re, expect_re, 0.01); - assert_f64_closeness(*z_im, expect_im, 0.01); - }); - } + wrong_num_points_in_planner!( + wrong_num_points_in_planner_32, + Planner32, + fft_32_with_opts_and_plan + ); + wrong_num_points_in_planner!( + wrong_num_points_in_planner_64, + Planner64, + fft_64_with_opts_and_plan + ); + + macro_rules! test_fft_correctness { + ($test_name:ident, $precision:ty, $fft_type:ident, $range_start:literal, $range_end:literal) => { + #[test] + fn $test_name() { + let range = Range { + start: $range_start, + end: $range_end, + }; + + for k in range { + let n: usize = 1 << k; + + let mut reals: Vec<$precision> = (1..=n).map(|i| i as $precision).collect(); + let mut imags: Vec<$precision> = (1..=n).map(|i| i as $precision).collect(); + $fft_type(&mut reals, &mut imags, Direction::Forward); + + let mut buffer: Vec> = (1..=n) + .map(|i| Complex::new(i as $precision, i as $precision)) + .collect(); + + let mut planner = FftPlanner::new(); + let fft = planner.plan_fft_forward(buffer.len()); + fft.process(&mut buffer); + + reals + .iter() + .zip(imags.iter()) + .enumerate() + .for_each(|(i, (z_re, z_im))| { + let expect_re = buffer[i].re; + let expect_im = buffer[i].im; + assert_float_closeness(*z_re, expect_re, 0.01); + assert_float_closeness(*z_im, expect_im, 0.01); + }); + } + } + }; } + + test_fft_correctness!(fft_correctness_32, f32, fft_32, 4, 9); + test_fft_correctness!(fft_correctness_64, f64, fft_64, 4, 17); } diff --git a/src/planner.rs b/src/planner.rs index 8e364c3..548e972 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -72,7 +72,7 @@ impl_planner_for!(Planner32, f32, generate_twiddles_simd_32); #[cfg(test)] mod tests { - use utilities::{assert_f32_closeness, assert_f64_closeness}; + use utilities::assert_float_closeness; use super::*; @@ -92,7 +92,7 @@ mod tests { test_no_twiddles!(no_twiddles_32, Planner32); macro_rules! forward_mul_inverse_eq_identity { - ($test_name:ident, $planner:ty, $float_check:ident) => { + ($test_name:ident, $planner:ty) => { #[test] fn $test_name() { for i in 3..25 { @@ -116,22 +116,14 @@ mod tests { .for_each(|(((a, b), c), d)| { let temp_re = a * c - b * d; let temp_im = a * d + b * c; - $float_check(temp_re, 1.0, 1e-2); - $float_check(temp_im, 0.0, 1e-2); + assert_float_closeness(temp_re, 1.0, 1e-2); + assert_float_closeness(temp_im, 0.0, 1e-2); }); } } }; } - forward_mul_inverse_eq_identity!( - forward_reverse_eq_identity_64, - Planner64, - assert_f64_closeness - ); - forward_mul_inverse_eq_identity!( - forward_reverse_eq_identity_32, - Planner32, - assert_f32_closeness - ); + forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, Planner64); + forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, Planner32); } diff --git a/src/twiddles.rs b/src/twiddles.rs index 93c9098..97f31a0 100644 --- a/src/twiddles.rs +++ b/src/twiddles.rs @@ -201,7 +201,7 @@ pub(crate) fn filter_twiddles(twiddles_re: &mut Vec, twiddles_im: & mod tests { use std::f64::consts::FRAC_1_SQRT_2; - use utilities::assert_f64_closeness; + use utilities::assert_float_closeness; use super::*; @@ -212,49 +212,58 @@ mod tests { let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, 1.0, 1e-10); - assert_f64_closeness(w_im, 0.0, 1e-10); + assert_float_closeness(w_re, 1.0, 1e-10); + assert_float_closeness(w_im, 0.0, 1e-10); let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, FRAC_1_SQRT_2, 1e-10); - assert_f64_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_re, FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, 0.0, 1e-10); - assert_f64_closeness(w_im, -1.0, 1e-10); + assert_float_closeness(w_re, 0.0, 1e-10); + assert_float_closeness(w_im, -1.0, 1e-10); let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, -FRAC_1_SQRT_2, 1e-10); - assert_f64_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_re, -FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); } - #[test] - fn twiddles_simd() { - for n in 4..28 { - let dist = 1 << n; - - let (twiddles_re_ref, twiddles_im_ref) = generate_twiddles(dist, Direction::Forward); - let (twiddles_re, twiddles_im) = generate_twiddles_simd_64(dist, Direction::Forward); - - twiddles_re - .iter() - .zip(twiddles_re_ref.iter()) - .for_each(|(simd, reference)| { - assert_f64_closeness(*simd, *reference, 1e-10); - }); - - twiddles_im - .iter() - .zip(twiddles_im_ref.iter()) - .for_each(|(simd, reference)| { - assert_f64_closeness(*simd, *reference, 1e-10); - }); - } + macro_rules! test_twiddles_simd { + ($test_name:ident, $generate_twiddles_simd:ident, $epsilon:literal) => { + #[test] + fn $test_name() { + for n in 4..25 { + let dist = 1 << n; + + let (twiddles_re_ref, twiddles_im_ref) = + generate_twiddles(dist, Direction::Forward); + let (twiddles_re, twiddles_im) = + $generate_twiddles_simd(dist, Direction::Forward); + + twiddles_re + .iter() + .zip(twiddles_re_ref.iter()) + .for_each(|(simd, reference)| { + assert_float_closeness(*simd, *reference, $epsilon); + }); + + twiddles_im + .iter() + .zip(twiddles_im_ref.iter()) + .for_each(|(simd, reference)| { + assert_float_closeness(*simd, *reference, $epsilon); + }); + } + } + }; } + test_twiddles_simd!(twiddles_simd_32, generate_twiddles_simd_32, 1e-1); + test_twiddles_simd!(twiddles_simd_64, generate_twiddles_simd_64, 1e-10); + #[test] fn twiddles_filter() { // Assume n = 28 @@ -269,8 +278,8 @@ mod tests { for i in 0..dist { let (tw_re, tw_im) = twiddles_iter.next().unwrap(); - assert_f64_closeness(twiddles_re[i], tw_re, 1e-6); - assert_f64_closeness(twiddles_im[i], tw_im, 1e-6); + assert_float_closeness(twiddles_re[i], tw_re, 1e-6); + assert_float_closeness(twiddles_im[i], tw_im, 1e-6); } for t in (0..n - 1).rev() { @@ -285,8 +294,8 @@ mod tests { for i in 0..dist { let (tw_re, tw_im) = twiddles_iter.next().unwrap(); - assert_f64_closeness(twiddles_re[i], tw_re, 1e-6); - assert_f64_closeness(twiddles_im[i], tw_im, 1e-6); + assert_float_closeness(twiddles_re[i], tw_re, 1e-6); + assert_float_closeness(twiddles_im[i], tw_im, 1e-6); } } } diff --git a/utilities/src/lib.rs b/utilities/src/lib.rs index 878048f..f7eff6f 100644 --- a/utilities/src/lib.rs +++ b/utilities/src/lib.rs @@ -1,33 +1,20 @@ pub extern crate rustfft; use std::f64::consts::PI; +use std::fmt::Display; use rand::distributions::Uniform; use rand::prelude::*; +use rustfft::num_traits::Float; -/// Asserts that two f64 numbers are approximately equal. +/// Asserts that two floating-point numbers are approximately equal. /// /// # Panics /// /// Panics if `actual` and `expected` are too far from each other #[allow(dead_code)] #[track_caller] -pub fn assert_f64_closeness(actual: f64, expected: f64, epsilon: f64) { - if (actual - expected).abs() >= epsilon { - panic!( - "Assertion failed: {actual} too far from expected value {expected} (with epsilon {epsilon})", - ); - } -} - -/// Asserts that two f32 numbers are approximately equal. -/// -/// # Panics -/// -/// Panics if `actual` and `expected` are too far from each other -#[allow(dead_code)] -#[track_caller] -pub fn assert_f32_closeness(actual: f32, expected: f32, epsilon: f32) { +pub fn assert_float_closeness(actual: T, expected: T, epsilon: T) { if (actual - expected).abs() >= epsilon { panic!( "Assertion failed: {actual} too far from expected value {expected} (with epsilon {epsilon})",