Skip to content

Commit

Permalink
Remove generate upstream (#31)
Browse files Browse the repository at this point in the history
* use upstream generated

* update readme
  • Loading branch information
makslevental authored Dec 17, 2023
1 parent de95702 commit 9bc0b1a
Show file tree
Hide file tree
Showing 24 changed files with 694 additions and 556 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ jobs:
shell: bash
run: |
pip install .[test,mlir] -v -f https://makslevental.github.io/wheels
mlir-python-utils-generate-all-upstream-trampolines
- name: Test
shell: bash
Expand Down Expand Up @@ -81,7 +80,6 @@ jobs:
run: |
export PIP_FIND_LINKS=https://makslevental.github.io/wheels
HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[test,jax] -v
jaxlib-mlir-python-utils-generate-all-upstream-trampolines
- name: Test
shell: bash
Expand Down Expand Up @@ -160,7 +158,6 @@ jobs:
cd /workspace
pip install -q .[test,mlir] -f https://makslevental.github.io/wheels
mlir-python-utils-generate-all-upstream-trampolines
pytest --capture=tee-sys tests
python examples/mwe.py
28 changes: 5 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,6 @@ The few main features/affordances:
See [mlir_utils.types](https://github.com/makslevental/mlir-python-utils/blob/a9885db18096a610d29a26293396d860d40ad213/mlir_utils/types.py) for details.
\
 
4. `register_value_caster` (enables `C[i, j]` above)
\
 
1. This is a mechanism for registering a python wrapper class that should wrap an `ir.Value` when that `ir.Value` is fetched using `ir.Operation.results`.
They are primarily called by the generated wrappers.
They enable you to associate "creative" methods with the `ir.Value` (like `__add__` or `__getitem__`).
They are keyed on the `mlir::TypeID` of the `ir.Type` of the `ir.Value` (through `ir.Type.typeid`, or more robustly, through `ir.Type.static_typeid`).
\
\
See [mlir_utils.types](https://github.com/makslevental/mlir-python-utils/blob/a9885db18096a610d29a26293396d860d40ad213/mlir_utils/types.py) for details.
\
 
4. `Pipeline()`
\
 
Expand All @@ -139,22 +127,17 @@ But, open an issue if something isn't clear.

## Install

First
This package is meant to work in concert with host bindings.
Practically speaking that means you need to have *some* package installed that includes mlir python bindings.

```shell
$ HOST_MLIR_PYTHON_PACKAGE_PREFIX=<YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX> pip install git+https://github.com/makslevental/mlir-python-utils
```
So

This package is meant to work in concert with host bindings.
Practically speaking that means you need to have *some* package installed that includes mlir python bindings, **and you need to `mlir-python-utils-generate-all-upstream-trampolines`**.
So after pip-installing you need to
```shell
$ <YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX>-mlir-python-utils-generate-all-upstream-trampolines
$ YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX=<YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX> pip install git+https://github.com/makslevental/mlir-python-utils
```

where `YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX` is (as it says) the package prefix for your chosen host bindings.
**When in doubt about this prefix**, it is everything up until `ir` when you import your bindings, e.g., in `import torch_mlir.ir`, `torch_mlir` is the `HOST_MLIR_PYTHON_PACKAGE_PREFIX` for the torch-mlir bindings.
Thus, for torch-mlir host bindings, you would execute `torch-mlir-mlir-python-utils-generate-all-upstream-trampolines`.
Note, the underscore in `torch_mlir` becomes a dash in `torch-mlir-mlir-python-utils-generate-all-upstream-trampolines`.

If you don't have any such package, but you want to experiment anyway, you can install the "stock" upstream bindings first:

Expand All @@ -166,7 +149,6 @@ and then

```shell
$ pip install git+https://github.com/makslevental/mlir-python-utils
$ mlir-python-utils-generate-all-upstream-trampolines
```

## Examples/Demo
Expand Down
32 changes: 12 additions & 20 deletions mlir/utils/_configuration/generate_trampolines.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,9 @@ def visit_Call(self, node: ast.Call):
# TODO(max): ops that have symboltables need to be classes but that requires some upstream support for statically
# identifying such ops
def generate_op_trampoline(op_class):
from ..util import (
get_result_or_results,
get_user_code_loc,
)
from ..meta import maybe_cast, region_op
from ..util import get_user_code_loc
from ...dialects._ods_common import get_op_result_or_op_results
from ..meta import region_op

_mod = ast.parse(dedent(inspect.getsource(op_class.__init__)))
init_fn = next(n for n in _mod.body if isinstance(n, ast.FunctionDef))
Expand Down Expand Up @@ -106,7 +104,7 @@ def generate_op_trampoline(op_class):
decorator_list = []
body += [
ast.parse(
f"return {maybe_cast.__name__}({get_result_or_results.__name__}({ast.unparse(ast_call(op_class_name, args.args, keywords))}))"
f"return {get_op_result_or_op_results.__name__}({ast.unparse(ast_call(op_class_name, args.args, keywords))})"
).body[0]
]

Expand Down Expand Up @@ -134,9 +132,9 @@ def generate_op_trampoline(op_class):

def generate_dialect_trampolines_from_module(input_module, skips: set):
from .. import util, meta
from ..util import get_result_or_results, get_user_code_loc
from ..meta import region_op, maybe_cast
from ...dialects import _ods_common
from ..util import get_user_code_loc
from ..meta import region_op
from ...dialects._ods_common import get_op_result_or_op_results
from ... import ir

skips.update({"_Dialect"})
Expand Down Expand Up @@ -174,13 +172,13 @@ def generate_dialect_trampolines_from_module(input_module, skips: set):
module=util.__name__,
names=[
ast.alias(f.__name__)
for f in [get_result_or_results, get_user_code_loc]
for f in [get_op_result_or_op_results, get_user_code_loc]
],
level=0,
),
ast.ImportFrom(
module=meta.__name__,
names=[ast.alias(f.__name__) for f in [maybe_cast, region_op]],
names=[ast.alias(f.__name__) for f in [region_op]],
level=0,
),
]
Expand Down Expand Up @@ -259,10 +257,10 @@ def generate_trampolines(
def generate_linalg(mod_path):
from ... import dialects as mlir_dialects

from .. import dialects, util, meta
from .. import dialects, util
from ...dialects.linalg import DefinedOpCallable, OperandKind
from ...dialects._ods_common import get_op_result_or_op_results
from ..util import get_user_code_loc
from ..meta import maybe_cast

linalg_modu = __import__(mod_path, fromlist=["*"])

Expand All @@ -281,11 +279,6 @@ def generate_linalg(mod_path):
names=[ast.alias(f.__name__) for f in [get_user_code_loc]],
level=0,
),
ast.ImportFrom(
module=meta.__name__,
names=[ast.alias(f.__name__) for f in [maybe_cast]],
level=0,
),
]
_keywords = [
ast.keyword("loc", ast.Name("loc")),
Expand All @@ -306,11 +299,10 @@ def generate_linalg(mod_path):
]

