Skip to content

Commit

Permalink
Adds support for IREE external parameters at module export time. (#155)
Browse files Browse the repository at this point in the history
* Adds kwargs to builtins that create parameters and globals:
  * external=True : Emit as an external parameter
  * external_scope : Set an explicit external scope
* name_mapper : Callback to map logical name on the module to parameter
name
  • Loading branch information
stellaraccident authored Nov 4, 2023
1 parent d2c0d49 commit 90bf857
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 95 deletions.
64 changes: 54 additions & 10 deletions python/shark_turbine/aot/builtins/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any
from typing import Any, Callable, Optional

import torch.nn as nn

Expand All @@ -17,6 +17,11 @@
abstractify_single_value,
)

from ..support.ir_utils import (
NameMapCallback,
GlobalAttributes,
)

from ..support.utils import (
TreeSpec,
tree_flatten,
Expand All @@ -34,10 +39,22 @@ def __init__(
value: Any,
*,
name: str = "global",
initialize: bool = True,
mutable: bool = False,
mutable: Optional[bool] = None,
initialize: Optional[bool] = None,
external: Optional[bool] = None,
external_scope: Optional[str] = None,
name_mapper: Optional[NameMapCallback] = None,
attrs: Optional[GlobalAttributes] = None,
):
super().__init__(initialize=initialize, mutable=mutable)
if attrs is None:
attrs = GlobalAttributes(
mutable=mutable,
initialize=initialize,
external=external,
external_scope=external_scope,
name_mapper=name_mapper,
)
super().__init__(attrs)
self._name = name
self._value = value
_, self._schema = tree_flatten(self._value)
Expand All @@ -59,11 +76,22 @@ def __init__(
self,
tree,
*,
name: str = "global",
initialize: bool = True,
mutable: bool = False,
mutable: Optional[bool] = None,
initialize: Optional[bool] = None,
external: Optional[bool] = None,
external_scope: Optional[str] = None,
name_mapper: Optional[NameMapCallback] = None,
attrs: Optional[GlobalAttributes] = None,
):
super().__init__(initialize=initialize, mutable=mutable)
if attrs is None:
attrs = GlobalAttributes(
mutable=mutable,
initialize=initialize,
external=external,
external_scope=external_scope,
name_mapper=name_mapper,
)
super().__init__(attrs)
self._tree = tree
self._items, self._schema = tree_flatten(tree)
self._names, _ = tree_flatten(_transform_tree_to_names("", tree))
Expand Down Expand Up @@ -95,9 +123,25 @@ class export_parameters(GlobalsDef, TreeAbstractifiable):
]

