Skip to content

Commit afc8358

Browse files
Refactor FxImporter to new location.
1 parent 4117974 commit afc8358

File tree

9 files changed

+149
-101
lines changed

9 files changed

+149
-101
lines changed

python/shark_turbine/aot/builtins/jittable.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
)
2323
from torch.fx.passes.shape_prop import TensorMetadata
2424

25-
from ...dynamo.importer import (
26-
GraphNodeImporter,
27-
FxImporter,
28-
)
29-
3025
from ...dynamo.passes import (
3126
DEFAULT_DECOMPOSITIONS,
3227
)
3328

29+
from ...importers.fx_importer import (
30+
GraphNodeImporter,
31+
FxImporter,
32+
)
33+
3434
from ..passes import (
3535
functorch_functionalize,
3636
)

python/shark_turbine/aot/support/ir_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414
import torch
1515

16-
from ...dynamo.importer import (
16+
from ...importers.fx_importer import (
1717
ContextCache,
1818
TORCH_DTYPE_TO_MLIR_TYPE_ASM,
1919
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Importers from various systems
2+
3+
This directory is self-contained and intended to be shared with other
4+
projects with its source-of-truth in torch-mlir.
5+
6+
All MLIR API dependencies must route through the relative `ir.py`, which
7+
it is expected that sub-projects will customize accordingly.

python/shark_turbine/dynamo/importer.py renamed to python/shark_turbine/importers/fx_importer.py

Lines changed: 55 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,13 @@
33
# Licensed under the Apache License v2.0 with LLVM Exceptions.
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
import builtins
76
import logging
87
import operator
98
import re
109
from types import NoneType, BuiltinMethodType, BuiltinFunctionType
1110
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
1211
import numpy as np
1312

14-
from iree.compiler.ir import (
15-
Attribute as MlirAttribute,
16-
Block,
17-
Context,
18-
FloatAttr,
19-
BF16Type,
20-
ComplexType,
21-
F16Type,
22-
F32Type,
23-
F64Type,
24-
FunctionType,
25-
InsertionPoint,
26-
IntegerAttr,
27-
IntegerType,
28-
RankedTensorType,
29-
Location,
30-
Module,
31-
Operation,
32-
StringAttr,
33-
Type as MlirType,
34-
Value,
35-
DenseResourceElementsAttr,
36-
)
37-
38-
import iree.compiler.dialects.func as func_dialect
39-
from iree.compiler.ir import SymbolTable
40-
41-
# import iree.compiler.dialects.torch as torch_dialect
42-
43-
4413
import torch
4514
import torch.fx as torch_fx
4615
from torch.fx.passes.shape_prop import TensorMetadata
@@ -67,6 +36,36 @@
6736
Argument as NodeArgument,
6837
)
6938

39+
from .ir import (
40+
Attribute,
41+
Block,
42+
Context,
43+
DenseResourceElementsAttr,
44+
FloatAttr,
45+
BF16Type,
46+
ComplexType,
47+
F16Type,
48+
F32Type,
49+
F64Type,
50+
FunctionType,
51+
InsertionPoint,
52+
IntegerAttr,
53+
IntegerType,
54+
RankedTensorType,
55+
Location,
56+
Module,
57+
Operation,
58+
StringAttr,
59+
SymbolTable,
60+
IrType,
61+
Value,
62+
func_dialect,
63+
)
64+
65+
from .utils import (
66+
TypeSubclassMap,
67+
)
68+
7069
__all__ = [
7170
"FxImporter",
7271
]
@@ -100,7 +99,7 @@
10099
torch.complex128: "complex<f64>",
101100
}
102101

