Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT: Add batch support #4327

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions Docs/sphinx_documentation/source/FFT.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,32 @@ object. Therefore, one should cache it for reuse if possible. Although
:cpp:`std::unique_ptr<FFT::R2C<Real>>` to store an object in one's class.


Class template `FFT::R2C` also supports batched FFTs. The batch size is set
in an :cpp:`FFT::Info` object passed to the constructor of
:cpp:`FFT::R2C`. Below is an example.

.. highlight:: c++

::

int batch_size = 10;
Geometry geom(...);
MultiFab mf(ba, dm, batch_size, 0);

FFT::Info info{};
info.setBatchSize(batch_size));
FFT::R2C<Real,FFT::Direction::both> r2c(geom.Domain(), info);

auto const& [cba, cdm] = r2c.getSpectralDataLayout();
cMultiFab cmf(cba, cdm, batch_size, 0);

r2c.forward(mf, cmf);

// Do work on cmf.
// Function forwardThenBackward is not yet supported for a batched FFT.

r2c.backward(cmf, mf);

.. _sec:FFT:localr2c:

FFT::LocalR2C Class
Expand Down
99 changes: 59 additions & 40 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace amrex::FFT

enum struct Direction { forward, backward, both, none };

enum struct DomainStrategy { slab, pencil };
enum struct DomainStrategy { automatic, slab, pencil };

AMREX_ENUM( Boundary, periodic, even, odd );

Expand All @@ -56,15 +56,28 @@ enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,

struct Info
{
//! Supported only in 3D. When batch_mode is true, FFT is performed on
//! Domain composition strategy.
DomainStrategy domain_strategy = DomainStrategy::automatic;

//! For automatic strategy, this is the size per process below which we
//! switch from slab to pencil.
int pencil_threshold = 8;

//! Supported only in 3D. When twod_mode is true, FFT is performed on
//! the first two dimensions only and the third dimension size is the
//! batch size.
bool batch_mode = false;
bool twod_mode = false;

//! Batched FFT size. Only support in R2C, not R2X.
int batch_size = 1;

//! Max number of processes to use
int nprocs = std::numeric_limits<int>::max();

Info& setBatchMode (bool x) { batch_mode = x; return *this; }
Info& setDomainStrategy (DomainStrategy s) { domain_strategy = s; return *this; }
Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; }
Info& setTwoDMode (bool x) { twod_mode = x; return *this; }
Info& setBatchSize (int bsize) { batch_size = bsize; return *this; }
Info& setNumProcs (int n) { nprocs = n; return *this; }
};

Expand Down Expand Up @@ -170,7 +183,7 @@ struct Plan
}

template <Direction D>
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false)
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false, int ncomp = 1)
{
static_assert(D == Direction::forward || D == Direction::backward);

Expand Down Expand Up @@ -198,6 +211,7 @@ struct Plan
howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
: AMREX_D_TERM(1, *1 , *box.length(2));
#endif
howmany *= ncomp;

amrex::ignore_unused(nc);

Expand Down Expand Up @@ -293,10 +307,10 @@ struct Plan
}

template <Direction D, int M>
void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache);
void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache, int ncomp = 1);

template <Direction D>
void init_c2c (Box const& box, VendorComplex* p)
void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1)
{
static_assert(D == Direction::forward || D == Direction::backward);

Expand All @@ -307,6 +321,7 @@ struct Plan

n = box.length(0);
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
howmany *= ncomp;

#if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
Expand Down Expand Up @@ -1131,7 +1146,7 @@ struct Plan
}
};

using Key = std::tuple<IntVectND<3>,Direction,Kind>;
using Key = std::tuple<IntVectND<3>,int,Direction,Kind>;
using PlanD = typename Plan<double>::VendorPlan;
using PlanF = typename Plan<float>::VendorPlan;

Expand All @@ -1143,7 +1158,7 @@ void add_vendor_plan_f (Key const& key, PlanF plan);

