Skip to content

Commit

Permalink
Reland "[NVPTX] deprecate nvvm.rotate.* intrinsics, cleanup funnel-sh…
Browse files Browse the repository at this point in the history
…ift handling" (llvm#110025)

This change deprecates the following intrinsics which can be trivially
converted to llvm funnel-shift intrinsics:

- @llvm.nvvm.rotate.b32
- @llvm.nvvm.rotate.right.b64
- @llvm.nvvm.rotate.b64

This fixes a bug in the previous version (llvm#107655) which flipped the
order of the operands to the PTX funnel shift instruction. In LLVM IR
the high bits are the first arg and the low bits are the second arg,
while in PTX this is reversed.
  • Loading branch information
AlexMaclean authored and xgupta committed Oct 4, 2024
1 parent 74705ec commit 76d8f51
Show file tree
Hide file tree
Showing 9 changed files with 465 additions and 582 deletions.
16 changes: 0 additions & 16 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -4479,22 +4479,6 @@ def int_nvvm_sust_p_3d_v4i32_trap
"llvm.nvvm.sust.p.3d.v4i32.trap">,
ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">;


def int_nvvm_rotate_b32
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
[IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">,
ClangBuiltin<"__nvvm_rotate_b32">;

def int_nvvm_rotate_b64
: DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
[IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">,
ClangBuiltin<"__nvvm_rotate_b64">;

def int_nvvm_rotate_right_b64
: DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty],
[IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">,
ClangBuiltin<"__nvvm_rotate_right_b64">;

def int_nvvm_swap_lo_hi_b64
: DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty],
[IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">,
Expand Down
184 changes: 106 additions & 78 deletions llvm/lib/IR/AutoUpgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,9 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
// nvvm.bitcast.{f2i,i2f,ll2d,d2ll}
Expand =
Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll";
else if (Name.consume_front("rotate."))
// nvvm.rotate.{b32,b64,right.b64}
Expand = Name == "b32" || Name == "b64" || Name == "right.b64";
else
Expand = false;

Expand Down Expand Up @@ -2258,6 +2261,108 @@ void llvm::UpgradeInlineAsmString(std::string *AsmStr) {
}
}

static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
Function *F, IRBuilder<> &Builder) {
Value *Rep = nullptr;

if (Name == "abs.i" || Name == "abs.ll") {
Value *Arg = CI->getArgOperand(0);
Value *Neg = Builder.CreateNeg(Arg, "neg");
Value *Cmp = Builder.CreateICmpSGE(
Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
} else if (Name.starts_with("atomic.load.add.f32.p") ||
Name.starts_with("atomic.load.add.f64.p")) {
Value *Ptr = CI->getArgOperand(0);
Value *Val = CI->getArgOperand(1);
Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(),
AtomicOrdering::SequentiallyConsistent);
} else if (Name.consume_front("max.") &&
(Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
Name == "ui" || Name == "ull")) {
Value *Arg0 = CI->getArgOperand(0);
Value *Arg1 = CI->getArgOperand(1);
Value *Cmp = Name.starts_with("u")
? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond")
: Builder.CreateICmpSGE(Arg0, Arg1, "max.cond");
Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max");
} else if (Name.consume_front("min.") &&
(Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
Name == "ui" || Name == "ull")) {
Value *Arg0 = CI->getArgOperand(0);
Value *Arg1 = CI->getArgOperand(1);
Value *Cmp = Name.starts_with("u")
? Builder.CreateICmpULE(Arg0, Arg1, "min.cond")
: Builder.CreateICmpSLE(Arg0, Arg1, "min.cond");
Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min");
} else if (Name == "clz.ll") {
// llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64.
Value *Arg = CI->getArgOperand(0);
Value *Ctlz = Builder.CreateCall(
Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz,
{Arg->getType()}),
{Arg, Builder.getFalse()}, "ctlz");
Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc");
} else if (Name == "popc.ll") {
// llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an
// i64.
Value *Arg = CI->getArgOperand(0);
Value *Popc = Builder.CreateCall(
Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop,
{Arg->getType()}),
Arg, "ctpop");
Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc");
} else if (Name == "h2f") {
Rep = Builder.CreateCall(
Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16,
{Builder.getFloatTy()}),
CI->getArgOperand(0), "h2f");
} else if (Name.consume_front("bitcast.") &&
(Name == "f2i" || Name == "i2f" || Name == "ll2d" ||
Name == "d2ll")) {
Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType());
} else if (Name == "rotate.b32") {
Value *Arg = CI->getOperand(0);
Value *ShiftAmt = CI->getOperand(1);
Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl,
{Arg, Arg, ShiftAmt});
} else if (Name == "rotate.b64") {
Type *Int64Ty = Builder.getInt64Ty();
Value *Arg = CI->getOperand(0);
Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty);
Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl,
{Arg, Arg, ZExtShiftAmt});
} else if (Name == "rotate.right.b64") {
Type *Int64Ty = Builder.getInt64Ty();
Value *Arg = CI->getOperand(0);
Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty);
Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr,
{Arg, Arg, ZExtShiftAmt});
} else {
Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
if (IID != Intrinsic::not_intrinsic &&
!F->getReturnType()->getScalarType()->isBFloatTy()) {
rename(F);
Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
SmallVector<Value *, 2> Args;
for (size_t I = 0; I < NewFn->arg_size(); ++I) {
Value *Arg = CI->getArgOperand(I);
Type *OldType = Arg->getType();
Type *NewType = NewFn->getArg(I)->getType();
Args.push_back(
(OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy())
? Builder.CreateBitCast(Arg, NewType)
: Arg);
}
Rep = Builder.CreateCall(NewFn, Args);
if (F->getReturnType()->isIntegerTy())
Rep = Builder.CreateBitCast(Rep, F->getReturnType());
}
}

