From c0da27c28d95ec475a8247e8c387d86bbecf3f0c Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Mon, 20 Jan 2025 02:53:51 +0530 Subject: [PATCH 1/6] Add support for storing basic block context in Gematria formats. * Add fields to the `proto` specification to store context. * Add members to the Gematria `BasicBlock` data structure to store context and update methods on it and its Python binding accordingly. * Bonus: Remove dangling TODO. --- gematria/basic_block/basic_block.cc | 30 +++++++++++++++++++--- gematria/basic_block/basic_block.h | 16 +++++++++--- gematria/basic_block/basic_block_protos.cc | 11 ++++++-- gematria/basic_block/basic_block_test.cc | 5 ++-- gematria/basic_block/python/basic_block.cc | 10 ++++++-- gematria/proto/basic_block.proto | 20 ++++++++++++++- 6 files changed, 77 insertions(+), 15 deletions(-) diff --git a/gematria/basic_block/basic_block.cc b/gematria/basic_block/basic_block.cc index 466cab16..f5aebef6 100644 --- a/gematria/basic_block/basic_block.cc +++ b/gematria/basic_block/basic_block.cc @@ -377,11 +377,17 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction) { return os; } -BasicBlock::BasicBlock(std::vector instructions) - : instructions(std::move(instructions)) {} +BasicBlock::BasicBlock(std::vector instructions, + std::vector back_context, + std::vector front_context) + : instructions(std::move(instructions)), + back_context(std::move(back_context)), + front_context(std::move(front_context)) {} bool BasicBlock::operator==(const BasicBlock& other) const { - return instructions == other.instructions; + return instructions == other.instructions && + back_context == other.back_context && + front_context == other.front_context; } std::string BasicBlock::ToString() const { @@ -395,6 +401,24 @@ std::string BasicBlock::ToString() const { if (buffer.back() == ' ') buffer.pop_back(); buffer += "))"; } + if (!back_context.empty()) { + buffer += "back_context=InstructionList(("; + for (const Instruction& instruction : back_context) { + buffer += instruction.ToString(); + buffer += ", "; + } + if (buffer.back() == ' ') buffer.pop_back(); + buffer += "))"; + } + if (!front_context.empty()) { + buffer += "front_context=InstructionList(("; + for (const Instruction& instruction : front_context) { + buffer += instruction.ToString(); + buffer += ", "; + } + if (buffer.back() == ' ') buffer.pop_back(); + buffer += "))"; + } buffer.push_back(')'); return buffer; } diff --git a/gematria/basic_block/basic_block.h b/gematria/basic_block/basic_block.h index 43579fb7..735aa3d1 100644 --- a/gematria/basic_block/basic_block.h +++ b/gematria/basic_block/basic_block.h @@ -221,7 +221,7 @@ std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand); // Represents an annotation holding a value such as some measure/statistic // paired with the instruction. struct Annotation { - Annotation() : value(-1){}; + Annotation() : value(-1) {}; // Initializes all fields of the annotation. Annotation(std::string name, double value); @@ -324,9 +324,12 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction); struct BasicBlock { BasicBlock() {} - // Initializes the basic block from a list of instructions. Needed for - // compatibility with the Python code. - explicit BasicBlock(std::vector instructions); + // Initializes the basic block from a list of instructions and optional + // context. Needed for compatibility with the Python code. + explicit BasicBlock( + std::vector instructions, + std::vector back_context = std::vector(), + std::vector front_context = std::vector()); BasicBlock(const BasicBlock&) = default; BasicBlock(BasicBlock&&) = default; @@ -346,6 +349,11 @@ struct BasicBlock { // The list of instructions in the basic block. std::vector instructions; + + // The back and front context instructions, i.e. those preceeding and + // following the instructions in the basic block. + std::vector back_context; + std::vector front_context; }; std::ostream& operator<<(std::ostream& os, const BasicBlock& block); diff --git a/gematria/basic_block/basic_block_protos.cc b/gematria/basic_block/basic_block_protos.cc index 7b8c0a0c..d327901b 100644 --- a/gematria/basic_block/basic_block_protos.cc +++ b/gematria/basic_block/basic_block_protos.cc @@ -180,8 +180,15 @@ CanonicalizedInstructionProto ProtoFromInstruction( BasicBlock BasicBlockFromProto(const BasicBlockProto& proto) { return BasicBlock( - /* instructions = */ ToVector( - proto.canonicalized_instructions(), InstructionFromProto)); + /* instructions = */ + ToVector(proto.canonicalized_instructions(), + InstructionFromProto), + /* back_context = */ + ToVector(proto.canonicalized_back_context(), + InstructionFromProto), + /* front_context = */ + ToVector(proto.canonicalized_front_context(), + InstructionFromProto)); } } // namespace gematria diff --git a/gematria/basic_block/basic_block_test.cc b/gematria/basic_block/basic_block_test.cc index 3977c560..348523d0 100644 --- a/gematria/basic_block/basic_block_test.cc +++ b/gematria/basic_block/basic_block_test.cc @@ -277,7 +277,7 @@ TEST(InstructionOperandTest, Equality) { TEST(InstructionOperandTest, ToString) { const struct { - InstructionOperand opernad; + InstructionOperand operand; const char* expected_string; } kTestCases[] = { {InstructionOperand::Register("RAX"), @@ -292,7 +292,7 @@ TEST(InstructionOperandTest, ToString) { "InstructionOperand.from_memory(32)"}}; for (const auto& test_case : kTestCases) { - EXPECT_EQ(test_case.opernad.ToString(), test_case.expected_string); + EXPECT_EQ(test_case.operand.ToString(), test_case.expected_string); } } @@ -318,7 +318,6 @@ TEST(InstructionOperandTest, AsTokenList) { } } -// TODO(virajbshah): Add tests for Annotation. TEST(AnnotationTest, Constructor) { constexpr char kName[] = "cache_miss_freq"; constexpr double kValue = 0.875; diff --git a/gematria/basic_block/python/basic_block.cc b/gematria/basic_block/python/basic_block.cc index e78c6111..b583dabb 100644 --- a/gematria/basic_block/python/basic_block.cc +++ b/gematria/basic_block/python/basic_block.cc @@ -255,9 +255,15 @@ PYBIND11_MODULE(basic_block, m) { py::class_ basic_block(m, "BasicBlock"); basic_block - .def(py::init /* instructions */>(), - py::arg("instructions") = std::vector()) + .def(py::init /* instructions */, + std::vector /* back_context */, + std::vector /* front_context */>(), + py::arg("instructions") = std::vector(), + py::arg("back_context") = std::vector(), + py::arg("front_context") = std::vector()) .def_readwrite("instructions", &BasicBlock::instructions) + .def_readwrite("back_context", &BasicBlock::back_context) + .def_readwrite("front_context", &BasicBlock::front_context) .def("__repr__", &BasicBlock::ToString) .def("__str__", &BasicBlock::ToString) .def("__eq__", &BasicBlock::operator==) diff --git a/gematria/proto/basic_block.proto b/gematria/proto/basic_block.proto index 4276518f..86269cea 100644 --- a/gematria/proto/basic_block.proto +++ b/gematria/proto/basic_block.proto @@ -30,8 +30,26 @@ message BasicBlockProto { // same instruction. repeated CanonicalizedInstructionProto canonicalized_instructions = 2; + // An optional list of machine instructions preceding the basic block, used + // to provide context that lies before `canonicalized_instructions`. These + // instructions are not included in the timing measurements and predictions. + repeated CanonicalizedInstructionProto machine_back_context = 3; + + // An optional list of machine instructions following the basic block, used + // to provide context lying after `canonicalized_instructions`. These + // instructions are not included in the timing measurements and predictions. + repeated CanonicalizedInstructionProto machine_front_context = 4; + + // Canonicalized instructions parallel to `machine_back_context`. May be + // empty in case no back context is provided. + repeated CanonicalizedInstructionProto canonicalized_back_context = 5; + + // Canonicalized instructions parallel to `machine_front_context`. May be + // empty in case no front context is provided. + repeated CanonicalizedInstructionProto canonicalized_front_context = 6; + // The fingerprint-id of this basic block. Might be empty. - string fingerprint = 3; + string fingerprint = 7; } // Represents a raw instruction extracted from binary code. From 502e6891fa075ce90fd55240c9ba8f5045281b53 Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Tue, 21 Jan 2025 00:30:42 +0530 Subject: [PATCH 2/6] Fix typo. --- gematria/proto/basic_block.proto | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gematria/proto/basic_block.proto b/gematria/proto/basic_block.proto index 86269cea..bbcd57ae 100644 --- a/gematria/proto/basic_block.proto +++ b/gematria/proto/basic_block.proto @@ -33,12 +33,12 @@ message BasicBlockProto { // An optional list of machine instructions preceding the basic block, used // to provide context that lies before `canonicalized_instructions`. These // instructions are not included in the timing measurements and predictions. - repeated CanonicalizedInstructionProto machine_back_context = 3; + repeated MachineInstructionProto machine_back_context = 3; // An optional list of machine instructions following the basic block, used // to provide context lying after `canonicalized_instructions`. These // instructions are not included in the timing measurements and predictions. - repeated CanonicalizedInstructionProto machine_front_context = 4; + repeated MachineInstructionProto machine_front_context = 4; // Canonicalized instructions parallel to `machine_back_context`. May be // empty in case no back context is provided. From 22b2eeb74d59788bf3bb7e236132c4886aeeaf7f Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Sat, 8 Feb 2025 23:14:53 +0530 Subject: [PATCH 3/6] Address reviewer suggestions. * Change `back_context` -> `preceding_context` and `front_context` -> `following_context`. * Add test for updated `BasicBlock::ToString`. * Keep previous proto field numbering. --- gematria/basic_block/basic_block.cc | 26 ++++++++++++---------- gematria/basic_block/basic_block.h | 10 ++++----- gematria/basic_block/basic_block_protos.cc | 8 +++---- gematria/basic_block/basic_block_test.cc | 23 ++++++++++++++++--- gematria/basic_block/python/basic_block.cc | 12 +++++----- gematria/proto/basic_block.proto | 26 +++++++++++----------- 6 files changed, 62 insertions(+), 43 deletions(-) diff --git a/gematria/basic_block/basic_block.cc b/gematria/basic_block/basic_block.cc index f5aebef6..55233a9c 100644 --- a/gematria/basic_block/basic_block.cc +++ b/gematria/basic_block/basic_block.cc @@ -378,16 +378,16 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction) { } BasicBlock::BasicBlock(std::vector instructions, - std::vector back_context, - std::vector front_context) + std::vector preceding_context, + std::vector following_context) : instructions(std::move(instructions)), - back_context(std::move(back_context)), - front_context(std::move(front_context)) {} + preceding_context(std::move(preceding_context)), + following_context(std::move(following_context)) {} bool BasicBlock::operator==(const BasicBlock& other) const { return instructions == other.instructions && - back_context == other.back_context && - front_context == other.front_context; + preceding_context == other.preceding_context && + following_context == other.following_context; } std::string BasicBlock::ToString() const { @@ -400,19 +400,21 @@ std::string BasicBlock::ToString() const { } if (buffer.back() == ' ') buffer.pop_back(); buffer += "))"; + if (!preceding_context.empty()) buffer += ", "; } - if (!back_context.empty()) { - buffer += "back_context=InstructionList(("; - for (const Instruction& instruction : back_context) { + if (!preceding_context.empty()) { + buffer += "preceding_context=InstructionList(("; + for (const Instruction& instruction : preceding_context) { buffer += instruction.ToString(); buffer += ", "; } if (buffer.back() == ' ') buffer.pop_back(); buffer += "))"; + if (!following_context.empty()) buffer += ", "; } - if (!front_context.empty()) { - buffer += "front_context=InstructionList(("; - for (const Instruction& instruction : front_context) { + if (!following_context.empty()) { + buffer += "following_context=InstructionList(("; + for (const Instruction& instruction : following_context) { buffer += instruction.ToString(); buffer += ", "; } diff --git a/gematria/basic_block/basic_block.h b/gematria/basic_block/basic_block.h index 735aa3d1..6a8c01bf 100644 --- a/gematria/basic_block/basic_block.h +++ b/gematria/basic_block/basic_block.h @@ -328,8 +328,8 @@ struct BasicBlock { // context. Needed for compatibility with the Python code. explicit BasicBlock( std::vector instructions, - std::vector back_context = std::vector(), - std::vector front_context = std::vector()); + std::vector preceding_context = std::vector(), + std::vector following_context = std::vector()); BasicBlock(const BasicBlock&) = default; BasicBlock(BasicBlock&&) = default; @@ -350,10 +350,10 @@ struct BasicBlock { // The list of instructions in the basic block. std::vector instructions; - // The back and front context instructions, i.e. those preceeding and + // The preceding and following context instructions, i.e. those preceeding and // following the instructions in the basic block. - std::vector back_context; - std::vector front_context; + std::vector preceding_context; + std::vector following_context; }; std::ostream& operator<<(std::ostream& os, const BasicBlock& block); diff --git a/gematria/basic_block/basic_block_protos.cc b/gematria/basic_block/basic_block_protos.cc index d327901b..b6e441ce 100644 --- a/gematria/basic_block/basic_block_protos.cc +++ b/gematria/basic_block/basic_block_protos.cc @@ -183,11 +183,11 @@ BasicBlock BasicBlockFromProto(const BasicBlockProto& proto) { /* instructions = */ ToVector(proto.canonicalized_instructions(), InstructionFromProto), - /* back_context = */ - ToVector(proto.canonicalized_back_context(), + /* preceding_context = */ + ToVector(proto.canonicalized_preceding_context(), InstructionFromProto), - /* front_context = */ - ToVector(proto.canonicalized_front_context(), + /* following_context = */ + ToVector(proto.canonicalized_following_context(), InstructionFromProto)); } diff --git a/gematria/basic_block/basic_block_test.cc b/gematria/basic_block/basic_block_test.cc index 348523d0..68bb6b72 100644 --- a/gematria/basic_block/basic_block_test.cc +++ b/gematria/basic_block/basic_block_test.cc @@ -596,7 +596,7 @@ TEST(BasicBlockTest, ToString) { /* instruction_annotations = */ {Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)}); - BasicBlock block({instruction}); + BasicBlock block({instruction}, {instruction}, {instruction}); constexpr char kExpectedString[] = "BasicBlock(instructions=InstructionList((Instruction(mnemonic='ADC', " "llvm_mnemonic='ADC32rr', prefixes=('LOCK',), " @@ -606,8 +606,25 @@ TEST(BasicBlockTest, ToString) { "output_operands=(InstructionOperand.from_register('RAX'),), " "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " - "value=0.875),))," - ")))"; + "value=0.875),)),)), " + "preceding_context=InstructionList((Instruction(mnemonic='ADC', " + "llvm_mnemonic='ADC32rr', prefixes=('LOCK',), " + "input_operands=(InstructionOperand.from_register('RAX'), " + "InstructionOperand.from_register('RBX'),), " + "implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), " + "output_operands=(InstructionOperand.from_register('RAX'),), " + "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " + "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " + "value=0.875),)),)), " + "following_context=InstructionList((Instruction(mnemonic='ADC', " + "llvm_mnemonic='ADC32rr', prefixes=('LOCK',), " + "input_operands=(InstructionOperand.from_register('RAX'), " + "InstructionOperand.from_register('RBX'),), " + "implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), " + "output_operands=(InstructionOperand.from_register('RAX'),), " + "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " + "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " + "value=0.875),)),)))"; EXPECT_EQ(block.ToString(), kExpectedString); } diff --git a/gematria/basic_block/python/basic_block.cc b/gematria/basic_block/python/basic_block.cc index b583dabb..863eb847 100644 --- a/gematria/basic_block/python/basic_block.cc +++ b/gematria/basic_block/python/basic_block.cc @@ -256,14 +256,14 @@ PYBIND11_MODULE(basic_block, m) { py::class_ basic_block(m, "BasicBlock"); basic_block .def(py::init /* instructions */, - std::vector /* back_context */, - std::vector /* front_context */>(), + std::vector /* preceding_context */, + std::vector /* following_context */>(), py::arg("instructions") = std::vector(), - py::arg("back_context") = std::vector(), - py::arg("front_context") = std::vector()) + py::arg("preceding_context") = std::vector(), + py::arg("following_context") = std::vector()) .def_readwrite("instructions", &BasicBlock::instructions) - .def_readwrite("back_context", &BasicBlock::back_context) - .def_readwrite("front_context", &BasicBlock::front_context) + .def_readwrite("preceding_context", &BasicBlock::preceding_context) + .def_readwrite("following_context", &BasicBlock::following_context) .def("__repr__", &BasicBlock::ToString) .def("__str__", &BasicBlock::ToString) .def("__eq__", &BasicBlock::operator==) diff --git a/gematria/proto/basic_block.proto b/gematria/proto/basic_block.proto index bbcd57ae..20528791 100644 --- a/gematria/proto/basic_block.proto +++ b/gematria/proto/basic_block.proto @@ -30,26 +30,26 @@ message BasicBlockProto { // same instruction. repeated CanonicalizedInstructionProto canonicalized_instructions = 2; + // The fingerprint-id of this basic block. Might be empty. + string fingerprint = 3; + // An optional list of machine instructions preceding the basic block, used - // to provide context that lies before `canonicalized_instructions`. These + // to provide context that precedes `canonicalized_instructions`. These // instructions are not included in the timing measurements and predictions. - repeated MachineInstructionProto machine_back_context = 3; + repeated MachineInstructionProto machine_preceding_context = 4; // An optional list of machine instructions following the basic block, used - // to provide context lying after `canonicalized_instructions`. These + // to provide context following `canonicalized_instructions`. These // instructions are not included in the timing measurements and predictions. - repeated MachineInstructionProto machine_front_context = 4; + repeated MachineInstructionProto machine_following_context = 5; - // Canonicalized instructions parallel to `machine_back_context`. May be - // empty in case no back context is provided. - repeated CanonicalizedInstructionProto canonicalized_back_context = 5; + // Canonicalized instructions parallel to `machine_preceding_context`. May be + // empty in case no preceding context is provided. + repeated CanonicalizedInstructionProto canonicalized_preceding_context = 6; - // Canonicalized instructions parallel to `machine_front_context`. May be - // empty in case no front context is provided. - repeated CanonicalizedInstructionProto canonicalized_front_context = 6; - - // The fingerprint-id of this basic block. Might be empty. - string fingerprint = 7; + // Canonicalized instructions parallel to `machine_following_context`. May be + // empty in case no following context is provided. + repeated CanonicalizedInstructionProto canonicalized_following_context = 7; } // Represents a raw instruction extracted from binary code. From ec7dc81958b58c1be2e8b7fe59a6681dc0d2b8e9 Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Mon, 20 Jan 2025 02:58:04 +0530 Subject: [PATCH 4/6] Add support adding context instructions to basic block graphs. * Update the graph builder and its Python bindings to add context instructions to basic block graphs and store context node mask to later be used by models. * Add tests for the new graph builder functionality. --- gematria/granite/graph_builder.cc | 110 +++++++++++++---------- gematria/granite/graph_builder.h | 30 +++++-- gematria/granite/graph_builder_test.cc | 96 ++++++++++++++++++++ gematria/granite/python/BUILD.bazel | 1 + gematria/granite/python/graph_builder.cc | 10 ++- 5 files changed, 189 insertions(+), 58 deletions(-) diff --git a/gematria/granite/graph_builder.cc b/gematria/granite/graph_builder.cc index 30f6c4ab..f4f6098d 100644 --- a/gematria/granite/graph_builder.cc +++ b/gematria/granite/graph_builder.cc @@ -190,7 +190,9 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( } bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( - const std::vector& instructions) { + const std::vector& instructions, + const std::vector& back_context, + const std::vector& front_context) { if (instructions.empty()) return false; AddBasicBlockTransaction transaction(this); @@ -202,58 +204,65 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( const int prev_num_edges = num_edges(); NodeIndex previous_instruction_node = kInvalidNode; - for (const Instruction& instruction : instructions) { - // Add the instruction node. - const NodeIndex instruction_node = - AddNode(NodeType::kInstruction, instruction.mnemonic); - if (instruction_node == kInvalidNode) { - return false; - } + const struct { + const std::vector& instruction_group; + bool is_context; + } instruction_groups[] = { + {back_context, true}, {instructions, false}, {front_context, true}}; + for (const auto [instruction_group, is_context] : instruction_groups) { + for (const Instruction& instruction : instruction_group) { + // Add the instruction node. + const NodeIndex instruction_node = + AddNode(NodeType::kInstruction, instruction.mnemonic, is_context); + if (instruction_node == kInvalidNode) { + return false; + } - // Store the annotations for later use (inclusion in embeddings), using -1 - // as a default value wherever annotations are missing. - std::vector row = std::vector(annotation_names_.size(), -1); - for (const auto& [name, value] : instruction.instruction_annotations) { - const auto annotation_index = annotation_name_to_idx_.find(name); - if (annotation_index == annotation_name_to_idx_.end()) continue; - row[annotation_index->second] = value; - } - instruction_annotations_.push_back(row); + // Store the annotations for later use (inclusion in embeddings), using -1 + // as a default value wherever annotations are missing. + std::vector row = std::vector(annotation_names_.size(), -1); + for (const auto& [name, value] : instruction.instruction_annotations) { + const auto annotation_index = annotation_name_to_idx_.find(name); + if (annotation_index == annotation_name_to_idx_.end()) continue; + row[annotation_index->second] = value; + } + instruction_annotations_.push_back(row); - // Add nodes for prefixes of the instruction. - for (const std::string& prefix : instruction.prefixes) { - const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); - if (prefix_node == kInvalidNode) { - return false; + // Add nodes for prefixes of the instruction. + for (const std::string& prefix : instruction.prefixes) { + const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); + if (prefix_node == kInvalidNode) { + return false; + } + AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); } - AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); - } - // Add a structural dependency edge from the previous instruction. - if (previous_instruction_node >= 0) { - AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, - instruction_node); - } + // Add a structural dependency edge from the previous instruction. + if (previous_instruction_node >= 0) { + AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, + instruction_node); + } - // Add edges for input operands. And nodes too, if necessary. - for (const InstructionOperand& operand : instruction.input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } + // Add edges for input operands. And nodes too, if necessary. + for (const InstructionOperand& operand : instruction.input_operands) { + if (!AddInputOperand(instruction_node, operand)) return false; + } + for (const InstructionOperand& operand : + instruction.implicit_input_operands) { + if (!AddInputOperand(instruction_node, operand)) return false; + } - // Add edges and nodes for output operands. - for (const InstructionOperand& operand : instruction.output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } + // Add edges and nodes for output operands. + for (const InstructionOperand& operand : instruction.output_operands) { + if (!AddOutputOperand(instruction_node, operand)) return false; + } + for (const InstructionOperand& operand : + instruction.implicit_output_operands) { + if (!AddOutputOperand(instruction_node, operand)) return false; + } - previous_instruction_node = instruction_node; + previous_instruction_node = instruction_node; + } } global_features_.emplace_back(num_node_tokens(), 0); @@ -276,6 +285,7 @@ void BasicBlockGraphBuilder::Reset() { node_types_.clear(); node_features_.clear(); + context_node_mask_.clear(); edge_senders_.clear(); edge_receivers_.clear(); @@ -404,15 +414,16 @@ bool BasicBlockGraphBuilder::AddDependencyOnRegister( } BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode( - NodeType node_type, TokenIndex token_index) { + NodeType node_type, TokenIndex token_index, bool is_context) { const NodeIndex new_node_index = num_nodes(); node_types_.push_back(node_type); node_features_.push_back(token_index); + context_node_mask_.push_back(is_context); return new_node_index; } BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode( - NodeType node_type, const std::string& token) { + NodeType node_type, const std::string& token, bool is_context) { const auto it = node_tokens_.find(token); TokenIndex token_index = kInvalidTokenIndex; if (it != node_tokens_.end()) { @@ -427,7 +438,7 @@ BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode( token_index = replacement_token_; } } - return AddNode(node_type, token_index); + return AddNode(node_type, token_index, is_context); } void BasicBlockGraphBuilder::AddEdge(EdgeType edge_type, NodeIndex sender, @@ -505,6 +516,7 @@ std::string BasicBlockGraphBuilder::DebugString() const { StrAppendList(buffer, "num_nodes_per_block", num_nodes_per_block()); StrAppendList(buffer, "num_edges_per_block", num_edges_per_block()); StrAppendList(buffer, "node_types", node_types()); + StrAppendList(buffer, "context_node_mask", context_node_mask()); StrAppendList(buffer, "edge_senders", edge_senders()); StrAppendList(buffer, "edge_receivers", edge_receivers()); StrAppendList(buffer, "edge_types", edge_types()); diff --git a/gematria/granite/graph_builder.h b/gematria/granite/graph_builder.h index c4684eea..fcb69cd3 100644 --- a/gematria/granite/graph_builder.h +++ b/gematria/granite/graph_builder.h @@ -187,14 +187,23 @@ class BasicBlockGraphBuilder { // method encountered an unknown token and the unknown token behavior is not // kReplaceToken or when the basic block does not contain any instructions. // When this happens, the graph builder is left in the previous state, i.e. no - // basic block is added to it. - bool AddBasicBlock(const BasicBlock& block) { + // basic block is added to it. The basic block context is added to the graph + // if and only if `add_context` is true. + bool AddBasicBlock(const BasicBlock& block, bool add_context = false) { + if (add_context) { + return AddBasicBlockFromInstructions( + block.instructions, block.back_context, block.front_context); + } return AddBasicBlockFromInstructions(block.instructions); } // A version of AddBasicBlock that takes the list of instructions in the basic - // block instead of the basic block object itself. + // block and optionally its back and front contexts instead of the basic block + // object itself. bool AddBasicBlockFromInstructions( - const std::vector& instructions); + const std::vector& instructions, + const std::vector& back_context = std::vector(), + const std::vector& front_context = + std::vector()); // Resets the graph builder so that it can be used to create a new graph from // scratch. @@ -242,6 +251,12 @@ class BasicBlockGraphBuilder { // Feature value of the nodes in the batch (i.e. the indices of the tokens // corresponding to the nodes). Corresponds to `GraphsTuple.nodes`. const std::vector& node_features() const { return node_features_; } + // Whether or not the corresponding node belongs to either the back or front + // context of the basic block, and not the basic block itself. Used by the + // models to exclude context nodes from predictions. + const std::vector& context_node_mask() const { + return context_node_mask_; + } // Names of types of instruction annotations stored. const std::vector& annotation_names() const { @@ -375,11 +390,13 @@ class BasicBlockGraphBuilder { // Adds a new node to the batch; the feature of the node is given directly by // the caller. - NodeIndex AddNode(NodeType node_type, TokenIndex token_index); + NodeIndex AddNode(NodeType node_type, TokenIndex token_index, + bool is_context = false); // Adds a new edge to the batch; the feature of the node is determined from // the token associated with the node. Returns kInvalidNode when the node was // not added. - NodeIndex AddNode(NodeType node_type, const std::string& token); + NodeIndex AddNode(NodeType node_type, const std::string& token, + bool is_context = false); // Adds a new edge to the batch. void AddEdge(EdgeType edge_type, NodeIndex sender, NodeIndex receiver); @@ -406,6 +423,7 @@ class BasicBlockGraphBuilder { std::vector node_types_; std::vector node_features_; + std::vector context_node_mask_; // Mapping from annotation type names to corresponding row index in the // `instruction_annotations_` matrix. diff --git a/gematria/granite/graph_builder_test.cc b/gematria/granite/graph_builder_test.cc index 82cdbca4..30ebaa00 100644 --- a/gematria/granite/graph_builder_test.cc +++ b/gematria/granite/graph_builder_test.cc @@ -33,8 +33,11 @@ namespace gematria { namespace { +using ::testing::_; +using ::testing::Each; using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::testing::IsFalse; using ::testing::Pair; // Tokens used in the basic blocks in tests. For simplicity, we do not use the @@ -490,6 +493,7 @@ TEST_F(BasicBlockGraphBuilderTest, MultipleInstructions) { true, false, // Fourth instruction. true, false)); + EXPECT_THAT(builder_->context_node_mask(), Each(IsFalse())); EXPECT_THAT( builder_->edge_types(), @@ -571,5 +575,97 @@ TEST_F(BasicBlockGraphBuilderTest, TwoNops) { )pb")))); } +// Tests that the graph is built and context node mask is set correctly when +// context is supplied. +TEST_F(BasicBlockGraphBuilderTest, MultipleBasicBlocksWithContext) { + CreateBuilder(OutOfVocabularyTokenBehavior::ReturnError()); + ASSERT_TRUE( + builder_->AddBasicBlock(BasicBlockFromProto(ParseTextProto(R"pb( + canonicalized_instructions: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + } + canonicalized_back_context: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + } + canonicalized_front_context: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + })pb")), + true)); + ASSERT_TRUE( + builder_->AddBasicBlock(BasicBlockFromProto(ParseTextProto(R"pb( + canonicalized_instructions: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + } + canonicalized_back_context: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + })pb")), + true)); + + EXPECT_EQ(builder_->num_graphs(), 2); + EXPECT_EQ(builder_->num_node_tokens(), std::size(kTokens)); + + EXPECT_EQ(builder_->num_nodes(), 7 + 5); + EXPECT_THAT(builder_->num_nodes_per_block(), ElementsAre(7, 5)); + + EXPECT_EQ(builder_->num_edges(), 8 + 5); + EXPECT_THAT(builder_->num_edges_per_block(), ElementsAre(8, 5)); + + EXPECT_THAT(builder_->node_types(), + ElementsAre(NodeType::kInstruction, NodeType::kRegister, + NodeType::kRegister, NodeType::kInstruction, + NodeType::kRegister, NodeType::kInstruction, + NodeType::kRegister, NodeType::kInstruction, + NodeType::kRegister, NodeType::kRegister, + NodeType::kInstruction, NodeType::kRegister)); + EXPECT_THAT( + builder_->node_features(), + ElementsAre(TokenIndex("NOT"), TokenIndex("RCX"), TokenIndex("RCX"), + TokenIndex("NOT"), TokenIndex("RCX"), TokenIndex("NOT"), + TokenIndex("RCX"), TokenIndex("NOT"), TokenIndex("RCX"), + TokenIndex("RCX"), TokenIndex("NOT"), TokenIndex("RCX"))); + EXPECT_THAT(builder_->InstructionNodeMask(), + ElementsAre(true, false, false, true, false, true, false, true, + false, false, true, false)); + EXPECT_THAT(builder_->context_node_mask(), + ElementsAre(true, _, _, false, _, true, _, true, _, _, false, _)); + + EXPECT_THAT( + builder_->edge_types(), + ElementsAre(EdgeType::kInputOperands, EdgeType::kOutputOperands, + EdgeType::kStructuralDependency, EdgeType::kInputOperands, + EdgeType::kOutputOperands, EdgeType::kStructuralDependency, + EdgeType::kInputOperands, EdgeType::kOutputOperands, + EdgeType::kInputOperands, EdgeType::kOutputOperands, + EdgeType::kStructuralDependency, EdgeType::kInputOperands, + EdgeType::kOutputOperands)); + + EXPECT_THAT(builder_->edge_senders(), + ElementsAre(1, 0, 0, 2, 3, 3, 4, 5, 8, 7, 7, 9, 10)); + EXPECT_THAT(builder_->edge_receivers(), + ElementsAre(0, 2, 3, 3, 4, 5, 5, 6, 7, 9, 10, 10, 11)); + + EXPECT_THAT( + builder_->global_features(), + ElementsAre(ElementsAre(0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 4, 0, 0, 0, 0), + ElementsAre(0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0))); + + EXPECT_THAT(builder_->DeltaBlockIndex(), ElementsAre(0, 0, 0, 1, 1)); +} + } // namespace } // namespace gematria diff --git a/gematria/granite/python/BUILD.bazel b/gematria/granite/python/BUILD.bazel index ec928b24..a46e812c 100644 --- a/gematria/granite/python/BUILD.bazel +++ b/gematria/granite/python/BUILD.bazel @@ -48,6 +48,7 @@ gematria_pybind_extension( srcs = ["graph_builder.cc"], visibility = ["//:internal_users"], deps = [ + "//gematria/basic_block", "//gematria/granite:graph_builder", "//gematria/model:oov_token_behavior", "//gematria/proto:canonicalized_instruction_cc_proto", diff --git a/gematria/granite/python/graph_builder.cc b/gematria/granite/python/graph_builder.cc index c238e2d3..74cba048 100644 --- a/gematria/granite/python/graph_builder.cc +++ b/gematria/granite/python/graph_builder.cc @@ -14,11 +14,11 @@ #include "gematria/granite/graph_builder.h" -#include #include #include #include "absl/strings/string_view.h" +#include "gematria/basic_block/basic_block.h" #include "gematria/model/oov_token_behavior.h" #include "gematria/proto/canonicalized_instruction.pb.h" #include "pybind11/cast.h" @@ -81,10 +81,12 @@ PYBIND11_MODULE(graph_builder, m) { py::arg("annotation_names") = std::vector(), py::arg("out_of_vocabulary_behavior")) .def("add_basic_block", &BasicBlockGraphBuilder::AddBasicBlock, - py::arg("block")) + py::arg("block"), py::arg("add_context") = false) .def("add_basic_block_from_instructions", &BasicBlockGraphBuilder::AddBasicBlockFromInstructions, - py::arg("instructions")) + py::arg("instructions"), + py::arg("back_context") = std::vector(), + py::arg("front_context") = std::vector()) .def("reset", &BasicBlockGraphBuilder::Reset) .def_property_readonly("num_node_tokens", &BasicBlockGraphBuilder::num_node_tokens) @@ -99,6 +101,8 @@ PYBIND11_MODULE(graph_builder, m) { &BasicBlockGraphBuilder::node_features) .def_property_readonly("instruction_node_mask", &BasicBlockGraphBuilder::InstructionNodeMask) + .def_property_readonly("context_node_mask", + &BasicBlockGraphBuilder::context_node_mask) .def_property_readonly("annotation_names", &BasicBlockGraphBuilder::annotation_names) .def_property_readonly("instruction_annotations", From 29fb29d020771066503867c079408d1dac2b6897 Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Sun, 9 Feb 2025 01:10:58 +0530 Subject: [PATCH 5/6] Address reviewer suggestions. * Change `back_context` -> `preceding_context` and `front_context` -> `following_context`. * Set `context_node_mask_` appropriately for non-instruction nodes as well, and update tests to reflect this. --- gematria/granite/graph_builder.cc | 56 +++++++++++++++--------- gematria/granite/graph_builder.h | 20 ++++----- gematria/granite/graph_builder_test.cc | 9 ++-- gematria/granite/python/graph_builder.cc | 4 +- 4 files changed, 53 insertions(+), 36 deletions(-) diff --git a/gematria/granite/graph_builder.cc b/gematria/granite/graph_builder.cc index f4f6098d..2c54e952 100644 --- a/gematria/granite/graph_builder.cc +++ b/gematria/granite/graph_builder.cc @@ -191,8 +191,8 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( const std::vector& instructions, - const std::vector& back_context, - const std::vector& front_context) { + const std::vector& preceding_context, + const std::vector& following_context) { if (instructions.empty()) return false; AddBasicBlockTransaction transaction(this); @@ -207,9 +207,10 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( const struct { const std::vector& instruction_group; bool is_context; - } instruction_groups[] = { - {back_context, true}, {instructions, false}, {front_context, true}}; - for (const auto [instruction_group, is_context] : instruction_groups) { + } instruction_groups[] = {{preceding_context, true}, + {instructions, false}, + {following_context, true}}; + for (const auto& [instruction_group, is_context] : instruction_groups) { for (const Instruction& instruction : instruction_group) { // Add the instruction node. const NodeIndex instruction_node = @@ -301,49 +302,53 @@ bool BasicBlockGraphBuilder::AddInputOperand( assert(instruction_node >= 0); assert(instruction_node < num_nodes()); + bool is_context = context_node_mask_[instruction_node]; switch (operand.type()) { case OperandType::kRegister: { if (!AddDependencyOnRegister(instruction_node, operand.register_name(), - EdgeType::kInputOperands)) { + EdgeType::kInputOperands, is_context)) { return false; } } break; case OperandType::kImmediateValue: { AddEdge(EdgeType::kInputOperands, - AddNode(NodeType::kImmediate, immediate_token_), + AddNode(NodeType::kImmediate, immediate_token_, is_context), instruction_node); } break; case OperandType::kFpImmediateValue: { AddEdge(EdgeType::kInputOperands, - AddNode(NodeType::kFpImmediate, fp_immediate_token_), + AddNode(NodeType::kFpImmediate, fp_immediate_token_, is_context), instruction_node); } break; case OperandType::kAddress: { const NodeIndex address_node = - AddNode(NodeType::kAddressOperand, address_token_); + AddNode(NodeType::kAddressOperand, address_token_, is_context); const AddressTuple& address_tuple = operand.address(); if (!address_tuple.base_register.empty()) { if (!AddDependencyOnRegister(address_node, address_tuple.base_register, - EdgeType::kAddressBaseRegister)) { + EdgeType::kAddressBaseRegister, + is_context)) { return false; } } if (!address_tuple.index_register.empty()) { if (!AddDependencyOnRegister(address_node, address_tuple.index_register, - EdgeType::kAddressIndexRegister)) { + EdgeType::kAddressIndexRegister, + is_context)) { return false; } } if (!address_tuple.segment_register.empty()) { - if (!AddDependencyOnRegister(address_node, - address_tuple.segment_register, - EdgeType::kAddressSegmentRegister)) { + if (!AddDependencyOnRegister( + address_node, address_tuple.segment_register, + EdgeType::kAddressSegmentRegister, is_context)) { return false; } } if (address_tuple.displacement != 0) { AddEdge(EdgeType::kAddressDisplacement, - AddNode(NodeType::kImmediate, immediate_token_), address_node); + AddNode(NodeType::kImmediate, immediate_token_, is_context), + address_node); } // NOTE(ondrasej): For now, we explicitly ignore the scaling. AddEdge(EdgeType::kInputOperands, address_node, instruction_node); @@ -352,7 +357,13 @@ bool BasicBlockGraphBuilder::AddInputOperand( NodeIndex& alias_group_node = LookupOrInsert( alias_group_nodes_, operand.alias_group_id(), kInvalidNode); if (alias_group_node == kInvalidNode) { - alias_group_node = AddNode(NodeType::kMemoryOperand, memory_token_); + alias_group_node = + AddNode(NodeType::kMemoryOperand, memory_token_, is_context); + } else if (context_node_mask_[alias_group_node] && !is_context) { + // Update `context_node_mask_` to indicate that `alias_group_node`, + // previously marked as context, is also part of the main block i.e. not + // context. + context_node_mask_[alias_group_node] = false; } AddEdge(EdgeType::kInputOperands, alias_group_node, instruction_node); } break; @@ -369,10 +380,11 @@ bool BasicBlockGraphBuilder::AddOutputOperand( assert(instruction_node >= 0); assert(instruction_node < num_nodes()); + bool is_context = context_node_mask_[instruction_node]; switch (operand.type()) { case OperandType::kRegister: { const NodeIndex register_node = - AddNode(NodeType::kRegister, operand.register_name()); + AddNode(NodeType::kRegister, operand.register_name(), is_context); if (register_node == kInvalidNode) return false; AddEdge(EdgeType::kOutputOperands, instruction_node, register_node); register_nodes_[operand.register_name()] = register_node; @@ -386,7 +398,7 @@ bool BasicBlockGraphBuilder::AddOutputOperand( break; case OperandType::kMemory: { const NodeIndex alias_group_node = - AddNode(NodeType::kMemoryOperand, memory_token_); + AddNode(NodeType::kMemoryOperand, memory_token_, is_context); alias_group_nodes_[operand.alias_group_id()] = alias_group_node; AddEdge(EdgeType::kOutputOperands, instruction_node, alias_group_node); } break; @@ -400,13 +412,17 @@ bool BasicBlockGraphBuilder::AddOutputOperand( bool BasicBlockGraphBuilder::AddDependencyOnRegister( NodeIndex dependent_node, const std::string& register_name, - EdgeType edge_type) { + EdgeType edge_type, bool is_context) { NodeIndex& operand_node = LookupOrInsert(register_nodes_, register_name, kInvalidNode); if (operand_node == kInvalidNode) { // Add a node for the register if it doesn't exist. This also updates the // node index in `node_by_register`. - operand_node = AddNode(NodeType::kRegister, register_name); + operand_node = AddNode(NodeType::kRegister, register_name, is_context); + } else if (context_node_mask_[operand_node] && !is_context) { + // Update `context_node_mask_` to indicate that `operand_node`, previously + // marked as context, is also part of the main block i.e. not context. + context_node_mask_[operand_node] = false; } if (operand_node == kInvalidNode) return false; AddEdge(edge_type, operand_node, dependent_node); diff --git a/gematria/granite/graph_builder.h b/gematria/granite/graph_builder.h index fcb69cd3..4025daca 100644 --- a/gematria/granite/graph_builder.h +++ b/gematria/granite/graph_builder.h @@ -91,7 +91,6 @@ #include #include -#include #include #include #include @@ -192,17 +191,18 @@ class BasicBlockGraphBuilder { bool AddBasicBlock(const BasicBlock& block, bool add_context = false) { if (add_context) { return AddBasicBlockFromInstructions( - block.instructions, block.back_context, block.front_context); + block.instructions, block.preceding_context, block.following_context); } return AddBasicBlockFromInstructions(block.instructions); } // A version of AddBasicBlock that takes the list of instructions in the basic - // block and optionally its back and front contexts instead of the basic block - // object itself. + // block and optionally its preceding and following contexts instead of the + // basic block object itself. bool AddBasicBlockFromInstructions( const std::vector& instructions, - const std::vector& back_context = std::vector(), - const std::vector& front_context = + const std::vector& preceding_context = + std::vector(), + const std::vector& following_context = std::vector()); // Resets the graph builder so that it can be used to create a new graph from @@ -251,9 +251,9 @@ class BasicBlockGraphBuilder { // Feature value of the nodes in the batch (i.e. the indices of the tokens // corresponding to the nodes). Corresponds to `GraphsTuple.nodes`. const std::vector& node_features() const { return node_features_; } - // Whether or not the corresponding node belongs to either the back or front - // context of the basic block, and not the basic block itself. Used by the - // models to exclude context nodes from predictions. + // Whether or not the corresponding node belongs to either the preceding or + // following context of the basic block, and not the basic block itself. Used + // by the models to exclude context nodes from predictions. const std::vector& context_node_mask() const { return context_node_mask_; } @@ -386,7 +386,7 @@ class BasicBlockGraphBuilder { // a register. Adds the register node if it doesn't exist in the graph. bool AddDependencyOnRegister(NodeIndex dependent_node, const std::string& register_name, - EdgeType edge_type); + EdgeType edge_type, bool is_context = false); // Adds a new node to the batch; the feature of the node is given directly by // the caller. diff --git a/gematria/granite/graph_builder_test.cc b/gematria/granite/graph_builder_test.cc index 30ebaa00..68ac52f1 100644 --- a/gematria/granite/graph_builder_test.cc +++ b/gematria/granite/graph_builder_test.cc @@ -587,13 +587,13 @@ TEST_F(BasicBlockGraphBuilderTest, MultipleBasicBlocksWithContext) { output_operands: { register_name: "RCX" } input_operands: { register_name: "RCX" } } - canonicalized_back_context: { + canonicalized_preceding_context: { mnemonic: "NOT" llvm_mnemonic: "NOT64r" output_operands: { register_name: "RCX" } input_operands: { register_name: "RCX" } } - canonicalized_front_context: { + canonicalized_following_context: { mnemonic: "NOT" llvm_mnemonic: "NOT64r" output_operands: { register_name: "RCX" } @@ -608,7 +608,7 @@ TEST_F(BasicBlockGraphBuilderTest, MultipleBasicBlocksWithContext) { output_operands: { register_name: "RCX" } input_operands: { register_name: "RCX" } } - canonicalized_back_context: { + canonicalized_preceding_context: { mnemonic: "NOT" llvm_mnemonic: "NOT64r" output_operands: { register_name: "RCX" } @@ -642,7 +642,8 @@ TEST_F(BasicBlockGraphBuilderTest, MultipleBasicBlocksWithContext) { ElementsAre(true, false, false, true, false, true, false, true, false, false, true, false)); EXPECT_THAT(builder_->context_node_mask(), - ElementsAre(true, _, _, false, _, true, _, true, _, _, false, _)); + ElementsAre(true, true, false, false, false, true, true, true, + true, false, false, false)); EXPECT_THAT( builder_->edge_types(), diff --git a/gematria/granite/python/graph_builder.cc b/gematria/granite/python/graph_builder.cc index 74cba048..c90e9d63 100644 --- a/gematria/granite/python/graph_builder.cc +++ b/gematria/granite/python/graph_builder.cc @@ -85,8 +85,8 @@ PYBIND11_MODULE(graph_builder, m) { .def("add_basic_block_from_instructions", &BasicBlockGraphBuilder::AddBasicBlockFromInstructions, py::arg("instructions"), - py::arg("back_context") = std::vector(), - py::arg("front_context") = std::vector()) + py::arg("preceding_context") = std::vector(), + py::arg("following_context") = std::vector()) .def("reset", &BasicBlockGraphBuilder::Reset) .def_property_readonly("num_node_tokens", &BasicBlockGraphBuilder::num_node_tokens) From aa1fe798a75eac379fbe87c68755e64652b56a5f Mon Sep 17 00:00:00 2001 From: Viraj Shah Date: Mon, 9 Jun 2025 19:35:52 +0530 Subject: [PATCH 6/6] Address reviews. --- gematria/granite/graph_builder.cc | 24 ++++++++++++++---------- gematria/granite/graph_builder.h | 6 ++++-- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/gematria/granite/graph_builder.cc b/gematria/granite/graph_builder.cc index 2c54e952..6a0c913c 100644 --- a/gematria/granite/graph_builder.cc +++ b/gematria/granite/graph_builder.cc @@ -246,20 +246,24 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( // Add edges for input operands. And nodes too, if necessary. for (const InstructionOperand& operand : instruction.input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; + if (!AddInputOperand(instruction_node, operand, is_context)) + return false; } for (const InstructionOperand& operand : instruction.implicit_input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; + if (!AddInputOperand(instruction_node, operand, is_context)) + return false; } // Add edges and nodes for output operands. for (const InstructionOperand& operand : instruction.output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; + if (!AddOutputOperand(instruction_node, operand, is_context)) + return false; } for (const InstructionOperand& operand : instruction.implicit_output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; + if (!AddOutputOperand(instruction_node, operand, is_context)) + return false; } previous_instruction_node = instruction_node; @@ -297,12 +301,12 @@ void BasicBlockGraphBuilder::Reset() { instruction_annotations_.clear(); } -bool BasicBlockGraphBuilder::AddInputOperand( - NodeIndex instruction_node, const InstructionOperand& operand) { +bool BasicBlockGraphBuilder::AddInputOperand(NodeIndex instruction_node, + const InstructionOperand& operand, + bool is_context) { assert(instruction_node >= 0); assert(instruction_node < num_nodes()); - bool is_context = context_node_mask_[instruction_node]; switch (operand.type()) { case OperandType::kRegister: { if (!AddDependencyOnRegister(instruction_node, operand.register_name(), @@ -375,12 +379,12 @@ bool BasicBlockGraphBuilder::AddInputOperand( return true; } -bool BasicBlockGraphBuilder::AddOutputOperand( - NodeIndex instruction_node, const InstructionOperand& operand) { +bool BasicBlockGraphBuilder::AddOutputOperand(NodeIndex instruction_node, + const InstructionOperand& operand, + bool is_context) { assert(instruction_node >= 0); assert(instruction_node < num_nodes()); - bool is_context = context_node_mask_[instruction_node]; switch (operand.type()) { case OperandType::kRegister: { const NodeIndex register_node = diff --git a/gematria/granite/graph_builder.h b/gematria/granite/graph_builder.h index 4025daca..8f110dd4 100644 --- a/gematria/granite/graph_builder.h +++ b/gematria/granite/graph_builder.h @@ -377,10 +377,12 @@ class BasicBlockGraphBuilder { // Adds nodes and edges for a single input operand of an instruction. bool AddInputOperand(NodeIndex instruction_node, - const InstructionOperand& operand); + const InstructionOperand& operand, + bool is_context = false); // Adds nodes and edges for a single output operand of an instruction. bool AddOutputOperand(NodeIndex instruction_node, - const InstructionOperand& operand); + const InstructionOperand& operand, + bool is_context = false); // Adds dependency of a node (instruction or an address computation node) on // a register. Adds the register node if it doesn't exist in the graph.