From b535e754af13f2c8cfafe6e4daeb5fc8eea168f4 Mon Sep 17 00:00:00 2001 From: Meow King Date: Sat, 3 Aug 2024 22:04:53 +0800 Subject: [PATCH] feat: equality of objects --- example/parse_ir_assmebly.py | 6 +-- src/llvm/Core/miscClasses.cpp | 36 +++++++++++----- src/llvm/Core/type.cpp | 2 +- src/llvm/_types.h | 61 +++++++++++++++------------ src/llvm/_types/PyContext.h | 3 +- src/llvm/_types/PyLLVMObject.h | 30 +++++++++++++ src/llvm/_types/PyMemoryBuffer.h | 3 +- src/llvm/_types/PyMetadataEntries.h | 3 +- src/llvm/_types/PyModule.h | 3 +- src/llvm/_types/PyModuleFlagEntries.h | 3 +- src/llvm/_types/PyModuleProvider.h | 3 +- src/llvm/_types/PyOperandBundle.h | 3 +- src/llvm/_types/PyPassManagerBase.h | 3 +- src/llvmpym_ext.cpp | 6 +-- 14 files changed, 112 insertions(+), 53 deletions(-) create mode 100644 src/llvm/_types/PyLLVMObject.h diff --git a/example/parse_ir_assmebly.py b/example/parse_ir_assmebly.py index e95b0e8..490aa93 100644 --- a/example/parse_ir_assmebly.py +++ b/example/parse_ir_assmebly.py @@ -31,11 +31,11 @@ module = f.parent # assert m == module assert f.kind == core.ValueKind.Function - print() for i, a in enumerate(f.args, 1): - print(f'Argument | name: "{a.name}", type: "{a.type}"') + print(f'\tArgument | name: "{a.name}", type: "{a.type}"') attrs = f.get_attributes_at_index(i) - print(f"\tattrs: {attrs}") + print(f"\t\tattrs: {attrs}") + print("\n----------------------------\n") diff --git a/src/llvm/Core/miscClasses.cpp b/src/llvm/Core/miscClasses.cpp index 0984d3b..b41214a 100644 --- a/src/llvm/Core/miscClasses.cpp +++ b/src/llvm/Core/miscClasses.cpp @@ -26,7 +26,9 @@ void bindOtherClasses(nb::module_ &m) { "exist simultaneously. A single context is not thread safe. However," "different contexts can execute on different threads simultaneously."); - auto AttributeClass = nb::class_(m, "Attribute", "Attribute"); + auto AttributeClass = + nb::class_> + (m, "Attribute", "Attribute"); auto EnumAttributeClass = nb::class_ (m, "EnumAttribute", "EnumAttribute"); auto TypeAttributeClass = nb::class_ @@ -35,12 +37,16 @@ void bindOtherClasses(nb::module_ &m) { (m, "StringAttribute", "StringAttribute"); - auto BasicBlockClass = nb::class_ - (m, "BasicBlock", "BasicBlock"); - auto DiagnosticInfoClass = nb::class_ - (m, "DiagnosticInfo", "DiagnosticInfo"); + auto BasicBlockClass = + nb::class_> + (m, "BasicBlock", "BasicBlock"); + auto DiagnosticInfoClass = + nb::class_> + (m, "DiagnosticInfo", "DiagnosticInfo"); - auto NamedMDNodeClass = nb::class_(m, "NamedMDNode", "NamedMDNode"); + auto NamedMDNodeClass = + nb::class_> + (m, "NamedMDNode", "NamedMDNode"); auto ModuleClass = nb::class_ (m, "Module", @@ -50,7 +56,9 @@ void bindOtherClasses(nb::module_ &m) { auto ModuleFlagEntriesClass = nb::class_ (m, "ModuleFlagEntry", "ModuleFlagEntry"); - auto MetadataClass = nb::class_(m, "Metadata", "Metadata"); + auto MetadataClass = + nb::class_> + (m, "Metadata", "Metadata"); auto MDNodeClass = nb::class_(m, "MDNode", "MDNode"); auto MDStringClass = nb::class_(m, "MDString", "MDString"); auto ValueAsMetadata = nb::class_ @@ -60,12 +68,18 @@ void bindOtherClasses(nb::module_ &m) { auto MetadataEntriesClass = nb::class_ (m, "MetadataEntry", "MetadataEntry"); - auto UseClass = nb::class_(m, "Use", "Use"); + auto UseClass = + nb::class_> + (m, "Use", "Use"); - auto IntrinsicClass = nb::class_(m, "Intrinsic", "Intrinsic"); + auto IntrinsicClass = + nb::class_> + (m, "Intrinsic", "Intrinsic"); auto OperandBundleClass = nb::class_(m, "OperandBundle", "OperandBundle"); - auto BuilderClass = nb::class_(m, "Builder", "Builder"); + auto BuilderClass = + nb::class_> + (m, "Builder", "Builder"); auto ModuleProviderClass = nb::class_ (m, "ModuleProvider", "ModuleProvider"); auto MemoryBufferClass = nb::class_ @@ -921,6 +935,8 @@ void bindOtherClasses(nb::module_ &m) { auto name = std::string(raw_name, len); return fmt::format("", self.get(), name); }) + .def("__bool__", // override default behavior + [](PyIntrinsic &self) { return true; }) .def_static("lookup", [](std::string &name) { return PyIntrinsic(LLVMLookupIntrinsicID(name.c_str(), name.size())); diff --git a/src/llvm/Core/type.cpp b/src/llvm/Core/type.cpp index e3c0fd6..fc093bf 100644 --- a/src/llvm/Core/type.cpp +++ b/src/llvm/Core/type.cpp @@ -17,7 +17,7 @@ template using optional = std::optional; void bindTypeClasses(nb::module_ &m) { - auto TypeClass = nb::class_(m, "Type", "Type"); + auto TypeClass = nb::class_>(m, "Type", "Type"); auto TypeIntClass = nb::class_(m, "IntType", "IntType"); auto TypeRealClass = nb::class_(m, "RealType", "RealType"); auto TypeFunctionClass = nb::class_ (m, "FunctionType", "FunctionType"); diff --git a/src/llvm/_types.h b/src/llvm/_types.h index cb25960..c383c16 100644 --- a/src/llvm/_types.h +++ b/src/llvm/_types.h @@ -21,6 +21,7 @@ #include "_types/PyPassManagerBase.h" #include "_types/PyMemoryBuffer.h" #include "_types/PyModuleProvider.h" +#include "_types/PyLLVMObject.h" #define DEFINE_PY_WRAPPER_CLASS(ClassName, UnderlyingType) \ @@ -232,6 +233,14 @@ macro(PyType, PyTypeTargetExt) +#define BIND_PYLLVMOBJECT_(ClassName, UnderlyingType, PyClassName) \ + nb::class_> \ + (m, #PyClassName, "The base class.") \ + .def("__bool__", &PyLLVMObject::__bool__) \ + .def("__eq__", &PyLLVMObject::__equal__) \ + .def("__hash__", &PyLLVMObject::__hash__); + + enum class PyAttributeIndex { Return = LLVMAttributeReturnIndex, Function = LLVMAttributeFunctionIndex @@ -249,31 +258,32 @@ enum class PyLLVMFastMathFlags { All = LLVMFastMathAll }; -template -class PyLLVMObject { -public: - virtual ~PyLLVMObject() = default; - - UnderlyingType get() const { - return const_cast(static_cast(this))->get(); - } - - bool __bool__() const { - UnderlyingType raw = static_cast(get()); - if (!raw) return false; - return true; - } - - // `__equal__` and `__hash__` works well on pointer type UnderlyingType - bool __equal__(const PyLLVMObject& other) const { - return this->get() == other.get(); - } - - std::size_t __hash__() const { - return std::hash{}(this->get()); - } -}; +// NOTE the `__bool__` method of PyIntrinsic is overridden +#define BIND_PYLLVMOBJECT() \ + BIND_PYLLVMOBJECT_(PyValue, LLVMValueRef, PyValueObject) \ + BIND_PYLLVMOBJECT_(PyType, LLVMTypeRef, PyTypeObject) \ + BIND_PYLLVMOBJECT_(PyDiagnosticInfo, LLVMDiagnosticInfoRef, PyDiagnosticInfoObject) \ + BIND_PYLLVMOBJECT_(PyAttribute, LLVMAttributeRef, PyAttributeObject) \ + BIND_PYLLVMOBJECT_(PyNamedMDNode, LLVMNamedMDNodeRef, PyNamedMDNodeObject) \ + BIND_PYLLVMOBJECT_(PyUse, LLVMUseRef, PyUseObject) \ + BIND_PYLLVMOBJECT_(PyBasicBlock, LLVMBasicBlockRef, PyBasicBlockObject) \ + BIND_PYLLVMOBJECT_(PyBuilder, LLVMBuilderRef, PyBuilderObject) \ + BIND_PYLLVMOBJECT_(PyMetadata, LLVMMetadataRef, PyMetadataObject) \ + BIND_PYLLVMOBJECT_(PyIntrinsic, unsigned, PyIntrinsicObject) \ +\ + BIND_PYLLVMOBJECT_(PyContext, LLVMContextRef, PyContextObject) \ + BIND_PYLLVMOBJECT_(PyMemoryBuffer, LLVMMemoryBufferRef, PyMemoryBufferObject) \ + BIND_PYLLVMOBJECT_(PyMetadataEntries, LLVMValueMetadataEntries, PyMetadataEntriesObject) \ + BIND_PYLLVMOBJECT_(PyModuleFlagEntries, LLVMModuleFlagEntries, PyModuleFlagEntriesObject) \ + BIND_PYLLVMOBJECT_(PyModule, LLVMModuleRef, PyModuleObject) \ + BIND_PYLLVMOBJECT_(PyModuleProvider, LLVMModuleProviderRef, PyModuleProviderObject) \ + BIND_PYLLVMOBJECT_(PyOperandBundle, LLVMOperandBundleRef, PyOperandBundleObject) \ + BIND_PYLLVMOBJECT_(PyPassManagerBase, LLVMPassManagerRef, PyPassManagerBaseObject) + + + + DEFINE_PY_WRAPPER_CLASS_POLYMORPHIC(PyValue, LLVMValueRef) DEFINE_PY_WRAPPER_CLASS_POLYMORPHIC(PyType, LLVMTypeRef) @@ -285,6 +295,7 @@ DEFINE_PY_WRAPPER_CLASS(PyBasicBlock, LLVMBasicBlockRef) DEFINE_PY_WRAPPER_CLASS(PyBuilder, LLVMBuilderRef) DEFINE_PY_WRAPPER_CLASS(PyMetadata, LLVMMetadataRef) + DEFINE_DIRECT_SUB_CLASS(PyMetadata, PyMDNode) DEFINE_DIRECT_SUB_CLASS(PyMetadata, PyValueAsMetadata) DEFINE_DIRECT_SUB_CLASS(PyMetadata, PyMDString) @@ -301,8 +312,6 @@ DEFINE_PY_WRAPPER_CLASS(PyIntrinsic, unsigned) - - DEFINE_DIRECT_SUB_CLASS(PyPassManagerBase, PyPassManager); DEFINE_DIRECT_SUB_CLASS(PyPassManagerBase, PyFunctionPassManager); diff --git a/src/llvm/_types/PyContext.h b/src/llvm/_types/PyContext.h index 5e8ab31..cb35c3e 100644 --- a/src/llvm/_types/PyContext.h +++ b/src/llvm/_types/PyContext.h @@ -6,8 +6,9 @@ #include #include #include +#include "PyLLVMObject.h" -class PyContext { +class PyContext : public PyLLVMObject { public: explicit PyContext(); explicit PyContext(LLVMContextRef context, bool is_global_context); diff --git a/src/llvm/_types/PyLLVMObject.h b/src/llvm/_types/PyLLVMObject.h new file mode 100644 index 0000000..689e9a7 --- /dev/null +++ b/src/llvm/_types/PyLLVMObject.h @@ -0,0 +1,30 @@ +#ifndef PYLLVMOBJECT_H +#define PYLLVMOBJECT_H + +template +class PyLLVMObject { +public: + virtual ~PyLLVMObject() = default; + + UnderlyingType get() const { + return const_cast(static_cast(this))->get(); + } + + bool __bool__() const { + UnderlyingType raw = static_cast(get()); + if (!raw) return false; + return true; + } + + // `__equal__` and `__hash__` works well on pointer type UnderlyingType + bool __equal__(const PyLLVMObject& other) const { + return this->get() == other.get(); + } + +std::size_t __hash__() const { + return std::hash{}(this->get()); +} +}; + + +#endif diff --git a/src/llvm/_types/PyMemoryBuffer.h b/src/llvm/_types/PyMemoryBuffer.h index 1acf15b..54f9c58 100644 --- a/src/llvm/_types/PyMemoryBuffer.h +++ b/src/llvm/_types/PyMemoryBuffer.h @@ -5,8 +5,9 @@ #include #include #include +#include "PyLLVMObject.h" -class PyMemoryBuffer { +class PyMemoryBuffer : public PyLLVMObject { public: explicit PyMemoryBuffer(LLVMMemoryBufferRef mb); LLVMMemoryBufferRef get() const; diff --git a/src/llvm/_types/PyMetadataEntries.h b/src/llvm/_types/PyMetadataEntries.h index b505156..486d7ea 100644 --- a/src/llvm/_types/PyMetadataEntries.h +++ b/src/llvm/_types/PyMetadataEntries.h @@ -5,10 +5,11 @@ #include #include #include +#include "PyLLVMObject.h" typedef LLVMValueMetadataEntry *LLVMValueMetadataEntries; -class PyMetadataEntries { +class PyMetadataEntries : public PyLLVMObject { public: explicit PyMetadataEntries(LLVMValueMetadataEntries entries, size_t len); LLVMValueMetadataEntries get() const; diff --git a/src/llvm/_types/PyModule.h b/src/llvm/_types/PyModule.h index f647050..c46e29e 100644 --- a/src/llvm/_types/PyModule.h +++ b/src/llvm/_types/PyModule.h @@ -7,8 +7,9 @@ #include #include #include +#include "PyLLVMObject.h" -class PyModule { +class PyModule : public PyLLVMObject { public: explicit PyModule(const std::string &id); explicit PyModule(const std::string &id, LLVMContextRef context); diff --git a/src/llvm/_types/PyModuleFlagEntries.h b/src/llvm/_types/PyModuleFlagEntries.h index 7b1def7..9d20a7d 100644 --- a/src/llvm/_types/PyModuleFlagEntries.h +++ b/src/llvm/_types/PyModuleFlagEntries.h @@ -5,10 +5,11 @@ #include #include #include +#include "PyLLVMObject.h" typedef LLVMModuleFlagEntry *LLVMModuleFlagEntries; -class PyModuleFlagEntries { +class PyModuleFlagEntries : public PyLLVMObject{ public: explicit PyModuleFlagEntries(LLVMModuleFlagEntries entries, size_t len); LLVMModuleFlagEntries get() const; diff --git a/src/llvm/_types/PyModuleProvider.h b/src/llvm/_types/PyModuleProvider.h index 7e0cf31..929050f 100644 --- a/src/llvm/_types/PyModuleProvider.h +++ b/src/llvm/_types/PyModuleProvider.h @@ -5,8 +5,9 @@ #include #include #include +#include "PyLLVMObject.h" -class PyModuleProvider { +class PyModuleProvider : public PyLLVMObject{ public: explicit PyModuleProvider(LLVMModuleProviderRef mp); LLVMModuleProviderRef get() const; diff --git a/src/llvm/_types/PyOperandBundle.h b/src/llvm/_types/PyOperandBundle.h index 41350bb..511660d 100644 --- a/src/llvm/_types/PyOperandBundle.h +++ b/src/llvm/_types/PyOperandBundle.h @@ -5,8 +5,9 @@ #include #include #include +#include "PyLLVMObject.h" -class PyOperandBundle { +class PyOperandBundle : public PyLLVMObject { public: explicit PyOperandBundle(LLVMOperandBundleRef bundle); LLVMOperandBundleRef get() const; diff --git a/src/llvm/_types/PyPassManagerBase.h b/src/llvm/_types/PyPassManagerBase.h index 2d4cccc..ae4fe2b 100644 --- a/src/llvm/_types/PyPassManagerBase.h +++ b/src/llvm/_types/PyPassManagerBase.h @@ -5,8 +5,9 @@ #include #include #include +#include "PyLLVMObject.h" -class PyPassManagerBase { +class PyPassManagerBase : public PyLLVMObject { public: explicit PyPassManagerBase(LLVMPassManagerRef pm); LLVMPassManagerRef get() const; diff --git a/src/llvmpym_ext.cpp b/src/llvmpym_ext.cpp index 9ae5183..a4da023 100644 --- a/src/llvmpym_ext.cpp +++ b/src/llvmpym_ext.cpp @@ -11,11 +11,7 @@ using namespace nb::literals; NB_MODULE(llvmpym_ext, m) { m.doc() = "LLVM Python Native Extension"; - nb::class_> - (m, "LLVMObject", "The base of for all LLVM object classes.") - .def("__bool__", &PyLLVMObject::__bool__) - .def("__eq__", &PyLLVMObject::__equal__) - .def("__hash__", &PyLLVMObject::__hash__); + BIND_PYLLVMOBJECT(); auto coreModule = m.def_submodule("core", "LLVM Core"); populateCore(coreModule);