Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 143 additions & 32 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52829,6 +52829,91 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
return DAG.getSetCC(SDLoc(N), VT, Shift.getOperand(0), Ones, ISD::SETGT);
}

// Check whether this is a shuffle that interleaves the lanes of the two input
// vectors. e.g. when interleaving two v8i32 into a single v16i32 that mask is
// <0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23>. Indices are based
// on the target type.
static bool isLaneInterleaveMask(ArrayRef<int> Mask, MVT VT) {
assert(VT.isVector() && "Expected vector VT.");

MVT ElemVT = VT.getScalarType();
unsigned NumElts = VT.getVectorNumElements();
unsigned EltBits = ElemVT.getSizeInBits();

if (Mask.size() != NumElts)
return false;

// A lane is 128 bits.
if (EltBits == 0 || (128u % EltBits) != 0)
return false;

// So 4 for i32, 8 for i16, etc.
unsigned EltsPerLane = 128u / EltBits;
unsigned GroupSize = 2 * EltsPerLane;

if (NumElts % GroupSize != 0)
return false;

unsigned Pos = 0;
for (unsigned G = 0; G != (NumElts / GroupSize); ++G) {
// Indices are based on the output type, hence B starts at NumElts.
unsigned ABase = G * EltsPerLane;
unsigned BBase = NumElts + G * EltsPerLane;

for (unsigned I = 0; I != EltsPerLane; ++I)
if (Mask[Pos++] != (int)(ABase + I))
return false;

for (unsigned I = 0; I != EltsPerLane; ++I)
if (Mask[Pos++] != (int)(BBase + I))
return false;
}

return true;
}

// Check whether this is a shuffle that interleaves the lanes of the two input
// vectors. e.g. v16i32 that mask is <0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7,
// 20, 21, 22, 23>.
static bool isLaneInterleaveShuffle(MVT VT, SDValue Shuf, SDValue &A,
SDValue &B, const SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
// For the _mm_pack{u|s}s variants, the shuffle is trivial and therefore
// elided.
if (VT == MVT::v16i16 || VT == MVT::v8i32) {
if (Shuf.getOpcode() == ISD::CONCAT_VECTORS && Shuf.getNumOperands() == 2) {
A = Shuf->getOperand(0);
B = Shuf->getOperand(1);
return true;
}

return false;
}

auto *SVN = dyn_cast<ShuffleVectorSDNode>(Shuf.getNode());
if (!SVN)
return false;

ArrayRef<int> TargetMask = SVN->getMask();
SDValue V1 = SVN->getOperand(0);
SDValue V2 = SVN->getOperand(1);

if (isLaneInterleaveMask(TargetMask, VT)) {
auto peelConcat = [](SDValue V) -> SDValue {
if (V.getOpcode() == ISD::CONCAT_VECTORS && V.getNumOperands() == 2)
return V.getOperand(0);
return V;
};

// The upper half is undefined.
A = peelConcat(V1);
B = peelConcat(V2);
return true;
}

return false;
}

/// Detect patterns of truncation with unsigned saturation:
///
/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
Expand Down Expand Up @@ -52973,42 +53058,68 @@ static SDValue combineTruncateWithSat(SDValue In, EVT VT, const SDLoc &DL,
Subtarget);
}

if (!(SVT == MVT::i32 || SVT == MVT::i16 || SVT == MVT::i8))
return SDValue();

unsigned TruncOpc = 0;
SDValue SatVal;
if (SDValue SSatVal = detectSSatPattern(In, VT)) {
SatVal = SSatVal;
TruncOpc = X86ISD::VTRUNCS;
} else if (SDValue USatVal = detectUSatPattern(In, VT, DAG, DL)) {
SatVal = USatVal;
TruncOpc = X86ISD::VTRUNCUS;
} else {
return SDValue();
}

unsigned ResElts = VT.getVectorNumElements();

bool IsEpi16 = (SVT == MVT::i8 && InSVT == MVT::i16);
bool IsEpi32 = (SVT == MVT::i16 && InSVT == MVT::i32);

// Is there an adventageous pack given the current types and features?
unsigned Width = VT.getSizeInBits();
bool HasPackForWidth =
(Width == 128 && Subtarget.hasSSE41()) ||
(Width == 256 && Subtarget.hasAVX2()) ||
(Width == 512 && Subtarget.hasBWI() && Subtarget.hasVLX());

