Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for f32, as well as f64 #17

Merged
merged 9 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ keywords = ["quantum", "fft", "discrete", "fourier", "transform"]
categories = ["algorithms", "compression", "science"]
exclude = ["assets", "scripts", "benches"]

[dependencies]
num-traits = "0.2.18"

[dev-dependencies]
utilities = { path = "utilities" }
Expand All @@ -23,3 +25,4 @@ panic = "abort"
[profile.profiling]
inherits = "release"
debug = true

8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,16 @@ including [RustFFT](https://crates.io/crates/rustfft/), while using significantl
### Rust

```rust
use phastft::planner::Direction;
use phastft::fft;
use phastft::{
planner::Direction,
fft_64
};

fn main() {
let big_n = 1 << 10;
let mut reals: Vec<f64> = (1..=big_n).map(|i| i as f64).collect();
let mut imags: Vec<f64> = (1..=big_n).map(|i| i as f64).collect();
fft(&mut reals, &mut imags, Direction::Forward);
fft_64(&mut reals, &mut imags, Direction::Forward);
}
```

Expand Down
8 changes: 4 additions & 4 deletions examples/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@ use std::str::FromStr;

use utilities::gen_random_signal;

use phastft::fft;
use phastft::fft_64;
use phastft::planner::Direction;

fn benchmark_fft(n: usize) {
fn benchmark_fft_64(n: usize) {
let big_n = 1 << n;
let mut reals = vec![0.0; big_n];
let mut imags = vec![0.0; big_n];
gen_random_signal(&mut reals, &mut imags);

let now = std::time::Instant::now();
fft(&mut reals, &mut imags, Direction::Forward);
fft_64(&mut reals, &mut imags, Direction::Forward);
let elapsed = now.elapsed().as_micros();
println!("{elapsed}");
}
Expand All @@ -24,5 +24,5 @@ fn main() {

let n = usize::from_str(&args[1]).unwrap();

benchmark_fft(n);
benchmark_fft_64(n);
}
4 changes: 2 additions & 2 deletions examples/profile.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::env;
use std::str::FromStr;

use phastft::fft;
use phastft::fft_64;
use phastft::planner::Direction;

fn benchmark_fft(num_qubits: usize) {
let n = 1 << num_qubits;
let mut reals: Vec<f64> = (1..=n).map(|i| i as f64).collect();
let mut imags: Vec<f64> = (1..=n).map(|i| i as f64).collect();
fft(&mut reals, &mut imags, Direction::Forward);
fft_64(&mut reals, &mut imags, Direction::Forward);
}

fn main() {
Expand Down
44 changes: 3 additions & 41 deletions src/cobra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
//! Symposium on Foundations of Computer Science (Cat. No.98CB36280), Palo Alto, CA, USA, 1998, pp. 544-553, doi:
//! 10.1109/SFCS.1998.743505.
//! keywords: {Read-write memory;Costs;Computer science;Drives;Random access memory;Argon;Registers;Read only memory;Computational modeling;Libraries}
use crate::kernels::Float;

use num_traits::Float;

const BLOCK_WIDTH: usize = 128;
// size of the cacheline
Expand Down Expand Up @@ -78,7 +79,7 @@ pub(crate) fn bit_rev<T>(buf: &mut [T], log_n: usize) {
since = "0.1.0",
note = "Please use COBRA for a cache-optimal bit reverse permutation."
)]
fn complex_bit_rev(reals: &mut [Float], imags: &mut [Float], log_n: usize) {
fn complex_bit_rev<T: Float>(reals: &mut [T], imags: &mut [T], log_n: usize) {
let mut nodd: usize;
let mut noddrev; // to hold bitwise negated or odd values

Expand Down Expand Up @@ -129,25 +130,6 @@ fn complex_bit_rev(reals: &mut [Float], imags: &mut [Float], log_n: usize) {
}
}

/// Run in-place bit reversal on the entire state, in parallel.
/// This function uses 2 threads to run a bit reverse on the reals buffer on one thread, and the other thread handles
/// the bit reversal of the imaginaries buffer
#[allow(dead_code)]
#[deprecated(
since = "0.1.0",
note = "Naive bit reverse permutation is slow and not cache friendly. COBRA should be used instead."
)]
pub(crate) fn bit_reverse_permute_state_par(
reals: &mut [Float],
imags: &mut [Float],
log_n: usize,
) {
std::thread::scope(|s| {
s.spawn(|| bit_rev(reals, log_n));
s.spawn(|| bit_rev(imags, log_n));
});
}

#[allow(dead_code)]
#[deprecated(
since = "0.1.0",
Expand Down Expand Up @@ -333,26 +315,6 @@ mod tests {
}
}

#[test]
fn jennifer_method_parallel() {
for n in 2..24 {
let big_n = 1 << n;
let mut actual_re: Vec<f64> = (0..big_n).map(f64::from).collect();
let mut actual_im: Vec<f64> = (0..big_n).map(f64::from).collect();

#[allow(deprecated)]
bit_reverse_permute_state_par(&mut actual_re, &mut actual_im, n);

let input_re: Vec<f64> = (0..big_n).map(f64::from).collect();
let expected_re = top_down_bit_reverse_permutation(&input_re);
assert_eq!(actual_re, expected_re);

let input_im: Vec<f64> = (0..big_n).map(f64::from).collect();
let expected_im = top_down_bit_reverse_permutation(&input_im);
assert_eq!(actual_im, expected_im);
}
}

