From 48a855f127d8576954a552da300bab2b4e7ae2f1 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Wed, 8 Nov 2023 18:59:11 -0800 Subject: [PATCH] Dynamo fx completeness (#162) This PR takes care of https://github.com/nod-ai/SHARK-Turbine/issues/139. For the weird tensor meta data, we should only throw error if quantized. If we start looking at other things like stride, memory_format, etc. the pytest fails. For the TODO where we wanted to throw a KeyError, we already check if parameter.name in node.kwargs in the if statement above, so we won't run into an invalid key. Other issues that were part of this importer completeness task have been addressed and documented. Fixes #139 Fixes #140 Fixes #141 Fixes #142 Fixes #143 Fixes #144 --- python/shark_turbine/dynamo/importer.py | 62 +++++++++++++------------ 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/python/shark_turbine/dynamo/importer.py b/python/shark_turbine/dynamo/importer.py index 4f39e8f90..c2dbfbccc 100644 --- a/python/shark_turbine/dynamo/importer.py +++ b/python/shark_turbine/dynamo/importer.py @@ -353,11 +353,14 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType: val = node.meta.get("val") if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) - # TODO: We should probably only be doing this if "vanilla". - # Specifically, there are strides/qparams/etc on there that - # should be annotated somewhere. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/139 - return self.tensor_metadata_to_type(tensor_meta) + # Quantized tensor meta data is not preserved in our lowering, + # so throw error instead of silently doing wrong thing. + if (tensor_meta.is_quantized): + raise NotImplementedError( + f"Quantized tensor meta data is not supported." + ) + else: + return self.tensor_metadata_to_type(tensor_meta) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta @@ -630,23 +633,16 @@ def _import_torch_op_overload( op_overload = getattr(op_overload, op_attrs[i]) schema = op_overload._schema - if not self._c.is_registered_operation(mlir_op_name): - # TODO: Implement a config setting to allow these to flow through. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/141 - raise NotImplementedError( - f"Unimplemented torch op in the IREE compiler: '{mlir_op_name}' " - f"(either implement this op/variant or configure the compiler to " - f"allow unknown operations and fallback to PyTorch)." - ) - return_count = len(schema.returns) if return_count == 1: # Unary return directly maps a single meta["val"] and cannot be subscripted. # if "tensor_meta" is None, this will throw unsupported placeholder node error result_types = [self._cc.node_val_to_type(node)] elif return_count == 0: - # TODO: Implement (https://github.com/nod-ai/SHARK-Turbine/issues/142) - raise NotImplementedError("FIXME: Zero ATen results") + # Some torch ops do have 0 returns, and these are supported with ZeroResults + # op trait. Python bindings for IR creation allow us to pass empty result_types + # for such ops. Therefore, we pass an empty result types for these cases. + result_types = [] else: # Multi-return will unpack the meta["val"] and trigger our getitem subscripting # short-circuit above. Note that if we ever choose to also fully reify Python @@ -663,8 +659,6 @@ def _import_torch_op_overload( operands = [] for i, parameter in enumerate(schema.arguments): if parameter.kwarg_only and parameter.name in node.kwargs: - # TODO: Nice error if KeyError. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/143 operands.append( self._import_argument( loc, node.kwargs[parameter.name], parameter.type @@ -681,12 +675,23 @@ def _import_torch_op_overload( ) ) - operation = Operation.create( - mlir_op_name, - results=result_types, - operands=operands, - loc=loc, - ) + # Support unregistered torch ops using torch.operator. + # torch.operator is used to represent ops from registry + # which haven't been generated by torch_ods_gen.py. + if not self._c.is_registered_operation(mlir_op_name): + operation = Operation.create( + "torch.operator", + results=result_types, + operands=operands, + loc=loc, + ) + else: + operation = Operation.create( + mlir_op_name, + results=result_types, + operands=operands, + loc=loc, + ) # Record value mapping. for i, value in enumerate(operation.results): @@ -830,18 +835,15 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: if isinstance(arg, list): return self._import_list_argument(loc, arg, expected_jit_type) + # The LITERAL_CONVERTER_MAP maps each arg to its respective constant + # of the expected jit IR type (types like torch.dtype will form a chain of + # maps to get to constant of expected_jit_type). cvt = LITERAL_CONVERTER_MAP.lookup(type(arg)) if cvt is None: raise RuntimeError(f"Unhandled default value ({arg.__class__}): {arg})") with loc: return cvt(arg, self, self._cc) - # TODO: Support torch specific types which show up in function schemas. - # These all require an expected_jit_type to convert. - # torch.dtype, torch.device, torch.memory_format, torch.layout - # list - # See: https://github.com/nod-ai/SHARK-Turbine/issues/144 - class TypeSubclassMap: """Mapping of super-types to values.