Skip to content

Commit

Permalink
Merge branch 'main' into feature/r2c
Browse files Browse the repository at this point in the history
  • Loading branch information
smu160 committed May 11, 2024
2 parents 5159c10 + 7e8176a commit 443d6e2
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "phastft"
version = "0.2.0"
version = "0.2.1"
edition = "2021"
authors = ["Saveliy Yusufov", "Shnatsel"]
license = "MIT OR Apache-2.0"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Transform (FFT) library written in pure Rust.
- Performance on par with other Rust FFT implementations
- Zero `unsafe` code
- Takes advantage of latest CPU features up to and including `AVX-512`, but performs well even without them
- Selects the fastest implementation at runtime. No need for `-C target-cpu=native`!
- Optional parallelization of some steps to 2 threads (with even more planned)
- 2x lower memory usage than [RustFFT](https://crates.io/crates/rustfft/)
- Python bindings (via [PyO3](https://github.com/PyO3/pyo3))
Expand Down
2 changes: 1 addition & 1 deletion pyphastft/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pyphastft"
version = "0.2.0"
version = "0.2.1"
edition = "2021"
authors = ["Saveliy Yusufov", "Shnatsel"]
license = "MIT OR Apache-2.0"
Expand Down
55 changes: 53 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,35 @@ macro_rules! impl_fft_for {
imags.len()
);

let mut planner = <$planner>::new(reals.len(), direction);
let mut planner = <$planner>::new(reals.len(), Direction::Forward);
assert!(
planner.num_twiddles().is_power_of_two()
&& planner.num_twiddles() == reals.len() / 2
);

let opts = Options::guess_options(reals.len());

match direction {
Direction::Reverse => {
for z_im in imags.iter_mut() {
*z_im = -*z_im;
}
}
_ => (),
}

$opts_and_plan(reals, imags, &opts, &mut planner);

match direction {
Direction::Reverse => {
let scaling_factor = (reals.len() as $precision).recip();
for (z_re, z_im) in reals.iter_mut().zip(imags.iter_mut()) {
*z_re *= scaling_factor;
*z_im *= -scaling_factor;
}
}
_ => (),
}
}
};
}
Expand Down Expand Up @@ -156,7 +177,7 @@ impl_fft_with_opts_and_plan_for!(
mod tests {
use std::ops::Range;

use utilities::assert_float_closeness;
use utilities::{assert_float_closeness, gen_random_signal};
use utilities::rustfft::FftPlanner;
use utilities::rustfft::num_complex::Complex;

Expand Down Expand Up @@ -257,4 +278,34 @@ mod tests {

test_fft_correctness!(fft_correctness_32, f32, fft_32, 4, 9);
test_fft_correctness!(fft_correctness_64, f64, fft_64, 4, 17);

#[test]
fn fft_round_trip() {
for i in 4..23 {
let big_n = 1 << i;
let mut reals = vec![0.0; big_n];
let mut imags = vec![0.0; big_n];

gen_random_signal(&mut reals, &mut imags);

let original_reals = reals.clone();
let original_imags = imags.clone();

// Forward FFT
fft_64(&mut reals, &mut imags, Direction::Forward);

// Inverse FFT
fft_64(&mut reals, &mut imags, Direction::Reverse);

// Ensure we get back the original signal within some tolerance
for ((orig_re, orig_im), (res_re, res_im)) in original_reals
.into_iter()
.zip(original_imags.into_iter())
.zip(reals.into_iter().zip(imags.into_iter()))
{
assert_float_closeness(res_re, orig_re, 1e-6);
assert_float_closeness(res_im, orig_im, 1e-6);
}
}
}
}
45 changes: 4 additions & 41 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::twiddles::{generate_twiddles, generate_twiddles_simd_32, generate_twi

/// Reverse is for running the Inverse Fast Fourier Transform (IFFT)
/// Forward is for running the regular FFT
#[derive(Copy, Clone)]
pub enum Direction {
/// Leave the exponent term in the twiddle factor alone
Forward = 1,
Expand Down Expand Up @@ -34,7 +35,7 @@ macro_rules! impl_planner_for {
/// # Panics
///
/// Panics if `num_points < 1` or if `num_points` is __not__ a power of 2.
pub fn new(num_points: usize, direction: Direction) -> Self {
pub fn new(num_points: usize, _direction: Direction) -> Self {
assert!(num_points > 0 && num_points.is_power_of_two());
if num_points <= 4 {
return Self {
Expand All @@ -46,9 +47,9 @@ macro_rules! impl_planner_for {
let dist = num_points >> 1;

let (twiddles_re, twiddles_im) = if dist >= 8 * 2 {
$generate_twiddles_simd_fn(dist, direction)
$generate_twiddles_simd_fn(dist, Direction::Forward)
} else {
generate_twiddles(dist, direction)
generate_twiddles(dist, Direction::Forward)
};

assert_eq!(twiddles_re.len(), twiddles_im.len());
Expand All @@ -72,8 +73,6 @@ impl_planner_for!(Planner32, f32, generate_twiddles_simd_32);

#[cfg(test)]
mod tests {
use utilities::assert_float_closeness;

use super::*;

macro_rules! test_no_twiddles {
Expand All @@ -90,40 +89,4 @@ mod tests {

test_no_twiddles!(no_twiddles_64, Planner64);
test_no_twiddles!(no_twiddles_32, Planner32);

macro_rules! forward_mul_inverse_eq_identity {
($test_name:ident, $planner:ty) => {
#[test]
fn $test_name() {
for i in 3..25 {
let num_points = 1 << i;
let planner_forward = <$planner>::new(num_points, Direction::Forward);
let planner_reverse = <$planner>::new(num_points, Direction::Reverse);

assert_eq!(
planner_reverse.num_twiddles(),
planner_forward.num_twiddles()
);

// (a + ib) (c + id) = ac + iad + ibc - bd
// = ac - bd + i(ad + bc)
planner_forward
.twiddles_re
.iter()
.zip(planner_forward.twiddles_im.iter())
.zip(planner_reverse.twiddles_re.iter())
.zip(planner_reverse.twiddles_im)
.for_each(|(((a, b), c), d)| {
let temp_re = a * c - b * d;
let temp_im = a * d + b * c;
assert_float_closeness(temp_re, 1.0, 1e-2);
assert_float_closeness(temp_im, 0.0, 1e-2);
});
}
}
};
}

forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, Planner64);
forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, Planner32);
}
45 changes: 45 additions & 0 deletions src/twiddles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,49 @@ mod tests {
}
}
}

macro_rules! forward_mul_inverse_eq_identity {
($test_name:ident, $generate_twiddles_simd_fn:ident) => {
#[test]
fn $test_name() {
for i in 3..25 {
let num_points = 1 << i;
let dist = num_points >> 1;

let (fwd_twiddles_re, fwd_twiddles_im) = if dist >= 8 * 2 {
$generate_twiddles_simd_fn(dist, Direction::Forward)
} else {
generate_twiddles(dist, Direction::Forward)
};

assert_eq!(fwd_twiddles_re.len(), fwd_twiddles_im.len());

let (rev_twiddles_re, rev_twiddles_im) = if dist >= 8 * 2 {
$generate_twiddles_simd_fn(dist, Direction::Reverse)
} else {
generate_twiddles(dist, Direction::Reverse)
};

assert_eq!(rev_twiddles_re.len(), rev_twiddles_im.len());

// (a + ib) (c + id) = ac + iad + ibc - bd
// = ac - bd + i(ad + bc)
fwd_twiddles_re
.iter()
.zip(fwd_twiddles_im.iter())
.zip(rev_twiddles_re.iter())
.zip(rev_twiddles_im.iter())
.for_each(|(((a, b), c), d)| {
let temp_re = a * c - b * d;
let temp_im = a * d + b * c;
assert_float_closeness(temp_re, 1.0, 1e-2);
assert_float_closeness(temp_im, 0.0, 1e-2);
});
}
}
};
}

forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_64, generate_twiddles_simd_64);
forward_mul_inverse_eq_identity!(forward_reverse_eq_identity_32, generate_twiddles_simd_32);
}
2 changes: 1 addition & 1 deletion utilities/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "utilities"
version = "0.1.0"
version = "0.2.1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down

0 comments on commit 443d6e2

Please sign in to comment.