Skip to content

Commit

Permalink
feat: equality of object / Value
Browse files Browse the repository at this point in the history
  • Loading branch information
Ziqi-Yang committed Aug 3, 2024
1 parent e6b3f6e commit 92f1937
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 45 deletions.
11 changes: 11 additions & 0 deletions docs/source/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,14 @@ As a contrast, the following went smooth.
In conclusion, it is suggested that classes which do self memory control shouldn't be
appeared as a default argument of some functions.


Resources / References
----------------------

LLVM online reference are all of the latest version. To view reference of a certain version,
please manually build the docs.

- `LLVM C API doxygen <https://llvm.org/docs/doxygen/group__LLVMCCore.html>`_
- `LLVM Reference Manual <https://llvm.org/docs/LangRef.html>`_

14 changes: 11 additions & 3 deletions example/parse_ir_assmebly.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@
m = utils.parse_assembly(asm_str)

for f in m.functions:
print(f'Function: {f.name}/`{f.type}`')
m = f.parent
print(f'Function | name: "{f.name}", type: "{f.type}"')
module = f.parent
# assert m == module
assert f.kind == core.ValueKind.Function
print(f"Functoin attributes: {}")
print()
for i, a in enumerate(f.args, 1):
print(f'Argument | name: "{a.name}", type: "{a.type}"')
attrs = f.get_attributes_at_index(i)
print(f"\tattrs: {attrs}")

print("\n----------------------------\n")

