Skip to content

Commit

Permalink
Added reduce_const_array_to_struct_pass to optimizer passes.
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasBanana committed Mar 1, 2024
1 parent b6e5ce5 commit 656d7d8
Showing 2 changed files with 306 additions and 0 deletions.
266 changes: 266 additions & 0 deletions source/opt/reduce_const_array_to_struct_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "reduce_const_array_to_struct_pass.h"

#include <set>

#include "source/opt/instruction.h"
#include "source/opt/ir_context.h"

namespace spvtools {
namespace opt {

Pass::Status ReduceConstArrayToStructPass::Process() {
bool modified = false;
std::vector<Instruction*> ArrayInst;

context()->module()->ForEachInst([&ArrayInst](Instruction* inst) {
if (inst->opcode() == spv::Op::OpTypeArray) {
ArrayInst.push_back(inst);
}
});

for (Instruction* inst : ArrayInst) {
modified |= ReduceArray(inst);
}

return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}

bool ReduceConstArrayToStructPass::ReduceArray(Instruction* inst) {

Instruction* arrayType = context()->get_def_use_mgr()->GetDef(inst->GetOperand(0).words[0]);
Instruction* structType = nullptr;
Instruction* decorateType = nullptr;

// Look for structs which use the array type
context()->get_def_use_mgr()->ForEachUser(arrayType, [&structType, &decorateType, &arrayType](Instruction* user) {
if (user->opcode() == spv::Op::OpTypeStruct) {
// Only consider structs that contains a single array
if(user->GetOperand(1).words[0] == arrayType->GetOperand(0).words[0] && user->NumOperands() == 2) {
structType = user;
}
}

if (user->opcode() == spv::Op::OpDecorate) {
if (spv::Decoration(user->GetOperand(1).words[0]) == spv::Decoration::ArrayStride && user->GetOperand(2).words[0] == 16) {
decorateType = user;
}
}
});

if (structType == nullptr || decorateType == nullptr)
return false;

bool bIsGlobal = false;

// We ignore global structures
context()->get_def_use_mgr()->ForEachUser(structType, [&bIsGlobal](Instruction* user) {
if (user->opcode() == spv::Op::OpName) {
if (!user->GetOperand(1).AsString().compare("type.$Globals")) {
bIsGlobal = true;
}
}
});

if (bIsGlobal) {
return false;
}

Instruction* pointerType = nullptr;
Instruction* memberDecorateType = nullptr;
Instruction* memberNameType = nullptr;

// Find the instructions related to the structure
context()->get_def_use_mgr()->ForEachUser(structType, [&pointerType, &memberDecorateType, &memberNameType, &structType](Instruction* user) {
if (user->opcode() == spv::Op::OpTypePointer) {
if(user->GetOperand(2).words[0] == structType->GetOperand(0).words[0]) {
pointerType = user;
}
} else if (user->opcode() == spv::Op::OpMemberDecorate) {
if (user->GetOperand(0).words[0] == structType->GetOperand(0).words[0]) {
memberDecorateType = user;
}
} else if (user->opcode() == spv::Op::OpMemberName) {
if (user->GetOperand(0).words[0] == structType->GetOperand(0).words[0]) {
memberNameType = user;
}
}
});

if (pointerType == nullptr) {
return false;
}

Instruction* variableType = nullptr;
context()->get_def_use_mgr()->ForEachUser(
pointerType, [&variableType, &pointerType](Instruction* user) {
if (user->opcode() == spv::Op::OpVariable) {
if (user->GetOperand(0).words[0] == pointerType->GetOperand(0).words[0]) {
variableType = user;
}
}
});

if (variableType == nullptr) {
return false;
}

struct AccessChainData {
uint32_t constantValue;
uint32_t offset;
Instruction* accessChain;
};

std::vector<AccessChainData> accessChains;
bool bInvalid = false;

// Check for const access and that usage of variable is only OpAccessChain
context()->get_def_use_mgr()->ForEachUser(
variableType, [&variableType, &accessChains, &bInvalid, this](Instruction* user) {
if (user->opcode() == spv::Op::OpAccessChain) {
if (user->GetOperand(2).words[0] == variableType->GetOperand(1).words[0]) {
if (user->NumOperands() < 5) {
bInvalid = true;
} else {
Operand constOperand = user->GetOperand(4);
const Instruction* ConstInst = context()->get_def_use_mgr()->GetDef(constOperand.words[0]);
if (ConstInst->opcode() != spv::Op::OpConstant) {
bInvalid = true;
} else {
uint32_t ConstVal = ConstInst->GetOperand(2).words[0];
accessChains.push_back({ConstVal, ConstVal * 4 * 4, user});
}
}
}
} else if (user->opcode() == spv::Op::OpInBoundsAccessChain ||
user->opcode() == spv::Op::OpPtrAccessChain) {
bInvalid = true;
}

});

std::sort(accessChains.begin(), accessChains.end(),
[](const AccessChainData& a, const AccessChainData& b) -> bool {
return a.offset < b.offset;
});

if (bInvalid) {
return false;
}

std::vector<Instruction*> newOpMemberNames;
std::vector<Instruction*> newOpMemberDecorates;
std::string structName = memberNameType->GetOperand(2).AsString();
std::map<uint32_t, uint32_t> uniqueKeys;

uint32_t n = 0;
for (auto & AccessChainData : accessChains) {

// Create the OpMemberName instructions
if (uniqueKeys.find(AccessChainData.constantValue) == uniqueKeys.end()) {
{
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {structType->GetOperand(0).words[0]}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {n}});

std::string MemberName = structName + "_" + std::to_string(AccessChainData.constantValue);
auto MemberNameVector = utils::MakeVector(MemberName);
operands.push_back({SPV_OPERAND_TYPE_LITERAL_STRING, std::move(MemberNameVector)});

Instruction* NewVar = new Instruction(context(), spv::Op::OpMemberName, 0, 0, operands);
newOpMemberNames.push_back(NewVar);
}

// Create the OpMemberDecorate instructions
{
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {structType->GetOperand(0).words[0]}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {n}});
operands.push_back({SPV_OPERAND_TYPE_DECORATION, {uint32_t(spv::Decoration::Offset)}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {AccessChainData.offset}});

Instruction* NewVar = new Instruction(context(), spv::Op::OpMemberDecorate, 0, 0, operands);
newOpMemberDecorates.push_back(NewVar);
}
uniqueKeys.insert({AccessChainData.constantValue, n});

n++;
}