#[test]
fn naive_bit_reverse_permutation() {
for n in 2..24 {
Expand Down
113 changes: 59 additions & 54 deletions src/kernels.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,61 @@
use std::simd::f64x8;

pub type Float = f64;

pub(crate) fn fft_chunk_n_simd(
reals: &mut [Float],
imags: &mut [Float],
twiddles_re: &[Float],
twiddles_im: &[Float],
dist: usize,
) {
let chunk_size = dist << 1;
assert!(chunk_size >= 16);

reals
.chunks_exact_mut(chunk_size)
.zip(imags.chunks_exact_mut(chunk_size))
.for_each(|(reals_chunk, imags_chunk)| {
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist);
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist);

reals_s0
.chunks_exact_mut(8)
.zip(reals_s1.chunks_exact_mut(8))
.zip(imags_s0.chunks_exact_mut(8))
.zip(imags_s1.chunks_exact_mut(8))
.zip(twiddles_re.chunks_exact(8))
.zip(twiddles_im.chunks_exact(8))
.for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| {
let real_c0 = f64x8::from_slice(re_s0);
let real_c1 = f64x8::from_slice(re_s1);
let imag_c0 = f64x8::from_slice(im_s0);
let imag_c1 = f64x8::from_slice(im_s1);

let tw_re = f64x8::from_slice(w_re);
let tw_im = f64x8::from_slice(w_im);

re_s0.copy_from_slice((real_c0 + real_c1).as_array());
im_s0.copy_from_slice((imag_c0 + imag_c1).as_array());
let v_re = real_c0 - real_c1;
let v_im = imag_c0 - imag_c1;
re_s1.copy_from_slice((v_re * tw_re - v_im * tw_im).as_array());
im_s1.copy_from_slice((v_re * tw_im + v_im * tw_re).as_array());
use std::simd::{f32x16, f64x8};

use num_traits::Float;

macro_rules! fft_butterfly_n_simd {
($func_name:ident, $precision:ty, $lanes:literal, $simd_vector:ty) => {
pub fn $func_name(
reals: &mut [$precision],
imags: &mut [$precision],
twiddles_re: &[$precision],
twiddles_im: &[$precision],
dist: usize,
) {
let chunk_size = dist << 1;
assert!(chunk_size >= $lanes * 2);
reals
.chunks_exact_mut(chunk_size)
.zip(imags.chunks_exact_mut(chunk_size))
.for_each(|(reals_chunk, imags_chunk)| {
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist);
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist);

reals_s0
.chunks_exact_mut($lanes)
.zip(reals_s1.chunks_exact_mut($lanes))
.zip(imags_s0.chunks_exact_mut($lanes))
.zip(imags_s1.chunks_exact_mut($lanes))
.zip(twiddles_re.chunks_exact($lanes))
.zip(twiddles_im.chunks_exact($lanes))
.for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| {
let real_c0 = <$simd_vector>::from_slice(re_s0);
let real_c1 = <$simd_vector>::from_slice(re_s1);
let imag_c0 = <$simd_vector>::from_slice(im_s0);
let imag_c1 = <$simd_vector>::from_slice(im_s1);

let tw_re = <$simd_vector>::from_slice(w_re);
let tw_im = <$simd_vector>::from_slice(w_im);

re_s0.copy_from_slice((real_c0 + real_c1).as_array());
im_s0.copy_from_slice((imag_c0 + imag_c1).as_array());
let v_re = real_c0 - real_c1;
let v_im = imag_c0 - imag_c1;
re_s1.copy_from_slice((v_re * tw_re - v_im * tw_im).as_array());
im_s1.copy_from_slice((v_re * tw_im + v_im * tw_re).as_array());
});
});
});
}
};
}

// TODO(saveliy): parallelize
pub(crate) fn fft_chunk_n(
reals: &mut [Float],
imags: &mut [Float],
twiddles_re: &[Float],
twiddles_im: &[Float],
fft_butterfly_n_simd!(fft_64_chunk_n_simd, f64, 8, f64x8);
fft_butterfly_n_simd!(fft_32_chunk_n_simd, f32, 16, f32x16);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Shnatsel Bumped up # of lanes to 16 when using AVX-512 and f32

pub(crate) fn fft_chunk_n<T: Float>(
reals: &mut [T],
imags: &mut [T],
twiddles_re: &[T],
twiddles_im: &[T],
dist: usize,
) {
let chunk_size = dist << 1;
Expand Down Expand Up @@ -79,14 +84,14 @@ pub(crate) fn fft_chunk_n(
*im_s0 = imag_c0 + imag_c1;
let v_re = real_c0 - real_c1;
let v_im = imag_c0 - imag_c1;
*re_s1 = v_re * w_re - v_im * w_im;
*im_s1 = v_re * w_im + v_im * w_re;
*re_s1 = v_re * *w_re - v_im * *w_im;
*im_s1 = v_re * *w_im + v_im * *w_re;
});
});
}

/// `chunk_size == 4`, so hard code twiddle factors
pub(crate) fn fft_chunk_4(reals: &mut [Float], imags: &mut [Float]) {
pub(crate) fn fft_chunk_4<T: Float>(reals: &mut [T], imags: &mut [T]) {
let dist = 2;
let chunk_size = dist << 1;

Expand Down Expand Up @@ -120,7 +125,7 @@ pub(crate) fn fft_chunk_4(reals: &mut [Float], imags: &mut [Float]) {
}

/// `chunk_size == 2`, so skip phase
pub(crate) fn fft_chunk_2(reals: &mut [Float], imags: &mut [Float]) {
pub(crate) fn fft_chunk_2<T: Float>(reals: &mut [T], imags: &mut [T]) {
reals
.chunks_exact_mut(2)
.zip(imags.chunks_exact_mut(2))
Expand Down
Loading
Loading