diff --git a/llvm-spirv/lib/SPIRV/SPIRVInternal.h b/llvm-spirv/lib/SPIRV/SPIRVInternal.h index e752951794319..754b7f670fd48 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVInternal.h +++ b/llvm-spirv/lib/SPIRV/SPIRVInternal.h @@ -60,6 +60,7 @@ using namespace llvm; namespace llvm { class IntrinsicInst; +class IRBuilderBase; } namespace SPIRV { @@ -551,6 +552,10 @@ std::string mapLLVMTypeToOCLType(const Type *Ty, bool Signed, Type *PointerElementType = nullptr); SPIRVDecorate *mapPostfixToDecorate(StringRef Postfix, SPIRVEntry *Target); +/// Return vector V extended with poison elements to match the number of +/// components of NewType. +Value *extendVector(Value *V, FixedVectorType *NewType, IRBuilderBase &Builder); + /// Add decorations to a SPIR-V entry. /// \param Decs Each string is a postfix without _ at the beginning. SPIRVValue *addDecorations(SPIRVValue *Target, diff --git a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp index f674abd7b41a6..cac125acdb3fb 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp +++ b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp @@ -2309,10 +2309,37 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F, if (BB) { Builder.SetInsertPoint(BB); } - return mapValue(BV, Builder.CreateShuffleVector( - transValue(VS->getVector1(), F, BB), - transValue(VS->getVector2(), F, BB), - ConstantVector::get(Components), BV->getName())); + Value *Vec1 = transValue(VS->getVector1(), F, BB); + Value *Vec2 = transValue(VS->getVector2(), F, BB); + auto *Vec1Ty = cast(Vec1->getType()); + auto *Vec2Ty = cast(Vec2->getType()); + if (Vec1Ty->getNumElements() != Vec2Ty->getNumElements()) { + // LLVM's shufflevector requires that the two vector operands have the + // same type; SPIR-V's OpVectorShuffle allows the vector operands to + // differ in the number of components. Adjust for that by extending + // the smaller vector. + if (Vec1Ty->getNumElements() < Vec2Ty->getNumElements()) { + Vec1 = extendVector(Vec1, Vec2Ty, Builder); + // Extending Vec1 requires offsetting any Vec2 indices in Components by + // the number of new elements. + unsigned Offset = Vec2Ty->getNumElements() - Vec1Ty->getNumElements(); + unsigned Vec2Start = Vec1Ty->getNumElements(); + for (auto &C : Components) { + if (auto *CI = dyn_cast(C)) { + uint64_t V = CI->getZExtValue(); + if (V >= Vec2Start) { + // This is a Vec2 index; add the offset to it. + C = ConstantInt::get(Int32Ty, V + Offset); + } + } + } + } else { + Vec2 = extendVector(Vec2, Vec1Ty, Builder); + } + } + return mapValue( + BV, Builder.CreateShuffleVector( + Vec1, Vec2, ConstantVector::get(Components), BV->getName())); } case OpBitReverse: { diff --git a/llvm-spirv/lib/SPIRV/SPIRVUtil.cpp b/llvm-spirv/lib/SPIRV/SPIRVUtil.cpp index 6e2a0cae19e76..2deb589d7a4f7 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVUtil.cpp +++ b/llvm-spirv/lib/SPIRV/SPIRVUtil.cpp @@ -94,6 +94,24 @@ void removeFnAttr(CallInst *Call, Attribute::AttrKind Attr) { Call->removeFnAttr(Attr); } +Value *extendVector(Value *V, FixedVectorType *NewType, + IRBuilderBase &Builder) { + unsigned OldSize = cast(V->getType())->getNumElements(); + unsigned NewSize = NewType->getNumElements(); + assert(OldSize < NewSize); + std::vector Components; + IntegerType *Int32Ty = Builder.getInt32Ty(); + for (unsigned I = 0; I < NewSize; I++) { + if (I < OldSize) + Components.push_back(ConstantInt::get(Int32Ty, I)); + else + Components.push_back(PoisonValue::get(Int32Ty)); + } + + return Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()), + ConstantVector::get(Components), "vecext"); +} + void saveLLVMModule(Module *M, const std::string &OutputFile) { std::error_code EC; ToolOutputFile Out(OutputFile.c_str(), EC, sys::fs::OF_None); diff --git a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h index 733b77ac08d8b..3a2fec68606ae 100644 --- a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -2213,15 +2213,11 @@ class SPIRVVectorShuffleBase : public SPIRVInstTemplateBase { protected: void validate() const override { SPIRVInstruction::validate(); - SPIRVId Vector1 = Ops[0]; - SPIRVId Vector2 = Ops[1]; + [[maybe_unused]] SPIRVId Vector1 = Ops[0]; assert(OpCode == OpVectorShuffle); assert(Type->isTypeVector()); assert(Type->getVectorComponentType() == getValueType(Vector1)->getVectorComponentType()); - if (getValue(Vector1)->isForward() || getValue(Vector2)->isForward()) - return; - assert(getValueType(Vector1) == getValueType(Vector2)); assert(Ops.size() - 2 == Type->getVectorComponentCount()); } }; diff --git a/llvm-spirv/test/OpVectorShuffle.spvasm b/llvm-spirv/test/OpVectorShuffle.spvasm new file mode 100644 index 0000000000000..b648595396e0d --- /dev/null +++ b/llvm-spirv/test/OpVectorShuffle.spvasm @@ -0,0 +1,36 @@ +; REQUIRES: spirv-as +; RUN: spirv-as --target-env spv1.0 -o %t.spv %s +; RUN: spirv-val %t.spv +; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s + OpCapability Addresses + OpCapability Kernel + OpMemoryModel Physical32 OpenCL + OpEntryPoint Kernel %1 "testVecShuffle" + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %uintv2 = OpTypeVector %uint 2 + %uintv3 = OpTypeVector %uint 3 + %uintv4 = OpTypeVector %uint 4 + %func = OpTypeFunction %void %uintv2 %uintv3 + + %1 = OpFunction %void None %func + %pv2 = OpFunctionParameter %uintv2 + %pv3 = OpFunctionParameter %uintv3 + %entry = OpLabel + + ; Same vector lengths + %vs1 = OpVectorShuffle %uintv4 %pv3 %pv3 0 1 3 5 +; CHECK: shufflevector <3 x i32> %[[#]], <3 x i32> %[[#]], <4 x i32> + + ; vec1 smaller than vec2 + %vs2 = OpVectorShuffle %uintv4 %pv2 %pv3 0 1 3 4 +; CHECK: %[[VS2EXT:[0-9a-z]+]] = shufflevector <2 x i32> %0, <2 x i32> poison, <3 x i32> +; CHECK: shufflevector <3 x i32> %[[VS2EXT]], <3 x i32> %[[#]], <4 x i32> + + ; vec1 larger than vec2 + %vs3 = OpVectorShuffle %uintv4 %pv3 %pv2 0 1 3 4 +; CHECK: %[[VS3EXT:[0-9a-z]+]] = shufflevector <2 x i32> %0, <2 x i32> poison, <3 x i32> +; CHECK: shufflevector <3 x i32> %[[#]], <3 x i32> %[[VS3EXT]], <4 x i32> + + OpReturn + OpFunctionEnd