@@ -906,7 +906,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
906906 // 2. reduce sum diff(buffer) into intermediate
907907 // 3. if root, set shadow(buffer) = intermediate [memcpy] then free
908908 // 3-e. else, set shadow(buffer) = 0 [memset]
909- if (funcName == " MPI_Bcast" ) {
909+ if (funcName == " MPI_Bcast" || funcName == " PMPI_Bcast " ) {
910910 if (Mode == DerivativeMode::ReverseModeGradient ||
911911 Mode == DerivativeMode::ReverseModeCombined ||
912912 Mode == DerivativeMode::ForwardMode ||
@@ -1352,7 +1352,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
13521352 // int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
13531353 // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
13541354
1355- if (funcName == " MPI_Allreduce" ) {
1355+ if (funcName == " MPI_Allreduce" || funcName == " PMPI_Allreduce " ) {
13561356 if (Mode == DerivativeMode::ReverseModeGradient ||
13571357 Mode == DerivativeMode::ReverseModeCombined ||
13581358 Mode == DerivativeMode::ForwardMode ||
@@ -1533,7 +1533,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
15331533 // void *recvbuf, int recvcount, MPI_Datatype recvtype,
15341534 // int root, MPI_Comm comm)
15351535
1536- if (funcName == " MPI_Gather" ) {
1536+ if (funcName == " MPI_Gather" || funcName == " PMPI_Gather " ) {
15371537 if (Mode == DerivativeMode::ReverseModeGradient ||
15381538 Mode == DerivativeMode::ReverseModeCombined ||
15391539 Mode == DerivativeMode::ForwardMode ||
@@ -1738,7 +1738,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
17381738 // sendtype,
17391739 // void *recvbuf, int recvcount, MPI_Datatype recvtype, int root,
17401740 // MPI_Comm comm)
1741- if (funcName == " MPI_Scatter" ) {
1741+ if (funcName == " MPI_Scatter" || funcName == " PMPI_Scatter " ) {
17421742 if (Mode == DerivativeMode::ReverseModeGradient ||
17431743 Mode == DerivativeMode::ReverseModeCombined ||
17441744 Mode == DerivativeMode::ForwardMode ||
@@ -1978,7 +1978,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
19781978 // void *recvbuf, int recvcount, MPI_Datatype recvtype,
19791979 // MPI_Comm comm)
19801980
1981- if (funcName == " MPI_Allgather" ) {
1981+ if (funcName == " MPI_Allgather" || funcName == " PMPI_Allgather " ) {
19821982 if (Mode == DerivativeMode::ReverseModeGradient ||
19831983 Mode == DerivativeMode::ReverseModeCombined ||
19841984 Mode == DerivativeMode::ForwardMode ||
@@ -2163,7 +2163,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
21632163
21642164 // Adjoint of barrier is to place a barrier at the corresponding
21652165 // location in the reverse.
2166- if (funcName == " MPI_Barrier" ) {
2166+ if (funcName == " MPI_Barrier" || funcName == " PMPI_Barrier " ) {
21672167 if (Mode == DerivativeMode::ReverseModeGradient ||
21682168 Mode == DerivativeMode::ReverseModeCombined) {
21692169 IRBuilder<> Builder2 (&call);
0 commit comments