From 4f10667d0838f99b96d573628d7e1abcef6e9d26 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 25 Dec 2023 14:41:46 -0600 Subject: [PATCH 1/2] move upate vars out of meta so it doesn't collide with upstream --- mlir/extras/dialects/ext/tensor.py | 4 +-- mlir/extras/meta.py | 38 ------------------------ mlir/extras/util.py | 46 ++++++++++++++++++++++++++---- 3 files changed, 43 insertions(+), 45 deletions(-) diff --git a/mlir/extras/dialects/ext/tensor.py b/mlir/extras/dialects/ext/tensor.py index 91ff9e8..e7a75c6 100644 --- a/mlir/extras/dialects/ext/tensor.py +++ b/mlir/extras/dialects/ext/tensor.py @@ -18,9 +18,9 @@ from ....dialects import tensor from ....dialects.tensor import * from .arith import ArithValue, Scalar, constant -from ...meta import region_op, _update_caller_vars +from ...meta import region_op from ...._mlir_libs._mlir import register_value_caster -from ...util import get_user_code_loc +from ...util import get_user_code_loc, _update_caller_vars from ....dialects._ods_common import get_op_result_or_op_results S = ShapedType.get_dynamic_size() diff --git a/mlir/extras/meta.py b/mlir/extras/meta.py index d137347..4d4c496 100644 --- a/mlir/extras/meta.py +++ b/mlir/extras/meta.py @@ -1,9 +1,7 @@ import contextlib -import ctypes import inspect import warnings from functools import wraps -from typing import Sequence from ..ir import Type, InsertionPoint, OpResultList, OpView from ..dialects._ods_common import get_op_result_or_op_results @@ -142,42 +140,6 @@ def maybe_no_args(*args, **kwargs): return maybe_no_args -def _update_caller_vars(previous_frame, args: Sequence, replacements: Sequence): - """Update caller vars passed as args. - - This function uses CPython API to update the values - of the caller's args (not the caller of this function but the caller of caller of this function). - It does this by searching for a match in the caller's f_locals based on identity (A is A) and then - updating all corresponding values in the f_locals dict. Finally, it uses PyFrame_LocalsToFast to signal - to the CPython runtime that an update has been made to f_locals. - - Args: - previous_frame: The frame in which vars will be updated. - args: The args to the callee. - replacements: The values that should replace the values of the vars in the caller. - """ - - if len(args) != len(replacements): - raise ValueError(f"updates must be 1-1: {args=} {replacements=}") - # find the name of the iter args in the previous frame - var_names = [ - [ - var_name - for var_name, var_val in previous_frame.f_locals.items() - if var_val is arg - ] - for arg in args - ] - for i, var_names in enumerate(var_names): - for var_name in var_names: - previous_frame.f_locals[var_name] = replacements[i] - # signal to update - # for some reason you can only update one at a time? - ctypes.pythonapi.PyFrame_LocalsToFast( - ctypes.py_object(previous_frame), ctypes.c_int(1) - ) - - class ModuleMeta(type): def __new__(cls, name, bases, classdict, **kwargs): ip = classdict.pop("ip") diff --git a/mlir/extras/util.py b/mlir/extras/util.py index 9df23bf..4165414 100644 --- a/mlir/extras/util.py +++ b/mlir/extras/util.py @@ -6,7 +6,7 @@ import warnings from dataclasses import dataclass from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, Sequence import numpy as np @@ -184,13 +184,13 @@ def infer_mlir_type( if isinstance(py_val, bool): return T.bool() elif isinstance(py_val, int): - if -(2**31) <= py_val < 2**31: + if -(2 ** 31) <= py_val < 2 ** 31: return T.i32() - elif 2**31 <= py_val < 2**32: + elif 2 ** 31 <= py_val < 2 ** 32: return T.ui32() - elif -(2**63) <= py_val < 2**63: + elif -(2 ** 63) <= py_val < 2 ** 63: return T.i64() - elif 2**63 <= py_val < 2**64: + elif 2 ** 63 <= py_val < 2 ** 64: return T.ui64() else: raise RuntimeError(f"Nonrepresentable integer {py_val}.") @@ -224,3 +224,39 @@ def memref_type_to_np_dtype(memref_type): T.memref(T.i64()): np.int64, } return _memref_type_to_np_dtype.get(memref_type) + + +def _update_caller_vars(previous_frame, args: Sequence, replacements: Sequence): + """Update caller vars passed as args. + + This function uses CPython API to update the values + of the caller's args (not the caller of this function but the caller of caller of this function). + It does this by searching for a match in the caller's f_locals based on identity (A is A) and then + updating all corresponding values in the f_locals dict. Finally, it uses PyFrame_LocalsToFast to signal + to the CPython runtime that an update has been made to f_locals. + + Args: + previous_frame: The frame in which vars will be updated. + args: The args to the callee. + replacements: The values that should replace the values of the vars in the caller. + """ + + if len(args) != len(replacements): + raise ValueError(f"updates must be 1-1: {args=} {replacements=}") + # find the name of the iter args in the previous frame + var_names = [ + [ + var_name + for var_name, var_val in previous_frame.f_locals.items() + if var_val is arg + ] + for arg in args + ] + for i, var_names in enumerate(var_names): + for var_name in var_names: + previous_frame.f_locals[var_name] = replacements[i] + # signal to update + # for some reason you can only update one at a time? + ctypes.pythonapi.PyFrame_LocalsToFast( + ctypes.py_object(previous_frame), ctypes.c_int(1) + ) From bcc31ad20c427ad4cc623d5634a96e2d24fee45c Mon Sep 17 00:00:00 2001 From: max Date: Mon, 25 Dec 2023 14:50:18 -0600 Subject: [PATCH 2/2] remove meta to not collide (move stuff to util) --- mlir/extras/dialects/ext/cf.py | 2 +- mlir/extras/dialects/ext/func.py | 4 +- mlir/extras/dialects/ext/gpu.py | 4 +- mlir/extras/dialects/ext/scf.py | 4 +- mlir/extras/meta.py | 155 ------------------------------- mlir/extras/util.py | 88 +++++++++++++++++- tests/test_regions.py | 2 +- tests/test_scf.py | 2 +- 8 files changed, 92 insertions(+), 169 deletions(-) delete mode 100644 mlir/extras/meta.py diff --git a/mlir/extras/dialects/ext/cf.py b/mlir/extras/dialects/ext/cf.py index 8796e97..ddb0cbb 100644 --- a/mlir/extras/dialects/ext/cf.py +++ b/mlir/extras/dialects/ext/cf.py @@ -8,7 +8,7 @@ _cext, ) from ....ir import Value, InsertionPoint, Block, OpView -from ...meta import get_user_code_loc, Successor +from ...util import get_user_code_loc, Successor @_cext.register_operation(_Dialect, replace=True) diff --git a/mlir/extras/dialects/ext/func.py b/mlir/extras/dialects/ext/func.py index 70d0c27..2f0bec3 100644 --- a/mlir/extras/dialects/ext/func.py +++ b/mlir/extras/dialects/ext/func.py @@ -2,8 +2,8 @@ import sys from typing import Union, Optional -from ...meta import make_maybe_no_args_decorator, op_region_builder -from ...util import get_user_code_loc +from ...meta import op_region_builder +from ...util import get_user_code_loc, make_maybe_no_args_decorator from ....dialects.func import * from ....ir import ( InsertionPoint, diff --git a/mlir/extras/dialects/ext/gpu.py b/mlir/extras/dialects/ext/gpu.py index 36cfdf0..68c6ac5 100644 --- a/mlir/extras/dialects/ext/gpu.py +++ b/mlir/extras/dialects/ext/gpu.py @@ -21,12 +21,10 @@ from .arith import constant from .func import FuncBase from ...meta import ( - ModuleMeta, - make_maybe_no_args_decorator, region_op, ) from ....dialects._ods_common import get_op_result_or_op_results -from ...util import get_user_code_loc +from ...util import get_user_code_loc, make_maybe_no_args_decorator, ModuleMeta def block_id_x(): diff --git a/mlir/extras/dialects/ext/scf.py b/mlir/extras/dialects/ext/scf.py index 8e71cc7..7ec47d1 100644 --- a/mlir/extras/dialects/ext/scf.py +++ b/mlir/extras/dialects/ext/scf.py @@ -13,8 +13,8 @@ ) from ...ast.util import ast_call, set_lineno from .gpu import get_device_mapping_array_attr -from ...meta import region_adder, region_op -from ...util import get_user_code_loc +from ...meta import region_op +from ...util import get_user_code_loc, region_adder from ....dialects._ods_common import ( get_op_result_or_op_results, get_default_loc_context, diff --git a/mlir/extras/meta.py b/mlir/extras/meta.py deleted file mode 100644 index 4d4c496..0000000 --- a/mlir/extras/meta.py +++ /dev/null @@ -1,155 +0,0 @@ -import contextlib -import inspect -import warnings -from functools import wraps - -from ..ir import Type, InsertionPoint, OpResultList, OpView -from ..dialects._ods_common import get_op_result_or_op_results - -try: - from ..ir import TypeID -except ImportError: - warnings.warn( - f"TypeID not supported by host bindings; value casting won't work correctly" - ) - TypeID = object - -from .util import get_user_code_loc, Successor - - -# builds the decorator -def make_maybe_no_args_decorator(decorator): - """ - a decorator decorator, allowing the decorator to be used as: - @decorator(with, arguments, and=kwargs) - or - @decorator - """ - - @wraps(decorator) - def new_dec(*args, **kwargs): - if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): - # actual decorated function - return decorator(args[0]) - else: - # decorator arguments - return lambda realf: decorator(realf, *args, **kwargs) - - return new_dec - - -@contextlib.contextmanager -def bb(*preds: tuple[Successor | OpView]): - current_ip = InsertionPoint.current - op = current_ip.block.owner - op_region = op.regions[0] - args = [] - if len(preds): - if isinstance(preds[0], OpView): - args = preds[0].operands - elif isinstance(preds[0], Successor): - args = preds[0].operands - else: - raise NotImplementedError(f"{preds[0]=} not supported.") - arg_locs = list(filter(None, [get_user_code_loc()] * len(args))) - if len(arg_locs) == 0: - arg_locs = None - block = op_region.blocks.append(*[a.type for a in args], arg_locs=arg_locs) - for p in preds: - if isinstance(p, OpView): - p.operation.successors[0] = block - elif isinstance(p, Successor): - for i, b in enumerate(p.block.owner.successors): - if i == p.pos: - p.op.successors[i] = block - p.block = block - break - with InsertionPoint(block): - yield block, list(block.arguments) - - -def op_region_builder(op, op_region, terminator=None): - def builder_wrapper(body_builder): - # add a block with block args having types ... - if len(op_region.blocks) == 0: - sig = inspect.signature(body_builder) - types = [p.annotation for p in sig.parameters.values()] - if not ( - len(types) == len(sig.parameters) - and all(isinstance(t, Type) for t in types) - ): - raise ValueError( - f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}" - ) - - arg_locs = list(filter(None, [get_user_code_loc()] * len(sig.parameters))) - if len(arg_locs) == 0: - arg_locs = None - op_region.blocks.append(*types, arg_locs=arg_locs) - - with InsertionPoint(op_region.blocks[0]): - results = body_builder(*list(op_region.blocks[0].arguments)) - - with InsertionPoint(list(op_region.blocks)[-1]): - if terminator is not None: - res = [] - if isinstance(results, (tuple, list)): - res.extend(results) - elif results is not None: - res.append(results) - terminator(res) - - res = get_op_result_or_op_results(op) - if isinstance(res, (OpResultList, list, tuple)): - return tuple(res) - else: - return res - - return builder_wrapper - - -def region_adder(terminator=None): - def wrapper(op_region_adder): - def region_adder_decorator(op, *args, **kwargs): - region = op_region_adder(op, *args, **kwargs) - - return op_region_builder(op, region, terminator) - - return region_adder_decorator - - return wrapper - - -def region_op(op_constructor, terminator=None): - # the decorator itself - def op_decorator(*args, **kwargs): - op = op_constructor(*args, **kwargs) - op_region = op.regions[0] - - return op_region_builder(op, op_region, terminator) - - # this is like make_maybe_no_args_decorator but a little different because the decorators here - # are already wrapped (or something like that) - @wraps(op_decorator) - def maybe_no_args(*args, **kwargs): - if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): - return op_decorator()(args[0]) - else: - return op_decorator(*args, **kwargs) - - return maybe_no_args - - -class ModuleMeta(type): - def __new__(cls, name, bases, classdict, **kwargs): - ip = classdict.pop("ip") - loc = classdict.pop("loc") - module_terminator = classdict.pop("module_terminator", None) - new = super().__new__(cls, name, bases, classdict) - if module_terminator is not None: - module_terminator(loc=loc, ip=ip) - for k, v in classdict.items(): - if callable(v): - v.qualname = name - ip.__exit__(None, None, None) - return new diff --git a/mlir/extras/util.py b/mlir/extras/util.py index 4165414..c9c0384 100644 --- a/mlir/extras/util.py +++ b/mlir/extras/util.py @@ -5,25 +5,28 @@ import sys import warnings from dataclasses import dataclass +from functools import wraps from pathlib import Path from typing import Callable, Optional, Union, Sequence import numpy as np +from .meta import op_region_builder from ..ir import ( Block, Context, + F32Type, + F64Type, + InsertionPoint, + IntegerType, Location, OpResult, OpResultList, OpView, Operation, + RankedTensorType, Value, _GlobalDebug, - IntegerType, - F32Type, - F64Type, - RankedTensorType, ) from ..extras import types as T @@ -260,3 +263,80 @@ def _update_caller_vars(previous_frame, args: Sequence, replacements: Sequence): ctypes.pythonapi.PyFrame_LocalsToFast( ctypes.py_object(previous_frame), ctypes.c_int(1) ) + + +def make_maybe_no_args_decorator(decorator): + """ + a decorator decorator, allowing the decorator to be used as: + @decorator(with, arguments, and=kwargs) + or + @decorator + """ + + @wraps(decorator) + def new_dec(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # actual decorated function + return decorator(args[0]) + else: + # decorator arguments + return lambda realf: decorator(realf, *args, **kwargs) + + return new_dec + + +@contextlib.contextmanager +def bb(*preds: tuple[Successor | OpView]): + current_ip = InsertionPoint.current + op = current_ip.block.owner + op_region = op.regions[0] + args = [] + if len(preds): + if isinstance(preds[0], OpView): + args = preds[0].operands + elif isinstance(preds[0], Successor): + args = preds[0].operands + else: + raise NotImplementedError(f"{preds[0]=} not supported.") + arg_locs = list(filter(None, [get_user_code_loc()] * len(args))) + if len(arg_locs) == 0: + arg_locs = None + block = op_region.blocks.append(*[a.type for a in args], arg_locs=arg_locs) + for p in preds: + if isinstance(p, OpView): + p.operation.successors[0] = block + elif isinstance(p, Successor): + for i, b in enumerate(p.block.owner.successors): + if i == p.pos: + p.op.successors[i] = block + p.block = block + break + with InsertionPoint(block): + yield block, list(block.arguments) + + +def region_adder(terminator=None): + def wrapper(op_region_adder): + def region_adder_decorator(op, *args, **kwargs): + region = op_region_adder(op, *args, **kwargs) + + return op_region_builder(op, region, terminator) + + return region_adder_decorator + + return wrapper + + +class ModuleMeta(type): + def __new__(cls, name, bases, classdict, **kwargs): + ip = classdict.pop("ip") + loc = classdict.pop("loc") + module_terminator = classdict.pop("module_terminator", None) + new = super().__new__(cls, name, bases, classdict) + if module_terminator is not None: + module_terminator(loc=loc, ip=ip) + for k, v in classdict.items(): + if callable(v): + v.qualname = name + ip.__exit__(None, None, None) + return new diff --git a/tests/test_regions.py b/tests/test_regions.py index db580b9..13bca3d 100644 --- a/tests/test_regions.py +++ b/tests/test_regions.py @@ -17,7 +17,7 @@ from mlir.dialects.tensor import yield_ as tensor_yield from mlir.extras.dialects.ext.tensor import generate from mlir.dialects.tensor import rank -from mlir.extras.meta import bb +from mlir.extras.util import bb # noinspection PyUnresolvedReferences from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext diff --git a/tests/test_scf.py b/tests/test_scf.py index 600fa77..b116f0c 100644 --- a/tests/test_scf.py +++ b/tests/test_scf.py @@ -73,7 +73,7 @@ def forfoo(i, *iter_args): return one, one assert len(forfoo) == 2 and all(isinstance(i, Scalar) for i in forfoo) - assert repr(forfoo) == "(Scalar(%0#0, f32), Scalar(%0#1, f32))" + assert repr(forfoo) == "[Scalar(%0#0, f32), Scalar(%0#1, f32)]" ctx.module.operation.verify() correct = dedent( """\