Skip to content

Commit

Permalink
Get indexing_test.py working.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Jan 26, 2024
1 parent 4248f01 commit 86161a8
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 32 deletions.
217 changes: 186 additions & 31 deletions python/shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, ClassVar, Optional, Type, TypeVar, Union, cast

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum

import sympy
import torch

from .. import ops
Expand All @@ -14,6 +16,7 @@
"Grid",
"InputBuffer",
"OutputBuffer",
"IndexingContext",
"IndexExpr",
"IndexSymbol",
"TemporaryBuffer",
Expand Down Expand Up @@ -77,8 +80,6 @@ def ir_type_asm(self) -> str:
# Dimension symbols
###############################################################################

import sympy

IndexExpr = sympy.core.Expr
IndexSymbol = sympy.core.Symbol

Expand All @@ -90,16 +91,38 @@ def index_symbol(name: str) -> IndexSymbol:
return symbol


def index_value(value: Any) -> IndexExpr:
expr = sympy.sympify(value)
if not isinstance(expr, sympy.core.Integer):
raise ValueError(f"Expected Integer from {value}. Got {type(expr)}")
return expr


class _IndexSymbolExpando:
def __getattr__(self, n):
return index_symbol(n)


sym = _IndexSymbolExpando()

sym_0 = index_symbol("0")
sym_1 = index_symbol("1")
sym_2 = index_symbol("2")
sym_n1 = index_symbol("-1")
sym_0 = index_value(0)
sym_1 = index_value(1)
sym_2 = index_value(2)
sym_n1 = index_value(-1)

###############################################################################
# Shape expressions
###############################################################################

SymbolicDimable = Union[str, IndexExpr]
SymbolicShapeable = tuple[SymbolicDimable]
SymbolicShapeExpr = tuple[IndexExpr]


def make_symbolic_shape(elements: SymbolicShapeable) -> SymbolicShapeExpr:
return tuple(
index_symbol(expr) if isinstance(expr, str) else expr for expr in elements
)


