diff --git a/gematria/granite/graph_builder.cc b/gematria/granite/graph_builder.cc index 30f6c4ab..7452cd91 100644 --- a/gematria/granite/graph_builder.cc +++ b/gematria/granite/graph_builder.cc @@ -34,6 +34,7 @@ namespace { constexpr BasicBlockGraphBuilder::NodeIndex kInvalidNode(-1); constexpr BasicBlockGraphBuilder::TokenIndex kInvalidTokenIndex(-1); +constexpr double kDefaultInstructionAnnotation(-1); std::unordered_map MakeIndex( std::vector items) { @@ -172,7 +173,7 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( : FindTokenOrDie( node_tokens_, out_of_vocabulary_behavior.replacement_token())) { - instruction_annotations_ = std::vector>(); + instruction_annotations_ = std::vector>(); // Make sure annotations are stored in a stable order as long the same // annotation names are used. @@ -183,12 +184,60 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( // Store row indices corresponding to specific annotation names. int annotation_idx = 0; - for (auto& annotation_name : annotation_names_) { + for (const std::string& annotation_name : annotation_names_) { annotation_name_to_idx_[annotation_name] = annotation_idx; ++annotation_idx; } } +BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddInstruction( + const Instruction& instruction, NodeIndex previous_instruction_node) { + // Add the instruction node. + const NodeIndex instruction_node = + AddNode(NodeType::kInstruction, instruction.mnemonic); + if (instruction_node == kInvalidNode) { + return kInvalidNode; + } + + // Store instruction annotations. + AddInstructionAnnotations(instruction); + + // Add nodes for prefixes of the instruction. + for (const std::string& prefix : instruction.prefixes) { + const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); + if (prefix_node == kInvalidNode) { + return kInvalidNode; + } + AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); + } + + // Add a structural dependency edge from the previous instruction. + if (previous_instruction_node >= 0) { + AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, + instruction_node); + } + + // Add edges for input operands. And nodes too, if necessary. + for (const InstructionOperand& operand : instruction.input_operands) { + if (!AddInputOperand(instruction_node, operand)) return kInvalidNode; + } + for (const InstructionOperand& operand : + instruction.implicit_input_operands) { + if (!AddInputOperand(instruction_node, operand)) return kInvalidNode; + } + + // Add edges and nodes for output operands. + for (const InstructionOperand& operand : instruction.output_operands) { + if (!AddOutputOperand(instruction_node, operand)) return kInvalidNode; + } + for (const InstructionOperand& operand : + instruction.implicit_output_operands) { + if (!AddOutputOperand(instruction_node, operand)) return kInvalidNode; + } + + return instruction_node; +} + bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( const std::vector& instructions) { if (instructions.empty()) return false; @@ -203,56 +252,11 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( NodeIndex previous_instruction_node = kInvalidNode; for (const Instruction& instruction : instructions) { - // Add the instruction node. - const NodeIndex instruction_node = - AddNode(NodeType::kInstruction, instruction.mnemonic); + NodeIndex instruction_node = + AddInstruction(instruction, previous_instruction_node); if (instruction_node == kInvalidNode) { return false; } - - // Store the annotations for later use (inclusion in embeddings), using -1 - // as a default value wherever annotations are missing. - std::vector row = std::vector(annotation_names_.size(), -1); - for (const auto& [name, value] : instruction.instruction_annotations) { - const auto annotation_index = annotation_name_to_idx_.find(name); - if (annotation_index == annotation_name_to_idx_.end()) continue; - row[annotation_index->second] = value; - } - instruction_annotations_.push_back(row); - - // Add nodes for prefixes of the instruction. - for (const std::string& prefix : instruction.prefixes) { - const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); - if (prefix_node == kInvalidNode) { - return false; - } - AddEdge(EdgeType::kInstructionPrefix, prefix_node, instruction_node); - } - - // Add a structural dependency edge from the previous instruction. - if (previous_instruction_node >= 0) { - AddEdge(EdgeType::kStructuralDependency, previous_instruction_node, - instruction_node); - } - - // Add edges for input operands. And nodes too, if necessary. - for (const InstructionOperand& operand : instruction.input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_input_operands) { - if (!AddInputOperand(instruction_node, operand)) return false; - } - - // Add edges and nodes for output operands. - for (const InstructionOperand& operand : instruction.output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } - for (const InstructionOperand& operand : - instruction.implicit_output_operands) { - if (!AddOutputOperand(instruction_node, operand)) return false; - } - previous_instruction_node = instruction_node; } @@ -441,6 +445,20 @@ void BasicBlockGraphBuilder::AddEdge(EdgeType edge_type, NodeIndex sender, edge_types_.push_back(edge_type); } +void BasicBlockGraphBuilder::AddInstructionAnnotations( + const Instruction& instruction) { + // Store the annotations for later use, using `kDefaultInstructionAnnotation` + // as a default value wherever annotations are missing. + std::vector row = std::vector(annotation_names_.size(), + kDefaultInstructionAnnotation); + for (const auto& [name, value] : instruction.instruction_annotations) { + const auto annotation_index = annotation_name_to_idx_.find(name); + if (annotation_index == annotation_name_to_idx_.end()) continue; + row[annotation_index->second] = value; + } + instruction_annotations_.push_back(row); +} + std::vector BasicBlockGraphBuilder::EdgeFeatures() const { std::vector edge_features(num_edges()); for (int i = 0; i < num_edges(); ++i) { diff --git a/gematria/granite/graph_builder.h b/gematria/granite/graph_builder.h index c4684eea..6e09b08e 100644 --- a/gematria/granite/graph_builder.h +++ b/gematria/granite/graph_builder.h @@ -91,7 +91,6 @@ #include #include -#include #include #include #include @@ -251,7 +250,7 @@ class BasicBlockGraphBuilder { // `num_instructions` x `annotation_names.size()` matrix, each entry of which // represents the value of the annotation of the type corresponding to the // column for the instruction corresponding to the row. - const std::vector>& instruction_annotations() const { + const std::vector>& instruction_annotations() const { return instruction_annotations_; } @@ -360,6 +359,10 @@ class BasicBlockGraphBuilder { size_t prev_global_features_size_; }; + // Adds nodes and edges for a single instruction of a basic block. + NodeIndex AddInstruction(const Instruction& instruction, + NodeIndex previous_instruction_node); + // Adds nodes and edges for a single input operand of an instruction. bool AddInputOperand(NodeIndex instruction_node, const InstructionOperand& operand); @@ -383,6 +386,10 @@ class BasicBlockGraphBuilder { // Adds a new edge to the batch. void AddEdge(EdgeType edge_type, NodeIndex sender, NodeIndex receiver); + // Updates the `instruction_annotations_` tensor with annotations from + // `instruction`. + void AddInstructionAnnotations(const Instruction& instruction); + // Mapping from string node tokens to indices of embedding vectors used in // the models. const std::unordered_map node_tokens_; @@ -410,7 +417,7 @@ class BasicBlockGraphBuilder { // Mapping from annotation type names to corresponding row index in the // `instruction_annotations_` matrix. std::unordered_map annotation_name_to_idx_; - std::vector> instruction_annotations_; + std::vector> instruction_annotations_; std::vector edge_senders_; std::vector edge_receivers_; diff --git a/gematria/granite/graph_builder_model_inference.cc b/gematria/granite/graph_builder_model_inference.cc index ca3c436c..46a52e7a 100644 --- a/gematria/granite/graph_builder_model_inference.cc +++ b/gematria/granite/graph_builder_model_inference.cc @@ -305,7 +305,7 @@ llvm::Error FillTensorFromStdVectorMatrix( auto* const tensor_data = interpreter->typed_input_tensor(tensor_index); for (int row = 0; row < input_matrix.size(); ++row) { - const std::vector& row_data = input_matrix[row]; + const std::vector& row_data = input_matrix[row]; if (expected_size != row_data.size()) { return llvm::createStringError( llvm::errc::invalid_argument, @@ -676,7 +676,7 @@ GraphBuilderModelInference::RunInference() { const std::vector instruction_node_mask = graph_builder_->InstructionNodeMask(); - const std::vector>& instruction_annotations = + const std::vector>& instruction_annotations = graph_builder_->instruction_annotations(); const std::vector delta_block_index = graph_builder_->DeltaBlockIndex(); diff --git a/gematria/granite/python/graph_builder.cc b/gematria/granite/python/graph_builder.cc index c238e2d3..6949cdc6 100644 --- a/gematria/granite/python/graph_builder.cc +++ b/gematria/granite/python/graph_builder.cc @@ -14,13 +14,11 @@ #include "gematria/granite/graph_builder.h" -#include #include #include #include "absl/strings/string_view.h" #include "gematria/model/oov_token_behavior.h" -#include "gematria/proto/canonicalized_instruction.pb.h" #include "pybind11/cast.h" #include "pybind11/detail/common.h" #include "pybind11/pybind11.h"