diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 157d5aa..626f294 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -63,6 +63,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: test + args: --all-features coverage: runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index 5aa2807..bd19297 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,12 @@ exclude = ["assets", "scripts", "benches"] [dependencies] num-traits = "0.2.18" multiversion = "0.7" +num-complex = { version = "0.4.6", features = ["bytemuck"], optional = true } +bytemuck = { version = "1.16.0", optional = true } + +[features] +default = [] +complex-nums = ["dep:num-complex", "dep:bytemuck"] [dev-dependencies] criterion = "0.5.1" @@ -33,3 +39,5 @@ panic = "abort" inherits = "release" debug = true +[package.metadata.docs.rs] +all-features = true \ No newline at end of file diff --git a/benches/README.md b/benches/README.md index 3e84f65..e276e06 100644 --- a/benches/README.md +++ b/benches/README.md @@ -4,24 +4,15 @@ ### Setup Environment -1. Install [FFTW3](http://www.fftw.org/download.html)[^1] +1. Clone the `PhastFT` git repository [^2]. - It may be possible to install `fftw3` using a package manager. - - ##### debian - ```bash - sudo apt install libfftw3-dev - ``` - -2. Clone the `PhastFT` git repository [^2]. - -3. Create virtual env +2. Create virtual env ```bash cd ~/PhastFT/benches && python3 -m venv .env && source .env/bin/activate ``` -4. Install python dependencies[^1] +3. Install python dependencies[^1] ```bash pip install -r requirements.txt @@ -29,10 +20,10 @@ cd ~/PhastFT/pyphastft pip install . ``` -5. Run the `FFTW3` vs. `RustFFT` vs. `PhastFT` benchmark for all inputs of size `2^n`, where `n \in [4, 30].` +5. Run the `FFTW3-RB` vs. `RustFFT` vs. `PhastFT` benchmarks` ```bash -./benchmark.sh 4 29 +python run_benches.py ``` 6. Plot the results @@ -101,13 +92,6 @@ On linux, open access to performance monitoring, and observability operations fo echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid ``` -Add debug to `Cargo.toml` under `profile.release`: - -```bash -[profile.release] -debug = true -``` - Finally, run: ```bash diff --git a/benches/bench.rs b/benches/bench.rs index 9bf511e..00d7833 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -13,7 +13,7 @@ use utilities::rustfft::num_complex::Complex; use utilities::rustfft::FftPlanner; const LENGTHS: &[usize] = &[ - 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, ]; fn generate_numbers(n: usize) -> (Vec, Vec) @@ -69,7 +69,7 @@ fn benchmark_forward_f32(c: &mut Criterion) { let planner = Planner32::new(len, Direction::Forward); let (mut reals, mut imags) = generate_numbers(len); - group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &len| { + group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &_len| { b.iter(|| { fft_32_with_opts_and_plan( black_box(&mut reals), @@ -85,7 +85,7 @@ fn benchmark_forward_f32(c: &mut Criterion) { let fft = planner.plan_fft_forward(len); let mut signal = generate_complex_numbers(len); - group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &len| { + group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &_len| { b.iter(|| fft.process(black_box(&mut signal))); }); } @@ -103,31 +103,49 @@ fn benchmark_inverse_f32(c: &mut Criterion) { c.bench_function(&id, |b| { let (mut reals, mut imags) = generate_numbers(len); b.iter(|| { - black_box(fft_32_with_opts_and_plan( - &mut reals, &mut imags, &options, &planner, - )); + fft_32_with_opts_and_plan( + black_box(&mut reals), + black_box(&mut imags), + black_box(&options), + black_box(&planner), + ); }); }); } } fn benchmark_forward_f64(c: &mut Criterion) { - let options = Options::default(); + let mut group = c.benchmark_group("Forward f64"); for n in LENGTHS.iter() { let len = 1 << n; - let id = format!("FFT Forward f64 {} elements", len); + let id = "PhastFT FFT Forward"; + let options = Options::guess_options(len); let planner = Planner64::new(len, Direction::Forward); + let (mut reals, mut imags) = generate_numbers(len); + group.throughput(Throughput::Elements(len as u64)); - c.bench_function(&id, |b| { - let (mut reals, mut imags) = generate_numbers(len); + group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &_len| { b.iter(|| { - black_box(fft_64_with_opts_and_plan( - &mut reals, &mut imags, &options, &planner, - )); + fft_64_with_opts_and_plan( + black_box(&mut reals), + black_box(&mut imags), + black_box(&options), + black_box(&planner), + ); }); }); + + let id = "RustFFT FFT Forward"; + let mut planner = FftPlanner::::new(); + let fft = planner.plan_fft_forward(len); + let mut signal = generate_complex_numbers(len); + + group.bench_with_input(BenchmarkId::new(id, len), &len, |b, &_len| { + b.iter(|| fft.process(black_box(&mut signal))); + }); } + group.finish(); } fn benchmark_inverse_f64(c: &mut Criterion) { @@ -141,9 +159,12 @@ fn benchmark_inverse_f64(c: &mut Criterion) { c.bench_function(&id, |b| { let (mut reals, mut imags) = generate_numbers(len); b.iter(|| { - black_box(fft_64_with_opts_and_plan( - &mut reals, &mut imags, &options, &planner, - )); + fft_64_with_opts_and_plan( + black_box(&mut reals), + black_box(&mut imags), + black_box(&options), + black_box(&planner), + ); }); }); } diff --git a/benches/run_benches.py b/benches/run_benches.py new file mode 100644 index 0000000..7c72b19 --- /dev/null +++ b/benches/run_benches.py @@ -0,0 +1,228 @@ +import os +import subprocess +import sys +from pathlib import Path +import matplotlib.pyplot as plt +import shutil +from datetime import datetime +import logging +import numpy as np + +# Configuration +OUTPUT_DIR = "benchmark_output" +HISTORY_DIR = "benchmark_history" +LOG_DIR = "benchmark_logs" +MAX_ITERS = 1 << 10 +START = 6 +END = 20 +STD_THRESHOLD = 0.05 # 5% standard deviation threshold + +# Ensure log directory exists +Path(LOG_DIR).mkdir(parents=True, exist_ok=True) + +# Setup logging +logging.basicConfig( + filename=Path(LOG_DIR) / "benchmark.log", + level=logging.INFO, + format="%(asctime)s - %(message)s", +) +console = logging.StreamHandler() +console.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s - %(message)s") +console.setFormatter(formatter) +logging.getLogger().addHandler(console) + + +def run_command(command, cwd=None): + result = subprocess.run( + command, shell=True, text=True, capture_output=True, cwd=cwd + ) + if result.returncode != 0: + logging.error(f"Error running command: {command}\n{result.stderr}") + sys.exit(result.returncode) + return result.stdout.strip() + + +def clean_build_rust(): + logging.info("Cleaning and building Rust project...") + run_command("cargo clean") + run_command("cargo build --release --examples") + + +def benchmark_with_stabilization(executable_name, n, max_iters, std_threshold): + times = [] + for i in range(max_iters): + result = run_command(f"../target/release/examples/{executable_name} {n}") + times.append(int(result)) + if len(times) > 10: # Start evaluating after a minimum number of runs + current_std = np.std(times) / np.mean(times) + if current_std < std_threshold: + break + return times + + +def benchmark( + benchmark_name, output_subdir, start, end, max_iters, std_threshold, executable_name +): + output_dir_path = Path(OUTPUT_DIR) / output_subdir + output_dir_path.mkdir(parents=True, exist_ok=True) + + for n in range(start, end + 1): + logging.info( + f"Running {benchmark_name} benchmark for N = 2^{n} with a standard deviation threshold of {std_threshold * 100}%..." + ) + times = benchmark_with_stabilization( + executable_name, n, max_iters, std_threshold + ) + output_file = output_dir_path / f"size_{n}" + with open(output_file, "w") as f: + for time in times: + f.write(f"{time}\n") + logging.info( + f"Completed N = 2^{n} in {len(times)} iterations with a final standard deviation of {np.std(times) / np.mean(times):.2%}" + ) + + +def read_benchmark_results(output_dir, start, end): + sizes = [] + times = [] + + for n in range(start, end + 1): + size_file = Path(output_dir) / f"size_{n}" + if size_file.exists(): + with open(size_file, "r") as f: + data = f.readlines() + data = [int(line.strip()) for line in data] + if data: + min_time_ns = min(data) + sizes.append(2**n) + times.append(min_time_ns) + else: + logging.warning(f"No data found in file: {size_file}") + else: + logging.warning(f"File does not exist: {size_file}") + + return sizes, times + + +def plot_benchmark_results(output_subdirs, start, end, history_dirs=[]): + plt.figure(figsize=(10, 6)) + has_data = False + + # Plot current results + for subdir in output_subdirs: + sizes, times = read_benchmark_results(Path(OUTPUT_DIR) / subdir, start, end) + if sizes and times: + has_data = True + plt.plot(sizes, times, marker="o", label=f"current {subdir}") + + # Plot previous results from history for PhastFT + for history_dir in history_dirs: + sizes, times = read_benchmark_results( + Path(history_dir) / "benchmark_output" / "phastft", start, end + ) + if sizes and times: + has_data = True + timestamp = Path(history_dir).stem + plt.plot( + sizes, times, marker="x", linestyle="--", label=f"{timestamp} phastft" + ) + + if has_data: + plt.title("Benchmark Results") + plt.xlabel("FFT Size (N)") + plt.ylabel("Minimum Time (ns)") + plt.xscale("log") + plt.yscale("log") + plt.grid(True, which="both", ls="--") + plt.legend() + plt.savefig(f"{OUTPUT_DIR}/benchmark_results.png", dpi=600) + # plt.show() + else: + logging.warning("No data available to plot.") + + +def compare_results(current_dir, previous_dir, start, end): + changes = {} + for n in range(start, end + 1): + current_file = Path(current_dir) / f"size_{n}" + previous_file = ( + Path(previous_dir) / "benchmark_output" / "phastft" / f"size_{n}" + ) + + if current_file.exists() and previous_file.exists(): + with open(current_file, "r") as cf, open(previous_file, "r") as pf: + current_data = [int(line.strip()) for line in cf.readlines()] + previous_data = [int(line.strip()) for line in pf.readlines()] + + if current_data and previous_data: + current_min = min(current_data) + previous_min = min(previous_data) + + if current_min != previous_min: + change = ((current_min - previous_min) / previous_min) * 100 + changes[n] = change + else: + logging.warning( + f"Data missing in files for size 2^{n}: Current data length: {len(current_data)}, Previous data length: {len(previous_data)}" + ) + else: + logging.warning( + f"Missing files for size 2^{n}: Current file exists: {current_file.exists()}, Previous file exists: {previous_file.exists()}" + ) + + return changes + + +def archive_current_results(): + if Path(OUTPUT_DIR).exists(): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + history_dir = Path(HISTORY_DIR) / timestamp + history_dir.mkdir(parents=True, exist_ok=True) + shutil.move(OUTPUT_DIR, history_dir) + logging.info(f"Archived current results to: {history_dir}") + else: + logging.warning( + f"Output directory '{OUTPUT_DIR}' does not exist and cannot be archived." + ) + + +def main(): + clean_build_rust() + + # Check if there are previous results for comparison + history_dirs = ( + sorted(Path(HISTORY_DIR).iterdir(), key=os.path.getmtime) + if Path(HISTORY_DIR).exists() + else [] + ) + latest_previous_dir = history_dirs[-1] if history_dirs else None + + # Run new benchmarks for PhastFT, RustFFT, and FFTW3 + benchmark("PhastFT", "phastft", START, END, MAX_ITERS, STD_THRESHOLD, "benchmark") + benchmark("RustFFT", "rustfft", START, END, MAX_ITERS, STD_THRESHOLD, "rustfft") + benchmark( + "FFTW3 Rust bindings", "fftwrb", START, END, MAX_ITERS, STD_THRESHOLD, "fftwrb" + ) + + # Compare new PhastFT benchmarks against previous results + if latest_previous_dir: + logging.info(f"Comparing with previous results from: {latest_previous_dir}") + changes = compare_results( + Path(OUTPUT_DIR) / "phastft", latest_previous_dir, START, END + ) + for n, change in changes.items(): + status = "improvement" if change < 0 else "regression" + logging.info(f"N = 2^{n}: {abs(change):.2f}% {status}") + else: + logging.info("No previous results found for comparison.") + + # Plot benchmark results + plot_benchmark_results(["phastft", "rustfft", "fftwrb"], START, END, history_dirs) + + # Archive current results + archive_current_results() + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark.rs b/examples/benchmark.rs index e6ce897..0fe62bf 100644 --- a/examples/benchmark.rs +++ b/examples/benchmark.rs @@ -18,7 +18,7 @@ fn benchmark_fft_64(n: usize) { let now = std::time::Instant::now(); fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &planner); - let elapsed = now.elapsed().as_micros(); + let elapsed = now.elapsed().as_nanos(); println!("{elapsed}"); } diff --git a/examples/fftwrb.rs b/examples/fftwrb.rs index db4beee..3b1bf5a 100644 --- a/examples/fftwrb.rs +++ b/examples/fftwrb.rs @@ -34,7 +34,7 @@ fn benchmark_fftw(n: usize) { &mut nums, ) .unwrap(); - let elapsed = now.elapsed().as_micros(); + let elapsed = now.elapsed().as_nanos(); println!("{elapsed}"); } diff --git a/examples/rustfft.rs b/examples/rustfft.rs index 8e9ac51..6f40f80 100644 --- a/examples/rustfft.rs +++ b/examples/rustfft.rs @@ -28,7 +28,7 @@ fn benchmark_rustfft(n: usize) { let now = std::time::Instant::now(); fft.process(&mut signal); - let elapsed = now.elapsed().as_micros(); + let elapsed = now.elapsed().as_nanos(); println!("{elapsed}"); } diff --git a/hooks/pre-commit b/hooks/pre-commit index 6d3f529..8f289a5 100755 --- a/hooks/pre-commit +++ b/hooks/pre-commit @@ -9,13 +9,13 @@ then exit 1 fi -if ! cargo clippy -- -D warnings +if ! cargo clippy --all-targets --all-features --tests -- -D warnings then echo "There are some clippy issues." exit 1 fi -if ! cargo test +if ! cargo test --all-features then echo "There are some test issues." exit 1 diff --git a/src/kernels.rs b/src/kernels.rs index 8651ee9..a1fe09c 100644 --- a/src/kernels.rs +++ b/src/kernels.rs @@ -119,15 +119,15 @@ pub(crate) fn fft_chunk_n( ))] #[inline] pub(crate) fn fft_chunk_4(reals: &mut [T], imags: &mut [T]) { - let dist = 2; - let chunk_size = dist << 1; + const DIST: usize = 2; + const CHUNK_SIZE: usize = DIST << 1; reals - .chunks_exact_mut(chunk_size) - .zip(imags.chunks_exact_mut(chunk_size)) + .chunks_exact_mut(CHUNK_SIZE) + .zip(imags.chunks_exact_mut(CHUNK_SIZE)) .for_each(|(reals_chunk, imags_chunk)| { - let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist); - let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist); + let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST); + let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST); let real_c0 = reals_s0[0]; let real_c1 = reals_s1[0]; diff --git a/src/lib.rs b/src/lib.rs index 2ece7ec..8f03b6d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,20 @@ #![doc = include_str!("../README.md")] -#![warn(clippy::complexity)] -#![warn(missing_docs)] -#![warn(clippy::style)] -#![warn(clippy::correctness)] -#![warn(clippy::suspicious)] -#![warn(clippy::perf)] +#![warn( + missing_docs, + clippy::complexity, + clippy::perf, + clippy::style, + clippy::correctness, + clippy::suspicious +)] #![forbid(unsafe_code)] #![feature(portable_simd, avx512_target_feature)] +#[cfg(feature = "complex-nums")] +use crate::utils::{combine_re_im, deinterleave_complex32, deinterleave_complex64}; +#[cfg(feature = "complex-nums")] +use num_complex::Complex; + use crate::cobra::cobra_apply; use crate::kernels::{ fft_32_chunk_n_simd, fft_64_chunk_n_simd, fft_chunk_2, fft_chunk_4, fft_chunk_n, @@ -64,6 +71,30 @@ macro_rules! impl_fft_for { impl_fft_for!(fft_64, f64, Planner64, fft_64_with_opts_and_plan); impl_fft_for!(fft_32, f32, Planner32, fft_32_with_opts_and_plan); +#[cfg(feature = "complex-nums")] +macro_rules! impl_fft_interleaved_for { + ($func_name:ident, $precision:ty, $fft_func:ident, $deinterleaving_func: ident) => { + /// FFT Interleaved -- this is an alternative to [`fft_64`]/[`fft_32`] in the case where + /// the input data is a array of [`Complex`]. + /// + /// The input should be provided in normal order, and then the modified input is + /// bit-reversed. + /// + /// ## References + /// + pub fn $func_name(signal: &mut [Complex<$precision>], direction: Direction) { + let (mut reals, mut imags) = $deinterleaving_func(signal); + $fft_func(&mut reals, &mut imags, direction); + signal.copy_from_slice(&combine_re_im(&reals, &imags)) + } + }; +} + +#[cfg(feature = "complex-nums")] +impl_fft_interleaved_for!(fft_32_interleaved, f32, fft_32, deinterleave_complex32); +#[cfg(feature = "complex-nums")] +impl_fft_interleaved_for!(fft_64_interleaved, f64, fft_64, deinterleave_complex64); + macro_rules! impl_fft_with_opts_and_plan_for { ($func_name:ident, $precision:ty, $planner:ty, $simd_butterfly_kernel:ident, $lanes:literal) => { /// Same as [fft], but also accepts [`Options`] that control optimization strategies, as well as @@ -300,6 +331,46 @@ mod tests { test_fft_correctness!(fft_correctness_32, f32, fft_32, 4, 9); test_fft_correctness!(fft_correctness_64, f64, fft_64, 4, 17); + #[cfg(feature = "complex-nums")] + #[test] + fn fft_interleaved_correctness() { + let n = 10; + let big_n = 1 << n; + let mut actual_signal: Vec<_> = (1..=big_n).map(|i| Complex::new(i as f64, 0.0)).collect(); + let mut expected_reals: Vec<_> = (1..=big_n).map(|i| i as f64).collect(); + let mut expected_imags = vec![0.0; big_n]; + + fft_64_interleaved(&mut actual_signal, Direction::Forward); + fft_64(&mut expected_reals, &mut expected_imags, Direction::Forward); + + actual_signal + .iter() + .zip(expected_reals) + .zip(expected_imags) + .for_each(|((z, z_re), z_im)| { + assert_float_closeness(z.re, z_re, 1e-10); + assert_float_closeness(z.im, z_im, 1e-10); + }); + + let n = 10; + let big_n = 1 << n; + let mut actual_signal: Vec<_> = (1..=big_n).map(|i| Complex::new(i as f32, 0.0)).collect(); + let mut expected_reals: Vec<_> = (1..=big_n).map(|i| i as f32).collect(); + let mut expected_imags = vec![0.0; big_n]; + + fft_32_interleaved(&mut actual_signal, Direction::Forward); + fft_32(&mut expected_reals, &mut expected_imags, Direction::Forward); + + actual_signal + .iter() + .zip(expected_reals) + .zip(expected_imags) + .for_each(|((z, z_re), z_im)| { + assert_float_closeness(z.re, z_re, 1e-10); + assert_float_closeness(z.im, z_im, 1e-10); + }); + } + #[test] fn fft_round_trip() { for i in 4..23 { @@ -341,9 +412,9 @@ mod tests { let mut re = reals.clone(); let mut im = imags.clone(); - let mut planner = Planner64::new(num_points, Direction::Forward); + let planner = Planner64::new(num_points, Direction::Forward); let opts = Options::guess_options(reals.len()); - fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner); + fft_64_with_opts_and_plan(&mut reals, &mut imags, &opts, &planner); fft_64(&mut re, &mut im, Direction::Forward); @@ -372,9 +443,9 @@ mod tests { let mut re = reals.clone(); let mut im = imags.clone(); - let mut planner = Planner32::new(num_points, direction); + let planner = Planner32::new(num_points, direction); let opts = Options::guess_options(reals.len()); - fft_32_with_opts_and_plan(&mut reals, &mut imags, &opts, &mut planner); + fft_32_with_opts_and_plan(&mut reals, &mut imags, &opts, &planner); fft_32(&mut re, &mut im, direction); diff --git a/src/twiddles.rs b/src/twiddles.rs index 7e11ba9..144e9fc 100644 --- a/src/twiddles.rs +++ b/src/twiddles.rs @@ -217,7 +217,7 @@ mod tests { use super::*; - // TODO(saveliy): use + // TODO(saveliy): try to use only real twiddle factors since sin is just a phase shift of cos #[test] fn twiddles_cos_only() { let n = 4; @@ -234,12 +234,12 @@ mod tests { assert!(fwd_twiddles_re.len() == dist && fwd_twiddles_im.len() == dist); for i in 0..dist { - let w_re = fwd_twiddles_re[i]; + let _w_re = fwd_twiddles_re[i]; let expected_w_im = fwd_twiddles_im[i]; let actual_w_im = -fwd_twiddles_re[(i + dist / 2) % dist]; //assert_float_closeness(actual_w_im, expected_w_im, 1e-6); - print!("actual: {actual_w_im} expected: {expected_w_im}\n"); + println!("actual: {actual_w_im} expected: {expected_w_im}"); } println!("{:?}", fwd_twiddles_re); println!("{:?}", fwd_twiddles_im); diff --git a/src/utils.rs b/src/utils.rs index 70ed734..3c0f240 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,14 @@ //! Utility functions such as interleave/deinterleave +#[cfg(feature = "complex-nums")] +use num_complex::Complex; + +#[cfg(feature = "complex-nums")] +use num_traits::Float; + +#[cfg(feature = "complex-nums")] +use bytemuck::cast_slice; + use std::simd::{prelude::Simd, simd_swizzle, SimdElement}; // We don't multiversion for AVX-512 here and keep the chunk size below AVX-512 @@ -59,12 +68,53 @@ pub(crate) fn deinterleave(input: &[T]) -> (Vec (out_odd, out_even) } +/// Utility function to separate a slice of [`Complex64``] +/// into a single vector of Complex Number Structs. +/// +/// # Panics +/// +/// Panics if `reals.len() != imags.len()`. +#[cfg(feature = "complex-nums")] +pub(crate) fn deinterleave_complex64(signal: &[Complex]) -> (Vec, Vec) { + let complex_t: &[f64] = cast_slice(signal); + deinterleave(complex_t) +} + +/// Utility function to separate a slice of [`Complex32``] +/// into a single vector of Complex Number Structs. +/// +/// # Panics +/// +/// Panics if `reals.len() != imags.len()`. +#[cfg(feature = "complex-nums")] +pub(crate) fn deinterleave_complex32(signal: &[Complex]) -> (Vec, Vec) { + let complex_t: &[f32] = cast_slice(signal); + deinterleave(complex_t) +} + +/// Utility function to combine separate vectors of real and imaginary components +/// into a single vector of Complex Number Structs. +/// +/// # Panics +/// +/// Panics if `reals.len() != imags.len()`. +#[cfg(feature = "complex-nums")] +pub(crate) fn combine_re_im(reals: &[T], imags: &[T]) -> Vec> { + assert_eq!(reals.len(), imags.len()); + + reals + .iter() + .zip(imags.iter()) + .map(|(z_re, z_im)| Complex::new(*z_re, *z_im)) + .collect() +} + #[cfg(test)] mod tests { - use super::deinterleave; + use super::*; fn gen_test_vec(len: usize) -> Vec { - (0..len).into_iter().collect() + (0..len).collect() } /// Slow but obviously correct implementation of deinterleaving, @@ -83,4 +133,21 @@ mod tests { assert_eq!(naive_b, opt_b); } } + + #[cfg(feature = "complex-nums")] + #[test] + fn test_separate_and_combine_re_im() { + let complex_vec: Vec<_> = vec![ + Complex::new(1.0, 2.0), + Complex::new(3.0, 4.0), + Complex::new(5.0, 6.0), + Complex::new(7.0, 8.0), + ]; + + let (reals, imags) = deinterleave_complex64(&complex_vec); + + let recombined_vec = combine_re_im(&reals, &imags); + + assert_eq!(complex_vec, recombined_vec); + } }