Skip to content

Commit 6b1848d

Browse files
authored
Support PMPI functions in CallDerivatives (#2530)
1 parent 2e6f771 commit 6b1848d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

enzyme/Enzyme/CallDerivatives.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
906906
// 2. reduce sum diff(buffer) into intermediate
907907
// 3. if root, set shadow(buffer) = intermediate [memcpy] then free
908908
// 3-e. else, set shadow(buffer) = 0 [memset]
909-
if (funcName == "MPI_Bcast") {
909+
if (funcName == "MPI_Bcast" || funcName == "PMPI_Bcast") {
910910
if (Mode == DerivativeMode::ReverseModeGradient ||
911911
Mode == DerivativeMode::ReverseModeCombined ||
912912
Mode == DerivativeMode::ForwardMode ||
@@ -1352,7 +1352,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
13521352
// int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
13531353
// MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
13541354

1355-
if (funcName == "MPI_Allreduce") {
1355+
if (funcName == "MPI_Allreduce" || funcName == "PMPI_Allreduce") {
13561356
if (Mode == DerivativeMode::ReverseModeGradient ||
13571357
Mode == DerivativeMode::ReverseModeCombined ||
13581358
Mode == DerivativeMode::ForwardMode ||
@@ -1533,7 +1533,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
15331533
// void *recvbuf, int recvcount, MPI_Datatype recvtype,
15341534
// int root, MPI_Comm comm)
15351535

1536-
if (funcName == "MPI_Gather") {
1536+
if (funcName == "MPI_Gather" || funcName == "PMPI_Gather") {
15371537
if (Mode == DerivativeMode::ReverseModeGradient ||
15381538
Mode == DerivativeMode::ReverseModeCombined ||
15391539
Mode == DerivativeMode::ForwardMode ||
@@ -1738,7 +1738,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
17381738
// sendtype,
17391739
// void *recvbuf, int recvcount, MPI_Datatype recvtype, int root,
17401740
// MPI_Comm comm)
1741-
if (funcName == "MPI_Scatter") {
1741+
if (funcName == "MPI_Scatter" || funcName == "PMPI_Scatter") {
17421742
if (Mode == DerivativeMode::ReverseModeGradient ||
17431743
Mode == DerivativeMode::ReverseModeCombined ||
17441744
Mode == DerivativeMode::ForwardMode ||
@@ -1978,7 +1978,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
19781978
// void *recvbuf, int recvcount, MPI_Datatype recvtype,
19791979
// MPI_Comm comm)
19801980

1981-
if (funcName == "MPI_Allgather") {
1981+
if (funcName == "MPI_Allgather" || funcName == "PMPI_Allgather") {
19821982
if (Mode == DerivativeMode::ReverseModeGradient ||
19831983
Mode == DerivativeMode::ReverseModeCombined ||
19841984
Mode == DerivativeMode::ForwardMode ||
@@ -2163,7 +2163,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
21632163

21642164
// Adjoint of barrier is to place a barrier at the corresponding
21652165
// location in the reverse.
2166-
if (funcName == "MPI_Barrier") {
2166+
if (funcName == "MPI_Barrier" || funcName == "PMPI_Barrier") {
21672167
if (Mode == DerivativeMode::ReverseModeGradient ||
21682168
Mode == DerivativeMode::ReverseModeCombined) {
21692169
IRBuilder<> Builder2(&call);

0 commit comments

Comments
 (0)