Skip to content

Commit

Permalink
add py312 tests
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Oct 3, 2023
1 parent e97d5c0 commit d80dfd3
Show file tree
Hide file tree
Showing 12 changed files with 342 additions and 159 deletions.
66 changes: 51 additions & 15 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ on:

jobs:

test-all:
test-mlir-bindings:

runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
os: [ ubuntu-22.04, macos-11, windows-2022 ]
py_version: [ "3.10", "3.11" ]
py_version: [ "3.10", "3.11", "3.12" ]

steps:
- name: Checkout
Expand All @@ -34,14 +34,53 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py_version }}
allow-prereleases: true

- name: Install and configure
shell: bash
run: |
pip install .[test,mlir] -v -f https://makslevental.github.io/wheels
mlir-python-utils-generate-all-upstream-trampolines
HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[jax] -v
- name: Test
shell: bash
run: |
if [ ${{ matrix.os }} == 'windows-2022' ]; then
pytest -s --ignore-glob=*test_other_hosts* tests
else
pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests
fi
- name: Test mwe
shell: bash
run: |
python examples/mwe.py
test-other-host-bindings:

runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
os: [ ubuntu-22.04, macos-11, windows-2022 ]
py_version: [ "3.10", "3.11" ]

steps:
- name: Checkout
uses: actions/checkout@v2

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py_version }}
allow-prereleases: true

- name: Install and configure
shell: bash
run: |
export PIP_FIND_LINKS=https://makslevental.github.io/wheels
HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[test,jax] -v
jaxlib-mlir-python-utils-generate-all-upstream-trampolines
pip install aie -f https://github.com/Xilinx/mlir-aie/releases/expanded_assets/latest-wheels --no-index
Expand All @@ -53,16 +92,11 @@ jobs:
shell: bash
run: |
if [ ${{ matrix.os }} == 'windows-2022' ]; then
pytest -s tests
pytest -s tests/test_other_hosts.py
else
pytest --capture=tee-sys tests
pytest --capture=tee-sys tests/test_other_hosts.py
fi
- name: Test mwe
shell: bash
run: |
python examples/mwe.py
test-jupyter:

runs-on: ${{ matrix.os }}
Expand All @@ -81,6 +115,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py_version }}
allow-prereleases: true

- name: Run notebook
shell: bash
Expand All @@ -98,7 +133,7 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-22.04 ]
py_version: [ "3.10", "3.11" ]
py_version: [ "3.10", "3.11", "3.12" ]

steps:
- name: Checkout
Expand All @@ -120,16 +155,17 @@ jobs:
bash miniconda.sh -b -u -p /root/miniconda3
eval "$(/root/miniconda3/bin/conda shell.bash hook)"
conda init
conda install -q -y python=${{ matrix.py_version }}
run: |
eval "$(/root/miniconda3/bin/conda shell.bash hook)"
conda create -n env -q -y -c conda-forge/label/python_rc python=${{ matrix.py_version }}
conda activate env
cd /workspace
pip install .[test,mlir] -f https://makslevental.github.io/wheels
pip install -q .[test,mlir] -f https://makslevental.github.io/wheels
mlir-python-utils-generate-all-upstream-trampolines
pytest --capture=tee-sys --ignore-glob=*test_smoke* tests
pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests
python examples/mwe.py
2 changes: 2 additions & 0 deletions mlir/utils/_configuration/generate_trampolines.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def generate_op_trampoline(op_class):
args=args,
body=body,
decorator_list=decorator_list,
type_params=[],
)
ast.fix_missing_locations(n)
return n
Expand Down Expand Up @@ -323,6 +324,7 @@ def generate_linalg(mod_path):
),
body=body,
decorator_list=[],
type_params=[],
)
ast.fix_missing_locations(n)
functions.append(n)
Expand Down
2 changes: 1 addition & 1 deletion mlir/utils/_configuration/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def add_file_to_sources_txt_file(file_path: Path):

assert file_path.exists(), f"file being added doesn't exist at {file_path}"
relative_file_path = Path(package) / file_path.relative_to(package_root_path)
if dist._read_files_egginfo() is not None:
if hasattr(dist, "_read_files_egginfo") and dist._read_files_egginfo() is not None:
with open(dist._path / "SOURCES.txt", "a") as sources_file:
sources_file.write(f"\n{relative_file_path}")
if dist._read_files_distinfo():
Expand Down
14 changes: 9 additions & 5 deletions mlir/utils/dialects/ext/func.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import inspect
import sys
from typing import Union, Optional

from ...meta import make_maybe_no_args_decorator, maybe_cast
from ...util import get_result_or_results, get_user_code_loc
from ....dialects.func import FuncOp, ReturnOp, CallOp
from ....ir import (
InsertionPoint,
Expand All @@ -12,9 +15,6 @@
Value,
)

from ...util import get_result_or_results, get_user_code_loc, is_311
from ...meta import make_maybe_no_args_decorator, maybe_cast


