diff --git a/Cargo.toml b/Cargo.toml index 410e232..fb9715d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,10 @@ keywords = ["quantum", "fft", "discrete", "fourier", "transform"] categories = ["algorithms", "compression", "science"] exclude = ["assets", "scripts", "benches"] +[features] +default = ["double"] +single = [] +double = [] [dev-dependencies] utilities = { path = "utilities" } diff --git a/examples/benchmark.rs b/examples/benchmark.rs index 952d197..a9455d6 100644 --- a/examples/benchmark.rs +++ b/examples/benchmark.rs @@ -3,13 +3,13 @@ use std::str::FromStr; use utilities::gen_random_signal; -use phastft::fft; use phastft::planner::Direction; +use phastft::{fft, Float}; fn benchmark_fft(n: usize) { let big_n = 1 << n; - let mut reals = vec![0.0; big_n]; - let mut imags = vec![0.0; big_n]; + let mut reals: Vec = vec![0.0; big_n]; + let mut imags: Vec = vec![0.0; big_n]; gen_random_signal(&mut reals, &mut imags); let now = std::time::Instant::now(); diff --git a/examples/fftwrb.rs b/examples/fftwrb.rs index db4beee..87743c6 100644 --- a/examples/fftwrb.rs +++ b/examples/fftwrb.rs @@ -2,18 +2,24 @@ use std::{env, ptr::slice_from_raw_parts_mut, str::FromStr}; use fftw::{ array::AlignedVec, - plan::{C2CPlan, C2CPlan64}, types::{Flag, Sign}, }; +use fftw::plan::C2CPlan; +#[cfg(feature = "single")] +use fftw::plan::C2CPlan32; +#[cfg(feature = "double")] +use fftw::plan::C2CPlan64; use utilities::{gen_random_signal, rustfft::num_complex::Complex}; +use phastft::Float; + fn benchmark_fftw(n: usize) { let big_n = 1 << n; let mut reals = vec![0.0; big_n]; let mut imags = vec![0.0; big_n]; - gen_random_signal(&mut reals, &mut imags); + gen_random_signal::(&mut reals, &mut imags); let mut nums = AlignedVec::new(big_n); reals .drain(..) @@ -22,6 +28,8 @@ fn benchmark_fftw(n: usize) { .for_each(|((re, im), z)| *z = Complex::new(re, im)); let now = std::time::Instant::now(); + + #[cfg(feature = "double")] C2CPlan64::aligned( &[big_n], Sign::Backward, @@ -34,6 +42,21 @@ fn benchmark_fftw(n: usize) { &mut nums, ) .unwrap(); + + #[cfg(feature = "single")] + C2CPlan32::aligned( + &[big_n], + Sign::Backward, + Flag::DESTROYINPUT | Flag::ESTIMATE, + ) + .unwrap() + .c2c( + // SAFETY: See above comment. + unsafe { &mut *slice_from_raw_parts_mut(nums.as_mut_ptr(), big_n) }, + &mut nums, + ) + .unwrap(); + let elapsed = now.elapsed().as_micros(); println!("{elapsed}"); } diff --git a/examples/profile.rs b/examples/profile.rs index 112e17e..e56d511 100644 --- a/examples/profile.rs +++ b/examples/profile.rs @@ -1,13 +1,13 @@ use std::env; use std::str::FromStr; -use phastft::fft; use phastft::planner::Direction; +use phastft::{fft, Float}; fn benchmark_fft(num_qubits: usize) { let n = 1 << num_qubits; - let mut reals: Vec = (1..=n).map(|i| i as f64).collect(); - let mut imags: Vec = (1..=n).map(|i| i as f64).collect(); + let mut reals: Vec<_> = (1..=n).map(|i| i as Float).collect(); + let mut imags: Vec<_> = (1..=n).map(|i| i as Float).collect(); fft(&mut reals, &mut imags, Direction::Forward); } diff --git a/src/cobra.rs b/src/cobra.rs index c9809a9..150da77 100644 --- a/src/cobra.rs +++ b/src/cobra.rs @@ -13,7 +13,8 @@ //! Symposium on Foundations of Computer Science (Cat. No.98CB36280), Palo Alto, CA, USA, 1998, pp. 544-553, doi: //! 10.1109/SFCS.1998.743505. //! keywords: {Read-write memory;Costs;Computer science;Drives;Random access memory;Argon;Registers;Read only memory;Computational modeling;Libraries} -use crate::kernels::Float; + +use crate::Float; const BLOCK_WIDTH: usize = 128; // size of the cacheline @@ -317,17 +318,17 @@ mod tests { fn jennifer_method() { for n in 2..24 { let big_n = 1 << n; - let mut actual_re: Vec = (0..big_n).map(f64::from).collect(); - let mut actual_im: Vec = (0..big_n).map(f64::from).collect(); + let mut actual_re: Vec = (0..big_n).map(|i| i as Float).collect(); + let mut actual_im: Vec = (0..big_n).map(|i| i as Float).collect(); #[allow(deprecated)] complex_bit_rev(&mut actual_re, &mut actual_im, n); - let input_re: Vec = (0..big_n).map(f64::from).collect(); + let input_re: Vec = (0..big_n).map(|i| i as Float).collect(); let expected_re = top_down_bit_reverse_permutation(&input_re); assert_eq!(actual_re, expected_re); - let input_im: Vec = (0..big_n).map(f64::from).collect(); + let input_im: Vec = (0..big_n).map(|i| i as Float).collect(); let expected_im = top_down_bit_reverse_permutation(&input_im); assert_eq!(actual_im, expected_im); } @@ -337,17 +338,17 @@ mod tests { fn jennifer_method_parallel() { for n in 2..24 { let big_n = 1 << n; - let mut actual_re: Vec = (0..big_n).map(f64::from).collect(); - let mut actual_im: Vec = (0..big_n).map(f64::from).collect(); + let mut actual_re: Vec = (0..big_n).map(|i| i as Float).collect(); + let mut actual_im: Vec = (0..big_n).map(|i| i as Float).collect(); #[allow(deprecated)] bit_reverse_permute_state_par(&mut actual_re, &mut actual_im, n); - let input_re: Vec = (0..big_n).map(f64::from).collect(); + let input_re: Vec = (0..big_n).map(|i| i as Float).collect(); let expected_re = top_down_bit_reverse_permutation(&input_re); assert_eq!(actual_re, expected_re); - let input_im: Vec = (0..big_n).map(f64::from).collect(); + let input_im: Vec = (0..big_n).map(|i| i as Float).collect(); let expected_im = top_down_bit_reverse_permutation(&input_im); assert_eq!(actual_im, expected_im); } diff --git a/src/kernels.rs b/src/kernels.rs index 67991da..d6da467 100644 --- a/src/kernels.rs +++ b/src/kernels.rs @@ -1,7 +1,11 @@ +#[cfg(feature = "single")] +use std::simd::f32x16; +#[cfg(feature = "double")] use std::simd::f64x8; -pub type Float = f64; +use crate::Float; +#[cfg(feature = "double")] pub(crate) fn fft_chunk_n_simd( reals: &mut [Float], imags: &mut [Float], @@ -45,7 +49,51 @@ pub(crate) fn fft_chunk_n_simd( }); } -// TODO(saveliy): parallelize +#[cfg(feature = "single")] +pub(crate) fn fft_chunk_n_simd( + reals: &mut [Float], + imags: &mut [Float], + twiddles_re: &[Float], + twiddles_im: &[Float], + dist: usize, +) { + const VECTOR_SIZE: usize = 16; + let chunk_size = dist << 1; + assert!(chunk_size >= 32); + + reals + .chunks_exact_mut(chunk_size) + .zip(imags.chunks_exact_mut(chunk_size)) + .for_each(|(reals_chunk, imags_chunk)| { + let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist); + let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist); + + reals_s0 + .chunks_exact_mut(VECTOR_SIZE) + .zip(reals_s1.chunks_exact_mut(VECTOR_SIZE)) + .zip(imags_s0.chunks_exact_mut(VECTOR_SIZE)) + .zip(imags_s1.chunks_exact_mut(VECTOR_SIZE)) + .zip(twiddles_re.chunks_exact(VECTOR_SIZE)) + .zip(twiddles_im.chunks_exact(VECTOR_SIZE)) + .for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| { + let real_c0 = f32x16::from_slice(re_s0); + let real_c1 = f32x16::from_slice(re_s1); + let imag_c0 = f32x16::from_slice(im_s0); + let imag_c1 = f32x16::from_slice(im_s1); + + let tw_re = f32x16::from_slice(w_re); + let tw_im = f32x16::from_slice(w_im); + + re_s0.copy_from_slice((real_c0 + real_c1).as_array()); + im_s0.copy_from_slice((imag_c0 + imag_c1).as_array()); + let v_re = real_c0 - real_c1; + let v_im = imag_c0 - imag_c1; + re_s1.copy_from_slice((v_re * tw_re - v_im * tw_im).as_array()); + im_s1.copy_from_slice((v_re * tw_im + v_im * tw_re).as_array()); + }); + }); +} + pub(crate) fn fft_chunk_n( reals: &mut [Float], imags: &mut [Float], @@ -119,7 +167,7 @@ pub(crate) fn fft_chunk_4(reals: &mut [Float], imags: &mut [Float]) { }); } -/// `chunk_size == 2`, so skip phase +/// `chunk_size == 2`, so we only need 1 and -1 pub(crate) fn fft_chunk_2(reals: &mut [Float], imags: &mut [Float]) { reals .chunks_exact_mut(2) diff --git a/src/lib.rs b/src/lib.rs index 4d33350..fb98342 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ #![feature(portable_simd)] use crate::cobra::cobra_apply; -use crate::kernels::{fft_chunk_2, fft_chunk_4, fft_chunk_n, fft_chunk_n_simd, Float}; +use crate::kernels::{fft_chunk_2, fft_chunk_4, fft_chunk_n, fft_chunk_n_simd}; use crate::options::Options; use crate::planner::{Direction, Planner}; use crate::twiddles::filter_twiddles; @@ -20,6 +20,14 @@ pub mod options; pub mod planner; mod twiddles; +/// Redefine `Float` as `f64` for double precision data +#[cfg(feature = "double")] +pub type Float = f64; + +/// Redefine `Float` as `f32` for single precision data +#[cfg(feature = "single")] +pub type Float = f32; + /// FFT -- Decimation in Frequency. This is just the decimation-in-time algorithm, reversed. /// This call to FFT is run, in-place. /// The input should be provided in normal order, and then the modified input is bit-reversed. @@ -83,7 +91,7 @@ pub fn fft_with_opts_and_plan( if t < n - 1 { filter_twiddles(twiddles_re, twiddles_im); } - if chunk_size >= 16 { + if chunk_size >= 32 { fft_chunk_n_simd(reals, imags, twiddles_re, twiddles_im, dist); } else { fft_chunk_n(reals, imags, twiddles_re, twiddles_im, dist); @@ -110,10 +118,9 @@ pub fn fft_with_opts_and_plan( mod tests { use std::ops::Range; - use utilities::{ - assert_f64_closeness, - rustfft::{num_complex::Complex64, FftPlanner}, - }; + use utilities::assert_float_closeness; + use utilities::rustfft::FftPlanner; + use utilities::rustfft::num_complex::Complex; use crate::planner::Direction; @@ -165,12 +172,12 @@ mod tests { for k in range { let n: usize = 1 << k; - let mut reals: Vec = (1..=n).map(|i| i as f64).collect(); - let mut imags: Vec = (1..=n).map(|i| i as f64).collect(); + let mut reals: Vec = (1..=n).map(|i| i as Float).collect(); + let mut imags: Vec = (1..=n).map(|i| i as Float).collect(); fft(&mut reals, &mut imags, Direction::Forward); - let mut buffer: Vec = (1..=n) - .map(|i| Complex64::new(i as f64, i as f64)) + let mut buffer: Vec> = (1..=n) + .map(|i| Complex::new(i as Float, i as Float)) .collect(); let mut planner = FftPlanner::new(); @@ -184,8 +191,8 @@ mod tests { .for_each(|(i, (z_re, z_im))| { let expect_re = buffer[i].re; let expect_im = buffer[i].im; - assert_f64_closeness(*z_re, expect_re, 0.01); - assert_f64_closeness(*z_im, expect_im, 0.01); + assert_float_closeness(*z_re, expect_re, 0.01); + assert_float_closeness(*z_im, expect_im, 0.01); }); } } diff --git a/src/planner.rs b/src/planner.rs index cec8fe4..77db7ab 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -2,7 +2,7 @@ //! a Fast Fourier Transform (FFT). Currently, the planner is responsible for //! pre-computing twiddle factors based on the input signal length, as well as the //! direction of the FFT. - +use crate::Float; use crate::twiddles::{generate_twiddles, generate_twiddles_simd}; /// Reverse is for running the Inverse Fast Fourier Transform (IFFT) @@ -20,9 +20,9 @@ pub enum Direction { /// the amount of twiddle factors should always be `(1/2) * N` pub struct Planner { /// The real components of the twiddle factors - pub twiddles_re: Vec, + pub twiddles_re: Vec, /// The imaginary components of the twiddle factors - pub twiddles_im: Vec, + pub twiddles_im: Vec, } impl Planner { @@ -66,7 +66,7 @@ impl Planner { #[cfg(test)] mod tests { - use utilities::assert_f64_closeness; + use utilities::assert_float_closeness; use crate::planner::{Direction, Planner}; @@ -101,8 +101,8 @@ mod tests { .for_each(|(((a, b), c), d)| { let temp_re = a * c - b * d; let temp_im = a * d + b * c; - assert_f64_closeness(temp_re, 1.0, 1e-6); - assert_f64_closeness(temp_im, 0.0, 1e-6); + assert_float_closeness(temp_re, 1.0, 1e-3); + assert_float_closeness(temp_im, 0.0, 1e-3); }); } } diff --git a/src/twiddles.rs b/src/twiddles.rs index 48b91fa..ff98dd8 100644 --- a/src/twiddles.rs +++ b/src/twiddles.rs @@ -1,8 +1,15 @@ -use std::{f64::consts::PI, simd::f64x8}; +#[cfg(feature = "double")] +use std::simd::Simd; -use crate::kernels::Float; +use crate::Float; use crate::planner::Direction; +#[cfg(feature = "single")] +const PI: Float = std::f32::consts::PI; + +#[cfg(feature = "double")] +const PI: Float = std::f64::consts::PI; + pub(crate) struct Twiddles { st: Float, ct: Float, @@ -31,7 +38,7 @@ impl Twiddles { impl Iterator for Twiddles { type Item = (Float, Float); - fn next(&mut self) -> Option<(f64, f64)> { + fn next(&mut self) -> Option<(Float, Float)> { let w_re = self.w_re_prev; let w_im = self.w_im_prev; @@ -43,7 +50,7 @@ impl Iterator for Twiddles { } } -pub fn generate_twiddles(dist: usize, direction: Direction) -> (Vec, Vec) { +pub fn generate_twiddles(dist: usize, direction: Direction) -> (Vec, Vec) { let mut twiddles_re = vec![0.0; dist]; let mut twiddles_im = vec![0.0; dist]; twiddles_re[0] = 1.0; @@ -53,7 +60,7 @@ pub fn generate_twiddles(dist: usize, direction: Direction) -> (Vec, Vec -1.0, }; - let angle: Float = sign * -PI / (dist as f64); + let angle: Float = sign * -PI / (dist as Float); let (st, ct) = angle.sin_cos(); let (mut w_re, mut w_im) = (1.0, 0.0); @@ -76,7 +83,10 @@ pub fn generate_twiddles(dist: usize, direction: Direction) -> (Vec, Vec (Vec, Vec) { +pub(crate) fn generate_twiddles_simd( + dist: usize, + direction: Direction, +) -> (Vec, Vec) { const CHUNK_SIZE: usize = 8; // TODO: make this a const generic? assert!(dist >= CHUNK_SIZE * 2); assert_eq!(dist % CHUNK_SIZE, 0); @@ -89,7 +99,7 @@ pub(crate) fn generate_twiddles_simd(dist: usize, direction: Direction) -> (Vec< Direction::Reverse => -1.0, }; - let angle: Float = sign * -PI / (dist as f64); + let angle: Float = sign * -PI / (dist as Float); let (st, ct) = angle.sin_cos(); let (mut w_re, mut w_im) = (1.0, 0.0); @@ -101,8 +111,8 @@ pub(crate) fn generate_twiddles_simd(dist: usize, direction: Direction) -> (Vec< }; let apply_symmetry_re = |input: &[Float], output: &mut [Float]| { - let first_re = f64x8::from_slice(input); - let minus_one = f64x8::splat(-1.0); + let first_re = Simd::::from_slice(input); + let minus_one = Simd::::splat(-1.0); let negated = (first_re * minus_one).reverse(); output.copy_from_slice(negated.as_array()); }; @@ -168,13 +178,13 @@ pub(crate) fn generate_twiddles_simd(dist: usize, direction: Direction) -> (Vec< (twiddles_re, twiddles_im) } -pub(crate) fn filter_twiddles(twiddles_re: &mut Vec, twiddles_im: &mut Vec) { +pub(crate) fn filter_twiddles(twiddles_re: &mut Vec, twiddles_im: &mut Vec) { assert_eq!(twiddles_re.len(), twiddles_im.len()); let dist = twiddles_re.len(); - let filtered_twiddles_re: Vec = + let filtered_twiddles_re: Vec = twiddles_re.chunks_exact(2).map(|chunk| chunk[0]).collect(); - let filtered_twiddles_im: Vec = + let filtered_twiddles_im: Vec = twiddles_im.chunks_exact(2).map(|chunk| chunk[0]).collect(); assert!( @@ -188,12 +198,16 @@ pub(crate) fn filter_twiddles(twiddles_re: &mut Vec, twiddles_im: &mut Vec< #[cfg(test)] mod tests { - use std::f64::consts::FRAC_1_SQRT_2; - - use utilities::assert_f64_closeness; + use utilities::assert_float_closeness; use super::*; + #[cfg(feature = "double")] + const FRAC_1_SQRT_2: Float = std::f64::consts::FRAC_1_SQRT_2; + + #[cfg(feature = "single")] + const FRAC_1_SQRT_2: Float = std::f32::consts::FRAC_1_SQRT_2; + #[test] fn twiddles_4() { const N: usize = 4; @@ -201,23 +215,23 @@ mod tests { let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, 1.0, 1e-10); - assert_f64_closeness(w_im, 0.0, 1e-10); + assert_float_closeness(w_re, 1.0, 1e-6); + assert_float_closeness(w_im, 0.0, 1e-6); let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, FRAC_1_SQRT_2, 1e-10); - assert_f64_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_re, FRAC_1_SQRT_2, 1e-6); + assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-6); let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, 0.0, 1e-10); - assert_f64_closeness(w_im, -1.0, 1e-10); + assert_float_closeness(w_re, 0.0, 1e-6); + assert_float_closeness(w_im, -1.0, 1e-6); let (w_re, w_im) = twiddle_iter.next().unwrap(); println!("{w_re} {w_im}"); - assert_f64_closeness(w_re, -FRAC_1_SQRT_2, 1e-10); - assert_f64_closeness(w_im, -FRAC_1_SQRT_2, 1e-10); + assert_float_closeness(w_re, -FRAC_1_SQRT_2, 1e-6); + assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-6); } #[test] @@ -232,14 +246,14 @@ mod tests { .iter() .zip(twiddles_re_ref.iter()) .for_each(|(simd, reference)| { - assert_f64_closeness(*simd, *reference, 1e-10); + assert_float_closeness(*simd, *reference, 1e-3); }); twiddles_im .iter() .zip(twiddles_im_ref.iter()) .for_each(|(simd, reference)| { - assert_f64_closeness(*simd, *reference, 1e-10); + assert_float_closeness(*simd, *reference, 1e-3); }); } } @@ -255,8 +269,8 @@ mod tests { for i in 0..dist { let (tw_re, tw_im) = twiddles_iter.next().unwrap(); - assert_f64_closeness(twiddles_re[i], tw_re, 1e-6); - assert_f64_closeness(twiddles_im[i], tw_im, 1e-6); + assert_float_closeness(twiddles_re[i], tw_re, 1e-6); + assert_float_closeness(twiddles_im[i], tw_im, 1e-6); } for t in (0..n - 1).rev() { @@ -271,8 +285,8 @@ mod tests { for i in 0..dist { let (tw_re, tw_im) = twiddles_iter.next().unwrap(); - assert_f64_closeness(twiddles_re[i], tw_re, 1e-6); - assert_f64_closeness(twiddles_im[i], tw_im, 1e-6); + assert_float_closeness(twiddles_re[i], tw_re, 1e-6); + assert_float_closeness(twiddles_im[i], tw_im, 1e-6); } } } diff --git a/utilities/src/lib.rs b/utilities/src/lib.rs index 878048f..47b7c8f 100644 --- a/utilities/src/lib.rs +++ b/utilities/src/lib.rs @@ -1,33 +1,17 @@ pub extern crate rustfft; -use std::f64::consts::PI; +// export rustfft to phastft +use rand::{distributions::Uniform, prelude::*}; +use rustfft::num_traits::Float; -use rand::distributions::Uniform; -use rand::prelude::*; - -/// Asserts that two f64 numbers are approximately equal. -/// -/// # Panics -/// -/// Panics if `actual` and `expected` are too far from each other -#[allow(dead_code)] -#[track_caller] -pub fn assert_f64_closeness(actual: f64, expected: f64, epsilon: f64) { - if (actual - expected).abs() >= epsilon { - panic!( - "Assertion failed: {actual} too far from expected value {expected} (with epsilon {epsilon})", - ); - } -} - -/// Asserts that two f32 numbers are approximately equal. +/// Asserts that two fp numbers are approximately equal. /// /// # Panics /// /// Panics if `actual` and `expected` are too far from each other #[allow(dead_code)] #[track_caller] -pub fn assert_f32_closeness(actual: f32, expected: f32, epsilon: f32) { +pub fn assert_float_closeness(actual: T, expected: T, epsilon: T) { if (actual - expected).abs() >= epsilon { panic!( "Assertion failed: {actual} too far from expected value {expected} (with epsilon {epsilon})", @@ -40,34 +24,23 @@ pub fn assert_f32_closeness(actual: f32, expected: f32, epsilon: f32) { /// # Panics /// /// Panics if `reals.len() != imags.len()` -pub fn gen_random_signal(reals: &mut [f64], imags: &mut [f64]) { - assert!(reals.len() == imags.len() && !reals.is_empty()); - let mut rng = thread_rng(); - let between = Uniform::from(0.0..1.0); - let angle_dist = Uniform::from(0.0..2.0 * PI); - let num_amps = reals.len(); +pub fn gen_random_signal(reals: &mut [T], imags: &mut [T]) +where + T: Float + rand::distributions::uniform::SampleUniform, +{ + assert_eq!( + reals.len(), + imags.len(), + "Real and imaginary slices must be of equal length" + ); - let mut probs: Vec<_> = (0..num_amps).map(|_| between.sample(&mut rng)).collect(); - - let total: f64 = probs.iter().sum(); - let total_recip = total.recip(); - - probs.iter_mut().for_each(|p| *p *= total_recip); - - let angles = (0..num_amps).map(|_| angle_dist.sample(&mut rng)); + let mut rng = thread_rng(); - probs - .iter() - .zip(angles) - .enumerate() - .for_each(|(i, (p, a))| { - let p_sqrt = p.sqrt(); - let (sin_a, cos_a) = a.sin_cos(); - let re = p_sqrt * cos_a; - let im = p_sqrt * sin_a; - reals[i] = re; - imags[i] = im; - }); + let uniform_dist = Uniform::new(T::from(-1.0).unwrap(), T::from(1.0).unwrap()); + for (real, imag) in reals.iter_mut().zip(imags.iter_mut()) { + *real = uniform_dist.sample(&mut rng); + *imag = uniform_dist.sample(&mut rng); + } } #[cfg(test)] @@ -77,17 +50,17 @@ mod tests { #[test] fn generate_random_signal() { let big_n = 1 << 25; - let mut reals = vec![0.0; big_n]; - let mut imags = vec![0.0; big_n]; + let mut reals: Vec<_> = vec![0.0; big_n]; + let mut imags: Vec<_> = vec![0.0; big_n]; - gen_random_signal(&mut reals, &mut imags); + gen_random_signal::(&mut reals, &mut imags); - let sum: f64 = reals + let sum = reals .iter() .zip(imags.iter()) .map(|(re, im)| re.powi(2) + im.powi(2)) .sum(); - assert_f64_closeness(sum, 1.0, 1e-6); + assert_float_closeness(sum, 1.0, 1e-6); } }