Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tk5] Implement slice analysis and sufficient op coverage for a softmax kernel. #222

Merged
merged 5 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 188 additions & 4 deletions python/shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
from . import context

__all__ = [
"BoundedSymbolicValue",
"KernelBuffer",
"Grid",
"InputBuffer",
"OutputBuffer",
"SymbolDef",
"TemporaryBuffer",
"sym",
"sym_0",
"sym_1",
"sym_2",
"sym_n1",
]


Expand Down Expand Up @@ -73,7 +78,37 @@ def ir_type_asm(self) -> str:
###############################################################################


class SymbolDef:
class SymbolExpr:
def is_one(self) -> Optional[bool]:
"""Returns True if the symbol is known to be 1.

Return False if known to be != 1 and None if not known.
"""
raise NotImplementedError

def is_non_negative(self) -> Optional[bool]:
"""Returns True is the symbol is known to be non-negative.

Returns False if known to be negative and None if not known.
"""
raise NotImplementedError

def is_positive(self) -> Optional[bool]:
"""Returns True is the symbol is known to be greater than zero.

Returns False if known to be <= 0 and None if not known.
"""
raise NotImplementedError

def is_negative(self) -> Optional[bool]:
"""Returns True is the symbol is known to be greater than zero.

Returns False if known to be <= 0 and None if not known.
"""
raise NotImplementedError


class SymbolDef(SymbolExpr):
"""Represents a named symbol representing a dimension in a shape."""

ALL_SYMBOLS: ClassVar[dict[str, "SymbolDef"]] = dict()
Expand Down Expand Up @@ -101,9 +136,153 @@ def __getattr__(self, n):

return Expando()

def is_one(self) -> Optional[bool]:
value = IndexingContext.current().get_static_value(self)
if value is None:
return None
return value == 1

def is_non_negative(self) -> Optional[bool]:
value = IndexingContext.current().get_static_value(self)
if value is None:
return None
return value >= 0

def is_positive(self) -> Optional[bool]:
value = IndexingContext.current().get_static_value(self)
if value is None:
return None
return value > 0

def is_negative(self) -> Optional[bool]:
value = IndexingContext.current().get_static_value(self)
if value is None:
return None
return value < 0


sym = SymbolDef.create_expando()

sym_0 = SymbolDef("0")
sym_1 = SymbolDef("1")
sym_2 = SymbolDef("2")
sym_n1 = SymbolDef("-1")


###############################################################################
# Bounded symbolic value.
###############################################################################

BoundedRangeExprT = tuple[Optional[SymbolExpr], Optional[SymbolExpr]]


class _BoundedSymbolicValueMeta(type):
"""Meta-class for deriving new bounded symbolic values."""

range: BoundedRangeExprT

def __new__(mcls, name: str, bases, dct, *, range: BoundedRangeExprT):
dct["range"] = range
dct["__qualname__"] = _bounded_symbolic_value_repr(range=range)
new_class = type.__new__(mcls, name, bases, dct)
return new_class

def __repr__(cls):
return _bounded_symbolic_value_repr(range=cls.range)

@property
def min_bound(cls) -> Optional[SymbolExpr]:
return cls.range[0]

@property
def max_bound(cls) -> Optional[SymbolExpr]:
return cls.range[1]

def bound(
cls: Type[SubtypeT],
min_bound: Optional[SymbolExpr],
max_bound: Optional[SymbolExpr],
) -> Type[SubtypeT]:
class Bounded(BoundedSymbolicValue, range=(min_bound, max_bound)):
...

return Bounded

def narrow(
cls: Type[SubtypeT],
*,
min_bound: Optional[SymbolExpr] = None,
max_bound: Optional[SymbolExpr] = None,
) -> Type[SubtypeT]:
class Bounded(
BoundedSymbolicValue,
range=(
min_bound if min_bound is not None else cls.min_bound,
max_bound if max_bound is not None else cls.max_bound,
),
):
...

