diff --git a/python/shark_turbine/dynamo/importer.py b/python/shark_turbine/dynamo/importer.py index 4f39e8f90..c2dbfbccc 100644 --- a/python/shark_turbine/dynamo/importer.py +++ b/python/shark_turbine/dynamo/importer.py @@ -353,11 +353,14 @@ def node_val_to_type(self, node: torch_fx.Node) -> MlirType: val = node.meta.get("val") if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) - # TODO: We should probably only be doing this if "vanilla". - # Specifically, there are strides/qparams/etc on there that - # should be annotated somewhere. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/139 - return self.tensor_metadata_to_type(tensor_meta) + # Quantized tensor meta data is not preserved in our lowering, + # so throw error instead of silently doing wrong thing. + if (tensor_meta.is_quantized): + raise NotImplementedError( + f"Quantized tensor meta data is not supported." + ) + else: + return self.tensor_metadata_to_type(tensor_meta) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta @@ -630,23 +633,16 @@ def _import_torch_op_overload( op_overload = getattr(op_overload, op_attrs[i]) schema = op_overload._schema - if not self._c.is_registered_operation(mlir_op_name): - # TODO: Implement a config setting to allow these to flow through. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/141 - raise NotImplementedError( - f"Unimplemented torch op in the IREE compiler: '{mlir_op_name}' " - f"(either implement this op/variant or configure the compiler to " - f"allow unknown operations and fallback to PyTorch)." - ) - return_count = len(schema.returns) if return_count == 1: # Unary return directly maps a single meta["val"] and cannot be subscripted. # if "tensor_meta" is None, this will throw unsupported placeholder node error result_types = [self._cc.node_val_to_type(node)] elif return_count == 0: - # TODO: Implement (https://github.com/nod-ai/SHARK-Turbine/issues/142) - raise NotImplementedError("FIXME: Zero ATen results") + # Some torch ops do have 0 returns, and these are supported with ZeroResults + # op trait. Python bindings for IR creation allow us to pass empty result_types + # for such ops. Therefore, we pass an empty result types for these cases. + result_types = [] else: # Multi-return will unpack the meta["val"] and trigger our getitem subscripting # short-circuit above. Note that if we ever choose to also fully reify Python @@ -663,8 +659,6 @@ def _import_torch_op_overload( operands = [] for i, parameter in enumerate(schema.arguments): if parameter.kwarg_only and parameter.name in node.kwargs: - # TODO: Nice error if KeyError. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/143 operands.append( self._import_argument( loc, node.kwargs[parameter.name], parameter.type @@ -681,12 +675,23 @@ def _import_torch_op_overload( ) ) - operation = Operation.create( - mlir_op_name, - results=result_types, - operands=operands, - loc=loc, - ) + # Support unregistered torch ops using torch.operator. + # torch.operator is used to represent ops from registry + # which haven't been generated by torch_ods_gen.py. + if not self._c.is_registered_operation(mlir_op_name): + operation = Operation.create( + "torch.operator", + results=result_types, + operands=operands, + loc=loc, + ) + else: + operation = Operation.create( + mlir_op_name, + results=result_types, + operands=operands, + loc=loc, + ) # Record value mapping. for i, value in enumerate(operation.results): @@ -830,18 +835,15 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: if isinstance(arg, list): return self._import_list_argument(loc, arg, expected_jit_type) + # The LITERAL_CONVERTER_MAP maps each arg to its respective constant + # of the expected jit IR type (types like torch.dtype will form a chain of + # maps to get to constant of expected_jit_type). cvt = LITERAL_CONVERTER_MAP.lookup(type(arg)) if cvt is None: raise RuntimeError(f"Unhandled default value ({arg.__class__}): {arg})") with loc: return cvt(arg, self, self._cc) - # TODO: Support torch specific types which show up in function schemas. - # These all require an expected_jit_type to convert. - # torch.dtype, torch.device, torch.memory_format, torch.layout - # list - # See: https://github.com/nod-ai/SHARK-Turbine/issues/144 - class TypeSubclassMap: """Mapping of super-types to values.