def __init__(
self, nn_module: nn.Module, *, initialize: bool = True, mutable: bool = False
self,
nn_module: nn.Module,
*,
mutable: Optional[bool] = None,
initialize: Optional[bool] = None,
external: Optional[bool] = None,
external_scope: Optional[str] = None,
name_mapper: Optional[NameMapCallback] = None,
attrs: Optional[GlobalAttributes] = None,
):
super().__init__(initialize=initialize, mutable=mutable)
if attrs is None:
attrs = GlobalAttributes(
mutable=mutable,
initialize=initialize,
external=external,
external_scope=external_scope,
name_mapper=name_mapper,
)
super().__init__(attrs)
self._param_list = list(nn_module.named_parameters())
self._tree = dict(self._param_list)
_, self._schema = tree_flatten(self._tree)
Expand Down
1 change: 1 addition & 0 deletions python/shark_turbine/aot/support/ir_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""Unifies all imports of iree.compiler.ir into one place."""

from iree.compiler.ir import (
Attribute,
Block,
BlockArgument,
Context,
Expand Down
100 changes: 80 additions & 20 deletions python/shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from .ir_imports import (
Attribute,
Block,
BlockArgument,
BF16Type,
Expand Down Expand Up @@ -101,6 +102,57 @@
torch.complex128: "complex<f64>",
}

###############################################################################
# Configuration
###############################################################################

# Maps a name to an altered name. If returns None, then the original
# name is used (this lets Dict.get serve as a NameMapCallback).
NameMapCallback = Callable[[str], Optional[str]]


class GlobalAttributes:
"""Settings for how to initialize the global."""

__slots__ = [
"mutable",
"initialize",
"external",
"external_scope",
"name_mapper",
"noinline",
]

def __init__(
self,
mutable: bool = False,
initialize: Optional[bool] = None,
external: Optional[bool] = None,
external_scope: Optional[str] = None,
name_mapper: Optional[NameMapCallback] = None,
noinline: bool = True,
):
if initialize and external:
raise ValueError("Only one of initialize=True or external=True is allowed")
if initialize is None and external is None:
# Initialize by default.
initialize = True

self.mutable = mutable
self.initialize = initialize
self.external = external
self.external_scope = external_scope
self.name_mapper = name_mapper
self.noinline = noinline

def map_name(self, name: str) -> str:
if self.name_mapper:
new_name = self.name_mapper(name)
if new_name is not None:
return new_name
return name


###############################################################################
# Builders
###############################################################################
Expand Down Expand Up @@ -187,33 +239,42 @@ def create_tensor_global(
symbol_name: str,
t: torch.Tensor,
*,
mutable: bool = False,
initialize: bool = True,
noinline: bool = True,
attrs: GlobalAttributes,
logical_name: Optional[str] = None,
) -> Tuple[str, Operation, IrType]:
element_type = self.torch_dtype_to_iree_type(t.dtype)
with self.global_ip, Location.unknown():
tensor_type = RankedTensorType.get(list(t.shape), element_type)
attrs = {
ir_attrs = {
"sym_name": StringAttr.get(symbol_name),
"sym_visibility": StringAttr.get("private"),
"type": TypeAttr.get(tensor_type),
}
if noinline:
attrs["noinline"] = UnitAttr.get()
if mutable:
attrs["is_mutable"] = UnitAttr.get()
if initialize:
if attrs.noinline:
ir_attrs["noinline"] = UnitAttr.get()
if attrs.mutable:
ir_attrs["is_mutable"] = UnitAttr.get()
if attrs.initialize:
detached_tensor = t.detach().contiguous().cpu()
array = np.array(detached_tensor)
# We know that a Numpy array is a ReadableBuffer so ignore type error.
contents = memoryview(array) # type: ignore
# TODO: Add resource elements to Python API and use that.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/137
elements_attr = DenseElementsAttr.get(contents, type=tensor_type)
attrs["initial_value"] = elements_attr
ir_attrs["initial_value"] = elements_attr
elif attrs.external:
external_scope_attr = StringAttr.get(attrs.external_scope or "model")
external_name = attrs.map_name(
logical_name if logical_name is not None else symbol_name
)
external_name_attr = StringAttr.get(external_name)
# TODO: Have real Python builders for this.
ir_attrs["initial_value"] = Attribute.parse(
f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {tensor_type}"
)

global_op = Operation.create("util.global", attributes=attrs)
global_op = Operation.create("util.global", attributes=ir_attrs)
self.symbol_table.insert(global_op)
actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value
return actual_symbol_name, global_op, tensor_type
Expand All @@ -223,22 +284,21 @@ def create_typed_global(
symbol_name: str,
global_type: IrType,
*,
mutable: bool = False,
initialize: bool = True,
noinline: bool = True,
attrs: GlobalAttributes,
logical_name: Optional[str] = None,
) -> Tuple[str, Operation]:
with self.global_ip, Location.unknown():
attrs = {
ir_attrs = {
"sym_name": StringAttr.get(symbol_name),
"sym_visibility": StringAttr.get("private"),
"type": TypeAttr.get(global_type),
}
if noinline:
attrs["noinline"] = UnitAttr.get()
if mutable:
attrs["is_mutable"] = UnitAttr.get()
if attrs.noinline:
ir_attrs["noinline"] = UnitAttr.get()
if attrs.mutable:
ir_attrs["is_mutable"] = UnitAttr.get()

global_op = Operation.create("util.global", attributes=attrs)
global_op = Operation.create("util.global", attributes=ir_attrs)
self.symbol_table.insert(global_op)
actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value
return actual_symbol_name, global_op
Expand Down
33 changes: 20 additions & 13 deletions python/shark_turbine/aot/support/procedural/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

from typing import (
Any,
Callable,
Dict,
Generator,
Optional,
Sequence,
Tuple,
)
Expand All @@ -25,6 +27,7 @@
)

from ..ir_utils import (
GlobalAttributes,
ModuleBuilder,
)

Expand Down Expand Up @@ -87,13 +90,11 @@ class GlobalsDef:
"""Base class for all exporting descriptors."""

__slots__ = [
"_initialize",
"_mutable",
"_attrs",
]

def __init__(self, *, initialize: bool, mutable: bool):
self._initialize = initialize
self._mutable = mutable
def __init__(self, attrs: GlobalAttributes):
self._attrs = attrs

def items(self) -> Generator[Tuple[str, Any], None, None]:
"""Yields tuples of name/value exports."""
Expand Down Expand Up @@ -124,8 +125,8 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
) = module_builder.create_tensor_global(
f"_{fq_name}",
value,
initialize=self._initialize,
mutable=self._mutable,
attrs=self._attrs,
logical_name=fq_name,
)
mapping.value = IrGlobalTensor(
fq_name,
Expand All @@ -140,11 +141,14 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractTensor):
global_type = value.get_ir_type(module_builder)
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
initialize=self._initialize,
mutable=self._mutable,
attrs=self._attrs,
logical_name=fq_name,
)
flat_globals.append(
IrGlobalTensor(
Expand All @@ -159,11 +163,14 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractScalar):
global_type = value.get_ir_type(module_builder)
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
initialize=self._initialize,
mutable=self._mutable,
attrs=self._attrs,
logical_name=fq_name,
)
flat_globals.append(
IrGlobalScalar(
Expand Down
1 change: 1 addition & 0 deletions python/shark_turbine/aot/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# Reference mapping
###############################################################################


# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
Expand Down
Loading

0 comments on commit 90bf857

Please sign in to comment.