Skip to content

Commit

Permalink
Add a pass to rename parameters.
Browse files Browse the repository at this point in the history
The implementation of this pass is really gross because there is missing IREE Python bindings for introspecting the #stream.parameter.named attribute. I'll clean that up later once I get upstream in better shape. In the meantime, this should work.
  • Loading branch information
stellaraccident committed Nov 4, 2023
1 parent 68b4b3b commit 53c82dd
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 0 deletions.
150 changes: 150 additions & 0 deletions python/shark_turbine/transforms/general/rename_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""This pass will rename any #stream.parameter.named<> attributes on globals.
It can either be used as-is or by sub-classing (i.e. in a model specific
subclass that renames from A->B, etc).
By default, no attributes are touched unless:
* rename_map= has an exact match
* rename_callback= returns a change
"""

from typing import Callable, Dict, List, Optional, Tuple, Union

import re

from iree.compiler.ir import (
Attribute,
Operation,
StringAttr,
)

from ..rewriter import *
from iree.compiler.ir import Context

ScopedName = Tuple[Optional[str], str]
MaybeScopedName = Union[str, ScopedName]


class RenameParametersPass(Pass):
def __init__(
self,
root_op: Operation,
*,
rename_map: Optional[Dict[MaybeScopedName, MaybeScopedName]] = None,
rename_callback: Callable[[Optional[str], str], Optional[ScopedName]] = lambda scope, name: None
):
super().__init__(root_op)
self.context = root_op.context
self.rename_map = rename_map
self.rename_callback = rename_callback
with self.context:
# Make a prototype named parameter attribute so we can get its
# typeid.
self.parameter_attr_typeid = Attribute.parse(
'#stream.parameter.named<""::"">'
).typeid

def run(self):
globals = self.globals
for _, global_op in self.globals.items():
attrs = global_op.op.attributes
try:
initial_value = attrs["initial_value"]
except KeyError:
continue

if initial_value.typeid == self.parameter_attr_typeid:
updated_value = self.remap(initial_value)
if updated_value != initial_value:
attrs["initial_value"] = updated_value

def remap(self, parameter_attr: Attribute) -> Attribute:
comps = _parse_parameter_attr(parameter_attr)
if not comps:
return parameter_attr
if len(comps) == 1:
orig_scope = None
name = comps[0]
else:
orig_scope, name = comps

def norm_map_result(result: MaybeScopedName) -> ScopedName:
if isinstance(result, str):
return orig_scope, result
if len(result) == 1:
return orig_scope, result[0]
else:
return result[0], result[1]

def make_attr(scoped_name: ScopedName) -> Attribute:
if scoped_name[0] is None:
name = StringAttr.get(scoped_name[1])
return Attribute.parse(f"#stream.parameter.named<{name}> : {parameter_attr.type}")
else:
scope = StringAttr.get(scoped_name[0])
name = StringAttr.get(scoped_name[1])
return Attribute.parse(f"#stream.parameter.named<{scope}::{name}> : {parameter_attr.type}")

# Check the rename map.
# Check with a fully-qualified name.
result = self.rename_map.get((orig_scope, name))
if result is not None:
return make_attr(norm_map_result(result))
# Check with just the
result = self.rename_map.get(name)
if result is not None:
return make_attr(norm_map_result(result))

# Check the callback.
result = self.rename_callback(orig_scope, name)
if result is not None:
return make_attr(result)

return parameter_attr


def _parse_parameter_attr(attr: Attribute) -> Optional[List[str]]:
# Returns one of:
# None if failed to parse of not a simple named parameter without attributes.
# [name] for names with a default scope
# [scope, name] for scoped names
# TODO: Burn this with fire. We should add Python bindings for these attributes
# vs string parsing them.
# TODO: The parameter attribute correctly parses C escapes but prints unescaped :(
asm = str(attr)
PREFIX = "#stream.parameter.named<"
STR_PATTERN = re.compile(r'"(.*?)(?!\\)"')
if asm.startswith(PREFIX):
asm = asm[len(PREFIX) :]
results = []
# Parse a str
m = STR_PATTERN.search(asm)
if not m or m.start() != 0:
return None
results.append(m.group(1))
asm = asm[m.end() :]
# Parse ::
if asm.startswith("::"):
asm = asm[2:]
else:
return results
# Parse a str
m = STR_PATTERN.search(asm)
if not m or m.start() != 0:
return None
results.append(m.group(1))
asm = asm[m.end() :]
if not asm.startswith(">"):
return None
return results


if __name__ == "__main__":
pass_main(RenameParametersPass)
63 changes: 63 additions & 0 deletions tests/transforms/general/rename_parameters_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023 Nod Labs, Inc
# Portions Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pathlib import Path
import logging
import unittest

from iree.compiler.ir import (
Context,
Operation,
)

from shark_turbine.transforms import rewriter
from shark_turbine.transforms.general import rename_parameters

SIMPLE_GLOBALS_ASM = r"""
module {
util.global private @_params.classifier.default {noinline} = #stream.parameter.named<"default"> : tensor<30xf32>
util.global private @_params.classifier.weight {noinline} = #stream.parameter.named<"foo"::"WEIGHT"> : tensor<30x20xf32>
util.global private @_params.classifier.bias {noinline} = #stream.parameter.named<"foo"::"params.classifier.bias"> : tensor<30xf32>
util.global private @_params.classifier.other {noinline} = dense<0.0> : tensor<30xf32>
util.global private @_uninitialized {noinline} : tensor<30xf32>
}
"""


class RenameTest(unittest.TestCase):
def testBasic(self):
with Context() as context:
module_op = Operation.parse(SIMPLE_GLOBALS_ASM)
rename_parameters.RenameParametersPass(
module_op,
rename_map={
"WEIGHT": "weight",
("foo", "params.classifier.bias"): ("bar", "BIAS"),
},
rename_callback=lambda scope, name: ("XXX", "YYY")
if name == "default"
else None,
).run()
module_asm = str(module_op)
print(module_asm)
self.assertIn(
'@_params.classifier.default {noinline} = #stream.parameter.named<"XXX"::"YYY"> : tensor<30xf32>',
module_asm,
)
self.assertIn(
'@_params.classifier.weight {noinline} = #stream.parameter.named<"foo"::"weight"> : tensor<30x20xf32>',
module_asm,
)
self.assertIn(
'@_params.classifier.bias {noinline} = #stream.parameter.named<"bar"::"BIAS"> : tensor<30xf32>',
module_asm,
)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

0 comments on commit 53c82dd

Please sign in to comment.