diff --git a/src/kernels.rs b/src/kernels.rs index 7e61c29..a02f77a 100644 --- a/src/kernels.rs +++ b/src/kernels.rs @@ -1,4 +1,4 @@ -use std::simd::{f32x8, f64x8}; +use std::simd::{f32x16, f64x8}; use num_traits::Float; @@ -12,7 +12,7 @@ macro_rules! fft_butterfly_n_simd { dist: usize, ) { let chunk_size = dist << 1; - assert!(chunk_size >= 16); + assert!(chunk_size >= $lanes * 2); reals .chunks_exact_mut(chunk_size) .zip(imags.chunks_exact_mut(chunk_size)) @@ -49,7 +49,7 @@ macro_rules! fft_butterfly_n_simd { } fft_butterfly_n_simd!(fft_64_chunk_n_simd, f64, 8, f64x8); -fft_butterfly_n_simd!(fft_32_chunk_n_simd, f32, 8, f32x8); +fft_butterfly_n_simd!(fft_32_chunk_n_simd, f32, 16, f32x16); pub(crate) fn fft_chunk_n( reals: &mut [T], diff --git a/src/lib.rs b/src/lib.rs index 4a5b6fb..eb0c6ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,7 +64,7 @@ impl_fft_for!(fft_64, f64, Planner64, fft_64_with_opts_and_plan); impl_fft_for!(fft_32, f32, Planner32, fft_32_with_opts_and_plan); macro_rules! impl_fft_with_opts_and_plan_for { - ($func_name:ident, $precision:ty, $planner:ty, $simd_butterfly_kernel:ident) => { + ($func_name:ident, $precision:ty, $planner:ty, $simd_butterfly_kernel:ident, $lanes:literal) => { /// Same as [fft], but also accepts [`Options`] that control optimization strategies, as well as /// a [`Planner`] in the case that this FFT will need to be run multiple times. /// @@ -102,7 +102,7 @@ macro_rules! impl_fft_with_opts_and_plan_for { if t < n - 1 { filter_twiddles(twiddles_re, twiddles_im); } - if chunk_size >= 16 { + if chunk_size >= $lanes * 2 { $simd_butterfly_kernel(reals, imags, twiddles_re, twiddles_im, dist); } else { fft_chunk_n(reals, imags, twiddles_re, twiddles_im, dist); @@ -131,14 +131,16 @@ impl_fft_with_opts_and_plan_for!( fft_64_with_opts_and_plan, f64, Planner64, - fft_64_chunk_n_simd + fft_64_chunk_n_simd, + 8 ); impl_fft_with_opts_and_plan_for!( fft_32_with_opts_and_plan, f32, Planner32, - fft_32_chunk_n_simd + fft_32_chunk_n_simd, + 16 ); #[cfg(test)] @@ -146,8 +148,8 @@ mod tests { use std::ops::Range; use utilities::assert_float_closeness; - use utilities::rustfft::num_complex::Complex; use utilities::rustfft::FftPlanner; + use utilities::rustfft::num_complex::Complex; use super::*;