Skip to content

Commit

Permalink
Removed duplicate code from BufferDevice (#41)
Browse files Browse the repository at this point in the history
* Removed duplicate code from BufferDevice. Fixed bug in test. Added additional tests.

* removed unused variable
  • Loading branch information
will-saunders-ukaea authored Sep 8, 2023
1 parent 20b9a66 commit 278f960
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 21 deletions.
13 changes: 0 additions & 13 deletions include/compute_target.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,19 +330,6 @@ template <typename T> class BufferDevice {
if (this->size > 0) {
this->sycl_target->queue.memcpy(this->ptr, vec.data(), this->size_bytes())
.wait();

auto k_ptr = this->ptr;

sycl::buffer<T, 1> b_vec(vec.data(), vec.size());
sycl_target->queue
.submit([&](sycl::handler &cgh) {
auto a_vec =
b_vec.template get_access<sycl::access::mode::read>(cgh);
cgh.parallel_for<>(
sycl::range<1>(this->size),
[=](sycl::id<1> idx) { k_ptr[idx] = a_vec[idx]; });
})
.wait_and_throw();
}
}

Expand Down
14 changes: 8 additions & 6 deletions include/particle_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,25 +161,27 @@ class ParticleGroup {
* @param particle_dat New ParticleDat to add.
*/
inline void add_particle_dat(ParticleDatSharedPtr<REAL> particle_dat);

/**
* Add a ParticleDat to the ParticleGroup after construction.
*
* @param particle_dat New ParticleDat to add.
*/
inline void add_particle_dat(ParticleDatSharedPtr<INT> particle_dat);

/**
* Add particles to the ParticleGroup. Any rank may add particles that exist
* anywhere in the domain. Implemetation TODO. This call is collective
* across the ParticleGroup and ranks that do not add particles should not
* pass any new particle data.
* anywhere in the domain. This call is collective across the ParticleGroup
* and ranks that do not add particles should not pass any new particle
* data.
*/
inline void add_particles();

/**
* Add particles to the ParticleGroup. Any rank may add particles that exist
* anywhere in the domain. Implemetation TODO. This call is collective
* across the ParticleGroup and ranks that do not add particles should not
* pass any new particle data.
* anywhere in the domain. This call is collective across the ParticleGroup
* and ranks that do not add particles should not pass any new particle
* data.
*
* @param particle_data New particle data to add to the ParticleGroup.
*/
Expand Down
54 changes: 53 additions & 1 deletion test/test_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ TEST(Buffer, Host) {

std::vector<double> empty(0);
BufferHost to_test_empty{sycl_target, empty};

BufferHost to_test_vector{sycl_target, correct};
for (int ix = 0; ix < N; ix++) {
EXPECT_EQ(correct[ix], to_test_vector.ptr[ix]);
}
}

TEST(Buffer, Device) {
Expand All @@ -36,7 +41,7 @@ TEST(Buffer, Device) {
}
sycl::buffer<int, 1> b_to_test(to_test.data(), to_test.size());

BufferHost buffer{sycl_target, correct};
BufferDevice buffer{sycl_target, correct};

EXPECT_EQ(buffer.size, N);
EXPECT_EQ(buffer.size_bytes(), N * sizeof(int));
Expand All @@ -58,6 +63,29 @@ TEST(Buffer, Device) {

std::vector<double> empty(0);
BufferDevice to_test_empty{sycl_target, empty};

for (int ix = 0; ix < N; ix++) {
correct[ix] *= 2;
}

BufferDevice to_test_vector{sycl_target, correct};
auto k_to_test_vector = to_test_vector.ptr;

std::vector<int> to_test2(N);
sycl::buffer<int, 1> b_to_test2(to_test2.data(), to_test2.size());
sycl_target->queue
.submit([&](sycl::handler &cgh) {
auto a_to_test = b_to_test2.get_access<sycl::access::mode::write>(cgh);
cgh.parallel_for<>(sycl::range<1>(N), [=](sycl::id<1> idx) {
a_to_test[idx] = k_to_test_vector[idx];
});
})
.wait_and_throw();

auto h_to_test2 = b_to_test2.get_host_access();
for (int ix = 0; ix < N; ix++) {
EXPECT_EQ(correct[ix], h_to_test2[ix]);
}
}

TEST(Buffer, DeviceHost) {
Expand Down Expand Up @@ -118,4 +146,28 @@ TEST(Buffer, DeviceHost) {

std::vector<double> empty(0);
BufferDeviceHost to_test_empty{sycl_target, empty};

for (int ix = 0; ix < N; ix++) {
correct.at(ix) *= 2;
}

{
BufferDeviceHost to_test_vector{sycl_target, correct};
auto k_to_test_vector = to_test_vector.d_buffer.ptr;
sycl_target->queue
.submit([&](sycl::handler &cgh) {
cgh.parallel_for<>(sycl::range<1>(N), [=](sycl::id<1> idx) {
k_to_test_vector[idx] *= 2;
});
})
.wait_and_throw();

for (int ix = 0; ix < N; ix++) {
EXPECT_EQ(correct[ix], to_test_vector.h_buffer.ptr[ix]);
}
to_test_vector.device_to_host();
for (int ix = 0; ix < N; ix++) {
EXPECT_EQ(correct[ix] * 2, to_test_vector.h_buffer.ptr[ix]);
}
}
}
1 change: 0 additions & 1 deletion test/test_utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ TEST(Utility, NormalDistribution) {

{
std::mt19937 rng0 = std::mt19937(123);
const double extents[1] = {1.0};
auto u0 = NESO::Particles::normal_distribution(1, 1, 2.0, 3.0, rng0);

std::mt19937 rng1 = std::mt19937(123);
Expand Down

0 comments on commit 278f960

Please sign in to comment.