Skip to content

Commit

Permalink
use upstream generated
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 17, 2023
1 parent 2148970 commit dc74d75
Show file tree
Hide file tree
Showing 13 changed files with 22 additions and 42 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ jobs:
shell: bash
run: |
pip install .[test,mlir] -v -f https://makslevental.github.io/wheels
mlir-python-utils-generate-all-upstream-trampolines
- name: Test
shell: bash
Expand Down Expand Up @@ -81,7 +80,6 @@ jobs:
run: |
export PIP_FIND_LINKS=https://makslevental.github.io/wheels
HOST_MLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir pip install .[test,jax] -v
jaxlib-mlir-python-utils-generate-all-upstream-trampolines
- name: Test
shell: bash
Expand Down Expand Up @@ -160,7 +158,6 @@ jobs:
cd /workspace
pip install -q .[test,mlir] -f https://makslevental.github.io/wheels
mlir-python-utils-generate-all-upstream-trampolines
pytest --capture=tee-sys tests
python examples/mwe.py
6 changes: 0 additions & 6 deletions mlir/utils/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,6 @@

from ...util import get_result_or_results, get_user_code_loc
from ...meta import register_value_caster, maybe_cast

try:
from ...dialects.arith import *
except ModuleNotFoundError:
pass

from ...types import infer_mlir_type, mlir_type_to_np_dtype


