Skip to content

Commit

Permalink
Merge pull request #17 from QuState/feature/single-precision
Browse files Browse the repository at this point in the history
Add support for `f32`, as well as `f64`
  • Loading branch information
smu160 authored Apr 1, 2024
2 parents a63ad3f + 4b909ae commit a65b542
Show file tree
Hide file tree
Showing 11 changed files with 562 additions and 501 deletions.
File renamed without changes.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ keywords = ["quantum", "fft", "discrete", "fourier", "transform"]
categories = ["algorithms", "compression", "science"]
exclude = ["assets", "scripts", "benches"]

[dependencies]
num-traits = "0.2.18"

[dev-dependencies]
utilities = { path = "utilities" }
Expand All @@ -23,3 +25,4 @@ panic = "abort"
[profile.profiling]
inherits = "release"
debug = true

8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,16 @@ including [RustFFT](https://crates.io/crates/rustfft/), while using significantl
### Rust

```rust
use phastft::planner::Direction;
use phastft::fft;
use phastft::{
planner::Direction,
fft_64
};

fn main() {
let big_n = 1 << 10;
let mut reals: Vec<f64> = (1..=big_n).map(|i| i as f64).collect();
let mut imags: Vec<f64> = (1..=big_n).map(|i| i as f64).collect();
fft(&mut reals, &mut imags, Direction::Forward);
fft_64(&mut reals, &mut imags, Direction::Forward);
}
```

Expand Down
8 changes: 4 additions & 4 deletions examples/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@ use std::str::FromStr;

use utilities::gen_random_signal;

use phastft::fft;
use phastft::fft_64;
use phastft::planner::Direction;

fn benchmark_fft(n: usize) {
fn benchmark_fft_64(n: usize) {
let big_n = 1 << n;
let mut reals = vec![0.0; big_n];
let mut imags = vec![0.0; big_n];
gen_random_signal(&mut reals, &mut imags);

let now = std::time::Instant::now();
fft(&mut reals, &mut imags, Direction::Forward);
fft_64(&mut reals, &mut imags, Direction::Forward);
let elapsed = now.elapsed().as_micros();
println!("{elapsed}");
}
Expand All @@ -24,5 +24,5 @@ fn main() {

let n = usize::from_str(&args[1]).unwrap();

benchmark_fft(n);
benchmark_fft_64(n);
}
4 changes: 2 additions & 2 deletions examples/profile.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::env;
use std::str::FromStr;

use phastft::fft;
use phastft::fft_64;
use phastft::planner::Direction;

fn benchmark_fft(num_qubits: usize) {
let n = 1 << num_qubits;
let mut reals: Vec<f64> = (1..=n).map(|i| i as f64).collect();
let mut imags: Vec<f64> = (1..=n).map(|i| i as f64).collect();
fft(&mut reals, &mut imags, Direction::Forward);
fft_64(&mut reals, &mut imags, Direction::Forward);
}

fn main() {
Expand Down
44 changes: 3 additions & 41 deletions src/cobra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
//! Symposium on Foundations of Computer Science (Cat. No.98CB36280), Palo Alto, CA, USA, 1998, pp. 544-553, doi:
//! 10.1109/SFCS.1998.743505.
//! keywords: {Read-write memory;Costs;Computer science;Drives;Random access memory;Argon;Registers;Read only memory;Computational modeling;Libraries}
use crate::kernels::Float;
use num_traits::Float;

const BLOCK_WIDTH: usize = 128;
// size of the cacheline
Expand Down Expand Up @@ -78,7 +79,7 @@ pub(crate) fn bit_rev<T>(buf: &mut [T], log_n: usize) {
since = "0.1.0",
note = "Please use COBRA for a cache-optimal bit reverse permutation."
)]
fn complex_bit_rev(reals: &mut [Float], imags: &mut [Float], log_n: usize) {
fn complex_bit_rev<T: Float>(reals: &mut [T], imags: &mut [T], log_n: usize) {
let mut nodd: usize;
let mut noddrev; // to hold bitwise negated or odd values

Expand Down Expand Up @@ -129,25 +130,6 @@ fn complex_bit_rev(reals: &mut [Float], imags: &mut [Float], log_n: usize) {
}
}

/// Run in-place bit reversal on the entire state, in parallel.
/// This function uses 2 threads to run a bit reverse on the reals buffer on one thread, and the other thread handles
/// the bit reversal of the imaginaries buffer
#[allow(dead_code)]
#[deprecated(
since = "0.1.0",
note = "Naive bit reverse permutation is slow and not cache friendly. COBRA should be used instead."
)]
pub(crate) fn bit_reverse_permute_state_par(
reals: &mut [Float],
imags: &mut [Float],
log_n: usize,
) {
std::thread::scope(|s| {
s.spawn(|| bit_rev(reals, log_n));
s.spawn(|| bit_rev(imags, log_n));
});
}

#[allow(dead_code)]
#[deprecated(
since = "0.1.0",
Expand Down Expand Up @@ -333,26 +315,6 @@ mod tests {
}
}

