3
3
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4
4
# See https://llvm.org/LICENSE.txt for license information.
5
5
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
- import builtins
7
6
import logging
8
7
import operator
9
8
import re
10
9
from types import NoneType , BuiltinMethodType , BuiltinFunctionType
11
10
from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple , Union
12
11
import numpy as np
13
12
14
- from iree .compiler .ir import (
15
- Attribute as MlirAttribute ,
16
- Block ,
17
- Context ,
18
- FloatAttr ,
19
- BF16Type ,
20
- ComplexType ,
21
- F16Type ,
22
- F32Type ,
23
- F64Type ,
24
- FunctionType ,
25
- InsertionPoint ,
26
- IntegerAttr ,
27
- IntegerType ,
28
- RankedTensorType ,
29
- Location ,
30
- Module ,
31
- Operation ,
32
- StringAttr ,
33
- Type as MlirType ,
34
- Value ,
35
- DenseResourceElementsAttr ,
36
- )
37
-
38
- import iree .compiler .dialects .func as func_dialect
39
- from iree .compiler .ir import SymbolTable
40
-
41
- # import iree.compiler.dialects.torch as torch_dialect
42
-
43
-
44
13
import torch
45
14
import torch .fx as torch_fx
46
15
from torch .fx .passes .shape_prop import TensorMetadata
67
36
Argument as NodeArgument ,
68
37
)
69
38
39
+ from .ir import (
40
+ Attribute ,
41
+ Block ,
42
+ Context ,
43
+ DenseResourceElementsAttr ,
44
+ FloatAttr ,
45
+ BF16Type ,
46
+ ComplexType ,
47
+ F16Type ,
48
+ F32Type ,
49
+ F64Type ,
50
+ FunctionType ,
51
+ InsertionPoint ,
52
+ IntegerAttr ,
53
+ IntegerType ,
54
+ RankedTensorType ,
55
+ Location ,
56
+ Module ,
57
+ Operation ,
58
+ StringAttr ,
59
+ SymbolTable ,
60
+ IrType ,
61
+ Value ,
62
+ func_dialect ,
63
+ )
64
+
65
+ from .utils import (
66
+ TypeSubclassMap ,
67
+ )
68
+
70
69
__all__ = [
71
70
"FxImporter" ,
72
71
]
100
99
torch .complex128 : "complex<f64>" ,
101
100
}
102
101
103
- TORCH_DTYPE_TO_MLIR_TYPE : Dict [torch .dtype , Callable [[], MlirType ]] = {
102
+ TORCH_DTYPE_TO_MLIR_TYPE : Dict [torch .dtype , Callable [[], IrType ]] = {
104
103
torch .float16 : lambda : F16Type .get (),
105
104
torch .bfloat16 : lambda : BF16Type .get (),
106
105
torch .float32 : lambda : F32Type .get (),
@@ -313,7 +312,7 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
313
312
for result_node in node .args [0 ]:
314
313
if result_node is None :
315
314
result_types .append (
316
- MlirType .parse ("!torch.none" , context = self ._c )
315
+ IrType .parse ("!torch.none" , context = self ._c )
317
316
)
318
317
else :
319
318
result_types .append (self ._cc .node_val_to_type (result_node ))
@@ -341,19 +340,19 @@ class ContextCache:
341
340
342
341
def __init__ (self , context : Context ):
343
342
self ._c = context
344
- self ._dtype_to_type : Dict [TorchDtype , MlirType ] = {}
345
- self ._tensor_metadata_cache : Dict [Tuple [torch .Size , torch .dtype ], MlirType ] = {}
343
+ self ._dtype_to_type : Dict [TorchDtype , IrType ] = {}
344
+ self ._tensor_metadata_cache : Dict [Tuple [torch .Size , torch .dtype ], IrType ] = {}
346
345
347
346
# Common types.
348
347
with context :
349
- self .torch_bool_type = MlirType .parse ("!torch.bool" )
350
- self .torch_float_type = MlirType .parse ("!torch.float" )
351
- self .torch_int_type = MlirType .parse ("!torch.int" )
352
- self .torch_none_type = MlirType .parse ("!torch.none" )
353
- self .torch_str_type = MlirType .parse ("!torch.str" )
354
- self .torch_device_type = MlirType .parse ("!torch.Device" )
355
-
356
- def integer_attr (self , value : int , bits : int ) -> MlirAttribute :
348
+ self .torch_bool_type = IrType .parse ("!torch.bool" )
349
+ self .torch_float_type = IrType .parse ("!torch.float" )
350
+ self .torch_int_type = IrType .parse ("!torch.int" )
351
+ self .torch_none_type = IrType .parse ("!torch.none" )
352
+ self .torch_str_type = IrType .parse ("!torch.str" )
353
+ self .torch_device_type = IrType .parse ("!torch.Device" )
354
+
355
+ def integer_attr (self , value : int , bits : int ) -> Attribute :
357
356
c = self ._c
358
357
return IntegerAttr .get (IntegerType .get_signless (bits , c ), value )
359
358
@@ -362,16 +361,16 @@ def integer_attr(self, value: int, bits: int) -> MlirAttribute:
362
361
def format_asm_shape (self , shape : torch .Size ) -> str :
363
362
return "," .join ("?" if is_symbolic (d ) else str (d ) for d in list (shape ))
364
363
365
- """Return MlirType for !torch.vtensor with the given shape and dtype"""
364
+ """Return IrType for !torch.vtensor with the given shape and dtype"""
366
365
367
366
def get_vtensor_type (self , shape : torch .Size , dtype : torch .dtype ):
368
367
shape_asm = self .format_asm_shape (shape )
369
368
mlir_dtype = str (self .dtype_to_type (dtype ))
370
- return MlirType .parse (
369
+ return IrType .parse (
371
370
f"!torch.vtensor<[{ shape_asm } ],{ str (mlir_dtype )} >" , context = self ._c
372
371
)
373
372
374
- def node_val_to_type (self , node : torch_fx .Node ) -> MlirType :
373
+ def node_val_to_type (self , node : torch_fx .Node ) -> IrType :
375
374
try :
376
375
tensor_meta = node .meta .get ("tensor_meta" )
377
376
val = node .meta .get ("val" )
@@ -393,7 +392,7 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType:
393
392
394
393
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE .get (type (val ))
395
394
if t is not None :
396
- return MlirType .parse (t , self ._c )
395
+ return IrType .parse (t , self ._c )
397
396
398
397
raise NotImplementedError (
399
398
f"FIXME: Unsupported placeholder node (this often indicates that a necessary) "
@@ -404,7 +403,7 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType:
404
403
f"FIXME: Illegal access to torch.fx.Node.meta: { e } ({ node .meta .keys ()} : { node .meta } )"
405
404
)
406
405
407
- def tensor_metadata_to_type (self , tm : TensorMetadata ) -> MlirType :
406
+ def tensor_metadata_to_type (self , tm : TensorMetadata ) -> IrType :
408
407
tm_shape = tuple (
409
408
item .node if is_symbolic (item ) else item for item in list (tm .shape )
410
409
)
@@ -416,20 +415,20 @@ def tensor_metadata_to_type(self, tm: TensorMetadata) -> MlirType:
416
415
self ._tensor_metadata_cache [key ] = t
417
416
return t
418
417
419
- def dtype_to_type (self , dtype : TorchDtype ) -> MlirType :
418
+ def dtype_to_type (self , dtype : TorchDtype ) -> IrType :
420
419
t = self ._dtype_to_type .get (dtype )
421
420
if t is None :
422
421
try :
423
422
asm = TORCH_DTYPE_TO_MLIR_TYPE_ASM [dtype ]
424
423
except IndexError :
425
424
raise ValueError (f"Unknown conversion from { dtype } to IREE type" )
426
- t = MlirType .parse (asm , self ._c )
425
+ t = IrType .parse (asm , self ._c )
427
426
self ._dtype_to_type [dtype ] = t
428
427
return t
429
428
430
- def tensor_to_vtensor_type (self , tensor : torch .Tensor ) -> MlirType :
429
+ def tensor_to_vtensor_type (self , tensor : torch .Tensor ) -> IrType :
431
430
dtype_asm = str (self .dtype_to_type (tensor .dtype ))
432
- return MlirType .parse (f"!torch.vtensor<{ list (tensor .size ())} ,{ dtype_asm } >" )
431
+ return IrType .parse (f"!torch.vtensor<{ list (tensor .size ())} ,{ dtype_asm } >" )
433
432
434
433
def get_node_location (self , node : torch_fx .Node ) -> Optional [Location ]:
435
434
stack_trace = node .meta .get ("stack_trace" )
@@ -844,7 +843,7 @@ def _import_list_argument(
844
843
else :
845
844
list_type = PY_TYPE_TO_TORCH_LIST_TYPE [element_type ]
846
845
847
- result_type = MlirType .parse (list_type , context = self ._c )
846
+ result_type = IrType .parse (list_type , context = self ._c )
848
847
operation = Operation .create (
849
848
"torch.prim.ListConstruct" ,
850
849
results = [result_type ],
@@ -869,44 +868,8 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value:
869
868
return cvt (arg , self , self ._cc )
870
869
871
870
872
- class TypeSubclassMap :
873
- """Mapping of super-types to values.
874
-
875
- Maintains a cache of actual types seen and uses that instead of a linear
876
- scan.
877
- """
878
-
879
- __slots__ = [
880
- "_cache" ,
881
- "_mapping" ,
882
- ]
883
-
884
- def __init__ (self ):
885
- # The linear list of converters.
886
- self ._mapping : List [Tuple [type , Any ]] = []
887
- # When there is a hit on the linear mapping, memoize it here.
888
- self ._cache : Dict [type , Any ] = {}
889
-
890
- def map (self , t : type , value : Any ):
891
- self ._mapping .append ((t , value ))
892
- self ._cache [t ] = value
893
-
894
- def lookup (self , t : type ) -> Any :
895
- try :
896
- return self ._cache [t ]
897
- except KeyError :
898
- pass
899
- for t_super , value in self ._mapping :
900
- if issubclass (t , t_super ):
901
- self ._cache [t ] = value
902
- return value
903
- else :
904
- self ._cache [t ] = None
905
- return None
906
-
907
-
908
871
def _make_constant_op (
909
- op_name : str , value_attr : MlirAttribute , result_type : Optional [MlirType ] = None
872
+ op_name : str , value_attr : Attribute , result_type : Optional [IrType ] = None
910
873
) -> Operation :
911
874
return Operation .create (
912
875
op_name ,
@@ -915,7 +878,7 @@ def _make_constant_op(
915
878
)
916
879
917
880
918
- def create_mlir_tensor_type (tensor : torch .Tensor ) -> MlirType :
881
+ def create_mlir_tensor_type (tensor : torch .Tensor ) -> IrType :
919
882
try :
920
883
dtype = tensor .dtype
921
884
element_type = TORCH_DTYPE_TO_MLIR_TYPE [dtype ]()
@@ -925,7 +888,7 @@ def create_mlir_tensor_type(tensor: torch.Tensor) -> MlirType:
925
888
raise TypeError (f"Could not map Torch dtype { dtype } to an IREE type" )
926
889
927
890
928
- def _make_vtensor_literal_op (tensor : torch .Tensor , vtensor_type : MlirType ) -> Operation :
891
+ def _make_vtensor_literal_op (tensor : torch .Tensor , vtensor_type : IrType ) -> Operation :
929
892
npy_dtype = TORCH_DTYPE_TO_NPY_TYPE .get (tensor .dtype )
930
893
assert (
931
894
npy_dtype is not None
0 commit comments