Skip to content

Commit

Permalink
Merge pull request #29 from QuState/hotfix
Browse files Browse the repository at this point in the history
Fixes inverse FFT ouput order issue
  • Loading branch information
smu160 authored May 3, 2024
2 parents 48a8d15 + aaa6a89 commit a715f81
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 42 deletions.
51 changes: 50 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,35 @@ macro_rules! impl_fft_for {
imags.len()
);

let mut planner = <$planner>::new(reals.len(), direction);
let mut planner = <$planner>::new(reals.len(), Direction::Forward);
assert!(
planner.num_twiddles().is_power_of_two()
&& planner.num_twiddles() == reals.len() / 2
);

let opts = Options::guess_options(reals.len());

match direction {
Direction::Reverse => {
for z_im in imags.iter_mut() {
*z_im = -*z_im;
}
}
_ => (),
}

$opts_and_plan(reals, imags, &opts, &mut planner);

match direction {
Direction::Reverse => {
let scaling_factor = (reals.len() as $precision).recip();
for (z_re, z_im) in reals.iter_mut().zip(imags.iter_mut()) {
*z_re *= scaling_factor;
*z_im *= -scaling_factor;
}
}
_ => (),
}
}
};
}
Expand Down Expand Up @@ -256,4 +277,32 @@ mod tests {

test_fft_correctness!(fft_correctness_32, f32, fft_32, 4, 9);
test_fft_correctness!(fft_correctness_64, f64, fft_64, 4, 17);

#[test]
fn ifft_using_fft() {
let n = 4;
let big_n = 1 << n;

let mut reals: Vec<_> = (1..=big_n).map(|i| i as f64).collect();
let mut imags: Vec<_> = vec![0.0; big_n];
println!("{:?}", reals);
println!("{:?}\n", imags);

fft_64(&mut reals, &mut imags, Direction::Forward);
println!("{:?}", reals);
println!("{:?}\n", imags);

fft_64(&mut reals, &mut imags, Direction::Reverse);
println!("{:?}", reals);
println!("{:?}", imags);

let mut signal_re = 1.0;

// Now check that the identity is indeed the original signal we generated above
for (z_re, z_im) in reals.into_iter().zip(imags.into_iter()) {
assert_float_closeness(z_re, signal_re, 1e-4);
assert_float_closeness(z_im, 0.0, 1e-4);
signal_re += 1.0;
}
}
}
45 changes: 4 additions & 41 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::twiddles::{generate_twiddles, generate_twiddles_simd_32, generate_twi

/// Reverse is for running the Inverse Fast Fourier Transform (IFFT)
/// Forward is for running the regular FFT
#[derive(Copy, Clone)]
pub enum Direction {
/// Leave the exponent term in the twiddle factor alone
Forward = 1,
Expand Down Expand Up @@ -34,7 +35,7 @@ macro_rules! impl_planner_for {
/// # Panics
///
/// Panics if `num_points < 1` or if `num_points` is __not__ a power of 2.
pub fn new(num_points: usize, direction: Direction) -> Self {
pub fn new(num_points: usize, _direction: Direction) -> Self {
assert!(num_points > 0 && num_points.is_power_of_two());
if num_points <= 4 {
return Self {
Expand All @@ -46,9 +47,9 @@ macro_rules! impl_planner_for {
let dist = num_points >> 1;

let (twiddles_re, twiddles_im) = if dist >= 8 * 2 {
$generate_twiddles_simd_fn(dist, direction)
$generate_twiddles_simd_fn(dist, Direction::Forward)
} else {
generate_twiddles(dist, direction)
generate_twiddles(dist, Direction::Forward)
};

assert_eq!(twiddles_re.len(), twiddles_im.len());
Expand All @@ -72,8 +73,6 @@ impl_planner_for!(Planner32, f32, generate_twiddles_simd_32);

#[cfg(test)]
mod tests {
use utilities::assert_float_closeness;

use super::*;

macro_rules! test_no_twiddles {
Expand All @@ -90,40 +89,4 @@ mod tests {

test_no_twiddles!(no_twiddles_64, Planner64);
test_no_twiddles!(no_twiddles_32, Planner32);

macro_rules! forward_mul_inverse_eq_identity {
($test_name:ident, $planner:ty) => {
#[test]
fn $test_name() {
for i in 3..25 {
let num_points = 1 << i;
let planner_forward = <$planner>::new(num_points, Direction::Forward);
let planner_reverse = <$planner>::new(num_points, Direction::Reverse);

assert_eq!(
planner_reverse.num_twiddles(),
planner_forward.num_twiddles()
);

// (a + ib) (c + id) = ac + iad + ibc - bd
// = ac - bd + i(ad + bc)
planner_forward
.twiddles_re
.iter()
.zip(planner_forward.twiddles_im.iter())
.zip(planner_reverse.twiddles_re.iter())
.zip(planner_reverse.twiddles_im)
.for_each(|(((a, b), c), d)| {
let temp_re = a * c - b * d;
let temp_im = a * d + b * c;
assert_float_closeness(temp_re, 1.0, 1e-2);
assert_float_closeness(temp_im, 0.0, 1e-2);
});
}
}
};
}

forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, Planner64);
forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, Planner32);
}
45 changes: 45 additions & 0 deletions src/twiddles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,49 @@ mod tests {
}
}
}

macro_rules! forward_mul_inverse_eq_identity {
($test_name:ident, $generate_twiddles_simd_fn:ident) => {
#[test]
fn $test_name() {
for i in 3..25 {
let num_points = 1 << i;
let dist = num_points >> 1;

let (fwd_twiddles_re, fwd_twiddles_im) = if dist >= 8 * 2 {
$generate_twiddles_simd_fn(dist, Direction::Forward)
} else {
generate_twiddles(dist, Direction::Forward)
};

assert_eq!(fwd_twiddles_re.len(), fwd_twiddles_im.len());

let (rev_twiddles_re, rev_twiddles_im) = if dist >= 8 * 2 {
$generate_twiddles_simd_fn(dist, Direction::Reverse)
} else {
generate_twiddles(dist, Direction::Reverse)
};

assert_eq!(rev_twiddles_re.len(), rev_twiddles_im.len());

// (a + ib) (c + id) = ac + iad + ibc - bd
// = ac - bd + i(ad + bc)
fwd_twiddles_re
.iter()
.zip(fwd_twiddles_im.iter())
.zip(rev_twiddles_re.iter())
.zip(rev_twiddles_im.iter())
.for_each(|(((a, b), c), d)| {
let temp_re = a * c - b * d;
let temp_im = a * d + b * c;
assert_float_closeness(temp_re, 1.0, 1e-2);
assert_float_closeness(temp_im, 0.0, 1e-2);
});
}
}
};
}

forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, generate_twiddles_simd_64);
forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, generate_twiddles_simd_32);
}

0 comments on commit a715f81

Please sign in to comment.