diff --git a/gematria/basic_block/basic_block.cc b/gematria/basic_block/basic_block.cc index 466cab16..55233a9c 100644 --- a/gematria/basic_block/basic_block.cc +++ b/gematria/basic_block/basic_block.cc @@ -377,11 +377,17 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction) { return os; } -BasicBlock::BasicBlock(std::vector instructions) - : instructions(std::move(instructions)) {} +BasicBlock::BasicBlock(std::vector instructions, + std::vector preceding_context, + std::vector following_context) + : instructions(std::move(instructions)), + preceding_context(std::move(preceding_context)), + following_context(std::move(following_context)) {} bool BasicBlock::operator==(const BasicBlock& other) const { - return instructions == other.instructions; + return instructions == other.instructions && + preceding_context == other.preceding_context && + following_context == other.following_context; } std::string BasicBlock::ToString() const { @@ -394,6 +400,26 @@ std::string BasicBlock::ToString() const { } if (buffer.back() == ' ') buffer.pop_back(); buffer += "))"; + if (!preceding_context.empty()) buffer += ", "; + } + if (!preceding_context.empty()) { + buffer += "preceding_context=InstructionList(("; + for (const Instruction& instruction : preceding_context) { + buffer += instruction.ToString(); + buffer += ", "; + } + if (buffer.back() == ' ') buffer.pop_back(); + buffer += "))"; + if (!following_context.empty()) buffer += ", "; + } + if (!following_context.empty()) { + buffer += "following_context=InstructionList(("; + for (const Instruction& instruction : following_context) { + buffer += instruction.ToString(); + buffer += ", "; + } + if (buffer.back() == ' ') buffer.pop_back(); + buffer += "))"; } buffer.push_back(')'); return buffer; diff --git a/gematria/basic_block/basic_block.h b/gematria/basic_block/basic_block.h index 43579fb7..6a8c01bf 100644 --- a/gematria/basic_block/basic_block.h +++ b/gematria/basic_block/basic_block.h @@ -221,7 +221,7 @@ std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand); // Represents an annotation holding a value such as some measure/statistic // paired with the instruction. struct Annotation { - Annotation() : value(-1){}; + Annotation() : value(-1) {}; // Initializes all fields of the annotation. Annotation(std::string name, double value); @@ -324,9 +324,12 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction); struct BasicBlock { BasicBlock() {} - // Initializes the basic block from a list of instructions. Needed for - // compatibility with the Python code. - explicit BasicBlock(std::vector instructions); + // Initializes the basic block from a list of instructions and optional + // context. Needed for compatibility with the Python code. + explicit BasicBlock( + std::vector instructions, + std::vector preceding_context = std::vector(), + std::vector following_context = std::vector()); BasicBlock(const BasicBlock&) = default; BasicBlock(BasicBlock&&) = default; @@ -346,6 +349,11 @@ struct BasicBlock { // The list of instructions in the basic block. std::vector instructions; + + // The preceding and following context instructions, i.e. those preceeding and + // following the instructions in the basic block. + std::vector preceding_context; + std::vector following_context; }; std::ostream& operator<<(std::ostream& os, const BasicBlock& block); diff --git a/gematria/basic_block/basic_block_protos.cc b/gematria/basic_block/basic_block_protos.cc index 7b8c0a0c..b6e441ce 100644 --- a/gematria/basic_block/basic_block_protos.cc +++ b/gematria/basic_block/basic_block_protos.cc @@ -180,8 +180,15 @@ CanonicalizedInstructionProto ProtoFromInstruction( BasicBlock BasicBlockFromProto(const BasicBlockProto& proto) { return BasicBlock( - /* instructions = */ ToVector( - proto.canonicalized_instructions(), InstructionFromProto)); + /* instructions = */ + ToVector(proto.canonicalized_instructions(), + InstructionFromProto), + /* preceding_context = */ + ToVector(proto.canonicalized_preceding_context(), + InstructionFromProto), + /* following_context = */ + ToVector(proto.canonicalized_following_context(), + InstructionFromProto)); } } // namespace gematria diff --git a/gematria/basic_block/basic_block_test.cc b/gematria/basic_block/basic_block_test.cc index 3977c560..68bb6b72 100644 --- a/gematria/basic_block/basic_block_test.cc +++ b/gematria/basic_block/basic_block_test.cc @@ -277,7 +277,7 @@ TEST(InstructionOperandTest, Equality) { TEST(InstructionOperandTest, ToString) { const struct { - InstructionOperand opernad; + InstructionOperand operand; const char* expected_string; } kTestCases[] = { {InstructionOperand::Register("RAX"), @@ -292,7 +292,7 @@ TEST(InstructionOperandTest, ToString) { "InstructionOperand.from_memory(32)"}}; for (const auto& test_case : kTestCases) { - EXPECT_EQ(test_case.opernad.ToString(), test_case.expected_string); + EXPECT_EQ(test_case.operand.ToString(), test_case.expected_string); } } @@ -318,7 +318,6 @@ TEST(InstructionOperandTest, AsTokenList) { } } -// TODO(virajbshah): Add tests for Annotation. TEST(AnnotationTest, Constructor) { constexpr char kName[] = "cache_miss_freq"; constexpr double kValue = 0.875; @@ -597,7 +596,7 @@ TEST(BasicBlockTest, ToString) { /* instruction_annotations = */ {Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)}); - BasicBlock block({instruction}); + BasicBlock block({instruction}, {instruction}, {instruction}); constexpr char kExpectedString[] = "BasicBlock(instructions=InstructionList((Instruction(mnemonic='ADC', " "llvm_mnemonic='ADC32rr', prefixes=('LOCK',), " @@ -607,8 +606,25 @@ TEST(BasicBlockTest, ToString) { "output_operands=(InstructionOperand.from_register('RAX'),), " "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " - "value=0.875),))," - ")))"; + "value=0.875),)),)), " + "preceding_context=InstructionList((Instruction(mnemonic='ADC', " + "llvm_mnemonic='ADC32rr', prefixes=('LOCK',), " + "input_operands=(InstructionOperand.from_register('RAX'), " + "InstructionOperand.from_register('RBX'),), " + "implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), " + "output_operands=(InstructionOperand.from_register('RAX'),), " + "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " + "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " + "value=0.875),)),)), " + "following_context=InstructionList((Instruction(mnemonic='ADC', " + "llvm_mnemonic='ADC32rr', prefixes=('LOCK',), " + "input_operands=(InstructionOperand.from_register('RAX'), " + "InstructionOperand.from_register('RBX'),), " + "implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), " + "output_operands=(InstructionOperand.from_register('RAX'),), " + "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " + "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " + "value=0.875),)),)))"; EXPECT_EQ(block.ToString(), kExpectedString); } diff --git a/gematria/basic_block/python/basic_block.cc b/gematria/basic_block/python/basic_block.cc index e78c6111..863eb847 100644 --- a/gematria/basic_block/python/basic_block.cc +++ b/gematria/basic_block/python/basic_block.cc @@ -255,9 +255,15 @@ PYBIND11_MODULE(basic_block, m) { py::class_ basic_block(m, "BasicBlock"); basic_block - .def(py::init /* instructions */>(), - py::arg("instructions") = std::vector()) + .def(py::init /* instructions */, + std::vector /* preceding_context */, + std::vector /* following_context */>(), + py::arg("instructions") = std::vector(), + py::arg("preceding_context") = std::vector(), + py::arg("following_context") = std::vector()) .def_readwrite("instructions", &BasicBlock::instructions) + .def_readwrite("preceding_context", &BasicBlock::preceding_context) + .def_readwrite("following_context", &BasicBlock::following_context) .def("__repr__", &BasicBlock::ToString) .def("__str__", &BasicBlock::ToString) .def("__eq__", &BasicBlock::operator==) diff --git a/gematria/proto/basic_block.proto b/gematria/proto/basic_block.proto index 4276518f..20528791 100644 --- a/gematria/proto/basic_block.proto +++ b/gematria/proto/basic_block.proto @@ -32,6 +32,24 @@ message BasicBlockProto { // The fingerprint-id of this basic block. Might be empty. string fingerprint = 3; + + // An optional list of machine instructions preceding the basic block, used + // to provide context that precedes `canonicalized_instructions`. These + // instructions are not included in the timing measurements and predictions. + repeated MachineInstructionProto machine_preceding_context = 4; + + // An optional list of machine instructions following the basic block, used + // to provide context following `canonicalized_instructions`. These + // instructions are not included in the timing measurements and predictions. + repeated MachineInstructionProto machine_following_context = 5; + + // Canonicalized instructions parallel to `machine_preceding_context`. May be + // empty in case no preceding context is provided. + repeated CanonicalizedInstructionProto canonicalized_preceding_context = 6; + + // Canonicalized instructions parallel to `machine_following_context`. May be + // empty in case no following context is provided. + repeated CanonicalizedInstructionProto canonicalized_following_context = 7; } // Represents a raw instruction extracted from binary code.