Skip to content

Commit

Permalink
Fix new ShellBatched for Device, add additional std::future::get to p…
Browse files Browse the repository at this point in the history
…roagate exceptions
  • Loading branch information
wavefunction91 committed May 20, 2024
1 parent d5118ff commit 64e0f35
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*
* See LICENSE.txt for details
*/
#include "shell_batched_replicated_xc_host_integrator.hpp"
#include "shellbatched_replicated_xc_device_integrator.hpp"
#include "shell_batched_replicated_xc_integrator_integrate_den.hpp"
#include "shell_batched_replicated_xc_integrator_exc_vxc.hpp"
#include "shell_batched_replicated_xc_integrator_exc_grad.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/
#pragma once
#include <gauxc/xc_integrator/replicated/replicated_xc_device_integrator.hpp>
#include "reference_replicated_xc_host_integrator.hpp"
#include "incore_replicated_xc_device_integrator.hpp"
#include "shell_batched_replicated_xc_integrator.hpp"

namespace GauXC {
Expand All @@ -17,12 +17,12 @@ template <typename ValueType>
class ShellBatchedReplicatedXCDeviceIntegrator :
public ShellBatchedReplicatedXCIntegrator<
ReplicatedXCDeviceIntegrator<ValueType>,
ReferenceReplicatedXCDeviceIntegrator<ValueType>
IncoreReplicatedXCDeviceIntegrator<ValueType>
> {

using base_type = ShellBatchedReplicatedXCIntegrator<
ReplicatedXCDeviceIntegrator<ValueType>,
ReferenceReplicatedXCDeviceIntegrator<ValueType>
IncoreReplicatedXCDeviceIntegrator<ValueType>
>;

public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#pragma once
#include <gauxc/gauxc_config.hpp>
#include "shell_batched_xc_integrator.hpp"
#ifdef GAUXC_ENABLE_DEVICE
#ifdef GAUXC_HAS_DEVICE
#include "device/xc_device_data.hpp"
#endif

Expand All @@ -32,7 +32,7 @@ class ShellBatchedReplicatedXCIntegrator :

protected:

#ifdef GAUXC_ENABLE_DEVICE
#ifdef GAUXC_HAS_DEVICE
std::unique_ptr<XCDeviceData> device_data_ptr_;
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/
#pragma once
#include "shell_batched_replicated_xc_integrator.hpp"
#ifdef GAUXC_ENABLE_DEVICE
#ifdef GAUXC_HAS_DEVICE
#include "device/local_device_work_driver.hpp"
#include "device/xc_device_aos_data.hpp"
#endif
Expand Down Expand Up @@ -72,7 +72,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
// Get Tasks
auto& tasks = this->load_balancer_->get_tasks();

#ifdef GAUXC_ENABLE_DEVICE
#ifdef GAUXC_HAS_DEVICE
// Allocate Device memory
auto* lwd = dynamic_cast<LocalDeviceWorkDriver*>(this->local_work_driver_.get() );
auto rt = detail::as_device_runtime(this->load_balancer_->runtime());
Expand Down Expand Up @@ -114,7 +114,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
this->reduction_driver_->allreduce_inplace( &N_EL, 1 , ReductionOp::Sum );
});

#ifdef GAUXC_ENABLE_DEVICE
#ifdef GAUXC_HAS_DEVICE
device_data_ptr_.reset();
#endif

Expand Down Expand Up @@ -261,6 +261,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
// Execute remaining tasks sequentially
if( task_future.valid() ) {
task_future.wait();
task_future.get();
}
while( not incore_task_data_queue.empty() ) {
execute_incore_task();
Expand Down Expand Up @@ -379,19 +380,27 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType


// Process selected task batch
#ifdef GAUXC_ENABLE_DEVICE
#ifdef GAUXC_HAS_DEVICE
if constexpr (IncoreIntegratorType::is_device) {
if(Pz or Py or Py)
GAUXC_GENERIC_EXCEPTION("Device UKS/GKS + ShellBatched NYI");

incore_integrator.exc_vxc_local_work( basis_subset, Ps_submat, nbe,
Pz_submat, nbe, Py_submat, nbe, Px_submat, nbe, VXCs_submat, nbe,
VXCz_submat, nbe, VXCy_submat, nbe, VXCx_submat, nbe,
&EXC_tmp, &NEL_tmp, task_begin, task_end, *device_data_ptr_ );
} else {
VXCs_submat, nbe, &EXC_tmp, &NEL_tmp, task_begin, task_end,
*device_data_ptr_ );
// TODO: Make this work for UKS/GKS after
// https://github.com/wavefunction91/GauXC/pull/91
//incore_integrator.exc_vxc_local_work( basis_subset, Ps_submat, nbe,
// Pz_submat, nbe, Py_submat, nbe, Px_submat, nbe, VXCs_submat, nbe,
// VXCz_submat, nbe, VXCy_submat, nbe, VXCx_submat, nbe,
// &EXC_tmp, &NEL_tmp, task_begin, task_end, *device_data_ptr_ );
} else if constexpr (not IncoreIntegratorType::is_device) {
#endif
incore_integrator.exc_vxc_local_work( basis_subset, Ps_submat, nbe,
Pz_submat, nbe, Py_submat, nbe, Px_submat, nbe, VXCs_submat, nbe,
VXCz_submat, nbe, VXCy_submat, nbe, VXCx_submat, nbe,
&EXC_tmp, &NEL_tmp, IntegratorSettingsKS{}, task_begin, task_end );
#ifdef GAUXC_ENABLE_DEVICE
#ifdef GAUXC_HAS_DEVICE
}
#endif

Expand Down

0 comments on commit 64e0f35

Please sign in to comment.