-
Notifications
You must be signed in to change notification settings - Fork 403
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2782 from spectre-ns/fft
Pure xtensor FFT implementation
- Loading branch information
Showing
8 changed files
with
349 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,3 +62,4 @@ __pycache__ | |
|
||
# Generated files | ||
*.pc | ||
.vscode/settings.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
.. Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht | ||
Distributed under the terms of the BSD 3-Clause License. | ||
The full license is in the file LICENSE, distributed with this software. | ||
xfft | ||
==== | ||
|
||
Defined in ``xtensor/xfft.hpp`` | ||
|
||
.. doxygenclass:: xt::fft::convolve | ||
:project: xtensor | ||
:members: | ||
|
||
.. doxygentypedef:: xt::fft::fft | ||
:project: xtensor | ||
|
||
.. doxygentypedef:: xt::fft::ifft | ||
:project: xtensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
#ifdef XTENSOR_USE_TBB | ||
#include <oneapi/tbb.h> | ||
#endif | ||
#include <stdexcept> | ||
|
||
#include <xtl/xcomplex.hpp> | ||
|
||
#include <xtensor/xarray.hpp> | ||
#include <xtensor/xaxis_slice_iterator.hpp> | ||
#include <xtensor/xbuilder.hpp> | ||
#include <xtensor/xcomplex.hpp> | ||
#include <xtensor/xmath.hpp> | ||
#include <xtensor/xnoalias.hpp> | ||
#include <xtensor/xview.hpp> | ||
|
||
namespace xt | ||
{ | ||
namespace fft | ||
{ | ||
namespace detail | ||
{ | ||
template < | ||
class E, | ||
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true> | ||
inline auto radix2(E&& e) | ||
{ | ||
using namespace xt::placeholders; | ||
using namespace std::complex_literals; | ||
using value_type = typename std::decay_t<E>::value_type; | ||
using precision = typename value_type::value_type; | ||
auto N = e.size(); | ||
const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); | ||
// check for power of 2 | ||
if (!powerOfTwo || N == 0) | ||
{ | ||
// TODO: Replace implementation with dft | ||
XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2"); | ||
} | ||
auto pi = xt::numeric_constants<precision>::PI; | ||
xt::xtensor<value_type, 1> ev = e; | ||
if (N <= 1) | ||
{ | ||
return ev; | ||
} | ||
else | ||
{ | ||
#ifdef XTENSOR_USE_TBB | ||
xt::xtensor<value_type, 1> even; | ||
xt::xtensor<value_type, 1> odd; | ||
oneapi::tbb::parallel_invoke( | ||
[&] | ||
{ | ||
even = radix2(xt::view(ev, xt::range(0, _, 2))); | ||
}, | ||
[&] | ||
{ | ||
odd = radix2(xt::view(ev, xt::range(1, _, 2))); | ||
} | ||
); | ||
#else | ||
auto even = radix2(xt::view(ev, xt::range(0, _, 2))); | ||
auto odd = radix2(xt::view(ev, xt::range(1, _, 2))); | ||
#endif | ||
|
||
auto range = xt::arange<double>(N / 2); | ||
auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N); | ||
auto t = exp * odd; | ||
auto first_half = even + t; | ||
auto second_half = even - t; | ||
// TODO: should be a call to stack if performance was improved | ||
auto spectrum = xt::xtensor<value_type, 1>::from_shape({N}); | ||
xt::view(spectrum, xt::range(0, N / 2)) = first_half; | ||
xt::view(spectrum, xt::range(N / 2, N)) = second_half; | ||
return spectrum; | ||
} | ||
} | ||
|
||
template <typename E> | ||
auto transform_bluestein(E&& data) | ||
{ | ||
using value_type = typename std::decay_t<E>::value_type; | ||
using precision = typename value_type::value_type; | ||
|
||
// Find a power-of-2 convolution length m such that m >= n * 2 + 1 | ||
const std::size_t n = data.size(); | ||
size_t m = std::ceil(std::log2(n * 2 + 1)); | ||
m = std::pow(2, m); | ||
|
||
// Trignometric table | ||
auto exp_table = xt::xtensor<std::complex<precision>, 1>::from_shape({n}); | ||
xt::xtensor<std::size_t, 1> i = xt::pow(xt::linspace<std::size_t>(0, n - 1, n), 2); | ||
i %= (n * 2); | ||
|
||
auto angles = xt::eval(precision{3.141592653589793238463} * i / n); | ||
auto j = std::complex<precision>(0, 1); | ||
exp_table = xt::exp(-angles * j); | ||
|
||
// Temporary vectors and preprocessing | ||
auto av = xt::empty<std::complex<precision>>({m}); | ||
xt::view(av, xt::range(0, n)) = data * exp_table; | ||
|
||
|
||
auto bv = xt::empty<std::complex<precision>>({m}); | ||
xt::view(bv, xt::range(0, n)) = ::xt::conj(exp_table); | ||
xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view( | ||
::xt::conj(xt::flip(exp_table)), | ||
xt::range(xt::placeholders::_, -1) | ||
); | ||
|
||
// Convolution | ||
auto xv = radix2(av); | ||
auto yv = radix2(bv); | ||
auto spectrum_k = xv * yv; | ||
auto complex_args = xt::conj(spectrum_k); | ||
auto fft_res = radix2(complex_args); | ||
auto cv = xt::conj(fft_res) / m; | ||
|
||
return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table); | ||
} | ||
} // namespace detail | ||
|
||
/** | ||
* @brief 1D FFT of an Nd array along a specified axis | ||
* @param e an Nd expression to be transformed to the fourier domain | ||
* @param axis the axis along which to perform the 1D FFT | ||
* @return a transformed xarray of the specified precision | ||
*/ | ||
template < | ||
class E, | ||
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true> | ||
inline auto fft(E&& e, std::ptrdiff_t axis = -1) | ||
{ | ||
using value_type = typename std::decay_t<E>::value_type; | ||
using precision = typename value_type::value_type; | ||
const auto saxis = xt::normalize_axis(e.dimension(), axis); | ||
const size_t N = e.shape(saxis); | ||
const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); | ||
xt::xarray<std::complex<precision>> out = xt::eval(e); | ||
auto begin = xt::axis_slice_begin(out, saxis); | ||
auto end = xt::axis_slice_end(out, saxis); | ||
for (auto iter = begin; iter != end; iter++) | ||
{ | ||
if (powerOfTwo) | ||
{ | ||
xt::noalias(*iter) = detail::radix2(*iter); | ||
} | ||
else | ||
{ | ||
xt::noalias(*iter) = detail::transform_bluestein(*iter); | ||
} | ||
} | ||
return out; | ||
} | ||
|
||
/** | ||
* @brief 1D FFT of an Nd array along a specified axis | ||
* @param e an Nd expression to be transformed to the fourier domain | ||
* @param axis the axis along which to perform the 1D FFT | ||
* @return a transformed xarray of the specified precision | ||
*/ | ||
template < | ||
class E, | ||
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true> | ||
inline auto fft(E&& e, std::ptrdiff_t axis = -1) | ||
{ | ||
using value_type = typename std::decay<E>::type::value_type; | ||
return fft(xt::cast<std::complex<value_type>>(e), axis); | ||
} | ||
|
||
template < | ||
class E, | ||
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true> | ||
auto ifft(E&& e, std::ptrdiff_t axis = -1) | ||
{ | ||
// check the length of the data on that axis | ||
const std::size_t n = e.shape(axis); | ||
if (n == 0) | ||
{ | ||
XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention"); | ||
} | ||
auto complex_args = xt::conj(e); | ||
auto fft_res = xt::fft::fft(complex_args, axis); | ||
fft_res = xt::conj(fft_res); | ||
return fft_res; | ||
} | ||
|
||
template < | ||
class E, | ||
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true> | ||
inline auto ifft(E&& e, std::ptrdiff_t axis = -1) | ||
{ | ||
using value_type = typename std::decay<E>::type::value_type; | ||
return ifft(xt::cast<std::complex<value_type>>(e), axis); | ||
} | ||
|
||
/* | ||
* @brief performs a circular fft convolution xvec and yvec must | ||
* be the same shape. | ||
* @param xvec first array of the convolution | ||
* @param yvec second array of the convolution | ||
* @param axis axis along which to perform the convolution | ||
*/ | ||
template <typename E1, typename E2> | ||
auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1) | ||
{ | ||
// we could broadcast but that could get complicated??? | ||
if (xvec.dimension() != yvec.dimension()) | ||
{ | ||
XTENSOR_THROW(std::runtime_error, "Mismatched dimentions"); | ||
} | ||
|
||
auto saxis = xt::normalize_axis(xvec.dimension(), axis); | ||
if (xvec.shape(saxis) != yvec.shape(saxis)) | ||
{ | ||
XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis"); | ||
} | ||
|
||
const std::size_t n = xvec.shape(saxis); | ||
|
||
auto xv = fft(xvec, axis); | ||
auto yv = fft(yvec, axis); | ||
|
||
auto begin_x = xt::axis_slice_begin(xv, saxis); | ||
auto end_x = xt::axis_slice_end(xv, saxis); | ||
auto iter_y = xt::axis_slice_begin(yv, saxis); | ||
|
||
for (auto iter = begin_x; iter != end_x; iter++) | ||
{ | ||
(*iter) = (*iter_y++) * (*iter); | ||
} | ||
|
||
auto outvec = ifft(xv, axis); | ||
|
||
// Scaling (because this FFT implementation omits it) | ||
outvec = outvec / n; | ||
|
||
return outvec; | ||
} | ||
|
||
} | ||
} // namespace xt::fft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#include "xtensor/xarray.hpp" | ||
#include "xtensor/xfft.hpp" | ||
|
||
#include "test_common_macros.hpp" | ||
|
||
namespace xt | ||
{ | ||
TEST(xfft, fft_power_2) | ||
{ | ||
size_t k = 2; | ||
size_t n = 8192; | ||
size_t A = 10; | ||
auto x = xt::linspace<float>(0, static_cast<float>(n - 1), n); | ||
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n); | ||
auto res = xt::fft::fft(y) / (n / 2); | ||
REQUIRE(A == doctest::Approx(std::abs(res(k))).epsilon(.0001)); | ||
} | ||
|
||
TEST(xfft, ifft_power_2) | ||
{ | ||
size_t k = 2; | ||
size_t n = 8; | ||
size_t A = 10; | ||
auto x = xt::linspace<float>(0, static_cast<float>(n - 1), n); | ||
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n); | ||
auto res = xt::fft::ifft(y) / (n / 2); | ||
REQUIRE(A == doctest::Approx(std::abs(res(k))).epsilon(.0001)); | ||
} | ||
|
||
TEST(xfft, convolve_power_2) | ||
{ | ||
xt::xarray<float> x = {1.0, 1.0, 1.0, 5.0}; | ||
xt::xarray<float> y = {5.0, 1.0, 1.0, 1.0}; | ||
xt::xarray<float> expected = {12, 12, 12, 28}; | ||
|
||
auto result = xt::fft::convolve(x, y); | ||
|
||
for (size_t i = 0; i < x.size(); i++) | ||
{ | ||
REQUIRE(expected(i) == doctest::Approx(std::abs(result(i))).epsilon(.0001)); | ||
} | ||
} | ||
|
||
TEST(xfft, fft_n_0_axis) | ||
{ | ||
size_t k = 2; | ||
size_t n = 10; | ||
size_t A = 1; | ||
size_t dim = 10; | ||
auto x = xt::linspace<float>(0, n - 1, n) * xt::ones<float>({dim, n}); | ||
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n); | ||
y = xt::transpose(y); | ||
auto res = xt::fft::fft(y, 0) / (n / 2.0); | ||
REQUIRE(A == doctest::Approx(std::abs(res(k, 0))).epsilon(.0001)); | ||
REQUIRE(A == doctest::Approx(std::abs(res(k, 1))).epsilon(.0001)); | ||
} | ||
|
||
TEST(xfft, fft_n_1_axis) | ||
{ | ||
size_t k = 2; | ||
size_t n = 15; | ||
size_t A = 1; | ||
size_t dim = 2; | ||
auto x = xt::linspace<float>(0, n - 1, n) * xt::ones<float>({dim, n}); | ||
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n); | ||
auto res = xt::fft::fft(y) / (n / 2.0); | ||
REQUIRE(A == doctest::Approx(std::abs(res(0, k))).epsilon(.0001)); | ||
REQUIRE(A == doctest::Approx(std::abs(res(1, k))).epsilon(.0001)); | ||
} | ||
|
||
TEST(xfft, convolve_n) | ||
{ | ||
xt::xarray<float> x = {1.0, 1.0, 1.0, 5.0, 1.0}; | ||
xt::xarray<float> y = {5.0, 1.0, 1.0, 1.0, 1.0}; | ||
xt::xarray<size_t> expected = {13, 13, 13, 29, 13}; | ||
|
||
auto result = xt::fft::convolve(x, y); | ||
|
||
xt::xarray<float> abs = xt::abs(result); | ||
|
||
for (size_t i = 0; i < abs.size(); i++) | ||
{ | ||
REQUIRE(expected(i) == doctest::Approx(abs(i)).epsilon(.0001)); | ||
} | ||
} | ||
} |