Skip to content

Commit

Permalink
use star import in ext dir
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 18, 2023
1 parent dac8023 commit 69c8f26
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 84 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ on:
# At minute 0 past hour 6. (see https://crontab.guru)
- cron: '00 06 * * *'

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit).
group: ci-build-test-${{ github.event.number || github.sha }}
cancel-in-progress: true

jobs:

test-mlir-bindings:
Expand Down
1 change: 1 addition & 0 deletions mlir/extras/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from ....dialects import arith as arith_dialect
from ....dialects.arith import *
from ....dialects import complex as complex_dialect
from ....dialects._arith_enum_gen import (
_arith_cmpfpredicateattr,
Expand Down
2 changes: 1 addition & 1 deletion mlir/extras/dialects/ext/cf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ....dialects.cf import BranchOp, CondBranchOp
from ....dialects.cf import *
from ....dialects._cf_ops_gen import _Dialect
from ....dialects._ods_common import (
get_op_result_or_value,
Expand Down
2 changes: 1 addition & 1 deletion mlir/extras/dialects/ext/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...meta import make_maybe_no_args_decorator, op_region_builder
from ...util import get_user_code_loc
from ....dialects.func import FuncOp, ReturnOp, CallOp
from ....dialects.func import *
from ....ir import (
InsertionPoint,
FunctionType,
Expand Down
55 changes: 19 additions & 36 deletions mlir/extras/dialects/ext/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,7 @@
from typing import Optional, Any

from ....dialects._ods_common import get_default_loc_context, _cext
from ....dialects.gpu import (
AddressSpace,
MappingId,
GPUModuleOp,
GPUFuncOp,
LaunchFuncOp,
LaunchOp,
ReturnOp,
AllReduceOp,
YieldOp,
TerminatorOp,
WaitOp,
)
from ....dialects.gpu import *
from ....dialects._gpu_ops_gen import _Dialect
from ....ir import (
Type,
Expand All @@ -30,9 +18,8 @@
)

from ... import types as T
from ...dialects.ext.arith import constant
from ...dialects.ext.func import FuncBase
from ....dialects.gpu import block_id, module_end
from .arith import constant
from .func import FuncBase
from ...meta import (
ModuleMeta,
make_maybe_no_args_decorator,
Expand Down Expand Up @@ -318,20 +305,18 @@ def __call__(
size[i] = constant(s, index=True)

loc = get_user_code_loc()
return (
get_op_result_or_op_results(
LaunchFuncOp(
[self.qualname, self.func_name]
if self.qualname is not None
else [self.func_name],
grid_size,
block_size,
kernel_operands,
async_dependencies,
dynamic_shared_memory_size,
async_object=stream,
loc=loc,
)
return get_op_result_or_op_results(
LaunchFuncOp(
[self.qualname, self.func_name]
if self.qualname is not None
else [self.func_name],
grid_size,
block_size,
kernel_operands,
async_dependencies,
dynamic_shared_memory_size,
async_object=stream,
loc=loc,
)
)

Expand Down Expand Up @@ -409,10 +394,8 @@ def all_reduce__(value: Value, *, op=None, uniform=None, loc=None, ip=None):


def all_reduce_(value: Value, *, op=None, uniform=None, loc=None, ip=None):
return (
get_op_result_or_op_results(
all_reduce__(value, op=op, uniform=uniform, loc=loc, ip=ip)
)
return get_op_result_or_op_results(
all_reduce__(value, op=op, uniform=uniform, loc=loc, ip=ip)
)


Expand All @@ -425,6 +408,6 @@ def wait(async_dependencies: Optional[list[Value]] = None, *, loc=None, ip=None)
if async_dependencies is None:
async_dependencies = []
async_token = gpu_async_token()
return (
get_op_result_or_op_results(WaitOp(async_token, async_dependencies, loc=loc, ip=ip))
return get_op_result_or_op_results(
WaitOp(async_token, async_dependencies, loc=loc, ip=ip)
)
1 change: 1 addition & 0 deletions mlir/extras/dialects/ext/llvm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ....ir import Type
from ....dialects.llvm import *


def llvm_ptr_t():
Expand Down
20 changes: 8 additions & 12 deletions mlir/extras/dialects/ext/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from ....ir import Type, Value, MemRefType, ShapedType, MLIRError

from ... import types as T
from ....dialects.memref import *
from ....dialects import memref, arith
from ...dialects.ext.arith import Scalar, constant
from ...dialects.ext.tensor import (
_indices_to_indexer,
compute_result_shape_reassoc_list,
)
from .arith import Scalar, constant
from .tensor import _indices_to_indexer, compute_result_shape_reassoc_list
from ...meta import region_op
from ...._mlir_libs._mlir import register_value_caster
from ...util import get_user_code_loc
Expand Down Expand Up @@ -39,7 +37,7 @@ def _alloc(
def alloc(sizes: Sequence[Union[int, Value]], element_type: Type, *, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return _alloc(memref.AllocOp, sizes, element_type, loc=loc, ip=ip)
return _alloc(AllocOp, sizes, element_type, loc=loc, ip=ip)


def alloca(
Expand All @@ -48,7 +46,7 @@ def alloca(
if loc is None:
loc = get_user_code_loc()
return get_op_result_or_op_results(
_alloc(memref.AllocaOp, sizes, element_type, loc=loc, ip=ip)
_alloc(AllocaOp, sizes, element_type, loc=loc, ip=ip)
)


Expand All @@ -59,7 +57,7 @@ def load(mem: Value, indices: Sequence[Value | int], *, loc=None, ip=None):
for idx, i in enumerate(indices):
if isinstance(i, int):
indices[idx] = constant(i, index=True)
return get_op_result_or_op_results(memref.LoadOp(mem, indices, loc=loc, ip=ip))
return get_op_result_or_op_results(LoadOp(mem, indices, loc=loc, ip=ip))


def store(
Expand All @@ -71,9 +69,7 @@ def store(
for idx, i in enumerate(indices):
if isinstance(i, int):
indices[idx] = constant(i, index=True)
return get_op_result_or_op_results(
memref.StoreOp(value, mem, indices, loc=loc, ip=ip)
)
return get_op_result_or_op_results(StoreOp(value, mem, indices, loc=loc, ip=ip))


def subview(
Expand Down Expand Up @@ -345,4 +341,4 @@ def _copy_to_subview(
return memref.copy(source, dest_subview, loc=loc, ip=ip)


alloca_scope = region_op(memref.AllocaScopeOp)
alloca_scope = region_op(AllocaScopeOp)
23 changes: 5 additions & 18 deletions mlir/extras/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,37 @@
from copy import deepcopy
from typing import Optional, Sequence, Union

from bytecode import ConcreteBytecode, ConcreteInstr
from bytecode import ConcreteBytecode

from ... import types as T
from ...ast.canonicalize import (
StrictTransformer,
Canonicalizer,
BytecodePatcher,
)
from ...ast.util import ast_call, set_lineno
from ...dialects.ext.arith import constant, index_cast
from ...dialects.ext.gpu import get_device_mapping_array_attr
from ....dialects.scf import yield_ as yield__, reduce_return, condition
from .gpu import get_device_mapping_array_attr
from ...meta import region_adder, region_op
from ...util import get_user_code_loc
from ....dialects._ods_common import (
get_op_results_or_values,
get_op_result_or_op_results,
get_default_loc_context,
_cext,
)
from ....dialects.linalg.opdsl.lang.emitter import _is_index_type
from ....dialects.scf import (
_Dialect,
IfOp,
ForOp,
ForallOp,
ParallelOp,
InParallelOp,
ReduceOp,
WhileOp,
ExecuteRegionOp,
)
from ....dialects.scf import *
from ....dialects.scf import _Dialect, yield_ as yield__, reduce_return, condition
from ....ir import (
InsertionPoint,
Value,
OpResultList,
OpResult,
Operation,
OpView,
IndexType,
_denseI64ArrayAttr,
Attribute,
OpaqueType,
)
from .arith import constant, index_cast

logger = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion mlir/extras/dialects/ext/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from ... import types as T
from ....dialects import tensor
from ...dialects.ext.arith import ArithValue, Scalar, constant
from ....dialects.tensor import *
from .arith import ArithValue, Scalar, constant
from ...meta import region_op, _update_caller_vars
from ...._mlir_libs._mlir import register_value_caster
from ...util import get_user_code_loc
Expand Down
21 changes: 6 additions & 15 deletions mlir/extras/dialects/ext/transform.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,23 @@
from typing import Optional, Union, Sequence

from ... import types as T
from ...meta import region_op
from ...util import get_user_code_loc
from ....dialects.transform.structured import _dispatch_mixed_values, TileUsingForOp
from ....dialects import pdl
from ....dialects.transform import *
from ....dialects.transform.loop import *
from ....dialects.transform.structured import *
from ....dialects._ods_common import get_op_result_or_op_results
from ....dialects._structured_transform_ops_gen import (
TileUsingForallOp,
MatchOp,
)
from ....dialects.transform import ApplyPatternsOp
from ....dialects.transform import (
SequenceOp,
FailurePropagationMode,
YieldOp,
AnyOpType,
OperationType,
)
from ....dialects.transform.loop import LoopUnrollOp
from ....dialects.transform import GetParentOp
from ....dialects.transform.structured import _dispatch_mixed_values
from ....ir import (
Type,
Value,
Operation,
StringAttr,
)
from ....dialects._ods_common import get_op_result_or_op_results
from ....dialects import pdl


pdl_operation_t = lambda: pdl.OperationType.get()

Expand Down

0 comments on commit 69c8f26

Please sign in to comment.