Skip to content

Commit cca9459

Browse files
committed
Split FFT operations from halo communication
1 parent 98496b4 commit cca9459

File tree

5 files changed

+45
-7
lines changed

5 files changed

+45
-7
lines changed

src/core/electrostatics/p3m.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ void CoulombP3MImpl<FloatType, Architecture>::init_cpu_kernels() {
262262

263263
assert(p3m.fft);
264264
p3m.local_mesh.calc_local_ca_mesh(p3m.params, local_geo, skin, elc_layer);
265+
p3m.fft->init_halo();
265266
p3m.fft->init_fft();
266267
p3m.calc_differential_operator();
267268

@@ -390,6 +391,7 @@ Utils::Vector9d CoulombP3MImpl<FloatType, Architecture>::long_range_pressure(
390391

391392
if (p3m.sum_q2 > 0.) {
392393
charge_assign(particles);
394+
p3m.fft->perform_scalar_halo_gather();
393395
p3m.fft->perform_scalar_fwd_fft();
394396

395397
auto constexpr mesh_start = Utils::Vector3i::broadcast(0);
@@ -455,6 +457,7 @@ double CoulombP3MImpl<FloatType, Architecture>::long_range_kernel(
455457
system.coulomb.impl->solver)) {
456458
charge_assign(particles);
457459
}
460+
p3m.fft->perform_scalar_halo_gather();
458461
p3m.fft->perform_scalar_fwd_fft();
459462
}
460463

@@ -513,6 +516,7 @@ double CoulombP3MImpl<FloatType, Architecture>::long_range_kernel(
513516
not p3m.params.tuning and check_complex_residuals;
514517
p3m.fft->check_complex_residuals = check_residuals;
515518
p3m.fft->perform_vector_back_fft();
519+
p3m.fft->perform_vector_halo_spread();
516520
p3m.fft->check_complex_residuals = false;
517521

518522
auto const force_prefac = prefactor / volume;

src/core/magnetostatics/dp3m.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ void DipolarP3MImpl<FloatType, Architecture>::init_cpu_kernels() {
136136

137137
assert(dp3m.fft);
138138
dp3m.local_mesh.calc_local_ca_mesh(dp3m.params, local_geo, verlet_skin, 0.);
139+
dp3m.fft->init_halo();
139140
dp3m.fft->init_fft();
140141
dp3m.calc_differential_operator();
141142

@@ -252,6 +253,7 @@ double DipolarP3MImpl<FloatType, Architecture>::long_range_kernel(
252253

253254
if (dp3m.sum_mu2 > 0.) {
254255
dipole_assign(particles);
256+
dp3m.fft->perform_vector_halo_gather();
255257
dp3m.fft->perform_vector_fwd_fft();
256258
}
257259

@@ -353,6 +355,7 @@ double DipolarP3MImpl<FloatType, Architecture>::long_range_kernel(
353355
++index;
354356
});
355357
dp3m.fft->perform_scalar_back_fft();
358+
dp3m.fft->perform_scalar_halo_spread();
356359
/* Assign force component from mesh to particle */
357360
auto const d_rs = (d + dp3m.mesh.ks_pnum) % 3;
358361
Utils::integral_parameter<int, AssignTorques, 1, 7>(
@@ -404,6 +407,7 @@ double DipolarP3MImpl<FloatType, Architecture>::long_range_kernel(
404407
++index;
405408
});
406409
dp3m.fft->perform_vector_back_fft();
410+
dp3m.fft->perform_vector_halo_spread();
407411
/* Assign force component from mesh to particle */
408412
auto const d_rs = (d + dp3m.mesh.ks_pnum) % 3;
409413
Utils::integral_parameter<int, AssignForces, 1, 7>(

src/core/p3m/FFTBackendLegacy.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,11 @@ void FFTBackendLegacy<FloatType>::update_mesh_data() {
6262
}
6363
}
6464

65-
template <typename FloatType> void FFTBackendLegacy<FloatType>::init_fft() {
65+
template <typename FloatType> void FFTBackendLegacy<FloatType>::init_halo() {
6666
mesh_comm.resize(::comm_cart, local_mesh);
67+
}
68+
69+
template <typename FloatType> void FFTBackendLegacy<FloatType>::init_fft() {
6770
auto ca_mesh_size = fft->initialize_fft(
6871
::comm_cart, local_mesh.dim, local_mesh.margin, params.mesh,
6972
params.mesh_off, mesh.ks_pnum, ::communicator.node_grid);
@@ -79,31 +82,41 @@ template <typename FloatType> void FFTBackendLegacy<FloatType>::init_fft() {
7982

8083
template <typename FloatType>
8184
void FFTBackendLegacy<FloatType>::perform_vector_back_fft() {
82-
/* Back FFT force component mesh */
8385
for (auto &rs_mesh_field : rs_mesh_fields) {
8486
fft->backward_fft(::comm_cart, rs_mesh_field.data(),
8587
check_complex_residuals);
8688
}
87-
/* redistribute force component mesh */
89+
}
90+
91+
template <typename FloatType>
92+
void FFTBackendLegacy<FloatType>::perform_vector_halo_spread() {
8893
std::array<FloatType *, 3u> meshes = {{rs_mesh_fields[0u].data(),
8994
rs_mesh_fields[1u].data(),
9095
rs_mesh_fields[2u].data()}};
9196
mesh_comm.spread_grid(::comm_cart, meshes, local_mesh.dim);
9297
}
9398

9499
template <typename FloatType>
95-
void FFTBackendLegacy<FloatType>::perform_scalar_fwd_fft() {
100+
void FFTBackendLegacy<FloatType>::perform_scalar_halo_gather() {
96101
mesh_comm.gather_grid(::comm_cart, rs_mesh.data(), local_mesh.dim);
102+
}
103+
104+
template <typename FloatType>
105+
void FFTBackendLegacy<FloatType>::perform_scalar_fwd_fft() {
97106
fft->forward_fft(::comm_cart, rs_mesh.data());
98107
update_mesh_data();
99108
}
100109

101110
template <typename FloatType>
102-
void FFTBackendLegacy<FloatType>::perform_vector_fwd_fft() {
111+
void FFTBackendLegacy<FloatType>::perform_vector_halo_gather() {
103112
std::array<FloatType *, 3u> meshes = {{rs_mesh_fields[0u].data(),
104113
rs_mesh_fields[1u].data(),
105114
rs_mesh_fields[2u].data()}};
106115
mesh_comm.gather_grid(::comm_cart, meshes, local_mesh.dim);
116+
}
117+
118+
template <typename FloatType>
119+
void FFTBackendLegacy<FloatType>::perform_vector_fwd_fft() {
107120
for (auto &rs_mesh_field : rs_mesh_fields) {
108121
fft->forward_fft(::comm_cart, rs_mesh_field.data());
109122
}
@@ -112,9 +125,11 @@ void FFTBackendLegacy<FloatType>::perform_vector_fwd_fft() {
112125

113126
template <typename FloatType>
114127
void FFTBackendLegacy<FloatType>::perform_scalar_back_fft() {
115-
/* Back FFT force component mesh */
116128
fft->backward_fft(::comm_cart, rs_mesh.data(), check_complex_residuals);
117-
/* redistribute force component mesh */
129+
}
130+
131+
template <typename FloatType>
132+
void FFTBackendLegacy<FloatType>::perform_scalar_halo_spread() {
118133
mesh_comm.spread_grid(::comm_cart, rs_mesh.data(), local_mesh.dim);
119134
}
120135

src/core/p3m/FFTBackendLegacy.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,15 @@ class FFTBackendLegacy : public FFTBackend<FloatType> {
6767
FFTBackendLegacy(p3m_data_struct_fft<FloatType> &obj, bool dipolar);
6868
~FFTBackendLegacy() override;
6969
void init_fft() override;
70+
void init_halo() override;
7071
void perform_scalar_fwd_fft() override;
7172
void perform_vector_fwd_fft() override;
7273
void perform_scalar_back_fft() override;
7374
void perform_vector_back_fft() override;
75+
void perform_scalar_halo_gather() override;
76+
void perform_vector_halo_gather() override;
77+
void perform_scalar_halo_spread() override;
78+
void perform_vector_halo_spread() override;
7479
void update_mesh_data();
7580

7681
/**

src/core/p3m/data_struct.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ template <typename FloatType> class FFTBackend {
104104
virtual ~FFTBackend() = default;
105105
/** @brief Initialize the FFT plans and buffers. */
106106
virtual void init_fft() = 0;
107+
/** @brief Initialize the halo buffers. */
108+
virtual void init_halo() = 0;
107109
/** @brief Carry out the forward FFT of the scalar mesh. */
108110
virtual void perform_scalar_fwd_fft() = 0;
109111
/** @brief Carry out the forward FFT of the vector meshes. */
@@ -112,6 +114,14 @@ template <typename FloatType> class FFTBackend {
112114
virtual void perform_scalar_back_fft() = 0;
113115
/** @brief Carry out the backward FFT of the vector meshes. */
114116
virtual void perform_vector_back_fft() = 0;
117+
/** @brief Update scalar mesh halo with data from neighbors (accumulation). */
118+
virtual void perform_scalar_halo_gather() = 0;
119+
/** @brief Update vector mesh halo with data from neighbors (accumulation). */
120+
virtual void perform_vector_halo_gather() = 0;
121+
/** @brief Update scalar mesh halo of all neighbors. */
122+
virtual void perform_scalar_halo_spread() = 0;
123+
/** @brief Update vector mesh halo of all neighbors. */
124+
virtual void perform_vector_halo_spread() = 0;
115125
/** @brief Get indices of the k-space data layout. */
116126
virtual std::tuple<int, int, int> get_permutations() const = 0;
117127
};

0 commit comments

Comments
 (0)