Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions gematria/basic_block/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,17 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction) {
return os;
}

BasicBlock::BasicBlock(std::vector<Instruction> instructions)
: instructions(std::move(instructions)) {}
BasicBlock::BasicBlock(std::vector<Instruction> instructions,
std::vector<Instruction> preceding_context,
std::vector<Instruction> following_context)
: instructions(std::move(instructions)),
preceding_context(std::move(preceding_context)),
following_context(std::move(following_context)) {}

bool BasicBlock::operator==(const BasicBlock& other) const {
return instructions == other.instructions;
return instructions == other.instructions &&
preceding_context == other.preceding_context &&
following_context == other.following_context;
}

std::string BasicBlock::ToString() const {
Expand All @@ -394,6 +400,26 @@ std::string BasicBlock::ToString() const {
}
if (buffer.back() == ' ') buffer.pop_back();
buffer += "))";
if (!preceding_context.empty()) buffer += ", ";
}
if (!preceding_context.empty()) {
buffer += "preceding_context=InstructionList((";
for (const Instruction& instruction : preceding_context) {
buffer += instruction.ToString();
buffer += ", ";
}
if (buffer.back() == ' ') buffer.pop_back();
buffer += "))";
if (!following_context.empty()) buffer += ", ";
}
if (!following_context.empty()) {
buffer += "following_context=InstructionList((";
for (const Instruction& instruction : following_context) {
buffer += instruction.ToString();
buffer += ", ";
}
if (buffer.back() == ' ') buffer.pop_back();
buffer += "))";
}
buffer.push_back(')');
return buffer;
Expand Down
16 changes: 12 additions & 4 deletions gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand);
// Represents an annotation holding a value such as some measure/statistic
// paired with the instruction.
struct Annotation {
Annotation() : value(-1){};
Annotation() : value(-1) {};

// Initializes all fields of the annotation.
Annotation(std::string name, double value);
Expand Down Expand Up @@ -324,9 +324,12 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction);
struct BasicBlock {
BasicBlock() {}

// Initializes the basic block from a list of instructions. Needed for
// compatibility with the Python code.
explicit BasicBlock(std::vector<Instruction> instructions);
// Initializes the basic block from a list of instructions and optional
// context. Needed for compatibility with the Python code.
explicit BasicBlock(
std::vector<Instruction> instructions,
std::vector<Instruction> preceding_context = std::vector<Instruction>(),
std::vector<Instruction> following_context = std::vector<Instruction>());

BasicBlock(const BasicBlock&) = default;
BasicBlock(BasicBlock&&) = default;
Expand All @@ -346,6 +349,11 @@ struct BasicBlock {

// The list of instructions in the basic block.
std::vector<Instruction> instructions;

// The preceding and following context instructions, i.e. those preceeding and
// following the instructions in the basic block.
std::vector<Instruction> preceding_context;
std::vector<Instruction> following_context;
};

std::ostream& operator<<(std::ostream& os, const BasicBlock& block);
Expand Down
11 changes: 9 additions & 2 deletions gematria/basic_block/basic_block_protos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,15 @@ CanonicalizedInstructionProto ProtoFromInstruction(

BasicBlock BasicBlockFromProto(const BasicBlockProto& proto) {
return BasicBlock(
/* instructions = */ ToVector<Instruction>(
proto.canonicalized_instructions(), InstructionFromProto));
/* instructions = */
ToVector<Instruction>(proto.canonicalized_instructions(),
InstructionFromProto),
/* preceding_context = */
ToVector<Instruction>(proto.canonicalized_preceding_context(),
InstructionFromProto),
/* following_context = */
ToVector<Instruction>(proto.canonicalized_following_context(),
InstructionFromProto));
}

} // namespace gematria
28 changes: 22 additions & 6 deletions gematria/basic_block/basic_block_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ TEST(InstructionOperandTest, Equality) {

TEST(InstructionOperandTest, ToString) {
const struct {
InstructionOperand opernad;
InstructionOperand operand;
const char* expected_string;
} kTestCases[] = {
{InstructionOperand::Register("RAX"),
Expand All @@ -292,7 +292,7 @@ TEST(InstructionOperandTest, ToString) {
"InstructionOperand.from_memory(32)"}};

for (const auto& test_case : kTestCases) {
EXPECT_EQ(test_case.opernad.ToString(), test_case.expected_string);
EXPECT_EQ(test_case.operand.ToString(), test_case.expected_string);
}
}

Expand All @@ -318,7 +318,6 @@ TEST(InstructionOperandTest, AsTokenList) {
}
}

// TODO(virajbshah): Add tests for Annotation.
TEST(AnnotationTest, Constructor) {
constexpr char kName[] = "cache_miss_freq";
constexpr double kValue = 0.875;
Expand Down Expand Up @@ -597,7 +596,7 @@ TEST(BasicBlockTest, ToString) {
/* instruction_annotations = */
{Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)});

BasicBlock block({instruction});
BasicBlock block({instruction}, {instruction}, {instruction});
constexpr char kExpectedString[] =
"BasicBlock(instructions=InstructionList((Instruction(mnemonic='ADC', "
"llvm_mnemonic='ADC32rr', prefixes=('LOCK',), "
Expand All @@ -607,8 +606,25 @@ TEST(BasicBlockTest, ToString) {
"output_operands=(InstructionOperand.from_register('RAX'),), "
"implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), "
"instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', "
"value=0.875),)),"
")))";
"value=0.875),)),)), "
"preceding_context=InstructionList((Instruction(mnemonic='ADC', "
"llvm_mnemonic='ADC32rr', prefixes=('LOCK',), "
"input_operands=(InstructionOperand.from_register('RAX'), "
"InstructionOperand.from_register('RBX'),), "
"implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), "
"output_operands=(InstructionOperand.from_register('RAX'),), "
"implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), "
"instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', "
"value=0.875),)),)), "
"following_context=InstructionList((Instruction(mnemonic='ADC', "
"llvm_mnemonic='ADC32rr', prefixes=('LOCK',), "
"input_operands=(InstructionOperand.from_register('RAX'), "
"InstructionOperand.from_register('RBX'),), "
"implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), "
"output_operands=(InstructionOperand.from_register('RAX'),), "
"implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), "
"instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', "
"value=0.875),)),)))";
EXPECT_EQ(block.ToString(), kExpectedString);
}

Expand Down
10 changes: 8 additions & 2 deletions gematria/basic_block/python/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,15 @@ PYBIND11_MODULE(basic_block, m) {

py::class_<BasicBlock> basic_block(m, "BasicBlock");
basic_block
.def(py::init<std::vector<Instruction> /* instructions */>(),
py::arg("instructions") = std::vector<Instruction>())
.def(py::init<std::vector<Instruction> /* instructions */,
std::vector<Instruction> /* preceding_context */,
std::vector<Instruction> /* following_context */>(),
py::arg("instructions") = std::vector<Instruction>(),
py::arg("preceding_context") = std::vector<Instruction>(),
py::arg("following_context") = std::vector<Instruction>())
.def_readwrite("instructions", &BasicBlock::instructions)
.def_readwrite("preceding_context", &BasicBlock::preceding_context)
.def_readwrite("following_context", &BasicBlock::following_context)
.def("__repr__", &BasicBlock::ToString)
.def("__str__", &BasicBlock::ToString)
.def("__eq__", &BasicBlock::operator==)
Expand Down
Loading
Loading