Skip to content

Commit

Permalink
Merge branch 'master' into issue_75/enable_c_on_discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
wavefunction91 committed Oct 30, 2023
2 parents 987c467 + abb2974 commit 3141809
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 49 deletions.
19 changes: 10 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ endif()


# GauXC Options
option( GAUXC_ENABLE_HOST "Enable Host Integrator" ON )
option( GAUXC_ENABLE_CUDA "Enable CUDA Bindings" OFF )
option( GAUXC_ENABLE_HIP "Enable HIP Bindings" OFF )
option( GAUXC_ENABLE_MPI "Enable MPI Bindings" ON )
option( GAUXC_ENABLE_OPENMP "Enable OpenMP Compilation" ON )
option( GAUXC_ENABLE_TESTS "Enable Unit Tests" ON )
option( GAUXC_ENABLE_GAU2GRID "Enable Gau2Grid Collocation" ON )
option( GAUXC_ENABLE_HDF5 "Enable HDF5 Bindings" ON )
option( GAUXC_ENABLE_FAST_RSQRT "Enable Fast RSQRT" OFF )
option( GAUXC_ENABLE_HOST "Enable Host Integrator" ON )
option( GAUXC_ENABLE_CUDA "Enable CUDA Bindings" OFF )
option( GAUXC_ENABLE_HIP "Enable HIP Bindings" OFF )
option( GAUXC_ENABLE_MPI "Enable MPI Bindings" ON )
option( GAUXC_ENABLE_OPENMP "Enable OpenMP Compilation" ON )
option( GAUXC_ENABLE_TESTS "Enable Unit Tests" ON )
option( GAUXC_ENABLE_GAU2GRID "Enable Gau2Grid Collocation" ON )
option( GAUXC_ENABLE_HDF5 "Enable HDF5 Bindings" ON )
option( GAUXC_ENABLE_FAST_RSQRT "Enable Fast RSQRT" OFF )
option( GAUXC_BLAS_PREFER_ILP64 "Enable ILP64 for host BLAS" OFF )