###############################################################################
Expand All @@ -116,7 +139,7 @@ def __new__(
bases,
dct,
*,
symbolic_shape: Optional[tuple[IndexExpr]],
symbolic_shape: Optional[SymbolicShapeExpr],
):
new_class = type.__new__(mcls, name, bases, dct)
new_class.symbolic_shape = symbolic_shape
Expand All @@ -125,23 +148,23 @@ def __new__(
return new_class

def __class_getitem__(
cls, symbolic_shape: Union[IndexExpr, tuple[IndexExpr]]
cls, symbolic_shape: Union[SymbolicDimable, tuple[SymbolicShapeable]]
) -> Type["Grid"]:
if not isinstance(symbolic_shape, tuple):
symbolic_shape = (symbolic_shape,)
return cast(Grid, _make_shaped_grid(cls, symbolic_shape))
return cast(Grid, _make_shaped_grid(cls, make_symbolic_shape(symbolic_shape)))

def __repr__(self):
if self.symbolic_shape:
return f"Grid[{', '.join(s.name for s in self.symbolic_shape)}]"
return f"Grid[{', '.join(repr(s) for s in self.symbolic_shape)}]"
else:
return "Grid"


class Grid(metaclass=_GridMeta, symbolic_shape=None):
"""Grid with bounding symbolic shape information in the type."""

symbolic_shape: ClassVar[Optional[tuple[IndexExpr]]]
symbolic_shape: ClassVar[Optional[SymbolicShapeExpr]]
rank: int

def __init__(self, *dims: int):
Expand Down Expand Up @@ -210,7 +233,7 @@ class _KernelBufferMeta(type):

element_type: ElementType
usage: KernelBufferUsage
symbolic_shape: Optional[tuple[IndexExpr]]
symbolic_shape: Optional[SymbolicShapeExpr]
rank: Optional[int]

def __new__(
Expand Down Expand Up @@ -238,7 +261,7 @@ def new_subtype(
cls: Type[SubtypeT],
*,
element_type: Union[NotSetType, ElementType] = NotSet,
symbolic_shape: Union[NotSetType, Optional[tuple[IndexExpr]]] = NotSet,
symbolic_shape: Union[NotSetType, Optional[SymbolicShapeable]] = NotSet,
usage: Union[NotSetType, KernelBufferUsage] = NotSet,
) -> Type[SubtypeT]:
init_element_type = (
Expand All @@ -251,7 +274,7 @@ def new_subtype(

class Subtype(cls):
element_type = init_element_type
symbolic_shape = init_symbolic_shape
symbolic_shape = make_symbolic_shape(init_symbolic_shape)
usage = init_usage

return Subtype
Expand Down Expand Up @@ -281,7 +304,7 @@ def _kernel_buffer_type_repr(
) -> str:
root = KernelBufferUsage._type_name(usage)
if symbolic_shape:
stem = f"{root}[{', '.join(s.name for s in symbolic_shape)}]"
stem = f"{root}[{', '.join(repr(s) for s in symbolic_shape)}]"
else:
stem = f"{root}"
if element_type != DefaultElementType:
Expand All @@ -303,7 +326,7 @@ class KernelBuffer(metaclass=_KernelBufferMeta):
"""

usage: ClassVar[KernelBufferUsage]
symbolic_shape: ClassVar[Optional[tuple[IndexExpr]]]
symbolic_shape: ClassVar[Optional[SymbolicShapeExpr]]
rank: Optional[int]

def __init__(self, tensor: torch.Tensor):
Expand All @@ -318,11 +341,13 @@ def __init__(self, tensor: torch.Tensor):
self.rank = tensor_rank

def __class_getitem__(
cls, symbolic_shape: Union[IndexExpr, tuple[IndexExpr]]
cls, symbolic_shape: Union[IndexExpr, SymbolicShapeExpr]
) -> Type["KernelBuffer"]:
if not isinstance(symbolic_shape, tuple):
symbolic_shape = (symbolic_shape,)
return cast(cls, cls.new_subtype(symbolic_shape=symbolic_shape))
return cast(
cls, cls.new_subtype(symbolic_shape=make_symbolic_shape(symbolic_shape))
)

def __repr__(self):
return f"{type(self)}({self._tensor})"
Expand Down Expand Up @@ -350,33 +375,163 @@ class TemporaryBuffer(KernelBuffer):
# IndexingContext
###############################################################################

ShapedType = Union[Type[KernelBuffer], Type[Grid]]
Dims = list[Union[None, IndexSymbol, int]]


@dataclass(slots=True)
class _ShapedBinding:
# The instance of shaped_type. Can be anything. We resolve dimension values
# against this.
instance: Any

# Shaped type that backes the instance.
shaped_type: ShapedType

# The symbolic shape (tuple of index expressions).
symbolic_shape: list[IndexExpr]

# Concrete dimensions instantiated with. Each is an integer or a dynamic
# dim symbol. It can also be None if the value is not dynamic and must be
# inferred from context.
dims: Dims


class IndexingContext:
"""The indexing context is responsible handling the binding of indexed
symbols to concrete values.
"""

__slots__ = [
"subs",
"shaped_bindings",
"dyn_dims",
"frozen_subs",
]

__tk_context_idname__ = "IndexingContext"

def __init__(self):
self.constant_bindings: dict[IndexSymbol, int] = {
sym_0: 0,
sym_1: 1,
sym_2: 2,
sym_n1: -1,
}
self.subs: dict[IndexSymbol, int] = {}
# Indexed by .instance
self.shaped_bindings: dict[Any, _ShapedBinding] = {}
self.dyn_dims: list[IndexSymbol] = []
self.frozen_subs: list[IndexSymbol, int] = []

def next_dyn_dim(self) -> IndexSymbol:
s = index_symbol(f"D{len(self.dyn_dims)}")
self.dyn_dims.append(s)
return s

def bind_shaped(
self, instance: Any, shaped_type: ShapedType, dims: Dims
) -> _ShapedBinding:
symbolic_shape = shaped_type.symbolic_shape
rank = shaped_type.rank
if rank != len(dims):
raise ValueError(
f"For {shaped_type} mismatched symbolic shape vs dim arity: {symbolic_shape} vs {dims}"
)
binding = _ShapedBinding(
instance, shaped_type, list(symbolic_shape), list(dims)
)
self.shaped_bindings[instance] = binding

def bind_constant(self, sym: IndexSymbol, value: int):
existing = self.constant_bindings.get(sym)
if existing is not None and existing != value:
try:
self._bind_symbol(sym, value)
except ValueError:
raise ValueError(
f"Attempt to rebind symbol {sym} to different constant ({value} vs {existing})"
f"Attempt to bind symbol {sym}={value} conflicts with previous "
f"{self.subs[sym]}"
)
self.constant_bindings[sym] = value

def get_static_value(self, sym: IndexExpr) -> Optional[int]:
"""If the symbol can be resolved to a static value, returns it."""
return self.constant_bindings.get(sym)
def _bind_symbol(self, symbol: IndexSymbol, value: int):
existing = self.subs.get(symbol)
if existing is not None and existing != value:
raise ValueError
self.subs[symbol] = value

def finalize(self):
assert len(self.frozen_subs) == 0
# Go over everything we know and bind all free symbols.
for _sb in self.shaped_bindings.values():
for i in range(_sb.shaped_type.rank):
dim_expr = _sb.symbolic_shape[i]
dim_value = _sb.dims[i]
if dim_value is not None:
if isinstance(dim_expr, IndexSymbol):
try:
self._bind_symbol(dim_expr, dim_value)
except ValueError as e:
raise ValueError(
f"For {_sb.instance} of {_sb.shaped_type} attempt to bind dim "
f"{dim_expr}={dim_value} conflicts with previous "
f"{self.subs[dim_expr]}"
)

# Note: At this point, we could solve the set of equation based
# bindings and maybe elicit some additional information, but for now
# we do forward-only inference.
frozen_subs = self.frozen_subs
frozen_subs.extend(self.subs.items())

# Check any equation based dims.
errors = []
for _sb in self.shaped_bindings.values():
for i in range(_sb.shaped_type.rank):
dim_expr = _sb.symbolic_shape[i]
dim_value = _sb.dims[i]
dim_expr = dim_expr.subs(frozen_subs).simplify()
_sb.symbolic_shape[i] = dim_expr
if dim_value is None:
# Ensure resolves to a known value.
if not isinstance(dim_expr, sympy.Integer):
errors.append(
f" {_sb.instance} of {_sb.shaped_type}[{i}]={dim_expr} did not "
f"resolve to a known value"
)
continue
# Notate the inferred dim.
_sb.dims[i] = int(dim_expr)
elif isinstance(dim_expr, sympy.Integer):
dim_expr_value = int(dim_expr)
if isinstance(dim_value, IndexExpr):
# If dynamic, then it turns out we have enough static information,
# so replace.
_sb.dims[i] = dim_expr_value
else:
# If static, make sure it matches the runtime value.
if dim_value is not None and dim_expr_value != dim_value:
errors.append(
f" {_sb.instance} of {_sb.shaped_type}[{i}]={dim_expr} was initialized with a "
f"mismatched runtime value of {dim_value}"
)
continue

# Error check.
if errors:
joined = "\n".join(errors)
raise ValueError(f"Indexing mismatches were encountered:\n{joined}")

def eval_dim(self, instance: Any, shaped_type: ShapedType, pos: int) -> IndexExpr:
# TODO: Could see if shaped_type is in self.shaped_bindings: it has some
# precomputed values that may save cycles to use.
symbolic_shape = shaped_type.symbolic_shape
try:
expr = symbolic_shape[pos]
except IndexError:
raise IndexError(f"Attempt to access out of range {shaped_type}[{pos}]")
return expr.subs(self.frozen_subs).simplify()

def eval_static_dim(
self, instance: Any, shaped_type: ShapedType, pos: int
) -> Optional[int]:
expr = self.eval_dim(instance, shaped_type, pos)
if isinstance(expr, sympy.Integer):
return int(expr)
else:
return None

##### Context management.
@staticmethod
Expand Down
Loading

0 comments on commit 86161a8

Please sign in to comment.