2 changes: 1 addition & 1 deletion src/llvm/Core/iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace nb::literals;
void bindIterators(nb::module_ &m) {
BIND_ITERATOR_CLASS(PyUseIterator, "UseIterator")
BIND_ITERATOR_CLASS(PyBasicBlockIterator, "BasicBlockIterator")
BIND_ITERATOR_CLASS(PyArgumentIterator, "ArgumentIterator")
// BIND_ITERATOR_CLASS(PyArgumentIterator, "ArgumentIterator")
BIND_ITERATOR_CLASS(PyInstructionIterator, "InstructionIterator")
BIND_ITERATOR_CLASS(PyGlobalVariableIterator, "GlobalVariableIterator")
BIND_ITERATOR_CLASS(PyGlobalIFuncIterator, "GlobalIFuncIterator")
Expand Down
37 changes: 19 additions & 18 deletions src/llvm/Core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,24 @@ template <typename T>
using optional = std::optional<T>;

void bindTypeClasses(nb::module_ &m) {
nb::class_<PyType>(m, "Type", "Type")
auto TypeClass = nb::class_<PyType>(m, "Type", "Type");
auto TypeIntClass = nb::class_<PyTypeInt, PyType>(m, "IntType", "IntType");
auto TypeRealClass = nb::class_<PyTypeReal, PyType>(m, "RealType", "RealType");
auto TypeFunctionClass = nb::class_<PyTypeFunction, PyType> (m, "FunctionType", "FunctionType");
auto TypeStructClass = nb::class_<PyTypeStruct, PyType> (m, "StructType", "StructType");
auto TypeSequenceClass = nb::class_<PyTypeSequence, PyType>(m, "SequenceType", "SequenceType");
auto TypeArrayClass = nb::class_<PyTypeArray, PyTypeSequence>(m, "ArrayType", "ArrayType");
auto TypePointerClass = nb::class_<PyTypePointer, PyTypeSequence>(m, "PointerType", "PointerType");
auto TypeVectorClass = nb::class_<PyTypeVector, PyTypeSequence>(m, "VectorType", "VectorType");
auto TypeVoidClass = nb::class_<PyTypeVoid, PyType>(m, "VoidType", "VoidType");
auto TypeLabelClass = nb::class_<PyTypeLabel, PyType>(m, "LabelType", "LabelType");
auto TypeX86MMXClass = nb::class_<PyTypeX86MMX, PyType>(m, "X86MMXType", "X86MMXType");
auto TypeX86AMXClass = nb::class_<PyTypeX86AMX, PyType>(m, "X86AMXType", "X86AMXType");
auto TypeTokenClass = nb::class_<PyTypeToken, PyType>(m, "TokenType", "TokenType");
auto TypeMetadataClass = nb::class_<PyTypeMetadata, PyType>(m, "MetadataType", "MetadataType");
auto TypeTargetExtClass = nb::class_<PyTypeTargetExt, PyType>(m, "TargetExtType", "TargetExtType");

TypeClass
.def("__repr__",
[](PyType &self) {
auto kind = get_repr_str(LLVMGetTypeKind(self.get()));
Expand Down Expand Up @@ -80,24 +97,8 @@ void bindTypeClasses(nb::module_ &m) {
return LLVMDumpType(self.get());
},
"Dump a representation of a type to stderr.");

auto TypeIntClass = nb::class_<PyTypeInt, PyType>(m, "IntType", "IntType");
auto TypeRealClass = nb::class_<PyTypeReal, PyType>(m, "RealType", "RealType");
auto TypeFunctionClass = nb::class_<PyTypeFunction, PyType> (m, "FunctionType", "FunctionType");
auto TypeStructClass = nb::class_<PyTypeStruct, PyType> (m, "StructType", "StructType");
auto TypeSequenceClass = nb::class_<PyTypeSequence, PyType>(m, "SequenceType", "SequenceType");
auto TypeArrayClass = nb::class_<PyTypeArray, PyTypeSequence>(m, "ArrayType", "ArrayType");
auto TypePointerClass = nb::class_<PyTypePointer, PyTypeSequence>(m, "PointerType", "PointerType");
auto TypeVectorClass = nb::class_<PyTypeVector, PyTypeSequence>(m, "VectorType", "VectorType");
auto TypeVoidClass = nb::class_<PyTypeVoid, PyType>(m, "VoidType", "VoidType");
auto TypeLabelClass = nb::class_<PyTypeLabel, PyType>(m, "LabelType", "LabelType");
auto TypeX86MMXClass = nb::class_<PyTypeX86MMX, PyType>(m, "X86MMXType", "X86MMXType");
auto TypeX86AMXClass = nb::class_<PyTypeX86AMX, PyType>(m, "X86AMXType", "X86AMXType");
auto TypeTokenClass = nb::class_<PyTypeToken, PyType>(m, "TokenType", "TokenType");
auto TypeMetadataClass = nb::class_<PyTypeMetadata, PyType>(m, "MetadataType", "MetadataType");
auto TypeTargetExtClass = nb::class_<PyTypeTargetExt, PyType>(m, "TargetExtType", "TargetExtType");



TypeIntClass
.def("__repr__",
[](PyTypeInt &self) {
Expand Down
35 changes: 15 additions & 20 deletions src/llvm/Core/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ using optional = std::optional<T>;


void bindValueClasses(nb::module_ &m) {
auto ValueClass = nb::class_<PyValue>(m, "Value", "Value");

auto ValueClass = nb::class_<PyValue, PyLLVMObject<PyValue, LLVMValueRef>>
(m, "Value", "Value");
nb::class_<PyMetadataAsValue, PyValue>(m, "MetadataAsValue", "MetadataAsValue");
auto MDNodeValueClass = nb::class_<PyMDNodeValue, PyMetadataAsValue>
(m, "MDNodeValue", "MDNodeValue");
Expand Down Expand Up @@ -696,7 +696,7 @@ void bindValueClasses(nb::module_ &m) {
},
"index"_a,
"Obtain the operand bundle attached to this instruction at the given index.")
.def("set_param_alignment",
.def("set_arg_alignment",
[](PyCallBase &self, LLVMAttributeIndex idx, unsigned align) {
return LLVMSetInstrParamAlignment(self.get(), idx, align);
},
Expand Down Expand Up @@ -1086,34 +1086,29 @@ void bindValueClasses(nb::module_ &m) {
.def_prop_ro("debug_loc_line",
[](PyFunction &f) { return LLVMGetDebugLocLine(f.get()); },
"Return the line number of the debug location for this value")
.def_prop_ro("param_num",
.def_prop_ro("arg_num",
[](PyFunction &self) {
return LLVMCountParams(self.get());
})
.def_prop_ro("params",
[](PyFunction &self) {
unsigned param_num = LLVMCountParams(self.get());
std::vector<LLVMValueRef> params(param_num);
LLVMGetParams(self.get(), params.data());
WRAP_VECTOR_FROM_DEST(PyArgument, param_num, res, params);
return res;
})
.def_prop_ro("first_param",
.def_prop_ro("first_arg",
[](PyFunction &self) -> optional<PyArgument> {
auto res = LLVMGetFirstParam(self.get());
WRAP_OPTIONAL_RETURN(res, PyArgument);
})
.def_prop_ro("last_param",
.def_prop_ro("last_arg",
[](PyFunction &self) -> optional<PyArgument> {
auto res = LLVMGetLastParam(self.get());
WRAP_OPTIONAL_RETURN(res, PyArgument);
})
// .def_prop_ro("params", // also have the same name method
// [](PyFunction &self) {
// auto res = LLVMGetFirstParam(self.get());
// return PyArgumentIterator(PyArgument(res));
// })
.def("get_param",
.def_prop_ro("args",
[](PyFunction &self) {
unsigned param_num = LLVMCountParams(self.get());
std::vector<LLVMValueRef> params(param_num);
LLVMGetParams(self.get(), params.data());
WRAP_VECTOR_FROM_DEST(PyArgument, param_num, res, params);
return res;
})
.def("get_arg",
[](PyFunction &self, unsigned index) {
return PyArgument(LLVMGetParam(self.get(), index));
},
Expand Down
35 changes: 32 additions & 3 deletions src/llvm/_types.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
// For type/value class hierarchy see doxygen/group__LLVMCCore.html
//
// ====================================

#ifndef LLVMPYM__TYPES_H
#define LLVMPYM__TYPES_H

Expand All @@ -20,7 +24,7 @@


#define DEFINE_PY_WRAPPER_CLASS(ClassName, UnderlyingType) \
class ClassName { \
class ClassName: public PyLLVMObject<ClassName, UnderlyingType> { \
public: \
explicit ClassName(UnderlyingType raw) \
: raw(raw) {} \
Expand All @@ -34,7 +38,7 @@
};

#define DEFINE_PY_WRAPPER_CLASS_POLYMORPHIC(ClassName, UnderlyingType) \
class ClassName { \
class ClassName: public PyLLVMObject<ClassName, UnderlyingType> { \
public: \
virtual ~ClassName() = default; \
explicit ClassName(UnderlyingType raw) \
Expand Down Expand Up @@ -245,6 +249,31 @@ enum class PyLLVMFastMathFlags {
All = LLVMFastMathAll
};

template <typename Derived, typename UnderlyingType>
class PyLLVMObject {
public:
virtual ~PyLLVMObject() = default;

UnderlyingType get() const {
return const_cast<const Derived*>(static_cast<const Derived*>(this))->get();
}

bool __bool__() const {
UnderlyingType raw = static_cast<UnderlyingType>(get());
if (!raw) return false;
return true;
}

// `__equal__` and `__hash__` works well on pointer type UnderlyingType
bool __equal__(const PyLLVMObject& other) const {
return this->get() == other.get();
}

std::size_t __hash__() const {
return std::hash<UnderlyingType>{}(this->get());
}
};


DEFINE_PY_WRAPPER_CLASS_POLYMORPHIC(PyValue, LLVMValueRef)
DEFINE_PY_WRAPPER_CLASS_POLYMORPHIC(PyType, LLVMTypeRef)
Expand Down Expand Up @@ -309,7 +338,7 @@ DEFINE_DIRECT_SUB_CLASS(PyPassManagerBase, PyFunctionPassManager);

DEFINE_ITERATOR_CLASS(PyUseIterator, PyUse, LLVMGetNextUse)
DEFINE_ITERATOR_CLASS(PyBasicBlockIterator, PyBasicBlock, LLVMGetNextBasicBlock)
DEFINE_ITERATOR_CLASS(PyArgumentIterator, PyArgument, LLVMGetNextParam)
// DEFINE_ITERATOR_CLASS(PyArgumentIterator, PyArgument, LLVMGetNextParam)
DEFINE_ITERATOR_CLASS(PyInstructionIterator, PyInstruction, LLVMGetNextInstruction)
DEFINE_ITERATOR_CLASS(PyGlobalVariableIterator, PyGlobalVariable, LLVMGetNextGlobal)
DEFINE_ITERATOR_CLASS(PyGlobalIFuncIterator, PyGlobalIFunc, LLVMGetNextGlobalIFunc)
Expand Down
7 changes: 7 additions & 0 deletions src/llvmpym_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "llvm/Core.h"
#include "llvm/ErrorHandling.h"
#include "llvm/Utils.h"
#include "llvm/_types.h"

namespace nb = nanobind;
using namespace nb::literals;
Expand All @@ -10,6 +11,12 @@ using namespace nb::literals;
NB_MODULE(llvmpym_ext, m) {
m.doc() = "LLVM Python Native Extension";

nb::class_<PyLLVMObject<PyValue, LLVMValueRef>>
(m, "LLVMObject", "The base of for all LLVM object classes.")
.def("__bool__", &PyLLVMObject<PyValue, LLVMValueRef>::__bool__)
.def("__eq__", &PyLLVMObject<PyValue, LLVMValueRef>::__equal__)
.def("__hash__", &PyLLVMObject<PyValue, LLVMValueRef>::__hash__);

auto coreModule = m.def_submodule("core", "LLVM Core");
populateCore(coreModule);

Expand Down
6 changes: 6 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,9 @@ def test_module(self):
pass


class TestEquality:
# TODO
def test_value(self):
x = ConstantInt(IntType.GlobalInt32, 100, True)
y = ConstantInt(IntType.GlobalInt32, 100, True)
assert x == y

0 comments on commit 92f1937

Please sign in to comment.