keywords = _keywords + [ast.keyword("outs", ast.List(outputs))]
# body = [ast.Str(op_callable.op_def.metadata.doc)]
body = [ast.parse(f"if loc is None: loc = {get_user_code_loc.__name__}()")]
body += [
ast.parse(
f"return {maybe_cast.__name__}({ast.unparse(ast_call('linalg.' + name, inputs, keywords))})"
f"return {ast.unparse(ast_call('linalg.' + name, inputs, keywords))}"
).body[0]
]
n = ast.FunctionDef(
Expand Down
44 changes: 18 additions & 26 deletions mlir/utils/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_arith_cmpipredicateattr,
)
from ....dialects.arith import _is_integer_like_type
from ....dialects._ods_common import get_op_result_or_value
from ....dialects._ods_common import get_op_result_or_value, get_op_result_or_op_results
from ....dialects.linalg.opdsl.lang.emitter import (
_is_floating_point_type,
_is_integer_type,
Expand Down Expand Up @@ -43,14 +43,8 @@
FloatAttr,
)

from ...util import get_result_or_results, get_user_code_loc
from ...meta import register_value_caster, maybe_cast

try:
from ...dialects.arith import *
except ModuleNotFoundError:
pass

from ...util import get_user_code_loc
from ...._mlir_libs._mlir import register_value_caster
from ...types import infer_mlir_type, mlir_type_to_np_dtype


Expand Down Expand Up @@ -87,19 +81,17 @@ def constant(

if _is_complex_type(type):
value = complex(value)
return maybe_cast(
get_result_or_results(
complex_dialect.ConstantOp(
type,
list(
map(
lambda x: FloatAttr.get(type.element_type, x),
[value.real, value.imag],
)
),
loc=loc,
ip=ip,
)
return get_op_result_or_op_results(
complex_dialect.ConstantOp(
type,
list(
map(
lambda x: FloatAttr.get(type.element_type, x),
[value.real, value.imag],
)
),
loc=loc,
ip=ip,
)
)

Expand All @@ -120,8 +112,8 @@ def constant(
type=type,
)

return maybe_cast(
get_result_or_results(arith_dialect.ConstantOp(type, value, loc=loc, ip=ip))
return get_op_result_or_op_results(
arith_dialect.ConstantOp(type, value, loc=loc, ip=ip)
)


Expand All @@ -136,8 +128,8 @@ def index_cast(
loc = get_user_code_loc()
if to is None:
to = IndexType.get()
return maybe_cast(
get_result_or_results(arith_dialect.IndexCastOp(to, value, loc=loc, ip=ip))
return get_op_result_or_op_results(
arith_dialect.IndexCastOp(to, value, loc=loc, ip=ip)
)


Expand Down
Loading

0 comments on commit 9bc0b1a

Please sign in to comment.