From d80dfd363cb3897c02a1545a04fdefc6366bb82c Mon Sep 17 00:00:00 2001 From: max Date: Mon, 2 Oct 2023 11:26:48 -0500 Subject: [PATCH] add py312 tests --- .github/workflows/test.yml | 66 +++++-- .../_configuration/generate_trampolines.py | 2 + mlir/utils/_configuration/util.py | 2 +- mlir/utils/dialects/ext/func.py | 14 +- mlir/utils/dialects/ext/scf.py | 48 ++--- mlir/utils/util.py | 10 +- tests/test_func.py | 12 +- tests/test_location_tracking.py | 6 +- tests/test_other_hosts.py | 174 ++++++++++++++++++ tests/test_smoke.py | 87 --------- tests/test_transformers.py | 56 +++++- tests/util.py | 24 +++ 12 files changed, 342 insertions(+), 159 deletions(-) create mode 100644 tests/test_other_hosts.py delete mode 100644 tests/test_smoke.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e09ccb8..7b9727d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ on: jobs: - test-all: + test-mlir-bindings: runs-on: ${{ matrix.os }} @@ -24,7 +24,7 @@ jobs: fail-fast: false matrix: os: [ ubuntu-22.04, macos-11, windows-2022 ] - py_version: [ "3.10", "3.11" ] + py_version: [ "3.10", "3.11", "3.12" ] steps: - name: Checkout @@ -34,14 +34,53 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.py_version }} + allow-prereleases: true - name: Install and configure shell: bash run: | pip install .[test,mlir] -v -f https://makslevental.github.io/wheels mlir-python-utils-generate-all-upstream-trampolines - - HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[jax] -v + + - name: Test + shell: bash + run: | + if [ ${{ matrix.os }} == 'windows-2022' ]; then + pytest -s --ignore-glob=*test_other_hosts* tests + else + pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests + fi + + - name: Test mwe + shell: bash + run: | + python examples/mwe.py + + test-other-host-bindings: + + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [ ubuntu-22.04, macos-11, windows-2022 ] + py_version: [ "3.10", "3.11" ] + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.py_version }} + allow-prereleases: true + + - name: Install and configure + shell: bash + run: | + export PIP_FIND_LINKS=https://makslevental.github.io/wheels + HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[test,jax] -v jaxlib-mlir-python-utils-generate-all-upstream-trampolines pip install aie -f https://github.com/Xilinx/mlir-aie/releases/expanded_assets/latest-wheels --no-index @@ -53,16 +92,11 @@ jobs: shell: bash run: | if [ ${{ matrix.os }} == 'windows-2022' ]; then - pytest -s tests + pytest -s tests/test_other_hosts.py else - pytest --capture=tee-sys tests + pytest --capture=tee-sys tests/test_other_hosts.py fi - - name: Test mwe - shell: bash - run: | - python examples/mwe.py - test-jupyter: runs-on: ${{ matrix.os }} @@ -81,6 +115,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.py_version }} + allow-prereleases: true - name: Run notebook shell: bash @@ -98,7 +133,7 @@ jobs: fail-fast: false matrix: os: [ ubuntu-22.04 ] - py_version: [ "3.10", "3.11" ] + py_version: [ "3.10", "3.11", "3.12" ] steps: - name: Checkout @@ -120,16 +155,17 @@ jobs: bash miniconda.sh -b -u -p /root/miniconda3 eval "$(/root/miniconda3/bin/conda shell.bash hook)" conda init - conda install -q -y python=${{ matrix.py_version }} run: | eval "$(/root/miniconda3/bin/conda shell.bash hook)" + conda create -n env -q -y -c conda-forge/label/python_rc python=${{ matrix.py_version }} + conda activate env cd /workspace - pip install .[test,mlir] -f https://makslevental.github.io/wheels + pip install -q .[test,mlir] -f https://makslevental.github.io/wheels mlir-python-utils-generate-all-upstream-trampolines - pytest --capture=tee-sys --ignore-glob=*test_smoke* tests + pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests python examples/mwe.py diff --git a/mlir/utils/_configuration/generate_trampolines.py b/mlir/utils/_configuration/generate_trampolines.py index c62c587..193ee7f 100644 --- a/mlir/utils/_configuration/generate_trampolines.py +++ b/mlir/utils/_configuration/generate_trampolines.py @@ -126,6 +126,7 @@ def generate_op_trampoline(op_class): args=args, body=body, decorator_list=decorator_list, + type_params=[], ) ast.fix_missing_locations(n) return n @@ -323,6 +324,7 @@ def generate_linalg(mod_path): ), body=body, decorator_list=[], + type_params=[], ) ast.fix_missing_locations(n) functions.append(n) diff --git a/mlir/utils/_configuration/util.py b/mlir/utils/_configuration/util.py index 9c22760..ce0c87b 100644 --- a/mlir/utils/_configuration/util.py +++ b/mlir/utils/_configuration/util.py @@ -24,7 +24,7 @@ def add_file_to_sources_txt_file(file_path: Path): assert file_path.exists(), f"file being added doesn't exist at {file_path}" relative_file_path = Path(package) / file_path.relative_to(package_root_path) - if dist._read_files_egginfo() is not None: + if hasattr(dist, "_read_files_egginfo") and dist._read_files_egginfo() is not None: with open(dist._path / "SOURCES.txt", "a") as sources_file: sources_file.write(f"\n{relative_file_path}") if dist._read_files_distinfo(): diff --git a/mlir/utils/dialects/ext/func.py b/mlir/utils/dialects/ext/func.py index 59af544..15daa05 100644 --- a/mlir/utils/dialects/ext/func.py +++ b/mlir/utils/dialects/ext/func.py @@ -1,6 +1,9 @@ import inspect +import sys from typing import Union, Optional +from ...meta import make_maybe_no_args_decorator, maybe_cast +from ...util import get_result_or_results, get_user_code_loc from ....dialects.func import FuncOp, ReturnOp, CallOp from ....ir import ( InsertionPoint, @@ -12,9 +15,6 @@ Value, ) -from ...util import get_result_or_results, get_user_code_loc, is_311 -from ...meta import make_maybe_no_args_decorator, maybe_cast - def call( callee_or_results: Union[FuncOp, list[Type]], @@ -145,10 +145,14 @@ def __init__( def _is_decl(self): # magic constant found from looking at the code for an empty fn - if is_311(): + if sys.version_info.minor == 12: + return self.body_builder.__code__.co_code == b"\x97\x00y\x00" + elif sys.version_info.minor == 11: return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00" - else: + elif sys.version_info.minor == 10: return self.body_builder.__code__.co_code == b"d\x00S\x00" + else: + raise NotImplementedError(f"{sys.version_info.minor} not supported.") def __str__(self): return str(f"{self.__class__} {self.__dict__}") diff --git a/mlir/utils/dialects/ext/scf.py b/mlir/utils/dialects/ext/scf.py index 08e63a0..fa6863c 100644 --- a/mlir/utils/dialects/ext/scf.py +++ b/mlir/utils/dialects/ext/scf.py @@ -5,6 +5,20 @@ from typing import Optional, Sequence, Union from bytecode import ConcreteBytecode, ConcreteInstr + +from ... import types as T +from ...ast.canonicalize import ( + StrictTransformer, + Canonicalizer, + BytecodePatcher, + OpCode, +) +from ...ast.util import ast_call, set_lineno +from ...dialects.ext.arith import constant, index_cast +from ...dialects.ext.gpu import get_device_mapping_array_attr +from ...dialects.scf import yield_ as yield__, reduce_return, condition +from ...meta import region_adder, region_op, maybe_cast +from ...util import get_result_or_results, get_user_code_loc from ....dialects._ods_common import get_op_results_or_values, get_default_loc_context from ....dialects.linalg.opdsl.lang.emitter import _is_index_type from ....dialects.scf import ( @@ -28,20 +42,6 @@ Attribute, ) -from ... import types as T -from ...ast.canonicalize import ( - StrictTransformer, - Canonicalizer, - BytecodePatcher, - OpCode, -) -from ...ast.util import ast_call, set_lineno -from ...dialects.ext.arith import constant, index_cast -from ...dialects.ext.gpu import get_device_mapping_array_attr -from ...dialects.scf import yield_ as yield__, reduce_return, condition -from ...meta import region_adder, region_op, maybe_cast -from ...util import get_result_or_results, get_user_code_loc, is_311 - logger = logging.getLogger(__name__) @@ -633,25 +633,7 @@ def visit_If(self, updated_node: ast.If) -> ast.With | list[ast.With, ast.With]: class RemoveJumpsAndInsertGlobals(BytecodePatcher): def patch_bytecode(self, code: ConcreteBytecode, f): - early_returns = [] - for i, c in enumerate(code): - c: ConcreteInstr - if c.opcode == int(OpCode.RETURN_VALUE): - early_returns.append(i) - - if c.opcode in { - # this is the first test condition jump from python <= 3.10 - # "POP_JUMP_IF_FALSE", - # this is the test condition jump from python >= 3.11 - int(OpCode.POP_JUMP_FORWARD_IF_FALSE) - if is_311() - else int(OpCode.POP_JUMP_IF_FALSE), - }: - code[i] = ConcreteInstr( - str(OpCode.POP_TOP), lineno=c.lineno, location=c.location - ) - - # TODO(max): this is bad + # TODO(max): this is bad and should be in the closure rather than as a global f.__globals__[yield_.__name__] = yield_ f.__globals__[if_ctx_manager.__name__] = if_ctx_manager f.__globals__[else_ctx_manager.__name__] = else_ctx_manager diff --git a/mlir/utils/util.py b/mlir/utils/util.py index 381c261..e9a1da4 100644 --- a/mlir/utils/util.py +++ b/mlir/utils/util.py @@ -34,10 +34,6 @@ def get_result_or_results( ) -def is_311(): - return sys.version_info.minor > 10 - - def get_user_code_loc(user_base: Optional[Path] = None): from .. import utils @@ -54,12 +50,14 @@ def get_user_code_loc(user_base: Optional[Path] = None): ): prev_frame = prev_frame.f_back frame_info = inspect.getframeinfo(prev_frame) - if is_311(): + if sys.version_info.minor >= 11: return Location.file( frame_info.filename, frame_info.lineno, frame_info.positions.col_offset ) - else: + elif sys.version_info.minor == 10: return Location.file(frame_info.filename, frame_info.lineno, col=0) + else: + raise NotImplementedError(f"{sys.version_info.minor} not supported.") @contextlib.contextmanager diff --git a/tests/test_func.py b/tests/test_func.py index d4dcc58..5fee634 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -1,15 +1,15 @@ import inspect +import sys from textwrap import dedent import pytest +import mlir.utils.types as T from mlir.utils.dialects.ext.arith import constant from mlir.utils.dialects.ext.func import func # noinspection PyUnresolvedReferences from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext -import mlir.utils.types as T -from mlir.utils.util import is_311 # needed since the fix isn't defined here nor conftest.py pytest.mark.usefixtures("ctx") @@ -41,10 +41,14 @@ def test_declare_byte_rep(ctx: MLIRContext): def demo_fun1(): ... - if is_311(): + if sys.version_info.minor == 12: + assert demo_fun1.__code__.co_code == b"\x97\x00y\x00" + elif sys.version_info.minor == 11: assert demo_fun1.__code__.co_code == b"\x97\x00d\x00S\x00" - else: + elif sys.version_info.minor == 10: assert demo_fun1.__code__.co_code == b"d\x00S\x00" + else: + raise NotImplementedError(f"{sys.version_info.minor} not supported.") def test_declare(ctx: MLIRContext): diff --git a/tests/test_location_tracking.py b/tests/test_location_tracking.py index c370fec..a59dba8 100644 --- a/tests/test_location_tracking.py +++ b/tests/test_location_tracking.py @@ -1,3 +1,4 @@ +import sys from os import sep from pathlib import Path from textwrap import dedent @@ -13,7 +14,6 @@ # noinspection PyUnresolvedReferences from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext -from mlir.utils.util import is_311 # needed since the fix isn't defined here nor conftest.py pytest.mark.usefixtures("ctx") @@ -27,7 +27,7 @@ def get_asm(operation): ) -@pytest.mark.skipif(not is_311(), reason="310 doesn't have col numbers") +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_replace_yield_5(ctx: MLIRContext): @canonicalize(using=canonicalizer) def iffoo(): @@ -72,7 +72,7 @@ def iffoo(): filecheck(correct, asm) -@pytest.mark.skipif(not is_311(), reason="310 doesn't have col numbers") +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_block_args(ctx: MLIRContext): one = constant(1, T.index) two = constant(2, T.index) diff --git a/tests/test_other_hosts.py b/tests/test_other_hosts.py new file mode 100644 index 0000000..c7c5f36 --- /dev/null +++ b/tests/test_other_hosts.py @@ -0,0 +1,174 @@ +from textwrap import dedent + +import aie +import aie.mlir.utils.types as T +import pytest +from aie.mlir.ir import Value +from aie.mlir.utils.context import mlir_mod_ctx +from aie.mlir.utils.dialects.aie import tile, buffer, CoreOp, end +from aie.mlir.utils.dialects.ext.arith import constant +from aie.mlir.utils.dialects.ext.memref import load +from aie.mlir.utils.dialects.memref import store +from aie.mlir.utils.meta import region_op +# noinspection PyUnresolvedReferences +from aie.mlir.utils.testing import filecheck, MLIRContext +from aie.mlir.utils.util import get_user_code_loc + +from util import ( + skip_jax_not_installed, + mlir_bindings_installed, + aie_bindings_installed, +) + + +@pytest.mark.skipif( + mlir_bindings_installed(), reason="mlir python bindings not installed" +) +def test_smoke(): + from mlir.utils.context import mlir_mod_ctx + from mlir.utils.testing import filecheck + + with mlir_mod_ctx(allow_unregistered_dialects=True) as ctx: + correct = dedent( + """\ + module { + } + """ + ) + + filecheck(correct, ctx.module) + + +@pytest.mark.skipif( + mlir_bindings_installed(), reason="mlir python bindings not installed" +) +def test_dialect_trampolines_smoke(): + from mlir.utils._configuration.generate_trampolines import ( + generate_all_upstream_trampolines, + ) + + generate_all_upstream_trampolines() + # noinspection PyUnresolvedReferences + from mlir.utils.dialects import ( + arith, + bufferization, + builtin, + cf, + complex, + func, + gpu, + linalg, + math, + memref, + ml_program, + pdl, + scf, + shape, + sparse_tensor, + tensor, + tosa, + transform, + vector, + ) + + +@pytest.mark.skipif(skip_jax_not_installed(), reason="jax not installed") +def test_jax_trampolines_smoke(): + # noinspection PyUnresolvedReferences + from jaxlib.mlir.utils.dialects import ( + arith, + builtin, + chlo, + func, + math, + memref, + mhlo, + ml_program, + scf, + sparse_tensor, + stablehlo, + vector, + ) + + +@pytest.fixture +def ctx() -> MLIRContext: + with mlir_mod_ctx(allow_unregistered_dialects=True) as ctx: + aie.dialects.aie.register_dialect(ctx.context) + yield ctx + + +# needed since the fix isn't defined here nor conftest.py +pytest.mark.usefixtures("ctx") + + +def core(tile: Value, *, stack_size=None, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() + return CoreOp(T.index, tile, stackSize=stack_size, loc=loc, ip=ip) + + +core = region_op(core, terminator=lambda *args: end()) + + +def test_basic(ctx: MLIRContext): + tile13 = tile(T.index, 1, 3) + + @core(tile13) + def demo_fun1(): + one = constant(1) + + correct = dedent( + """\ + module { + %0 = AIE.tile(1, 3) + %1 = AIE.core(%0) { + %c1_i32 = arith.constant 1 : i32 + AIE.end + } + } + """ + ) + filecheck(correct, ctx.module) + + +@pytest.mark.skipif(aie_bindings_installed(), reason="aie bindings not installed") +def test01_memory_read_write(ctx: MLIRContext): + tile13 = tile(T.index, 1, 3) + buf13_0 = buffer(T.memref(256, T.i32), tile13) + + @core(tile13) + def core13(): + val1 = constant(7) + idx1 = constant(7, index=True) + two = val1 + val1 + store(two, buf13_0, [idx1]) + val2 = constant(8) + idx2 = constant(5, index=True) + store(val2, buf13_0, [idx2]) + val3 = load(buf13_0, [idx1]) + idx3 = constant(9, index=True) + store(val3, buf13_0, [idx3]) + + correct = dedent( + """\ + module { + %0 = AIE.tile(1, 3) + %1 = AIE.buffer(%0) : memref<256xi32> + %2 = AIE.core(%0) { + %c7_i32 = arith.constant 7 : i32 + %c7 = arith.constant 7 : index + %3 = arith.addi %c7_i32, %c7_i32 : i32 + memref.store %3, %1[%c7] : memref<256xi32> + %c8_i32 = arith.constant 8 : i32 + %c5 = arith.constant 5 : index + memref.store %c8_i32, %1[%c5] : memref<256xi32> + %4 = memref.load %1[%c7] : memref<256xi32> + %c9 = arith.constant 9 : index + memref.store %4, %1[%c9] : memref<256xi32> + AIE.end + } + } + """ + ) + filecheck(correct, ctx.module) diff --git a/tests/test_smoke.py b/tests/test_smoke.py deleted file mode 100644 index d372f5b..0000000 --- a/tests/test_smoke.py +++ /dev/null @@ -1,87 +0,0 @@ -from pathlib import Path -from textwrap import dedent - -import pytest - -from util import skip_jax_not_installed, skip_torch_mlir_not_installed - - -def test_smoke(): - from mlir.utils.context import mlir_mod_ctx - from mlir.utils.testing import filecheck - - with mlir_mod_ctx(allow_unregistered_dialects=True) as ctx: - correct = dedent( - """\ - module { - } - """ - ) - - filecheck(correct, ctx.module) - - -def test_dialect_trampolines_smoke(): - from mlir.utils._configuration.generate_trampolines import ( - generate_all_upstream_trampolines, - ) - - generate_all_upstream_trampolines() - # noinspection PyUnresolvedReferences - from mlir.utils.dialects import ( - arith, - bufferization, - builtin, - cf, - complex, - func, - gpu, - linalg, - math, - memref, - ml_program, - pdl, - scf, - shape, - sparse_tensor, - tensor, - tosa, - transform, - vector, - ) - - -@pytest.mark.skipif(skip_torch_mlir_not_installed(), reason="torch_mlir not installed") -def test_torch_dialect_trampolines_smoke(): - import torch_mlir.utils.dialects - - from torch_mlir.utils._configuration.generate_trampolines import ( - generate_trampolines, - ) - - generate_trampolines( - "torch_mlir.dialects.torch", - Path(torch_mlir.utils.dialects.__path__[0]), - "torch", - ) - # noinspection PyUnresolvedReferences - from torch_mlir.utils.dialects import torch - - -@pytest.mark.skipif(skip_jax_not_installed(), reason="jax not installed") -def test_jax_trampolines_smoke(): - # noinspection PyUnresolvedReferences - from jaxlib.mlir.utils.dialects import ( - arith, - builtin, - chlo, - func, - math, - memref, - mhlo, - ml_program, - scf, - sparse_tensor, - stablehlo, - vector, - ) diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 8dade0e..b8dc2f0 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -1,4 +1,5 @@ import ast +import sys from textwrap import dedent import astpretty @@ -16,7 +17,6 @@ # noinspection PyUnresolvedReferences from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext -from mlir.utils.util import is_311 # needed since the fix isn't defined here nor conftest.py pytest.mark.usefixtures("ctx") @@ -33,6 +33,7 @@ def _fields(n: ast.AST, show_offsets: bool = True) -> tuple[str, ...]: astpretty._fields = _fields +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_handle_yield_1(): def iffoo(): one = constant(1.0) @@ -121,6 +122,7 @@ def iffoo(): Return(lineno=7, value=None), ], returns=None, + type_params=[], ), ], ) @@ -129,6 +131,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_handle_yield_2(): def iffoo(): one = constant(1.0) @@ -211,6 +214,7 @@ def iffoo(): Return(lineno=6, value=None), ], returns=None, + type_params=[], ), ], ) @@ -219,6 +223,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_handle_yield_3(): def iffoo(): one = constant(1.0) @@ -308,6 +313,7 @@ def iffoo(): Return(lineno=7, value=None), ], returns=None, + type_params=[], ), ], ) @@ -317,6 +323,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_handle_yield_4(): def iffoo(): one = constant(1.0) @@ -328,7 +335,7 @@ def iffoo(): mod = transform_func(iffoo, ReplaceYieldWithSCFYield) - if is_311(): + if sys.version_info.minor >= 11: correct = dedent( """\ def iffoo(): @@ -340,7 +347,7 @@ def iffoo(): return """ ) - else: + elif sys.version_info.minor == 10: correct = dedent( """\ def iffoo(): @@ -352,6 +359,9 @@ def iffoo(): return """ ) + else: + raise NotImplementedError(f"{sys.version_info.minor} not supported.") + assert correct.strip() == ast.unparse(mod) dump = astpretty.pformat(mod, show_offsets=True) @@ -430,6 +440,7 @@ def iffoo(): Return(lineno=7, value=None), ], returns=None, + type_params=[], ), ], ) @@ -439,6 +450,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_nested_no_else_no_yield(): def iffoo(): one = constant(1.0) @@ -552,6 +564,7 @@ def iffoo(): Return(lineno=8, value=None), ], returns=None, + type_params=[], ), ], ) @@ -561,6 +574,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_replace_cond_1(): def iffoo(): one = constant(1.0) @@ -662,6 +676,7 @@ def iffoo(): Return(lineno=7, value=None), ], returns=None, + type_params=[], ), ], ) @@ -671,6 +686,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_replace_cond_2(): def iffoo(): one = constant(1.0) @@ -783,6 +799,7 @@ def iffoo(): Return(lineno=7, value=None), ], returns=None, + type_params=[], ), ], ) @@ -792,6 +809,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_replace_cond_3(): def iffoo(): one = constant(1.0) @@ -803,7 +821,7 @@ def iffoo(): mod = transform_func(iffoo, ReplaceYieldWithSCFYield, ReplaceIfWithWith) - if is_311(): + if sys.version_info.minor >= 11: correct = dedent( """\ def iffoo(): @@ -815,7 +833,7 @@ def iffoo(): return """ ) - else: + elif sys.version_info.minor == 10: correct = dedent( """\ def iffoo(): @@ -827,6 +845,9 @@ def iffoo(): return """ ) + else: + raise NotImplementedError(f"{sys.version_info.minor} not supported.") + assert correct.strip() == ast.unparse(mod) dump = astpretty.pformat(mod, show_offsets=True) @@ -933,6 +954,7 @@ def iffoo(): Return(lineno=7, value=None), ], returns=None, + type_params=[], ), ], ) @@ -941,6 +963,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_nested_with_else_no_yield(): def iffoo(): one = constant(1.0) @@ -1073,6 +1096,7 @@ def iffoo(): Return(lineno=10, value=None), ], returns=None, + type_params=[], ), ], ) @@ -1082,6 +1106,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_insert_end_ifs_yield(): def iffoo(): one = constant(1.0) @@ -1226,6 +1251,7 @@ def iffoo(): Return(lineno=8, value=None), ], returns=None, + type_params=[], ), ], ) @@ -1235,6 +1261,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_else_with_nested_no_yields_yield_results(): def iffoo(): one = constant(1.0) @@ -1441,6 +1468,7 @@ def iffoo(): Return(lineno=12, value=None), ], returns=None, + type_params=[], ), ], ) @@ -1450,6 +1478,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_else_with_nested_no_yields_yield_multiple_results(): def iffoo(): one = constant(1.0) @@ -1668,6 +1697,7 @@ def iffoo(): Return(lineno=12, value=None), ], returns=None, + type_params=[], ), ], ) @@ -1677,6 +1707,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_with_else_else_with_yields(): def iffoo(): one = constant(1.0) @@ -1901,6 +1932,7 @@ def iffoo(): Return(lineno=13, value=None), ], returns=None, + type_params=[], ), ], ) @@ -1909,6 +1941,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_if_canonicalize_elif_elif(): def iffoo(): one = constant(1.0) @@ -2220,6 +2253,7 @@ def iffoo(): Return(lineno=17, value=None), ], returns=None, + type_params=[], ), ], ) @@ -2229,6 +2263,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_elif_1(): def iffoo(): one = constant(1.0) @@ -2475,6 +2510,7 @@ def iffoo(): Return(lineno=13, value=None), ], returns=None, + type_params=[], ), ], ) @@ -2484,6 +2520,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_elif_2(): def iffoo(): one = constant(1.0) @@ -2817,6 +2854,7 @@ def iffoo(): Return(lineno=16, value=None), ], returns=None, + type_params=[], ), ], ) @@ -2826,6 +2864,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_elif_3(): def iffoo(): one = constant(1.0) @@ -3223,6 +3262,7 @@ def iffoo(): Return(lineno=18, value=None), ], returns=None, + type_params=[], ), ], ) @@ -3232,6 +3272,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_elif_nested_else_branch(): def iffoo(): one = constant(1.0) @@ -3660,6 +3701,7 @@ def iffoo(): Return(lineno=24, value=None), ], returns=None, + type_params=[], ), ], ) @@ -3669,6 +3711,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_elif_nested_else_branch_multiple_yield(ctx: MLIRContext): def iffoo(): one = constant(1.0) @@ -4209,6 +4252,7 @@ def iffoo(): Return(lineno=24, value=None), ], returns=None, + type_params=[], ), ], ) @@ -4218,6 +4262,7 @@ def iffoo(): assert correct.strip() == dump +@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest") def test_while_canonicalize(ctx: MLIRContext): one = constant(1) two = constant(2) @@ -4296,6 +4341,7 @@ def foo(): ), ], returns=None, + type_params=[], ), ], ) diff --git a/tests/util.py b/tests/util.py index 9f021eb..dc0744f 100644 --- a/tests/util.py +++ b/tests/util.py @@ -20,3 +20,27 @@ def skip_jax_not_installed(): except ImportError: # skip return True + + +def mlir_bindings_installed(): + try: + import mlir + + # don't skip + return False + + except ImportError: + # skip + return True + + +def aie_bindings_installed(): + try: + import aie + + # don't skip + return False + + except ImportError: + # skip + return True