include(CMakeDependentOption)
cmake_dependent_option( GAUXC_ENABLE_MAGMA
Expand Down
9 changes: 8 additions & 1 deletion cmake/gauxc-config.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@ set( GAUXC_ENABLE_MPI @GAUXC_ENABLE_MPI@ )
set( GAUXC_ENABLE_OPENMP @GAUXC_ENABLE_OPENMP@ )
set( GAUXC_ENABLE_GAU2GRID @GAUXC_ENABLE_GAU2GRID@ )
set( GAUXC_HAS_HDF5 @GAUXC_HAS_HDF5@ )
set( GAUXC_BLAS_IS_LP64 @GAUXC_BLAS_IS_LP64@ )

# Make sure C / CXX are enabled (former for BLAS discovery)
enable_language(C)
enable_language(CXX)


if( GAUXC_ENABLE_HOST )
find_dependency( BLAS )
if(GAUXC_BLAS_IS_LP64)
set( _blas_components lp64 )
else()
set( _blas_components ilp64 )
endif()
find_dependency( BLAS COMPONENTS "${_blas_components}")
unset( _blas_components )
endif()

if( GAUXC_ENABLE_CUDA )
Expand Down
16 changes: 15 additions & 1 deletion src/xc_integrator/local_work_driver/host/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
#
# See LICENSE.txt for details
#
find_package( BLAS REQUIRED )
if(GAUXC_BLAS_PREFER_ILP64)
find_package( BLAS REQUIRED OPTIONAL_COMPONENTS ilp64 )
else()
find_package( BLAS REQUIRED )
endif()
include( gauxc-gau2grid )

target_sources( gauxc PRIVATE
Expand All @@ -19,6 +23,16 @@ target_sources( gauxc PRIVATE
blas.cxx
)

if(BLAS_IS_LP64)
message(STATUS "Discovered BLAS is LP64")
set_target_properties(gauxc PROPERTIES COMPILE_DEFINITIONS BLAS_IS_LP64=1)
else()
message(STATUS "Discovered BLAS is ILP64")
set_target_properties(gauxc PROPERTIES COMPILE_DEFINITIONS BLAS_IS_LP64=0)
endif()
set(GAUXC_BLAS_IS_LP64 ${BLAS_IS_LP64} CACHE BOOL "BLAS used in GauXC is LP64" FORCE)


target_link_libraries( gauxc PUBLIC BLAS::BLAS )

if( GAUXC_ENABLE_GAU2GRID AND TARGET gau2grid::gg )
Expand Down
104 changes: 66 additions & 38 deletions src/xc_integrator/local_work_driver/host/blas.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,48 @@
#include <type_traits>
#include <gauxc/exceptions.hpp>

#if BLAS_IS_LP64
#define blas_int int32_t
#else
#define blas_int int64_t
#endif

extern "C" {

//void dlacpy_( const char* UPLO, const int* M, const int* N, const double* A,
// const int* LDA, double* B, const int* LDB );
//void slacpy_( const char* UPLO, const int* M, const int* N, const float* A,
// const int* LDA, float* B, const int* LDB );

void dgemm_( const char* TA, const char* TB, const int* M, const int* N,
const int* K, const double* ALPHA, const double* A,
const int* LDA, const double* B, const int* LDB,
const double* BETA, double* C, const int* LDC );
void sgemm_( const char* TA, const char* TB, const int* M, const int* N,
const int* K, const float* ALPHA, const float* A,
const int* LDA, const float* B, const int* LDB,
const float* BETA, float* C, const int* LDC );

void dsyr2k_( const char* UPLO, const char* TRANS, const int* N, const int* K,
const double* ALPHA, const double* A, const int* LDA, const double* B,
const int* LDB, const double* BETA, double* C, const int* LDC );
void ssyr2k_( const char* UPLO, const char* TRANS, const int* N, const int* K,
const float* ALPHA, const float* A, const int* LDA, const float* B,
const int* LDB, const float* BETA, float* C, const int* LDC );

double ddot_( const int* N, const double* X, const int* INCX, const double* Y,
const int* INCY );
float sdot_( const int* N, const float* X, const int* INCX, const float* Y,
const int* INCY );


void daxpy_( const int* N, const double* ALPHA, const double* A, const int* INCX,
double* Y, const int* INCY );
void saxpy_( const int* N, const float* ALPHA, const float* A, const int* INCX,
float* Y, const int* INCY );

void dscal_( const int* N, const double* ALPHA, const double* X, const int* INCX );
void sscal_( const int* N, const float* ALPHA, const float* X, const int* INCX );
void dgemm_( const char* TA, const char* TB, const blas_int* M, const blas_int* N,
const blas_int* K, const double* ALPHA, const double* A,
const blas_int* LDA, const double* B, const blas_int* LDB,
const double* BETA, double* C, const blas_int* LDC );
void sgemm_( const char* TA, const char* TB, const blas_int* M, const blas_int* N,
const blas_int* K, const float* ALPHA, const float* A,
const blas_int* LDA, const float* B, const blas_int* LDB,
const float* BETA, float* C, const blas_int* LDC );

void dsyr2k_( const char* UPLO, const char* TRANS, const blas_int* N, const blas_int* K,
const double* ALPHA, const double* A, const blas_int* LDA, const double* B,
const blas_int* LDB, const double* BETA, double* C, const blas_int* LDC );
void ssyr2k_( const char* UPLO, const char* TRANS, const blas_int* N, const blas_int* K,
const float* ALPHA, const float* A, const blas_int* LDA, const float* B,
const blas_int* LDB, const float* BETA, float* C, const blas_int* LDC );

double ddot_( const blas_int* N, const double* X, const blas_int* INCX, const double* Y,
const blas_int* INCY );
float sdot_( const blas_int* N, const float* X, const blas_int* INCX, const float* Y,
const blas_int* INCY );


void daxpy_( const blas_int* N, const double* ALPHA, const double* A, const blas_int* INCX,
double* Y, const blas_int* INCY );
void saxpy_( const blas_int* N, const float* ALPHA, const float* A, const blas_int* INCX,
float* Y, const blas_int* INCY );

void dscal_( const blas_int* N, const double* ALPHA, const double* X, const blas_int* INCX );
void sscal_( const blas_int* N, const float* ALPHA, const float* X, const blas_int* INCX );
}

namespace GauXC::blas {
Expand Down Expand Up @@ -97,10 +103,16 @@ template void lacpy( char UPLO, int M, int N, const double* A, int LDA,


template <typename T>
void gemm( char TA, char TB, int M, int N, int K, T ALPHA,
const T* A, int LDA, const T* B, int LDB, T BETA,
T* C, int LDC ) {
void gemm( char TA, char TB, int _M, int _N, int _K, T ALPHA,
const T* A, int _LDA, const T* B, int _LDB, T BETA,
T* C, int _LDC ) {

blas_int M = _M;
blas_int N = _N;
blas_int K = _K;
blas_int LDA = _LDA;
blas_int LDB = _LDB;
blas_int LDC = _LDC;

if constexpr ( std::is_same_v<T,float> )
sgemm_( &TA, &TB, &M, &N, &K, &ALPHA, A, &LDA, B, &LDB, &BETA, C, &LDC );
Expand All @@ -126,10 +138,15 @@ void gemm( char doubleA, char doubleB, int M, int N, int K, double ALPHA,


template <typename T>
void syr2k( char UPLO, char TRANS, int N, int K, T ALPHA,
const T* A, int LDA, const T* B, int LDB, T BETA,
T* C, int LDC ) {
void syr2k( char UPLO, char TRANS, int _N, int _K, T ALPHA,
const T* A, int _LDA, const T* B, int _LDB, T BETA,
T* C, int _LDC ) {

blas_int N = _N;
blas_int K = _K;
blas_int LDA = _LDA;
blas_int LDB = _LDB;
blas_int LDC = _LDC;

if constexpr ( std::is_same_v<T,float> )
ssyr2k_( &UPLO, &TRANS, &N, &K, &ALPHA, A, &LDA, B, &LDB, &BETA, C, &LDC );
Expand All @@ -156,7 +173,11 @@ void syr2k( char UPLO, char doubleRANS, int N, int K, double ALPHA,


template <typename T>
T dot( int N, const T* X, int INCX, const T* Y, int INCY ) {
T dot( int _N, const T* X, int _INCX, const T* Y, int _INCY ) {

blas_int N = _N;
blas_int INCX = _INCX;
blas_int INCY = _INCY;

if constexpr ( std::is_same_v<T,float> )
return sdot_(&N, X, &INCX, Y, &INCY);
Expand All @@ -178,7 +199,11 @@ double dot( int N, const double* X, int INCX, const double* Y, int INCY );


template <typename T>
void axpy( int N, T ALPHA, const T* X, int INCX, T* Y, int INCY ) {
void axpy( int _N, T ALPHA, const T* X, int _INCX, T* Y, int _INCY ) {

blas_int N = _N;
blas_int INCX = _INCX;
blas_int INCY = _INCY;

if constexpr ( std::is_same_v<T,float> )
saxpy_(&N, &ALPHA, X, &INCX, Y, &INCY );
Expand All @@ -201,7 +226,10 @@ void axpy( int N, double ALPHA, const double* A, int INCX, double* Y,


template <typename T>
void scal( int N, T ALPHA, T* X, int INCX ) {
void scal( int _N, T ALPHA, T* X, int _INCX ) {

blas_int N = _N;
blas_int INCX = _INCX;

if constexpr ( std::is_same_v<T,float> )
sscal_(&N, &ALPHA, X, &INCX );
Expand Down

0 comments on commit 3141809

Please sign in to comment.