Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Oct 4, 2023
1 parent d80dfd3 commit b62d3d5
Show file tree
Hide file tree
Showing 12 changed files with 987 additions and 233 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ jobs:
shell: bash
run: |
if [ ${{ matrix.os }} == 'windows-2022' ]; then
pytest -s --ignore-glob=*test_other_hosts* tests
pytest -s tests
else
pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests
pytest --capture=tee-sys tests
fi
- name: Test mwe
Expand Down Expand Up @@ -167,5 +167,5 @@ jobs:
pip install -q .[test,mlir] -f https://makslevental.github.io/wheels
mlir-python-utils-generate-all-upstream-trampolines
pytest --capture=tee-sys --ignore-glob=*test_other_hosts* tests
pytest --capture=tee-sys tests
python examples/mwe.py
9 changes: 4 additions & 5 deletions examples/mwe.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import numpy as np

# you need this to register the memref value caster
# noinspection PyUnresolvedReferences
import mlir.utils.dialects.ext.memref
import mlir.utils.types as T
from mlir.utils.ast.canonicalize import canonicalize
from mlir.utils.context import MLIRContext, mlir_mod_ctx
from mlir.utils.dialects.ext.arith import constant
from mlir.utils.dialects.ext.func import func
from mlir.utils.dialects.ext.scf import canonicalizer as scf, range_ as range
from mlir.utils.runtime.passes import Pipeline, run_pipeline
from mlir.utils.runtime.passes import Pipeline
from mlir.utils.runtime.refbackend import LLVMJITBackend

# you need this to register the memref value caster
# noinspection PyUnresolvedReferences
import mlir.utils.dialects.ext.memref


def setting_memref(ctx: MLIRContext, backend: LLVMJITBackend):
K = 10
Expand Down
159 changes: 159 additions & 0 deletions mlir/utils/dialects/ext/cf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from ...meta import maybe_cast
from ...util import get_user_code_loc, get_result_or_results, Successor
from ....dialects import cf
from ....dialects._ods_common import (
get_op_results_or_values,
get_default_loc_context,
segmented_accessor,
get_op_result_or_value,
)
from ....ir import Value, InsertionPoint, Block, OpView


class BranchOp(cf.BranchOp.__base__):
OPERATION_NAME = "cf.br"

_ODS_REGIONS = (0, True)

def __init__(self, destOperands, dest=None, *, loc=None, ip=None):
operands = []
results = []
attributes = {}
regions = None
operands.extend(get_op_results_or_values(destOperands))
_ods_context = get_default_loc_context(loc)
_ods_successors = []
if dest is not None:
_ods_successors.append(dest)
super().__init__(
self.build_generic(
attributes=attributes,
results=results,
operands=operands,
successors=_ods_successors,
regions=regions,
loc=loc,
ip=ip,
)
)


class CondBranchOp(OpView):
OPERATION_NAME = "cf.cond_br"

_ODS_OPERAND_SEGMENTS = [1, -1, -1]

_ODS_REGIONS = (0, True)

def __init__(
self,
condition,
trueDestOperands=None,
falseDestOperands=None,
trueDest=None,
falseDest=None,
*,
loc=None,
ip=None
):
operands = []
results = []
attributes = {}
regions = None
operands.append(get_op_result_or_value(condition))
if trueDestOperands is None:
trueDestOperands = []
if falseDestOperands is None:
falseDestOperands = []
operands.append(get_op_results_or_values(trueDestOperands))
operands.append(get_op_results_or_values(falseDestOperands))
_ods_context = get_default_loc_context(loc)
_ods_successors = []
if trueDest is not None:
_ods_successors.append(trueDest)
if falseDest is not None:
_ods_successors.append(falseDest)
super().__init__(
self.build_generic(
attributes=attributes,
results=results,
operands=operands,
successors=_ods_successors,
regions=regions,
loc=loc,
ip=ip,
)
)

@property
def condition(self):
operand_range = segmented_accessor(
self.operation.operands, self.operation.attributes["operandSegmentSizes"], 0
)
return operand_range[0]

@property
def trueDestOperands(self):
operand_range = segmented_accessor(
self.operation.operands, self.operation.attributes["operandSegmentSizes"], 1
)
return operand_range

@property
def falseDestOperands(self):
operand_range = segmented_accessor(
self.operation.operands, self.operation.attributes["operandSegmentSizes"], 2
)
return operand_range

@property
def true(self):
return Successor(self, self.trueDestOperands, self.successors[0], 0)

@property
def false(self):
return Successor(self, self.falseDestOperands, self.successors[1], 1)


def br(dest: Value | Block = None, *dest_operands: list[Value], loc=None, ip=None):
if isinstance(dest, Value):
dest_operands = [dest] + list(dest_operands)
dest = None
if dest is None:
dest = InsertionPoint.current.block
if loc is None:
loc = get_user_code_loc()
return maybe_cast(
get_result_or_results(BranchOp(dest_operands, dest, loc=loc, ip=ip))
)


def cond_br(
condition: Value,
true_dest: Value | Block = None,
false_dest: Value | Block = None,
true_dest_operands: list[Value] = None,
false_dest_operands: list[Value] = None,
*,
loc=None,
ip=None
):
if true_dest is None:
true_dest = InsertionPoint.current.block
if false_dest is None:
false_dest = InsertionPoint.current.block
if loc is None:
loc = get_user_code_loc()
return maybe_cast(
get_result_or_results(
CondBranchOp(
condition,
true_dest_operands,
false_dest_operands,
true_dest,
false_dest,
loc=loc,
ip=ip,
)
)
)
Loading

0 comments on commit b62d3d5

Please sign in to comment.