return Bounded


def _bounded_symbolic_value_repr(*, range: BoundedRangeExprT) -> str:
min_expr, max_expr = range
min_s = repr(min_expr) if min_expr is not None else "*"
max_s = repr(max_expr) if max_expr is not None else "*"
return f"BoundedSymbolicValue({min_s} : {max_s})"


class BoundedSymbolicValue(
SymbolExpr, metaclass=_BoundedSymbolicValueMeta, range=(None, None)
):
"""Represents a symbolic value that is bounded to a range fixed for the type."""

def __init__(self, value: Optional[int] = None):
self.value = value

def __repr__(self):
return f"{type(self)}({'proxy' if self.value is None else self.value})"

@property
def static_range(self) -> Optional[tuple[int, int]]:
# TODO: This is a hack until shape derivation is in place.
ctx = IndexingContext.current()
mn, mx = type(self).range
if mn is not None:
mn = ctx.get_static_value(mn)
if mx is not None:
mx = ctx.get_static_value(mx)
if mn is not None and mx is not None:
return mn, mx
else:
return None

def is_one(self) -> Optional[bool]:
r = self.static_range
if r:
return r[0] == 1 and r[1] == 2
return None

def is_non_negative(self) -> Optional[bool]:
r = self.static_range
if r:
return r[0] >= 0
return None

def is_positive(self) -> Optional[bool]:
r = self.static_range
if r:
return r[0] > 0
return None

def is_negative(self) -> Optional[bool]:
r = self.static_range
if r:
return r[1] < 0
return None


###############################################################################
# Grid
###############################################################################
Expand Down Expand Up @@ -271,7 +450,7 @@ def __repr__(cls):
)


def _is_kernel_buffer_meta_derived(t: type) -> bool:
def is_kernel_buffer_meta_derived(t: type) -> bool:
return isinstance(t, _KernelBufferMeta)


Expand Down Expand Up @@ -361,7 +540,12 @@ class IndexingContext:
__tk_context_idname__ = "IndexingContext"

def __init__(self):
self.constant_bindings: dict[SymbolDef, int] = {}
self.constant_bindings: dict[SymbolDef, int] = {
sym_0: 0,
sym_1: 1,
sym_2: 2,
sym_n1: -1,
}

def bind_constant(self, sym: SymbolDef, value: int):
existing = self.constant_bindings.get(sym)
Expand All @@ -371,7 +555,7 @@ def bind_constant(self, sym: SymbolDef, value: int):
)
self.constant_bindings[sym] = value

def get_static_value(self, sym: SymbolDef) -> Optional[int]:
def get_static_value(self, sym: SymbolExpr) -> Optional[int]:
"""If the symbol can be resolved to a static value, returns it."""
return self.constant_bindings.get(sym)

Expand Down
17 changes: 15 additions & 2 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import torch.fx as fx

from .indexing import (
BoundedSymbolicValue,
Grid,
KernelBuffer,
sym_0,
)

from ..lang.types import (
Expand Down Expand Up @@ -106,13 +109,23 @@ def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, ite


class CompiledContext(BaseContext):
def __init__(self, tracer: KernelTracer):
def __init__(self, tracer: KernelTracer, *, grid_type: Type[Grid]):
super().__init__(eager=False)
self.tracer = tracer
self.grid_type = grid_type

def handle_thread_program_id(self, op, axis: int) -> Index:
grid_shape = self.grid_type.symbolic_shape
if axis < 0 or axis >= len(grid_shape):
raise IndexError(
f"Illegal index into grid of rank {len(grid_shape)}: {axis}"
)
proxy = self.tracer.create_proxy(
"call_function", op, args=(axis,), kwargs={}, type_expr=Index
"call_function",
op,
args=(axis,),
kwargs={},
type_expr=BoundedSymbolicValue.bound(sym_0, grid_shape[axis]),
)
return proxy

Expand Down
Loading
Loading