return Rep;
}

static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F,
IRBuilder<> &Builder) {
LLVMContext &C = F->getContext();
Expand Down Expand Up @@ -4208,85 +4313,8 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {

if (!IsX86 && Name == "stackprotectorcheck") {
Rep = nullptr;
} else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) {
Value *Arg = CI->getArgOperand(0);
Value *Neg = Builder.CreateNeg(Arg, "neg");
Value *Cmp = Builder.CreateICmpSGE(
Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond");
Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs");
} else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") ||
Name.starts_with("atomic.load.add.f64.p"))) {
Value *Ptr = CI->getArgOperand(0);
Value *Val = CI->getArgOperand(1);
Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(),
AtomicOrdering::SequentiallyConsistent);
} else if (IsNVVM && Name.consume_front("max.") &&
(Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
Name == "ui" || Name == "ull")) {
Value *Arg0 = CI->getArgOperand(0);
Value *Arg1 = CI->getArgOperand(1);
Value *Cmp = Name.starts_with("u")
? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond")
: Builder.CreateICmpSGE(Arg0, Arg1, "max.cond");
Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max");
} else if (IsNVVM && Name.consume_front("min.") &&
(Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
Name == "ui" || Name == "ull")) {
Value *Arg0 = CI->getArgOperand(0);
Value *Arg1 = CI->getArgOperand(1);
Value *Cmp = Name.starts_with("u")
? Builder.CreateICmpULE(Arg0, Arg1, "min.cond")
: Builder.CreateICmpSLE(Arg0, Arg1, "min.cond");
Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min");
} else if (IsNVVM && Name == "clz.ll") {
// llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64.
Value *Arg = CI->getArgOperand(0);
Value *Ctlz = Builder.CreateCall(
Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz,
{Arg->getType()}),
{Arg, Builder.getFalse()}, "ctlz");
Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc");
} else if (IsNVVM && Name == "popc.ll") {
// llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an
// i64.
Value *Arg = CI->getArgOperand(0);
Value *Popc = Builder.CreateCall(
Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop,
{Arg->getType()}),
Arg, "ctpop");
Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc");
} else if (IsNVVM) {
if (Name == "h2f") {
Rep =
Builder.CreateCall(Intrinsic::getDeclaration(
F->getParent(), Intrinsic::convert_from_fp16,
{Builder.getFloatTy()}),
CI->getArgOperand(0), "h2f");
} else if (Name.consume_front("bitcast.") &&
(Name == "f2i" || Name == "i2f" || Name == "ll2d" ||
Name == "d2ll")) {
Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType());
} else {
Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
if (IID != Intrinsic::not_intrinsic &&
!F->getReturnType()->getScalarType()->isBFloatTy()) {
rename(F);
NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
SmallVector<Value *, 2> Args;
for (size_t I = 0; I < NewFn->arg_size(); ++I) {
Value *Arg = CI->getArgOperand(I);
Type *OldType = Arg->getType();
Type *NewType = NewFn->getArg(I)->getType();
Args.push_back((OldType->isIntegerTy() &&
NewType->getScalarType()->isBFloatTy())
? Builder.CreateBitCast(Arg, NewType)
: Arg);
}
Rep = Builder.CreateCall(NewFn, Args);
if (F->getReturnType()->isIntegerTy())
Rep = Builder.CreateBitCast(Rep, F->getReturnType());
}
}
Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder);
} else if (IsX86) {
Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder);
} else if (IsARM) {
Expand Down
33 changes: 13 additions & 20 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,20 +594,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);

// TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs
// that don't have h/w rotation we lower them to multi-instruction assembly.
// See ROT*_sw in NVPTXIntrInfo.td
setOperationAction(ISD::ROTL, MVT::i64, Legal);
setOperationAction(ISD::ROTR, MVT::i64, Legal);
setOperationAction(ISD::ROTL, MVT::i32, Legal);
setOperationAction(ISD::ROTR, MVT::i32, Legal);

setOperationAction(ISD::ROTL, MVT::i16, Expand);
setOperationAction(ISD::ROTL, MVT::v2i16, Expand);
setOperationAction(ISD::ROTR, MVT::i16, Expand);
setOperationAction(ISD::ROTR, MVT::v2i16, Expand);
setOperationAction(ISD::ROTL, MVT::i8, Expand);
setOperationAction(ISD::ROTR, MVT::i8, Expand);
setOperationAction({ISD::ROTL, ISD::ROTR},
{MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
Expand);

if (STI.hasHWROT32())
setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal);

setOperationAction(ISD::BSWAP, MVT::i16, Expand);

setOperationAction(ISD::BR_JT, MVT::Other, Custom);
Expand Down Expand Up @@ -958,8 +951,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::LDUV4)
MAKE_CASE(NVPTXISD::StoreV2)
MAKE_CASE(NVPTXISD::StoreV4)
MAKE_CASE(NVPTXISD::FUN_SHFL_CLAMP)
MAKE_CASE(NVPTXISD::FUN_SHFR_CLAMP)
MAKE_CASE(NVPTXISD::FSHL_CLAMP)
MAKE_CASE(NVPTXISD::FSHR_CLAMP)
MAKE_CASE(NVPTXISD::IMAD)
MAKE_CASE(NVPTXISD::BFE)
MAKE_CASE(NVPTXISD::BFI)
Expand Down Expand Up @@ -2490,8 +2483,8 @@ SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
// dLo = shf.r.clamp aLo, aHi, Amt

SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
SDValue Lo = DAG.getNode(NVPTXISD::FUN_SHFR_CLAMP, dl, VT, ShOpLo, ShOpHi,
ShAmt);
SDValue Lo =
DAG.getNode(NVPTXISD::FSHR_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);

SDValue Ops[2] = { Lo, Hi };
return DAG.getMergeValues(Ops, dl);
Expand Down Expand Up @@ -2549,8 +2542,8 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
// dHi = shf.l.clamp aLo, aHi, Amt
// dLo = aLo << Amt

SDValue Hi = DAG.getNode(NVPTXISD::FUN_SHFL_CLAMP, dl, VT, ShOpLo, ShOpHi,
ShAmt);
SDValue Hi =
DAG.getNode(NVPTXISD::FSHL_CLAMP, dl, VT, ShOpHi, ShOpLo, ShAmt);
SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);

SDValue Ops[2] = { Lo, Hi };
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ enum NodeType : unsigned {
CallSeqEnd,
CallPrototype,
ProxyReg,
FUN_SHFL_CLAMP,
FUN_SHFR_CLAMP,
FSHL_CLAMP,
FSHR_CLAMP,
MUL_WIDE_SIGNED,
MUL_WIDE_UNSIGNED,
IMAD,
Expand Down
Loading

0 comments on commit 76d8f51

Please sign in to comment.