Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,12 +621,14 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) {
return val != CI->getOperand(0);
}
// only the recv buffer and request is active for mpi isend/irecv
if (Name == "MPI_Irecv" || Name == "MPI_Isend" || Name == "PMPI_Irecv" ||
Name == "PMPI_Isend") {
if (Name == "MPI_Irecv" || Name == "MPI_Isend" || Name == "MPI_Send_init" || Name == "MPI_Recv_init" ||
Name == "PMPI_Irecv" || Name == "PMPI_Isend" || Name == "PMPI_Send_init" || Name == "PMPI_Recv_init") {
return val != CI->getOperand(0) && val != CI->getOperand(6);
}

// only request is active
if (Name == "MPI_Start" || Name == "PMPI_Start")
return val != CI->getOperand(0);

if (Name == "MPI_Wait" || Name == "PMPI_Wait")
return val != CI->getOperand(0);

Expand Down
40 changes: 34 additions & 6 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
BuilderZ.setFastMathFlags(getFast());

// MPI send / recv can only send float/integers
if (funcName == "PMPI_Isend" || funcName == "MPI_Isend" ||
funcName == "PMPI_Irecv" || funcName == "MPI_Irecv") {
if (funcName == "PMPI_Isend" || funcName == "MPI_Isend" || funcName == "PMPI_Send_init" || funcName == "MPI_Send_init" ||
funcName == "PMPI_Irecv" || funcName == "MPI_Irecv" || funcName == "PMPI_Recv_init" || funcName == "MPI_Recv_init") {
if (!gutils->isConstantInstruction(&call)) {
if (Mode == DerivativeMode::ReverseModePrimal ||
Mode == DerivativeMode::ReverseModeCombined) {
Expand Down Expand Up @@ -73,7 +73,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
getMPIMemberPtr<MPI_Elem::Old>(BuilderZ, impialloc, impi));
BuilderZ.CreateStore(impialloc, d_req);

if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") {
if (funcName == "MPI_Isend" || funcName == "PMPI_Isend" ||
funcName == "MPI_Send_init" || funcName == "PMPI_Send_init") {
Value *tysize =
MPI_TYPE_SIZE(gutils->getNewFromOriginal(call.getOperand(2)),
BuilderZ, call.getType(), called);
Expand Down Expand Up @@ -134,12 +135,22 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
BuilderZ.CreatePointerCast(comm, getInt8PtrTy(call.getContext())),
getMPIMemberPtr<MPI_Elem::Comm>(BuilderZ, impialloc, impi));

MPI_CallType callType;
if (funcName == "MPI_Isend" || funcName == "PMPI_Isend")
callType = MPI_CallType::ISEND;
else if (funcName == "MPI_Irecv" || funcName == "PMPI_Irecv")
callType = MPI_CallType::IRECV;
else if (funcName == "MPI_Send_init" || funcName == "PMPI_Send_init")
callType = MPI_CallType::SEND_INIT;
else if (funcName == "MPI_Recv_init" || funcName == "PMPI_Recv_init")
callType = MPI_CallType::RECV_INIT;
else
assert(0 && "illegal mpi");

BuilderZ.CreateStore(
ConstantInt::get(
Type::getInt8Ty(impialloc->getContext()),
(funcName == "MPI_Isend" || funcName == "PMPI_Isend")
? (int)MPI_CallType::ISEND
: (int)MPI_CallType::IRECV),
(int)callType),
getMPIMemberPtr<MPI_Elem::Call>(BuilderZ, impialloc, impi));
// TODO old
}
Expand Down Expand Up @@ -2178,6 +2189,23 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
return;
}

// Adjoint of MPI_Send is to place a MPI_send at the corresponding
// location in the reverse.
if (func == "MPI_Send" || func == "PMPI_Send") {
if (Mode == DerivativeMode::ReverseModeGradient ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we miss this from earlier??

if so we should add a test

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this was a brain fart on my side this ought to be MPI_Start not MPI_Send

Mode == DerivativeMode::ReverseModeCombined) {
IRBuilder<> Builder2(&call);
getReverseBuilder(Builder2);
auto callval = call.getCalledOperand();
Value *args[] = {
lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2)};
Builder2.CreateCall(call.getFunctionType(), callval, args);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't cover fwd mode, so we shouldn't return in that case

}
if (Mode == DerivativeMode::ReverseModeGradient)
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
return;
}

