Skip to content

Commit 12f9d4e

Browse files
committed
Added reduce_const_array_to_struct_pass to optimizer passes.
1 parent 4704f23 commit 12f9d4e

File tree

2 files changed

+306
-0
lines changed

2 files changed

+306
-0
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
// Copyright (c) 2019 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "reduce_const_array_to_struct_pass.h"
16+
17+
#include <set>
18+
19+
#include "source/opt/instruction.h"
20+
#include "source/opt/ir_context.h"
21+
22+
namespace spvtools {
23+
namespace opt {
24+
25+
Pass::Status ReduceConstArrayToStructPass::Process() {
26+
bool modified = false;
27+
std::vector<Instruction*> ArrayInst;
28+
29+
context()->module()->ForEachInst([&ArrayInst](Instruction* inst) {
30+
if (inst->opcode() == spv::Op::OpTypeArray) {
31+
ArrayInst.push_back(inst);
32+
}
33+
});
34+
35+
for (Instruction* inst : ArrayInst) {
36+
modified |= ReduceArray(inst);
37+
}
38+
39+
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
40+
}
41+
42+
bool ReduceConstArrayToStructPass::ReduceArray(Instruction* inst) {
43+
44+
Instruction* arrayType = context()->get_def_use_mgr()->GetDef(inst->GetOperand(0).words[0]);
45+
Instruction* structType = nullptr;
46+
Instruction* decorateType = nullptr;
47+
48+
// Look for structs which use the array type
49+
context()->get_def_use_mgr()->ForEachUser(arrayType, [&structType, &decorateType, &arrayType](Instruction* user) {
50+
if (user->opcode() == spv::Op::OpTypeStruct) {
51+
// Only consider structs that contains a single array
52+
if(user->GetOperand(1).words[0] == arrayType->GetOperand(0).words[0] && user->NumOperands() == 2) {
53+
structType = user;
54+
}
55+
}
56+
57+
if (user->opcode() == spv::Op::OpDecorate) {
58+
if (spv::Decoration(user->GetOperand(1).words[0]) == spv::Decoration::ArrayStride && user->GetOperand(2).words[0] == 16) {
59+
decorateType = user;
60+
}
61+
}
62+
});
63+
64+
if (structType == nullptr || decorateType == nullptr)
65+
return false;
66+
67+
bool bIsGlobal = false;
68+
69+
// We ignore global structures
70+
context()->get_def_use_mgr()->ForEachUser(structType, [&bIsGlobal](Instruction* user) {
71+
if (user->opcode() == spv::Op::OpName) {
72+
if (!user->GetOperand(1).AsString().compare("type.$Globals")) {
73+
bIsGlobal = true;
74+
}
75+
}
76+
});
77+
78+
if (bIsGlobal) {
79+
return false;
80+
}
81+
82+
Instruction* pointerType = nullptr;
83+
Instruction* memberDecorateType = nullptr;
84+
Instruction* memberNameType = nullptr;
85+
86+
// Find the instructions related to the structure
87+
context()->get_def_use_mgr()->ForEachUser(structType, [&pointerType, &memberDecorateType, &memberNameType, &structType](Instruction* user) {
88+
if (user->opcode() == spv::Op::OpTypePointer) {
89+
if(user->GetOperand(2).words[0] == structType->GetOperand(0).words[0]) {
90+
pointerType = user;
91+
}
92+
} else if (user->opcode() == spv::Op::OpMemberDecorate) {
93+
if (user->GetOperand(0).words[0] == structType->GetOperand(0).words[0]) {
94+
memberDecorateType = user;
95+
}
96+
} else if (user->opcode() == spv::Op::OpMemberName) {
97+
if (user->GetOperand(0).words[0] == structType->GetOperand(0).words[0]) {
98+
memberNameType = user;
99+
}
100+
}
101+
});
102+
103+
if (pointerType == nullptr) {
104+
return false;
105+
}
106+
107+
Instruction* variableType = nullptr;
108+
context()->get_def_use_mgr()->ForEachUser(
109+
pointerType, [&variableType, &pointerType](Instruction* user) {
110+
if (user->opcode() == spv::Op::OpVariable) {
111+
if (user->GetOperand(0).words[0] == pointerType->GetOperand(0).words[0]) {
112+
variableType = user;
113+
}
114+
}
115+
});
116+
117+
if (variableType == nullptr) {
118+
return false;
119+
}
120+
121+
struct AccessChainData {
122+
uint32_t constantValue;
123+
uint32_t offset;
124+
Instruction* accessChain;
125+
};
126+
127+
std::vector<AccessChainData> accessChains;
128+
bool bInvalid = false;
129+
130+
// Check for const access and that usage of variable is only OpAccessChain
131+
context()->get_def_use_mgr()->ForEachUser(
132+
variableType, [&variableType, &accessChains, &bInvalid, this](Instruction* user) {
133+
if (user->opcode() == spv::Op::OpAccessChain) {
134+
if (user->GetOperand(2).words[0] == variableType->GetOperand(1).words[0]) {
135+
if (user->NumOperands() < 5) {
136+
bInvalid = true;
137+
} else {
138+
Operand constOperand = user->GetOperand(4);
139+
const Instruction* ConstInst = context()->get_def_use_mgr()->GetDef(constOperand.words[0]);
140+
if (ConstInst->opcode() != spv::Op::OpConstant) {
141+
bInvalid = true;
142+
} else {
143+
uint32_t ConstVal = ConstInst->GetOperand(2).words[0];
144+
accessChains.push_back({ConstVal, ConstVal * 4 * 4, user});
145+
}
146+
}
147+
}
148+
} else if (user->opcode() == spv::Op::OpInBoundsAccessChain ||
149+
user->opcode() == spv::Op::OpPtrAccessChain) {
150+
bInvalid = true;
151+
}
152+
153+
});
154+
155+
std::sort(accessChains.begin(), accessChains.end(),
156+
[](const AccessChainData& a, const AccessChainData& b) -> bool {
157+
return a.offset < b.offset;
158+
});
159+
160+
if (bInvalid) {
161+
return false;
162+
}
163+
164+
std::vector<Instruction*> newOpMemberNames;
165+
std::vector<Instruction*> newOpMemberDecorates;
166+
std::string structName = memberNameType->GetOperand(2).AsString();
167+
std::map<uint32_t, uint32_t> uniqueKeys;
168+
169+
uint32_t n = 0;
170+
for (auto & AccessChainData : accessChains) {
171+
172+
// Create the OpMemberName instructions
173+
if (uniqueKeys.find(AccessChainData.constantValue) == uniqueKeys.end()) {
174+
{
175+
std::vector<Operand> operands;
176+
operands.push_back({SPV_OPERAND_TYPE_ID, {structType->GetOperand(0).words[0]}});
177+
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {n}});
178+
179+
std::string MemberName = structName + "_" + std::to_string(AccessChainData.constantValue);
180+
auto MemberNameVector = utils::MakeVector(MemberName);
181+
operands.push_back({SPV_OPERAND_TYPE_LITERAL_STRING, std::move(MemberNameVector)});
182+
183+
Instruction* NewVar = new Instruction(context(), spv::Op::OpMemberName, 0, 0, operands);
184+
newOpMemberNames.push_back(NewVar);
185+
}
186+
187+
// Create the OpMemberDecorate instructions
188+
{
189+
std::vector<Operand> operands;
190+
operands.push_back({SPV_OPERAND_TYPE_ID, {structType->GetOperand(0).words[0]}});
191+
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {n}});
192+
operands.push_back({SPV_OPERAND_TYPE_DECORATION, {uint32_t(spv::Decoration::Offset)}});
193+
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {AccessChainData.offset}});
194+
195+
Instruction* NewVar = new Instruction(context(), spv::Op::OpMemberDecorate, 0, 0, operands);
196+
newOpMemberDecorates.push_back(NewVar);
197+
}
198+
uniqueKeys.insert({AccessChainData.constantValue, n});
199+
200+
n++;
201+
}
202+
203+
// Create the new Accesses to the struct
204+
{
205+
analysis::Integer unsigned_int_type(32, false);
206+
analysis::Type* registered_unsigned_int_type = context()->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
207+
const analysis::Constant* NewConstant = context()->get_constant_mgr()->GetConstant(registered_unsigned_int_type, {uniqueKeys[AccessChainData.constantValue]});
208+
Instruction* ConstInst = context()->get_constant_mgr()->GetDefiningInstruction(NewConstant);
209+
210+
get_def_use_mgr()->AnalyzeInstDef(ConstInst);
211+
get_def_use_mgr()->AnalyzeInstUse(ConstInst);
212+
213+
std::vector<Operand> operands;
214+
215+
operands.push_back({SPV_OPERAND_TYPE_ID, {AccessChainData.accessChain->GetOperand(2).words[0]}});
216+
operands.push_back({SPV_OPERAND_TYPE_ID, {ConstInst->result_id()}});
217+
if(AccessChainData.accessChain->NumOperands() > 5) {
218+
operands.push_back({SPV_OPERAND_TYPE_ID, {AccessChainData.accessChain->GetOperand(5).words[0]}});
219+
}
220+
221+
Instruction* newVar = new Instruction(context(), spv::Op::OpAccessChain, AccessChainData.accessChain->GetOperand(0).words[0], AccessChainData.accessChain->result_id(), operands);
222+
223+
get_def_use_mgr()->AnalyzeInstDef(newVar);
224+
get_def_use_mgr()->AnalyzeInstUse(newVar);
225+
226+
newVar->InsertBefore(AccessChainData.accessChain);
227+
AccessChainData.accessChain->RemoveFromList();
228+
}
229+
}
230+
231+
for (Instruction* newMemberName : newOpMemberNames) {
232+
get_def_use_mgr()->AnalyzeInstDef(newMemberName);
233+
get_def_use_mgr()->AnalyzeInstUse(newMemberName);
234+
newMemberName->InsertBefore(memberNameType);
235+
}
236+
memberNameType->RemoveFromList();
237+
238+
for (Instruction* newMemberDecorate : newOpMemberDecorates) {
239+
240+
get_def_use_mgr()->AnalyzeInstDef(newMemberDecorate);
241+
get_def_use_mgr()->AnalyzeInstUse(newMemberDecorate);
242+
newMemberDecorate->InsertBefore(memberDecorateType);
243+
}
244+
memberDecorateType->RemoveFromList();
245+
246+
{
247+
std::vector<Operand> operands;
248+
for (uint32_t i = 0; i < uniqueKeys.size(); ++i) {
249+
operands.push_back(arrayType->GetOperand(1));
250+
}
251+
252+
// Create the new struct
253+
Instruction* newTypeStructVar = new Instruction(context(), spv::Op::OpTypeStruct, structType->GetOperand(0).words[0], 0, operands);
254+
255+
get_def_use_mgr()->AnalyzeInstDef(newTypeStructVar);
256+
get_def_use_mgr()->AnalyzeInstUse(newTypeStructVar);
257+
258+
newTypeStructVar->InsertBefore(structType);
259+
structType->RemoveFromList();
260+
}
261+
262+
return true;
263+
}
264+
265+
} // namespace opt
266+
} // namespace spvtools
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2019 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef SOURCE_OPT_REDUCE_CONST_ARRAY_TO_STRUCT_PASS_
16+
#define SOURCE_OPT_REDUCE_CONST_ARRAY_TO_STRUCT_PASS_
17+
18+
#include <unordered_map>
19+
20+
#include "source/opt/ir_context.h"
21+
#include "source/opt/module.h"
22+
#include "source/opt/pass.h"
23+
24+
namespace spvtools {
25+
namespace opt {
26+
27+
// This pass attempts to reduce array with constant access to structs to minimize size of CPU to GPU transfer
28+
class ReduceConstArrayToStructPass : public Pass {
29+
public:
30+
const char* name() const override { return "reduce-const-array-to-struct"; }
31+
Status Process() override;
32+
33+
private:
34+
bool ReduceArray(Instruction* inst);
35+
};
36+
37+
} // namespace opt
38+
} // namespace spvtools
39+
40+
#endif // SOURCE_OPT_ANDROID_DRIVER_PATCH_

0 commit comments

Comments
 (0)