// Create the new Accesses to the struct
{
analysis::Integer unsigned_int_type(32, false);
analysis::Type* registered_unsigned_int_type = context()->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
const analysis::Constant* NewConstant = context()->get_constant_mgr()->GetConstant(registered_unsigned_int_type, {uniqueKeys[AccessChainData.constantValue]});
Instruction* ConstInst = context()->get_constant_mgr()->GetDefiningInstruction(NewConstant);

get_def_use_mgr()->AnalyzeInstDef(ConstInst);
get_def_use_mgr()->AnalyzeInstUse(ConstInst);

std::vector<Operand> operands;

operands.push_back({SPV_OPERAND_TYPE_ID, {AccessChainData.accessChain->GetOperand(2).words[0]}});
operands.push_back({SPV_OPERAND_TYPE_ID, {ConstInst->result_id()}});
if(AccessChainData.accessChain->NumOperands() > 5) {
operands.push_back({SPV_OPERAND_TYPE_ID, {AccessChainData.accessChain->GetOperand(5).words[0]}});
}

Instruction* newVar = new Instruction(context(), spv::Op::OpAccessChain, AccessChainData.accessChain->GetOperand(0).words[0], AccessChainData.accessChain->result_id(), operands);

get_def_use_mgr()->AnalyzeInstDef(newVar);
get_def_use_mgr()->AnalyzeInstUse(newVar);

newVar->InsertBefore(AccessChainData.accessChain);
AccessChainData.accessChain->RemoveFromList();
}
}

for (Instruction* newMemberName : newOpMemberNames) {
get_def_use_mgr()->AnalyzeInstDef(newMemberName);
get_def_use_mgr()->AnalyzeInstUse(newMemberName);
newMemberName->InsertBefore(memberNameType);
}
memberNameType->RemoveFromList();

for (Instruction* newMemberDecorate : newOpMemberDecorates) {

get_def_use_mgr()->AnalyzeInstDef(newMemberDecorate);
get_def_use_mgr()->AnalyzeInstUse(newMemberDecorate);
newMemberDecorate->InsertBefore(memberDecorateType);
}
memberDecorateType->RemoveFromList();

{
std::vector<Operand> operands;
for (uint32_t i = 0; i < uniqueKeys.size(); ++i) {
operands.push_back(arrayType->GetOperand(1));
}

// Create the new struct
Instruction* newTypeStructVar = new Instruction(context(), spv::Op::OpTypeStruct, structType->GetOperand(0).words[0], 0, operands);

get_def_use_mgr()->AnalyzeInstDef(newTypeStructVar);
get_def_use_mgr()->AnalyzeInstUse(newTypeStructVar);

newTypeStructVar->InsertBefore(structType);
structType->RemoveFromList();
}

return true;
}

} // namespace opt
} // namespace spvtools
40 changes: 40 additions & 0 deletions source/opt/reduce_const_array_to_struct_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SOURCE_OPT_REDUCE_CONST_ARRAY_TO_STRUCT_PASS_
#define SOURCE_OPT_REDUCE_CONST_ARRAY_TO_STRUCT_PASS_

#include <unordered_map>

#include "source/opt/ir_context.h"
#include "source/opt/module.h"
#include "source/opt/pass.h"

namespace spvtools {
namespace opt {

// This pass attempts to reduce array with constant access to structs to minimize size of CPU to GPU transfer
class ReduceConstArrayToStructPass : public Pass {
public:
const char* name() const override { return "reduce-const-array-to-struct"; }
Status Process() override;

private:
bool ReduceArray(Instruction* inst);
};

} // namespace opt
} // namespace spvtools

#endif // SOURCE_OPT_ANDROID_DRIVER_PATCH_

0 comments on commit 656d7d8

Please sign in to comment.