Skip to content

Commit

Permalink
enable free floating funcs (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental authored Jan 1, 2024
1 parent d56d2f8 commit 823c14a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
26 changes: 19 additions & 7 deletions mlir/extras/dialects/ext/func.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import inspect
import sys
from functools import update_wrapper
from typing import Union, Optional

from ...meta import op_region_builder
from ...util import get_user_code_loc, make_maybe_no_args_decorator
from ....dialects.func import *
from ....extras import types as T
from ....ir import (
InsertionPoint,
FlatSymbolRefAttr,
FunctionType,
InsertionPoint,
StringAttr,
TypeAttr,
FlatSymbolRefAttr,
Type,
TypeAttr,
Value,
)

Expand Down Expand Up @@ -91,9 +93,15 @@ def prep_func_types(sig, return_types):
if not p.annotation is inspect.Signature.empty
]
assert all(
isinstance(r, Type) for r in input_types
isinstance(r, (str, Type)) for r in input_types
), f"all input types must be mlir types {input_types=}"
return input_types, return_types, [get_user_code_loc()] * len(sig.parameters)
user_loc = get_user_code_loc()
# If ir.Context is none (like for deferred func emit)
if user_loc is None:
user_locs = None
else:
user_locs = [user_loc] * len(sig.parameters)
return input_types, return_types, user_locs


class FuncBase:
Expand Down Expand Up @@ -169,9 +177,13 @@ def __str__(self):
def emit(self, *call_args) -> FuncOp:
if self._func_op is None:
if len(call_args) == 0:
input_types = self.input_types
input_types = self.input_types[:]
for i, v in enumerate(input_types):
if isinstance(v, str):
input_types[i] = Type(eval(v, {"T": T}))
else:
input_types = [a.type for a in call_args]

function_type = TypeAttr.get(
FunctionType.get(
inputs=input_types,
Expand Down Expand Up @@ -244,7 +256,7 @@ def func(
loc=loc,
ip=ip,
)
func.__name__ = f.__name__
func = update_wrapper(func, f)
if emit:
func.emit()
return func
26 changes: 26 additions & 0 deletions tests/test_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,3 +463,29 @@ def foo1():
"""
)
filecheck(correct, ctx.module)


@func(emit=False)
def matmul_i16_i16(
A: "T.memref(64, 32, T.i16())",
B: "T.memref(32, 64, T.i16())",
C: "T.memref(64, 64, T.i16())",
):
linalg.matmul(A, B, C)


def test_defer_emit(ctx: MLIRContext):

matmul_i16_i16.emit()

correct = dedent(
"""\
module {
func.func @matmul_i16_i16(%arg0: memref<64x32xi16>, %arg1: memref<32x64xi16>, %arg2: memref<64x64xi16>) {
linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : memref<64x32xi16>, memref<32x64xi16>) outs(%arg2 : memref<64x64xi16>)
return
}
}
"""
)
filecheck(correct, ctx.module)

0 comments on commit 823c14a

Please sign in to comment.