Expand Down
2 changes: 1 addition & 1 deletion mlir/utils/dialects/ext/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ... import types as T
from ...dialects.ext.arith import constant
from ...dialects.ext.func import FuncBase
from ...dialects.gpu import block_id, module_end
from ....dialects.gpu import block_id, module_end
from ...meta import (
ModuleMeta,
make_maybe_no_args_decorator,
Expand Down
2 changes: 1 addition & 1 deletion mlir/utils/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ...ast.util import ast_call, set_lineno
from ...dialects.ext.arith import constant, index_cast
from ...dialects.ext.gpu import get_device_mapping_array_attr
from ...dialects.scf import yield_ as yield__, reduce_return, condition
from ....dialects.scf import yield_ as yield__, reduce_return, condition
from ...meta import region_adder, region_op, maybe_cast
from ...util import get_result_or_results, get_user_code_loc
from ....dialects._ods_common import (
Expand Down
8 changes: 4 additions & 4 deletions tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from mlir.utils.dialects.ext.memref import load, store
from mlir.utils.dialects.ext.scf import canonicalizer
from mlir.utils.dialects.ext.scf import forall, in_parallel_
from mlir.utils.dialects.gpu import host_register
from mlir.dialects.gpu import host_register
from mlir.utils.dialects.ext.gpu import all_reduce, wait
from mlir.utils.dialects.llvm import mlir_zero
from mlir.utils.dialects.math import fma
from mlir.utils.dialects.memref import cast
from mlir.dialects.llvm import mlir_zero
from mlir.dialects.math import fma
from mlir.dialects.memref import cast
from mlir.utils.runtime.passes import run_pipeline, Pipeline

# noinspection PyUnresolvedReferences
Expand Down
2 changes: 1 addition & 1 deletion tests/test_location_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mlir.utils.dialects.ext.arith import constant
from mlir.utils.dialects.ext.scf import canonicalizer
from mlir.utils.dialects.ext.tensor import S
from mlir.utils.dialects.tensor import generate, yield_ as tensor_yield, rank
from mlir.dialects.tensor import generate, yield_ as tensor_yield, rank

# noinspection PyUnresolvedReferences
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
Expand Down
2 changes: 1 addition & 1 deletion tests/test_memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
yield_,
canonicalizer,
)
from mlir.utils.dialects.memref import subview
from mlir.dialects.memref import subview

# noinspection PyUnresolvedReferences
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nvgpu_nvvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from mlir.utils.dialects.ext.arith import constant
from mlir.utils.dialects.ext.func import func
from mlir.utils.dialects.ext.nvgpu import tensormap_descriptor
from mlir.utils.dialects.memref import cast
from mlir.utils.dialects.nvgpu import tma_create_descriptor
from mlir.dialects.memref import cast
from mlir.dialects.nvgpu import tma_create_descriptor

# noinspection PyUnresolvedReferences
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
Expand Down
13 changes: 1 addition & 12 deletions tests/test_other_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,10 @@
from util import jax_not_installed, mlir_bindings_not_installed


@pytest.mark.skipif(
mlir_bindings_not_installed(), reason="mlir python bindings not installed"
)
def test_dialect_trampolines_smoke():
from mlir.utils._configuration.generate_trampolines import (
generate_all_upstream_trampolines,
)

generate_all_upstream_trampolines()


@pytest.mark.skipif(jax_not_installed(), reason="jax not installed")
def test_jax_trampolines_smoke():
# noinspection PyUnresolvedReferences
from jaxlib.mlir.utils.dialects import (
from jaxlib.mlir.dialects import (
arith,
builtin,
chlo,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import pytest

import mlir.utils.types as T
from mlir.utils.dialects import linalg
from mlir.dialects import linalg
from mlir.utils.dialects.ext import memref
from mlir.utils.dialects.ext.arith import constant
from mlir.utils.dialects.ext.cf import br, cond_br
from mlir.utils.dialects.ext.func import func
from mlir.utils.dialects.ext.tensor import S
from mlir.utils.dialects.func import return_
from mlir.utils.dialects.memref import alloca_scope, alloca_scope_return
from mlir.dialects.func import return_
from mlir.dialects.memref import alloca_scope, alloca_scope_return
from mlir.utils.dialects.scf import execute_region, yield_ as scf_yield
from mlir.utils.dialects.tensor import generate, yield_ as tensor_yield
from mlir.utils.dialects.tensor import rank
from mlir.dialects.tensor import generate, yield_ as tensor_yield
from mlir.dialects.tensor import rank
from mlir.utils.meta import bb

# noinspection PyUnresolvedReferences
Expand Down
6 changes: 3 additions & 3 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@

import mlir.utils.types as T
from mlir.utils.ast.canonicalize import canonicalize
from mlir.utils.dialects import linalg
from mlir.utils.dialects.arith import sitofp, index_cast
from mlir.dialects import linalg
from mlir.dialects.arith import sitofp, index_cast
from mlir.utils.dialects.ext.arith import constant
from mlir.utils.dialects.ext.func import func
from mlir.utils.dialects.ext.memref import load, store, S
from mlir.utils.dialects.ext.scf import (
canonicalizer,
range_,
)
from mlir.utils.dialects.memref import cast
from mlir.dialects.memref import cast
from mlir.utils.runtime.passes import Pipeline, run_pipeline
from mlir.utils.runtime.refbackend import (
LLVMJITBackend,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
while___,
)
from mlir.utils.dialects.ext.tensor import empty, Tensor
from mlir.utils.dialects.memref import alloca_scope, alloca_scope_return
from mlir.dialects.memref import alloca_scope, alloca_scope_return

# noinspection PyUnresolvedReferences
from mlir.utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
Expand Down
4 changes: 2 additions & 2 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mlir.utils import types as T
from mlir.utils.ast.canonicalize import canonicalize
from mlir.utils.dialects import linalg, arith
from mlir.dialects import linalg, arith
from mlir.utils.dialects.ext.func import func
from mlir.utils.dialects.ext.gpu import block_attr, thread_attr
from mlir.utils.dialects.ext.scf import (
Expand All @@ -22,7 +22,7 @@
tile_to_scf_forall,
apply_patterns,
)
from mlir.utils.dialects.transform import apply_patterns_canonicalization, apply_cse
from mlir.dialects.transform import apply_patterns_canonicalization, apply_cse
from mlir.utils.runtime.passes import run_pipeline, Pipeline

# noinspection PyUnresolvedReferences
Expand Down

0 comments on commit dc74d75

Please sign in to comment.