diff --git a/gematria/basic_block/basic_block.cc b/gematria/basic_block/basic_block.cc index 466cab16..f5aebef6 100644 --- a/gematria/basic_block/basic_block.cc +++ b/gematria/basic_block/basic_block.cc @@ -377,11 +377,17 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction) { return os; } -BasicBlock::BasicBlock(std::vector instructions) - : instructions(std::move(instructions)) {} +BasicBlock::BasicBlock(std::vector instructions, + std::vector back_context, + std::vector front_context) + : instructions(std::move(instructions)), + back_context(std::move(back_context)), + front_context(std::move(front_context)) {} bool BasicBlock::operator==(const BasicBlock& other) const { - return instructions == other.instructions; + return instructions == other.instructions && + back_context == other.back_context && + front_context == other.front_context; } std::string BasicBlock::ToString() const { @@ -395,6 +401,24 @@ std::string BasicBlock::ToString() const { if (buffer.back() == ' ') buffer.pop_back(); buffer += "))"; } + if (!back_context.empty()) { + buffer += "back_context=InstructionList(("; + for (const Instruction& instruction : back_context) { + buffer += instruction.ToString(); + buffer += ", "; + } + if (buffer.back() == ' ') buffer.pop_back(); + buffer += "))"; + } + if (!front_context.empty()) { + buffer += "front_context=InstructionList(("; + for (const Instruction& instruction : front_context) { + buffer += instruction.ToString(); + buffer += ", "; + } + if (buffer.back() == ' ') buffer.pop_back(); + buffer += "))"; + } buffer.push_back(')'); return buffer; } diff --git a/gematria/basic_block/basic_block.h b/gematria/basic_block/basic_block.h index 43579fb7..735aa3d1 100644 --- a/gematria/basic_block/basic_block.h +++ b/gematria/basic_block/basic_block.h @@ -221,7 +221,7 @@ std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand); // Represents an annotation holding a value such as some measure/statistic // paired with the instruction. struct Annotation { - Annotation() : value(-1){}; + Annotation() : value(-1) {}; // Initializes all fields of the annotation. Annotation(std::string name, double value); @@ -324,9 +324,12 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instruction); struct BasicBlock { BasicBlock() {} - // Initializes the basic block from a list of instructions. Needed for - // compatibility with the Python code. - explicit BasicBlock(std::vector instructions); + // Initializes the basic block from a list of instructions and optional + // context. Needed for compatibility with the Python code. + explicit BasicBlock( + std::vector instructions, + std::vector back_context = std::vector(), + std::vector front_context = std::vector()); BasicBlock(const BasicBlock&) = default; BasicBlock(BasicBlock&&) = default; @@ -346,6 +349,11 @@ struct BasicBlock { // The list of instructions in the basic block. std::vector instructions; + + // The back and front context instructions, i.e. those preceeding and + // following the instructions in the basic block. + std::vector back_context; + std::vector front_context; }; std::ostream& operator<<(std::ostream& os, const BasicBlock& block); diff --git a/gematria/basic_block/basic_block_protos.cc b/gematria/basic_block/basic_block_protos.cc index 7b8c0a0c..d327901b 100644 --- a/gematria/basic_block/basic_block_protos.cc +++ b/gematria/basic_block/basic_block_protos.cc @@ -180,8 +180,15 @@ CanonicalizedInstructionProto ProtoFromInstruction( BasicBlock BasicBlockFromProto(const BasicBlockProto& proto) { return BasicBlock( - /* instructions = */ ToVector( - proto.canonicalized_instructions(), InstructionFromProto)); + /* instructions = */ + ToVector(proto.canonicalized_instructions(), + InstructionFromProto), + /* back_context = */ + ToVector(proto.canonicalized_back_context(), + InstructionFromProto), + /* front_context = */ + ToVector(proto.canonicalized_front_context(), + InstructionFromProto)); } } // namespace gematria diff --git a/gematria/basic_block/basic_block_test.cc b/gematria/basic_block/basic_block_test.cc index 3977c560..348523d0 100644 --- a/gematria/basic_block/basic_block_test.cc +++ b/gematria/basic_block/basic_block_test.cc @@ -277,7 +277,7 @@ TEST(InstructionOperandTest, Equality) { TEST(InstructionOperandTest, ToString) { const struct { - InstructionOperand opernad; + InstructionOperand operand; const char* expected_string; } kTestCases[] = { {InstructionOperand::Register("RAX"), @@ -292,7 +292,7 @@ TEST(InstructionOperandTest, ToString) { "InstructionOperand.from_memory(32)"}}; for (const auto& test_case : kTestCases) { - EXPECT_EQ(test_case.opernad.ToString(), test_case.expected_string); + EXPECT_EQ(test_case.operand.ToString(), test_case.expected_string); } } @@ -318,7 +318,6 @@ TEST(InstructionOperandTest, AsTokenList) { } } -// TODO(virajbshah): Add tests for Annotation. TEST(AnnotationTest, Constructor) { constexpr char kName[] = "cache_miss_freq"; constexpr double kValue = 0.875; diff --git a/gematria/basic_block/python/basic_block.cc b/gematria/basic_block/python/basic_block.cc index e78c6111..b583dabb 100644 --- a/gematria/basic_block/python/basic_block.cc +++ b/gematria/basic_block/python/basic_block.cc @@ -255,9 +255,15 @@ PYBIND11_MODULE(basic_block, m) { py::class_ basic_block(m, "BasicBlock"); basic_block - .def(py::init /* instructions */>(), - py::arg("instructions") = std::vector()) + .def(py::init /* instructions */, + std::vector /* back_context */, + std::vector /* front_context */>(), + py::arg("instructions") = std::vector(), + py::arg("back_context") = std::vector(), + py::arg("front_context") = std::vector()) .def_readwrite("instructions", &BasicBlock::instructions) + .def_readwrite("back_context", &BasicBlock::back_context) + .def_readwrite("front_context", &BasicBlock::front_context) .def("__repr__", &BasicBlock::ToString) .def("__str__", &BasicBlock::ToString) .def("__eq__", &BasicBlock::operator==) diff --git a/gematria/granite/graph_builder.cc b/gematria/granite/graph_builder.cc index 30f6c4ab..f4f6098d 100644 --- a/gematria/granite/graph_builder.cc +++ b/gematria/granite/graph_builder.cc @@ -190,7 +190,9 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( } bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( - const std::vector& instructions) { + const std::vector& instructions, + const std::vector& back_context, + const std::vector& front_context) { if (instructions.empty()) return false; AddBasicBlockTransaction transaction(this); @@ -202,58 +204,65 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( const int prev_num_edges = num_edges(); NodeIndex previous_instruction_node = kInvalidNode; - for (const Instruction& instruction : instructions) { - // Add the instruction node. - const NodeIndex instruction_node = - AddNode(NodeType::kInstruction, instruction.mnemonic); - if (instruction_node == kInvalidNode) { - return false; - } + const struct { + const std::vector& instruction_group; + bool is_context; + } instruction_groups[] = { + {back_context, true}, {instructions, false}, {front_context, true}}; + for (const auto [instruction_group, is_context] : instruction_groups) { + for (const Instruction& instruction : instruction_group) { + // Add the instruction node. + const NodeIndex instruction_node = + AddNode(NodeType::kInstruction, instruction.mnemonic, is_context); + if (instruction_node == kInvalidNode) { + return false; + } - // Store the annotations for later use (inclusion in embeddings), using -1 - // as a default value wherever annotations are missing. - std::vector row = std::vector(annotation_names_.size(), -1); - for (const auto& [name, value] : instruction.instruction_annotations) { - const auto annotation_index = annotation_name_to_idx_.find(name); - if (annotation_index == annotation_name_to_idx_.end()) continue; - row[annotation_index->second] = value; - } - instruction_annotations_.push_back(row); + // Store the annotations for later use (inclusion in embeddings), using -1 + // as a default value wherever annotations are missing. + std::vector row = std::vector(annotation_names_.size(), -1); + for (const auto& [name, value] : instruction.instruction_annotations) { + const auto annotation_index = annotation_name_to_idx_.find(name); + if (annotation_index == annotation_name_to_idx_.end()) continue; + row[annotation_index->second] = value; + } + instruction_annotations_.push_back(row); - // Add nodes for prefixes of the instruction. - for (const std::string& prefix : instruction.prefixes) { - const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); - if (prefix_node == kInvalidNode) { - return false; + // Add nodes for prefixes of the instruction. + for (const std::string& prefix : instruction.prefixes) { + const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); + if (prefix_node == kInvalidNode) { + return false; + } + AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); } - AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); - } - // Add a structural dependency edge from the previous instruction. - if (previous_instruction_node >= 0) { - AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, - instruction_node); - } + // Add a structural dependency edge from the previous instruction. + if (previous_instruction_node >= 0) { + AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, + instruction_node); + } - // Add edges for input operands. And nodes too, if necessary. - for (const InstructionOperand& operand : instruction.input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } + // Add edges for input operands. And nodes too, if necessary. + for (const InstructionOperand& operand : instruction.input_operands) { + if (!AddInputOperand(instruction_node, operand)) return false; + } + for (const InstructionOperand& operand : + instruction.implicit_input_operands) { + if (!AddInputOperand(instruction_node, operand)) return false; + } - // Add edges and nodes for output operands. - for (const InstructionOperand& operand : instruction.output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } + // Add edges and nodes for output operands. + for (const InstructionOperand& operand : instruction.output_operands) { + if (!AddOutputOperand(instruction_node, operand)) return false; + } + for (const InstructionOperand& operand : + instruction.implicit_output_operands) { + if (!AddOutputOperand(instruction_node, operand)) return false; + } - previous_instruction_node = instruction_node; + previous_instruction_node = instruction_node; + } } global_features_.emplace_back(num_node_tokens(), 0); @@ -276,6 +285,7 @@ void BasicBlockGraphBuilder::Reset() { node_types_.clear(); node_features_.clear(); + context_node_mask_.clear(); edge_senders_.clear(); edge_receivers_.clear(); @@ -404,15 +414,16 @@ bool BasicBlockGraphBuilder::AddDependencyOnRegister( } BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode( - NodeType node_type, TokenIndex token_index) { + NodeType node_type, TokenIndex token_index, bool is_context) { const NodeIndex new_node_index = num_nodes(); node_types_.push_back(node_type); node_features_.push_back(token_index); + context_node_mask_.push_back(is_context); return new_node_index; } BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode( - NodeType node_type, const std::string& token) { + NodeType node_type, const std::string& token, bool is_context) { const auto it = node_tokens_.find(token); TokenIndex token_index = kInvalidTokenIndex; if (it != node_tokens_.end()) { @@ -427,7 +438,7 @@ BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode( token_index = replacement_token_; } } - return AddNode(node_type, token_index); + return AddNode(node_type, token_index, is_context); } void BasicBlockGraphBuilder::AddEdge(EdgeType edge_type, NodeIndex sender, @@ -505,6 +516,7 @@ std::string BasicBlockGraphBuilder::DebugString() const { StrAppendList(buffer, "num_nodes_per_block", num_nodes_per_block()); StrAppendList(buffer, "num_edges_per_block", num_edges_per_block()); StrAppendList(buffer, "node_types", node_types()); + StrAppendList(buffer, "context_node_mask", context_node_mask()); StrAppendList(buffer, "edge_senders", edge_senders()); StrAppendList(buffer, "edge_receivers", edge_receivers()); StrAppendList(buffer, "edge_types", edge_types()); diff --git a/gematria/granite/graph_builder.h b/gematria/granite/graph_builder.h index c4684eea..fcb69cd3 100644 --- a/gematria/granite/graph_builder.h +++ b/gematria/granite/graph_builder.h @@ -187,14 +187,23 @@ class BasicBlockGraphBuilder { // method encountered an unknown token and the unknown token behavior is not // kReplaceToken or when the basic block does not contain any instructions. // When this happens, the graph builder is left in the previous state, i.e. no - // basic block is added to it. - bool AddBasicBlock(const BasicBlock& block) { + // basic block is added to it. The basic block context is added to the graph + // if and only if `add_context` is true. + bool AddBasicBlock(const BasicBlock& block, bool add_context = false) { + if (add_context) { + return AddBasicBlockFromInstructions( + block.instructions, block.back_context, block.front_context); + } return AddBasicBlockFromInstructions(block.instructions); } // A version of AddBasicBlock that takes the list of instructions in the basic - // block instead of the basic block object itself. + // block and optionally its back and front contexts instead of the basic block + // object itself. bool AddBasicBlockFromInstructions( - const std::vector& instructions); + const std::vector& instructions, + const std::vector& back_context = std::vector(), + const std::vector& front_context = + std::vector()); // Resets the graph builder so that it can be used to create a new graph from // scratch. @@ -242,6 +251,12 @@ class BasicBlockGraphBuilder { // Feature value of the nodes in the batch (i.e. the indices of the tokens // corresponding to the nodes). Corresponds to `GraphsTuple.nodes`. const std::vector& node_features() const { return node_features_; } + // Whether or not the corresponding node belongs to either the back or front + // context of the basic block, and not the basic block itself. Used by the + // models to exclude context nodes from predictions. + const std::vector& context_node_mask() const { + return context_node_mask_; + } // Names of types of instruction annotations stored. const std::vector& annotation_names() const { @@ -375,11 +390,13 @@ class BasicBlockGraphBuilder { // Adds a new node to the batch; the feature of the node is given directly by // the caller. - NodeIndex AddNode(NodeType node_type, TokenIndex token_index); + NodeIndex AddNode(NodeType node_type, TokenIndex token_index, + bool is_context = false); // Adds a new edge to the batch; the feature of the node is determined from // the token associated with the node. Returns kInvalidNode when the node was // not added. - NodeIndex AddNode(NodeType node_type, const std::string& token); + NodeIndex AddNode(NodeType node_type, const std::string& token, + bool is_context = false); // Adds a new edge to the batch. void AddEdge(EdgeType edge_type, NodeIndex sender, NodeIndex receiver); @@ -406,6 +423,7 @@ class BasicBlockGraphBuilder { std::vector node_types_; std::vector node_features_; + std::vector context_node_mask_; // Mapping from annotation type names to corresponding row index in the // `instruction_annotations_` matrix. diff --git a/gematria/granite/graph_builder_test.cc b/gematria/granite/graph_builder_test.cc index 82cdbca4..30ebaa00 100644 --- a/gematria/granite/graph_builder_test.cc +++ b/gematria/granite/graph_builder_test.cc @@ -33,8 +33,11 @@ namespace gematria { namespace { +using ::testing::_; +using ::testing::Each; using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::testing::IsFalse; using ::testing::Pair; // Tokens used in the basic blocks in tests. For simplicity, we do not use the @@ -490,6 +493,7 @@ TEST_F(BasicBlockGraphBuilderTest, MultipleInstructions) { true, false, // Fourth instruction. true, false)); + EXPECT_THAT(builder_->context_node_mask(), Each(IsFalse())); EXPECT_THAT( builder_->edge_types(), @@ -571,5 +575,97 @@ TEST_F(BasicBlockGraphBuilderTest, TwoNops) { )pb")))); } +// Tests that the graph is built and context node mask is set correctly when +// context is supplied. +TEST_F(BasicBlockGraphBuilderTest, MultipleBasicBlocksWithContext) { + CreateBuilder(OutOfVocabularyTokenBehavior::ReturnError()); + ASSERT_TRUE( + builder_->AddBasicBlock(BasicBlockFromProto(ParseTextProto(R"pb( + canonicalized_instructions: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + } + canonicalized_back_context: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + } + canonicalized_front_context: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + })pb")), + true)); + ASSERT_TRUE( + builder_->AddBasicBlock(BasicBlockFromProto(ParseTextProto(R"pb( + canonicalized_instructions: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + } + canonicalized_back_context: { + mnemonic: "NOT" + llvm_mnemonic: "NOT64r" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RCX" } + })pb")), + true)); + + EXPECT_EQ(builder_->num_graphs(), 2); + EXPECT_EQ(builder_->num_node_tokens(), std::size(kTokens)); + + EXPECT_EQ(builder_->num_nodes(), 7 + 5); + EXPECT_THAT(builder_->num_nodes_per_block(), ElementsAre(7, 5)); + + EXPECT_EQ(builder_->num_edges(), 8 + 5); + EXPECT_THAT(builder_->num_edges_per_block(), ElementsAre(8, 5)); + + EXPECT_THAT(builder_->node_types(), + ElementsAre(NodeType::kInstruction, NodeType::kRegister, + NodeType::kRegister, NodeType::kInstruction, + NodeType::kRegister, NodeType::kInstruction, + NodeType::kRegister, NodeType::kInstruction, + NodeType::kRegister, NodeType::kRegister, + NodeType::kInstruction, NodeType::kRegister)); + EXPECT_THAT( + builder_->node_features(), + ElementsAre(TokenIndex("NOT"), TokenIndex("RCX"), TokenIndex("RCX"), + TokenIndex("NOT"), TokenIndex("RCX"), TokenIndex("NOT"), + TokenIndex("RCX"), TokenIndex("NOT"), TokenIndex("RCX"), + TokenIndex("RCX"), TokenIndex("NOT"), TokenIndex("RCX"))); + EXPECT_THAT(builder_->InstructionNodeMask(), + ElementsAre(true, false, false, true, false, true, false, true, + false, false, true, false)); + EXPECT_THAT(builder_->context_node_mask(), + ElementsAre(true, _, _, false, _, true, _, true, _, _, false, _)); + + EXPECT_THAT( + builder_->edge_types(), + ElementsAre(EdgeType::kInputOperands, EdgeType::kOutputOperands, + EdgeType::kStructuralDependency, EdgeType::kInputOperands, + EdgeType::kOutputOperands, EdgeType::kStructuralDependency, + EdgeType::kInputOperands, EdgeType::kOutputOperands, + EdgeType::kInputOperands, EdgeType::kOutputOperands, + EdgeType::kStructuralDependency, EdgeType::kInputOperands, + EdgeType::kOutputOperands)); + + EXPECT_THAT(builder_->edge_senders(), + ElementsAre(1, 0, 0, 2, 3, 3, 4, 5, 8, 7, 7, 9, 10)); + EXPECT_THAT(builder_->edge_receivers(), + ElementsAre(0, 2, 3, 3, 4, 5, 5, 6, 7, 9, 10, 10, 11)); + + EXPECT_THAT( + builder_->global_features(), + ElementsAre(ElementsAre(0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 4, 0, 0, 0, 0), + ElementsAre(0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0))); + + EXPECT_THAT(builder_->DeltaBlockIndex(), ElementsAre(0, 0, 0, 1, 1)); +} + } // namespace } // namespace gematria diff --git a/gematria/granite/python/BUILD.bazel b/gematria/granite/python/BUILD.bazel index ec928b24..a46e812c 100644 --- a/gematria/granite/python/BUILD.bazel +++ b/gematria/granite/python/BUILD.bazel @@ -48,6 +48,7 @@ gematria_pybind_extension( srcs = ["graph_builder.cc"], visibility = ["//:internal_users"], deps = [ + "//gematria/basic_block", "//gematria/granite:graph_builder", "//gematria/model:oov_token_behavior", "//gematria/proto:canonicalized_instruction_cc_proto", diff --git a/gematria/granite/python/graph_builder.cc b/gematria/granite/python/graph_builder.cc index c238e2d3..74cba048 100644 --- a/gematria/granite/python/graph_builder.cc +++ b/gematria/granite/python/graph_builder.cc @@ -14,11 +14,11 @@ #include "gematria/granite/graph_builder.h" -#include #include #include #include "absl/strings/string_view.h" +#include "gematria/basic_block/basic_block.h" #include "gematria/model/oov_token_behavior.h" #include "gematria/proto/canonicalized_instruction.pb.h" #include "pybind11/cast.h" @@ -81,10 +81,12 @@ PYBIND11_MODULE(graph_builder, m) { py::arg("annotation_names") = std::vector(), py::arg("out_of_vocabulary_behavior")) .def("add_basic_block", &BasicBlockGraphBuilder::AddBasicBlock, - py::arg("block")) + py::arg("block"), py::arg("add_context") = false) .def("add_basic_block_from_instructions", &BasicBlockGraphBuilder::AddBasicBlockFromInstructions, - py::arg("instructions")) + py::arg("instructions"), + py::arg("back_context") = std::vector(), + py::arg("front_context") = std::vector()) .def("reset", &BasicBlockGraphBuilder::Reset) .def_property_readonly("num_node_tokens", &BasicBlockGraphBuilder::num_node_tokens) @@ -99,6 +101,8 @@ PYBIND11_MODULE(graph_builder, m) { &BasicBlockGraphBuilder::node_features) .def_property_readonly("instruction_node_mask", &BasicBlockGraphBuilder::InstructionNodeMask) + .def_property_readonly("context_node_mask", + &BasicBlockGraphBuilder::context_node_mask) .def_property_readonly("annotation_names", &BasicBlockGraphBuilder::annotation_names) .def_property_readonly("instruction_annotations", diff --git a/gematria/granite/python/graph_builder_model_base.py b/gematria/granite/python/graph_builder_model_base.py index bd32485c..2e035a9d 100644 --- a/gematria/granite/python/graph_builder_model_base.py +++ b/gematria/granite/python/graph_builder_model_base.py @@ -58,6 +58,9 @@ class GraphBuilderModelBase( 'GraphBuilderModelBase.instruction_node_mask' ) + # The name of the input tensor that receives the context node mask. + CONTEXT_NODE_MASK_TENSOR_NAME = 'GraphBuilderModelBase.context_node_mask' + # The name of the input tensor that holds the instruction annotations. INSTRUCTION_ANNOTATIONS_TENSOR_NAME = ( 'GraphBuilderModelBase.instruction_annotations' @@ -75,6 +78,12 @@ class GraphBuilderModelBase( # further processing during readout. _instruction_node_mask: tf.Tensor + # A Boolean tensor placeholder that receives a mask for context nodes of the + # same shape as `_instruction_node_mask`. A given element is True if the + # corresponding node is an instruction belonging to either the back or front + # context. The mask is used to exclude context nodes from predictions. + _context_node_mask: tf.Tensor + # A tensor that contains feature vectors of nodes representing instructions in # the order in which they are in the basic block, i.e. in the same order # instructions appear in ModelBase._output_tensor_deltas. @@ -237,12 +246,18 @@ def _create_graph_network_resources(self) -> None: shape=(None,), name=GraphBuilderModelBase.INSTRUCTION_NODE_MASK_TENSOR_NAME, ) + self._context_node_mask = tf.placeholder( + dtype=tf.dtypes.bool, + shape=(None,), + name=GraphBuilderModelBase.CONTEXT_NODE_MASK_TENSOR_NAME, + ) # @Override def _create_readout_network_resources(self) -> None: super()._create_readout_network_resources() self._instruction_features = tf.boolean_mask( - self._graphs_tuple_outputs.nodes, self._instruction_node_mask + self._graphs_tuple_outputs.nodes, + self._instruction_node_mask & ~self._context_node_mask, ) # @Override @@ -256,6 +271,9 @@ def _make_batch_feed_dict(self) -> model_base.FeedDict: feed_dict[self._instruction_node_mask] = np.array( self._batch_graph_builder.instruction_node_mask, dtype=bool ) + feed_dict[self._context_node_mask] = np.array( + self._batch_graph_builder.context_node_mask, dtype=bool + ) feed_dict[self._instruction_annotations] = ( self._batch_graph_builder.instruction_annotations ) @@ -311,7 +329,10 @@ def _make_batch_graphs_tuple(self): # @Override def _add_basic_block_to_batch(self, block: basic_block.BasicBlock) -> None: - basic_block_was_added = self._batch_graph_builder.add_basic_block(block) + # Add context to the basic block graph only for seq2seq models. + basic_block_was_added = self._batch_graph_builder.add_basic_block( + block, add_context=self.use_deltas + ) if not basic_block_was_added: # TODO(ondrasej): Better handling of blocks that can't be added to the # batch. For now, we just let the exception propagate out of the model and diff --git a/gematria/granite/python/graph_builder_model_base_test.py b/gematria/granite/python/graph_builder_model_base_test.py index d4669455..85f87b86 100644 --- a/gematria/granite/python/graph_builder_model_base_test.py +++ b/gematria/granite/python/graph_builder_model_base_test.py @@ -215,6 +215,39 @@ def test_train_seq2seq_model(self, loss_type, loss_normalization): model, self.blocks_with_throughput[0:1], num_epochs=50 ) + @parameterized.named_parameters( + *model_test.LOSS_TYPES_AND_LOSS_NORMALIZATIONS + ) + def test_train_seq2seq_context_model(self, loss_type, loss_normalization): + blocks_with_throughput = self.blocks_with_throughput[0:1] + # Altered version of the regular basic blocks where all instructions belong + # to the context - i.e. an empty blocks with contexts matching the original + # basic blocks. + altered_blocks_with_throughput = [ + throughput.BasicBlockWithThroughput( + block=basic_block.BasicBlock( + instructions=original.block.instructions, + back_context=original.block.instructions, + front_context=original.block.instructions, + ), + throughputs=original.throughputs, + ) + for original in blocks_with_throughput + ] + + model = TestGraphBuilderModel( + tokens=self.tokens, + loss_type=loss_type, + use_deltas=True, + use_delta_loss=False, + loss_normalization=loss_normalization, + num_message_passing_iterations=1, + ) + model.initialize() + self.check_training_model( + model, altered_blocks_with_throughput, num_epochs=50 + ) + def test_validate_basic_block(self): model = TestGraphBuilderModel( tokens=self.tokens, num_message_passing_iterations=1 diff --git a/gematria/proto/basic_block.proto b/gematria/proto/basic_block.proto index 4276518f..86269cea 100644 --- a/gematria/proto/basic_block.proto +++ b/gematria/proto/basic_block.proto @@ -30,8 +30,26 @@ message BasicBlockProto { // same instruction. repeated CanonicalizedInstructionProto canonicalized_instructions = 2; + // An optional list of machine instructions preceding the basic block, used + // to provide context that lies before `canonicalized_instructions`. These + // instructions are not included in the timing measurements and predictions. + repeated CanonicalizedInstructionProto machine_back_context = 3; + + // An optional list of machine instructions following the basic block, used + // to provide context lying after `canonicalized_instructions`. These + // instructions are not included in the timing measurements and predictions. + repeated CanonicalizedInstructionProto machine_front_context = 4; + + // Canonicalized instructions parallel to `machine_back_context`. May be + // empty in case no back context is provided. + repeated CanonicalizedInstructionProto canonicalized_back_context = 5; + + // Canonicalized instructions parallel to `machine_front_context`. May be + // empty in case no front context is provided. + repeated CanonicalizedInstructionProto canonicalized_front_context = 6; + // The fingerprint-id of this basic block. Might be empty. - string fingerprint = 3; + string fingerprint = 7; } // Represents a raw instruction extracted from binary code.