103-
TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], MlirType]] = {
102+
TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = {
104103
torch.float16: lambda: F16Type.get(),
105104
torch.bfloat16: lambda: BF16Type.get(),
106105
torch.float32: lambda: F32Type.get(),
@@ -313,7 +312,7 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
313312
for result_node in node.args[0]:
314313
if result_node is None:
315314
result_types.append(
316-
MlirType.parse("!torch.none", context=self._c)
315+
IrType.parse("!torch.none", context=self._c)
317316
)
318317
else:
319318
result_types.append(self._cc.node_val_to_type(result_node))
@@ -341,19 +340,19 @@ class ContextCache:
341340

342341
def __init__(self, context: Context):
343342
self._c = context
344-
self._dtype_to_type: Dict[TorchDtype, MlirType] = {}
345-
self._tensor_metadata_cache: Dict[Tuple[torch.Size, torch.dtype], MlirType] = {}
343+
self._dtype_to_type: Dict[TorchDtype, IrType] = {}
344+
self._tensor_metadata_cache: Dict[Tuple[torch.Size, torch.dtype], IrType] = {}
346345

347346
# Common types.
348347
with context:
349-
self.torch_bool_type = MlirType.parse("!torch.bool")
350-
self.torch_float_type = MlirType.parse("!torch.float")
351-
self.torch_int_type = MlirType.parse("!torch.int")
352-
self.torch_none_type = MlirType.parse("!torch.none")
353-
self.torch_str_type = MlirType.parse("!torch.str")
354-
self.torch_device_type = MlirType.parse("!torch.Device")
355-
356-
def integer_attr(self, value: int, bits: int) -> MlirAttribute:
348+
self.torch_bool_type = IrType.parse("!torch.bool")
349+
self.torch_float_type = IrType.parse("!torch.float")
350+
self.torch_int_type = IrType.parse("!torch.int")
351+
self.torch_none_type = IrType.parse("!torch.none")
352+
self.torch_str_type = IrType.parse("!torch.str")
353+
self.torch_device_type = IrType.parse("!torch.Device")
354+
355+
def integer_attr(self, value: int, bits: int) -> Attribute:
357356
c = self._c
358357
return IntegerAttr.get(IntegerType.get_signless(bits, c), value)
359358

@@ -362,16 +361,16 @@ def integer_attr(self, value: int, bits: int) -> MlirAttribute:
362361
def format_asm_shape(self, shape: torch.Size) -> str:
363362
return ",".join("?" if is_symbolic(d) else str(d) for d in list(shape))
364363

365-
"""Return MlirType for !torch.vtensor with the given shape and dtype"""
364+
"""Return IrType for !torch.vtensor with the given shape and dtype"""
366365

367366
def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype):
368367
shape_asm = self.format_asm_shape(shape)
369368
mlir_dtype = str(self.dtype_to_type(dtype))
370-
return MlirType.parse(
369+
return IrType.parse(
371370
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
372371
)
373372

374-
def node_val_to_type(self, node: torch_fx.Node) -> MlirType:
373+
def node_val_to_type(self, node: torch_fx.Node) -> IrType:
375374
try:
376375
tensor_meta = node.meta.get("tensor_meta")
377376
val = node.meta.get("val")
@@ -393,7 +392,7 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType:
393392

394393
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
395394
if t is not None:
396-
return MlirType.parse(t, self._c)
395+
return IrType.parse(t, self._c)
397396

398397
raise NotImplementedError(
399398
f"FIXME: Unsupported placeholder node (this often indicates that a necessary) "
@@ -404,7 +403,7 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType:
404403
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
405404
)
406405

407-
def tensor_metadata_to_type(self, tm: TensorMetadata) -> MlirType:
406+
def tensor_metadata_to_type(self, tm: TensorMetadata) -> IrType:
408407
tm_shape = tuple(
409408
item.node if is_symbolic(item) else item for item in list(tm.shape)
410409
)
@@ -416,20 +415,20 @@ def tensor_metadata_to_type(self, tm: TensorMetadata) -> MlirType:
416415
self._tensor_metadata_cache[key] = t
417416
return t
418417

419-
def dtype_to_type(self, dtype: TorchDtype) -> MlirType:
418+
def dtype_to_type(self, dtype: TorchDtype) -> IrType:
420419
t = self._dtype_to_type.get(dtype)
421420
if t is None:
422421
try:
423422
asm = TORCH_DTYPE_TO_MLIR_TYPE_ASM[dtype]
424423
except IndexError:
425424
raise ValueError(f"Unknown conversion from {dtype} to IREE type")
426-
t = MlirType.parse(asm, self._c)
425+
t = IrType.parse(asm, self._c)
427426
self._dtype_to_type[dtype] = t
428427
return t
429428

430-
def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> MlirType:
429+
def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> IrType:
431430
dtype_asm = str(self.dtype_to_type(tensor.dtype))
432-
return MlirType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>")
431+
return IrType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>")
433432

434433
def get_node_location(self, node: torch_fx.Node) -> Optional[Location]:
435434
stack_trace = node.meta.get("stack_trace")
@@ -844,7 +843,7 @@ def _import_list_argument(
844843
else:
845844
list_type = PY_TYPE_TO_TORCH_LIST_TYPE[element_type]
846845

847-
result_type = MlirType.parse(list_type, context=self._c)
846+
result_type = IrType.parse(list_type, context=self._c)
848847
operation = Operation.create(
849848
"torch.prim.ListConstruct",
850849
results=[result_type],
@@ -869,44 +868,8 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value:
869868
return cvt(arg, self, self._cc)
870869

871870

872-
class TypeSubclassMap:
873-
"""Mapping of super-types to values.
874-
875-
Maintains a cache of actual types seen and uses that instead of a linear
876-
scan.
877-
"""
878-
879-
__slots__ = [
880-
"_cache",
881-
"_mapping",
882-
]
883-
884-
def __init__(self):
885-
# The linear list of converters.
886-
self._mapping: List[Tuple[type, Any]] = []
887-
# When there is a hit on the linear mapping, memoize it here.
888-
self._cache: Dict[type, Any] = {}
889-
890-
def map(self, t: type, value: Any):
891-
self._mapping.append((t, value))
892-
self._cache[t] = value
893-
894-
def lookup(self, t: type) -> Any:
895-
try:
896-
return self._cache[t]
897-
except KeyError:
898-
pass
899-
for t_super, value in self._mapping:
900-
if issubclass(t, t_super):
901-
self._cache[t] = value
902-
return value
903-
else:
904-
self._cache[t] = None
905-
return None
906-
907-
908871
def _make_constant_op(
909-
op_name: str, value_attr: MlirAttribute, result_type: Optional[MlirType] = None
872+
op_name: str, value_attr: Attribute, result_type: Optional[IrType] = None
910873
) -> Operation:
911874
return Operation.create(
912875
op_name,
@@ -915,7 +878,7 @@ def _make_constant_op(
915878
)
916879

917880

918-
def create_mlir_tensor_type(tensor: torch.Tensor) -> MlirType:
881+
def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
919882
try:
920883
dtype = tensor.dtype
921884
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
@@ -925,7 +888,7 @@ def create_mlir_tensor_type(tensor: torch.Tensor) -> MlirType:
925888
raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type")
926889

927890

928-
def _make_vtensor_literal_op(tensor: torch.Tensor, vtensor_type: MlirType) -> Operation:
891+
def _make_vtensor_literal_op(tensor: torch.Tensor, vtensor_type: IrType) -> Operation:
929892
npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype)
930893
assert (
931894
npy_dtype is not None

python/shark_turbine/importers/ir.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2023 Nod Labs, Inc
2+
# Portions Copyright 2022 The IREE Authors
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
5+
# See https://llvm.org/LICENSE.txt for license information.
6+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
8+
from iree.compiler.ir import (
9+
Attribute as Attribute,
10+
Block,
11+
Context,
12+
DenseResourceElementsAttr,
13+
FloatAttr,
14+
BF16Type,
15+
ComplexType,
16+
F16Type,
17+
F32Type,
18+
F64Type,
19+
FunctionType,
20+
InsertionPoint,
21+
IntegerAttr,
22+
IntegerType,
23+
RankedTensorType,
24+
Location,
25+
Module,
26+
Operation,
27+
StringAttr,
28+
SymbolTable,
29+
Type as IrType,
30+
Value,
31+
)
32+
33+
from iree.compiler.dialects import (
34+
func as func_dialect,
35+
)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2023 Nod Labs, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from typing import Any, Dict, List, Tuple
8+
9+
10+
class TypeSubclassMap:
11+
"""Mapping of super-types to values.
12+
13+
Maintains a cache of actual types seen and uses that instead of a linear
14+
scan.
15+
"""
16+
17+
__slots__ = [
18+
"_cache",
19+
"_mapping",
20+
]
21+
22+
def __init__(self):
23+
# The linear list of converters.
24+
self._mapping: List[Tuple[type, Any]] = []
25+
# When there is a hit on the linear mapping, memoize it here.
26+
self._cache: Dict[type, Any] = {}
27+
28+
def map(self, t: type, value: Any):
29+
self._mapping.append((t, value))
30+
self._cache[t] = value
31+
32+
def lookup(self, t: type) -> Any:
33+
try:
34+
return self._cache[t]
35+
except KeyError:
36+
pass
37+
for t_super, value in self._mapping:
38+
if issubclass(t, t_super):
39+
self._cache[t] = value
40+
return value
41+
else:
42+
self._cache[t] = None
43+
return None

tests/dynamo/importer_dynamic_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch._export import dynamic_dim
1414

1515
# from torch._export.constraints import constrain_as_size, constrain_as_value
16-
from shark_turbine.dynamo.importer import FxImporter
16+
from shark_turbine.importers.fx_importer import FxImporter
1717
from shark_turbine.dynamo.passes import turbine_cpu_pass_pipeline
1818
import torch
1919
import torch._dynamo as dynamo

0 commit comments

Comments
 (0)