const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (TLI.isTypeLegal(InVT) && InVT.isVector() && SVT != MVT::i1 &&
Subtarget.hasAVX512() && (InSVT != MVT::i16 || Subtarget.hasBWI()) &&
(SVT == MVT::i32 || SVT == MVT::i16 || SVT == MVT::i8)) {
unsigned TruncOpc = 0;
SDValue SatVal;
if (SDValue SSatVal = detectSSatPattern(In, VT)) {
SatVal = SSatVal;
TruncOpc = X86ISD::VTRUNCS;
} else if (SDValue USatVal = detectUSatPattern(In, VT, DAG, DL)) {
SatVal = USatVal;
TruncOpc = X86ISD::VTRUNCUS;
}
if (SatVal) {
unsigned ResElts = VT.getVectorNumElements();
// If the input type is less than 512 bits and we don't have VLX, we need
// to widen to 512 bits.
if (!Subtarget.hasVLX() && !InVT.is512BitVector()) {
unsigned NumConcats = 512 / InVT.getSizeInBits();
ResElts *= NumConcats;
SmallVector<SDValue, 4> ConcatOps(NumConcats, DAG.getUNDEF(InVT));
ConcatOps[0] = SatVal;
InVT = EVT::getVectorVT(*DAG.getContext(), InSVT,
NumConcats * InVT.getVectorNumElements());
SatVal = DAG.getNode(ISD::CONCAT_VECTORS, DL, InVT, ConcatOps);
}
// Widen the result if its narrower than 128 bits.
if (ResElts * SVT.getSizeInBits() < 128)
ResElts = 128 / SVT.getSizeInBits();
EVT TruncVT = EVT::getVectorVT(*DAG.getContext(), SVT, ResElts);
SDValue Res = DAG.getNode(TruncOpc, DL, TruncVT, SatVal);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
DAG.getVectorIdxConstant(0, DL));
if (HasPackForWidth && (IsEpi16 || IsEpi32)) {
SDValue A, B;
if (isLaneInterleaveShuffle(InVT.getSimpleVT(), SatVal, A, B, DAG,
Subtarget)) {
unsigned PackOpc =
TruncOpc == X86ISD::VTRUNCS ? X86ISD::PACKSS : X86ISD::PACKUS;

return DAG.getNode(PackOpc, DL, VT, A, B);
}
}

if (TLI.isTypeLegal(InVT) && InVT.isVector() && SVT != MVT::i1 &&
Subtarget.hasAVX512() && (InSVT != MVT::i16 || Subtarget.hasBWI())) {

// If the input type is less than 512 bits and we don't have VLX, we
// need to widen to 512 bits.
if (!Subtarget.hasVLX() && !InVT.is512BitVector()) {
unsigned NumConcats = 512 / InVT.getSizeInBits();
ResElts *= NumConcats;
SmallVector<SDValue, 4> ConcatOps(NumConcats, DAG.getUNDEF(InVT));
ConcatOps[0] = SatVal;
InVT = EVT::getVectorVT(*DAG.getContext(), InSVT,
NumConcats * InVT.getVectorNumElements());
SatVal = DAG.getNode(ISD::CONCAT_VECTORS, DL, InVT, ConcatOps);
}
// Widen the result if its narrower than 128 bits.
if (ResElts * SVT.getSizeInBits() < 128)
ResElts = 128 / SVT.getSizeInBits();
EVT TruncVT = EVT::getVectorVT(*DAG.getContext(), SVT, ResElts);
SDValue Res = DAG.getNode(TruncOpc, DL, TruncVT, SatVal);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
DAG.getVectorIdxConstant(0, DL));
}

return SDValue();
}

Expand Down
6 changes: 0 additions & 6 deletions llvm/test/CodeGen/X86/combine-sub-usat.ll
Original file line number Diff line number Diff line change
Expand Up @@ -251,18 +251,12 @@ define <8 x i16> @combine_trunc_v8i32_v8i16(<8 x i16> %a0, <8 x i32> %a1) {
;
; SSE41-LABEL: combine_trunc_v8i32_v8i16:
; SSE41: # %bb.0:
; SSE41-NEXT: pmovsxbw {{.*#+}} xmm3 = [65535,0,65535,0,65535,0,65535,0]
; SSE41-NEXT: pminud %xmm3, %xmm2
; SSE41-NEXT: pminud %xmm3, %xmm1
; SSE41-NEXT: packusdw %xmm2, %xmm1
; SSE41-NEXT: psubusw %xmm1, %xmm0
; SSE41-NEXT: retq
;
; SSE42-LABEL: combine_trunc_v8i32_v8i16:
; SSE42: # %bb.0:
; SSE42-NEXT: pmovsxbw {{.*#+}} xmm3 = [65535,0,65535,0,65535,0,65535,0]
; SSE42-NEXT: pminud %xmm3, %xmm2
; SSE42-NEXT: pminud %xmm3, %xmm1
; SSE42-NEXT: packusdw %xmm2, %xmm1
; SSE42-NEXT: psubusw %xmm1, %xmm0
; SSE42-NEXT: retq
Expand Down
Loading