From 9699f9ee967b1c318f173eb1e39be593c5cefa09 Mon Sep 17 00:00:00 2001 From: Saveliy Yusufov Date: Wed, 22 May 2024 12:32:52 -0400 Subject: [PATCH] Vectorize deinterleaving of AoS --> SoA Use bytemuck + SIMD::deinterleave to rearrange input data from a slice of Complex values into 2 slices of f32 or f64 values --- Cargo.toml | 5 +-- src/lib.rs | 103 ++++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 86 insertions(+), 22 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2bc3158..e3ab005 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,11 +13,12 @@ exclude = ["assets", "scripts", "benches"] [dependencies] num-traits = "0.2.18" multiversion = "0.7" -num-complex = { version = "0.4.6", optional = true } +num-complex = { version = "0.4.6", features = ["bytemuck"], optional = true } +bytemuck = { version = "1.16.0", optional = true } [features] default = [] -complex-nums = ["dep:num-complex"] +complex-nums = ["dep:num-complex", "dep:bytemuck"] [dev-dependencies] utilities = { path = "utilities" } diff --git a/src/lib.rs b/src/lib.rs index 383dbd7..da93cdd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,20 @@ #![doc = include_str!("../README.md")] -#![warn(clippy::complexity)] -#![warn(missing_docs)] -#![warn(clippy::style)] -#![warn(clippy::correctness)] -#![warn(clippy::suspicious)] -#![warn(clippy::perf)] +#![warn( + missing_docs, + clippy::complexity, + clippy::perf, + clippy::style, + clippy::correctness, + clippy::suspicious +)] #![forbid(unsafe_code)] #![feature(portable_simd, avx512_target_feature)] +#[cfg(feature = "complex-nums")] +use std::simd::{f32x16, f64x8}; + +#[cfg(feature = "complex-nums")] +use bytemuck::cast_slice; #[cfg(feature = "complex-nums")] use num_complex::Complex; #[cfg(feature = "complex-nums")] @@ -90,9 +97,9 @@ impl_fft_for!(fft_32, f32, Planner32, fft_32_with_opts_and_plan); #[cfg(feature = "complex-nums")] macro_rules! impl_fft_interleaved_for { - ($func_name:ident, $precision:ty, $fft_func:ident) => { + ($func_name:ident, $precision:ty, $fft_func:ident, $deinterleaving_func: ident) => { /// FFT Interleaved -- this is an alternative to [`fft_64`]/[`fft_32`] in the case where - /// the input data is a array of [`num_complex::Complex`]. + /// the input data is a array of [`Complex`]. /// /// The input should be provided in normal order, and then the modified input is /// bit-reversed. @@ -100,7 +107,7 @@ macro_rules! impl_fft_interleaved_for { /// ## References /// pub fn $func_name(signal: &mut [Complex<$precision>], direction: Direction) { - let (mut reals, mut imags) = separate_re_im(signal); + let (mut reals, mut imags) = $deinterleaving_func(signal); $fft_func(&mut reals, &mut imags, direction); signal.copy_from_slice(&combine_re_im(&reals, &imags)) } @@ -108,9 +115,9 @@ macro_rules! impl_fft_interleaved_for { } #[cfg(feature = "complex-nums")] -impl_fft_interleaved_for!(fft_32_interleaved, f32, fft_32); +impl_fft_interleaved_for!(fft_32_interleaved, f32, fft_32, separate_re_im_f32); #[cfg(feature = "complex-nums")] -impl_fft_interleaved_for!(fft_64_interleaved, f64, fft_64); +impl_fft_interleaved_for!(fft_64_interleaved, f64, fft_64, separate_re_im_f64); macro_rules! impl_fft_with_opts_and_plan_for { ($func_name:ident, $precision:ty, $planner:ty, $simd_butterfly_kernel:ident, $lanes:literal) => { @@ -201,13 +208,69 @@ impl_fft_with_opts_and_plan_for!( 16 ); -/// Utility function to separate interleaved format signals (i.e., Vector of Complex Number Structs) -/// into separate vectors for the corresponding real and imaginary components. -#[cfg(feature = "complex-nums")] -pub fn separate_re_im(signal: &[Complex]) -> (Vec, Vec) { - signal.iter().map(|z| (z.re, z.im)).unzip() +macro_rules! impl_separate_re_im { + ($func_name:ident, $precision:ty, $lanes:literal, $simd_vec:ty) => { + /// Utility function to separate interleaved format signals (i.e., Vector of Complex Number Structs) + /// into separate vectors for the corresponding real and imaginary components. + #[multiversion::multiversion( + targets("x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", // x86_64-v4 + "x86_64+avx2+fma", // x86_64-v3 + "x86_64+sse4.2", // x86_64-v2 + "x86+avx512f+avx512bw+avx512cd+avx512dq+avx512vl", + "x86+avx2+fma", + "x86+sse4.2", + "x86+sse2", + ))] + pub fn $func_name( + signal: &[Complex<$precision>], + ) -> (Vec<$precision>, Vec<$precision>) { + let n = signal.len(); + let mut reals = vec![0.0; n]; + let mut imags = vec![0.0; n]; + + let complex_f64 = cast_slice(signal); + const CHUNK_SIZE: usize = $lanes * 2; + + let mut i = 0; + for ((chunk, chunk_re), chunk_im) in complex_f64 + .chunks_exact(CHUNK_SIZE) + .zip(reals.chunks_exact_mut($lanes)) + .zip(imags.chunks_exact_mut($lanes)) + { + let (first_half, second_half) = chunk.split_at($lanes); + + let a = <$simd_vec>::from_slice(&first_half); + let b = <$simd_vec>::from_slice(&second_half); + let (re_deinterleaved, im_deinterleaved) = a.deinterleave(b); + + chunk_re.copy_from_slice(&re_deinterleaved.to_array()); + chunk_im.copy_from_slice(&im_deinterleaved.to_array()); + i += CHUNK_SIZE; + } + + let remainder = complex_f64.chunks_exact(CHUNK_SIZE).remainder(); + if !remainder.is_empty() { + remainder + .chunks_exact(2) + .zip(reals[i..].iter_mut()) + .zip(imags[i..].iter_mut()) + .for_each(|((c, re), im)| { + *re = c[0]; + *im = c[1]; + }); + } + + (reals, imags) + } + }; } +#[cfg(feature = "complex-nums")] +impl_separate_re_im!(separate_re_im_f32, f32, 16, f32x16); + +#[cfg(feature = "complex-nums")] +impl_separate_re_im!(separate_re_im_f64, f64, 8, f64x8); + /// Utility function to combine separate vectors of real and imaginary components /// into a single vector of Complex Number Structs. /// @@ -238,14 +301,14 @@ mod tests { #[cfg(feature = "complex-nums")] #[test] fn test_separate_and_combine_re_im() { - let complex_vec = vec![ + let complex_vec: Vec<_> = vec![ Complex::new(1.0, 2.0), Complex::new(3.0, 4.0), Complex::new(5.0, 6.0), Complex::new(7.0, 8.0), ]; - let (reals, imags) = separate_re_im(&complex_vec); + let (reals, imags) = separate_re_im_f64(&complex_vec); let recombined_vec = combine_re_im(&reals, &imags); @@ -351,7 +414,7 @@ mod tests { #[cfg(feature = "complex-nums")] #[test] fn fft_interleaved_correctness() { - let n = 4; + let n = 10; let big_n = 1 << n; let mut actual_signal: Vec<_> = (1..=big_n).map(|i| Complex::new(i as f64, 0.0)).collect(); let mut expected_reals: Vec<_> = (1..=big_n).map(|i| i as f64).collect(); @@ -369,7 +432,7 @@ mod tests { assert_float_closeness(z.im, z_im, 1e-10); }); - let n = 4; + let n = 10; let big_n = 1 << n; let mut actual_signal: Vec<_> = (1..=big_n).map(|i| Complex::new(i as f32, 0.0)).collect(); let mut expected_reals: Vec<_> = (1..=big_n).map(|i| i as f32).collect();