Skip to content

Commit

Permalink
Add support for reading OpMatrixTimesScalar
Browse files Browse the repository at this point in the history
Consequently, reading for OpTypeMatrix is added as well.

Signed-off-by: Qinglai Xiao <q.xiao@think-silicon.com>
  • Loading branch information
jigsawecho authored and svenvh committed Sep 11, 2019
1 parent 03389df commit 9ed3c9d
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 2 deletions.
28 changes: 28 additions & 0 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool IsClassMember) {
case OpTypeVector:
return mapType(T, VectorType::get(transType(T->getVectorComponentType()),
T->getVectorComponentCount()));
case OpTypeMatrix:
return mapType(T, ArrayType::get(transType(T->getMatrixColumnType()),
T->getMatrixColumnCount()));
case OpTypeOpaque:
return mapType(T, StructType::create(*Context, T->getName()));
case OpTypeFunction: {
Expand Down Expand Up @@ -511,6 +514,9 @@ std::string SPIRVToLLVM::transTypeToOCLTypeName(SPIRVType *T, bool IsSigned) {
case OpTypeVector:
return transTypeToOCLTypeName(T->getVectorComponentType()) +
T->getVectorComponentCount();
case OpTypeMatrix:
return transTypeToOCLTypeName(T->getMatrixColumnType()) +
T->getMatrixColumnCount();
case OpTypeOpaque:
return T->getName();
case OpTypeFunction:
Expand Down Expand Up @@ -1228,6 +1234,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
switch (BV->getType()->getOpCode()) {
case OpTypeVector:
return mapValue(BV, ConstantVector::get(CV));
case OpTypeMatrix:
case OpTypeArray:
return mapValue(
BV, ConstantArray::get(dyn_cast<ArrayType>(transType(BCC->getType())),
Expand Down Expand Up @@ -1552,6 +1559,27 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
return mapValue(BV, Scale);
}

case OpMatrixTimesScalar: {
auto MTS = static_cast<SPIRVMatrixTimesScalar *>(BV);
IRBuilder<> Builder(BB);
auto Scalar = transValue(MTS->getScalar(), F, BB);
auto Matrix = transValue(MTS->getMatrix(), F, BB);
uint64_t ColNum = Matrix->getType()->getArrayNumElements();
auto ColType = cast<ArrayType>(Matrix->getType())->getElementType();
auto VecSize = ColType->getVectorNumElements();
auto NewVec = Builder.CreateVectorSplat(VecSize, Scalar, Scalar->getName());
NewVec->takeName(Scalar);

Value *V = UndefValue::get(Matrix->getType());
for (uint64_t Idx = 0; Idx != ColNum; Idx++) {
auto Col = Builder.CreateExtractValue(Matrix, Idx);
auto I = Builder.CreateFMul(Col, NewVec);
V = Builder.CreateInsertValue(V, I, Idx);
}

return mapValue(BV, V);
}

case OpCopyObject: {
SPIRVCopyObject *CO = static_cast<SPIRVCopyObject *>(BV);
AllocaInst *AI =
Expand Down
2 changes: 0 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,6 @@ template <spv::Op OC> bool isa(SPIRVEntry *E) {
#define _SPIRV_OP(x) typedef SPIRVEntryUnimplemented<Op##x> SPIRV##x;
_SPIRV_OP(Nop)
_SPIRV_OP(SourceContinued)
_SPIRV_OP(TypeMatrix)
_SPIRV_OP(TypeRuntimeArray)
_SPIRV_OP(SpecConstantTrue)
_SPIRV_OP(SpecConstantFalse)
Expand All @@ -788,7 +787,6 @@ _SPIRV_OP(QuantizeToF16)
_SPIRV_OP(Transpose)
_SPIRV_OP(ArrayLength)
_SPIRV_OP(SMod)
_SPIRV_OP(MatrixTimesScalar)
_SPIRV_OP(VectorTimesMatrix)
_SPIRV_OP(MatrixTimesVector)
_SPIRV_OP(MatrixTimesMatrix)
Expand Down
53 changes: 53 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,59 @@ class SPIRVVectorTimesScalar : public SPIRVInstruction {
SPIRVId Scalar;
};

class SPIRVMatrixTimesScalar : public SPIRVInstruction {
public:
static const Op OC = OpMatrixTimesScalar;
static const SPIRVWord FixedWordCount = 4;
// Complete constructor
SPIRVMatrixTimesScalar(SPIRVType *TheType, SPIRVId TheId, SPIRVId TheMatrix,
SPIRVId TheScalar, SPIRVBasicBlock *BB)
: SPIRVInstruction(5, OC, TheType, TheId, BB), Matrix(TheMatrix),
Scalar(TheScalar) {
validate();
assert(BB && "Invalid BB");
}
// Incomplete constructor
SPIRVMatrixTimesScalar()
: SPIRVInstruction(OC), Matrix(SPIRVID_INVALID), Scalar(SPIRVID_INVALID) {
}
SPIRVValue *getMatrix() const { return getValue(Matrix); }
SPIRVValue *getScalar() const { return getValue(Scalar); }

std::vector<SPIRVValue *> getOperands() override {
std::vector<SPIRVId> Operands;
Operands.push_back(Matrix);
Operands.push_back(Scalar);
return getValues(Operands);
}

void setWordCount(SPIRVWord FixedWordCount) override {
SPIRVEntry::setWordCount(FixedWordCount);
}

_SPIRV_DEF_ENCDEC4(Type, Id, Matrix, Scalar)

void validate() const override {
SPIRVInstruction::validate();
if (getValue(Matrix)->isForward() || getValue(Scalar)->isForward())
return;

SPIRVType *Ty = getType()->getScalarType();
SPIRVType *MTy = getValueType(Matrix)->getScalarType();
SPIRVType *STy = getValueType(Scalar);

assert(Ty->isTypeFloat() && "Invalid result type for OpMatrixTimesScalar");
assert(MTy->isTypeFloat() && "Invalid Matrix type for OpMatrixTimesScalar");
assert(STy->isTypeFloat() && "Invalid Scalar type for OpMatrixTimesScalar");

SPIRVInstruction::validate();
}

private:
SPIRVId Matrix;
SPIRVId Scalar;
};

class SPIRVUnary : public SPIRVInstTemplateBase {
protected:
void validate() const override {
Expand Down
12 changes: 12 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,10 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVId TheVector,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) override;
SPIRVInstruction *addMatrixTimesScalarInst(SPIRVType *TheType,
SPIRVId TheMatrix,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) override;
SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *,
SPIRVBasicBlock *) override;
SPIRVInstruction *addVariable(SPIRVType *, bool, SPIRVLinkageTypeKind,
Expand Down Expand Up @@ -1063,6 +1067,14 @@ SPIRVModuleImpl::addVectorTimesScalarInst(SPIRVType *TheType, SPIRVId TheVector,
new SPIRVVectorTimesScalar(TheType, getId(), TheVector, TheScalar, BB));
}

SPIRVInstruction *
SPIRVModuleImpl::addMatrixTimesScalarInst(SPIRVType *TheType, SPIRVId TheMatrix,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) {
return BB->addInstruction(
new SPIRVMatrixTimesScalar(TheType, getId(), TheMatrix, TheScalar, BB));
}

SPIRVInstruction *
SPIRVModuleImpl::addGroupInst(Op OpCode, SPIRVType *Type, Scope Scope,
const std::vector<SPIRVValue *> &Ops,
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ class SPIRVModule {
SPIRVId TheVector,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) = 0;
virtual SPIRVInstruction *addMatrixTimesScalarInst(SPIRVType *TheType,
SPIRVId TheMatrix,
SPIRVId TheScalar,
SPIRVBasicBlock *BB) = 0;
virtual SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *,
SPIRVBasicBlock *) = 0;
virtual SPIRVInstruction *addVariable(SPIRVType *, bool, SPIRVLinkageTypeKind,
Expand Down
30 changes: 30 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,36 @@ SPIRVType *SPIRVType::getVectorComponentType() const {
return static_cast<const SPIRVTypeVector *>(this)->getComponentType();
}

SPIRVWord SPIRVType::getMatrixColumnCount() const {
assert(OpCode == OpTypeMatrix && "Not matrix type");
return static_cast<const SPIRVTypeMatrix *const>(this)->getColumnCount();
}

SPIRVType *SPIRVType::getMatrixColumnType() const {
assert(OpCode == OpTypeMatrix && "Not matrix type");
return static_cast<const SPIRVTypeMatrix *const>(this)->getColumnType();
}

SPIRVType *SPIRVType::getScalarType() const {
switch (OpCode) {
case OpTypePointer:
return getPointerElementType()->getScalarType();
case OpTypeArray:
return getArrayElementType();
case OpTypeVector:
return getVectorComponentType();
case OpTypeMatrix:
return getMatrixColumnType()->getVectorComponentType();
case OpTypeInt:
case OpTypeFloat:
case OpTypeBool:
return const_cast<SPIRVType *>(this);
default:
break;
}
return nullptr;
}

bool SPIRVType::isTypeVoid() const { return OpCode == OpTypeVoid; }
bool SPIRVType::isTypeArray() const { return OpCode == OpTypeArray; }

Expand Down
45 changes: 45 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class SPIRVType : public SPIRVEntry {
SPIRVWord getStructMemberCount() const;
SPIRVWord getVectorComponentCount() const;
SPIRVType *getVectorComponentType() const;
SPIRVWord getMatrixColumnCount() const;
SPIRVType *getMatrixColumnType() const;
SPIRVType *getScalarType() const;

bool isTypeVoid() const;
bool isTypeArray() const;
Expand Down Expand Up @@ -310,6 +313,48 @@ class SPIRVTypeVector : public SPIRVType {
SPIRVWord CompCount; // Component Count
};

class SPIRVTypeMatrix : public SPIRVType {
public:
// Complete constructor
SPIRVTypeMatrix(SPIRVModule *M, SPIRVId TheId, SPIRVType *TheColType,
SPIRVWord TheColCount)
: SPIRVType(M, 4, OpTypeMatrix, TheId), ColType(TheColType),
ColCount(TheColCount) {
validate();
}
// Incomplete constructor
SPIRVTypeMatrix() : SPIRVType(OpTypeMatrix), ColType(nullptr), ColCount(0) {}

SPIRVType *getColumnType() const { return ColType; }
SPIRVWord getColumnCount() const { return ColCount; }

bool isValidIndex(SPIRVWord Index) const { return Index < ColCount; }

SPIRVCapVec getRequiredCapability() const override {
SPIRVCapVec V(getColumnType()->getRequiredCapability());
if (ColCount >= 8)
V.push_back(CapabilityVector16);
return V;
}

virtual std::vector<SPIRVEntry *> getNonLiteralOperands() const override {
return std::vector<SPIRVEntry *>(1, ColType);
}

void validate() const override {
SPIRVEntry::validate();
ColType->validate();
assert(ColCount >= 2);
}

protected:
_SPIRV_DEF_ENCDEC3(Id, ColType, ColCount)

private:
SPIRVType *ColType; // Column Type
SPIRVWord ColCount; // Column Count
};

class SPIRVConstant;
class SPIRVTypeArray : public SPIRVType {
public:
Expand Down
60 changes: 60 additions & 0 deletions test/matrix_times_scalar.spt
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
119734787 65536 458752 21 0
2 Capability Addresses
2 Capability Linkage
2 Capability Kernel
2 Capability Float64
2 Capability Matrix
3 MemoryModel 2 2
8 EntryPoint 6 20 "matrix_times_scalar"
3 Source 3 102000
3 Name 12 "res"
3 Name 13 "lhs"
3 Name 14 "rhs"

2 TypeVoid 5
3 TypeFloat 6 32
4 TypeVector 7 6 4
4 TypeMatrix 8 7 4
4 TypePointer 9 7 8
4 TypePointer 10 7 6
6 TypeFunction 11 5 9 9 10

5 Function 5 20 0 11
3 FunctionParameter 9 12
3 FunctionParameter 9 13
3 FunctionParameter 10 14

2 Label 15
4 Load 8 16 13
4 Load 6 17 14
5 MatrixTimesScalar 8 18 16 17
3 Store 12 18
1 Return

1 FunctionEnd

; FIXME: LIT comments/commands are moved at the end because llvm-spirv stops
; reading the file after first ';' symbol

; RUN: llvm-spirv %s -to-binary -o %t.spv
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r %t.spv -o %t.bc
; RUN: llvm-dis < %t.bc | FileCheck %s --check-prefix=CHECK-LLVM

; CHECK-LLVM: %1 = load [4 x <4 x float>], [4 x <4 x float>]* %lhs
; CHECK-LLVM: %2 = load float, float* %rhs
; CHECK-LLVM: %.splatinsert = insertelement <4 x float> undef, float %2, i32 0
; CHECK-LLVM: %3 = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer
; CHECK-LLVM: %4 = extractvalue [4 x <4 x float>] %1, 0
; CHECK-LLVM: %5 = fmul <4 x float> %4, %3
; CHECK-LLVM: %6 = insertvalue [4 x <4 x float>] undef, <4 x float> %5, 0
; CHECK-LLVM: %7 = extractvalue [4 x <4 x float>] %1, 1
; CHECK-LLVM: %8 = fmul <4 x float> %7, %3
; CHECK-LLVM: %9 = insertvalue [4 x <4 x float>] %6, <4 x float> %8, 1
; CHECK-LLVM: %10 = extractvalue [4 x <4 x float>] %1, 2
; CHECK-LLVM: %11 = fmul <4 x float> %10, %3
; CHECK-LLVM: %12 = insertvalue [4 x <4 x float>] %9, <4 x float> %11, 2
; CHECK-LLVM: %13 = extractvalue [4 x <4 x float>] %1, 3
; CHECK-LLVM: %14 = fmul <4 x float> %13, %3
; CHECK-LLVM: %15 = insertvalue [4 x <4 x float>] %12, <4 x float> %14, 3
; CHECK-LLVM: store [4 x <4 x float>] %15, [4 x <4 x float>]* %res

0 comments on commit 9ed3c9d

Please sign in to comment.