Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for IREE external parameters at module export time. #155

Merged
merged 3 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions python/shark_turbine/aot/builtins/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any
from typing import Any, Callable, Optional

import torch.nn as nn

Expand All @@ -17,6 +17,11 @@
abstractify_single_value,
)

from ..support.ir_utils import (
NameMapCallback,
GlobalAttributes,
)

from ..support.utils import (
TreeSpec,
tree_flatten,
Expand All @@ -34,10 +39,22 @@ def __init__(
value: Any,
*,
name: str = "global",
initialize: bool = True,
mutable: bool = False,
mutable: Optional[bool] = None,
initialize: Optional[bool] = None,
external: Optional[bool] = None,
external_scope: Optional[str] = None,
name_mapper: Optional[NameMapCallback] = None,
attrs: Optional[GlobalAttributes] = None,
):
super().__init__(initialize=initialize, mutable=mutable)
if attrs is None:
attrs = GlobalAttributes(
mutable=mutable,
initialize=initialize,
external=external,
external_scope=external_scope,
name_mapper=name_mapper,
)
super().__init__(attrs)
self._name = name
self._value = value
_, self._schema = tree_flatten(self._value)
Expand All @@ -59,11 +76,22 @@ def __init__(
self,
tree,
*,
name: str = "global",
initialize: bool = True,
mutable: bool = False,
mutable: Optional[bool] = None,
initialize: Optional[bool] = None,
external: Optional[bool] = None,
external_scope: Optional[str] = None,
name_mapper: Optional[NameMapCallback] = None,
attrs: Optional[GlobalAttributes] = None,
):
super().__init__(initialize=initialize, mutable=mutable)
if attrs is None:
attrs = GlobalAttributes(
mutable=mutable,
initialize=initialize,
external=external,
external_scope=external_scope,
name_mapper=name_mapper,
)
super().__init__(attrs)
self._tree = tree
self._items, self._schema = tree_flatten(tree)
self._names, _ = tree_flatten(_transform_tree_to_names("", tree))
Expand Down Expand Up @@ -95,9 +123,25 @@ class export_parameters(GlobalsDef, TreeAbstractifiable):
]

