diff --git a/WORKSPACE b/WORKSPACE index 05ca018c..1a25cc4e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -142,8 +142,8 @@ git_repository( new_git_repository( name = "pybind11", build_file = "@pybind11_bazel//:pybind11.BUILD", + commit = "e7e5d6e5bb0af543a2ded6d34163176c3e6ab745", remote = "https://github.com/pybind/pybind11.git", - tag = "v2.10.3", ) git_repository( diff --git a/gematria/basic_block/python/basic_block.cc b/gematria/basic_block/python/basic_block.cc index e78c6111..75e679a4 100644 --- a/gematria/basic_block/python/basic_block.cc +++ b/gematria/basic_block/python/basic_block.cc @@ -23,6 +23,7 @@ #include "pybind11/cast.h" #include "pybind11/detail/common.h" +#include "pybind11/native_enum.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl.h" @@ -82,22 +83,19 @@ PYBIND11_MODULE(basic_block, m) { // Python code propagate to C++ code. py::bind_vector>(m, "StringList"); - py::enum_(m, "OperandType", R"( - The type of the operands used in the basic blocks. - - Values: - REGISTER: The operand is a register. - IMMEDIATE_VALUE: The operand is an integer immediate value. This - immediate value can have up to 64-bits. - FP_IMMEDIATE_VALUE: The operand is a floating-point immediate value. - ADDRESS: The operand is an address computation. - MEMORY: The operand is a location in the memory.)") - .value("UNKNOWN", OperandType::kUnknown) - .value("REGISTER", OperandType::kRegister) - .value("IMMEDIATE_VALUE", OperandType::kImmediateValue) - .value("FP_IMMEDIATE_VALUE", OperandType::kFpImmediateValue) - .value("ADDRESS", OperandType::kAddress) - .value("MEMORY", OperandType::kMemory); + py::native_enum(m, "OperandType") + .value("UNKNOWN", OperandType::kUnknown, "The operand type is unknown.") + .value("REGISTER", OperandType::kRegister, "The operand is a register.") + .value("IMMEDIATE_VALUE", OperandType::kImmediateValue, + "The operand is an integer immediate value. This immediate value " + "can have up to 64-bits.") + .value("FP_IMMEDIATE_VALUE", OperandType::kFpImmediateValue, + "The operand is a floating-point immediate value.") + .value("ADDRESS", OperandType::kAddress, + "The operand is an address computation.") + .value("MEMORY", OperandType::kMemory, + "The operand is a location in the memory.") + .finalize(); py::class_ address_tuple(m, "AddressTuple"); address_tuple diff --git a/gematria/basic_block/python/basic_block_test.py b/gematria/basic_block/python/basic_block_test.py index 1628d834..a3af1dd2 100644 --- a/gematria/basic_block/python/basic_block_test.py +++ b/gematria/basic_block/python/basic_block_test.py @@ -22,12 +22,11 @@ class OperandTypeTest(absltest.TestCase): def test_values(self): - self.assertGreaterEqual(len(basic_block.OperandType.__members__), 0) + self.assertGreaterEqual(len(basic_block.OperandType), 0) def test_docstring(self): - docstring = basic_block.OperandType.__doc__ - for value in basic_block.OperandType.__members__: - self.assertIn(value, docstring) + for value in basic_block.OperandType: + self.assertNotEmpty(value.__doc__) class AddressTupleTest(absltest.TestCase): diff --git a/gematria/datasets/python/bhive_to_exegesis.cc b/gematria/datasets/python/bhive_to_exegesis.cc index b737a531..de96e3d1 100644 --- a/gematria/datasets/python/bhive_to_exegesis.cc +++ b/gematria/datasets/python/bhive_to_exegesis.cc @@ -24,6 +24,7 @@ #include "llvm/tools/llvm-exegesis/lib/TargetSelect.h" #include "pybind11/cast.h" #include "pybind11/detail/common.h" +#include "pybind11/native_enum.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" // IWYU pragma: keep #include "pybind11_abseil/import_status_module.h" @@ -39,11 +40,12 @@ PYBIND11_MODULE(bhive_to_exegesis, m) { py::google::ImportStatusModule(); - py::enum_(m, "AnnotatorType") + py::native_enum(m, "AnnotatorType") .value("exegesis", BHiveToExegesis::AnnotatorType::kExegesis) .value("fast", BHiveToExegesis::AnnotatorType::kFast) .value("none", BHiveToExegesis::AnnotatorType::kNone) - .export_values(); + .export_values() + .finalize(); py::class_(m, "BHiveToExegesis") .def_static( diff --git a/gematria/granite/python/graph_builder.cc b/gematria/granite/python/graph_builder.cc index c238e2d3..106ef096 100644 --- a/gematria/granite/python/graph_builder.cc +++ b/gematria/granite/python/graph_builder.cc @@ -14,7 +14,6 @@ #include "gematria/granite/graph_builder.h" -#include #include #include @@ -23,6 +22,7 @@ #include "gematria/proto/canonicalized_instruction.pb.h" #include "pybind11/cast.h" #include "pybind11/detail/common.h" +#include "pybind11/native_enum.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "pybind11_protobuf/native_proto_caster.h" @@ -43,16 +43,17 @@ PYBIND11_MODULE(graph_builder, m) { pybind11_protobuf::ImportNativeProtoCasters(); - py::enum_(m, "NodeType") + py::native_enum(m, "NodeType") .value("INSTRUCTION", NodeType::kInstruction) .value("REGISTER", NodeType::kRegister) .value("IMMEDIATE", NodeType::kImmediate) .value("FP_IMMEDIATE", NodeType::kFpImmediate) .value("ADDRESS_OPERAND", NodeType::kAddressOperand) .value("MEMORY_OPERAND", NodeType::kMemoryOperand) - .export_values(); + .export_values() + .finalize(); - py::enum_(m, "EdgeType") + py::native_enum(m, "EdgeType") .value("STRUCTURAL_DEPENDENCY", EdgeType::kStructuralDependency) .value("REVERSE_STRUCTURAL_DEPENDENCY", EdgeType::kReverseStructuralDependency) @@ -63,7 +64,8 @@ PYBIND11_MODULE(graph_builder, m) { .value("ADDRESS_SEGMENT_REGISTER", EdgeType::kAddressSegmentRegister) .value("ADDRESS_DISPLACEMENT", EdgeType::kAddressDisplacement) .value("INSTRUCTION_PREFIX", EdgeType::kInstructionPrefix) - .export_values(); + .export_values() + .finalize(); py::class_(m, "BasicBlockGraphBuilder") .def( diff --git a/gematria/granite/python/graph_builder_model_base_test.py b/gematria/granite/python/graph_builder_model_base_test.py index d4669455..7585ee72 100644 --- a/gematria/granite/python/graph_builder_model_base_test.py +++ b/gematria/granite/python/graph_builder_model_base_test.py @@ -69,11 +69,7 @@ def _create_graph_network_modules(self): module=graph_nets.modules.GraphIndependent( edge_model_fn=functools.partial( snt.Embed, - # TODO(ondrasej): Pybind11 generated enum types do not - # implement the full Python enum interface. Replace this - # with len(graph_builder.EdgeType) when - # https://github.com/pybind/pybind11/issues/2332 is fixed. - vocab_size=len(graph_builder.EdgeType.__members__), + vocab_size=len(graph_builder.EdgeType), embed_dim=1, initializers=embedding_initializers, ), diff --git a/gematria/granite/python/token_graph_builder_model.py b/gematria/granite/python/token_graph_builder_model.py index 9a26ebc8..c11b9ba0 100644 --- a/gematria/granite/python/token_graph_builder_model.py +++ b/gematria/granite/python/token_graph_builder_model.py @@ -297,11 +297,7 @@ def _create_graph_network_modules( module=graph_nets.modules.GraphIndependent( edge_model_fn=functools.partial( snt.Embed, - # TODO(ondrasej): Pybind11 generated enum types do not - # implement the full Python enum interface. Replace this - # with len(graph_builder.EdgeType) when - # https://github.com/pybind/pybind11/issues/2332 is fixed. - vocab_size=len(graph_builder.EdgeType.__members__), + vocab_size=len(graph_builder.EdgeType), embed_dim=self._edge_embedding_size, initializers=embedding_initializers, ), diff --git a/gematria/model/python/oov_token_behavior.cc b/gematria/model/python/oov_token_behavior.cc index c48f2e67..33e65ca5 100644 --- a/gematria/model/python/oov_token_behavior.cc +++ b/gematria/model/python/oov_token_behavior.cc @@ -14,6 +14,7 @@ #include "gematria/model/oov_token_behavior.h" +#include "pybind11/native_enum.h" #include "pybind11/pybind11.h" namespace gematria { @@ -36,13 +37,14 @@ PYBIND11_MODULE(oov_token_behavior, m) { .def_property_readonly("replacement_token", &OutOfVocabularyTokenBehavior::replacement_token); - py::enum_(oov_token_behavior, - "BehaviorType") + py::native_enum( + oov_token_behavior, "BehaviorType") .value("RETURN_ERROR", OutOfVocabularyTokenBehavior::BehaviorType::kReturnError) .value("REPLACE_TOKEN", OutOfVocabularyTokenBehavior::BehaviorType::kReplaceToken) - .export_values(); + .export_values() + .finalize(); } } // namespace