template <typename T>
template <Direction D, int M>
void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache)
void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache, int ncomp)
{
static_assert(D == Direction::forward || D == Direction::backward);

Expand All @@ -1154,10 +1169,10 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool

n = 1;
for (auto s : fft_size) { n *= s; }
howmany = 1;
howmany = ncomp;

#if defined(AMREX_USE_GPU)
Key key = {fft_size.template expand<3>(), D, kind};
Key key = {fft_size.template expand<3>(), ncomp, D, kind};
if (cache) {
VendorPlan* cached_plan = nullptr;
if constexpr (std::is_same_v<float,T>) {
Expand All @@ -1174,27 +1189,34 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
amrex::ignore_unused(cache);
#endif

int len[M];
for (int i = 0; i < M; ++i) {
len[i] = fft_size[M-1-i];
}

int nc = fft_size[0]/2+1;
for (int i = 1; i < M; ++i) {
nc *= fft_size[i];
}

#if defined(AMREX_USE_CUDA)

AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
cufftType type;
int n_in, n_out;
if constexpr (D == Direction::forward) {
type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
n_in = n;
n_out = nc;
} else {
type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
n_in = nc;
n_out = n;
}
std::size_t work_size;
if constexpr (M == 1) {
AMREX_CUFFT_SAFE_CALL
(cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size));
} else if constexpr (M == 2) {
AMREX_CUFFT_SAFE_CALL
(cufftMakePlan2d(plan, fft_size[1], fft_size[0], type, &work_size));
} else if constexpr (M == 3) {
AMREX_CUFFT_SAFE_CALL
(cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size));
}
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, M, len, nullptr, 1, n_in, nullptr, 1, n_out, type, howmany, &work_size));

#elif defined(AMREX_USE_HIP)

Expand All @@ -1219,19 +1241,21 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
if (M == 1) {
pp = new mkl_desc_r(fft_size[0]);
} else {
std::vector<std::int64_t> len(M);
std::vector<std::int64_t> len64(M);
for (int idim = 0; idim < M; ++idim) {
len[idim] = fft_size[M-1-idim];
len64[idim] = len[idim];
}
pp = new mkl_desc_r(len);
pp = new mkl_desc_r(len64);
}
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
#else
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
#endif

pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
std::vector<std::int64_t> strides(M+1);
strides[0] = 0;
strides[M] = 1;
Expand All @@ -1258,29 +1282,24 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
return;
}

int size_for_row_major[M];
for (int idim = 0; idim < M; ++idim) {
size_for_row_major[idim] = fft_size[M-1-idim];
}

