Skip to content

Commit

Permalink
fixed automatic workspace memory
Browse files Browse the repository at this point in the history
  • Loading branch information
TysonRayJones authored Sep 20, 2023
1 parent 7e2362a commit f7c6ee6
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 71 deletions.
18 changes: 12 additions & 6 deletions QuEST/include/QuEST.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,15 @@

// ensure custatevecHandle_t is defined, even if no GPU
# ifdef USE_CUQUANTUM
# include <custatevec.h>
# include <custatevec.h>
typedef struct CuQuantumConfig {
cudaMemPool_t cuMemPool;
cudaStream_t cuStream;
custatevecHandle_t cuQuantumHandle;
custatevecDeviceMemHandler_t cuMemHandler;
} CuQuantumConfig;
# else
# define custatevecHandle_t void*
# define CuQuantumConfig void*
# endif


Expand Down Expand Up @@ -379,10 +385,10 @@ typedef struct Qureg
//! Storage for reduction of probabilities on GPU
qreal *firstLevelReduction, *secondLevelReduction;

//! Storage for wavefunction amplitues and config (copy of QuESTEnv's handle) in the cuQuantum version
//! Storage for wavefunction amplitues and config (copy of QuESTEnv's handle) in cuQuantum deployment
cuAmp* cuStateVec;
cuAmp* deviceCuStateVec;
custatevecHandle_t cuQuantumHandle;
CuQuantumConfig* cuConfig;

//! Storage for generated QASM output
QASMLogger* qasmLog;
Expand All @@ -403,8 +409,8 @@ typedef struct QuESTEnv
unsigned long int* seeds;
int numSeeds;

// handle to cuQuantum (specifically cuStateVec) used only cuQuantum deployment mode (otherwise is void*)
custatevecHandle_t cuQuantumHandle;
// a copy of the QuESTEnv's config, used only in cuQuantum deployment
CuQuantumConfig* cuConfig;

} QuESTEnv;

Expand Down
7 changes: 6 additions & 1 deletion QuEST/include/QuEST_precision.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@

# include <math.h>

// define CUDA complex types as void if not using cuQuantum

// define CUDA complex types as void if not using cuQuantum.
// note we used cuComplex.h for complex numbers, in lieu of
// Thrust's complex<qreal>, so that the QuEST.h header can
// always be compiled with C99, rather than C++14.
# ifdef USE_CUQUANTUM
# include <cuComplex.h>
# else
# define cuFloatComplex void
# define cuDoubleComplex void
# endif


// set default double precision if not set during compilation
# ifndef QuEST_PREC
# define QuEST_PREC 2
Expand Down
165 changes: 101 additions & 64 deletions QuEST/src/GPU/QuEST_cuQuantum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
# define CU_AMP_IN_MATRIX_PREC void // invalid
#endif



