Skip to content

Commit

Permalink
Finish adding test macros for f32/f64
Browse files Browse the repository at this point in the history
- Make float comparison generic for f32/f64

- Add f32/f64 tests for twiddles using macros
  • Loading branch information
smu160 committed Mar 28, 2024
1 parent 14df45d commit 5806b50
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 134 deletions.
160 changes: 92 additions & 68 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,81 +145,105 @@ impl_fft_with_opts_and_plan_for!(
mod tests {
use std::ops::Range;

use utilities::{
assert_f64_closeness,
rustfft::{num_complex::Complex64, FftPlanner},
};
use utilities::assert_float_closeness;
use utilities::rustfft::num_complex::Complex;
use utilities::rustfft::FftPlanner;

use super::*;

#[should_panic]
#[test]
fn non_power_of_two_fft() {
let num_points = 5;

// this test will actually always fail at this stage
let mut planner = Planner64::new(num_points, Direction::Forward);

let mut reals = vec![0.0; num_points];
let mut imags = vec![0.0; num_points];
let opts = Options::guess_options(reals.len());
macro_rules! non_power_of_2_planner {
($test_name:ident, $planner:ty) => {
#[should_panic]
#[test]
fn $test_name() {
let num_points = 5;

// but this call should, in principle, panic as well
fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner);
// this test _should_ always fail at this stage
let _ = <$planner>::new(num_points, Direction::Forward);
}
};
}

// A regression test to make sure the `Planner` is compatible with fft execution.
#[should_panic]
#[test]
fn wrong_num_points_in_planner() {
let n = 16;
let num_points = 1 << n;

// We purposely set n = 16 and pass it to the planner.
// n = 16 == 2^{4} is clearly a power of two, so the planner won't throw it out.
// However, the call to `fft_with_opts_and_plan` should panic since it tests that the
// size of the generated twiddle factors is half the size of the input.
// In this case, we have an input of size 1024 (used for mp3), but we tell the planner the
// input size is 16.
let mut planner = Planner64::new(n, Direction::Forward);

let mut reals = vec![0.0; num_points];
let mut imags = vec![0.0; num_points];
let opts = Options::guess_options(reals.len());

// this call should panic
fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner);
non_power_of_2_planner!(non_power_of_2_planner_32, Planner32);
non_power_of_2_planner!(non_power_of_2_planner_64, Planner64);

macro_rules! wrong_num_points_in_planner {
($test_name:ident, $planner:ty, $fft_with_opts_and_plan:ident) => {
// A regression test to make sure the `Planner` is compatible with fft execution.
#[should_panic]
#[test]
fn $test_name() {
let n = 16;
let num_points = 1 << n;

// We purposely set n = 16 and pass it to the planner.
// n == 16 == 2^{4} is clearly a power of two, so the planner won't throw it out.
// However, the call to `fft_with_opts_and_plan` should panic since it tests that the
// size of the generated twiddle factors is half the size of the input.
// In this case, we have an input of size 1024 (used for mp3), but we tell the planner the
// input size is 16.
let mut planner = <$planner>::new(n, Direction::Forward);

let mut reals = vec![0.0; num_points];
let mut imags = vec![0.0; num_points];
let opts = Options::guess_options(reals.len());

// this call should panic
$fft_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner);
}
};
}

#[test]
fn fft_correctness() {
let range = Range { start: 4, end: 17 };

for k in range {
let n: usize = 1 << k;

let mut reals: Vec<_> = (1..=n).map(|i| i as f64).collect();
let mut imags: Vec<_> = (1..=n).map(|i| i as f64).collect();
fft_64(&mut reals, &mut imags, Direction::Forward);

let mut buffer: Vec<Complex64> = (1..=n)
.map(|i| Complex64::new(i as f64, i as f64))
.collect();

let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(buffer.len());
fft.process(&mut buffer);

reals
.iter()
.zip(imags.iter())
.enumerate()
.for_each(|(i, (z_re, z_im))| {
let expect_re = buffer[i].re;
let expect_im = buffer[i].im;
assert_f64_closeness(*z_re, expect_re, 0.01);
assert_f64_closeness(*z_im, expect_im, 0.01);
});
}
wrong_num_points_in_planner!(
wrong_num_points_in_planner_32,
Planner32,
fft_32_with_opts_and_plan
);
wrong_num_points_in_planner!(
wrong_num_points_in_planner_64,
Planner64,
fft_64_with_opts_and_plan
);

macro_rules! test_fft_correctness {
($test_name:ident, $precision:ty, $fft_type:ident, $range_start:literal, $range_end:literal) => {
#[test]
fn $test_name() {
let range = Range {
start: $range_start,
end: $range_end,
};

for k in range {
let n: usize = 1 << k;

let mut reals: Vec<$precision> = (1..=n).map(|i| i as $precision).collect();
let mut imags: Vec<$precision> = (1..=n).map(|i| i as $precision).collect();
$fft_type(&mut reals, &mut imags, Direction::Forward);

let mut buffer: Vec<Complex<$precision>> = (1..=n)
.map(|i| Complex::new(i as $precision, i as $precision))
.collect();

let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(buffer.len());
fft.process(&mut buffer);

reals
.iter()
.zip(imags.iter())
.enumerate()
.for_each(|(i, (z_re, z_im))| {
let expect_re = buffer[i].re;
let expect_im = buffer[i].im;
assert_float_closeness(*z_re, expect_re, 0.01);
assert_float_closeness(*z_im, expect_im, 0.01);
});
}
}
};
}

test_fft_correctness!(fft_correctness_32, f32, fft_32, 4, 9);
test_fft_correctness!(fft_correctness_64, f64, fft_64, 4, 17);
}
20 changes: 6 additions & 14 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl_planner_for!(Planner32, f32, generate_twiddles_simd_32);

#[cfg(test)]
mod tests {
use utilities::{assert_f32_closeness, assert_f64_closeness};
use utilities::assert_float_closeness;

use super::*;

Expand All @@ -92,7 +92,7 @@ mod tests {
test_no_twiddles!(no_twiddles_32, Planner32);

macro_rules! forward_mul_inverse_eq_identity {
($test_name:ident, $planner:ty, $float_check:ident) => {
($test_name:ident, $planner:ty) => {
#[test]
fn $test_name() {
for i in 3..25 {
Expand All @@ -116,22 +116,14 @@ mod tests {
.for_each(|(((a, b), c), d)| {
let temp_re = a * c - b * d;
let temp_im = a * d + b * c;
$float_check(temp_re, 1.0, 1e-2);
$float_check(temp_im, 0.0, 1e-2);
assert_float_closeness(temp_re, 1.0, 1e-2);
assert_float_closeness(temp_im, 0.0, 1e-2);
});
}
}
};
}

forward_mul_inverse_eq_identity!(
forward_reverse_eq_identity_64,
Planner64,
assert_f64_closeness
);
forward_mul_inverse_eq_identity!(
forward_reverse_eq_identity_32,
Planner32,
assert_f32_closeness
);
forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, Planner64);
forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, Planner32);
}
79 changes: 44 additions & 35 deletions src/twiddles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ pub(crate) fn filter_twiddles<T: Float>(twiddles_re: &mut Vec<T>, twiddles_im: &
mod tests {
use std::f64::consts::FRAC_1_SQRT_2;

use utilities::assert_f64_closeness;
use utilities::assert_float_closeness;

use super::*;

Expand All @@ -212,49 +212,58 @@ mod tests {

let (w_re, w_im) = twiddle_iter.next().unwrap();
println!("{w_re} {w_im}");
assert_f64_closeness(w_re, 1.0, 1e-10);
assert_f64_closeness(w_im, 0.0, 1e-10);
assert_float_closeness(w_re, 1.0, 1e-10);
assert_float_closeness(w_im, 0.0, 1e-10);

let (w_re, w_im) = twiddle_iter.next().unwrap();
println!("{w_re} {w_im}");
assert_f64_closeness(w_re, FRAC_1_SQRT_2, 1e-10);
assert_f64_closeness(w_im, -FRAC_1_SQRT_2, 1e-10);
assert_float_closeness(w_re, FRAC_1_SQRT_2, 1e-10);
assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10);

let (w_re, w_im) = twiddle_iter.next().unwrap();
println!("{w_re} {w_im}");
assert_f64_closeness(w_re, 0.0, 1e-10);
assert_f64_closeness(w_im, -1.0, 1e-10);
assert_float_closeness(w_re, 0.0, 1e-10);
assert_float_closeness(w_im, -1.0, 1e-10);

let (w_re, w_im) = twiddle_iter.next().unwrap();
println!("{w_re} {w_im}");
assert_f64_closeness(w_re, -FRAC_1_SQRT_2, 1e-10);
assert_f64_closeness(w_im, -FRAC_1_SQRT_2, 1e-10);
assert_float_closeness(w_re, -FRAC_1_SQRT_2, 1e-10);
assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10);
}

#[test]
fn twiddles_simd() {
for n in 4..28 {
let dist = 1 << n;

let (twiddles_re_ref, twiddles_im_ref) = generate_twiddles(dist, Direction::Forward);
let (twiddles_re, twiddles_im) = generate_twiddles_simd_64(dist, Direction::Forward);

twiddles_re
.iter()
.zip(twiddles_re_ref.iter())
.for_each(|(simd, reference)| {
assert_f64_closeness(*simd, *reference, 1e-10);
});

twiddles_im
.iter()
.zip(twiddles_im_ref.iter())
.for_each(|(simd, reference)| {
assert_f64_closeness(*simd, *reference, 1e-10);
});
}
macro_rules! test_twiddles_simd {
($test_name:ident, $generate_twiddles_simd:ident, $epsilon:literal) => {
#[test]
fn $test_name() {
for n in 4..25 {
let dist = 1 << n;

let (twiddles_re_ref, twiddles_im_ref) =
generate_twiddles(dist, Direction::Forward);
let (twiddles_re, twiddles_im) =
$generate_twiddles_simd(dist, Direction::Forward);

twiddles_re
.iter()
.zip(twiddles_re_ref.iter())
.for_each(|(simd, reference)| {
assert_float_closeness(*simd, *reference, $epsilon);
});

twiddles_im
.iter()
.zip(twiddles_im_ref.iter())
.for_each(|(simd, reference)| {
assert_float_closeness(*simd, *reference, $epsilon);
});
}
}
};
}

test_twiddles_simd!(twiddles_simd_32, generate_twiddles_simd_32, 1e-1);
test_twiddles_simd!(twiddles_simd_64, generate_twiddles_simd_64, 1e-10);

#[test]
fn twiddles_filter() {
// Assume n = 28
Expand All @@ -269,8 +278,8 @@ mod tests {

for i in 0..dist {
let (tw_re, tw_im) = twiddles_iter.next().unwrap();
assert_f64_closeness(twiddles_re[i], tw_re, 1e-6);
assert_f64_closeness(twiddles_im[i], tw_im, 1e-6);
assert_float_closeness(twiddles_re[i], tw_re, 1e-6);
assert_float_closeness(twiddles_im[i], tw_im, 1e-6);
}

for t in (0..n - 1).rev() {
Expand All @@ -285,8 +294,8 @@ mod tests {

for i in 0..dist {
let (tw_re, tw_im) = twiddles_iter.next().unwrap();
assert_f64_closeness(twiddles_re[i], tw_re, 1e-6);
assert_f64_closeness(twiddles_im[i], tw_im, 1e-6);
assert_float_closeness(twiddles_re[i], tw_re, 1e-6);
assert_float_closeness(twiddles_im[i], tw_im, 1e-6);
}
}
}
Expand Down
21 changes: 4 additions & 17 deletions utilities/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,20 @@
pub extern crate rustfft;

use std::f64::consts::PI;
use std::fmt::Display;

use rand::distributions::Uniform;
use rand::prelude::*;
use rustfft::num_traits::Float;

/// Asserts that two f64 numbers are approximately equal.
/// Asserts that two floating-point numbers are approximately equal.
///
/// # Panics
///
/// Panics if `actual` and `expected` are too far from each other
#[allow(dead_code)]
#[track_caller]
pub fn assert_f64_closeness(actual: f64, expected: f64, epsilon: f64) {
if (actual - expected).abs() >= epsilon {
panic!(
"Assertion failed: {actual} too far from expected value {expected} (with epsilon {epsilon})",
);
}
}

/// Asserts that two f32 numbers are approximately equal.
///
/// # Panics
///
/// Panics if `actual` and `expected` are too far from each other
#[allow(dead_code)]
#[track_caller]
pub fn assert_f32_closeness(actual: f32, expected: f32, epsilon: f32) {
pub fn assert_float_closeness<T: Float + Display>(actual: T, expected: T, epsilon: T) {
if (actual - expected).abs() >= epsilon {
panic!(
"Assertion failed: {actual} too far from expected value {expected} (with epsilon {epsilon})",
Expand Down

0 comments on commit 5806b50

Please sign in to comment.