def __init__(
self, nn_module: nn.Module, *, initialize: bool = True, mutable: bool = False
self,
nn_module: nn.Module,
*,
mutable: Optional[bool] = None,
initialize: Optional[bool] = None,
external: Optional[bool] = None,
external_scope: Optional[str] = None,
name_mapper: Optional[NameMapCallback] = None,
attrs: Optional[GlobalAttributes] = None,
):
super().__init__(initialize=initialize, mutable=mutable)
if attrs is None:
attrs = GlobalAttributes(
mutable=mutable,
initialize=initialize,
external=external,
external_scope=external_scope,
name_mapper=name_mapper,
)
super().__init__(attrs)
self._param_list = list(nn_module.named_parameters())
self._tree = dict(self._param_list)
_, self._schema = tree_flatten(self._tree)
Expand Down
1 change: 1 addition & 0 deletions python/shark_turbine/aot/support/ir_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""Unifies all imports of iree.compiler.ir into one place."""

from iree.compiler.ir import (
Attribute,
Block,
BlockArgument,
Context,
Expand Down
100 changes: 80 additions & 20 deletions python/shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from .ir_imports import (
Attribute,
Block,
BlockArgument,
BF16Type,
Expand Down Expand Up @@ -101,6 +102,57 @@
torch.complex128: "complex<f64>",
}

###############################################################################
# Configuration
###############################################################################

# Maps a name to an altered name. If returns None, then the original
# name is used (this lets Dict.get serve as a NameMapCallback).
NameMapCallback = Callable[[str], Optional[str]]


class GlobalAttributes:
"""Settings for how to initialize the global."""

__slots__ = [
"mutable",
"initialize",
"external",
"external_scope",
"name_mapper",
"noinline",
]

def __init__(
self,
mutable: bool = False,
initialize: Optional[bool] = None,
external: Optional[bool] = None,
external_scope: Optional[str] = None,
name_mapper: Optional[NameMapCallback] = None,
noinline: bool = True,
):
if initialize and external:
raise ValueError("Only one of initialize=True or external=True is allowed")
if initialize is None and external is None:
# Initialize by default.
initialize = True

self.mutable = mutable
self.initialize = initialize
self.external = external
self.external_scope = external_scope
self.name_mapper = name_mapper
self.noinline = noinline

def map_name(self, name: str) -> str:
if self.name_mapper:
new_name = self.name_mapper(name)
if new_name is not None:
return new_name
return name


###############################################################################
# Builders
###############################################################################
Expand Down Expand Up @@ -187,33 +239,42 @@ def create_tensor_global(
symbol_name: str,
t: torch.Tensor,
*,
mutable: bool = False,
initialize: bool = True,
noinline: bool = True,
attrs: GlobalAttributes,
logical_name: Optional[str] = None,
) -> Tuple[str, Operation, IrType]:
element_type = self.torch_dtype_to_iree_type(t.dtype)
with self.global_ip, Location.unknown():
tensor_type = RankedTensorType.get(list(t.shape), element_type)
attrs = {
ir_attrs = {
"sym_name": StringAttr.get(symbol_name),
"sym_visibility": StringAttr.get("private"),
"type": TypeAttr.get(tensor_type),
}
if noinline:
attrs["noinline"] = UnitAttr.get()
if mutable:
attrs["is_mutable"] = UnitAttr.get()
if initialize:
if attrs.noinline:
ir_attrs["noinline"] = UnitAttr.get()
if attrs.mutable:
ir_attrs["is_mutable"] = UnitAttr.get()
if attrs.initialize:
detached_tensor = t.detach().contiguous().cpu()
array = np.array(detached_tensor)
# We know that a Numpy array is a ReadableBuffer so ignore type error.
contents = memoryview(array) # type: ignore
# TODO: Add resource elements to Python API and use that.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/137
elements_attr = DenseElementsAttr.get(contents, type=tensor_type)
attrs["initial_value"] = elements_attr
ir_attrs["initial_value"] = elements_attr
elif attrs.external:
external_scope_attr = StringAttr.get(attrs.external_scope or "model")
external_name = attrs.map_name(
logical_name if logical_name is not None else symbol_name
)
external_name_attr = StringAttr.get(external_name)
# TODO: Have real Python builders for this.
ir_attrs["initial_value"] = Attribute.parse(
f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {tensor_type}"
)

global_op = Operation.create("util.global", attributes=attrs)
global_op = Operation.create("util.global", attributes=ir_attrs)
self.symbol_table.insert(global_op)
actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value
return actual_symbol_name, global_op, tensor_type
Expand All @@ -223,22 +284,21 @@ def create_typed_global(
symbol_name: str,
global_type: IrType,
*,
mutable: bool = False,
initialize: bool = True,
noinline: bool = True,
attrs: GlobalAttributes,
logical_name: Optional[str] = None,
) -> Tuple[str, Operation]:
with self.global_ip, Location.unknown():
attrs = {
ir_attrs = {
"sym_name": StringAttr.get(symbol_name),
"sym_visibility": StringAttr.get("private"),
"type": TypeAttr.get(global_type),
}
if noinline:
attrs["noinline"] = UnitAttr.get()
if mutable:
attrs["is_mutable"] = UnitAttr.get()
if attrs.noinline:
ir_attrs["noinline"] = UnitAttr.get()
if attrs.mutable:
ir_attrs["is_mutable"] = UnitAttr.get()

global_op = Operation.create("util.global", attributes=attrs)
global_op = Operation.create("util.global", attributes=ir_attrs)
self.symbol_table.insert(global_op)
actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value
return actual_symbol_name, global_op
Expand Down
33 changes: 20 additions & 13 deletions python/shark_turbine/aot/support/procedural/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

from typing import (
Any,
Callable,
Dict,
Generator,
Optional,
Sequence,
Tuple,
)
Expand All @@ -25,6 +27,7 @@
)

from ..ir_utils import (
GlobalAttributes,
ModuleBuilder,
)

Expand Down Expand Up @@ -87,13 +90,11 @@ class GlobalsDef:
"""Base class for all exporting descriptors."""

__slots__ = [
"_initialize",
"_mutable",
"_attrs",
]

def __init__(self, *, initialize: bool, mutable: bool):
self._initialize = initialize
self._mutable = mutable
def __init__(self, attrs: GlobalAttributes):
self._attrs = attrs

def items(self) -> Generator[Tuple[str, Any], None, None]:
"""Yields tuples of name/value exports."""
Expand Down Expand Up @@ -124,8 +125,8 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
) = module_builder.create_tensor_global(
f"_{fq_name}",
value,
initialize=self._initialize,
mutable=self._mutable,
attrs=self._attrs,
logical_name=fq_name,
)
mapping.value = IrGlobalTensor(
fq_name,
Expand All @@ -140,11 +141,14 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractTensor):
global_type = value.get_ir_type(module_builder)
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
initialize=self._initialize,
mutable=self._mutable,
attrs=self._attrs,
logical_name=fq_name,
)
flat_globals.append(
IrGlobalTensor(
Expand All @@ -159,11 +163,14 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractScalar):
global_type = value.get_ir_type(module_builder)
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
initialize=self._initialize,
mutable=self._mutable,
attrs=self._attrs,
logical_name=fq_name,
)
flat_globals.append(
IrGlobalScalar(
Expand Down
1 change: 1 addition & 0 deletions python/shark_turbine/aot/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# Reference mapping
###############################################################################


# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
Expand Down
Loading