// convenient operator overloads for cuAmp, for doing complex artihmetic.
// some of these are defined to be used by Thrust's backend, because we
// avoided Thrust's complex<qreal> (see QuEST_precision.h for explanation).
Expand Down Expand Up @@ -124,6 +122,89 @@ std::vector<int> getIndsFromMask(long long int mask, int numBits) {



/*
* CUQUANTUM MEMORY MANAGEMENT
*/

int GPUSupportsMemPools() {

// consult only the first device (garuanteed already to exist)
int deviceId;
cudaGetDevice(&deviceId);
int supports;
cudaDeviceGetAttribute(&supports, cudaDevAttrMemoryPoolsSupported, deviceId);
return supports;
}

int memPoolAlloc(void* ctx, void** ptr, size_t size, cudaStream_t stream) {
cudaMemPool_t& pool = *static_cast<cudaMemPool_t*>(ctx);
return cudaMallocFromPoolAsync(ptr, size, pool, stream);
}
int memPoolFree(void* ctx, void* ptr, size_t size, cudaStream_t stream) {
return cudaFreeAsync(ptr, stream);
}

CuQuantumConfig* createCuConfig() {

// create cuQuantumConfig in heap memory
CuQuantumConfig* config = (CuQuantumConfig*) malloc(sizeof(CuQuantumConfig));

// bind existing memory pool (does not need later manual freeing)
int deviceId;
cudaGetDevice(&deviceId);
cudaDeviceGetMemPool(&(config->cuMemPool), deviceId);

// create new custatevecHandle_t
custatevecCreate(&(config->cuQuantumHandle));

// create new cudaStream_t
cudaStreamCreate(&(config->cuStream));

// custatevecDeviceMemHandler_t needs no explicit creation

// return config's heap pointer
return config;
}

void initCuConfig(CuQuantumConfig* config) {

// get existing memPool threshold, above which memory gets freed at every stream synch
size_t currMaxMem;
cudaMemPoolGetAttribute(config->cuMemPool, cudaMemPoolAttrReleaseThreshold, &currMaxMem);

// if memPool threshold smaller than 1 MiB = 16 qubits, extend it
size_t desiredMaxMem = 16*(1<<15);
if (currMaxMem < desiredMaxMem)
cudaMemPoolSetAttribute(config->cuMemPool, cudaMemPoolAttrReleaseThreshold, &desiredMaxMem);

// bind mempool to deviceMemHandler
config->cuMemHandler.ctx = &(config->cuMemPool);
config->cuMemHandler.device_alloc = memPoolAlloc;
config->cuMemHandler.device_free = memPoolFree;
strcpy(config->cuMemHandler.name, "mempool");

// bind deviceMemHandler to cuQuantum
custatevecSetDeviceMemHandler(config->cuQuantumHandle, &(config->cuMemHandler));

// bind stream to cuQuantum
custatevecSetStream(config->cuQuantumHandle, config->cuStream);
}

void destroyCuConfig(CuQuantumConfig* config) {

// free config's heap attributes
cudaStreamDestroy(config->cuStream);
custatevecDestroy(config->cuQuantumHandle);

// don't need to free cuMemPool; it already existed
// don't need to free cuMemHandler; it's a struct included in config's heap memory

// free config's heap memory
free(config);
}



/*
* CUQUANTUM WRAPPERS (to reduce boilerplate)
*/
Expand All @@ -141,7 +222,7 @@ void custatevec_applyMatrix(Qureg qureg, std::vector<int> ctrls, std::vector<int
size_t workSize = 0;

custatevecApplyMatrix(
qureg.cuQuantumHandle,
qureg.cuConfig->cuQuantumHandle,
qureg.deviceCuStateVec, CU_AMP_IN_STATE_PREC, qureg.numQubitsInStateVec,
matr.data(), CU_AMP_IN_MATRIX_PREC, CUSTATEVEC_MATRIX_LAYOUT_ROW, adj,
targs.data(), targs.size(),
Expand All @@ -166,7 +247,7 @@ void custatevec_applyDiagonal(Qureg qureg, std::vector<int> ctrls, std::vector<i
size_t workSize = 0;

custatevecApplyGeneralizedPermutationMatrix(
qureg.cuQuantumHandle, qureg.deviceCuStateVec,
qureg.cuConfig->cuQuantumHandle, qureg.deviceCuStateVec,
CU_AMP_IN_STATE_PREC, qureg.numQubitsInStateVec,
perm, elems, CU_AMP_IN_MATRIX_PREC, adj,
targs.data(), targs.size(),
Expand Down Expand Up @@ -424,50 +505,6 @@ extern "C" {
* ENVIRONMENT MANAGEMENT
*/

int GPUSupportsMemPools() {

// consult only the first device (garuanteed already to exist)
int device = 0;

int supports;
cudaDeviceGetAttribute(&supports, cudaDevAttrMemoryPoolsSupported, device);
return supports;
}

int memPoolAlloc(void* ctx, void** ptr, size_t size, cudaStream_t stream) {
cudaMemPool_t pool = * reinterpret_cast<cudaMemPool_t*>(ctx);
return cudaMallocFromPoolAsync(ptr, size, pool, stream);
}
int memPoolFree(void* ctx, void* ptr, size_t size, cudaStream_t stream) {
return cudaFreeAsync(ptr, stream);
}

void setupAutoWorkspaces(custatevecHandle_t cuQuantumHandle) {

// get the current (device's default) stream-ordered memory pool (assuming single GPU)
int device = 0;
cudaMemPool_t memPool;
cudaDeviceGetMemPool(&memPool, device);

// get its current memory threshold, above which memory gets freed at every stream synch
size_t currMaxMem;
cudaMemPoolGetAttribute(memPool, cudaMemPoolAttrReleaseThreshold, &currMaxMem);

// if it's smaller than 1 MiB = 16 qubits, extend it
size_t desiredMaxMem = 16*(1LL<<16);
if (currMaxMem < desiredMaxMem)
cudaMemPoolSetAttribute(memPool, cudaMemPoolAttrReleaseThreshold, &desiredMaxMem);

// create a mem handler around the mem pool
custatevecDeviceMemHandler_t memHandler;
memHandler.ctx = reinterpret_cast<void*>(&memPool);
memHandler.device_alloc = memPoolAlloc;
memHandler.device_free = memPoolFree;

// set cuQuantum to use this handler and pool, to automate workspace memory management
custatevecSetDeviceMemHandler(cuQuantumHandle, &memHandler);
}

QuESTEnv createQuESTEnv(void) {
validateGPUExists(GPUExists(), __func__);
validateGPUIsCuQuantumCompatible(GPUSupportsMemPools(),__func__);
Expand All @@ -480,18 +517,18 @@ QuESTEnv createQuESTEnv(void) {
env.numSeeds = 0;
seedQuESTDefault(&env);

// prepare cuQuantum
custatevecCreate(&env.cuQuantumHandle);
setupAutoWorkspaces(env.cuQuantumHandle);
// prepare cuQuantum with automatic workspaces
env.cuConfig = createCuConfig();
initCuConfig(env.cuConfig);

return env;
}

void destroyQuESTEnv(QuESTEnv env){
free(env.seeds);

// finalise cuQuantum
custatevecDestroy(env.cuQuantumHandle);
destroyCuConfig(env.cuConfig);
}


Expand All @@ -511,8 +548,8 @@ void statevec_createQureg(Qureg *qureg, int numQubits, QuESTEnv env)
qureg->numChunks = 1;
qureg->isDensityMatrix = 0;

// copy env's cuQuantum handle to qureg
qureg->cuQuantumHandle = env.cuQuantumHandle;
// copy env's cuQuantum config handle
qureg->cuConfig = env.cuConfig;

// allocate user-facing CPU memory
qureg->stateVec.real = (qreal*) malloc(numAmps * sizeof(qureg->stateVec.real));
Expand Down Expand Up @@ -627,7 +664,7 @@ void densmatr_initPlusState(Qureg qureg)
void statevec_initZeroState(Qureg qureg)
{
custatevecInitializeStateVector(
qureg.cuQuantumHandle,
qureg.cuConfig->cuQuantumHandle,
qureg.deviceCuStateVec,
CU_AMP_IN_STATE_PREC,
qureg.numQubitsInStateVec,
Expand All @@ -648,7 +685,7 @@ void statevec_initBlankState(Qureg qureg)
void statevec_initPlusState(Qureg qureg)
{
custatevecInitializeStateVector(
qureg.cuQuantumHandle,
qureg.cuConfig->cuQuantumHandle,
qureg.deviceCuStateVec,
CU_AMP_IN_STATE_PREC,
qureg.numQubitsInStateVec,
Expand Down Expand Up @@ -768,7 +805,7 @@ void statevec_multiControlledUnitary(Qureg qureg, long long int ctrlQubitsMask,
ctrlVals[i] = !(ctrlFlipMask & (1LL<<ctrlInds[i]));

custatevecApplyMatrix(
qureg.cuQuantumHandle,
qureg.cuConfig->cuQuantumHandle,
qureg.deviceCuStateVec, CU_AMP_IN_STATE_PREC, qureg.numQubitsInStateVec,
toCuMatr(u).data(), CU_AMP_IN_MATRIX_PREC, CUSTATEVEC_MATRIX_LAYOUT_ROW, 0,
targs, 1, ctrlInds.data(), ctrlVals.data(), ctrlInds.size(),
Expand Down Expand Up @@ -866,7 +903,7 @@ void statevec_multiRotateZ(Qureg qureg, long long int mask, qreal angle)
std::vector<custatevecPauli_t> paulis(targs.size(), CUSTATEVEC_PAULI_Z);

custatevecApplyPauliRotation(
qureg.cuQuantumHandle, qureg.deviceCuStateVec,
qureg.cuConfig->cuQuantumHandle, qureg.deviceCuStateVec,
CU_AMP_IN_STATE_PREC, qureg.numQubitsInStateVec,
theta, paulis.data(), targs.data(), targs.size(),
nullptr, nullptr, 0);
Expand All @@ -880,7 +917,7 @@ void statevec_multiControlledMultiRotateZ(Qureg qureg, long long int ctrlMask, l
std::vector<custatevecPauli_t> paulis(targs.size(), CUSTATEVEC_PAULI_Z);

custatevecApplyPauliRotation(
qureg.cuQuantumHandle, qureg.deviceCuStateVec,
qureg.cuConfig->cuQuantumHandle, qureg.deviceCuStateVec,
CU_AMP_IN_STATE_PREC, qureg.numQubitsInStateVec,
theta, paulis.data(), targs.data(), targs.size(),
ctrls.data(), nullptr, ctrls.size());
Expand Down Expand Up @@ -912,7 +949,7 @@ void statevec_swapQubitAmps(Qureg qureg, int qb1, int qb2)
int numPairs = 1;

custatevecSwapIndexBits(
qureg.cuQuantumHandle, qureg.deviceCuStateVec,
qureg.cuConfig->cuQuantumHandle, qureg.deviceCuStateVec,
CU_AMP_IN_STATE_PREC, qureg.numQubitsInStateVec,
targPairs, numPairs,
nullptr, nullptr, 0);
Expand Down Expand Up @@ -1100,7 +1137,7 @@ qreal statevec_calcTotalProb(Qureg qureg)
int numBasisBits = 1;

custatevecAbs2SumOnZBasis(
qureg.cuQuantumHandle, qureg.deviceCuStateVec,
qureg.cuConfig->cuQuantumHandle, qureg.deviceCuStateVec,
CU_AMP_IN_STATE_PREC, qureg.numQubitsInStateVec,
&abs2sum0, &abs2sum1, basisBits, numBasisBits);

Expand All @@ -1117,7 +1154,7 @@ qreal statevec_calcProbOfOutcome(Qureg qureg, int measureQubit, int outcome)
int numBasisBits = 1;

custatevecAbs2SumOnZBasis(
qureg.cuQuantumHandle, qureg.deviceCuStateVec,
qureg.cuConfig->cuQuantumHandle, qureg.deviceCuStateVec,
CU_AMP_IN_STATE_PREC, qureg.numQubitsInStateVec,
&prob0, prob1, basisBits, numBasisBits);

Expand Down Expand Up @@ -1230,7 +1267,7 @@ void statevec_collapseToKnownProbOutcome(Qureg qureg, int measureQubit, int outc
int basisBits[] = {measureQubit};

custatevecCollapseOnZBasis(
qureg.cuQuantumHandle, qureg.deviceCuStateVec,
qureg.cuConfig->cuQuantumHandle, qureg.deviceCuStateVec,
CU_AMP_IN_STATE_PREC, qureg.numQubitsInStateVec,
outcome, basisBits, 1, outcomeProb);
}
Expand Down

0 comments on commit f7c6ee6

Please sign in to comment.