diff --git a/python/shark_turbine/aot/builtins/globals.py b/python/shark_turbine/aot/builtins/globals.py index 8df9b8929..66ea41022 100644 --- a/python/shark_turbine/aot/builtins/globals.py +++ b/python/shark_turbine/aot/builtins/globals.py @@ -5,7 +5,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any +from typing import Any, Callable, Optional import torch.nn as nn @@ -17,6 +17,11 @@ abstractify_single_value, ) +from ..support.ir_utils import ( + NameMapCallback, + GlobalAttributes, +) + from ..support.utils import ( TreeSpec, tree_flatten, @@ -34,10 +39,22 @@ def __init__( value: Any, *, name: str = "global", - initialize: bool = True, - mutable: bool = False, + mutable: Optional[bool] = None, + initialize: Optional[bool] = None, + external: Optional[bool] = None, + external_scope: Optional[str] = None, + name_mapper: Optional[NameMapCallback] = None, + attrs: Optional[GlobalAttributes] = None, ): - super().__init__(initialize=initialize, mutable=mutable) + if attrs is None: + attrs = GlobalAttributes( + mutable=mutable, + initialize=initialize, + external=external, + external_scope=external_scope, + name_mapper=name_mapper, + ) + super().__init__(attrs) self._name = name self._value = value _, self._schema = tree_flatten(self._value) @@ -59,11 +76,22 @@ def __init__( self, tree, *, - name: str = "global", - initialize: bool = True, - mutable: bool = False, + mutable: Optional[bool] = None, + initialize: Optional[bool] = None, + external: Optional[bool] = None, + external_scope: Optional[str] = None, + name_mapper: Optional[NameMapCallback] = None, + attrs: Optional[GlobalAttributes] = None, ): - super().__init__(initialize=initialize, mutable=mutable) + if attrs is None: + attrs = GlobalAttributes( + mutable=mutable, + initialize=initialize, + external=external, + external_scope=external_scope, + name_mapper=name_mapper, + ) + super().__init__(attrs) self._tree = tree self._items, self._schema = tree_flatten(tree) self._names, _ = tree_flatten(_transform_tree_to_names("", tree)) @@ -95,9 +123,25 @@ class export_parameters(GlobalsDef, TreeAbstractifiable): ] def __init__( - self, nn_module: nn.Module, *, initialize: bool = True, mutable: bool = False + self, + nn_module: nn.Module, + *, + mutable: Optional[bool] = None, + initialize: Optional[bool] = None, + external: Optional[bool] = None, + external_scope: Optional[str] = None, + name_mapper: Optional[NameMapCallback] = None, + attrs: Optional[GlobalAttributes] = None, ): - super().__init__(initialize=initialize, mutable=mutable) + if attrs is None: + attrs = GlobalAttributes( + mutable=mutable, + initialize=initialize, + external=external, + external_scope=external_scope, + name_mapper=name_mapper, + ) + super().__init__(attrs) self._param_list = list(nn_module.named_parameters()) self._tree = dict(self._param_list) _, self._schema = tree_flatten(self._tree) diff --git a/python/shark_turbine/aot/support/ir_imports.py b/python/shark_turbine/aot/support/ir_imports.py index 77cdf1033..9fbc3752f 100644 --- a/python/shark_turbine/aot/support/ir_imports.py +++ b/python/shark_turbine/aot/support/ir_imports.py @@ -8,6 +8,7 @@ """Unifies all imports of iree.compiler.ir into one place.""" from iree.compiler.ir import ( + Attribute, Block, BlockArgument, Context, diff --git a/python/shark_turbine/aot/support/ir_utils.py b/python/shark_turbine/aot/support/ir_utils.py index 094af4ba1..f4877a857 100644 --- a/python/shark_turbine/aot/support/ir_utils.py +++ b/python/shark_turbine/aot/support/ir_utils.py @@ -23,6 +23,7 @@ ) from .ir_imports import ( + Attribute, Block, BlockArgument, BF16Type, @@ -101,6 +102,57 @@ torch.complex128: "complex", } +############################################################################### +# Configuration +############################################################################### + +# Maps a name to an altered name. If returns None, then the original +# name is used (this lets Dict.get serve as a NameMapCallback). +NameMapCallback = Callable[[str], Optional[str]] + + +class GlobalAttributes: + """Settings for how to initialize the global.""" + + __slots__ = [ + "mutable", + "initialize", + "external", + "external_scope", + "name_mapper", + "noinline", + ] + + def __init__( + self, + mutable: bool = False, + initialize: Optional[bool] = None, + external: Optional[bool] = None, + external_scope: Optional[str] = None, + name_mapper: Optional[NameMapCallback] = None, + noinline: bool = True, + ): + if initialize and external: + raise ValueError("Only one of initialize=True or external=True is allowed") + if initialize is None and external is None: + # Initialize by default. + initialize = True + + self.mutable = mutable + self.initialize = initialize + self.external = external + self.external_scope = external_scope + self.name_mapper = name_mapper + self.noinline = noinline + + def map_name(self, name: str) -> str: + if self.name_mapper: + new_name = self.name_mapper(name) + if new_name is not None: + return new_name + return name + + ############################################################################### # Builders ############################################################################### @@ -187,23 +239,22 @@ def create_tensor_global( symbol_name: str, t: torch.Tensor, *, - mutable: bool = False, - initialize: bool = True, - noinline: bool = True, + attrs: GlobalAttributes, + logical_name: Optional[str] = None, ) -> Tuple[str, Operation, IrType]: element_type = self.torch_dtype_to_iree_type(t.dtype) with self.global_ip, Location.unknown(): tensor_type = RankedTensorType.get(list(t.shape), element_type) - attrs = { + ir_attrs = { "sym_name": StringAttr.get(symbol_name), "sym_visibility": StringAttr.get("private"), "type": TypeAttr.get(tensor_type), } - if noinline: - attrs["noinline"] = UnitAttr.get() - if mutable: - attrs["is_mutable"] = UnitAttr.get() - if initialize: + if attrs.noinline: + ir_attrs["noinline"] = UnitAttr.get() + if attrs.mutable: + ir_attrs["is_mutable"] = UnitAttr.get() + if attrs.initialize: detached_tensor = t.detach().contiguous().cpu() array = np.array(detached_tensor) # We know that a Numpy array is a ReadableBuffer so ignore type error. @@ -211,9 +262,19 @@ def create_tensor_global( # TODO: Add resource elements to Python API and use that. # See: https://github.com/nod-ai/SHARK-Turbine/issues/137 elements_attr = DenseElementsAttr.get(contents, type=tensor_type) - attrs["initial_value"] = elements_attr + ir_attrs["initial_value"] = elements_attr + elif attrs.external: + external_scope_attr = StringAttr.get(attrs.external_scope or "model") + external_name = attrs.map_name( + logical_name if logical_name is not None else symbol_name + ) + external_name_attr = StringAttr.get(external_name) + # TODO: Have real Python builders for this. + ir_attrs["initial_value"] = Attribute.parse( + f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {tensor_type}" + ) - global_op = Operation.create("util.global", attributes=attrs) + global_op = Operation.create("util.global", attributes=ir_attrs) self.symbol_table.insert(global_op) actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value return actual_symbol_name, global_op, tensor_type @@ -223,22 +284,21 @@ def create_typed_global( symbol_name: str, global_type: IrType, *, - mutable: bool = False, - initialize: bool = True, - noinline: bool = True, + attrs: GlobalAttributes, + logical_name: Optional[str] = None, ) -> Tuple[str, Operation]: with self.global_ip, Location.unknown(): - attrs = { + ir_attrs = { "sym_name": StringAttr.get(symbol_name), "sym_visibility": StringAttr.get("private"), "type": TypeAttr.get(global_type), } - if noinline: - attrs["noinline"] = UnitAttr.get() - if mutable: - attrs["is_mutable"] = UnitAttr.get() + if attrs.noinline: + ir_attrs["noinline"] = UnitAttr.get() + if attrs.mutable: + ir_attrs["is_mutable"] = UnitAttr.get() - global_op = Operation.create("util.global", attributes=attrs) + global_op = Operation.create("util.global", attributes=ir_attrs) self.symbol_table.insert(global_op) actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value return actual_symbol_name, global_op diff --git a/python/shark_turbine/aot/support/procedural/globals.py b/python/shark_turbine/aot/support/procedural/globals.py index 2d4e96202..be882eb59 100644 --- a/python/shark_turbine/aot/support/procedural/globals.py +++ b/python/shark_turbine/aot/support/procedural/globals.py @@ -9,8 +9,10 @@ from typing import ( Any, + Callable, Dict, Generator, + Optional, Sequence, Tuple, ) @@ -25,6 +27,7 @@ ) from ..ir_utils import ( + GlobalAttributes, ModuleBuilder, ) @@ -87,13 +90,11 @@ class GlobalsDef: """Base class for all exporting descriptors.""" __slots__ = [ - "_initialize", - "_mutable", + "_attrs", ] - def __init__(self, *, initialize: bool, mutable: bool): - self._initialize = initialize - self._mutable = mutable + def __init__(self, attrs: GlobalAttributes): + self._attrs = attrs def items(self) -> Generator[Tuple[str, Any], None, None]: """Yields tuples of name/value exports.""" @@ -124,8 +125,8 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any: ) = module_builder.create_tensor_global( f"_{fq_name}", value, - initialize=self._initialize, - mutable=self._mutable, + attrs=self._attrs, + logical_name=fq_name, ) mapping.value = IrGlobalTensor( fq_name, @@ -140,11 +141,14 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any: continue elif isinstance(value, AbstractTensor): global_type = value.get_ir_type(module_builder) - (actual_symbol_name, global_op,) = module_builder.create_typed_global( + ( + actual_symbol_name, + global_op, + ) = module_builder.create_typed_global( f"_{fq_name}", global_type, - initialize=self._initialize, - mutable=self._mutable, + attrs=self._attrs, + logical_name=fq_name, ) flat_globals.append( IrGlobalTensor( @@ -159,11 +163,14 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any: continue elif isinstance(value, AbstractScalar): global_type = value.get_ir_type(module_builder) - (actual_symbol_name, global_op,) = module_builder.create_typed_global( + ( + actual_symbol_name, + global_op, + ) = module_builder.create_typed_global( f"_{fq_name}", global_type, - initialize=self._initialize, - mutable=self._mutable, + attrs=self._attrs, + logical_name=fq_name, ) flat_globals.append( IrGlobalScalar( diff --git a/python/shark_turbine/aot/support/utils.py b/python/shark_turbine/aot/support/utils.py index edeec2ef9..f2b29962e 100644 --- a/python/shark_turbine/aot/support/utils.py +++ b/python/shark_turbine/aot/support/utils.py @@ -30,6 +30,7 @@ # Reference mapping ############################################################################### + # Opaque value to indicate something is empty. Used in cases where 'None' # may have a different meaning. class EmptyType: diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index 07618952c..d1486becc 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -40,12 +40,8 @@ def run(self, x=AbstractTensor(128, 20)): inst = GlobalModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn( - "util.global private @_params.classifier.weight", module_str - ) - self.assertIn( - "util.global private @_params.classifier.bias", module_str - ) + self.assertIn("util.global private @_params.classifier.weight", module_str) + self.assertIn("util.global private @_params.classifier.bias", module_str) def testGlobalLoadFromPyTree(self): m = SimpleParams() @@ -104,12 +100,8 @@ def update_params(me, updates=abstractify(params)): inst = GlobalModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn( - "util.global.store %arg0, @_params.classifier.weight", module_str - ) - self.assertIn( - "util.global.store %arg1, @_params.classifier.bias", module_str - ) + self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str) + self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str) def testGlobalStoreFromLeaf(self): m = SimpleParams() @@ -117,17 +109,13 @@ def testGlobalStoreFromLeaf(self): class GlobalModule(CompiledModule): params = export_parameters(m, initialize=False, mutable=True) - def update_bias( - self, new_bias=abstractify(params["classifier.bias"]) - ): + def update_bias(self, new_bias=abstractify(params["classifier.bias"])): self.params["classifier.bias"] = new_bias inst = GlobalModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn( - "util.global.store %arg0, @_params.classifier.bias", module_str - ) + self.assertIn("util.global.store %arg0, @_params.classifier.bias", module_str) def testExportSingleGlobalTensor(self): state_example = torch.randn(3, 11) @@ -142,9 +130,7 @@ def read_state(self): module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) self.assertIn("util.global private @_state0.global", module_str) - self.assertIn( - "%_state0.global = util.global.load @_state0.global", module_str - ) + self.assertIn("%_state0.global = util.global.load @_state0.global", module_str) self.assertIn("return %_state0.global", module_str) def testExportTreeGlobalTensors(self): @@ -170,18 +156,10 @@ def read_state(self): self.assertIn("util.global private @_state0.seq.1", module_str) self.assertIn("util.global private @_state0.seq.2", module_str) self.assertIn("util.global private @_state0.data", module_str) - self.assertIn( - "%_state0.data = util.global.load @_state0.data", module_str - ) - self.assertIn( - "%_state0.seq.0 = util.global.load @_state0.seq.0", module_str - ) - self.assertIn( - "%_state0.seq.1 = util.global.load @_state0.seq.1", module_str - ) - self.assertIn( - "%_state0.seq.2 = util.global.load @_state0.seq.2", module_str - ) + self.assertIn("%_state0.data = util.global.load @_state0.data", module_str) + self.assertIn("%_state0.seq.0 = util.global.load @_state0.seq.0", module_str) + self.assertIn("%_state0.seq.1 = util.global.load @_state0.seq.1", module_str) + self.assertIn("%_state0.seq.2 = util.global.load @_state0.seq.2", module_str) self.assertIn( "return %_state0.data, %_state0.seq.0, %_state0.seq.1, %_state0.seq.2", module_str, @@ -198,9 +176,7 @@ def testUpdateGlobalStateTree(self): } class SingleState(CompiledModule): - state0 = export_global_tree( - state_example, mutable=True, initialize=False - ) + state0 = export_global_tree(state_example, mutable=True, initialize=False) def read_state(self, updates=abstractify(state_example)): self.state0 = updates @@ -222,9 +198,7 @@ def testTensorUpdateGlobal(self): update_example = torch.randn(1, 20) class UpdateState(CompiledModule): - state0 = export_global( - state_example, mutable=True, initialize=False - ) + state0 = export_global(state_example, mutable=True, initialize=False) def tensor_update_state(self, update=abstractify(update_example)): return IREE.tensor_update(self.state0, update, 0, 0) @@ -242,9 +216,7 @@ def testTensorUpdateGlobalReturnNone(self): update_example = torch.randn(1, 1, 4) class UpdateState(CompiledModule): - state0 = export_global( - state_example, mutable=True, initialize=False - ) + state0 = export_global(state_example, mutable=True, initialize=False) def tensor_update_state(self, update=abstractify(update_example)): thing = [] @@ -259,6 +231,85 @@ def tensor_update_state(self, update=abstractify(update_example)): module_str, ) + def testExternalGlobalParametersDefaults(self): + m = SimpleParams() + + class GlobalModule( + CompiledModule, export_name="external_global_parameters_defaults" + ): + params = export_parameters(m, external=True) + compute = jittable(m.forward) + + def run(self, x=AbstractTensor(128, 20)): + return self.compute(x) + + inst = GlobalModule(context=Context()) + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn( + '#stream.parameter.named<"model"::"params.classifier.weight"> : tensor<30x20xf32>', + module_str, + ) + self.assertIn( + '#stream.parameter.named<"model"::"params.classifier.bias"> : tensor<30xf32>', + module_str, + ) + + def testExternalGlobalParametersExplicit(self): + m = SimpleParams() + + class GlobalModule( + CompiledModule, export_name="external_global_parameters_explicit" + ): + params = export_parameters( + m, external=True, external_scope="foo", name_mapper=lambda s: s.upper() + ) + compute = jittable(m.forward) + + def run(self, x=AbstractTensor(128, 20)): + return self.compute(x) + + inst = GlobalModule(context=Context()) + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn( + '#stream.parameter.named<"foo"::"PARAMS.CLASSIFIER.WEIGHT"> : tensor<30x20xf32>', + module_str, + ) + self.assertIn( + '#stream.parameter.named<"foo"::"PARAMS.CLASSIFIER.BIAS"> : tensor<30xf32>', + module_str, + ) + + def testExternalGlobalParametersMapDict(self): + m = SimpleParams() + mapper = { + "params.classifier.weight": "WEIGHT", + } + + class GlobalModule( + CompiledModule, export_name="external_global_parameters_map_dict" + ): + params = export_parameters( + m, external=True, external_scope="foo", name_mapper=mapper.get + ) + compute = jittable(m.forward) + + def run(self, x=AbstractTensor(128, 20)): + return self.compute(x) + + inst = GlobalModule(context=Context()) + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn( + '#stream.parameter.named<"foo"::"WEIGHT"> : tensor<30x20xf32>', + module_str, + ) + self.assertIn( + '#stream.parameter.named<"foo"::"params.classifier.bias"> : tensor<30xf32>', + module_str, + ) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/aot/iree_procedural_test.py b/tests/aot/iree_procedural_test.py index 3ec2f159a..3069e4d29 100644 --- a/tests/aot/iree_procedural_test.py +++ b/tests/aot/iree_procedural_test.py @@ -69,7 +69,10 @@ def foobar(self, x=AbstractTensor(None), y=AbstractTensor(3)): inst = BasicModule(context=Context(), import_to=None) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn('flow.tensor.trace "DEBUG" = [%arg0 : tensor{%dim}, %arg1 : tensor<3xf32>]', module_str) + self.assertIn( + 'flow.tensor.trace "DEBUG" = [%arg0 : tensor{%dim}, %arg1 : tensor<3xf32>]', + module_str, + ) def testStoreDynamic(self): class BasicModule(CompiledModule): diff --git a/tests/aot/jittable_test.py b/tests/aot/jittable_test.py index b8599827d..739bb9222 100644 --- a/tests/aot/jittable_test.py +++ b/tests/aot/jittable_test.py @@ -42,9 +42,7 @@ def compute(): def testCallWithStructure(self): class ProcArgsModule(CompiledModule): - def call_with_dicts( - self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1) - ): + def call_with_dicts(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): intermediate = self.compute({"a": a, "b": b}) return self.compute(intermediate) @@ -61,9 +59,7 @@ def compute(struct): def testCallWithArgsKwargs(self): class ProcArgsModule(CompiledModule): - def call_with_kwargs( - self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1) - ): + def call_with_kwargs(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): intermediate = self.compute(**{"a": a, "b": b}) return self.compute(**intermediate) @@ -78,9 +74,7 @@ def compute(*, a, b): def testDynamicDims(self): class ProcArgsModule(CompiledModule): - def dynamic_dim( - self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1) - ): + def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)): return self.compute( a, b,