// Remove free's in forward pass so the comm can be used in the reverse
// pass
if (funcName == "MPI_Comm_free" || funcName == "MPI_Comm_disconnect") {
Expand Down
8 changes: 6 additions & 2 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4952,8 +4952,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
return;
}
if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" ||
funcName == "PMPI_Isend" || funcName == "PMPI_Irecv") {
if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" || funcName == "MPI_Send_init" || funcName == "MPI_Recv_init" ||
funcName == "PMPI_Isend" || funcName == "PMPI_Irecv" || funcName == "PMPI_Send_init" || funcName == "PMPI_Recv_init") {
TypeTree buf = TypeTree(BaseType::Pointer);

if (Constant *C = dyn_cast<Constant>(call.getOperand(2))) {
Expand Down Expand Up @@ -4989,6 +4989,10 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
TypeTree(BaseType::Pointer).Only(-1, &call), &call);
return;
}
if (funcName == "MPI_Start" || funcName == "PMPI_Start") {
// TODO
return;
}
if (funcName == "MPI_Wait") {
updateAnalysis(call.getOperand(0),
TypeTree(BaseType::Pointer).Only(-1, &call), &call);
Expand Down
39 changes: 36 additions & 3 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,9 @@ llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M,
BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
BasicBlock *isend = BasicBlock::Create(M.getContext(), "invertISend", F);
BasicBlock *irecv = BasicBlock::Create(M.getContext(), "invertIRecv", F);
BasicBlock *send_init = BasicBlock::Create(M.getContext(), "invertSendInit", F);
BasicBlock *recv_init = BasicBlock::Create(M.getContext(), "invertRecvInit", F);
BasicBlock *error = BasicBlock::Create(M.getContext(), "invertError", F);

#if 0
/*0 */getInt8PtrTy(call.getContext())
Expand Down Expand Up @@ -1831,10 +1834,25 @@ llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M,
buf, count, datatype, source, tag, comm, d_req,
};

B.CreateCondBr(B.CreateICmpEQ(fn, ConstantInt::get(fn->getType(),
(int)MPI_CallType::ISEND)),
isend, irecv);
auto *SI = B.CreateSwitch(fn, error, 4);

SI->addCase(
ConstantInt::get(fn->getType(), (int)MPI_CallType::ISEND),
isend);
SI->addCase(
ConstantInt::get(fn->getType(), (int)MPI_CallType::IRECV),
irecv);
SI->addCase(
ConstantInt::get(fn->getType(), (int)MPI_CallType::SEND_INIT),
send_init);
SI->addCase(
ConstantInt::get(fn->getType(), (int)MPI_CallType::RECV_INIT),
recv_init);

{
B.SetInsertPoint(error);
B.CreateUnreachable();
}
{
B.SetInsertPoint(isend);
auto fcall = B.CreateCall(irecvfn, args);
Expand All @@ -1848,6 +1866,21 @@ llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M,
fcall->setCallingConv(isendfn->getCallingConv());
B.CreateRetVoid();
}

{
B.SetInsertPoint(send_init);
auto fcall = B.CreateCall(recv_initfn, args);
fcall->setCallingConv(isendfn->getCallingConv());
B.CreateRetVoid();
}

{
B.SetInsertPoint(recv_init);
auto fcall = B.CreateCall(send_initfn, args);
fcall->setCallingConv(isendfn->getCallingConv());
B.CreateRetVoid();
}

return F;
}

Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,8 @@ allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1,
enum class MPI_CallType {
ISEND = 1,
IRECV = 2,
SEND_INIT = 3,
RECV_INIT = 4
};

enum class MPI_Elem {
Expand Down
Loading