Skip to content

Commit

Permalink
data type utility for dynamo/aot
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Nov 2, 2023
1 parent f603634 commit b135238
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion python/shark_turbine/dynamo/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
Block,
Context,
FloatAttr,
BF16Type,
ComplexType,
F16Type,
F32Type,
F64Type,
FunctionType,
InsertionPoint,
IntegerAttr,
Expand Down Expand Up @@ -95,6 +100,24 @@
torch.complex128: "complex<f64>",
}

TORCH_DTYPE_TO_IREE_TYPE: Dict[torch.dtype, Callable[[], MlirType]] = {
torch.float16: lambda: F16Type.get(),
torch.bfloat16: lambda: BF16Type.get(),
torch.float32: lambda: F32Type.get(),
torch.float64: lambda: F64Type.get(),
torch.uint8: lambda: IntegerType.get_unsigned(8),
torch.int8: lambda: IntegerType.get_signed(8),
torch.int16: lambda: IntegerType.get_signed(16),
torch.int32: lambda: IntegerType.get_signed(32),
torch.int64: lambda: IntegerType.get_signed(64),
torch.bool: lambda: IntegerType.get_signless(1),
torch.qint8: lambda: IntegerType.get_signless(8),
torch.quint8: lambda: IntegerType.get_signless(8),
torch.complex32: lambda: ComplexType.get(F16Type.get()),
torch.complex64: lambda: ComplexType.get(F32Type.get()),
torch.complex128: lambda: ComplexType.get(F64Type.get()),
}

TORCH_DTYPE_TO_NPY_TYPE = {
# torch.qint8: None, # no equivalent np datatype
# torch.quint8: None,
Expand Down Expand Up @@ -892,7 +915,6 @@ def _make_constant_op(

def create_iree_tensor_type(tensor: torch.Tensor) -> MlirType:
try:
from ..aot.support.ir_utils import TORCH_DTYPE_TO_IREE_TYPE
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_IREE_TYPE[dtype]()
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
Expand Down

0 comments on commit b135238

Please sign in to comment.