From a0af0d3cdda16162524fa60f388a2aeaf62f6087 Mon Sep 17 00:00:00 2001 From: max Date: Sun, 17 Dec 2023 11:35:36 -0600 Subject: [PATCH] use upstream generated --- .github/workflows/test.yml | 3 --- mlir/utils/dialects/ext/arith.py | 6 ------ mlir/utils/dialects/ext/gpu.py | 2 +- mlir/utils/dialects/ext/memref.py | 2 +- mlir/utils/dialects/ext/scf.py | 2 +- mlir/utils/dialects/ext/tensor.py | 2 +- mlir/utils/runtime/refbackend.py | 2 +- tests/test_gpu.py | 8 ++++---- tests/test_location_tracking.py | 2 +- tests/test_memref.py | 2 +- tests/test_nvgpu_nvvm.py | 4 ++-- tests/test_other_hosts.py | 13 +------------ tests/test_regions.py | 12 ++++++------ tests/test_runtime.py | 6 +++--- tests/test_scf.py | 2 +- tests/test_transform.py | 4 ++-- 16 files changed, 26 insertions(+), 46 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a53c919..282697f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,7 +40,6 @@ jobs: shell: bash run: | pip install .[test,mlir] -v -f https://makslevental.github.io/wheels - mlir-python-utils-generate-all-upstream-trampolines - name: Test shell: bash @@ -81,7 +80,6 @@ jobs: run: | export PIP_FIND_LINKS=https://makslevental.github.io/wheels HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[test,jax] -v - jaxlib-mlir-python-utils-generate-all-upstream-trampolines - name: Test shell: bash @@ -160,7 +158,6 @@ jobs: cd /workspace pip install -q .[test,mlir] -f https://makslevental.github.io/wheels - mlir-python-utils-generate-all-upstream-trampolines pytest --capture=tee-sys tests python examples/mwe.py diff --git a/mlir/utils/dialects/ext/arith.py b/mlir/utils/dialects/ext/arith.py index 450eecc..d2ee166 100644 --- a/mlir/utils/dialects/ext/arith.py +++ b/mlir/utils/dialects/ext/arith.py @@ -45,12 +45,6 @@ from ...util import get_result_or_results, get_user_code_loc from ...meta import register_value_caster, maybe_cast - -try: - from ...dialects.arith import * -except ModuleNotFoundError: - pass - from ...types import infer_mlir_type, mlir_type_to_np_dtype diff --git a/mlir/utils/dialects/ext/gpu.py b/mlir/utils/dialects/ext/gpu.py index 02f321c..c620fb5 100644 --- a/mlir/utils/dialects/ext/gpu.py +++ b/mlir/utils/dialects/ext/gpu.py @@ -32,7 +32,7 @@ from ... import types as T from ...dialects.ext.arith import constant from ...dialects.ext.func import FuncBase -from ...dialects.gpu import block_id, module_end +from ....dialects.gpu import block_id, module_end from ...meta import ( ModuleMeta, make_maybe_no_args_decorator, diff --git a/mlir/utils/dialects/ext/memref.py b/mlir/utils/dialects/ext/memref.py index b606143..2e7d019 100644 --- a/mlir/utils/dialects/ext/memref.py +++ b/mlir/utils/dialects/ext/memref.py @@ -5,7 +5,7 @@ from ....ir import Type, Value, MemRefType, ShapedType, MLIRError from ... import types as T -from ...dialects import memref, arith +from ....dialects import memref, arith from ...dialects.ext.arith import Scalar, constant from ...dialects.ext.tensor import ( _indices_to_indexer, diff --git a/mlir/utils/dialects/ext/scf.py b/mlir/utils/dialects/ext/scf.py index a97e7f6..9e74286 100644 --- a/mlir/utils/dialects/ext/scf.py +++ b/mlir/utils/dialects/ext/scf.py @@ -15,7 +15,7 @@ from ...ast.util import ast_call, set_lineno from ...dialects.ext.arith import constant, index_cast from ...dialects.ext.gpu import get_device_mapping_array_attr -from ...dialects.scf import yield_ as yield__, reduce_return, condition +from ....dialects.scf import yield_ as yield__, reduce_return, condition from ...meta import region_adder, region_op, maybe_cast from ...util import get_result_or_results, get_user_code_loc from ....dialects._ods_common import ( diff --git a/mlir/utils/dialects/ext/tensor.py b/mlir/utils/dialects/ext/tensor.py index d387888..17341b4 100644 --- a/mlir/utils/dialects/ext/tensor.py +++ b/mlir/utils/dialects/ext/tensor.py @@ -15,7 +15,7 @@ ) from ... import types as T -from ...dialects import tensor +from ....dialects import tensor from ...dialects.ext.arith import ArithValue, Scalar, constant from ...meta import ( register_value_caster, diff --git a/mlir/utils/runtime/refbackend.py b/mlir/utils/runtime/refbackend.py index 9889967..dcd844b 100644 --- a/mlir/utils/runtime/refbackend.py +++ b/mlir/utils/runtime/refbackend.py @@ -23,7 +23,7 @@ from .. import types as T -from ..dialects.memref import cast +from ...dialects.memref import cast from ..runtime.passes import Pipeline, run_pipeline from ..types import ( memref_type_to_np_dtype, diff --git a/tests/test_gpu.py b/tests/test_gpu.py index 7671883..846ef19 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -21,11 +21,11 @@ from mlir.utils.dialects.ext.memref import load, store from mlir.utils.dialects.ext.scf import canonicalizer from mlir.utils.dialects.ext.scf import forall, in_parallel_ -from mlir.utils.dialects.gpu import host_register +from mlir.dialects.gpu import host_register from mlir.utils.dialects.ext.gpu import all_reduce, wait -from mlir.utils.dialects.llvm import mlir_zero -from mlir.utils.dialects.math import fma -from mlir.utils.dialects.memref import cast +from mlir.dialects.llvm import mlir_zero +from mlir.dialects.math import fma +from mlir.dialects.memref import cast from mlir.utils.runtime.passes import run_pipeline, Pipeline # noinspection PyUnresolvedReferences diff --git a/tests/test_location_tracking.py b/tests/test_location_tracking.py index 08eb739..499fffa 100644 --- a/tests/test_location_tracking.py +++ b/tests/test_location_tracking.py @@ -10,7 +10,7 @@ from mlir.utils.dialects.ext.arith import constant from mlir.utils.dialects.ext.scf import canonicalizer from mlir.utils.dialects.ext.tensor import S -from mlir.utils.dialects.tensor import generate, yield_ as tensor_yield, rank +from mlir.dialects.tensor import generate, yield_ as tensor_yield, rank # noinspection PyUnresolvedReferences from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext diff --git a/tests/test_memref.py b/tests/test_memref.py index d18bb4b..9d70068 100644 --- a/tests/test_memref.py +++ b/tests/test_memref.py @@ -13,7 +13,7 @@ yield_, canonicalizer, ) -from mlir.utils.dialects.memref import subview +from mlir.dialects.memref import subview # noinspection PyUnresolvedReferences from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext diff --git a/tests/test_nvgpu_nvvm.py b/tests/test_nvgpu_nvvm.py index c4438ee..080ce60 100644 --- a/tests/test_nvgpu_nvvm.py +++ b/tests/test_nvgpu_nvvm.py @@ -6,8 +6,8 @@ from mlir.utils.dialects.ext.arith import constant from mlir.utils.dialects.ext.func import func from mlir.utils.dialects.ext.nvgpu import tensormap_descriptor -from mlir.utils.dialects.memref import cast -from mlir.utils.dialects.nvgpu import tma_create_descriptor +from mlir.dialects.memref import cast +from mlir.dialects.nvgpu import tma_create_descriptor # noinspection PyUnresolvedReferences from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext diff --git a/tests/test_other_hosts.py b/tests/test_other_hosts.py index 368b15a..09e59e9 100644 --- a/tests/test_other_hosts.py +++ b/tests/test_other_hosts.py @@ -3,21 +3,10 @@ from util import jax_not_installed, mlir_bindings_not_installed -@pytest.mark.skipif( - mlir_bindings_not_installed(), reason="mlir python bindings not installed" -) -def test_dialect_trampolines_smoke(): - from mlir.utils._configuration.generate_trampolines import ( - generate_all_upstream_trampolines, - ) - - generate_all_upstream_trampolines() - - @pytest.mark.skipif(jax_not_installed(), reason="jax not installed") def test_jax_trampolines_smoke(): # noinspection PyUnresolvedReferences - from jaxlib.mlir.utils.dialects import ( + from jaxlib.mlir.dialects import ( arith, builtin, chlo, diff --git a/tests/test_regions.py b/tests/test_regions.py index c299a6e..d5b3fdc 100644 --- a/tests/test_regions.py +++ b/tests/test_regions.py @@ -3,17 +3,17 @@ import pytest import mlir.utils.types as T -from mlir.utils.dialects import linalg +from mlir.dialects import linalg from mlir.utils.dialects.ext import memref from mlir.utils.dialects.ext.arith import constant from mlir.utils.dialects.ext.cf import br, cond_br from mlir.utils.dialects.ext.func import func from mlir.utils.dialects.ext.tensor import S -from mlir.utils.dialects.func import return_ -from mlir.utils.dialects.memref import alloca_scope, alloca_scope_return -from mlir.utils.dialects.scf import execute_region, yield_ as scf_yield -from mlir.utils.dialects.tensor import generate, yield_ as tensor_yield -from mlir.utils.dialects.tensor import rank +from mlir.dialects.func import return_ +from mlir.dialects.memref import alloca_scope, alloca_scope_return +from mlir.dialects.scf import execute_region, yield_ as scf_yield +from mlir.dialects.tensor import generate, yield_ as tensor_yield +from mlir.dialects.tensor import rank from mlir.utils.meta import bb # noinspection PyUnresolvedReferences diff --git a/tests/test_runtime.py b/tests/test_runtime.py index d4c63da..c3bfa61 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -13,8 +13,8 @@ import mlir.utils.types as T from mlir.utils.ast.canonicalize import canonicalize -from mlir.utils.dialects import linalg -from mlir.utils.dialects.arith import sitofp, index_cast +from mlir.dialects import linalg +from mlir.dialects.arith import sitofp, index_cast from mlir.utils.dialects.ext.arith import constant from mlir.utils.dialects.ext.func import func from mlir.utils.dialects.ext.memref import load, store, S @@ -22,7 +22,7 @@ canonicalizer, range_, ) -from mlir.utils.dialects.memref import cast +from mlir.dialects.memref import cast from mlir.utils.runtime.passes import Pipeline, run_pipeline from mlir.utils.runtime.refbackend import ( LLVMJITBackend, diff --git a/tests/test_scf.py b/tests/test_scf.py index 92d02cc..aa02fb5 100644 --- a/tests/test_scf.py +++ b/tests/test_scf.py @@ -25,7 +25,7 @@ while___, ) from mlir.utils.dialects.ext.tensor import empty, Tensor -from mlir.utils.dialects.memref import alloca_scope, alloca_scope_return +from mlir.dialects.memref import alloca_scope, alloca_scope_return # noinspection PyUnresolvedReferences from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext diff --git a/tests/test_transform.py b/tests/test_transform.py index fbeb568..d999834 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -5,7 +5,7 @@ from mlir.utils import types as T from mlir.utils.ast.canonicalize import canonicalize -from mlir.utils.dialects import linalg, arith +from mlir.dialects import linalg, arith from mlir.utils.dialects.ext.func import func from mlir.utils.dialects.ext.gpu import block_attr, thread_attr from mlir.utils.dialects.ext.scf import ( @@ -22,7 +22,7 @@ tile_to_scf_forall, apply_patterns, ) -from mlir.utils.dialects.transform import apply_patterns_canonicalization, apply_cse +from mlir.dialects.transform import apply_patterns_canonicalization, apply_cse from mlir.utils.runtime.passes import run_pipeline, Pipeline # noinspection PyUnresolvedReferences