diff --git a/README.md b/README.md index ac086e6..21d53d9 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,11 @@ This package is meant to work in concert with the upstream bindings. Practically speaking that means you need to have *some* package installed that includes mlir python bindings. In addition, you have to do one of two things to **configure this package** (after installing it): -1. `$ configure-mlir-python-utils -y `, where `MLIR_PYTHON_PACKAGE_PREFIX` is (as it says) the +1. `$ configure-mlir-python-utils -y `, where `MLIR_PYTHON_PACKAGE_PREFIX` is (as it says) + the package prefix for your chosen upstream bindings. So for example, for `torch-mlir`, you would - execute `configure-mlir-python-utils torch_mlir`, since `torch-mlir`'s bindings are the root of the `torch-mlir` python + execute `configure-mlir-python-utils torch_mlir`, since `torch-mlir`'s bindings are the root of the `torch-mlir` + python package. **When in doubt about this prefix**, it is everything up until `ir` (e.g., as in `from torch_mlir import ir`). 2. `$ export MLIR_PYTHON_PACKAGE_PREFIX=`, i.e., you can set this string as an environment @@ -49,4 +51,12 @@ pip install setuptools -U pip install -e .[torch-mlir-test] \ -f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest \ -f https://llvm.github.io/torch-mlir/package-index/ +``` + +There's an annoying bug where if you try to register to a different set of host bindings it won't work the first (e.g., +going from `torch-mlir` to `mlir`). +Workaround is to delete the prefix token before configuring, like so: + +```shell +rm /home/mlevental/dev_projects/mlir_utils/mlir_utils/_configuration/__MLIR_PYTHON_PACKAGE_PREFIX__ && configure-mlir-python-utils mlir ``` \ No newline at end of file diff --git a/mlir_utils/_configuration/module_alias_map.py b/mlir_utils/_configuration/module_alias_map.py index a6260b6..cc73a47 100644 --- a/mlir_utils/_configuration/module_alias_map.py +++ b/mlir_utils/_configuration/module_alias_map.py @@ -87,3 +87,11 @@ def find_spec( ) else: return None + + +def maybe_remove_alias_module_loader(): + for i in range(len(sys.meta_path)): + finder = sys.meta_path[i] + if isinstance(finder, AliasedModuleFinder): + del sys.meta_path[i] + return diff --git a/mlir_utils/dialects/ext/arith.py b/mlir_utils/dialects/ext/arith.py index 2d5eb39..ac78138 100644 --- a/mlir_utils/dialects/ext/arith.py +++ b/mlir_utils/dialects/ext/arith.py @@ -155,7 +155,7 @@ def __call__(cls, *args, **kwargs): class ArithValue(Value, metaclass=ArithValueMeta): - """Mixin class for functionality shared by Value subclasses that support + """Class for functionality shared by Value subclasses that support arithmetic operations. Note, since we bind the ArithValueMeta here, it is here that the __new__ and diff --git a/mlir_utils/dialects/ext/tensor.py b/mlir_utils/dialects/ext/tensor.py index d82f7d2..46138ff 100644 --- a/mlir_utils/dialects/ext/tensor.py +++ b/mlir_utils/dialects/ext/tensor.py @@ -6,6 +6,7 @@ from mlir.ir import Type, Value, RankedTensorType, DenseElementsAttr, ShapedType from mlir_utils.dialects.ext.arith import ArithValue +from mlir_utils.dialects.util import register_value_caster try: from mlir_utils.dialects.tensor import * @@ -64,28 +65,5 @@ def empty( return cls(EmptyOp(shape, el_type).result) - def __class_getitem__( - cls, dim_sizes_dtype: Tuple[Union[list[int], tuple[int, ...]], Type] - ) -> Type: - """A convenience method for creating RankedTensorType. - - Args: - dim_sizes_dtype: A tuple of both the shape of the type and the dtype. - - Returns: - An instance of RankedTensorType. - """ - if len(dim_sizes_dtype) != 2: - raise ValueError( - f"Wrong type of argument to {cls.__name__}: {dim_sizes_dtype=}" - ) - dim_sizes, dtype = dim_sizes_dtype - if not isinstance(dtype, Type): - raise ValueError(f"{dtype=} is not {Type=}") - static_sizes = [] - for s in dim_sizes: - if isinstance(s, int): - static_sizes.append(s) - else: - static_sizes.append(ShapedType.get_dynamic_size()) - return RankedTensorType.get(static_sizes, dtype) + +register_value_caster(RankedTensorType.static_typeid, Tensor) diff --git a/mlir_utils/dialects/util.py b/mlir_utils/dialects/util.py index 13e8ec2..c520288 100644 --- a/mlir_utils/dialects/util.py +++ b/mlir_utils/dialects/util.py @@ -1,9 +1,11 @@ import ctypes -from functools import wraps import inspect +from collections import defaultdict +from functools import wraps +from typing import Callable from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values -from mlir.ir import InsertionPoint, Value, Type +from mlir.ir import InsertionPoint, Value, Type, TypeID def get_result_or_results(op): @@ -31,20 +33,52 @@ def maybe_no_args(*args, **kwargs): return maybe_no_args +__VALUE_CASTERS: defaultdict[ + TypeID, list[Callable[[Value], Value | None]] +] = defaultdict(list) + + +def register_value_caster( + typeid: TypeID, caster: Callable[[Value], Value], priority: int = None +): + if not isinstance(typeid, TypeID): + raise ValueError(f"{typeid=} is not a TypeID") + if priority is None: + __VALUE_CASTERS[typeid].append(caster) + else: + __VALUE_CASTERS[typeid].insert(priority, caster) + + +def has_value_caster(typeid: TypeID): + if not isinstance(typeid, TypeID): + raise ValueError(f"{typeid=} is not a TypeID") + if not typeid in __VALUE_CASTERS: + return False + return True + + +def get_value_caster(typeid: TypeID): + if not has_value_caster(typeid): + raise ValueError(f"no registered caster for {typeid=}") + return __VALUE_CASTERS[typeid] + + def maybe_cast(val: Value): """Maybe cast an ir.Value to one of Tensor, Scalar. Args: val: The ir.Value to maybe cast. """ - from mlir_utils.dialects.ext.tensor import Tensor from mlir_utils.dialects.ext.arith import Scalar if not isinstance(val, Value): return val - if Tensor.isinstance(val): - return Tensor(val) + if has_value_caster(val.type.typeid): + for caster in get_value_caster(val.type.typeid): + if casted := caster(val): + return casted + raise ValueError(f"no successful casts for {val=}") if Scalar.isinstance(val): return Scalar(val) return val diff --git a/tests/test_value_caster.py b/tests/test_value_caster.py new file mode 100644 index 0000000..4045e6b --- /dev/null +++ b/tests/test_value_caster.py @@ -0,0 +1,35 @@ +import pytest +from mlir.ir import OpResult + +from mlir_utils.dialects.ext.tensor import S, empty +from mlir_utils.dialects.ext.arith import constant +from mlir_utils.dialects.util import register_value_caster + +# noinspection PyUnresolvedReferences +from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext +from mlir_utils.types import f64_t, RankedTensorType + +# needed since the fix isn't defined here nor conftest.py +pytest.mark.usefixtures("ctx") + + +def test_caster_registration(ctx: MLIRContext): + sizes = S, 3, S + ten = empty(sizes, f64_t) + assert repr(ten) == "Tensor(%0, tensor)" + + def dummy_caster(val): + print(val) + return val + + register_value_caster(RankedTensorType.static_typeid, dummy_caster) + ten = empty(sizes, f64_t) + assert repr(ten) == "Tensor(%1, tensor)" + + register_value_caster(RankedTensorType.static_typeid, dummy_caster, 0) + ten = empty(sizes, f64_t) + assert repr(ten) != "Tensor(%1, tensor)" + assert isinstance(ten, OpResult) + + one = constant(1) + assert repr(one) == "Scalar(%3, i64)"