diff --git a/src/xc_integrator/CMakeLists.txt b/src/xc_integrator/CMakeLists.txt index 5c21557d..c8adc6ad 100644 --- a/src/xc_integrator/CMakeLists.txt +++ b/src/xc_integrator/CMakeLists.txt @@ -7,6 +7,7 @@ # add_subdirectory(integrator_util) add_subdirectory(local_work_driver) +add_subdirectory(shell_batched) add_subdirectory(replicated) add_subdirectory(xc_data) diff --git a/src/xc_integrator/replicated/device/CMakeLists.txt b/src/xc_integrator/replicated/device/CMakeLists.txt index f580b381..0d789eff 100644 --- a/src/xc_integrator/replicated/device/CMakeLists.txt +++ b/src/xc_integrator/replicated/device/CMakeLists.txt @@ -8,6 +8,6 @@ target_sources( gauxc PRIVATE replicated_xc_device_integrator.cxx incore_replicated_xc_device_integrator.cxx - shellbatched_replicated_xc_device_integrator.cxx + shell_batched_replicated_xc_device_integrator.cxx ) diff --git a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp index 6d169c03..11d0e067 100644 --- a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp +++ b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp @@ -20,6 +20,7 @@ class IncoreReplicatedXCDeviceIntegrator : public: + static constexpr bool is_device = true; using value_type = typename base_type::value_type; using basis_type = typename base_type::basis_type; diff --git a/src/xc_integrator/replicated/device/replicated_xc_device_integrator.cxx b/src/xc_integrator/replicated/device/replicated_xc_device_integrator.cxx index 534d5333..dc022f18 100644 --- a/src/xc_integrator/replicated/device/replicated_xc_device_integrator.cxx +++ b/src/xc_integrator/replicated/device/replicated_xc_device_integrator.cxx @@ -7,7 +7,7 @@ */ #include #include "incore_replicated_xc_device_integrator.hpp" -#include "shellbatched_replicated_xc_device_integrator.hpp" +#include "shell_batched_replicated_xc_device_integrator.hpp" #include "device/local_device_work_driver.hpp" namespace GauXC { diff --git a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator.cxx b/src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.cxx similarity index 63% rename from src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator.cxx rename to src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.cxx index 4fd5e5ba..35414129 100644 --- a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator.cxx +++ b/src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.cxx @@ -5,10 +5,11 @@ * * See LICENSE.txt for details */ -#include "shellbatched_replicated_xc_device_integrator_integrate_den.hpp" -#include "shellbatched_replicated_xc_device_integrator_exc_vxc.hpp" -#include "shellbatched_replicated_xc_device_integrator_exc_grad.hpp" -#include "shellbatched_replicated_xc_device_integrator_exx.hpp" +#include "shell_batched_replicated_xc_device_integrator.hpp" +#include "shell_batched_replicated_xc_integrator_integrate_den.hpp" +#include "shell_batched_replicated_xc_integrator_exc_vxc.hpp" +#include "shell_batched_replicated_xc_integrator_exc_grad.hpp" +#include "shell_batched_replicated_xc_integrator_exx.hpp" namespace GauXC { namespace detail { diff --git a/src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.hpp b/src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.hpp new file mode 100644 index 00000000..dd165f64 --- /dev/null +++ b/src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.hpp @@ -0,0 +1,42 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ +#pragma once +#include +#include "incore_replicated_xc_device_integrator.hpp" +#include "shell_batched_replicated_xc_integrator.hpp" + +namespace GauXC { +namespace detail { + +template +class ShellBatchedReplicatedXCDeviceIntegrator : + public ShellBatchedReplicatedXCIntegrator< + ReplicatedXCDeviceIntegrator, + IncoreReplicatedXCDeviceIntegrator + > { + + using base_type = ShellBatchedReplicatedXCIntegrator< + ReplicatedXCDeviceIntegrator, + IncoreReplicatedXCDeviceIntegrator + >; + +public: + + template + ShellBatchedReplicatedXCDeviceIntegrator( Args&&... args ) : + base_type( std::forward(args)... ) { } + + virtual ~ShellBatchedReplicatedXCDeviceIntegrator() noexcept; + +}; + +extern template class ShellBatchedReplicatedXCDeviceIntegrator; + +} +} + diff --git a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator.hpp b/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator.hpp deleted file mode 100644 index 01210e0c..00000000 --- a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator.hpp +++ /dev/null @@ -1,131 +0,0 @@ -/** - * GauXC Copyright (c) 2020-2024, The Regents of the University of California, - * through Lawrence Berkeley National Laboratory (subject to receipt of - * any required approvals from the U.S. Dept. of Energy). All rights reserved. - * - * See LICENSE.txt for details - */ -#pragma once -#include -#include "device/xc_device_data.hpp" -#include "incore_replicated_xc_device_integrator.hpp" - -namespace GauXC { -namespace detail { - -template -class ShellBatchedReplicatedXCDeviceIntegrator : - public ReplicatedXCDeviceIntegrator { - - using base_type = ReplicatedXCDeviceIntegrator; - -public: - - using value_type = typename base_type::value_type; - using basis_type = typename base_type::basis_type; - - using host_task_container = std::vector; - using host_task_iterator = typename host_task_container::iterator; - -protected: - - using incore_integrator_type = - IncoreReplicatedXCDeviceIntegrator; - - // Struct to manage data associated with task subset to execute on the device - struct incore_device_task { - host_task_iterator task_begin; - host_task_iterator task_end; - std::vector shell_list; - }; - - void integrate_den_( int64_t m, int64_t n, const value_type* P, - int64_t ldp, value_type* integrate_den ) override; - - void eval_exc_vxc_( int64_t m, int64_t n, const value_type* P, - int64_t ldp, value_type* VXC, int64_t ldvxc, - value_type* EXC, const IntegratorSettingsXC& settings ) override; - - void eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, - int64_t ldps, - const value_type* Pz, - int64_t ldpz, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, - value_type* EXC, const IntegratorSettingsXC& settings ) override; - - void eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, - int64_t ldps, - const value_type* Pz, - int64_t ldpz, - const value_type* Py, - int64_t ldpy, - const value_type* Px, - int64_t ldpx, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, - value_type* VXCy, int64_t ldvxcy, - value_type* VXCx, int64_t ldvxcx, - value_type* EXC, const IntegratorSettingsXC& settings ) override; - - void eval_exc_grad_( int64_t m, int64_t n, const value_type* P, - int64_t ldp, value_type* EXC_GRAD ) override; - - void eval_exx_( int64_t m, int64_t n, const value_type* P, - int64_t ldp, value_type* K, int64_t ldk, - const IntegratorSettingsEXX& settings ) override; - - void exc_vxc_local_work_( const basis_type& basis, const value_type* P, int64_t ldp, - value_type* VXC, int64_t ldvxc, value_type* EXC, value_type *N_EL, - host_task_iterator task_begin, host_task_iterator task_end, - incore_integrator_type& incore_integrator, - XCDeviceData& device_data ); - - void exc_vxc_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps, - const value_type* Pz, int64_t ldpz, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, value_type* EXC, value_type *N_EL, - host_task_iterator task_begin, host_task_iterator task_end, - XCDeviceData& device_data ); - - void exc_vxc_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps, - const value_type* Pz, int64_t ldpz, - const value_type* Py, int64_t ldpy, - const value_type* Px, int64_t ldpx, - value_type* VXC, int64_t ldvxc, - value_type* VXCz, int64_t ldvxcz, - value_type* VXCy, int64_t ldvxcy, - value_type* VXCx, int64_t ldvxcx, value_type* EXC, value_type *N_EL, - host_task_iterator task_begin, host_task_iterator task_end, - XCDeviceData& device_data ); - - void eval_exc_grad_local_work_( int64_t m, int64_t n, const value_type* P, - int64_t ldp, value_type* EXC_GRAD, - host_task_iterator task_begin, host_task_iterator task_end, - incore_integrator_type& incore_integrator, - XCDeviceData& device_data ); - - - incore_device_task generate_incore_device_task( const uint32_t nbf_threshold, - const basis_type& basis, - host_task_iterator task_begin, - host_task_iterator task_end ); - - void execute_task_batch( incore_device_task& task, const basis_type& basis, const Molecule& mol, const value_type* P, - int64_t ldp, value_type* VXC, int64_t ldvxc, value_type* EXC, - value_type* N_EL, incore_integrator_type& incore_integrator, - XCDeviceData& device_data ); -public: - - template - ShellBatchedReplicatedXCDeviceIntegrator( Args&&... args ) : - base_type( std::forward(args)... ) { } - - virtual ~ShellBatchedReplicatedXCDeviceIntegrator() noexcept; - -}; - -extern template class ShellBatchedReplicatedXCDeviceIntegrator; - -} -} diff --git a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_exc_vxc.hpp b/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_exc_vxc.hpp deleted file mode 100644 index df2d6f57..00000000 --- a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_exc_vxc.hpp +++ /dev/null @@ -1,497 +0,0 @@ -/** - * GauXC Copyright (c) 2020-2024, The Regents of the University of California, - * through Lawrence Berkeley National Laboratory (subject to receipt of - * any required approvals from the U.S. Dept. of Energy). All rights reserved. - * - * See LICENSE.txt for details - */ -#include "shellbatched_replicated_xc_device_integrator.hpp" -#include "device/local_device_work_driver.hpp" -#include "device/xc_device_aos_data.hpp" -#include "integrator_util/integrator_common.hpp" -#include "host/util.hpp" -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace GauXC { -namespace detail { - - -template -void ShellBatchedReplicatedXCDeviceIntegrator:: - eval_exc_vxc_( int64_t m, int64_t n, const value_type* P, - int64_t ldp, value_type* VXC, int64_t ldvxc, - value_type* EXC, const IntegratorSettingsXC& settings ) { - - - const auto& basis = this->load_balancer_->basis(); - - // Check that P / VXC are sane - const int64_t nbf = basis.nbf(); - if( m != n ) - GAUXC_GENERIC_EXCEPTION("P/VXC Must Be Square"); - if( m != nbf ) - GAUXC_GENERIC_EXCEPTION("P/VXC Must Have Same Dimension as Basis"); - if( ldp < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDP"); - if( ldvxc < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDVXC"); - - - // Get Tasks - auto& tasks = this->load_balancer_->get_tasks(); - - // Allocate Device memory - auto* lwd = dynamic_cast(this->local_work_driver_.get() ); - auto rt = detail::as_device_runtime(this->load_balancer_->runtime()); - auto device_data_ptr = - this->timer_.time_op("XCIntegrator.DeviceAlloc", - [&](){ return lwd->create_device_data(rt); }); - - // Generate incore integrator instance, transfer ownership of LWD - incore_integrator_type incore_integrator( this->func_, this->load_balancer_, - this->release_local_work_driver(), this->reduction_driver_ ); - - // Temporary electron count to judge integrator accuracy - value_type N_EL; - - // Compute local contributions to EXC/VXC - this->timer_.time_op("XCIntegrator.LocalWork", [&](){ - exc_vxc_local_work_( basis, P, ldp, VXC, ldvxc, EXC, - &N_EL, tasks.begin(), tasks.end(), incore_integrator, - *device_data_ptr ); - }); - - // Release ownership of LWD back to this integrator instance - this->local_work_driver_ = std::move( incore_integrator.release_local_work_driver() ); - - - // Reduce Results - this->timer_.time_op("XCIntegrator.Allreduce", [&](){ - this->reduction_driver_->allreduce_inplace( VXC, nbf*nbf, ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( EXC, 1 , ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( &N_EL, 1 , ReductionOp::Sum ); - }); - - -} - - - -template -void ShellBatchedReplicatedXCDeviceIntegrator:: - eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, - int64_t ldps, - const value_type* Pz, - int64_t ldpz, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, - value_type* EXC, const IntegratorSettingsXC& settings ) { - - GauXC::util::unused(m,n,Ps,ldps,Pz,ldpz,VXCs,ldvxcs,VXCz,ldvxcz,EXC,settings); - GAUXC_GENERIC_EXCEPTION("UKS NOT YET IMPLEMENTED FOR DEVICE"); -} - - -template -void ShellBatchedReplicatedXCDeviceIntegrator:: - eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, - int64_t ldps, - const value_type* Pz, - int64_t ldpz, - const value_type* Py, - int64_t ldpy, - const value_type* Px, - int64_t ldpx, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, - value_type* VXCy, int64_t ldvxcy, - value_type* VXCx, int64_t ldvxcx, - value_type* EXC, const IntegratorSettingsXC& settings ) { - GauXC::util::unused(m,n,Ps,ldps,Pz,ldpz,Py,ldpy,Px,ldpx,VXCs,ldvxcs,VXCz,ldvxcz,VXCy,ldvxcy,VXCx,ldvxcx,EXC,settings); - GAUXC_GENERIC_EXCEPTION("GKS NOT YET IMPLEMENTED FOR DEVICE"); -} - - -template -void ShellBatchedReplicatedXCDeviceIntegrator:: - exc_vxc_local_work_( const basis_type& basis, const value_type* P, int64_t ldp, - value_type* VXC, int64_t ldvxc, value_type* EXC, value_type *N_EL, - host_task_iterator task_begin, host_task_iterator task_end, - incore_integrator_type& incore_integrator, XCDeviceData& device_data ) { - - - //incore_integrator.exc_vxc_local_work( basis, P, ldp, VXC, ldvxc, EXC, N_EL, task_begin, task_end, device_data ); - //return; - - - const auto nbf = basis.nbf(); - const uint32_t nbf_threshold = 8000; - const auto& mol = this->load_balancer_->molecule(); - // Zero out integrands on host - this->timer_.time_op("XCIntegrator.ZeroHost", [&](){ - *EXC = 0.; - *N_EL = 0.; - for( auto j = 0; j < nbf; ++j ) - for( auto i = 0; i < nbf; ++i ) - VXC[i + j*ldvxc] = 0.; - }); - - - // Task queue - std::queue< incore_device_task > incore_device_task_queue; - - // Task queue modification mutex - std::mutex queue_mod_ex; - - // Lambda for the execution of incore tasks on the device - auto execute_incore_device_task = [&]() { - - // Early return if there is no task to execute - if( incore_device_task_queue.empty() ) return; - - incore_device_task next_task; - { - std::lock_guard lock(queue_mod_ex); - - // Move the next task into local scope and remove - // from queue - next_task = std::move( incore_device_task_queue.front() ); - incore_device_task_queue.pop(); - } - - // Execute task - execute_task_batch( next_task, basis, mol, P, ldp, VXC, ldvxc, EXC, N_EL, - incore_integrator, device_data ); - }; - - - // Setup future to track execution of currently running - // device task - std::future task_future; - - auto task_it = task_begin; - while( task_it != task_end ) { - - // Generate and enqueue task - incore_device_task_queue.emplace( - generate_incore_device_task( nbf_threshold, basis, task_it, task_end ) - ); - - // Update iterator for next task generation - task_it = incore_device_task_queue.back().task_end; - - if( not task_future.valid() ) { - // No device task to wait on - task_future = std::async( std::launch::async, execute_incore_device_task ); - } else { - // Check the status of current device task - auto status = task_future.wait_for( std::chrono::milliseconds(5) ); - if( status == std::future_status::ready ) { - // If the status is ready - execute the next task in queue - task_future.get(); - task_future = std::async( std::launch::async, execute_incore_device_task ); - } - } - - } // Loop until all tasks have been enqued - - // TODO: Try to merge remaining tasks appropriately - - // Execute remaining tasks sequentially - if( task_future.valid() ) { - task_future.wait(); - task_future.get(); - } - while( not incore_device_task_queue.empty() ) { - execute_incore_device_task(); - } -} - -template -void ShellBatchedReplicatedXCDeviceIntegrator:: - exc_vxc_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps, - const value_type* Pz, int64_t ldpz, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, value_type* EXC, value_type *N_EL, - host_task_iterator task_begin, host_task_iterator task_end, - XCDeviceData& device_data ) { - - GauXC::util::unused(basis,Ps,ldps,Pz,ldpz,VXCs,ldvxcs,VXCz,ldvxcz,EXC,N_EL,task_begin,task_end,device_data); - GAUXC_GENERIC_EXCEPTION("UKS NOT YET IMPLEMENTED FOR DEVICE"); -} - -template -void ShellBatchedReplicatedXCDeviceIntegrator:: - exc_vxc_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps, - const value_type* Pz, int64_t ldpz, - const value_type* Py, int64_t ldpy, - const value_type* Px, int64_t ldpx, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, - value_type* VXCy, int64_t ldvxcy, - value_type* VXCx, int64_t ldvxcx, value_type* EXC, value_type *N_EL, - host_task_iterator task_begin, host_task_iterator task_end, - XCDeviceData& device_data ) { - - GauXC::util::unused(basis,Ps,ldps,Pz,ldpz,Py,ldpy,Px,ldpx,VXCs,ldvxcs,VXCz,ldvxcz,VXCy,ldvxcy,VXCx,ldvxcx,EXC,N_EL,task_begin,task_end,device_data); - GAUXC_GENERIC_EXCEPTION("GKS NOT YET IMPLEMENTED FOR DEVICE"); -} - - -template -typename ShellBatchedReplicatedXCDeviceIntegrator::incore_device_task - ShellBatchedReplicatedXCDeviceIntegrator:: - generate_incore_device_task( const uint32_t nbf_threshold, - const basis_type& basis, - host_task_iterator task_begin, - host_task_iterator task_end ) { - - - auto nbe_comparator = []( const auto& task_a, const auto& task_b ) { - return task_a.bfn_screening.nbe < task_b.bfn_screening.nbe; - }; - - // Find task with largest NBE - auto max_task = this->timer_.time_op_accumulate("XCIntegrator.MaxTask", [&]() { - return std::max_element( task_begin, task_end, nbe_comparator ); - } ); - - const auto max_shell_list = max_task->bfn_screening.shell_list; // copy for reset - - - // Init union shell list to max shell list outside of loop - std::set union_shell_set(max_shell_list.begin(), - max_shell_list.end()); - - - - int n_overlap_pthresh = 20; - double overlap_pthresh_delta = 1. / n_overlap_pthresh; - std::vector overlap_pthresh; - for( int i = 1; i < n_overlap_pthresh; ++i ) - overlap_pthresh.emplace_back( i*overlap_pthresh_delta ); - - std::vector overlap_pthresh_idx( overlap_pthresh.size() ); - std::iota( overlap_pthresh_idx.begin(), overlap_pthresh_idx.end(), 0 ); - - std::map> - cached_task_ends; - - int cur_partition_pthresh_idx = -1; - - auto _it = std::partition_point( overlap_pthresh_idx.rbegin(), - overlap_pthresh_idx.rend(), - [&](int idx) { - - uint32_t overlap_threshold = - std::max(1., max_shell_list.size() * overlap_pthresh[idx] ); - - - host_task_iterator search_st = task_begin; - host_task_iterator search_en = task_end; - - // Make a local copy of union list - std::set local_union_shell_set; - - // Attempt to limit task search based on current partition - if( cur_partition_pthresh_idx >= 0 ) { - - const auto& last_pthresh = - cached_task_ends.at(cur_partition_pthresh_idx); - - if( cur_partition_pthresh_idx > idx ) { - search_st = last_pthresh.first; - local_union_shell_set = last_pthresh.second; - } else { - search_en = last_pthresh.first; - local_union_shell_set = union_shell_set; - } - - } else { - local_union_shell_set = union_shell_set; - } - - - // Partition tasks into those which overlap max_task up to - // specified threshold - auto local_task_end = - this->timer_.time_op_accumulate("XCIntegrator.TaskIntersection", [&]() { - return std::partition( search_st, search_en, [&](const auto& t) { - return util::integral_list_intersect( max_shell_list, t.bfn_screening.shell_list, - overlap_threshold ); - } ); - } ); - - - - // Take union of shell list for all overlapping tasks - this->timer_.time_op_accumulate("XCIntegrator.ShellListUnion",[&]() { - for( auto task_it = search_st; task_it != local_task_end; ++task_it ) { - local_union_shell_set.insert( task_it->bfn_screening.shell_list.begin(), - task_it->bfn_screening.shell_list.end() ); - } - } ); - - auto cur_nbe = basis.nbf_subset( local_union_shell_set.begin(), - local_union_shell_set.end() ); - - //std::cout << " Threshold % = " << std::setw(5) << overlap_pthresh[idx] << ", "; - //std::cout << " Overlap Threshold = " << std::setw(8) << overlap_threshold << ", "; - //std::cout << " Current NBE = " << std::setw(8) << cur_nbe << std::endl; - - // Cache the data - cached_task_ends[idx] = std::make_pair( local_task_end, local_union_shell_set ); - - // Update partitioned threshold - cur_partition_pthresh_idx = idx; - - return (uint32_t)cur_nbe < nbf_threshold; - - } ); - - host_task_iterator local_task_end; - auto _idx_partition = (_it == overlap_pthresh_idx.rend()) ? 0 : *_it; - std::tie( local_task_end, union_shell_set ) = cached_task_ends.at(_idx_partition); - - - - - - //std::cout << "FOUND " << std::distance( task_begin, local_task_end ) - // << " OVERLAPPING TASKS" << std::endl; - - - std::vector union_shell_list( union_shell_set.begin(), - union_shell_set.end() ); - - // Try to add additional tasks given current union list - local_task_end = this->timer_.time_op_accumulate("XCIntegrator.SubtaskGeneration", [&]() { - return std::partition( local_task_end, task_end, [&]( const auto& t ) { - return util::list_subset( union_shell_list, t.bfn_screening.shell_list ); - } ); - } ); - - //std::cout << "FOUND " << std::distance( task_begin, local_task_end ) - // << " SUBTASKS" << std::endl; - - - incore_device_task ex_task; - ex_task.task_begin = task_begin; - ex_task.task_end = local_task_end; - ex_task.shell_list = std::move( union_shell_list ); - - return ex_task; - -} - - - - - - - - - - - - -template -void ShellBatchedReplicatedXCDeviceIntegrator:: - execute_task_batch( incore_device_task& task, const basis_type& basis, const Molecule& mol, - const value_type* P, int64_t ldp, value_type* VXC, int64_t ldvxc, - value_type* EXC, value_type *N_EL, incore_integrator_type& incore_integrator, - XCDeviceData& device_data ) { - - // Alias information - auto task_begin = task.task_begin; - auto task_end = task.task_end; - auto& union_shell_list = task.shell_list; - - - // Extract subbasis - BasisSet basis_subset; basis_subset.reserve(union_shell_list.size()); - this->timer_.time_op_accumulate("XCIntegrator.CopySubBasis",[&]() { - for( auto i : union_shell_list ) { - basis_subset.emplace_back( basis.at(i) ); - } - }); - - // Setup basis maps - BasisSetMap basis_map( basis, mol ); - - //const size_t nshells = basis_subset.nshells(); - const size_t nbe = basis_subset.nbf(); - //std::cout << "TASK_UNION HAS:" << std::endl - // << " NSHELLS = " << nshells << std::endl - // << " NBE = " << nbe << std::endl; - - // Recalculate shell_list based on subbasis - this->timer_.time_op_accumulate("XCIntegrator.RecalcShellList",[&]() { - for( auto _it = task_begin; _it != task_end; ++_it ) { - auto union_list_idx = 0; - auto& cur_shell_list = _it->bfn_screening.shell_list; - for( auto j = 0ul; j < cur_shell_list.size(); ++j ) { - while( union_shell_list[union_list_idx] != cur_shell_list[j] ) - union_list_idx++; - cur_shell_list[j] = union_list_idx; - } - } - } ); - - - // Allocate host temporaries - std::vector P_submat_host(nbe*nbe), VXC_submat_host(nbe*nbe,0.); - double EXC_tmp, NEL_tmp; - double* P_submat = P_submat_host.data(); - double* VXC_submat = VXC_submat_host.data(); - - - - // Extract subdensity - std::vector> union_submat_cut; - std::vector foo; - std::tie(union_submat_cut,foo) = - gen_compressed_submat_map( basis_map, union_shell_list, - basis.nbf(), basis.nbf() ); - - this->timer_.time_op_accumulate("XCIntegrator.ExtractSubDensity",[&]() { - detail::submat_set( basis.nbf(), basis.nbf(), nbe, nbe, P, ldp, - P_submat, nbe, union_submat_cut ); - } ); - - - // Process selected task batch - incore_integrator.exc_vxc_local_work( basis_subset, P_submat, nbe, VXC_submat, nbe, - &EXC_tmp, &NEL_tmp, task_begin, task_end, device_data ); - - - // Update full quantities - *EXC += EXC_tmp; - *N_EL += NEL_tmp; - this->timer_.time_op_accumulate("XCIntegrator.IncrementSubPotential",[&]() { - detail::inc_by_submat( basis.nbf(), basis.nbf(), nbe, nbe, VXC, ldvxc, - VXC_submat, nbe, union_submat_cut ); - }); - - - // Reset shell_list to be wrt full basis - this->timer_.time_op_accumulate("XCIntegrator.ResetShellList",[&]() { - for( auto _it = task_begin; _it != task_end; ++_it ) - for( auto j = 0ul; j < _it->bfn_screening.shell_list.size(); ++j ) { - _it->bfn_screening.shell_list[j] = union_shell_list[_it->bfn_screening.shell_list[j]]; - } - }); - -} - -} -} - diff --git a/src/xc_integrator/replicated/host/CMakeLists.txt b/src/xc_integrator/replicated/host/CMakeLists.txt index caabe0aa..ae47dc6d 100644 --- a/src/xc_integrator/replicated/host/CMakeLists.txt +++ b/src/xc_integrator/replicated/host/CMakeLists.txt @@ -8,5 +8,6 @@ target_sources( gauxc PRIVATE replicated_xc_host_integrator.cxx reference_replicated_xc_host_integrator.cxx + shell_batched_replicated_xc_host_integrator.cxx ) diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.cxx b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.cxx index 3b0da926..7471e273 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.cxx +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.cxx @@ -10,8 +10,7 @@ #include "reference_replicated_xc_host_integrator_exc_grad.hpp" #include "reference_replicated_xc_host_integrator_exx.hpp" -namespace GauXC { -namespace detail { +namespace GauXC::detail { template ReferenceReplicatedXCHostIntegrator::~ReferenceReplicatedXCHostIntegrator() noexcept = default; @@ -19,4 +18,3 @@ ReferenceReplicatedXCHostIntegrator::~ReferenceReplicatedXCHostIntegr template class ReferenceReplicatedXCHostIntegrator; } -} diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp index d1567572..0d25f213 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp @@ -9,8 +9,7 @@ #include #include "xc_host_data.hpp" -namespace GauXC { -namespace detail { +namespace GauXC::detail { template class ReferenceReplicatedXCHostIntegrator : @@ -20,34 +19,35 @@ class ReferenceReplicatedXCHostIntegrator : public: + static constexpr bool is_device = false; using value_type = typename base_type::value_type; using basis_type = typename base_type::basis_type; + using task_container = std::vector; + using task_iterator = typename task_container::iterator; + protected: - void integrate_den_( int64_t m, int64_t n, const value_type* P, - int64_t ldp, value_type* N_EL ) override; + // Density Integration + void integrate_den_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* N_EL ) override; - void eval_exc_vxc_( int64_t m, int64_t n, const value_type* P, - int64_t ldp, value_type* VXC, int64_t ldvxc, - value_type* EXC, const IntegratorSettingsXC& ks_settings ) override; + /// RKS EXC/VXC + void eval_exc_vxc_( int64_t m, int64_t n, const value_type* P, int64_t ldp, + value_type* VXC, int64_t ldvxc, value_type* EXC, + const IntegratorSettingsXC& ks_settings ) override; - void eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, - int64_t ldps, - const value_type* Pz, - int64_t ldpz, + /// UKS EXC/VXC + void eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, value_type* VXCs, int64_t ldvxcs, value_type* VXCz, int64_t ldvxcz, value_type* EXC, const IntegratorSettingsXC& ks_settings ) override; - void eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, - int64_t ldps, - const value_type* Pz, - int64_t ldpz, - const value_type* Py, - int64_t ldpy, - const value_type* Px, - int64_t ldpx, + /// GKS EXC/VXC - also serves as the generic implementation + void eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + const value_type* Py, int64_t ldpy, + const value_type* Px, int64_t ldpx, value_type* VXCs, int64_t ldvxcs, value_type* VXCz, int64_t ldvxcz, value_type* VXCy, int64_t ldvxcy, @@ -55,24 +55,23 @@ class ReferenceReplicatedXCHostIntegrator : value_type* EXC, const IntegratorSettingsXC& ks_settings ) override; + /// RKS EXC Gradient void eval_exc_grad_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* EXC_GRAD ) override; + /// sn-LinK void eval_exx_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* K, int64_t ldk, const IntegratorSettingsEXX& settings ) override; - void integrate_den_local_work_( const value_type* P, int64_t ldp, - value_type *N_EL ); - void exc_vxc_local_work_( const value_type* Ps, int64_t ldps, - const value_type* Pz, int64_t ldpz, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, - value_type* EXC, value_type *N_EL ); + // Implementation details of integrate_den + void integrate_den_local_work_( const value_type* P, int64_t ldp, + value_type *N_EL ); - void exc_vxc_local_work_( const value_type* Ps, int64_t ldps, + // Implementation details of exc_vxc (for RKS/UKS/GKS deduced from input character) + void exc_vxc_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps, const value_type* Pz, int64_t ldpz, const value_type* Py, int64_t ldpy, const value_type* Px, int64_t ldpx, @@ -80,9 +79,13 @@ class ReferenceReplicatedXCHostIntegrator : value_type* VXCz, int64_t ldvxcz, value_type* VXCy, int64_t ldvxcy, value_type* VXCx, int64_t ldvxcx, - value_type* EXC, value_type *N_EL, const IntegratorSettingsXC& ks_settings ); + value_type* EXC, value_type *N_EL, const IntegratorSettingsXC& ks_settings, + task_iterator task_begin, task_iterator task_end ); + // Implemetation details of exc_grad void exc_grad_local_work_( const value_type* P, int64_t ldp, value_type* EXC_GRAD ); + + // Implementation details of sn-LinK void exx_local_work_( const value_type* P, int64_t ldp, value_type* K, int64_t ldk, const IntegratorSettingsEXX& settings ); @@ -94,9 +97,15 @@ class ReferenceReplicatedXCHostIntegrator : virtual ~ReferenceReplicatedXCHostIntegrator() noexcept; + + template + void exc_vxc_local_work(Args&&... args) { + exc_vxc_local_work_( std::forward(args)... ); + } + + }; extern template class ReferenceReplicatedXCHostIntegrator; -} -} +} // namespace GauXC::detail diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exc_grad.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exc_grad.hpp index a736802d..05972bbf 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exc_grad.hpp +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exc_grad.hpp @@ -13,8 +13,7 @@ #include "host/blas.hpp" #include -namespace GauXC { -namespace detail { +namespace GauXC::detail { template void ReferenceReplicatedXCHostIntegrator:: @@ -449,5 +448,4 @@ void ReferenceReplicatedXCHostIntegrator:: } -} -} +} // namespace GauXC::detail diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exc_vxc.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exc_vxc.hpp index d891b543..de8244ee 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exc_vxc.hpp +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exc_vxc.hpp @@ -13,165 +13,67 @@ #include "host/blas.hpp" #include -namespace GauXC { -namespace detail { +namespace GauXC::detail { +/** + * Generic implementation of EXC/VXC for RKS/UKS/GKS + * + * If passed pointers are null-y and the leading dimensions + * are zero, RKS/UKS are deduced. RKS/UKS drivers delegate + * to this function/ + */ template void ReferenceReplicatedXCHostIntegrator:: - eval_exc_vxc_( int64_t m, int64_t n, const value_type* P, - int64_t ldp, value_type* VXC, int64_t ldvxc, + eval_exc_vxc_( int64_t m, int64_t n, + const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + const value_type* Py, int64_t ldpy, + const value_type* Px, int64_t ldpx, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* VXCy, int64_t ldvxcy, + value_type* VXCx, int64_t ldvxcx, value_type* EXC, const IntegratorSettingsXC& ks_settings ) { const auto& basis = this->load_balancer_->basis(); - // Check that P / VXC are sane - const int64_t nbf = basis.nbf(); - if( m != n ) - GAUXC_GENERIC_EXCEPTION("P/VXC Must Be Square"); - if( m != nbf ) - GAUXC_GENERIC_EXCEPTION("P/VXC Must Have Same Dimension as Basis"); - if( ldp < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDP"); - if( ldvxc < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDVXC"); - - - // Get Tasks - this->load_balancer_->get_tasks(); - - // Temporary electron count to judge integrator accuracy - value_type N_EL; - - // Compute Local contributions to EXC / VXC - this->timer_.time_op("XCIntegrator.LocalWork", [&](){ - //exc_vxc_local_work_( P, ldp, VXC, ldvxc, EXC, &N_EL ); - exc_vxc_local_work_( P, ldp, nullptr, 0, nullptr, 0, nullptr, 0, - VXC, ldvxc, nullptr, 0, nullptr, 0, nullptr, 0, EXC, &N_EL, ks_settings ); - }); - - - // Reduce Results - this->timer_.time_op("XCIntegrator.Allreduce", [&](){ - - if( not this->reduction_driver_->takes_host_memory() ) - GAUXC_GENERIC_EXCEPTION("This Module Only Works With Host Reductions"); - - this->reduction_driver_->allreduce_inplace( VXC, nbf*nbf, ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( EXC, 1 , ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( &N_EL, 1 , ReductionOp::Sum ); - - }); - -} - -template -void ReferenceReplicatedXCHostIntegrator:: - eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, - int64_t ldps, - const value_type* Pz, - int64_t ldpz, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, - value_type* EXC, const IntegratorSettingsXC& ks_settings) { - - const auto& basis = this->load_balancer_->basis(); - // Check that P / VXC are sane const int64_t nbf = basis.nbf(); if( m != n ) GAUXC_GENERIC_EXCEPTION("P/VXC Must Be Square"); if( m != nbf ) GAUXC_GENERIC_EXCEPTION("P/VXC Must Have Same Dimension as Basis"); - if( ldps < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDPSCALAR"); - if( ldpz < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDPZ"); - if( ldvxcs < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDVXCSCALAR"); - if( ldvxcz < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDVXCZ"); - - // Get Tasks - this->load_balancer_->get_tasks(); - - // Temporary electron count to judge integrator accuracy - value_type N_EL; - - // Compute Local contributions to EXC / VXC - this->timer_.time_op("XCIntegrator.LocalWork", [&](){ - exc_vxc_local_work_( Ps, ldps, Pz, ldpz, nullptr, 0,nullptr, 0, - VXCs, ldvxcs, VXCz, ldvxcz, nullptr, 0, nullptr, 0, EXC, &N_EL, ks_settings ); - }); - - - // Reduce Results - this->timer_.time_op("XCIntegrator.Allreduce", [&](){ - - if( not this->reduction_driver_->takes_host_memory() ) - GAUXC_GENERIC_EXCEPTION("This Module Only Works With Host Reductions"); - - this->reduction_driver_->allreduce_inplace( VXCs, nbf*nbf, ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( VXCz, nbf*nbf, ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( EXC, 1 , ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( &N_EL, 1 , ReductionOp::Sum ); - - }); - - -} -template -void ReferenceReplicatedXCHostIntegrator:: - eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, - int64_t ldps, - const value_type* Pz, - int64_t ldpz, - const value_type* Py, - int64_t ldpy, - const value_type* Px, - int64_t ldpx, - value_type* VXCs, int64_t ldvxcs, - value_type* VXCz, int64_t ldvxcz, - value_type* VXCy, int64_t ldvxcy, - value_type* VXCx, int64_t ldvxcx, - value_type* EXC, const IntegratorSettingsXC& ks_settings ) { - - const auto& basis = this->load_balancer_->basis(); - - // Check that P / VXC are sane - const int64_t nbf = basis.nbf(); - if( m != n ) - GAUXC_GENERIC_EXCEPTION("P/VXC Must Be Square"); - if( m != nbf ) - GAUXC_GENERIC_EXCEPTION("P/VXC Must Have Same Dimension as Basis"); if( ldps < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDPSCALAR"); - if( ldpz < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDPS"); + if( ldpz and ldpz < nbf ) GAUXC_GENERIC_EXCEPTION("Invalid LDPZ"); - if( ldpy < nbf ) + if( ldpy and ldpy < nbf ) GAUXC_GENERIC_EXCEPTION("Invalid LDPX"); - if( ldpx < nbf ) + if( ldpx and ldpx < nbf ) GAUXC_GENERIC_EXCEPTION("Invalid LDPY"); + if( ldvxcs < nbf ) - GAUXC_GENERIC_EXCEPTION("Invalid LDVXCSCALAR"); - if( ldvxcz < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDVXCS"); + if( ldvxcz and ldvxcz < nbf ) GAUXC_GENERIC_EXCEPTION("Invalid LDVXCZ"); - if( ldvxcy < nbf ) + if( ldvxcy and ldvxcy < nbf ) GAUXC_GENERIC_EXCEPTION("Invalid LDVXCX"); - if( ldvxcx < nbf ) + if( ldvxcx and ldvxcx < nbf ) GAUXC_GENERIC_EXCEPTION("Invalid LDVXCY"); // Get Tasks - this->load_balancer_->get_tasks(); + auto& tasks = this->load_balancer_->get_tasks(); // Temporary electron count to judge integrator accuracy value_type N_EL; // Compute Local contributions to EXC / VXC this->timer_.time_op("XCIntegrator.LocalWork", [&](){ - exc_vxc_local_work_( Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx, + exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx, VXCs, ldvxcs, VXCz, ldvxcz, - VXCy, ldvxcy, VXCx, ldvxcx, EXC, &N_EL, ks_settings ); + VXCy, ldvxcy, VXCx, ldvxcx, EXC, &N_EL, ks_settings, + tasks.begin(), tasks.end() ); }); @@ -182,9 +84,10 @@ void ReferenceReplicatedXCHostIntegrator:: GAUXC_GENERIC_EXCEPTION("This Module Only Works With Host Reductions"); this->reduction_driver_->allreduce_inplace( VXCs, nbf*nbf, ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( VXCz, nbf*nbf, ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( VXCy, nbf*nbf, ReductionOp::Sum ); - this->reduction_driver_->allreduce_inplace( VXCx, nbf*nbf, ReductionOp::Sum ); + if(VXCz) this->reduction_driver_->allreduce_inplace( VXCz, nbf*nbf, ReductionOp::Sum ); + if(VXCy) this->reduction_driver_->allreduce_inplace( VXCy, nbf*nbf, ReductionOp::Sum ); + if(VXCx) this->reduction_driver_->allreduce_inplace( VXCx, nbf*nbf, ReductionOp::Sum ); + this->reduction_driver_->allreduce_inplace( EXC, 1 , ReductionOp::Sum ); this->reduction_driver_->allreduce_inplace( &N_EL, 1 , ReductionOp::Sum ); @@ -194,10 +97,11 @@ void ReferenceReplicatedXCHostIntegrator:: } - +/// Generic implementation details of EXC/VXC local work - deduces RKS/UKS/GKS +/// based on null-y / zero parameters template void ReferenceReplicatedXCHostIntegrator:: - exc_vxc_local_work_( const value_type* Ps, int64_t ldps, + exc_vxc_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps, const value_type* Pz, int64_t ldpz, const value_type* Py, int64_t ldpy, const value_type* Px, int64_t ldpx, @@ -205,7 +109,9 @@ void ReferenceReplicatedXCHostIntegrator:: value_type* VXCz, int64_t ldvxcz, value_type* VXCy, int64_t ldvxcy, value_type* VXCx, int64_t ldvxcx, - value_type* EXC, value_type *N_EL, const IntegratorSettingsXC& settings) { + value_type* EXC, value_type *N_EL, + const IntegratorSettingsXC& settings, + task_iterator task_begin, task_iterator task_end) { const bool is_gks = (Pz != nullptr) and (VXCz != nullptr) and (VXCy != nullptr) and (VXCx != nullptr); const bool is_uks = (Pz != nullptr) and (VXCz != nullptr) and (VXCy == nullptr) and (VXCx == nullptr); @@ -228,7 +134,6 @@ void ReferenceReplicatedXCHostIntegrator:: // Setup Aliases const auto& func = *this->func_; - const auto& basis = this->load_balancer_->basis(); const auto& mol = this->load_balancer_->molecule(); const bool needs_laplacian = func.needs_laplacian(); @@ -248,7 +153,7 @@ void ReferenceReplicatedXCHostIntegrator:: }; auto& tasks = this->load_balancer_->get_tasks(); - std::sort( tasks.begin(), tasks.end(), task_comparator ); + std::sort( task_begin, task_end, task_comparator ); // Check that Partition Weights have been calculated @@ -284,7 +189,7 @@ void ReferenceReplicatedXCHostIntegrator:: double NEL_WORK = 0.0; // Loop over tasks - const size_t ntasks = tasks.size(); + const size_t ntasks = std::distance(task_begin, task_end); #pragma omp parallel { @@ -297,7 +202,7 @@ void ReferenceReplicatedXCHostIntegrator:: //std::cout << iT << "/" << ntasks << std::endl; //printf("%lu / %lu\n", iT, ntasks); // Alias current task - const auto& task = tasks[iT]; + const auto& task = *(task_begin + iT); // Get tasks constants const int32_t npts = task.points.size(); @@ -683,5 +588,36 @@ void ReferenceReplicatedXCHostIntegrator:: } + + +/// RKS EXC/VXC driver - delegates to generic GKS impl +template +void ReferenceReplicatedXCHostIntegrator:: + eval_exc_vxc_( int64_t m, int64_t n, + const value_type* P, int64_t ldp, + value_type* VXC, int64_t ldvxc, + value_type* EXC, const IntegratorSettingsXC& ks_settings) { + + eval_exc_vxc_(m, n, P, ldp, nullptr, 0, nullptr, 0, nullptr, 0, + VXC, ldvxc, nullptr, 0, nullptr, 0, nullptr, 0, EXC, ks_settings); + } + + +/// UKS EXC/VXC driver - delegates to generic GKS impl +template +void ReferenceReplicatedXCHostIntegrator:: + eval_exc_vxc_( int64_t m, int64_t n, + const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* EXC, const IntegratorSettingsXC& ks_settings) { + + eval_exc_vxc_(m, n, Ps, ldps, Pz, ldpz, nullptr, 0, nullptr, 0, + VXCs, ldvxcs, VXCz, ldvxcz, nullptr, 0, nullptr, 0, + EXC, ks_settings); + } + +} // namespace GauXC::detail diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exx.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exx.hpp index 7c827734..4898fd29 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exx.hpp +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exx.hpp @@ -27,8 +27,7 @@ ostream& operator<<( ostream& out, const vector& v ) { } } -namespace GauXC { -namespace detail { +namespace GauXC::detail { template void ReferenceReplicatedXCHostIntegrator:: @@ -543,5 +542,4 @@ void ReferenceReplicatedXCHostIntegrator:: } -} -} +} // namespace GauXC::detail diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_integrate_den.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_integrate_den.hpp index 669178ca..425d5ac9 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_integrate_den.hpp +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_integrate_den.hpp @@ -12,8 +12,7 @@ #include "host/local_host_work_driver.hpp" #include -namespace GauXC { -namespace detail { +namespace GauXC::detail { template void ReferenceReplicatedXCHostIntegrator:: @@ -25,9 +24,9 @@ void ReferenceReplicatedXCHostIntegrator:: // Check that P / VXC are sane const int64_t nbf = basis.nbf(); if( m != n ) - GAUXC_GENERIC_EXCEPTION("P/VXC Must Be Square"); + GAUXC_GENERIC_EXCEPTION("P Must Be Square"); if( m != nbf ) - GAUXC_GENERIC_EXCEPTION("P/VXC Must Have Same Dimension as Basis"); + GAUXC_GENERIC_EXCEPTION("P Must Have Same Dimension as Basis"); if( ldp < nbf ) GAUXC_GENERIC_EXCEPTION("Invalid LDP"); @@ -164,5 +163,4 @@ void ReferenceReplicatedXCHostIntegrator:: } -} -} +} // namespace GauXC::detail diff --git a/src/xc_integrator/replicated/host/replicated_xc_host_integrator.cxx b/src/xc_integrator/replicated/host/replicated_xc_host_integrator.cxx index b8a4e06f..aaa09ff6 100644 --- a/src/xc_integrator/replicated/host/replicated_xc_host_integrator.cxx +++ b/src/xc_integrator/replicated/host/replicated_xc_host_integrator.cxx @@ -7,10 +7,10 @@ */ #include #include "reference_replicated_xc_host_integrator.hpp" +#include "shell_batched_replicated_xc_host_integrator.hpp" #include "host/local_host_work_driver.hpp" -namespace GauXC { -namespace detail { +namespace GauXC::detail { template ReplicatedXCHostIntegrator::~ReplicatedXCHostIntegrator() noexcept = default; @@ -43,6 +43,11 @@ typename ReplicatedXCHostIntegratorFactory::ptr_return_t func, lb, std::move(lwd), rd ); + else if( integrator_kernel == "SHELLBATCHED" ) + return std::make_unique>( + func, lb, std::move(lwd), rd + ); + else GAUXC_GENERIC_EXCEPTION("INTEGRATOR KERNEL: " + integrator_kernel + " NOT RECOGNIZED"); @@ -54,6 +59,5 @@ typename ReplicatedXCHostIntegratorFactory::ptr_return_t template class ReplicatedXCHostIntegratorFactory; -} -} +} // namespace GauXC::detail diff --git a/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.cxx b/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.cxx new file mode 100644 index 00000000..e5b265c5 --- /dev/null +++ b/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.cxx @@ -0,0 +1,23 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ +#include "shell_batched_replicated_xc_host_integrator.hpp" +#include "shell_batched_replicated_xc_integrator_integrate_den.hpp" +#include "shell_batched_replicated_xc_integrator_exc_vxc.hpp" +#include "shell_batched_replicated_xc_integrator_exc_grad.hpp" +#include "shell_batched_replicated_xc_integrator_exx.hpp" + +namespace GauXC { +namespace detail { + +template +ShellBatchedReplicatedXCHostIntegrator::~ShellBatchedReplicatedXCHostIntegrator() noexcept = default; + +template class ShellBatchedReplicatedXCHostIntegrator; + +} +} diff --git a/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.hpp b/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.hpp new file mode 100644 index 00000000..3c3db085 --- /dev/null +++ b/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.hpp @@ -0,0 +1,42 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ +#pragma once +#include +#include "reference_replicated_xc_host_integrator.hpp" +#include "shell_batched_replicated_xc_integrator.hpp" + +namespace GauXC { +namespace detail { + +template +class ShellBatchedReplicatedXCHostIntegrator : + public ShellBatchedReplicatedXCIntegrator< + ReplicatedXCHostIntegrator, + ReferenceReplicatedXCHostIntegrator + > { + + using base_type = ShellBatchedReplicatedXCIntegrator< + ReplicatedXCHostIntegrator, + ReferenceReplicatedXCHostIntegrator + >; + +public: + + template + ShellBatchedReplicatedXCHostIntegrator( Args&&... args ) : + base_type( std::forward(args)... ) { } + + virtual ~ShellBatchedReplicatedXCHostIntegrator() noexcept; + +}; + +extern template class ShellBatchedReplicatedXCHostIntegrator; + +} +} + diff --git a/src/xc_integrator/shell_batched/CMakeLists.txt b/src/xc_integrator/shell_batched/CMakeLists.txt new file mode 100644 index 00000000..636666c4 --- /dev/null +++ b/src/xc_integrator/shell_batched/CMakeLists.txt @@ -0,0 +1,12 @@ +# +# GauXC Copyright (c) 2020-2024, The Regents of the University of California, +# through Lawrence Berkeley National Laboratory (subject to receipt of +# any required approvals from the U.S. Dept. of Energy). All rights reserved. +# +# See LICENSE.txt for details +# +target_sources( gauxc PRIVATE shell_batched_xc_integrator.cxx ) +target_include_directories( gauxc + PUBLIC + $ +) diff --git a/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp new file mode 100644 index 00000000..969257df --- /dev/null +++ b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp @@ -0,0 +1,116 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ +#pragma once +#include +#include "shell_batched_xc_integrator.hpp" +#ifdef GAUXC_HAS_DEVICE +#include "device/xc_device_data.hpp" +#endif + +namespace GauXC { +namespace detail { + +template +class ShellBatchedReplicatedXCIntegrator : + public BaseIntegratorType, + public ShellBatchedXCIntegratorBase { + + using base_type = BaseIntegratorType; + +public: + + using value_type = typename base_type::value_type; + using basis_type = typename base_type::basis_type; + + using host_task_container = std::vector; + using host_task_iterator = typename host_task_container::iterator; + +protected: + +#ifdef GAUXC_HAS_DEVICE + std::unique_ptr device_data_ptr_; +#endif + + using incore_integrator_type = IncoreIntegratorType; + using incore_task_data = ShellBatchedXCIntegratorBase::incore_task_data; + + // Density Integration + void integrate_den_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* N_EL ) override; + + /// RKS EXC/VXC + void eval_exc_vxc_( int64_t m, int64_t n, const value_type* P, int64_t ldp, + value_type* VXC, int64_t ldvxc, value_type* EXC, + const IntegratorSettingsXC& ks_settings ) override; + + /// UKS EXC/VXC + void eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* EXC, const IntegratorSettingsXC& ks_settings ) override; + + /// GKS EXC/VXC - also serves as the generic implementation + void eval_exc_vxc_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + const value_type* Py, int64_t ldpy, + const value_type* Px, int64_t ldpx, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* VXCy, int64_t ldvxcy, + value_type* VXCx, int64_t ldvxcx, + value_type* EXC, const IntegratorSettingsXC& ks_settings ) override; + + + /// RKS EXC Gradient + void eval_exc_grad_( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXC_GRAD ) override; + + /// sn-LinK + void eval_exx_( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* K, int64_t ldk, + const IntegratorSettingsEXX& settings ) override; + + + + + // Implementation details of exc_vxc (for RKS/UKS/GKS deduced from input character) + void exc_vxc_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + const value_type* Py, int64_t ldpy, + const value_type* Px, int64_t ldpx, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* VXCy, int64_t ldvxcy, + value_type* VXCx, int64_t ldvxcx, + value_type* EXC, value_type *N_EL, + host_task_iterator task_begin, host_task_iterator task_end, incore_integrator_type& incore_integrator + ); + + + void execute_task_batch( incore_task_data& task, const basis_type& basis, const Molecule& mol, + const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + const value_type* Py, int64_t ldpy, + const value_type* Px, int64_t ldpx, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* VXCy, int64_t ldvxcy, + value_type* VXCx, int64_t ldvxcx, + value_type* EXC, value_type* N_EL, incore_integrator_type& incore_integrator); +public: + + template + ShellBatchedReplicatedXCIntegrator( Args&&... args ) : + base_type( std::forward(args)... ) { } + + virtual ~ShellBatchedReplicatedXCIntegrator() noexcept = default; + +}; + +} +} diff --git a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_exc_grad.hpp b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_grad.hpp similarity index 59% rename from src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_exc_grad.hpp rename to src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_grad.hpp index 50e9b5e2..39901b7d 100644 --- a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_exc_grad.hpp +++ b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_grad.hpp @@ -5,26 +5,16 @@ * * See LICENSE.txt for details */ -#include "shellbatched_replicated_xc_device_integrator.hpp" -#include "device/local_device_work_driver.hpp" -#include "device/xc_device_aos_data.hpp" -#include "integrator_util/integrator_common.hpp" -#include "host/util.hpp" +#pragma once +#include "shell_batched_replicated_xc_integrator.hpp" #include #include -#include -#include -#include -#include -#include -#include - namespace GauXC { namespace detail { -template -void ShellBatchedReplicatedXCDeviceIntegrator:: +template +void ShellBatchedReplicatedXCIntegrator:: eval_exc_grad_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* EXC_GRAD ) { diff --git a/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_vxc.hpp b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_vxc.hpp new file mode 100644 index 00000000..7378d32a --- /dev/null +++ b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exc_vxc.hpp @@ -0,0 +1,440 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ +#pragma once +#include "shell_batched_replicated_xc_integrator.hpp" +#ifdef GAUXC_HAS_DEVICE +#include "device/local_device_work_driver.hpp" +#include "device/xc_device_aos_data.hpp" +#endif +#include "integrator_util/integrator_common.hpp" +#include "host/util.hpp" +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace GauXC { +namespace detail { + + +template +void ShellBatchedReplicatedXCIntegrator:: + eval_exc_vxc_( int64_t m, int64_t n, + const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + const value_type* Py, int64_t ldpy, + const value_type* Px, int64_t ldpx, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* VXCy, int64_t ldvxcy, + value_type* VXCx, int64_t ldvxcx, + value_type* EXC, const IntegratorSettingsXC& ks_settings) { + + + const auto& basis = this->load_balancer_->basis(); + + // Check that P / VXC are sane + const int64_t nbf = basis.nbf(); + if( m != n ) + GAUXC_GENERIC_EXCEPTION("P/VXC Must Be Square"); + if( m != nbf ) + GAUXC_GENERIC_EXCEPTION("P/VXC Must Have Same Dimension as Basis"); + + if( ldps < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDPS"); + if( ldpz and ldpz < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDPZ"); + if( ldpy and ldpy < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDPX"); + if( ldpx and ldpx < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDPY"); + + if( ldvxcs < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDVXCS"); + if( ldvxcz and ldvxcz < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDVXCZ"); + if( ldvxcy and ldvxcy < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDVXCX"); + if( ldvxcx and ldvxcx < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDVXCY"); + + + // Get Tasks + auto& tasks = this->load_balancer_->get_tasks(); + + #ifdef GAUXC_HAS_DEVICE + // Allocate Device memory + auto* lwd = dynamic_cast(this->local_work_driver_.get() ); + auto rt = detail::as_device_runtime(this->load_balancer_->runtime()); + if constexpr (IncoreIntegratorType::is_device) { + device_data_ptr_ = + this->timer_.time_op("XCIntegrator.DeviceAlloc", + [&](){ return lwd->create_device_data(rt); }); + } + #endif + + // Generate incore integrator instance, transfer ownership of LWD + incore_integrator_type incore_integrator( this->func_, this->load_balancer_, + this->release_local_work_driver(), this->reduction_driver_ ); + + // Temporary electron count to judge integrator accuracy + value_type N_EL; + + // Compute local contributions to EXC/VXC + this->timer_.time_op("XCIntegrator.LocalWork", [&](){ + exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx, + VXCs, ldvxcs, VXCz, ldvxcz, VXCy, ldvxcy, VXCx, ldvxcx, EXC, + &N_EL, tasks.begin(), tasks.end(), incore_integrator ); + }); + + // Release ownership of LWD back to this integrator instance + this->local_work_driver_ = std::move( incore_integrator.release_local_work_driver() ); + + + // Reduce Results + this->timer_.time_op("XCIntegrator.Allreduce", [&](){ + if( not this->reduction_driver_->takes_host_memory() ) + GAUXC_GENERIC_EXCEPTION("This Module Only Works With Host Reductions"); + + this->reduction_driver_->allreduce_inplace( VXCs, nbf*nbf, ReductionOp::Sum ); + if(VXCz) this->reduction_driver_->allreduce_inplace( VXCz, nbf*nbf, ReductionOp::Sum ); + if(VXCy) this->reduction_driver_->allreduce_inplace( VXCy, nbf*nbf, ReductionOp::Sum ); + if(VXCx) this->reduction_driver_->allreduce_inplace( VXCx, nbf*nbf, ReductionOp::Sum ); + this->reduction_driver_->allreduce_inplace( EXC, 1 , ReductionOp::Sum ); + this->reduction_driver_->allreduce_inplace( &N_EL, 1 , ReductionOp::Sum ); + }); + + #ifdef GAUXC_HAS_DEVICE + device_data_ptr_.reset(); + #endif + +} + + + +template +void ShellBatchedReplicatedXCIntegrator:: + eval_exc_vxc_( int64_t m, int64_t n, + const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* EXC, const IntegratorSettingsXC& ks_settings) { + + eval_exc_vxc_(m, n, Ps, ldps, Pz, ldpz, nullptr, 0, nullptr, 0, + VXCs, ldvxcs, VXCz, ldvxcz, nullptr, 0, nullptr, 0, + EXC, ks_settings); + +} + +template +void ShellBatchedReplicatedXCIntegrator:: + eval_exc_vxc_( int64_t m, int64_t n, + const value_type* P, int64_t ldp, + value_type* VXC, int64_t ldvxc, + value_type* EXC, const IntegratorSettingsXC& ks_settings) { + + eval_exc_vxc_(m, n, P, ldp, nullptr, 0, nullptr, 0, nullptr, 0, + VXC, ldvxc, nullptr, 0, nullptr, 0, nullptr, 0, EXC, ks_settings); + +} + +template +void ShellBatchedReplicatedXCIntegrator:: + exc_vxc_local_work_( const basis_type& basis, + const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + const value_type* Py, int64_t ldpy, + const value_type* Px, int64_t ldpx, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* VXCy, int64_t ldvxcy, + value_type* VXCx, int64_t ldvxcx, + value_type* EXC, value_type *N_EL, + host_task_iterator task_begin, host_task_iterator task_end, + incore_integrator_type& incore_integrator ) { + + //incore_integrator.exc_vxc_local_work( basis, P, ldp, VXC, ldvxc, EXC, N_EL, task_begin, task_end, device_data ); + //return; + + + const auto nbf = basis.nbf(); + const uint32_t nbf_threshold = 8000; + const auto& mol = this->load_balancer_->molecule(); + // Zero out integrands on host + this->timer_.time_op("XCIntegrator.ZeroHost", [&](){ + *EXC = 0.; + *N_EL = 0.; + for( auto j = 0; j < nbf; ++j ) + for( auto i = 0; i < nbf; ++i ) { + VXCs[i + j*ldvxcs] = 0.; + } + if(VXCz) + for( auto j = 0; j < nbf; ++j ) + for( auto i = 0; i < nbf; ++i ) { + VXCz[i + j*ldvxcz] = 0.; + } + if(VXCy) + for( auto j = 0; j < nbf; ++j ) + for( auto i = 0; i < nbf; ++i ) { + VXCy[i + j*ldvxcy] = 0.; + } + if(VXCx) + for( auto j = 0; j < nbf; ++j ) + for( auto i = 0; i < nbf; ++i ) { + VXCx[i + j*ldvxcx] = 0.; + } + }); + + + // Task queue + std::queue< incore_task_data > incore_task_data_queue; + + // Task queue modification mutex + std::mutex queue_mod_ex; + + // Lambda for the execution of incore tasks on the device + auto execute_incore_task = [&]() { + + // Early return if there is no task to execute + if( incore_task_data_queue.empty() ) return; + + incore_task_data next_task; + { + std::lock_guard lock(queue_mod_ex); + + // Move the next task into local scope and remove + // from queue + next_task = std::move( incore_task_data_queue.front() ); + incore_task_data_queue.pop(); + } + + // Execute task + execute_task_batch( next_task, basis, mol, Ps, ldps, Pz, ldpz, + Py, ldpy, Px, ldpx, VXCs, ldvxcs, VXCz, ldvxcz, VXCy, ldvxcy, + VXCx, ldvxcx, EXC, N_EL, incore_integrator ); + }; + + + // Setup future to track execution of currently running + // device task + std::future task_future; + + auto task_it = task_begin; + while( task_it != task_end ) { + + // Generate and enqueue task + incore_task_data_queue.emplace( + generate_incore_task( nbf_threshold, basis, task_it, task_end ) + ); + + // Update iterator for next task generation + task_it = incore_task_data_queue.back().task_end; + + if( not task_future.valid() ) { + // No device task to wait on + task_future = std::async( std::launch::async, execute_incore_task ); + } else { + // Check the status of current device task + auto status = task_future.wait_for( std::chrono::milliseconds(5) ); + if( status == std::future_status::ready ) { + // If the status is ready - execute the next task in queue + task_future.get(); + task_future = std::async( std::launch::async, execute_incore_task ); + } + } + + } // Loop until all tasks have been enqued + + // TODO: Try to merge remaining tasks appropriately + + // Execute remaining tasks sequentially + if( task_future.valid() ) { + task_future.wait(); + task_future.get(); // Propagate trailing exceptions if present + } + while( not incore_task_data_queue.empty() ) { + execute_incore_task(); + } +} + + + +template +void ShellBatchedReplicatedXCIntegrator:: + execute_task_batch( incore_task_data& task, const basis_type& basis, const Molecule& mol, + const value_type* Ps, int64_t ldps, + const value_type* Pz, int64_t ldpz, + const value_type* Py, int64_t ldpy, + const value_type* Px, int64_t ldpx, + value_type* VXCs, int64_t ldvxcs, + value_type* VXCz, int64_t ldvxcz, + value_type* VXCy, int64_t ldvxcy, + value_type* VXCx, int64_t ldvxcx, + value_type* EXC, value_type *N_EL, + incore_integrator_type& incore_integrator ) { + + + // Alias information + auto task_begin = task.task_begin; + auto task_end = task.task_end; + auto& union_shell_list = task.shell_list; + + + // Extract subbasis + BasisSet basis_subset; basis_subset.reserve(union_shell_list.size()); + this->timer_.time_op_accumulate("XCIntegrator.CopySubBasis",[&]() { + for( auto i : union_shell_list ) { + basis_subset.emplace_back( basis.at(i) ); + } + }); + + // Setup basis maps + BasisSetMap basis_map( basis, mol ); + + //const size_t nshells = basis_subset.nshells(); + const size_t nbe = basis_subset.nbf(); + //std::cout << "TASK_UNION HAS:" << std::endl + // << " NSHELLS = " << nshells << std::endl + // << " NBE = " << nbe << std::endl; + + // Recalculate shell_list based on subbasis + this->timer_.time_op_accumulate("XCIntegrator.RecalcShellList",[&]() { + for( auto _it = task_begin; _it != task_end; ++_it ) { + auto union_list_idx = 0; + auto& cur_shell_list = _it->bfn_screening.shell_list; + for( auto j = 0ul; j < cur_shell_list.size(); ++j ) { + while( union_shell_list[union_list_idx] != cur_shell_list[j] ) + union_list_idx++; + cur_shell_list[j] = union_list_idx; + } + } + } ); + + + // Allocate host temporaries + std::vector Ps_submat_host(nbe*nbe), VXCs_submat_host(nbe*nbe,0.); + double EXC_tmp, NEL_tmp; + double* Ps_submat = Ps_submat_host.data(); + double* VXCs_submat = VXCs_submat_host.data(); + + std::vector Pz_submat_host, Py_submat_host, Px_submat_host; + std::vector VXCz_submat_host, VXCy_submat_host, VXCx_submat_host; + double *Pz_submat = nullptr, *Py_submat = nullptr, *Px_submat = nullptr; + double *VXCz_submat = nullptr, *VXCy_submat = nullptr , *VXCx_submat = nullptr; + + if(Pz) { + Pz_submat_host.resize(nbe*nbe); + VXCz_submat_host.resize(nbe*nbe, 0.0); + Pz_submat = Pz_submat_host.data(); + VXCz_submat = VXCz_submat_host.data(); + } + + if(Py) { + Py_submat_host.resize(nbe*nbe); + VXCy_submat_host.resize(nbe*nbe, 0.0); + Py_submat = Py_submat_host.data(); + VXCy_submat = VXCy_submat_host.data(); + } + + if(Px) { + Px_submat_host.resize(nbe*nbe); + VXCx_submat_host.resize(nbe*nbe, 0.0); + Px_submat = Px_submat_host.data(); + VXCx_submat = VXCx_submat_host.data(); + } + + + // Extract subdensity + std::vector> union_submat_cut; + std::vector foo; + std::tie(union_submat_cut,foo) = + gen_compressed_submat_map( basis_map, union_shell_list, + basis.nbf(), basis.nbf() ); + + this->timer_.time_op_accumulate("XCIntegrator.ExtractSubDensity",[&]() { + detail::submat_set( basis.nbf(), basis.nbf(), nbe, nbe, Ps, ldps, + Ps_submat, nbe, union_submat_cut ); + if(Pz) + detail::submat_set( basis.nbf(), basis.nbf(), nbe, nbe, Pz, ldpz, + Pz_submat, nbe, union_submat_cut ); + + if(Py) + detail::submat_set( basis.nbf(), basis.nbf(), nbe, nbe, Py, ldpy, + Py_submat, nbe, union_submat_cut ); + + if(Px) + detail::submat_set( basis.nbf(), basis.nbf(), nbe, nbe, Px, ldpx, + Px_submat, nbe, union_submat_cut ); + } ); + + + // Process selected task batch +#ifdef GAUXC_HAS_DEVICE + if constexpr (IncoreIntegratorType::is_device) { + if(Pz or Py or Py) + GAUXC_GENERIC_EXCEPTION("Device UKS/GKS + ShellBatched NYI"); + + incore_integrator.exc_vxc_local_work( basis_subset, Ps_submat, nbe, + VXCs_submat, nbe, &EXC_tmp, &NEL_tmp, task_begin, task_end, + *device_data_ptr_ ); + // TODO: Make this work for UKS/GKS after + // https://github.com/wavefunction91/GauXC/pull/91 + //incore_integrator.exc_vxc_local_work( basis_subset, Ps_submat, nbe, + // Pz_submat, nbe, Py_submat, nbe, Px_submat, nbe, VXCs_submat, nbe, + // VXCz_submat, nbe, VXCy_submat, nbe, VXCx_submat, nbe, + // &EXC_tmp, &NEL_tmp, task_begin, task_end, *device_data_ptr_ ); + } else if constexpr (not IncoreIntegratorType::is_device) { +#endif + incore_integrator.exc_vxc_local_work( basis_subset, Ps_submat, nbe, + Pz_submat, nbe, Py_submat, nbe, Px_submat, nbe, VXCs_submat, nbe, + VXCz_submat, nbe, VXCy_submat, nbe, VXCx_submat, nbe, + &EXC_tmp, &NEL_tmp, IntegratorSettingsKS{}, task_begin, task_end ); +#ifdef GAUXC_HAS_DEVICE + } +#endif + + + // Update full quantities + *EXC += EXC_tmp; + *N_EL += NEL_tmp; + this->timer_.time_op_accumulate("XCIntegrator.IncrementSubPotential",[&]() { + detail::inc_by_submat( basis.nbf(), basis.nbf(), nbe, nbe, VXCs, ldvxcs, + VXCs_submat, nbe, union_submat_cut ); + if(VXCz) + detail::inc_by_submat( basis.nbf(), basis.nbf(), nbe, nbe, VXCz, ldvxcz, + VXCz_submat, nbe, union_submat_cut ); + + if(VXCy) + detail::inc_by_submat( basis.nbf(), basis.nbf(), nbe, nbe, VXCy, ldvxcy, + VXCy_submat, nbe, union_submat_cut ); + + if(VXCx) + detail::inc_by_submat( basis.nbf(), basis.nbf(), nbe, nbe, VXCx, ldvxcx, + VXCx_submat, nbe, union_submat_cut ); + }); + + + // Reset shell_list to be wrt full basis + this->timer_.time_op_accumulate("XCIntegrator.ResetShellList",[&]() { + for( auto _it = task_begin; _it != task_end; ++_it ) + for( auto j = 0ul; j < _it->bfn_screening.shell_list.size(); ++j ) { + _it->bfn_screening.shell_list[j] = union_shell_list[_it->bfn_screening.shell_list[j]]; + } + }); + +} + +} +} + diff --git a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_exx.hpp b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exx.hpp similarity index 60% rename from src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_exx.hpp rename to src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exx.hpp index 81105897..1f788549 100644 --- a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_exx.hpp +++ b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exx.hpp @@ -5,26 +5,16 @@ * * See LICENSE.txt for details */ -#include "shellbatched_replicated_xc_device_integrator.hpp" -#include "device/local_device_work_driver.hpp" -#include "device/xc_device_aos_data.hpp" -#include "integrator_util/integrator_common.hpp" -#include "host/util.hpp" +#pragma once +#include "shell_batched_replicated_xc_integrator.hpp" #include #include -#include -#include -#include -#include -#include -#include - namespace GauXC { namespace detail { -template -void ShellBatchedReplicatedXCDeviceIntegrator:: +template +void ShellBatchedReplicatedXCIntegrator:: eval_exx_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* K, int64_t ldk, const IntegratorSettingsEXX& settings ) { diff --git a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_integrate_den.hpp b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_integrate_den.hpp similarity index 58% rename from src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_integrate_den.hpp rename to src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_integrate_den.hpp index f07e64ac..3045ccc4 100644 --- a/src/xc_integrator/replicated/device/shellbatched_replicated_xc_device_integrator_integrate_den.hpp +++ b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_integrate_den.hpp @@ -5,26 +5,16 @@ * * See LICENSE.txt for details */ -#include "shellbatched_replicated_xc_device_integrator.hpp" -#include "device/local_device_work_driver.hpp" -#include "device/xc_device_aos_data.hpp" -#include "integrator_util/integrator_common.hpp" -#include "host/util.hpp" +#pragma once +#include "shell_batched_replicated_xc_integrator.hpp" #include #include -#include -#include -#include -#include -#include -#include - namespace GauXC { namespace detail { -template -void ShellBatchedReplicatedXCDeviceIntegrator:: +template +void ShellBatchedReplicatedXCIntegrator:: integrate_den_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* N_EL ) { diff --git a/src/xc_integrator/shell_batched/shell_batched_xc_integrator.cxx b/src/xc_integrator/shell_batched/shell_batched_xc_integrator.cxx new file mode 100644 index 00000000..314f0027 --- /dev/null +++ b/src/xc_integrator/shell_batched/shell_batched_xc_integrator.cxx @@ -0,0 +1,152 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ + +#include "shell_batched_xc_integrator.hpp" +#include +#include +#include +#include +#include + +namespace GauXC::detail { + +ShellBatchedXCIntegratorBase::incore_task_data + ShellBatchedXCIntegratorBase::generate_incore_task( uint32_t nbf_threshold, + const basis_type& basis, host_task_iterator task_begin, + host_task_iterator task_end ) { + + // Find task with largest NBE + auto nbe_comparator = []( const auto& task_a, const auto& task_b ) { + return task_a.bfn_screening.nbe < task_b.bfn_screening.nbe; + }; + auto max_task = std::max_element( task_begin, task_end, nbe_comparator ); + + const auto max_shell_list = max_task->bfn_screening.shell_list; // copy for reset + + // Init union shell list to max shell list outside of loop + std::set union_shell_set(max_shell_list.begin(), + max_shell_list.end()); + + + // Voodoo: once only Manwe and I knew what was happening here, now + // only Manwe knows + int n_overlap_pthresh = 20; + double overlap_pthresh_delta = 1. / n_overlap_pthresh; + std::vector overlap_pthresh; + for( int i = 1; i < n_overlap_pthresh; ++i ) + overlap_pthresh.emplace_back( i*overlap_pthresh_delta ); + + std::vector overlap_pthresh_idx( overlap_pthresh.size() ); + std::iota( overlap_pthresh_idx.begin(), overlap_pthresh_idx.end(), 0 ); + + std::map> + cached_task_ends; + + int cur_partition_pthresh_idx = -1; + + auto _it = std::partition_point( overlap_pthresh_idx.rbegin(), + overlap_pthresh_idx.rend(), + [&](int idx) { + + uint32_t overlap_threshold = + std::max(1., max_shell_list.size() * overlap_pthresh[idx] ); + + + host_task_iterator search_st = task_begin; + host_task_iterator search_en = task_end; + + // Make a local copy of union list + std::set local_union_shell_set; + + // Attempt to limit task search based on current partition + if( cur_partition_pthresh_idx >= 0 ) { + + const auto& last_pthresh = + cached_task_ends.at(cur_partition_pthresh_idx); + + if( cur_partition_pthresh_idx > idx ) { + search_st = last_pthresh.first; + local_union_shell_set = last_pthresh.second; + } else { + search_en = last_pthresh.first; + local_union_shell_set = union_shell_set; + } + + } else { + local_union_shell_set = union_shell_set; + } + + + // Partition tasks into those which overlap max_task up to + // specified threshold + auto local_task_end = std::partition( search_st, search_en, + [&](const auto& t) { + return util::integral_list_intersect( max_shell_list, + t.bfn_screening.shell_list, overlap_threshold ); + } ); + + + + // Take union of shell list for all overlapping tasks + for( auto task_it = search_st; task_it != local_task_end; ++task_it ) { + local_union_shell_set.insert( task_it->bfn_screening.shell_list.begin(), + task_it->bfn_screening.shell_list.end() ); + } + + auto cur_nbe = basis.nbf_subset( local_union_shell_set.begin(), + local_union_shell_set.end() ); + + //std::cout << " Threshold % = " << std::setw(5) << overlap_pthresh[idx] << ", "; + //std::cout << " Overlap Threshold = " << std::setw(8) << overlap_threshold << ", "; + //std::cout << " Current NBE = " << std::setw(8) << cur_nbe << std::endl; + + // Cache the data + cached_task_ends[idx] = std::make_pair( local_task_end, local_union_shell_set ); + + // Update partitioned threshold + cur_partition_pthresh_idx = idx; + + return (uint32_t)cur_nbe < nbf_threshold; + + } ); + + host_task_iterator local_task_end; + auto _idx_partition = (_it == overlap_pthresh_idx.rend()) ? 0 : *_it; + std::tie( local_task_end, union_shell_set ) = + cached_task_ends.at(_idx_partition); + + + + + + //std::cout << "FOUND " << std::distance( task_begin, local_task_end ) + // << " OVERLAPPING TASKS" << std::endl; + + + std::vector union_shell_list( union_shell_set.begin(), + union_shell_set.end() ); + + // Try to add additional tasks given current union list + local_task_end = std::partition( local_task_end, task_end, + [&]( const auto& t ) { + return util::list_subset( union_shell_list, t.bfn_screening.shell_list ); + } ); + + //std::cout << "FOUND " << std::distance( task_begin, local_task_end ) + // << " SUBTASKS" << std::endl; + + + incore_task_data ex_task; + ex_task.task_begin = task_begin; + ex_task.task_end = local_task_end; + ex_task.shell_list = std::move( union_shell_list ); + + return ex_task; +} + +} diff --git a/src/xc_integrator/shell_batched/shell_batched_xc_integrator.hpp b/src/xc_integrator/shell_batched/shell_batched_xc_integrator.hpp new file mode 100644 index 00000000..1d04169d --- /dev/null +++ b/src/xc_integrator/shell_batched/shell_batched_xc_integrator.hpp @@ -0,0 +1,38 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ +#pragma once +#include +#include + +namespace GauXC { +namespace detail { + +struct ShellBatchedXCIntegratorBase { + + using basis_type = BasisSet; + + using host_task_container = std::vector; + using host_task_iterator = typename host_task_container::iterator; + + // Struct to manage data associated with task subset to execute in batch + struct incore_task_data { + host_task_iterator task_begin; + host_task_iterator task_end; + std::vector shell_list; + }; + + incore_task_data generate_incore_task( + uint32_t nbf_threshold, const basis_type& basis, + host_task_iterator task_begin, host_task_iterator task_end ); + + virtual ~ShellBatchedXCIntegratorBase() noexcept = default; + +}; + +} +} diff --git a/tests/xc_integrator.cxx b/tests/xc_integrator.cxx index c70b9e0a..a4a89766 100644 --- a/tests/xc_integrator.cxx +++ b/tests/xc_integrator.cxx @@ -264,8 +264,14 @@ void test_integrator(std::string reference_file, functional_type& func, PruningS #ifdef GAUXC_HAS_HOST SECTION( "Host" ) { - test_xc_integrator( ExecutionSpace::Host, rt, reference_file, func, - pruning_scheme, true, true, true ); + SECTION("Reference") { + test_xc_integrator( ExecutionSpace::Host, rt, reference_file, func, + pruning_scheme, true, true, true ); + } + SECTION("ShellBatched") { + test_xc_integrator( ExecutionSpace::Host, rt, reference_file, func, + pruning_scheme, false, false, false, "ShellBatched" ); + } } #endif