Skip to content

Commit

Permalink
feat: equality of objects
Browse files Browse the repository at this point in the history
  • Loading branch information
Ziqi-Yang committed Aug 3, 2024
1 parent 92f1937 commit b535e75
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 53 deletions.
6 changes: 3 additions & 3 deletions example/parse_ir_assmebly.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
module = f.parent
# assert m == module
assert f.kind == core.ValueKind.Function
print()
for i, a in enumerate(f.args, 1):
print(f'Argument | name: "{a.name}", type: "{a.type}"')
print(f'\tArgument | name: "{a.name}", type: "{a.type}"')
attrs = f.get_attributes_at_index(i)
print(f"\tattrs: {attrs}")
print(f"\t\tattrs: {attrs}")


print("\n----------------------------\n")

36 changes: 26 additions & 10 deletions src/llvm/Core/miscClasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ void bindOtherClasses(nb::module_ &m) {
"exist simultaneously. A single context is not thread safe. However,"
"different contexts can execute on different threads simultaneously.");

auto AttributeClass = nb::class_<PyAttribute>(m, "Attribute", "Attribute");
auto AttributeClass =
nb::class_<PyAttribute, PyLLVMObject<PyAttribute, LLVMAttributeRef>>
(m, "Attribute", "Attribute");
auto EnumAttributeClass = nb::class_<PyEnumAttribute, PyAttribute>
(m, "EnumAttribute", "EnumAttribute");
auto TypeAttributeClass = nb::class_<PyTypeAttribute, PyAttribute>
Expand All @@ -35,12 +37,16 @@ void bindOtherClasses(nb::module_ &m) {
(m, "StringAttribute", "StringAttribute");


auto BasicBlockClass = nb::class_<PyBasicBlock>
(m, "BasicBlock", "BasicBlock");
auto DiagnosticInfoClass = nb::class_<PyDiagnosticInfo>
(m, "DiagnosticInfo", "DiagnosticInfo");
auto BasicBlockClass =
nb::class_<PyBasicBlock, PyLLVMObject<PyBasicBlock, LLVMBasicBlockRef>>
(m, "BasicBlock", "BasicBlock");
auto DiagnosticInfoClass =
nb::class_<PyDiagnosticInfo, PyLLVMObject<PyDiagnosticInfo, LLVMDiagnosticInfoRef>>
(m, "DiagnosticInfo", "DiagnosticInfo");

auto NamedMDNodeClass = nb::class_<PyNamedMDNode>(m, "NamedMDNode", "NamedMDNode");
auto NamedMDNodeClass =
nb::class_<PyNamedMDNode, PyLLVMObject<PyNamedMDNode, LLVMNamedMDNodeRef>>
(m, "NamedMDNode", "NamedMDNode");
auto ModuleClass =
nb::class_<PyModule>
(m, "Module",
Expand All @@ -50,7 +56,9 @@ void bindOtherClasses(nb::module_ &m) {

auto ModuleFlagEntriesClass = nb::class_<PyModuleFlagEntries>
(m, "ModuleFlagEntry", "ModuleFlagEntry");
auto MetadataClass = nb::class_<PyMetadata>(m, "Metadata", "Metadata");
auto MetadataClass =
nb::class_<PyMetadata, PyLLVMObject<PyMetadata, LLVMMetadataRef>>
(m, "Metadata", "Metadata");
auto MDNodeClass = nb::class_<PyMDNode, PyMetadata>(m, "MDNode", "MDNode");
auto MDStringClass = nb::class_<PyMDString, PyMetadata>(m, "MDString", "MDString");
auto ValueAsMetadata = nb::class_<PyValueAsMetadata, PyMetadata>
Expand All @@ -60,12 +68,18 @@ void bindOtherClasses(nb::module_ &m) {
auto MetadataEntriesClass = nb::class_<PyMetadataEntries>
(m, "MetadataEntry", "MetadataEntry");

auto UseClass = nb::class_<PyUse>(m, "Use", "Use");
auto UseClass =
nb::class_<PyUse, PyLLVMObject<PyUse, LLVMUseRef>>
(m, "Use", "Use");

auto IntrinsicClass = nb::class_<PyIntrinsic>(m, "Intrinsic", "Intrinsic");
auto IntrinsicClass =
nb::class_<PyIntrinsic, PyLLVMObject<PyIntrinsic, unsigned>>
(m, "Intrinsic", "Intrinsic");
auto OperandBundleClass = nb::class_<PyOperandBundle>(m, "OperandBundle",
"OperandBundle");
auto BuilderClass = nb::class_<PyBuilder>(m, "Builder", "Builder");
auto BuilderClass =
nb::class_<PyBuilder, PyLLVMObject<PyBuilder, LLVMBuilderRef>>
(m, "Builder", "Builder");
auto ModuleProviderClass = nb::class_<PyModuleProvider>
(m, "ModuleProvider", "ModuleProvider");
auto MemoryBufferClass = nb::class_<PyMemoryBuffer>
Expand Down Expand Up @@ -921,6 +935,8 @@ void bindOtherClasses(nb::module_ &m) {
auto name = std::string(raw_name, len);
return fmt::format("<Instrinsic id={} name={}>", self.get(), name);
})
.def("__bool__", // override default behavior
[](PyIntrinsic &self) { return true; })
.def_static("lookup",
[](std::string &name) {
return PyIntrinsic(LLVMLookupIntrinsicID(name.c_str(), name.size()));
Expand Down
2 changes: 1 addition & 1 deletion src/llvm/Core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ template <typename T>
using optional = std::optional<T>;

void bindTypeClasses(nb::module_ &m) {
auto TypeClass = nb::class_<PyType>(m, "Type", "Type");
auto TypeClass = nb::class_<PyType, PyLLVMObject<PyType, LLVMTypeRef>>(m, "Type", "Type");
auto TypeIntClass = nb::class_<PyTypeInt, PyType>(m, "IntType", "IntType");
auto TypeRealClass = nb::class_<PyTypeReal, PyType>(m, "RealType", "RealType");
auto TypeFunctionClass = nb::class_<PyTypeFunction, PyType> (m, "FunctionType", "FunctionType");
Expand Down
61 changes: 35 additions & 26 deletions src/llvm/_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "_types/PyPassManagerBase.h"
#include "_types/PyMemoryBuffer.h"
#include "_types/PyModuleProvider.h"
#include "_types/PyLLVMObject.h"


#define DEFINE_PY_WRAPPER_CLASS(ClassName, UnderlyingType) \
Expand Down Expand Up @@ -232,6 +233,14 @@
macro(PyType, PyTypeTargetExt)


#define BIND_PYLLVMOBJECT_(ClassName, UnderlyingType, PyClassName) \
nb::class_<PyLLVMObject<ClassName, UnderlyingType>> \
(m, #PyClassName, "The base class.") \
.def("__bool__", &PyLLVMObject<ClassName, UnderlyingType>::__bool__) \
.def("__eq__", &PyLLVMObject<ClassName, UnderlyingType>::__equal__) \
.def("__hash__", &PyLLVMObject<ClassName, UnderlyingType>::__hash__);


enum class PyAttributeIndex {
Return = LLVMAttributeReturnIndex,
Function = LLVMAttributeFunctionIndex
Expand All @@ -249,31 +258,32 @@ enum class PyLLVMFastMathFlags {
All = LLVMFastMathAll
};

template <typename Derived, typename UnderlyingType>
class PyLLVMObject {
public:
virtual ~PyLLVMObject() = default;

UnderlyingType get() const {
return const_cast<const Derived*>(static_cast<const Derived*>(this))->get();
}

bool __bool__() const {
UnderlyingType raw = static_cast<UnderlyingType>(get());
if (!raw) return false;
return true;
}

// `__equal__` and `__hash__` works well on pointer type UnderlyingType
bool __equal__(const PyLLVMObject& other) const {
return this->get() == other.get();
}

std::size_t __hash__() const {
return std::hash<UnderlyingType>{}(this->get());
}
};

// NOTE the `__bool__` method of PyIntrinsic is overridden
#define BIND_PYLLVMOBJECT() \
BIND_PYLLVMOBJECT_(PyValue, LLVMValueRef, PyValueObject) \
BIND_PYLLVMOBJECT_(PyType, LLVMTypeRef, PyTypeObject) \
BIND_PYLLVMOBJECT_(PyDiagnosticInfo, LLVMDiagnosticInfoRef, PyDiagnosticInfoObject) \
BIND_PYLLVMOBJECT_(PyAttribute, LLVMAttributeRef, PyAttributeObject) \
BIND_PYLLVMOBJECT_(PyNamedMDNode, LLVMNamedMDNodeRef, PyNamedMDNodeObject) \
BIND_PYLLVMOBJECT_(PyUse, LLVMUseRef, PyUseObject) \
BIND_PYLLVMOBJECT_(PyBasicBlock, LLVMBasicBlockRef, PyBasicBlockObject) \
BIND_PYLLVMOBJECT_(PyBuilder, LLVMBuilderRef, PyBuilderObject) \
BIND_PYLLVMOBJECT_(PyMetadata, LLVMMetadataRef, PyMetadataObject) \
BIND_PYLLVMOBJECT_(PyIntrinsic, unsigned, PyIntrinsicObject) \
\
BIND_PYLLVMOBJECT_(PyContext, LLVMContextRef, PyContextObject) \
BIND_PYLLVMOBJECT_(PyMemoryBuffer, LLVMMemoryBufferRef, PyMemoryBufferObject) \
BIND_PYLLVMOBJECT_(PyMetadataEntries, LLVMValueMetadataEntries, PyMetadataEntriesObject) \
BIND_PYLLVMOBJECT_(PyModuleFlagEntries, LLVMModuleFlagEntries, PyModuleFlagEntriesObject) \
BIND_PYLLVMOBJECT_(PyModule, LLVMModuleRef, PyModuleObject) \
BIND_PYLLVMOBJECT_(PyModuleProvider, LLVMModuleProviderRef, PyModuleProviderObject) \
BIND_PYLLVMOBJECT_(PyOperandBundle, LLVMOperandBundleRef, PyOperandBundleObject) \
BIND_PYLLVMOBJECT_(PyPassManagerBase, LLVMPassManagerRef, PyPassManagerBaseObject)





DEFINE_PY_WRAPPER_CLASS_POLYMORPHIC(PyValue, LLVMValueRef)
DEFINE_PY_WRAPPER_CLASS_POLYMORPHIC(PyType, LLVMTypeRef)
Expand All @@ -285,6 +295,7 @@ DEFINE_PY_WRAPPER_CLASS(PyBasicBlock, LLVMBasicBlockRef)
DEFINE_PY_WRAPPER_CLASS(PyBuilder, LLVMBuilderRef)

DEFINE_PY_WRAPPER_CLASS(PyMetadata, LLVMMetadataRef)

DEFINE_DIRECT_SUB_CLASS(PyMetadata, PyMDNode)
DEFINE_DIRECT_SUB_CLASS(PyMetadata, PyValueAsMetadata)
DEFINE_DIRECT_SUB_CLASS(PyMetadata, PyMDString)
Expand All @@ -301,8 +312,6 @@ DEFINE_PY_WRAPPER_CLASS(PyIntrinsic, unsigned)





DEFINE_DIRECT_SUB_CLASS(PyPassManagerBase, PyPassManager);
DEFINE_DIRECT_SUB_CLASS(PyPassManagerBase, PyFunctionPassManager);

Expand Down
3 changes: 2 additions & 1 deletion src/llvm/_types/PyContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
#include <memory>
#include <unordered_map>
#include <mutex>
#include "PyLLVMObject.h"

class PyContext {
class PyContext : public PyLLVMObject<PyContext, LLVMContextRef> {
public:
explicit PyContext();
explicit PyContext(LLVMContextRef context, bool is_global_context);
Expand Down
30 changes: 30 additions & 0 deletions src/llvm/_types/PyLLVMObject.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef PYLLVMOBJECT_H
#define PYLLVMOBJECT_H

template <typename Derived, typename UnderlyingType>
class PyLLVMObject {
public:
virtual ~PyLLVMObject() = default;

UnderlyingType get() const {
return const_cast<const Derived*>(static_cast<const Derived*>(this))->get();
}

bool __bool__() const {
UnderlyingType raw = static_cast<UnderlyingType>(get());
if (!raw) return false;
return true;
}

// `__equal__` and `__hash__` works well on pointer type UnderlyingType
bool __equal__(const PyLLVMObject& other) const {
return this->get() == other.get();
}

std::size_t __hash__() const {
return std::hash<UnderlyingType>{}(this->get());
}
};


#endif
3 changes: 2 additions & 1 deletion src/llvm/_types/PyMemoryBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#include <memory>
#include <unordered_map>
#include <mutex>
#include "PyLLVMObject.h"

class PyMemoryBuffer {
class PyMemoryBuffer : public PyLLVMObject<PyMemoryBuffer, LLVMMemoryBufferRef> {
public:
explicit PyMemoryBuffer(LLVMMemoryBufferRef mb);
LLVMMemoryBufferRef get() const;
Expand Down
3 changes: 2 additions & 1 deletion src/llvm/_types/PyMetadataEntries.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
#include <memory>
#include <unordered_map>
#include <mutex>
#include "PyLLVMObject.h"

typedef LLVMValueMetadataEntry *LLVMValueMetadataEntries;

class PyMetadataEntries {
class PyMetadataEntries : public PyLLVMObject<PyMetadataEntries, LLVMValueMetadataEntries> {
public:
explicit PyMetadataEntries(LLVMValueMetadataEntries entries, size_t len);
LLVMValueMetadataEntries get() const;
Expand Down
3 changes: 2 additions & 1 deletion src/llvm/_types/PyModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
#include <unordered_map>
#include <mutex>
#include <string>
#include "PyLLVMObject.h"

class PyModule {
class PyModule : public PyLLVMObject<PyModule, LLVMModuleRef> {
public:
explicit PyModule(const std::string &id);
explicit PyModule(const std::string &id, LLVMContextRef context);
Expand Down
3 changes: 2 additions & 1 deletion src/llvm/_types/PyModuleFlagEntries.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
#include <memory>
#include <unordered_map>
#include <mutex>
#include "PyLLVMObject.h"

typedef LLVMModuleFlagEntry *LLVMModuleFlagEntries;

class PyModuleFlagEntries {
class PyModuleFlagEntries : public PyLLVMObject<PyModuleFlagEntries, LLVMModuleFlagEntries>{
public:
explicit PyModuleFlagEntries(LLVMModuleFlagEntries entries, size_t len);
LLVMModuleFlagEntries get() const;
Expand Down
3 changes: 2 additions & 1 deletion src/llvm/_types/PyModuleProvider.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#include <memory>
#include <unordered_map>
#include <mutex>
#include "PyLLVMObject.h"

class PyModuleProvider {
class PyModuleProvider : public PyLLVMObject<PyModuleProvider, LLVMModuleProviderRef>{
public:
explicit PyModuleProvider(LLVMModuleProviderRef mp);
LLVMModuleProviderRef get() const;
Expand Down
3 changes: 2 additions & 1 deletion src/llvm/_types/PyOperandBundle.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#include <memory>
#include <unordered_map>
#include <mutex>
#include "PyLLVMObject.h"

class PyOperandBundle {
class PyOperandBundle : public PyLLVMObject<PyOperandBundle, LLVMOperandBundleRef> {
public:
explicit PyOperandBundle(LLVMOperandBundleRef bundle);
LLVMOperandBundleRef get() const;
Expand Down
3 changes: 2 additions & 1 deletion src/llvm/_types/PyPassManagerBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#include <memory>
#include <unordered_map>
#include <mutex>
#include "PyLLVMObject.h"

class PyPassManagerBase {
class PyPassManagerBase : public PyLLVMObject<PyPassManagerBase, LLVMPassManagerRef> {
public:
explicit PyPassManagerBase(LLVMPassManagerRef pm);
LLVMPassManagerRef get() const;
Expand Down
6 changes: 1 addition & 5 deletions src/llvmpym_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@ using namespace nb::literals;
NB_MODULE(llvmpym_ext, m) {
m.doc() = "LLVM Python Native Extension";

nb::class_<PyLLVMObject<PyValue, LLVMValueRef>>
(m, "LLVMObject", "The base of for all LLVM object classes.")
.def("__bool__", &PyLLVMObject<PyValue, LLVMValueRef>::__bool__)
.def("__eq__", &PyLLVMObject<PyValue, LLVMValueRef>::__equal__)
.def("__hash__", &PyLLVMObject<PyValue, LLVMValueRef>::__hash__);
BIND_PYLLVMOBJECT();

auto coreModule = m.def_submodule("core", "LLVM Core");
populateCore(coreModule);
Expand Down

0 comments on commit b535e75

Please sign in to comment.