From d911d4cf3a136605fe1536a477067afa3ddecf66 Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Fri, 7 Feb 2025 19:26:48 -0800 Subject: [PATCH] FFT: Add batch support --- Docs/sphinx_documentation/source/FFT.rst | 26 +++ Src/FFT/AMReX_FFT_Helper.H | 99 +++++---- Src/FFT/AMReX_FFT_OpenBCSolver.H | 23 +- Src/FFT/AMReX_FFT_Poisson.H | 4 +- Src/FFT/AMReX_FFT_R2C.H | 263 +++++++++++++---------- Src/FFT/AMReX_FFT_R2X.H | 24 +-- Tests/FFT/Batch/CMakeLists.txt | 10 + Tests/FFT/Batch/GNUmakefile | 26 +++ Tests/FFT/Batch/Make.package | 1 + Tests/FFT/Batch/main.cpp | 167 ++++++++++++++ Tests/FFT/R2C/main.cpp | 9 +- 11 files changed, 471 insertions(+), 181 deletions(-) create mode 100644 Tests/FFT/Batch/CMakeLists.txt create mode 100644 Tests/FFT/Batch/GNUmakefile create mode 100644 Tests/FFT/Batch/Make.package create mode 100644 Tests/FFT/Batch/main.cpp diff --git a/Docs/sphinx_documentation/source/FFT.rst b/Docs/sphinx_documentation/source/FFT.rst index 2a5957e40bc..27dbb9ec2b7 100644 --- a/Docs/sphinx_documentation/source/FFT.rst +++ b/Docs/sphinx_documentation/source/FFT.rst @@ -67,6 +67,32 @@ object. Therefore, one should cache it for reuse if possible. Although :cpp:`std::unique_ptr>` to store an object in one's class. +Class template `FFT::R2C` also supports batched FFTs. The batch size is set +in an :cpp:`FFT::Info` object passed to the constructor of +:cpp:`FFT::R2C`. Below is an example. + +.. highlight:: c++ + +:: + + int batch_size = 10; + Geometry geom(...); + MultiFab mf(ba, dm, batch_size, 0); + + FFT::Info info{}; + info.setBatchSize(batch_size)); + FFT::R2C r2c(geom.Domain(), info); + + auto const& [cba, cdm] = r2c.getSpectralDataLayout(); + cMultiFab cmf(cba, cdm, batch_size, 0); + + r2c.forward(mf, cmf); + + // Do work on cmf. + // Function forwardThenBackward is not yet supported for a batched FFT. + + r2c.backward(cmf, mf); + .. _sec:FFT:localr2c: FFT::LocalR2C Class diff --git a/Src/FFT/AMReX_FFT_Helper.H b/Src/FFT/AMReX_FFT_Helper.H index a0783dfac5d..1c41cbffff7 100644 --- a/Src/FFT/AMReX_FFT_Helper.H +++ b/Src/FFT/AMReX_FFT_Helper.H @@ -47,7 +47,7 @@ namespace amrex::FFT enum struct Direction { forward, backward, both, none }; -enum struct DomainStrategy { slab, pencil }; +enum struct DomainStrategy { automatic, slab, pencil }; AMREX_ENUM( Boundary, periodic, even, odd ); @@ -56,15 +56,28 @@ enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b, struct Info { - //! Supported only in 3D. When batch_mode is true, FFT is performed on + //! Domain composition strategy. + DomainStrategy domain_strategy = DomainStrategy::automatic; + + //! For automatic strategy, this is the size per process below which we + //! switch from slab to pencil. + int pencil_threshold = 8; + + //! Supported only in 3D. When twod_mode is true, FFT is performed on //! the first two dimensions only and the third dimension size is the //! batch size. - bool batch_mode = false; + bool twod_mode = false; + + //! Batched FFT size. Only support in R2C, not R2X. + int batch_size = 1; //! Max number of processes to use int nprocs = std::numeric_limits::max(); - Info& setBatchMode (bool x) { batch_mode = x; return *this; } + Info& setDomainStrategy (DomainStrategy s) { domain_strategy = s; return *this; } + Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; } + Info& setTwoDMode (bool x) { twod_mode = x; return *this; } + Info& setBatchSize (int bsize) { batch_size = bsize; return *this; } Info& setNumProcs (int n) { nprocs = n; return *this; } }; @@ -170,7 +183,7 @@ struct Plan } template - void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false) + void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false, int ncomp = 1) { static_assert(D == Direction::forward || D == Direction::backward); @@ -198,6 +211,7 @@ struct Plan howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2)) : AMREX_D_TERM(1, *1 , *box.length(2)); #endif + howmany *= ncomp; amrex::ignore_unused(nc); @@ -293,10 +307,10 @@ struct Plan } template - void init_r2c (IntVectND const& fft_size, void*, void*, bool cache); + void init_r2c (IntVectND const& fft_size, void*, void*, bool cache, int ncomp = 1); template - void init_c2c (Box const& box, VendorComplex* p) + void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1) { static_assert(D == Direction::forward || D == Direction::backward); @@ -307,6 +321,7 @@ struct Plan n = box.length(0); howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2)); + howmany *= ncomp; #if defined(AMREX_USE_CUDA) AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan)); @@ -1131,7 +1146,7 @@ struct Plan } }; -using Key = std::tuple,Direction,Kind>; +using Key = std::tuple,int,Direction,Kind>; using PlanD = typename Plan::VendorPlan; using PlanF = typename Plan::VendorPlan; @@ -1143,7 +1158,7 @@ void add_vendor_plan_f (Key const& key, PlanF plan); template template -void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool cache) +void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool cache, int ncomp) { static_assert(D == Direction::forward || D == Direction::backward); @@ -1154,10 +1169,10 @@ void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool n = 1; for (auto s : fft_size) { n *= s; } - howmany = 1; + howmany = ncomp; #if defined(AMREX_USE_GPU) - Key key = {fft_size.template expand<3>(), D, kind}; + Key key = {fft_size.template expand<3>(), ncomp, D, kind}; if (cache) { VendorPlan* cached_plan = nullptr; if constexpr (std::is_same_v) { @@ -1174,27 +1189,34 @@ void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool amrex::ignore_unused(cache); #endif + int len[M]; + for (int i = 0; i < M; ++i) { + len[i] = fft_size[M-1-i]; + } + + int nc = fft_size[0]/2+1; + for (int i = 1; i < M; ++i) { + nc *= fft_size[i]; + } + #if defined(AMREX_USE_CUDA) AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan)); AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0)); cufftType type; + int n_in, n_out; if constexpr (D == Direction::forward) { type = std::is_same_v ? CUFFT_R2C : CUFFT_D2Z; + n_in = n; + n_out = nc; } else { type = std::is_same_v ? CUFFT_C2R : CUFFT_Z2D; + n_in = nc; + n_out = n; } std::size_t work_size; - if constexpr (M == 1) { - AMREX_CUFFT_SAFE_CALL - (cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size)); - } else if constexpr (M == 2) { - AMREX_CUFFT_SAFE_CALL - (cufftMakePlan2d(plan, fft_size[1], fft_size[0], type, &work_size)); - } else if constexpr (M == 3) { - AMREX_CUFFT_SAFE_CALL - (cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size)); - } + AMREX_CUFFT_SAFE_CALL + (cufftMakePlanMany(plan, M, len, nullptr, 1, n_in, nullptr, 1, n_out, type, howmany, &work_size)); #elif defined(AMREX_USE_HIP) @@ -1219,11 +1241,11 @@ void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool if (M == 1) { pp = new mkl_desc_r(fft_size[0]); } else { - std::vector len(M); + std::vector len64(M); for (int idim = 0; idim < M; ++idim) { - len[idim] = fft_size[M-1-idim]; + len64[idim] = len[idim]; } - pp = new mkl_desc_r(len); + pp = new mkl_desc_r(len64); } #ifndef AMREX_USE_MKL_DFTI_2024 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, @@ -1231,7 +1253,9 @@ void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool #else pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE); #endif - + pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany); + pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n); + pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc); std::vector strides(M+1); strides[0] = 0; strides[M] = 1; @@ -1258,29 +1282,24 @@ void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool return; } - int size_for_row_major[M]; - for (int idim = 0; idim < M; ++idim) { - size_for_row_major[idim] = fft_size[M-1-idim]; - } - if constexpr (std::is_same_v) { if constexpr (D == Direction::forward) { - plan = fftwf_plan_dft_r2c - (M, size_for_row_major, (float*)pf, (fftwf_complex*)pb, + plan = fftwf_plan_many_dft_r2c + (M, len, howmany, (float*)pf, nullptr, 1, n, (fftwf_complex*)pb, nullptr, 1, nc, FFTW_ESTIMATE); } else { - plan = fftwf_plan_dft_c2r - (M, size_for_row_major, (fftwf_complex*)pb, (float*)pf, + plan = fftwf_plan_many_dft_c2r + (M, len, howmany, (fftwf_complex*)pb, nullptr, 1, nc, (float*)pf, nullptr, 1, n, FFTW_ESTIMATE); } } else { if constexpr (D == Direction::forward) { - plan = fftw_plan_dft_r2c - (M, size_for_row_major, (double*)pf, (fftw_complex*)pb, + plan = fftw_plan_many_dft_r2c + (M, len, howmany, (double*)pf, nullptr, 1, n, (fftw_complex*)pb, nullptr, 1, nc, FFTW_ESTIMATE); } else { - plan = fftw_plan_dft_c2r - (M, size_for_row_major, (fftw_complex*)pb, (double*)pf, + plan = fftw_plan_many_dft_c2r + (M, len, howmany, (fftw_complex*)pb, nullptr, 1, nc, (double*)pf, nullptr, 1, n, FFTW_ESTIMATE); } } @@ -1508,10 +1527,10 @@ namespace detail b = make_box(b); } auto const& ng = make_iv(mf.nGrowVect()); - FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), 1, ng, MFInfo{}.SetAlloc(false)); + FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), mf.nComp(), ng, MFInfo{}.SetAlloc(false)); using FAB = typename FA::fab_type; for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) { - submf.setFab(mfi, FAB(mfi.fabbox(), 1, mf[mfi].dataPtr())); + submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr())); } return submf; } diff --git a/Src/FFT/AMReX_FFT_OpenBCSolver.H b/Src/FFT/AMReX_FFT_OpenBCSolver.H index 58aa771cc6a..cd93ea2a67e 100644 --- a/Src/FFT/AMReX_FFT_OpenBCSolver.H +++ b/Src/FFT/AMReX_FFT_OpenBCSolver.H @@ -37,7 +37,7 @@ Box OpenBCSolver::make_grown_domain (Box const& domain, Info const& info) { IntVect len = domain.length(); #if (AMREX_SPACEDIM == 3) - if (info.batch_mode) { len[2] = 0; } + if (info.twod_mode) { len[2] = 0; } #else amrex::ignore_unused(info); #endif @@ -48,10 +48,11 @@ template OpenBCSolver::OpenBCSolver (Box const& domain, Info const& info) : m_domain(domain), m_info(info), - m_r2c(OpenBCSolver::make_grown_domain(domain,info), info) + m_r2c(OpenBCSolver::make_grown_domain(domain,info), + m_info.setDomainStrategy(FFT::DomainStrategy::slab)) { #if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { + if (m_info.twod_mode) { auto gdom = make_grown_domain(domain,m_info); gdom.enclosedCells(2); gdom.setSmall(2, 0); @@ -59,7 +60,7 @@ OpenBCSolver::OpenBCSolver (Box const& domain, Info const& info) m_info.nprocs, m_domain.length(2)}); gdom.setBig(2, nprocs-1); - m_r2c_green = std::make_unique>(gdom,info); + m_r2c_green = std::make_unique>(gdom,m_info); auto [sd, ord] = m_r2c_green->getSpectralData(); m_G_fft = cMF(*sd, amrex::make_alias, 0, 1); } else @@ -78,7 +79,7 @@ void OpenBCSolver::setGreensFunction (F const& greens_function) { BL_PROFILE("OpenBCSolver::setGreensFunction"); - auto* infab = m_info.batch_mode ? detail::get_fab(m_r2c_green->m_rx) + auto* infab = m_info.twod_mode ? detail::get_fab(m_r2c_green->m_rx) : detail::get_fab(m_r2c.m_rx); auto const& lo = m_domain.smallEnd(); auto const& lo3 = lo.dim3(); @@ -87,7 +88,7 @@ void OpenBCSolver::setGreensFunction (F const& greens_function) auto const& a = infab->array(); auto box = infab->box(); GpuArray nimages{1,1,1}; - int ndims = m_info.batch_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM; + int ndims = m_info.twod_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM; for (int idim = 0; idim < ndims; ++idim) { if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) { box.growHi(idim, -len[idim]+1); // +1 to include the middle plane @@ -129,13 +130,13 @@ void OpenBCSolver::setGreensFunction (F const& greens_function) }); } - if (m_info.batch_mode) { + if (m_info.twod_mode) { m_r2c_green->forward(m_r2c_green->m_rx); } else { m_r2c.forward(m_r2c.m_rx); } - if (!m_info.batch_mode) { + if (!m_info.twod_mode) { auto [sd, ord] = m_r2c.getSpectralData(); amrex::ignore_unused(ord); auto const* srcfab = detail::get_fab(*sd); @@ -166,7 +167,7 @@ void OpenBCSolver::solve (MF& phi, MF const& rho) inmf.setVal(T(0)); inmf.ParallelCopy(rho, 0, 0, 1); - m_r2c.m_openbc_half = !m_info.batch_mode; + m_r2c.m_openbc_half = !m_info.twod_mode; m_r2c.forward(inmf); m_r2c.m_openbc_half = false; @@ -183,7 +184,7 @@ void OpenBCSolver::solve (MF& phi, MF const& rho) Box const& rhobox = rhofab->box(); #if (AMREX_SPACEDIM == 3) Long leng = gfab->box().numPts(); - if (m_info.batch_mode) { + if (m_info.twod_mode) { AMREX_ASSERT(gfab->box().length(2) == 1 && leng == (rhobox.length(0) * rhobox.length(1))); } else { @@ -204,7 +205,7 @@ void OpenBCSolver::solve (MF& phi, MF const& rho) } } - m_r2c.m_openbc_half = !m_info.batch_mode; + m_r2c.m_openbc_half = !m_info.twod_mode; m_r2c.backward_doit(phi, phi.nGrowVect()); m_r2c.m_openbc_half = false; } diff --git a/Src/FFT/AMReX_FFT_Poisson.H b/Src/FFT/AMReX_FFT_Poisson.H index 36c4cc62b44..afeacf2b379 100644 --- a/Src/FFT/AMReX_FFT_Poisson.H +++ b/Src/FFT/AMReX_FFT_Poisson.H @@ -127,7 +127,7 @@ public: } } Info info{}; - info.setBatchMode(true); + info.setTwoDMode(true); if (periodic_xy) { m_r2c = std::make_unique>(m_geom.Domain(), info); @@ -145,7 +145,7 @@ public: std::make_pair(Boundary::periodic,Boundary::periodic), std::make_pair(Boundary::even,Boundary::even))}, m_r2c(std::make_unique> - (geom.Domain(), Info().setBatchMode(true))) + (geom.Domain(), Info().setTwoDMode(true))) { #if (AMREX_SPACEDIM == 3) AMREX_ALWAYS_ASSERT(geom.isPeriodic(0) && geom.isPeriodic(1)); diff --git a/Src/FFT/AMReX_FFT_R2C.H b/Src/FFT/AMReX_FFT_R2C.H index 141e8254116..8d1837e971e 100644 --- a/Src/FFT/AMReX_FFT_R2C.H +++ b/Src/FFT/AMReX_FFT_R2C.H @@ -29,9 +29,7 @@ template class PoissonHybrid; * For more details, we refer the users to * https://amrex-codes.github.io/amrex/docs_html/FFT_Chapter.html. */ -template - // Don't change the default. Otherwise OpenBCSolver might break. +template class R2C { public: @@ -79,12 +77,13 @@ public: */ template = 0> - void forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward) + void forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward, + int incomp = 0, int outcomp = 0) { BL_PROFILE("FFT::R2C::forwardbackward"); - this->forward(inmf); + this->forward(inmf, incomp); this->post_forward_doit_0(post_forward); - this->backward(outmf); + this->backward(outmf, outcomp); } /** @@ -98,7 +97,7 @@ public: */ template = 0> - void forward (MF const& inmf); + void forward (MF const& inmf, int incomp = 0); /** * \brief Forward transform @@ -111,7 +110,7 @@ public: */ template = 0> - void forward (MF const& inmf, cMF& outmf); + void forward (MF const& inmf, cMF& outmf, int incomp = 0, int outcomp = 0); /** * \brief Backward transform @@ -122,7 +121,7 @@ public: * \param outmf output data in MultiFab or FabArray> */ template = 0> - void backward (MF& outmf); + void backward (MF& outmf, int outcomp = 0); /** * \brief Backward transform @@ -135,7 +134,7 @@ public: */ template = 0> - void backward (cMF const& inmf, MF& outmf); + void backward (cMF const& inmf, MF& outmf, int incomp = 0, int outcomp = 0); //! Scaling factor. If the data goes through forward and then backward, //! the result multiplied by the scaling factor is equal to the original @@ -176,13 +175,15 @@ private: void prepare_openbc (); void backward_doit (MF& outmf, IntVect const& ngout = IntVect(0), - Periodicity const& period = Periodicity::NonPeriodic()); + Periodicity const& period = Periodicity::NonPeriodic(), + int outcomp = 0); void backward_doit (cMF const& inmf, MF& outmf, IntVect const& ngout = IntVect(0), - Periodicity const& period = Periodicity::NonPeriodic()); + Periodicity const& period = Periodicity::NonPeriodic(), + int incomp = 0, int outcomp = 0); - static std::pair,Plan> make_c2c_plans (cMF& inout); + std::pair,Plan> make_c2c_plans (cMF& inout) const; Plan m_fft_fwd_x{}; Plan m_fft_bwd_x{}; @@ -224,7 +225,7 @@ private: Box m_spectral_domain_y; Box m_spectral_domain_z; - std::unique_ptr> m_r2c_sub; + std::unique_ptr> m_r2c_sub; detail::SubHelper m_sub_helper; Info m_info; @@ -234,8 +235,8 @@ private: bool m_openbc_half = false; }; -template -R2C::R2C (Box const& domain, Info const& info) +template +R2C::R2C (Box const& domain, Info const& info) : m_real_domain(domain), m_spectral_domain_x(IntVect(0), IntVect(AMREX_D_DECL(domain.length(0)/2, domain.length(1)-1, @@ -262,9 +263,9 @@ R2C::R2C (Box const& domain, Info const& info) AMREX_ALWAYS_ASSERT(m_real_domain.numPts() > 1); #if (AMREX_SPACEDIM == 2) - AMREX_ALWAYS_ASSERT(!m_info.batch_mode); + AMREX_ALWAYS_ASSERT(!m_info.twod_mode); #else - if (m_info.batch_mode) { + if (m_info.twod_mode) { AMREX_ALWAYS_ASSERT((int(domain.length(0) > 1) + int(domain.length(1) > 1) + int(domain.length(2) > 1)) >= 2); @@ -274,7 +275,7 @@ R2C::R2C (Box const& domain, Info const& info) { Box subbox = m_sub_helper.make_box(m_real_domain); if (subbox.size() != m_real_domain.size()) { - m_r2c_sub = std::make_unique>(subbox, info); + m_r2c_sub = std::make_unique>(subbox, m_info); return; } } @@ -283,8 +284,16 @@ R2C::R2C (Box const& domain, Info const& info) int nprocs = std::min(ParallelContext::NProcsSub(), m_info.nprocs); #if (AMREX_SPACEDIM == 3) - if (S == DomainStrategy::slab && (m_real_domain.length(1) > 1)) { - if (m_info.batch_mode && m_real_domain.length(2) == 1) { + if (m_info.domain_strategy == DomainStrategy::automatic) { + int shortside = m_real_domain.shortside(); + if (shortside <= m_info.pencil_threshold*nprocs) { + m_info.domain_strategy = DomainStrategy::pencil; + } else { + m_info.domain_strategy = DomainStrategy::slab; + } + } + if (m_info.domain_strategy == DomainStrategy::slab && (m_real_domain.length(1) > 1)) { + if (m_info.twod_mode && m_real_domain.length(2) == 1) { m_slab_decomp = false; } else { m_slab_decomp = true; @@ -292,10 +301,12 @@ R2C::R2C (Box const& domain, Info const& info) } #endif + auto const ncomp = m_info.batch_size; + auto bax = amrex::decompose(m_real_domain, nprocs, {AMREX_D_DECL(false,!m_slab_decomp,true)}, true); DistributionMapping dmx = detail::make_iota_distromap(bax.size()); - m_rx.define(bax, dmx, 1, 0, MFInfo().SetAlloc(false)); + m_rx.define(bax, dmx, ncomp, 0, MFInfo().SetAlloc(false)); { BoxList bl = bax.boxList(); @@ -304,10 +315,10 @@ R2C::R2C (Box const& domain, Info const& info) b.setBig(0, m_spectral_domain_x.bigEnd(0)); } BoxArray cbax(std::move(bl)); - m_cx.define(cbax, dmx, 1, 0, MFInfo().SetAlloc(false)); + m_cx.define(cbax, dmx, ncomp, 0, MFInfo().SetAlloc(false)); } - m_do_alld_fft = (ParallelDescriptor::NProcs() == 1) && (! m_info.batch_mode); + m_do_alld_fft = (ParallelDescriptor::NProcs() == 1) && (! m_info.twod_mode); if (!m_do_alld_fft) // do a series of 1d or 2d ffts { @@ -319,7 +330,7 @@ R2C::R2C (Box const& domain, Info const& info) #if (AMREX_SPACEDIM == 2) bool batch_on_y = false; #else - bool batch_on_y = m_info.batch_mode && (m_real_domain.length(2) == 1); + bool batch_on_y = m_info.twod_mode && (m_real_domain.length(2) == 1); #endif DistributionMapping cdmy; if ((m_real_domain.length(1) > 1) && !m_slab_decomp && !batch_on_y) @@ -331,13 +342,13 @@ R2C::R2C (Box const& domain, Info const& info) } else { cdmy = detail::make_iota_distromap(cbay.size()); } - m_cy.define(cbay, cdmy, 1, 0, MFInfo().SetAlloc(false)); + m_cy.define(cbay, cdmy, ncomp, 0, MFInfo().SetAlloc(false)); } #endif #if (AMREX_SPACEDIM == 3) if (m_real_domain.length(1) > 1 && - (! m_info.batch_mode && m_real_domain.length(2) > 1)) + (! m_info.twod_mode && m_real_domain.length(2) > 1)) { auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs, {false,true,true}, true); @@ -349,7 +360,7 @@ R2C::R2C (Box const& domain, Info const& info) } else { cdmz = detail::make_iota_distromap(cbaz.size()); } - m_cz.define(cbaz, cdmz, 1, 0, MFInfo().SetAlloc(false)); + m_cz.define(cbaz, cdmz, ncomp, 0, MFInfo().SetAlloc(false)); } #endif @@ -402,14 +413,14 @@ R2C::R2C (Box const& domain, Info const& info) auto* pr = m_rx[myproc].dataPtr(); auto* pc = (typename Plan::VendorComplex *)m_cx[myproc].dataPtr(); #ifdef AMREX_USE_SYCL - m_fft_fwd_x.template init_r2c(box, pr, pc, m_slab_decomp); + m_fft_fwd_x.template init_r2c(box, pr, pc, m_slab_decomp, ncomp); m_fft_bwd_x = m_fft_fwd_x; #else if constexpr (D == Direction::both || D == Direction::forward) { - m_fft_fwd_x.template init_r2c(box, pr, pc, m_slab_decomp); + m_fft_fwd_x.template init_r2c(box, pr, pc, m_slab_decomp, ncomp); } if constexpr (D == Direction::both || D == Direction::backward) { - m_fft_bwd_x.template init_r2c(box, pr, pc, m_slab_decomp); + m_fft_bwd_x.template init_r2c(box, pr, pc, m_slab_decomp, ncomp); } #endif } @@ -434,21 +445,21 @@ R2C::R2C (Box const& domain, Info const& info) auto* pr = (void*)m_rx[0].dataPtr(); auto* pc = (void*)m_cx[0].dataPtr(); #ifdef AMREX_USE_SYCL - m_fft_fwd_x.template init_r2c(len, pr, pc, false); + m_fft_fwd_x.template init_r2c(len, pr, pc, false, ncomp); m_fft_bwd_x = m_fft_fwd_x; #else if constexpr (D == Direction::both || D == Direction::forward) { - m_fft_fwd_x.template init_r2c(len, pr, pc, false); + m_fft_fwd_x.template init_r2c(len, pr, pc, false, ncomp); } if constexpr (D == Direction::both || D == Direction::backward) { - m_fft_bwd_x.template init_r2c(len, pr, pc, false); + m_fft_bwd_x.template init_r2c(len, pr, pc, false, ncomp); } #endif } } -template -R2C::~R2C () +template +R2C::~R2C () { if (m_fft_bwd_x.plan != m_fft_fwd_x.plan) { m_fft_bwd_x.destroy(); @@ -468,14 +479,16 @@ R2C::~R2C () m_fft_fwd_x_half.destroy(); } -template -void R2C::prepare_openbc () +template +void R2C::prepare_openbc () { if (m_r2c_sub) { amrex::Abort("R2C: OpenBC not supported with reduced dimensions"); } #if (AMREX_SPACEDIM == 3) if (m_do_alld_fft) { return; } + auto const ncomp = m_info.batch_size; + if (m_slab_decomp && ! m_fft_fwd_x_half.defined) { auto* fab = detail::get_fab(m_rx); if (fab) { @@ -488,16 +501,16 @@ void R2C::prepare_openbc () detail::get_fab(m_cx)->dataPtr(); #ifdef AMREX_USE_SYCL m_fft_fwd_x_half.template init_r2c - (box, pr, pc, m_slab_decomp); + (box, pr, pc, m_slab_decomp, ncomp); m_fft_bwd_x_half = m_fft_fwd_x_half; #else if constexpr (D == Direction::both || D == Direction::forward) { m_fft_fwd_x_half.template init_r2c - (box, pr, pc, m_slab_decomp); + (box, pr, pc, m_slab_decomp, ncomp); } if constexpr (D == Direction::both || D == Direction::backward) { m_fft_bwd_x_half.template init_r2c - (box, pr, pc, m_slab_decomp); + (box, pr, pc, m_slab_decomp, ncomp); } #endif } @@ -522,26 +535,28 @@ void R2C::prepare_openbc () #endif } -template +template template > -void R2C::forward (MF const& inmf) +void R2C::forward (MF const& inmf, int incomp) { BL_PROFILE("FFT::R2C::forward(in)"); + auto const ncomp = m_info.batch_size; + if (m_r2c_sub) { if (m_sub_helper.ghost_safe(inmf.nGrowVect())) { - m_r2c_sub->forward(m_sub_helper.make_alias_mf(inmf)); + m_r2c_sub->forward(m_sub_helper.make_alias_mf(inmf), incomp); } else { - MF tmp(inmf.boxArray(), inmf.DistributionMap(), 1, 0); - tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); - m_r2c_sub->forward(m_sub_helper.make_alias_mf(tmp)); + MF tmp(inmf.boxArray(), inmf.DistributionMap(), ncomp, 0); + tmp.LocalCopy(inmf, incomp, 0, ncomp, IntVect(0)); + m_r2c_sub->forward(m_sub_helper.make_alias_mf(tmp),0); } return; } if (&m_rx != &inmf) { - m_rx.ParallelCopy(inmf, 0, 0, 1); + m_rx.ParallelCopy(inmf, incomp, 0, ncomp); } if (m_do_alld_fft) { @@ -553,95 +568,99 @@ void R2C::forward (MF const& inmf) fft_x.template compute_r2c(); if ( m_cmd_x2y) { - ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y); + ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, ncomp, m_dtos_x2y); } m_fft_fwd_y.template compute_c2c(); if ( m_cmd_y2z) { - ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z); + ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, ncomp, m_dtos_y2z); } #if (AMREX_SPACEDIM == 3) else if ( m_cmd_x2z) { if (m_openbc_half) { + NonLocalBC::PackComponents components{}; + components.n_components = ncomp; NonLocalBC::ApplyDtosAndProjectionOnReciever packing - {NonLocalBC::PackComponents{}, m_dtos_x2z}; + {components, m_dtos_x2z}; auto handler = ParallelCopy_nowait(m_cz, m_cx, *m_cmd_x2z_half, packing); Box upper_half = m_spectral_domain_z; // Note that z-direction's index is 0 because we z is the // unit-stride direction here. upper_half.growLo (0,-m_spectral_domain_z.length(0)/2); - m_cz.setVal(0, upper_half, 0, 1); + m_cz.setVal(0, upper_half, 0, ncomp); ParallelCopy_finish(m_cz, std::move(handler), *m_cmd_x2z_half, packing); } else { - ParallelCopy(m_cz, m_cx, *m_cmd_x2z, 0, 0, 1, m_dtos_x2z); + ParallelCopy(m_cz, m_cx, *m_cmd_x2z, 0, 0, ncomp, m_dtos_x2z); } } #endif m_fft_fwd_z.template compute_c2c(); } -template +template template > -void R2C::backward (MF& outmf) +void R2C::backward (MF& outmf, int outcomp) { - backward_doit(outmf); + backward_doit(outmf, IntVect(0), Periodicity::NonPeriodic(), outcomp); } -template -void R2C::backward_doit (MF& outmf, IntVect const& ngout, - Periodicity const& period) +template +void R2C::backward_doit (MF& outmf, IntVect const& ngout, + Periodicity const& period, int outcomp) { BL_PROFILE("FFT::R2C::backward(out)"); + auto const ncomp = m_info.batch_size; + if (m_r2c_sub) { if (m_sub_helper.ghost_safe(outmf.nGrowVect())) { MF submf = m_sub_helper.make_alias_mf(outmf); IntVect const& subngout = m_sub_helper.make_iv(ngout); Periodicity const& subperiod = m_sub_helper.make_periodicity(period); - m_r2c_sub->backward_doit(submf, subngout, subperiod); + m_r2c_sub->backward_doit(submf, subngout, subperiod, outcomp); } else { - MF tmp(outmf.boxArray(), outmf.DistributionMap(), 1, + MF tmp(outmf.boxArray(), outmf.DistributionMap(), ncomp, m_sub_helper.make_safe_ghost(outmf.nGrowVect())); - this->backward_doit(tmp, ngout, period); - outmf.LocalCopy(tmp, 0, 0, 1, tmp.nGrowVect()); + this->backward_doit(tmp, ngout, period, 0); + outmf.LocalCopy(tmp, 0, outcomp, ncomp, tmp.nGrowVect()); } return; } if (m_do_alld_fft) { m_fft_bwd_x.template compute_r2c(); - outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), + outmf.ParallelCopy(m_rx, 0, outcomp, ncomp, IntVect(0), amrex::elemwiseMin(ngout,outmf.nGrowVect()), period); return; } m_fft_bwd_z.template compute_c2c(); if ( m_cmd_z2y) { - ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y); + ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, ncomp, m_dtos_z2y); } #if (AMREX_SPACEDIM == 3) else if ( m_cmd_z2x) { auto const& cmd = m_openbc_half ? m_cmd_z2x_half : m_cmd_z2x; - ParallelCopy(m_cx, m_cz, *cmd, 0, 0, 1, m_dtos_z2x); + ParallelCopy(m_cx, m_cz, *cmd, 0, 0, ncomp, m_dtos_z2x); } #endif m_fft_bwd_y.template compute_c2c(); if ( m_cmd_y2x) { - ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x); + ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, ncomp, m_dtos_y2x); } auto& fft_x = m_openbc_half ? m_fft_bwd_x_half : m_fft_bwd_x; fft_x.template compute_r2c(); - outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), + outmf.ParallelCopy(m_rx, 0, outcomp, ncomp, IntVect(0), amrex::elemwiseMin(ngout,outmf.nGrowVect()), period); } -template +template std::pair, Plan> -R2C::make_c2c_plans (cMF& inout) +R2C::make_c2c_plans (cMF& inout) const { Plan fwd; Plan bwd; @@ -652,26 +671,28 @@ R2C::make_c2c_plans (cMF& inout) Box const& box = fab->box(); auto* pio = (typename Plan::VendorComplex *)fab->dataPtr(); + auto const ncomp = m_info.batch_size; + #ifdef AMREX_USE_SYCL - fwd.template init_c2c(box, pio); + fwd.template init_c2c(box, pio, ncomp); bwd = fwd; #else if constexpr (D == Direction::both || D == Direction::forward) { - fwd.template init_c2c(box, pio); + fwd.template init_c2c(box, pio, ncomp); } if constexpr (D == Direction::both || D == Direction::backward) { - bwd.template init_c2c(box, pio); + bwd.template init_c2c(box, pio, ncomp); } #endif return {fwd, bwd}; } -template +template template -void R2C::post_forward_doit_0 (F const& post_forward) +void R2C::post_forward_doit_0 (F const& post_forward) { - if (m_info.batch_mode) { + if (m_info.twod_mode || m_info.batch_size > 1) { amrex::Abort("xxxxx todo: post_forward"); #if (AMREX_SPACEDIM > 1) } else if (m_r2c_sub) { @@ -722,11 +743,11 @@ void R2C::post_forward_doit_0 (F const& post_forward) } } -template +template template -void R2C::post_forward_doit_1 (F const& post_forward) +void R2C::post_forward_doit_1 (F const& post_forward) { - if (m_info.batch_mode) { + if (m_info.twod_mode || m_info.batch_size > 1) { amrex::Abort("xxxxx todo: post_forward"); } else if (m_r2c_sub) { amrex::Abort("R2C::post_forward_doit_1: How did this happen?"); @@ -765,11 +786,11 @@ void R2C::post_forward_doit_1 (F const& post_forward) } } -template -T R2C::scalingFactor () const +template +T R2C::scalingFactor () const { #if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { + if (m_info.twod_mode) { if (m_real_domain.length(2) > 1) { return T(1)/T(Long(m_real_domain.length(0)) * Long(m_real_domain.length(1))); @@ -783,11 +804,11 @@ T R2C::scalingFactor () const } } -template +template template > -std::pair::cMF *, IntVect> -R2C::getSpectralData () +std::pair::cMF *, IntVect> +R2C::getSpectralData () { #if (AMREX_SPACEDIM > 1) if (m_r2c_sub) { @@ -804,100 +825,116 @@ R2C::getSpectralData () } } -template +template template > -void R2C::forward (MF const& inmf, cMF& outmf) +void R2C::forward (MF const& inmf, cMF& outmf, int incomp, int outcomp) { BL_PROFILE("FFT::R2C::forward(inout)"); + auto const ncomp = m_info.batch_size; + if (m_r2c_sub) { bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); MF inmf_sub, inmf_tmp; + int incomp_sub; if (inmf_safe) { inmf_sub = m_sub_helper.make_alias_mf(inmf); + incomp_sub = incomp; } else { - inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); - inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), ncomp, 0); + inmf_tmp.LocalCopy(inmf, incomp, 0, ncomp, IntVect(0)); inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + incomp_sub = 0; } bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); cMF outmf_sub, outmf_tmp; + int outcomp_sub; if (outmf_safe) { outmf_sub = m_sub_helper.make_alias_mf(outmf); + outcomp_sub = outcomp; } else { - outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, 0); + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), ncomp, 0); outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); + outcomp_sub = 0; } - m_r2c_sub->forward(inmf_sub, outmf_sub); + m_r2c_sub->forward(inmf_sub, outmf_sub, incomp_sub, outcomp_sub); if (!outmf_safe) { - outmf.LocalCopy(outmf_tmp, 0, 0, 1, IntVect(0)); + outmf.LocalCopy(outmf_tmp, 0, outcomp, ncomp, IntVect(0)); } } else { - forward(inmf); + forward(inmf, incomp); if (!m_cz.empty()) { // m_cz's order (z,x,y) -> (x,y,z) RotateBwd dtos{}; MultiBlockCommMetaData cmd (outmf, m_spectral_domain_x, m_cz, IntVect(0), dtos); - ParallelCopy(outmf, m_cz, cmd, 0, 0, 1, dtos); + ParallelCopy(outmf, m_cz, cmd, 0, outcomp, ncomp, dtos); } else if (!m_cy.empty()) { // m_cy's order (y,x,z) -> (x,y,z) MultiBlockCommMetaData cmd (outmf, m_spectral_domain_x, m_cy, IntVect(0), m_dtos_y2x); - ParallelCopy(outmf, m_cy, cmd, 0, 0, 1, m_dtos_y2x); + ParallelCopy(outmf, m_cy, cmd, 0, outcomp, ncomp, m_dtos_y2x); } else { - outmf.ParallelCopy(m_cx, 0, 0, 1); + outmf.ParallelCopy(m_cx, 0, outcomp, ncomp); } } } -template +template template > -void R2C::backward (cMF const& inmf, MF& outmf) +void R2C::backward (cMF const& inmf, MF& outmf, int incomp, int outcomp) { - backward_doit(inmf, outmf); + backward_doit(inmf, outmf, IntVect(0), Periodicity::NonPeriodic(), incomp, outcomp); } -template -void R2C::backward_doit (cMF const& inmf, MF& outmf, IntVect const& ngout, - Periodicity const& period) +template +void R2C::backward_doit (cMF const& inmf, MF& outmf, IntVect const& ngout, + Periodicity const& period, int incomp, int outcomp) { BL_PROFILE("FFT::R2C::backward(inout)"); + auto const ncomp = m_info.batch_size; + if (m_r2c_sub) { bool inmf_safe = m_sub_helper.ghost_safe(inmf.nGrowVect()); cMF inmf_sub, inmf_tmp; + int incomp_sub; if (inmf_safe) { inmf_sub = m_sub_helper.make_alias_mf(inmf); + incomp_sub = incomp; } else { - inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), 1, 0); - inmf_tmp.LocalCopy(inmf, 0, 0, 1, IntVect(0)); + inmf_tmp.define(inmf.boxArray(), inmf.DistributionMap(), ncomp, 0); + inmf_tmp.LocalCopy(inmf, incomp, 0, ncomp, IntVect(0)); inmf_sub = m_sub_helper.make_alias_mf(inmf_tmp); + incomp_sub = 0; } bool outmf_safe = m_sub_helper.ghost_safe(outmf.nGrowVect()); MF outmf_sub, outmf_tmp; + int outcomp_sub; if (outmf_safe) { outmf_sub = m_sub_helper.make_alias_mf(outmf); + outcomp_sub = outcomp; } else { IntVect const& ngtmp = m_sub_helper.make_safe_ghost(outmf.nGrowVect()); - outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), 1, ngtmp); + outmf_tmp.define(outmf.boxArray(), outmf.DistributionMap(), ncomp, ngtmp); outmf_sub = m_sub_helper.make_alias_mf(outmf_tmp); + outcomp_sub = 0; } IntVect const& subngout = m_sub_helper.make_iv(ngout); Periodicity const& subperiod = m_sub_helper.make_periodicity(period); - m_r2c_sub->backward_doit(inmf_sub, outmf_sub, subngout, subperiod); + m_r2c_sub->backward_doit(inmf_sub, outmf_sub, subngout, subperiod, incomp_sub, outcomp_sub); if (!outmf_safe) { - outmf.LocalCopy(outmf_tmp, 0, 0, 1, outmf_tmp.nGrowVect()); + outmf.LocalCopy(outmf_tmp, 0, outcomp, ncomp, outmf_tmp.nGrowVect()); } } else @@ -906,21 +943,21 @@ void R2C::backward_doit (cMF const& inmf, MF& outmf, IntVect const& ngout RotateFwd dtos{}; MultiBlockCommMetaData cmd (m_cz, m_spectral_domain_z, inmf, IntVect(0), dtos); - ParallelCopy(m_cz, inmf, cmd, 0, 0, 1, dtos); + ParallelCopy(m_cz, inmf, cmd, incomp, 0, ncomp, dtos); } else if (!m_cy.empty()) { // (x,y,z) -> m_cy's ordering (y,x,z) MultiBlockCommMetaData cmd (m_cy, m_spectral_domain_y, inmf, IntVect(0), m_dtos_x2y); - ParallelCopy(m_cy, inmf, cmd, 0, 0, 1, m_dtos_x2y); + ParallelCopy(m_cy, inmf, cmd, incomp, 0, ncomp, m_dtos_x2y); } else { - m_cx.ParallelCopy(inmf, 0, 0, 1); + m_cx.ParallelCopy(inmf, incomp, 0, ncomp); } - backward_doit(outmf, ngout, period); + backward_doit(outmf, ngout, period, outcomp); } } -template +template std::pair -R2C::getSpectralDataLayout () const +R2C::getSpectralDataLayout () const { #if (AMREX_SPACEDIM > 1) if (m_r2c_sub) { diff --git a/Src/FFT/AMReX_FFT_R2X.H b/Src/FFT/AMReX_FFT_R2X.H index 6d2b95bb833..050685baa48 100644 --- a/Src/FFT/AMReX_FFT_R2X.H +++ b/Src/FFT/AMReX_FFT_R2X.H @@ -133,11 +133,11 @@ R2X::R2X (Box const& domain, static_assert(std::is_same_v || std::is_same_v); - AMREX_ALWAYS_ASSERT(m_dom_0.numPts() > 1); + AMREX_ALWAYS_ASSERT((m_dom_0.numPts() > 1) && (m_info.batch_size == 1)); #if (AMREX_SPACEDIM == 2) - AMREX_ALWAYS_ASSERT(!m_info.batch_mode); + AMREX_ALWAYS_ASSERT(!m_info.twod_mode); #else - if (m_info.batch_mode) { + if (m_info.twod_mode) { AMREX_ALWAYS_ASSERT((int(domain.length(0) > 1) + int(domain.length(1) > 1) + int(domain.length(2) > 1)) >= 2); @@ -191,7 +191,7 @@ R2X::R2X (Box const& domain, #if (AMREX_SPACEDIM == 2) bool batch_on_y = false; #else - bool batch_on_y = m_info.batch_mode && (m_dom_0.length(2) == 1); + bool batch_on_y = m_info.twod_mode && (m_dom_0.length(2) == 1); #endif if ((domain.length(1) > 1) && !batch_on_y) { @@ -245,7 +245,7 @@ R2X::R2X (Box const& domain, #endif #if (AMREX_SPACEDIM == 3) - if (domain.length(2) > 1 && !m_info.batch_mode) { + if (domain.length(2) > 1 && !m_info.twod_mode) { if (! m_cy.empty()) { // copy(m_cy, m_cz) m_dom_cz = Box(IntVect(0), IntVect(AMREX_D_DECL(m_dom_cy.bigEnd(2), @@ -539,9 +539,9 @@ template T R2X::scalingFactor () const { Long r = 1; - int ndims = m_info.batch_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM; + int ndims = m_info.twod_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM; #if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode && m_dom_0.length(2) == 1) { ndims = 1; }; + if (m_info.twod_mode && m_dom_0.length(2) == 1) { ndims = 1; }; #endif for (int idim = 0; idim < ndims; ++idim) { r *= m_dom_0.length(idim); @@ -770,7 +770,7 @@ void R2X::forward (MF const& inmf, MF& outmf) this->forward(inmf); #if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { + if (m_info.twod_mode) { if (m_cy.empty() && !m_ry.empty()) { ParallelCopy(outmf, m_dom_rx, m_ry, 0, 0, 1, IntVect(0), Swap01{}); } else if (m_ry.empty() && m_cy.empty() && m_cx.empty()) { @@ -822,7 +822,7 @@ void R2X::forward (MF const& inmf, cMF& outmf) this->forward(inmf); #if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { + if (m_info.twod_mode) { if (!m_cy.empty()) { auto lo = m_dom_cy.smallEnd(); auto hi = m_dom_cy.bigEnd(); @@ -936,7 +936,7 @@ void R2X::backward (MF const& inmf, MF& outmf, IntVect const& ngout, else { #if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { + if (m_info.twod_mode) { if (m_cy.empty() && !m_ry.empty()) { ParallelCopy(m_ry, m_dom_ry, inmf, 0, 0, 1, IntVect(0), Swap01{}); } else if (m_ry.empty() && m_cy.empty() && m_cx.empty()) { @@ -995,7 +995,7 @@ void R2X::backward (cMF const& inmf, MF& outmf, IntVect const& ngout, else { #if (AMREX_SPACEDIM == 3) - if (m_info.batch_mode) { + if (m_info.twod_mode) { if (!m_cy.empty()) { ParallelCopy(m_cy, m_dom_cy, inmf, 0, 0, 1, IntVect(0), Swap01{}); } else if (m_ry.empty() && m_cy.empty() && !m_cx.empty()) { @@ -1021,7 +1021,7 @@ template template void R2X::post_forward_doit (FAB* fab, F const& f) { - if (m_info.batch_mode) { + if (m_info.twod_mode) { amrex::Abort("xxxxx post_forward_doit: todo"); } if (fab) { diff --git a/Tests/FFT/Batch/CMakeLists.txt b/Tests/FFT/Batch/CMakeLists.txt new file mode 100644 index 00000000000..21a9d3b2681 --- /dev/null +++ b/Tests/FFT/Batch/CMakeLists.txt @@ -0,0 +1,10 @@ +foreach(D IN LISTS AMReX_SPACEDIM) + set(_sources main.cpp) + + set(_input_files) + + setup_test(${D} _sources _input_files) + + unset(_sources) + unset(_input_files) +endforeach() diff --git a/Tests/FFT/Batch/GNUmakefile b/Tests/FFT/Batch/GNUmakefile new file mode 100644 index 00000000000..93376f44852 --- /dev/null +++ b/Tests/FFT/Batch/GNUmakefile @@ -0,0 +1,26 @@ +AMREX_HOME := ../../.. + +DEBUG = FALSE + +DIM = 3 + +COMP = gcc + +USE_MPI = TRUE +USE_OMP = FALSE +USE_CUDA = FALSE +USE_HIP = FALSE +USE_SYCL = FALSE + +USE_FFT = TRUE + +BL_NO_FORT = TRUE + +TINY_PROFILE = FALSE + +include $(AMREX_HOME)/Tools/GNUMake/Make.defs + +include ./Make.package +include $(AMREX_HOME)/Src/Base/Make.package + +include $(AMREX_HOME)/Tools/GNUMake/Make.rules diff --git a/Tests/FFT/Batch/Make.package b/Tests/FFT/Batch/Make.package new file mode 100644 index 00000000000..6b4b865e8fc --- /dev/null +++ b/Tests/FFT/Batch/Make.package @@ -0,0 +1 @@ +CEXE_sources += main.cpp diff --git a/Tests/FFT/Batch/main.cpp b/Tests/FFT/Batch/main.cpp new file mode 100644 index 00000000000..69f5aa5711c --- /dev/null +++ b/Tests/FFT/Batch/main.cpp @@ -0,0 +1,167 @@ +#include // Put this at the top for testing + +#include +#include +#include +#include + +using namespace amrex; + +int main (int argc, char* argv[]) +{ + amrex::Initialize(argc, argv); + { + BL_PROFILE("main"); + + AMREX_D_TERM(int n_cell_x = 64;, + int n_cell_y = 16;, + int n_cell_z = 32); + + AMREX_D_TERM(int max_grid_size_x = 32;, + int max_grid_size_y = 16;, + int max_grid_size_z = 16); + + AMREX_D_TERM(Real prob_lo_x = 0.;, + Real prob_lo_y = 0.;, + Real prob_lo_z = 0.); + AMREX_D_TERM(Real prob_hi_x = 1.;, + Real prob_hi_y = 1.;, + Real prob_hi_z = 1.); + + int batch_size = 4; + + { + ParmParse pp; + AMREX_D_TERM(pp.query("n_cell_x", n_cell_x);, + pp.query("n_cell_y", n_cell_y);, + pp.query("n_cell_z", n_cell_z)); + AMREX_D_TERM(pp.query("max_grid_size_x", max_grid_size_x);, + pp.query("max_grid_size_y", max_grid_size_y);, + pp.query("max_grid_size_z", max_grid_size_z)); + pp.query("batch_size", batch_size); + } + + Box domain(IntVect(0),IntVect(AMREX_D_DECL(n_cell_x-1,n_cell_y-1,n_cell_z-1))); + BoxArray ba(domain); + ba.maxSize(IntVect(AMREX_D_DECL(max_grid_size_x, + max_grid_size_y, + max_grid_size_z))); + DistributionMapping dm(ba); + + Geometry geom; + { + geom.define(domain, + RealBox(AMREX_D_DECL(prob_lo_x,prob_lo_y,prob_lo_z), + AMREX_D_DECL(prob_hi_x,prob_hi_y,prob_hi_z)), + CoordSys::cartesian, {AMREX_D_DECL(1,1,1)}); + } + auto const& dx = geom.CellSizeArray(); + + MultiFab mf(ba,dm,batch_size,0); + auto const& ma = mf.arrays(); + ParallelFor(mf, IntVect(0), batch_size, + [=] AMREX_GPU_DEVICE (int b, int i, int j, int k, int n) + { + AMREX_D_TERM(Real x = (i+0.5_rt) * dx[0] - 0.5_rt;, + Real y = (j+0.5_rt) * dx[1] - 0.5_rt;, + Real z = (k+0.5_rt) * dx[2] - 0.5_rt); + ma[b](i,j,k,n) = std::exp(-10._rt* + (AMREX_D_TERM(x*x*1.05_rt, + y*y*0.90_rt, + z*z))) + Real(n); + }); + + MultiFab mf2(ba,dm,batch_size,0); + + auto scaling = Real(1) / Real(geom.Domain().d_numPts()); + + cMultiFab cmf; + + // forward + { + FFT::Info info{}; + info.setDomainStrategy(FFT::DomainStrategy::pencil); + info.setBatchSize(batch_size); + FFT::R2C r2c(geom.Domain(), info); + auto const& [cba, cdm] = r2c.getSpectralDataLayout(); + cmf.define(cba, cdm, batch_size, 0); + r2c.forward(mf,cmf); + } + + // backward + { + FFT::Info info{}; + info.setDomainStrategy(FFT::DomainStrategy::slab); + info.setBatchSize(batch_size); + FFT::R2C r2c(geom.Domain(), info); + r2c.backward(cmf,mf2); + } + + { + auto const& ma2 = mf2.arrays(); + ParallelFor(mf2, IntVect(0), batch_size, + [=] AMREX_GPU_DEVICE (int b, int i, int j, int k, int n) + { + ma2[b](i,j,k,n) = ma[b](i,j,k,n) - ma2[b](i,j,k,n)*scaling; + }); + + auto error = mf2.norminf(0, batch_size, IntVect(0)); + amrex::Print() << " Expected to be close to zero: " << error << "\n"; +#ifdef AMREX_USE_FLOAT + auto eps = 1.e-6f; +#else + auto eps = 1.e-13; +#endif + AMREX_ALWAYS_ASSERT(error < eps); + } + + { + FFT::R2C r2c(geom.Domain()); + cMultiFab cmf2(cmf.boxArray(), cmf.DistributionMap(), 2, 0); + MultiFab errmf(cmf.boxArray(), cmf.DistributionMap(), cmf.nComp(), 0); + for (int icomp = 0; icomp < batch_size; ++icomp) { + r2c.forward(mf, cmf2, icomp, 1); + auto const& cma = cmf.const_arrays(); + auto const& cma2 = cmf2.const_arrays(); + auto const& ema = errmf.arrays(); + ParallelFor(errmf, [=] AMREX_GPU_DEVICE (int b, int i, int j, int k) + { + auto c = cma[b](i,j,k,icomp) - cma2[b](i,j,k,1); + ema[b](i,j,k,icomp) = amrex::norm(c); + }); + Gpu::streamSynchronize(); + } + + auto error = errmf.norminf(0, batch_size, IntVect(0)); + amrex::Print() << " Expected to be close to zero: " << error << "\n"; +#ifdef AMREX_USE_FLOAT + auto eps = 0.5e-6f; +#else + auto eps = 1.e-15; +#endif + AMREX_ALWAYS_ASSERT(error < eps); + } + + { + FFT::R2C r2c(geom.Domain()); + for (int icomp = 0; icomp < batch_size; ++icomp) { + r2c.backward(cmf, mf2, icomp, icomp); + } + + auto const& ma2 = mf2.arrays(); + ParallelFor(mf2, IntVect(0), batch_size, + [=] AMREX_GPU_DEVICE (int b, int i, int j, int k, int n) + { + ma2[b](i,j,k,n) = ma[b](i,j,k,n) - ma2[b](i,j,k,n)*scaling; + }); + + auto error = mf2.norminf(0, batch_size, IntVect(0)); + amrex::Print() << " Expected to be close to zero: " << error << "\n"; +#ifdef AMREX_USE_FLOAT + auto eps = 1.e-6f; +#else + auto eps = 1.e-13; +#endif + AMREX_ALWAYS_ASSERT(error < eps); + } + } + amrex::Finalize(); +} diff --git a/Tests/FFT/R2C/main.cpp b/Tests/FFT/R2C/main.cpp index ee70b43b7bb..1f3a0e68547 100644 --- a/Tests/FFT/R2C/main.cpp +++ b/Tests/FFT/R2C/main.cpp @@ -74,7 +74,8 @@ int main (int argc, char* argv[]) // forward { - FFT::R2C r2c(geom.Domain()); + FFT::R2C r2c + (geom.Domain(), FFT::Info{}.setDomainStrategy(FFT::DomainStrategy::pencil)); auto const& [cba, cdm] = r2c.getSpectralDataLayout(); cmf.define(cba, cdm, 1, 0); r2c.forward(mf,cmf); @@ -82,7 +83,8 @@ int main (int argc, char* argv[]) // backward { - FFT::R2C r2c(geom.Domain()); + FFT::R2C r2c + (geom.Domain(), FFT::Info{}.setDomainStrategy(FFT::DomainStrategy::pencil)); r2c.backward(cmf,mf2); } @@ -105,7 +107,8 @@ int main (int argc, char* argv[]) mf2.setVal(std::numeric_limits::max()); { // forward and backward - FFT::R2C r2c(geom.Domain()); + FFT::R2C r2c + (geom.Domain(), FFT::Info{}.setDomainStrategy(FFT::DomainStrategy::slab)); r2c.forwardThenBackward(mf, mf2, [=] AMREX_GPU_DEVICE (int, int, int, auto& sp) {