#[test]
fn jennifer_method_parallel() {
for n in 2..24 {
let big_n = 1 << n;
let mut actual_re: Vec<f64> = (0..big_n).map(f64::from).collect();
let mut actual_im: Vec<f64> = (0..big_n).map(f64::from).collect();

#[allow(deprecated)]
bit_reverse_permute_state_par(&mut actual_re, &mut actual_im, n);

let input_re: Vec<f64> = (0..big_n).map(f64::from).collect();
let expected_re = top_down_bit_reverse_permutation(&input_re);
assert_eq!(actual_re, expected_re);

let input_im: Vec<f64> = (0..big_n).map(f64::from).collect();
let expected_im = top_down_bit_reverse_permutation(&input_im);
assert_eq!(actual_im, expected_im);
}
}

#[test]
fn naive_bit_reverse_permutation() {
for n in 2..24 {
Expand Down
113 changes: 59 additions & 54 deletions src/kernels.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,61 @@
use std::simd::f64x8;

pub type Float = f64;

pub(crate) fn fft_chunk_n_simd(
reals: &mut [Float],
imags: &mut [Float],
twiddles_re: &[Float],
twiddles_im: &[Float],
dist: usize,
) {
let chunk_size = dist << 1;
assert!(chunk_size >= 16);

reals
.chunks_exact_mut(chunk_size)
.zip(imags.chunks_exact_mut(chunk_size))
.for_each(|(reals_chunk, imags_chunk)| {
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist);
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist);

reals_s0
.chunks_exact_mut(8)
.zip(reals_s1.chunks_exact_mut(8))
.zip(imags_s0.chunks_exact_mut(8))
.zip(imags_s1.chunks_exact_mut(8))
.zip(twiddles_re.chunks_exact(8))
.zip(twiddles_im.chunks_exact(8))
.for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| {
let real_c0 = f64x8::from_slice(re_s0);
let real_c1 = f64x8::from_slice(re_s1);
let imag_c0 = f64x8::from_slice(im_s0);
let imag_c1 = f64x8::from_slice(im_s1);

let tw_re = f64x8::from_slice(w_re);
let tw_im = f64x8::from_slice(w_im);

re_s0.copy_from_slice((real_c0 + real_c1).as_array());
im_s0.copy_from_slice((imag_c0 + imag_c1).as_array());
let v_re = real_c0 - real_c1;
let v_im = imag_c0 - imag_c1;
re_s1.copy_from_slice((v_re * tw_re - v_im * tw_im).as_array());
im_s1.copy_from_slice((v_re * tw_im + v_im * tw_re).as_array());
use std::simd::{f32x16, f64x8};

use num_traits::Float;

macro_rules! fft_butterfly_n_simd {
($func_name:ident, $precision:ty, $lanes:literal, $simd_vector:ty) => {
pub fn $func_name(
reals: &mut [$precision],
imags: &mut [$precision],
twiddles_re: &[$precision],
twiddles_im: &[$precision],
dist: usize,
) {
let chunk_size = dist << 1;
assert!(chunk_size >= $lanes * 2);
reals
.chunks_exact_mut(chunk_size)
.zip(imags.chunks_exact_mut(chunk_size))
.for_each(|(reals_chunk, imags_chunk)| {
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist);
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist);

reals_s0
.chunks_exact_mut($lanes)
.zip(reals_s1.chunks_exact_mut($lanes))
.zip(imags_s0.chunks_exact_mut($lanes))
.zip(imags_s1.chunks_exact_mut($lanes))
.zip(twiddles_re.chunks_exact($lanes))
.zip(twiddles_im.chunks_exact($lanes))
.for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| {
let real_c0 = <$simd_vector>::from_slice(re_s0);
let real_c1 = <$simd_vector>::from_slice(re_s1);
let imag_c0 = <$simd_vector>::from_slice(im_s0);
let imag_c1 = <$simd_vector>::from_slice(im_s1);

let tw_re = <$simd_vector>::from_slice(w_re);
let tw_im = <$simd_vector>::from_slice(w_im);

re_s0.copy_from_slice((real_c0 + real_c1).as_array());
im_s0.copy_from_slice((imag_c0 + imag_c1).as_array());
let v_re = real_c0 - real_c1;
let v_im = imag_c0 - imag_c1;
re_s1.copy_from_slice((v_re * tw_re - v_im * tw_im).as_array());
im_s1.copy_from_slice((v_re * tw_im + v_im * tw_re).as_array());
});
});
});
}
};
}