def call(
callee_or_results: Union[FuncOp, list[Type]],
Expand Down Expand Up @@ -145,10 +145,14 @@ def __init__(

def _is_decl(self):
# magic constant found from looking at the code for an empty fn
if is_311():
if sys.version_info.minor == 12:
return self.body_builder.__code__.co_code == b"\x97\x00y\x00"
elif sys.version_info.minor == 11:
return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00"
else:
elif sys.version_info.minor == 10:
return self.body_builder.__code__.co_code == b"d\x00S\x00"
else:
raise NotImplementedError(f"{sys.version_info.minor} not supported.")

def __str__(self):
return str(f"{self.__class__} {self.__dict__}")
Expand Down
48 changes: 15 additions & 33 deletions mlir/utils/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@
from typing import Optional, Sequence, Union

from bytecode import ConcreteBytecode, ConcreteInstr

from ... import types as T
from ...ast.canonicalize import (
StrictTransformer,
Canonicalizer,
BytecodePatcher,
OpCode,
)
from ...ast.util import ast_call, set_lineno
from ...dialects.ext.arith import constant, index_cast
from ...dialects.ext.gpu import get_device_mapping_array_attr
from ...dialects.scf import yield_ as yield__, reduce_return, condition
from ...meta import region_adder, region_op, maybe_cast
from ...util import get_result_or_results, get_user_code_loc
from ....dialects._ods_common import get_op_results_or_values, get_default_loc_context
from ....dialects.linalg.opdsl.lang.emitter import _is_index_type
from ....dialects.scf import (
Expand All @@ -28,20 +42,6 @@
Attribute,
)

from ... import types as T
from ...ast.canonicalize import (
StrictTransformer,
Canonicalizer,
BytecodePatcher,
OpCode,
)
from ...ast.util import ast_call, set_lineno
from ...dialects.ext.arith import constant, index_cast
from ...dialects.ext.gpu import get_device_mapping_array_attr
from ...dialects.scf import yield_ as yield__, reduce_return, condition
from ...meta import region_adder, region_op, maybe_cast
from ...util import get_result_or_results, get_user_code_loc, is_311

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -633,25 +633,7 @@ def visit_If(self, updated_node: ast.If) -> ast.With | list[ast.With, ast.With]:

class RemoveJumpsAndInsertGlobals(BytecodePatcher):
def patch_bytecode(self, code: ConcreteBytecode, f):
early_returns = []
for i, c in enumerate(code):
c: ConcreteInstr
if c.opcode == int(OpCode.RETURN_VALUE):
early_returns.append(i)

if c.opcode in {
# this is the first test condition jump from python <= 3.10
# "POP_JUMP_IF_FALSE",
# this is the test condition jump from python >= 3.11
int(OpCode.POP_JUMP_FORWARD_IF_FALSE)
if is_311()
else int(OpCode.POP_JUMP_IF_FALSE),
}:
code[i] = ConcreteInstr(
str(OpCode.POP_TOP), lineno=c.lineno, location=c.location
)

# TODO(max): this is bad
# TODO(max): this is bad and should be in the closure rather than as a global
f.__globals__[yield_.__name__] = yield_
f.__globals__[if_ctx_manager.__name__] = if_ctx_manager
f.__globals__[else_ctx_manager.__name__] = else_ctx_manager
Expand Down
10 changes: 4 additions & 6 deletions mlir/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def get_result_or_results(
)


def is_311():
return sys.version_info.minor > 10


def get_user_code_loc(user_base: Optional[Path] = None):
from .. import utils

Expand All @@ -54,12 +50,14 @@ def get_user_code_loc(user_base: Optional[Path] = None):
):
prev_frame = prev_frame.f_back
frame_info = inspect.getframeinfo(prev_frame)
if is_311():
if sys.version_info.minor >= 11:
return Location.file(
frame_info.filename, frame_info.lineno, frame_info.positions.col_offset
)
else:
elif sys.version_info.minor == 10:
return Location.file(frame_info.filename, frame_info.lineno, col=0)
else:
raise NotImplementedError(f"{sys.version_info.minor} not supported.")


@contextlib.contextmanager
Expand Down
12 changes: 8 additions & 4 deletions tests/test_func.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import inspect
import sys
from textwrap import dedent

import pytest

import mlir.utils.types as T
from mlir.utils.dialects.ext.arith import constant
from mlir.utils.dialects.ext.func import func

# noinspection PyUnresolvedReferences
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
import mlir.utils.types as T
from mlir.utils.util import is_311

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")
Expand Down Expand Up @@ -41,10 +41,14 @@ def test_declare_byte_rep(ctx: MLIRContext):
def demo_fun1():
...

if is_311():
if sys.version_info.minor == 12:
assert demo_fun1.__code__.co_code == b"\x97\x00y\x00"
elif sys.version_info.minor == 11:
assert demo_fun1.__code__.co_code == b"\x97\x00d\x00S\x00"
else:
elif sys.version_info.minor == 10:
assert demo_fun1.__code__.co_code == b"d\x00S\x00"
else:
raise NotImplementedError(f"{sys.version_info.minor} not supported.")


def test_declare(ctx: MLIRContext):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_location_tracking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from os import sep
from pathlib import Path
from textwrap import dedent
Expand All @@ -13,7 +14,6 @@

# noinspection PyUnresolvedReferences
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
from mlir.utils.util import is_311

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")
Expand All @@ -27,7 +27,7 @@ def get_asm(operation):
)


@pytest.mark.skipif(not is_311(), reason="310 doesn't have col numbers")
@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest")
def test_if_replace_yield_5(ctx: MLIRContext):
@canonicalize(using=canonicalizer)
def iffoo():
Expand Down Expand Up @@ -72,7 +72,7 @@ def iffoo():
filecheck(correct, asm)


@pytest.mark.skipif(not is_311(), reason="310 doesn't have col numbers")
@pytest.mark.skipif(sys.version_info.minor != 12, reason="only check latest")
def test_block_args(ctx: MLIRContext):
one = constant(1, T.index)
two = constant(2, T.index)
Expand Down
Loading

0 comments on commit d80dfd3

Please sign in to comment.