if constexpr (std::is_same_v<float,T>) {
if constexpr (D == Direction::forward) {
plan = fftwf_plan_dft_r2c
(M, size_for_row_major, (float*)pf, (fftwf_complex*)pb,
plan = fftwf_plan_many_dft_r2c
(M, len, howmany, (float*)pf, nullptr, 1, n, (fftwf_complex*)pb, nullptr, 1, nc,
FFTW_ESTIMATE);
} else {
plan = fftwf_plan_dft_c2r
(M, size_for_row_major, (fftwf_complex*)pb, (float*)pf,
plan = fftwf_plan_many_dft_c2r
(M, len, howmany, (fftwf_complex*)pb, nullptr, 1, nc, (float*)pf, nullptr, 1, n,
FFTW_ESTIMATE);
}
} else {
if constexpr (D == Direction::forward) {
plan = fftw_plan_dft_r2c
(M, size_for_row_major, (double*)pf, (fftw_complex*)pb,
plan = fftw_plan_many_dft_r2c
(M, len, howmany, (double*)pf, nullptr, 1, n, (fftw_complex*)pb, nullptr, 1, nc,
FFTW_ESTIMATE);
} else {
plan = fftw_plan_dft_c2r
(M, size_for_row_major, (fftw_complex*)pb, (double*)pf,
plan = fftw_plan_many_dft_c2r
(M, len, howmany, (fftw_complex*)pb, nullptr, 1, nc, (double*)pf, nullptr, 1, n,
FFTW_ESTIMATE);
}
}
Expand Down Expand Up @@ -1508,10 +1527,10 @@ namespace detail
b = make_box(b);
}
auto const& ng = make_iv(mf.nGrowVect());
FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), 1, ng, MFInfo{}.SetAlloc(false));
FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), mf.nComp(), ng, MFInfo{}.SetAlloc(false));
using FAB = typename FA::fab_type;
for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
submf.setFab(mfi, FAB(mfi.fabbox(), 1, mf[mfi].dataPtr()));
submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
}
return submf;
}
Expand Down
23 changes: 12 additions & 11 deletions Src/FFT/AMReX_FFT_OpenBCSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Box OpenBCSolver<T>::make_grown_domain (Box const& domain, Info const& info)
{
IntVect len = domain.length();
#if (AMREX_SPACEDIM == 3)
if (info.batch_mode) { len[2] = 0; }
if (info.twod_mode) { len[2] = 0; }
#else
amrex::ignore_unused(info);
#endif
Expand All @@ -48,18 +48,19 @@ template <typename T>
OpenBCSolver<T>::OpenBCSolver (Box const& domain, Info const& info)
: m_domain(domain),
m_info(info),
m_r2c(OpenBCSolver<T>::make_grown_domain(domain,info), info)
m_r2c(OpenBCSolver<T>::make_grown_domain(domain,info),
m_info.setDomainStrategy(FFT::DomainStrategy::slab))
{
#if (AMREX_SPACEDIM == 3)
if (m_info.batch_mode) {
if (m_info.twod_mode) {
auto gdom = make_grown_domain(domain,m_info);
gdom.enclosedCells(2);
gdom.setSmall(2, 0);
int nprocs = std::min({ParallelContext::NProcsSub(),
m_info.nprocs,
m_domain.length(2)});
gdom.setBig(2, nprocs-1);
m_r2c_green = std::make_unique<R2C<T>>(gdom,info);
m_r2c_green = std::make_unique<R2C<T>>(gdom,m_info);
auto [sd, ord] = m_r2c_green->getSpectralData();
m_G_fft = cMF(*sd, amrex::make_alias, 0, 1);
} else
Expand All @@ -78,7 +79,7 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
{
BL_PROFILE("OpenBCSolver::setGreensFunction");

auto* infab = m_info.batch_mode ? detail::get_fab(m_r2c_green->m_rx)
auto* infab = m_info.twod_mode ? detail::get_fab(m_r2c_green->m_rx)
: detail::get_fab(m_r2c.m_rx);
auto const& lo = m_domain.smallEnd();
auto const& lo3 = lo.dim3();
Expand All @@ -87,7 +88,7 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
auto const& a = infab->array();
auto box = infab->box();
GpuArray<int,3> nimages{1,1,1};
int ndims = m_info.batch_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM;
int ndims = m_info.twod_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM;
for (int idim = 0; idim < ndims; ++idim) {
if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) {
box.growHi(idim, -len[idim]+1); // +1 to include the middle plane
Expand Down Expand Up @@ -129,13 +130,13 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
});
}

if (m_info.batch_mode) {
if (m_info.twod_mode) {
m_r2c_green->forward(m_r2c_green->m_rx);
} else {
m_r2c.forward(m_r2c.m_rx);
}

if (!m_info.batch_mode) {
if (!m_info.twod_mode) {
auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
auto const* srcfab = detail::get_fab(*sd);
Expand Down Expand Up @@ -166,7 +167,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
inmf.setVal(T(0));
inmf.ParallelCopy(rho, 0, 0, 1);

m_r2c.m_openbc_half = !m_info.batch_mode;
m_r2c.m_openbc_half = !m_info.twod_mode;
m_r2c.forward(inmf);
m_r2c.m_openbc_half = false;

Expand All @@ -183,7 +184,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
Box const& rhobox = rhofab->box();
#if (AMREX_SPACEDIM == 3)
Long leng = gfab->box().numPts();
if (m_info.batch_mode) {
if (m_info.twod_mode) {
AMREX_ASSERT(gfab->box().length(2) == 1 &&
leng == (rhobox.length(0) * rhobox.length(1)));
} else {
Expand All @@ -204,7 +205,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
}
}

m_r2c.m_openbc_half = !m_info.batch_mode;
m_r2c.m_openbc_half = !m_info.twod_mode;
m_r2c.backward_doit(phi, phi.nGrowVect());
m_r2c.m_openbc_half = false;
}
Expand Down
4 changes: 2 additions & 2 deletions Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public:
}
}
Info info{};
info.setBatchMode(true);
info.setTwoDMode(true);
if (periodic_xy) {
m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain(),
info);
Expand All @@ -145,7 +145,7 @@ public:
std::make_pair(Boundary::periodic,Boundary::periodic),
std::make_pair(Boundary::even,Boundary::even))},
m_r2c(std::make_unique<R2C<typename MF::value_type>>
(geom.Domain(), Info().setBatchMode(true)))
(geom.Domain(), Info().setTwoDMode(true)))
{
#if (AMREX_SPACEDIM == 3)
AMREX_ALWAYS_ASSERT(geom.isPeriodic(0) && geom.isPeriodic(1));
Expand Down
Loading
Loading