From 50060d73ca7a9cd3514656964c0502415dd25420 Mon Sep 17 00:00:00 2001 From: Andrea Faulds Date: Tue, 4 Jun 2024 12:58:16 +0200 Subject: [PATCH] WIP: Create MLIR functions for ONNX operators that are functions Resolves #3384. Many ONNX operators are defined by functions and therefore could be expanded into simpler ONNX operations during importing, avoiding the need for tools downstream to support these operators directly. This commit changes onnx_importer.py to systematically perform this expansion for all ONNX operators that are not explicitly denylisted. When importing a node, the schema for the node's operation is retrieved. If the schema provides a function for the operator, a specialized version for the node's types and attributes will be created and imported as an MLIR function with private visibility. An MLIR function call will then be omitted, instead of a normal operator node. Caching is used to avoid generating redundant functions within the same module. Note that previously all MLIR functions generated by the importer had no visibility specified. This commit changes this: the main function for a model is now public. This is so that the MLIR inliner pass will automatically discard the (private) operator functions after inlining. Explanations for subtle code changes: - Looking up the correct schema and function for an operator requires knowing the opset version. NodeImporter retrieves this from the opset imports on the ModelProto retained by the GraphInfo. Previously, the model_proto field on GraphInfo was None when importing a subgraph in import_regions, but this conflicts with the new need for opset version info. Since the apparent purpose of setting it to None was to control how GraphInfo generates its input map, a new flag is added to GraphInfo (is_subgraph) to control this behavior, so that the actual ModelProto can now be provided without breaking this. - Some operators' functions are context-dependent, which means the function definition depends on the types of the inputs. Therefore node importing now needs to look up the types of a node's inputs, not just its outputs as was the case previously. Consequently the operand to find_type_proto_for_name() may now be a graph input or initializer in some cases, so it has to be updated. --- .../configs/onnx_backend.py | 2 +- python/torch_mlir/extras/onnx_importer.py | 369 ++++++++++++++++-- 2 files changed, 343 insertions(+), 28 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index de39475b0dbbc..9cb051cbfed39 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -92,7 +92,7 @@ def _module_lowering( # Lower from ONNX to Torch run_pipeline_with_repro_report( torch_mod, - f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", + f"builtin.module(inline, func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", "Lowering Onnx backend contract to Linalg-on-Tensors backend contract", ) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index e0d3529d942e9..a90c54516d3ba 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -111,7 +111,12 @@ def create_module(self, context: Optional[Context] = None) -> Module: class GraphInfo: """Information about a Graph within a model.""" - def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): + def __init__( + self, + model_info: ModelInfo, + graph_proto: onnx.GraphProto, + is_subgraph: bool = False, + ): self.model_info = model_info self.graph_proto = graph_proto self.initializer_map: Dict[str, onnx.TensorProto] = { @@ -129,7 +134,11 @@ def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): # Generate the effective input map, which for old models can be a # subset of the input map. - if model_info and model_info.config.elide_initialized_inputs: + if ( + not is_subgraph + and model_info + and model_info.config.elide_initialized_inputs + ): self.input_map = { k: v for k, v in self.declared_input_map.items() @@ -149,8 +158,18 @@ def find_type_proto_for_name(self, name: str) -> onnx.TypeProto: # Node outputs don't typically have type information, but shape inference # will associate them in the value_info. If not there, it may be a # graph output, which must have type information. - value_info = self.value_info_map.get(name) or self.output_map.get(name) - if value_info is not None: + value_info = ( + self.value_info_map.get(name) + or self.output_map.get(name) + or self.declared_input_map.get(name) + ) + if value_info is None: + tensor_proto = self.initializer_map.get(name) + if tensor_proto is not None: + return onnx.helper.make_tensor_type_proto( + tensor_proto.data_type, tensor_proto.dims + ) + else: return value_info.type # No type information is associated, this can occur when the value is unused: return "" @@ -172,6 +191,8 @@ class NodeImporter: __slots__ = [ "_c", "_cc", + "_m", + "_mc", "_gi", "_p", "_b", @@ -185,9 +206,13 @@ def __init__( parent_op: Operation, block: Block, context_cache: "ContextCache", + module_op: Operation, + module_cache: "ModuleCache", ): self._c = parent_op.context self._cc = context_cache + self._m = module_op + self._mc = module_cache self._gi = graph_info self._p = parent_op self._b = block @@ -195,9 +220,19 @@ def __init__( @classmethod def define_function( - cls, graph_info: GraphInfo, module_op: Operation + cls, + graph_info: GraphInfo, + module_op: Operation, + context_cache: Optional["ContextCache"] = None, + module_cache: Optional["ModuleCache"] = None, + public: bool = True, ) -> "NodeImporter": - cc = ContextCache(module_op.context) + cc = ( + context_cache + if context_cache is not None + else ContextCache(module_op.context) + ) + mc = module_cache if module_cache is not None else ModuleCache(module_op, cc) with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): body = module_op.regions[0].blocks[0] func_name = graph_info.graph_proto.name @@ -209,11 +244,23 @@ def define_function( for out in graph_info.output_map.values() ] ftype = FunctionType.get(input_types, output_types) - func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body)) + func_op = func_dialect.FuncOp( + func_name, + ftype, + ip=InsertionPoint(body), + visibility="public" if public else "private", + ) block = func_op.add_entry_block( [Location.name(k) for k in graph_info.input_map.keys()] ) - imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc) + imp = NodeImporter( + graph_info, + parent_op=func_op, + block=block, + context_cache=cc, + module_op=module_op, + module_cache=mc, + ) for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): imp._nv_map[node_name] = input_value imp._populate_graph_attrs(func_op) @@ -293,6 +340,8 @@ def get_none(self): def import_node(self, node: onnx.NodeProto): with InsertionPoint(self._b), Location.name(node.name): op_type = node.op_type + op_domain = node.domain + # Handle special op types that materialize to non-op IR constructs. # Handlers return True if the op was handled, else this function # should process it as a general node. @@ -303,33 +352,57 @@ def import_node(self, node: onnx.NodeProto): return # General node import. input_values = [] + input_type_protos = [] for input_name in node.input: try: input_values.append(self._nv_map[input_name]) + # Missing optional arguments will have empty types + input_type_protos.append( + self._gi.find_type_proto_for_name(input_name) + or onnx.TypeProto() + ) except KeyError: raise OnnxImportError( f"Non topologically produced ONNX node input '{input_name}': {node}" ) - output_names = list(node.output) - output_types = [ - self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n)) - for n in output_names - ] - - attrs = self.import_attributes(node.attribute) - attrs["name"] = StringAttr.get(f"onnx.{op_type}") - regions = self.count_regions(node.attribute) - - custom_op = Operation.create( - name="torch.operator", - results=output_types, - operands=input_values, - attributes=attrs, - regions=regions, + output_names = [] + output_type_protos = [] + output_types = [] + for output_name in node.output: + output_names.append(output_name) + type_proto = self._gi.find_type_proto_for_name(output_name) + output_type_protos.append(type_proto) + output_types.append(self._cc.type_proto_to_type(type_proto)) + + for opset_import in self._gi.model_info.model_proto.opset_import: + if opset_import.domain == op_domain: + opset_version = opset_import.version + break + operator_func_op = self._mc.get_operator_function( + op_type, + op_domain, + opset_version, + input_type_protos, + output_type_protos, + node, ) - self.import_regions(node.attribute, custom_op) + if operator_func_op is not None: + custom_op = func_dialect.CallOp(operator_func_op, input_values) + else: + attrs = self.import_attributes(node.attribute) + attrs["name"] = StringAttr.get(f"onnx.{op_type}") + regions = self.count_regions(node.attribute) + custom_op = Operation.create( + name="torch.operator", + results=output_types, + operands=input_values, + attributes=attrs, + regions=regions, + ) + self.import_regions(node.attribute, custom_op) + for output_name, output_value in zip(output_names, custom_op.results): self._nv_map[output_name] = output_value @@ -387,9 +460,14 @@ def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op): *block_types, arg_locs=[op.location] * len(block_types) ) block = region.blocks[0] - graph_info = GraphInfo(None, attr.g) + graph_info = GraphInfo(self._gi.model_info, attr.g, is_subgraph=True) imp = NodeImporter( - graph_info, parent_op=op, block=block, context_cache=self._cc + graph_info, + parent_op=op, + block=block, + context_cache=self._cc, + module_op=self._m, + module_cache=self._mc, ) for node_name, input_value in zip(block_names, block.arguments): @@ -603,6 +681,13 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: element_type = self.get_optional_element_type(ot.elem_type) return self.get_optional_type(element_type) + # Check if TypeProto is empty (sometimes happens for unused function + # arguments) + if tp.SerializeToString( + deterministic=True + ) == onnx.TypeProto().SerializeToString(deterministic=True): + return self.get_none_type() + # TODO: Others if ever needed. Or we consider ourselves DNN-only. # See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type. raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") @@ -631,6 +716,236 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: return handler(tp) +class ModuleCache: + """Caches per-module lookups of various things.""" + + __slots__ = [ + "_m", + "_cc", + "_operator_function_map", + ] + + def __init__(self, module_op: Operation, context_cache: ContextCache): + self._m = module_op + self._cc = context_cache + self._operator_function_map: Dict[str, func_dialect.FuncOp] = {} + + def get_operator_function( + self, + op_name: str, + op_domain: str, + opset_version: int, + input_type_protos: list[onnx.TypeProto], + output_type_protos: list[onnx.TypeProto], + caller_node: onnx.NodeProto, + ) -> Optional[func_dialect.FuncOp]: + """ + Get or create MLIR function corresponding to an ONNX operator. + + Returns None for ONNX operators that aren't functions. + """ + + # Functions we do not want to attempt expansion for. + DENYLISTS_BY_DOMAIN = { + # Default domain (ONNX built-in ops) + "": set( + [ + # CastLike's second input `target_type` is used only for its + # type (T2), from which its output's type is inferred, but + # because its value is unused, ONNX's shape inference + # doesn't annotate the input value with a type, so looking + # up the function by the provided input types will fail. + "CastLike", + # ONNX errors when trying to infer the type of the Loop op + # within this function: "[ShapeInferenceError] Inferred + # shape and existing shape differ in rank: (1) vs (0)" + "Range", + ] + ) + } + + if ( + op_domain in DENYLISTS_BY_DOMAIN + and op_name in DENYLISTS_BY_DOMAIN[op_domain] + ): + return None + + op_schema = onnx.defs.get_schema( + op_name, domain=op_domain, max_inclusive_version=opset_version + ) + + # The get_schema() lookup above should get the right version of the + # operator definition, but the function body can change slightly + # within a single operator version, as explained in + # https://github.com/onnx/onnx/blob/093a8d335a66ea136eb1f16b3a1ce6237ee353ab/onnx/defs/schema.h#L1070-L1086 + # There also seem to be cases where a function goes from being not + # context-dependent to context-dependent. + f = lambda ver: ver <= opset_version + ncd_function_version = max( + filter(f, op_schema.function_opset_versions), + default=None, + ) + cd_function_version = max( + filter(f, op_schema.context_dependent_function_opset_versions), + default=None, + ) + if ncd_function_version is None and cd_function_version is None: + # No relevant function definition + return None + elif ncd_function_version is not None and ( + cd_function_version is None or cd_function_version < ncd_function_version + ): + specific_version = ncd_function_version + is_context_dependent = False + else: + specific_version = cd_function_version + is_context_dependent = True + + # This is both a key for memoization of function importing and also a + # name mangling scheme, so it must include all information needed to + # uniquely identify a function and anything it might be parameterized + # over. + key = repr( + ( + op_name, + op_domain, + opset_version, + input_type_protos, + # Though output types can be inferred from input types, it does + # not seem to be the case that there's only one legal set of + # outputs for a given set of inputs. When attemtping to always + # use onnx.shape_inference.infer_function_output_types instead + # of the caller-provided types, sometimes IR verification fails + output_type_protos, + # Avoid including the attributes twice (once on their own and + # once as part of the node) for context-dependent functions, + # avoid including unused parts of the node for other functions. + caller_node if is_context_dependent else caller_node.attribute, + ) + ) + + existing = self._operator_function_map.get(key) + if existing is not None: + return existing + + if is_context_dependent: + function_proto_str = ( + op_schema.get_context_dependent_function_with_opset_version( + specific_version, + caller_node.SerializeToString(), + [ + t.SerializeToString() if not isinstance(t, bytes) else t + for t in input_type_protos + ], + ) + ) + else: + function_proto_str = op_schema.get_function_with_opset_version( + specific_version + ) + if not function_proto_str: + raise OnnxImportError( + f"Function lookup for {op_name}/{op_domain}/{specific_version}/{is_context_dependent} failed unexpectedly. This probably indicates a bug." + ) + function_proto = onnx.onnx_pb.FunctionProto() + function_proto.ParseFromString(function_proto_str) + + # An ONNX function may be polymorphic, parameterized over the types of + # its inputs and values of its attributes (~= compile-time constants). + # We need to monomorphize it for importing into MLIR. It seems like the + # only practical way to do this is by turning it into a model: + # - models can have types on their inputs and outputs, unlike functions + # - ONNX provides a function to do shape inference (providing concrete + # types for everything in the body) for models, but not for functions + # - the rest of the code in this importer can only handle models, not + # functions + + tmp_graph_proto = onnx.GraphProto() + + for input_name, input_type_proto in zip( + function_proto.input, input_type_protos + ): + input_proto = onnx.ValueInfoProto() + input_proto.name = input_name + input_proto.type.CopyFrom(input_type_proto) + tmp_graph_proto.input.append(input_proto) + output_proto = onnx.ValueInfoProto() + + for output_name, output_type_proto in zip( + function_proto.output, output_type_protos + ): + output_proto.name = output_name + output_proto.type.CopyFrom(output_type_proto) + tmp_graph_proto.output.append(output_proto) + + call_attributes = caller_node.attribute + for node in function_proto.node: + # Import referenced attributes from call-site or default values + new_node = onnx.NodeProto() + new_node.CopyFrom(node) + old_attributes = list(node.attribute) + # .clear() isn't available on protobuf lists for some reason + while len(new_node.attribute) > 0: + new_node.attribute.pop() + for node_attribute in old_attributes: + if node_attribute.ref_attr_name: + ref_name = node_attribute.ref_attr_name + for call_attribute in call_attributes: + if call_attribute.name == ref_name: + new_attribute = onnx.AttributeProto() + new_attribute.CopyFrom(call_attribute) + new_attribute.name = node_attribute.name + new_node.attribute.append(new_attribute) + break + else: + # The default value seems to sometimes be empty for + # optional attributes that don't have an actual default. + if ( + op_schema.attributes[ref_name].default_value + and op_schema.attributes[ref_name].default_value.type + ): + new_attribute = onnx.AttributeProto() + new_attribute.CopyFrom( + op_schema.attributes[ref_name].default_value + ) + new_attribute.name = node_attribute.name + new_node.attribute.append(new_attribute) + else: + new_node.attribute.append(node_attribute) + tmp_graph_proto.node.append(new_node) + + tmp_graph_proto.name = key + + tmp_model_proto = onnx.ModelProto() + tmp_model_proto.opset_import.extend(function_proto.opset_import) + # FIXME: is this the correct IR version, or should it be the latest, or + # the one used by the actual model, or something else? + tmp_model_proto.ir_version = onnx.helper.find_min_ir_version_for( + function_proto.opset_import + ) + tmp_model_proto.graph.CopyFrom(tmp_graph_proto) + + tmp_model_proto = onnx.shape_inference.infer_shapes( + tmp_model_proto, check_type=True, strict_mode=True, data_prop=True + ) + tmp_graph_proto = tmp_model_proto.graph + + # Useful for debugging. + # onnx.checker.check_model(tmp_model_proto, full_check=True) + + tmp_model_info = ModelInfo(tmp_model_proto) + tmp_graph_info = GraphInfo(tmp_model_info, tmp_graph_proto) + + imp = NodeImporter.define_function( + tmp_graph_info, self._m, self._cc, self, public=False + ) + imp.import_all() + func_op = imp._p + + self._operator_function_map[key] = func_op + return func_op + + ELEM_TYPE_TO_IR_TYPE_CB = { onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(), onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8),