// TODO(saveliy): parallelize
pub(crate) fn fft_chunk_n(
reals: &mut [Float],
imags: &mut [Float],
twiddles_re: &[Float],
twiddles_im: &[Float],
fft_butterfly_n_simd!(fft_64_chunk_n_simd, f64, 8, f64x8);
fft_butterfly_n_simd!(fft_32_chunk_n_simd, f32, 16, f32x16);

pub(crate) fn fft_chunk_n<T: Float>(
reals: &mut [T],
imags: &mut [T],
twiddles_re: &[T],
twiddles_im: &[T],
dist: usize,
) {
let chunk_size = dist << 1;
Expand Down Expand Up @@ -79,14 +84,14 @@ pub(crate) fn fft_chunk_n(
*im_s0 = imag_c0 + imag_c1;
let v_re = real_c0 - real_c1;
let v_im = imag_c0 - imag_c1;
*re_s1 = v_re * w_re - v_im * w_im;
*im_s1 = v_re * w_im + v_im * w_re;
*re_s1 = v_re * *w_re - v_im * *w_im;
*im_s1 = v_re * *w_im + v_im * *w_re;
});
});
}

/// `chunk_size == 4`, so hard code twiddle factors
pub(crate) fn fft_chunk_4(reals: &mut [Float], imags: &mut [Float]) {
pub(crate) fn fft_chunk_4<T: Float>(reals: &mut [T], imags: &mut [T]) {
let dist = 2;
let chunk_size = dist << 1;

Expand Down Expand Up @@ -120,7 +125,7 @@ pub(crate) fn fft_chunk_4(reals: &mut [Float], imags: &mut [Float]) {
}

/// `chunk_size == 2`, so skip phase
pub(crate) fn fft_chunk_2(reals: &mut [Float], imags: &mut [Float]) {
pub(crate) fn fft_chunk_2<T: Float>(reals: &mut [T], imags: &mut [T]) {
reals
.chunks_exact_mut(2)
.zip(imags.chunks_exact_mut(2))
Expand Down
Loading

0 comments on commit a65b542

Please sign in to comment.