diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 127d1e36a9b..6c0d8ae6379 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -621,12 +621,14 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { return val != CI->getOperand(0); } // only the recv buffer and request is active for mpi isend/irecv - if (Name == "MPI_Irecv" || Name == "MPI_Isend" || Name == "PMPI_Irecv" || - Name == "PMPI_Isend") { + if (Name == "MPI_Irecv" || Name == "MPI_Isend" || Name == "MPI_Send_init" || Name == "MPI_Recv_init" || + Name == "PMPI_Irecv" || Name == "PMPI_Isend" || Name == "PMPI_Send_init" || Name == "PMPI_Recv_init") { return val != CI->getOperand(0) && val != CI->getOperand(6); } - // only request is active + if (Name == "MPI_Start" || Name == "PMPI_Start") + return val != CI->getOperand(0); + if (Name == "MPI_Wait" || Name == "PMPI_Wait") return val != CI->getOperand(0); diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 22b3318ced6..1f7cfbc8a71 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -44,8 +44,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, BuilderZ.setFastMathFlags(getFast()); // MPI send / recv can only send float/integers - if (funcName == "PMPI_Isend" || funcName == "MPI_Isend" || - funcName == "PMPI_Irecv" || funcName == "MPI_Irecv") { + if (funcName == "PMPI_Isend" || funcName == "MPI_Isend" || funcName == "PMPI_Send_init" || funcName == "MPI_Send_init" || + funcName == "PMPI_Irecv" || funcName == "MPI_Irecv" || funcName == "PMPI_Recv_init" || funcName == "MPI_Recv_init") { if (!gutils->isConstantInstruction(&call)) { if (Mode == DerivativeMode::ReverseModePrimal || Mode == DerivativeMode::ReverseModeCombined) { @@ -73,7 +73,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, getMPIMemberPtr(BuilderZ, impialloc, impi)); BuilderZ.CreateStore(impialloc, d_req); - if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") { + if (funcName == "MPI_Isend" || funcName == "PMPI_Isend" || + funcName == "MPI_Send_init" || funcName == "PMPI_Send_init") { Value *tysize = MPI_TYPE_SIZE(gutils->getNewFromOriginal(call.getOperand(2)), BuilderZ, call.getType(), called); @@ -134,12 +135,22 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, BuilderZ.CreatePointerCast(comm, getInt8PtrTy(call.getContext())), getMPIMemberPtr(BuilderZ, impialloc, impi)); + MPI_CallType callType; + if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") + callType = MPI_CallType::ISEND; + else if (funcName == "MPI_Irecv" || funcName == "PMPI_Irecv") + callType = MPI_CallType::IRECV; + else if (funcName == "MPI_Send_init" || funcName == "PMPI_Send_init") + callType = MPI_CallType::SEND_INIT; + else if (funcName == "MPI_Recv_init" || funcName == "PMPI_Recv_init") + callType = MPI_CallType::RECV_INIT; + else + assert(0 && "illegal mpi"); + BuilderZ.CreateStore( ConstantInt::get( Type::getInt8Ty(impialloc->getContext()), - (funcName == "MPI_Isend" || funcName == "PMPI_Isend") - ? (int)MPI_CallType::ISEND - : (int)MPI_CallType::IRECV), + (int)callType), getMPIMemberPtr(BuilderZ, impialloc, impi)); // TODO old } @@ -2178,6 +2189,23 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, return; } + // Adjoint of MPI_Send is to place a MPI_send at the corresponding + // location in the reverse. + if (func == "MPI_Send" || func == "PMPI_Send") { + if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ReverseModeCombined) { + IRBuilder<> Builder2(&call); + getReverseBuilder(Builder2); + auto callval = call.getCalledOperand(); + Value *args[] = { + lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2)}; + Builder2.CreateCall(call.getFunctionType(), callval, args); + } + if (Mode == DerivativeMode::ReverseModeGradient) + eraseIfUnused(call, /*erase*/ true, /*check*/ false); + return; + } + // Remove free's in forward pass so the comm can be used in the reverse // pass if (funcName == "MPI_Comm_free" || funcName == "MPI_Comm_disconnect") { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index a1f75b6e88d..16e4189ec03 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4952,8 +4952,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); return; } - if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" || - funcName == "PMPI_Isend" || funcName == "PMPI_Irecv") { + if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" || funcName == "MPI_Send_init" || funcName == "MPI_Recv_init" || + funcName == "PMPI_Isend" || funcName == "PMPI_Irecv" || funcName == "PMPI_Send_init" || funcName == "PMPI_Recv_init") { TypeTree buf = TypeTree(BaseType::Pointer); if (Constant *C = dyn_cast(call.getOperand(2))) { @@ -4989,6 +4989,10 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { TypeTree(BaseType::Pointer).Only(-1, &call), &call); return; } + if (funcName == "MPI_Start" || funcName == "PMPI_Start") { + // TODO + return; + } if (funcName == "MPI_Wait") { updateAnalysis(call.getOperand(0), TypeTree(BaseType::Pointer).Only(-1, &call), &call); diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 0b7c05a6f43..3790ed949b2 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -1771,6 +1771,9 @@ llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M, BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F); BasicBlock *isend = BasicBlock::Create(M.getContext(), "invertISend", F); BasicBlock *irecv = BasicBlock::Create(M.getContext(), "invertIRecv", F); + BasicBlock *send_init = BasicBlock::Create(M.getContext(), "invertSendInit", F); + BasicBlock *recv_init = BasicBlock::Create(M.getContext(), "invertRecvInit", F); + BasicBlock *error = BasicBlock::Create(M.getContext(), "invertError", F); #if 0 /*0 */getInt8PtrTy(call.getContext()) @@ -1831,10 +1834,25 @@ llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M, buf, count, datatype, source, tag, comm, d_req, }; - B.CreateCondBr(B.CreateICmpEQ(fn, ConstantInt::get(fn->getType(), - (int)MPI_CallType::ISEND)), - isend, irecv); + auto *SI = B.CreateSwitch(fn, error, 4); + + SI->addCase( + ConstantInt::get(fn->getType(), (int)MPI_CallType::ISEND), + isend); + SI->addCase( + ConstantInt::get(fn->getType(), (int)MPI_CallType::IRECV), + irecv); + SI->addCase( + ConstantInt::get(fn->getType(), (int)MPI_CallType::SEND_INIT), + send_init); + SI->addCase( + ConstantInt::get(fn->getType(), (int)MPI_CallType::RECV_INIT), + recv_init); + { + B.SetInsertPoint(error); + B.CreateUnreachable(); + } { B.SetInsertPoint(isend); auto fcall = B.CreateCall(irecvfn, args); @@ -1848,6 +1866,21 @@ llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M, fcall->setCallingConv(isendfn->getCallingConv()); B.CreateRetVoid(); } + + { + B.SetInsertPoint(send_init); + auto fcall = B.CreateCall(recv_initfn, args); + fcall->setCallingConv(isendfn->getCallingConv()); + B.CreateRetVoid(); + } + + { + B.SetInsertPoint(recv_init); + auto fcall = B.CreateCall(send_initfn, args); + fcall->setCallingConv(isendfn->getCallingConv()); + B.CreateRetVoid(); + } + return F; } diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index ead3d2baf75..c5592de3a48 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1132,6 +1132,8 @@ allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1, enum class MPI_CallType { ISEND = 1, IRECV = 2, + SEND_INIT = 3, + RECV_INIT = 4 }; enum class MPI_Elem {