diff --git a/gematria/basic_block/basic_block.cc b/gematria/basic_block/basic_block.cc index 466cab16..55233a9c 100644 --- a/gematria/basic_block/basic_block.cc +++ b/gematria/basic_block/basic_block.cc @@ -377,11 +377,17 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction) { return os; } -BasicBlock::BasicBlock(std::vector instructions) - : instructions(std::move(instructions)) {} +BasicBlock::BasicBlock(std::vector instructions, + std::vector preceding_context, + std::vector following_context) + : instructions(std::move(instructions)), + preceding_context(std::move(preceding_context)), + following_context(std::move(following_context)) {} bool BasicBlock::operator==(const BasicBlock& other) const { - return instructions == other.instructions; + return instructions == other.instructions && + preceding_context == other.preceding_context && + following_context == other.following_context; } std::string BasicBlock::ToString() const { @@ -394,6 +400,26 @@ std::string BasicBlock::ToString() const { } if (buffer.back() == ' ') buffer.pop_back(); buffer += "))"; + if (!preceding_context.empty()) buffer += ", "; + } + if (!preceding_context.empty()) { + buffer += "preceding_context=InstructionList(("; + for (const Instruction& instruction : preceding_context) { + buffer += instruction.ToString(); + buffer += ", "; + } + if (buffer.back() == ' ') buffer.pop_back(); + buffer += "))"; + if (!following_context.empty()) buffer += ", "; + } + if (!following_context.empty()) { + buffer += "following_context=InstructionList(("; + for (const Instruction& instruction : following_context) { + buffer += instruction.ToString(); + buffer += ", "; + } + if (buffer.back() == ' ') buffer.pop_back(); + buffer += "))"; } buffer.push_back(')'); return buffer; diff --git a/gematria/basic_block/basic_block.h b/gematria/basic_block/basic_block.h index 43579fb7..6a8c01bf 100644 --- a/gematria/basic_block/basic_block.h +++ b/gematria/basic_block/basic_block.h @@ -221,7 +221,7 @@ std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand); // Represents an annotation holding a value such as some measure/statistic // paired with the instruction. struct Annotation { - Annotation() : value(-1){}; + Annotation() : value(-1) {}; // Initializes all fields of the annotation. Annotation(std::string name, double value); @@ -324,9 +324,12 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction); struct BasicBlock { BasicBlock() {} - // Initializes the basic block from a list of instructions. Needed for - // compatibility with the Python code. - explicit BasicBlock(std::vector instructions); + // Initializes the basic block from a list of instructions and optional + // context. Needed for compatibility with the Python code. + explicit BasicBlock( + std::vector instructions, + std::vector preceding_context = std::vector(), + std::vector following_context = std::vector()); BasicBlock(const BasicBlock&) = default; BasicBlock(BasicBlock&&) = default; @@ -346,6 +349,11 @@ struct BasicBlock { // The list of instructions in the basic block. std::vector instructions; + + // The preceding and following context instructions, i.e. those preceeding and + // following the instructions in the basic block. + std::vector preceding_context; + std::vector following_context; }; std::ostream& operator<<(std::ostream& os, const BasicBlock& block); diff --git a/gematria/basic_block/basic_block_protos.cc b/gematria/basic_block/basic_block_protos.cc index 7b8c0a0c..b6e441ce 100644 --- a/gematria/basic_block/basic_block_protos.cc +++ b/gematria/basic_block/basic_block_protos.cc @@ -180,8 +180,15 @@ CanonicalizedInstructionProto ProtoFromInstruction( BasicBlock BasicBlockFromProto(const BasicBlockProto& proto) { return BasicBlock( - /* instructions = */ ToVector( - proto.canonicalized_instructions(), InstructionFromProto)); + /* instructions = */ + ToVector(proto.canonicalized_instructions(), + InstructionFromProto), + /* preceding_context = */ + ToVector(proto.canonicalized_preceding_context(), + InstructionFromProto), + /* following_context = */ + ToVector(proto.canonicalized_following_context(), + InstructionFromProto)); } } // namespace gematria diff --git a/gematria/basic_block/basic_block_test.cc b/gematria/basic_block/basic_block_test.cc index 3977c560..68bb6b72 100644 --- a/gematria/basic_block/basic_block_test.cc +++ b/gematria/basic_block/basic_block_test.cc @@ -277,7 +277,7 @@ TEST(InstructionOperandTest, Equality) { TEST(InstructionOperandTest, ToString) { const struct { - InstructionOperand opernad; + InstructionOperand operand; const char* expected_string; } kTestCases[] = { {InstructionOperand::Register("RAX"), @@ -292,7 +292,7 @@ TEST(InstructionOperandTest, ToString) { "InstructionOperand.from_memory(32)"}}; for (const auto& test_case : kTestCases) { - EXPECT_EQ(test_case.opernad.ToString(), test_case.expected_string); + EXPECT_EQ(test_case.operand.ToString(), test_case.expected_string); } } @@ -318,7 +318,6 @@ TEST(InstructionOperandTest, AsTokenList) { } } -// TODO(virajbshah): Add tests for Annotation. TEST(AnnotationTest, Constructor) { constexpr char kName[] = "cache_miss_freq"; constexpr double kValue = 0.875; @@ -597,7 +596,7 @@ TEST(BasicBlockTest, ToString) { /* instruction_annotations = */ {Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)}); - BasicBlock block({instruction}); + BasicBlock block({instruction}, {instruction}, {instruction}); constexpr char kExpectedString[] = "BasicBlock(instructions=InstructionList((Instruction(mnemonic='ADC', " "llvm_mnemonic='ADC32rr', prefixes=('LOCK',), " @@ -607,8 +606,25 @@ TEST(BasicBlockTest, ToString) { "output_operands=(InstructionOperand.from_register('RAX'),), " "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " - "value=0.875),))," - ")))"; + "value=0.875),)),)), " + "preceding_context=InstructionList((Instruction(mnemonic='ADC', " + "llvm_mnemonic='ADC32rr', prefixes=('LOCK',), " + "input_operands=(InstructionOperand.from_register('RAX'), " + "InstructionOperand.from_register('RBX'),), " + "implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), " + "output_operands=(InstructionOperand.from_register('RAX'),), " + "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " + "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " + "value=0.875),)),)), " + "following_context=InstructionList((Instruction(mnemonic='ADC', " + "llvm_mnemonic='ADC32rr', prefixes=('LOCK',), " + "input_operands=(InstructionOperand.from_register('RAX'), " + "InstructionOperand.from_register('RBX'),), " + "implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), " + "output_operands=(InstructionOperand.from_register('RAX'),), " + "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " + "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " + "value=0.875),)),)))"; EXPECT_EQ(block.ToString(), kExpectedString); } diff --git a/gematria/basic_block/python/basic_block.cc b/gematria/basic_block/python/basic_block.cc index e78c6111..863eb847 100644 --- a/gematria/basic_block/python/basic_block.cc +++ b/gematria/basic_block/python/basic_block.cc @@ -255,9 +255,15 @@ PYBIND11_MODULE(basic_block, m) { py::class_ basic_block(m, "BasicBlock"); basic_block - .def(py::init /* instructions */>(), - py::arg("instructions") = std::vector()) + .def(py::init /* instructions */, + std::vector /* preceding_context */, + std::vector /* following_context */>(), + py::arg("instructions") = std::vector(), + py::arg("preceding_context") = std::vector(), + py::arg("following_context") = std::vector()) .def_readwrite("instructions", &BasicBlock::instructions) + .def_readwrite("preceding_context", &BasicBlock::preceding_context) + .def_readwrite("following_context", &BasicBlock::following_context) .def("__repr__", &BasicBlock::ToString) .def("__str__", &BasicBlock::ToString) .def("__eq__", &BasicBlock::operator==) diff --git a/gematria/granite/graph_builder.cc b/gematria/granite/graph_builder.cc index 30f6c4ab..6a0c913c 100644 --- a/gematria/granite/graph_builder.cc +++ b/gematria/granite/graph_builder.cc @@ -190,7 +190,9 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( } bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( - const std::vector& instructions) { + const std::vector& instructions, + const std::vector& preceding_context, + const std::vector& following_context) { if (instructions.empty()) return false; AddBasicBlockTransaction transaction(this); @@ -202,58 +204,70 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( const int prev_num_edges = num_edges(); NodeIndex previous_instruction_node = kInvalidNode; - for (const Instruction& instruction : instructions) { - // Add the instruction node. - const NodeIndex instruction_node = - AddNode(NodeType::kInstruction, instruction.mnemonic); - if (instruction_node == kInvalidNode) { - return false; - } + const struct { + const std::vector& instruction_group; + bool is_context; + } instruction_groups[] = {{preceding_context, true}, + {instructions, false}, + {following_context, true}}; + for (const auto& [instruction_group, is_context] : instruction_groups) { + for (const Instruction& instruction : instruction_group) { + // Add the instruction node. + const NodeIndex instruction_node = + AddNode(NodeType::kInstruction, instruction.mnemonic, is_context); + if (instruction_node == kInvalidNode) { + return false; + } - // Store the annotations for later use (inclusion in embeddings), using -1 - // as a default value wherever annotations are missing. - std::vector row = std::vector(annotation_names_.size(), -1); - for (const auto& [name, value] : instruction.instruction_annotations) { - const auto annotation_index = annotation_name_to_idx_.find(name); - if (annotation_index == annotation_name_to_idx_.end()) continue; - row[annotation_index->second] = value; - } - instruction_annotations_.push_back(row); + // Store the annotations for later use (inclusion in embeddings), using -1 + // as a default value wherever annotations are missing. + std::vector row = std::vector(annotation_names_.size(), -1); + for (const auto& [name, value] : instruction.instruction_annotations) { + const auto annotation_index = annotation_name_to_idx_.find(name); + if (annotation_index == annotation_name_to_idx_.end()) continue; + row[annotation_index->second] = value; + } + instruction_annotations_.push_back(row); - // Add nodes for prefixes of the instruction. - for (const std::string& prefix : instruction.prefixes) { - const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); - if (prefix_node == kInvalidNode) { - return false; + // Add nodes for prefixes of the instruction. + for (const std::string& prefix : instruction.prefixes) { + const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); + if (prefix_node == kInvalidNode) { + return false; + } + AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); } - AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); - } - // Add a structural dependency edge from the previous instruction. - if (previous_instruction_node >= 0) { - AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, - instruction_node); - } + // Add a structural dependency edge from the previous instruction. + if (previous_instruction_node >= 0) { + AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, + instruction_node); + } - // Add edges for input operands. And nodes too, if necessary. - for (const InstructionOperand& operand : instruction.input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } + // Add edges for input operands. And nodes too, if necessary. + for (const InstructionOperand& operand : instruction.input_operands) { + if (!AddInputOperand(instruction_node, operand, is_context)) + return false; + } + for (const InstructionOperand& operand : + instruction.implicit_input_operands) { + if (!AddInputOperand(instruction_node, operand, is_context)) + return false; + } - // Add edges and nodes for output operands. - for (const InstructionOperand& operand : instruction.output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } + // Add edges and nodes for output operands. + for (const InstructionOperand& operand : instruction.output_operands) { + if (!AddOutputOperand(instruction_node, operand, is_context)) + return false; + } + for (const InstructionOperand& operand : + instruction.implicit_output_operands) { + if (!AddOutputOperand(instruction_node, operand, is_context)) + return false; + } - previous_instruction_node = instruction_node; + previous_instruction_node = instruction_node; + } } global_features_.emplace_back(num_node_tokens(), 0); @@ -276,6 +290,7 @@ void BasicBlockGraphBuilder::Reset() { node_types_.clear(); node_features_.clear(); + context_node_mask_.clear(); edge_senders_.clear(); edge_receivers_.clear(); @@ -286,54 +301,58 @@ void BasicBlockGraphBuilder::Reset() { instruction_annotations_.clear(); } -bool BasicBlockGraphBuilder::AddInputOperand( - NodeIndex instruction_node, const InstructionOperand& operand) { +bool BasicBlockGraphBuilder::AddInputOperand(NodeIndex instruction_node, + const InstructionOperand& operand, + bool is_context) { assert(instruction_node >= 0); assert(instruction_node < num_nodes()); switch (operand.type()) { case OperandType::kRegister: { if (!AddDependencyOnRegister(instruction_node, operand.register_name(), - EdgeType::kInputOperands)) { + EdgeType::kInputOperands, is_context)) { return false; } } break; case OperandType::kImmediateValue: { AddEdge(EdgeType::kInputOperands, - AddNode(NodeType::kImmediate, immediate_token_), + AddNode(NodeType::kImmediate, immediate_token_, is_context), instruction_node); } break; case OperandType::kFpImmediateValue: { AddEdge(EdgeType::kInputOperands, - AddNode(NodeType::kFpImmediate, fp_immediate_token_), + AddNode(NodeType::kFpImmediate, fp_immediate_token_, is_context), instruction_node); } break; case OperandType::kAddress: { const NodeIndex address_node = - AddNode(NodeType::kAddressOperand, address_token_); + AddNode(NodeType::kAddressOperand, address_token_, is_context); const AddressTuple& address_tuple = operand.address(); if (!address_tuple.base_register.empty()) { if (!AddDependencyOnRegister(address_node, address_tuple.base_register, - EdgeType::kAddressBaseRegister)) { + EdgeType::kAddressBaseRegister, + is_context)) { return false; } } if (!address_tuple.index_register.empty()) { if (!AddDependencyOnRegister(address_node, address_tuple.index_register, - EdgeType::kAddressIndexRegister)) { + EdgeType::kAddressIndexRegister, + is_context)) { return false; } } if (!address_tuple.segment_register.empty()) { - if (!AddDependencyOnRegister(address_node, - address_tuple.segment_register, - EdgeType::kAddressSegmentRegister)) { + if (!AddDependencyOnRegister( + address_node, address_tuple.segment_register, + EdgeType::kAddressSegmentRegister, is_context)) { return false; } } if (address_tuple.displacement != 0) { AddEdge(EdgeType::kAddressDisplacement, - AddNode(NodeType::kImmediate, immediate_token_), address_node); + AddNode(NodeType::kImmediate, immediate_token_, is_context), + address_node); } // NOTE(ondrasej): For now, we explicitly ignore the scaling. AddEdge(EdgeType::kInputOperands, address_node, instruction_node); @@ -342,7 +361,13 @@ bool BasicBlockGraphBuilder::AddInputOperand( NodeIndex& alias_group_node = LookupOrInsert( alias_group_nodes_, operand.alias_group_id(), kInvalidNode); if (alias_group_node == kInvalidNode) { - alias_group_node = AddNode(NodeType::kMemoryOperand, memory_token_); + alias_group_node = + AddNode(NodeType::kMemoryOperand, memory_token_, is_context); + } else if (context_node_mask_[alias_group_node] && !is_context) { + // Update `context_node_mask_` to indicate that `alias_group_node`, + // previously marked as context, is also part of the main block i.e. not + // context. + context_node_mask_[alias_group_node] = false; } AddEdge(EdgeType::kInputOperands, alias_group_node, instruction_node); } break; @@ -354,15 +379,16 @@ bool BasicBlockGraphBuilder::AddInputOperand( return true; } -bool BasicBlockGraphBuilder::AddOutputOperand( - NodeIndex instruction_node, const InstructionOperand& operand) { +bool BasicBlockGraphBuilder::AddOutputOperand(NodeIndex instruction_node, + const InstructionOperand& operand, + bool is_context) { assert(instruction_node >= 0); assert(instruction_node < num_nodes()); switch (operand.type()) { case OperandType::kRegister: { const NodeIndex register_node = - AddNode(NodeType::kRegister, operand.register_name()); + AddNode(NodeType::kRegister, operand.register_name(), is_context); if (register_node == kInvalidNode) return false; AddEdge(EdgeType::kOutputOperands, instruction_node, register_node); register_nodes_[operand.register_name()] = register_node; @@ -376,7 +402,7 @@ bool BasicBlockGraphBuilder::AddOutputOperand( break; case OperandType::kMemory: { const NodeIndex alias_group_node = - AddNode(NodeType::kMemoryOperand, memory_token_); + AddNode(NodeType::kMemoryOperand, memory_token_, is_context); alias_group_nodes_[operand.alias_group_id()] = alias_group_node; AddEdge(EdgeType::kOutputOperands, instruction_node, alias_group_node); } break; @@ -390,13 +416,17 @@ bool BasicBlockGraphBuilder::AddOutputOperand( bool BasicBlockGraphBuilder::AddDependencyOnRegister( NodeIndex dependent_node, const std::string& register_name, - EdgeType edge_type) { + EdgeType edge_type, bool is_context) { NodeIndex& operand_node = LookupOrInsert(register_nodes_, register_name, kInvalidNode); if (operand_node == kInvalidNode) { // Add a node for the register if it doesn't exist. This also updates the // node index in `node_by_register`. - operand_node = AddNode(NodeType::kRegister, register_name); + operand_node = AddNode(NodeType::kRegister, register_name, is_context); + } else if (context_node_mask_[operand_node] && !is_context) { + // Update `context_node_mask_` to indicate that `operand_node`, previously + // marked as context, is also part of the main block i.e. not context. + context_node_mask_[operand_node] = false; } if (operand_node == kInvalidNode) return false; AddEdge(edge_type, operand_node, dependent_node); @@ -404,15 +434,16 @@ bool BasicBlockGraphBuilder::AddDependencyOnRegister( } BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode( - NodeType node_type, TokenIndex token_index) { + NodeType node_type, TokenIndex token_index, bool is_context) { const NodeIndex new_node_index = num_nodes(); node_types_.push_back(node_type); node_features_.push_back(token_index); + context_node_mask_.push_back(is_context); return new_node_index; } BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode( - NodeType node_type, const std::string& token) { + NodeType node_type, const std::string& token, bool is_context) { const auto it = node_tokens_.find(token); TokenIndex token_index = kInvalidTokenIndex; if (it != node_tokens_.end()) { @@ -427,7 +458,7 @@ BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode( token_index = replacement_token_; } } - return AddNode(node_type, token_index); + return AddNode(node_type, token_index, is_context); } void BasicBlockGraphBuilder::AddEdge(EdgeType edge_type, NodeIndex sender, @@ -505,6 +536,7 @@ std::string BasicBlockGraphBuilder::DebugString() const { StrAppendList(buffer, "num_nodes_per_block", num_nodes_per_block()); StrAppendList(buffer, "num_edges_per_block", num_edges_per_block()); StrAppendList(buffer, "node_types", node_types()); + StrAppendList(buffer, "context_node_mask", context_node_mask()); StrAppendList(buffer, "edge_senders", edge_senders()); StrAppendList(buffer, "edge_receivers", edge_receivers()); StrAppendList(buffer, "edge_types", edge_types()); diff --git a/gematria/granite/graph_builder.h b/gematria/granite/graph_builder.h index c4684eea..8f110dd4 100644 --- a/gematria/granite/graph_builder.h +++ b/gematria/granite/graph_builder.h @@ -91,7 +91,6 @@ #include #include -#include #include #include #include @@ -187,14 +186,24 @@ class BasicBlockGraphBuilder { // method encountered an unknown token and the unknown token behavior is not // kReplaceToken or when the basic block does not contain any instructions. // When this happens, the graph builder is left in the previous state, i.e. no - // basic block is added to it. - bool AddBasicBlock(const BasicBlock& block) { + // basic block is added to it. The basic block context is added to the graph + // if and only if `add_context` is true. + bool AddBasicBlock(const BasicBlock& block, bool add_context = false) { + if (add_context) { + return AddBasicBlockFromInstructions( + block.instructions, block.preceding_context, block.following_context); + } return AddBasicBlockFromInstructions(block.instructions); } // A version of AddBasicBlock that takes the list of instructions in the basic - // block instead of the basic block object itself. + // block and optionally its preceding and following contexts instead of the + // basic block object itself. bool AddBasicBlockFromInstructions( - const std::vector& instructions); + const std::vector& instructions, + const std::vector& preceding_context = + std::vector(), + const std::vector& following_context = + std::vector()); // Resets the graph builder so that it can be used to create a new graph from // scratch. @@ -242,6 +251,12 @@ class BasicBlockGraphBuilder { // Feature value of the nodes in the batch (i.e. the indices of the tokens // corresponding to the nodes). Corresponds to `GraphsTuple.nodes`. const std::vector& node_features() const { return node_features_; } + // Whether or not the corresponding node belongs to either the preceding or + // following context of the basic block, and not the basic block itself. Used + // by the models to exclude context nodes from predictions. + const std::vector& context_node_mask() const { + return context_node_mask_; + } // Names of types of instruction annotations stored. const std::vector& annotation_names() const { @@ -362,24 +377,28 @@ class BasicBlockGraphBuilder { // Adds nodes and edges for a single input operand of an instruction. bool AddInputOperand(NodeIndex instruction_node, - const InstructionOperand& operand); + const InstructionOperand& operand, + bool is_context = false); // Adds nodes and edges for a single output operand of an instruction. bool AddOutputOperand(NodeIndex instruction_node, - const InstructionOperand& operand); + const InstructionOperand& operand, + bool is_context = false); // Adds dependency of a node (instruction or an address computation node) on // a register. Adds the register node if it doesn't exist in the graph. bool AddDependencyOnRegister(NodeIndex dependent_node, const std::string& register_name, - EdgeType edge_type); + EdgeType edge_type, bool is_context = false); // Adds a new node to the batch; the feature of the node is given directly by // the caller. - NodeIndex AddNode(NodeType node_type, TokenIndex token_index); + NodeIndex AddNode(NodeType node_type, TokenIndex token_index, + bool is_context = false); // Adds a new edge to the batch; the feature of the node is determined from // the token associated with the node. Returns kInvalidNode when the node was // not added. - NodeIndex AddNode(NodeType node_type, const std::string& token); + NodeIndex AddNode(NodeType node_type, const std::string& token, + bool is_context = false); // Adds a new edge to the batch. void AddEdge(EdgeType edge_type, NodeIndex sender, NodeIndex receiver); @@ -406,6 +425,7 @@ class BasicBlockGraphBuilder { std::vector node_types_; std::vector node_features_; + std::vector context_node_mask_; // Mapping from annotation type names to corresponding row index in the // `instruction_annotations_` matrix. diff --git a/gematria/granite/graph_builder_test.cc b/gematria/granite/graph_builder_test.cc index 82cdbca4..68ac52f1 100644 --- a/gematria/granite/graph_builder_test.cc +++ b/gematria/granite/graph_builder_test.cc @@ -33,8 +33,11 @@ namespace gematria { namespace { +using ::testing::_; +using ::testing::Each; using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::testing::IsFalse; using ::testing::Pair; // Tokens used in the basic blocks in tests. For simplicity, we do not use the @@ -490,6 +493,7 @@ TEST_F(BasicBlockGraphBuilderTest, MultipleInstructions) { true, false, // Fourth instruction. true, false)); + EXPECT_THAT(builder_->context_node_mask(), Each(IsFalse())); EXPECT_THAT( builder_->edge_types(), @@ -571,5 +575,98 @@ TEST_F(BasicBlockGraphBuilderTest, TwoNops) { )pb")))); } +// Tests that the graph is built and context node mask is set correctly when +// context is supplied. +TEST_F(BasicBlockGraphBuilderTest, MultipleBasicBlocksWithContext) { + CreateBuilder(OutOfVocabularyTokenBehavior::ReturnError()); + ASSERT_TRUE( + builder_->AddBasicBlock(BasicBlockFromProto(ParseTextProto(R"pb( + canonicalized_instructions: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + } + canonicalized_preceding_context: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + } + canonicalized_following_context: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + })pb")), + true)); + ASSERT_TRUE( + builder_->AddBasicBlock(BasicBlockFromProto(ParseTextProto(R"pb( + canonicalized_instructions: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + } + canonicalized_preceding_context: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + })pb")), + true)); + + EXPECT_EQ(builder_->num_graphs(), 2); + EXPECT_EQ(builder_->num_node_tokens(), std::size(kTokens)); + + EXPECT_EQ(builder_->num_nodes(), 7 + 5); + EXPECT_THAT(builder_->num_nodes_per_block(), ElementsAre(7, 5)); + + EXPECT_EQ(builder_->num_edges(), 8 + 5); + EXPECT_THAT(builder_->num_edges_per_block(), ElementsAre(8, 5)); + + EXPECT_THAT(builder_->node_types(), + ElementsAre(NodeType::kInstruction, NodeType::kRegister, + NodeType::kRegister, NodeType::kInstruction, + NodeType::kRegister, NodeType::kInstruction, + NodeType::kRegister, NodeType::kInstruction, + NodeType::kRegister, NodeType::kRegister, + NodeType::kInstruction, NodeType::kRegister)); + EXPECT_THAT( + builder_->node_features(), + ElementsAre(TokenIndex("NOT"), TokenIndex("RCX"), TokenIndex("RCX"), + TokenIndex("NOT"), TokenIndex("RCX"), TokenIndex("NOT"), + TokenIndex("RCX"), TokenIndex("NOT"), TokenIndex("RCX"), + TokenIndex("RCX"), TokenIndex("NOT"), TokenIndex("RCX"))); + EXPECT_THAT(builder_->InstructionNodeMask(), + ElementsAre(true, false, false, true, false, true, false, true, + false, false, true, false)); + EXPECT_THAT(builder_->context_node_mask(), + ElementsAre(true, true, false, false, false, true, true, true, + true, false, false, false)); + + EXPECT_THAT( + builder_->edge_types(), + ElementsAre(EdgeType::kInputOperands, EdgeType::kOutputOperands, + EdgeType::kStructuralDependency, EdgeType::kInputOperands, + EdgeType::kOutputOperands, EdgeType::kStructuralDependency, + EdgeType::kInputOperands, EdgeType::kOutputOperands, + EdgeType::kInputOperands, EdgeType::kOutputOperands, + EdgeType::kStructuralDependency, EdgeType::kInputOperands, + EdgeType::kOutputOperands)); + + EXPECT_THAT(builder_->edge_senders(), + ElementsAre(1, 0, 0, 2, 3, 3, 4, 5, 8, 7, 7, 9, 10)); + EXPECT_THAT(builder_->edge_receivers(), + ElementsAre(0, 2, 3, 3, 4, 5, 5, 6, 7, 9, 10, 10, 11)); + + EXPECT_THAT( + builder_->global_features(), + ElementsAre(ElementsAre(0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 4, 0, 0, 0, 0), + ElementsAre(0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0))); + + EXPECT_THAT(builder_->DeltaBlockIndex(), ElementsAre(0, 0, 0, 1, 1)); +} + } // namespace } // namespace gematria diff --git a/gematria/granite/python/BUILD.bazel b/gematria/granite/python/BUILD.bazel index ec928b24..a46e812c 100644 --- a/gematria/granite/python/BUILD.bazel +++ b/gematria/granite/python/BUILD.bazel @@ -48,6 +48,7 @@ gematria_pybind_extension( srcs = ["graph_builder.cc"], visibility = ["//:internal_users"], deps = [ + "//gematria/basic_block", "//gematria/granite:graph_builder", "//gematria/model:oov_token_behavior", "//gematria/proto:canonicalized_instruction_cc_proto", diff --git a/gematria/granite/python/graph_builder.cc b/gematria/granite/python/graph_builder.cc index c238e2d3..c90e9d63 100644 --- a/gematria/granite/python/graph_builder.cc +++ b/gematria/granite/python/graph_builder.cc @@ -14,11 +14,11 @@ #include "gematria/granite/graph_builder.h" -#include #include #include #include "absl/strings/string_view.h" +#include "gematria/basic_block/basic_block.h" #include "gematria/model/oov_token_behavior.h" #include "gematria/proto/canonicalized_instruction.pb.h" #include "pybind11/cast.h" @@ -81,10 +81,12 @@ PYBIND11_MODULE(graph_builder, m) { py::arg("annotation_names") = std::vector(), py::arg("out_of_vocabulary_behavior")) .def("add_basic_block", &BasicBlockGraphBuilder::AddBasicBlock, - py::arg("block")) + py::arg("block"), py::arg("add_context") = false) .def("add_basic_block_from_instructions", &BasicBlockGraphBuilder::AddBasicBlockFromInstructions, - py::arg("instructions")) + py::arg("instructions"), + py::arg("preceding_context") = std::vector(), + py::arg("following_context") = std::vector()) .def("reset", &BasicBlockGraphBuilder::Reset) .def_property_readonly("num_node_tokens", &BasicBlockGraphBuilder::num_node_tokens) @@ -99,6 +101,8 @@ PYBIND11_MODULE(graph_builder, m) { &BasicBlockGraphBuilder::node_features) .def_property_readonly("instruction_node_mask", &BasicBlockGraphBuilder::InstructionNodeMask) + .def_property_readonly("context_node_mask", + &BasicBlockGraphBuilder::context_node_mask) .def_property_readonly("annotation_names", &BasicBlockGraphBuilder::annotation_names) .def_property_readonly("instruction_annotations", diff --git a/gematria/proto/basic_block.proto b/gematria/proto/basic_block.proto index 4276518f..20528791 100644 --- a/gematria/proto/basic_block.proto +++ b/gematria/proto/basic_block.proto @@ -32,6 +32,24 @@ message BasicBlockProto { // The fingerprint-id of this basic block. Might be empty. string fingerprint = 3; + + // An optional list of machine instructions preceding the basic block, used + // to provide context that precedes `canonicalized_instructions`. These + // instructions are not included in the timing measurements and predictions. + repeated MachineInstructionProto machine_preceding_context = 4; + + // An optional list of machine instructions following the basic block, used + // to provide context following `canonicalized_instructions`. These + // instructions are not included in the timing measurements and predictions. + repeated MachineInstructionProto machine_following_context = 5; + + // Canonicalized instructions parallel to `machine_preceding_context`. May be + // empty in case no preceding context is provided. + repeated CanonicalizedInstructionProto canonicalized_preceding_context = 6; + + // Canonicalized instructions parallel to `machine_following_context`. May be + // empty in case no following context is provided. + repeated CanonicalizedInstructionProto canonicalized_following_context = 7; } // Represents a raw instruction extracted from binary code.