From 167b72da0cabaf538ac0f09d867bc622ae736532 Mon Sep 17 00:00:00 2001 From: Rob Elliott Date: Fri, 22 Sep 2023 10:21:28 -0700 Subject: [PATCH] Summary: Initial TOSA support in executorch (#161) Summary: An implementation of a TOSA Partitioner and TOSA backend supporting a small list of operators in fp32 and int8/32. This includes a set of integer networks which will then compile to target Ethos-U NPU using the Vela compiler. * A small set of test networks in fp32 * A subset of these supporting integer and compiling with Vela * A test wrapper to register and test networks through the compile stack * A flow pt2 -> executorch -> tosa_flatbuffer.tosa file -> tosa_reference_model * A flow pt2 -> executorch -> tosa_flatbuffer.tosa file -> vela -> command_stream * This work depends on a few different python modules * https://review.mlplatform.org/plugins/gitiles/tosa/serialization_lib/ * https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/ * This work uses other projects for validation * https://review.mlplatform.org/plugins/gitiles/tosa/reference_model/ Pull Request resolved: https://github.com/pytorch/executorch/pull/161 Reviewed By: mergennachin Differential Revision: D49542254 Pulled By: digantdesai fbshipit-source-id: bb4074d5b66233c85452c2373a08ddc1cdd0826f --- .gitmodules | 3 + LICENSE | 1 + backends/arm/README.md | 24 + backends/arm/arm_backend.py | 613 +++++++++++++++++++++ backends/arm/test/test_models.py | 151 +++++ backends/arm/test/test_tosa.py | 62 +++ backends/arm/third-party/serialization_lib | 1 + backends/arm/tosa_mapping.py | 102 ++++ examples/README.md | 1 + examples/arm/arm_tosa_e2e.py | 159 ++++++ pyproject.toml | 1 + pytest.ini | 4 +- setup.py | 2 + 13 files changed, 1122 insertions(+), 2 deletions(-) create mode 100644 backends/arm/README.md create mode 100644 backends/arm/arm_backend.py create mode 100644 backends/arm/test/test_models.py create mode 100644 backends/arm/test/test_tosa.py create mode 160000 backends/arm/third-party/serialization_lib create mode 100644 backends/arm/tosa_mapping.py create mode 100644 examples/arm/arm_tosa_e2e.py diff --git a/.gitmodules b/.gitmodules index a36f2ed6a2..980a999eff 100644 --- a/.gitmodules +++ b/.gitmodules @@ -28,3 +28,6 @@ [submodule "backends/xnnpack/third-party/XNNPACK"] path = backends/xnnpack/third-party/XNNPACK url = https://github.com/google/XNNPACK.git +[submodule "backends/arm/third-party/serialization_lib"] + path = backends/arm/third-party/serialization_lib + url = https://git.mlplatform.org/tosa/serialization_lib.git diff --git a/LICENSE b/LICENSE index d79f406aab..aa8394b44d 100644 --- a/LICENSE +++ b/LICENSE @@ -3,6 +3,7 @@ BSD License For "ExecuTorch" software Copyright (c) Meta Platforms, Inc. and affiliates. +Copyright 2023 Arm Limited and/or its affiliates. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/backends/arm/README.md b/backends/arm/README.md new file mode 100644 index 0000000000..01722e613a --- /dev/null +++ b/backends/arm/README.md @@ -0,0 +1,24 @@ +# Executorch Arm/TOSA Delegate + +This subtree contains the Arm Delegate implementation for Executorch. + +This delegate is structured to, over time, support a number of different Arm devices +through an AoT flow which targets multiple Arm IP using the TOSA standard. + +The expected flow is: + * torch.nn.module -> TOSA -> command_stream for fully AoT flows e.g. embedded. + * torch.nn.module -> TOSA for flows supporting a JiT compilation step. + +Current backend support is being developed for TOSA to Ethos-U55/65 via the +ethos-u-vela compilation stack. which follows the fully AoT flow. + +## Layout +- `arm_backend.py` - AoT Partitioner which maps to a subset of Base Inference and Main Inference TOSA profiles, where the subset may be further constrained for early support devices like Ethos-U55. AoT Backend which implements the preprocess step which converts to TOSA and can emit files for ethos-u-vela as shown in `executorch/examples/arm/` +- `test/` - unit test and test support functions +- `third-party/` - source dependencies - currently just on TOSA serialization_lib +- `tosa_mapping.py` - helper functions for mapping edge dialect to TOSA + +## Help & Improvements +If you have problems or questions, or have suggestions for ways to make +implementation and testing better, please reach out to the Arm team developing this delegate, or +create an issue on [github](https://www.github.com/pytorch/executorch/issues). diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py new file mode 100644 index 0000000000..5f90e94d81 --- /dev/null +++ b/backends/arm/arm_backend.py @@ -0,0 +1,613 @@ +# Copyright 2023 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Main implementation of AoT flow to partition and preprocess for Arm target +# backends. Converts via TOSA as an intermediate form supported by AoT and +# JIT compiler flows. +# + +import logging +import operator +import os +import tempfile + +from typing import final, List + +import serializer.tosa_serializer as ts + +import torch +from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) + +from executorch.exir.dialects._ops import ops as exir_ops +from serializer.tosa_serializer import TosaOp +from torch._export.exported_program import ExportedProgram +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + +from torch.fx.passes.operator_support import OperatorSupportBase + +from . import tosa_mapping + +# TOSA backend debug functionality +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) +TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" +if TOSA_DBG_VERBOSE: + logging.basicConfig(level=logging.INFO) + logger.setLevel(logging.INFO) + + +def dbg_node(node): + # Debug output of node information + logger.info("OP") + logger.info(f" op is {node.op}") + logger.info(f" name is {node.name}") + logger.info(f" node target is {node.target}") + logger.info(f" node args is {node.args}") + logger.info(f" node kwargs is {node.kwargs}") + logger.info(" node.meta = ") + for k, v in node.meta.items(): + logger.info(f" '{k}' = {v}") + if type([]) == type(v): + for i in v: + logger.info(f" {i} ") + + +class TOSASupportedOperators(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + supported = node.op == "call_function" and node.target in [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + exir_ops.edge.aten.avg_pool2d.default, + ] + return supported + + +def attr_torch_to_tosa(op, node): + if TosaOp.Op().MATMUL == op: + attr = ts.TosaSerializerAttribute() + attr.MatMulAttribute(0, 0) + return attr + if TosaOp.Op().MUL == op: + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(0) + return attr + return None + + +@final +class ArmPartitioner(Partitioner): + compile_spec = [] + + def __init__(self) -> None: + self.delegation_spec = DelegationSpec(ArmBackend.__name__, self.compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + # Run the CapabilityBasedPartitioner to return the largest possible + # subgraphs containing the nodes with the tags + logger.info("ArmPartitioner::partition") + partition_tags = {} + + capability_partitioner = CapabilityBasedPartitioner( + exported_program.graph_module, + TOSASupportedOperators(), + allows_single_node_partition=True, + ) + partition_list = capability_partitioner.propose_partitions() + for partition in partition_list: + for node in partition.nodes: + tag = f"tag{partition.id}" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + +# Output TOSA flatbuffer and test harness file +def dbg_tosa_dump(tosa_fb, path): + filename = "output.tosa" + + logger.info(f"Emitting debug output to {path}") + + os.makedirs(path, exist_ok=True) + + fb = tosa_fb.serialize() + js = tosa_fb.writeJson(filename) + + f = open(path + filename, "wb") + f.write(fb) + f.close() + + f = open(path + "desc.json", "w") + f.write(js) + f.close() + + +def dbg_fail(node, tosa_fb, path): + dbg_tosa_dump(tosa_fb, path) + logger.warn("Internal error due to poorly handled node:") + dbg_node(node) + logger.warn(f"Debug output captured in '{path}'.") + raise RuntimeError("TOSA Internal Error on node, enable logging for further info") + + +@final +class ArmBackend(BackendDetails): + @staticmethod + def preprocess( # noqa: C901 + edge_program: ExportedProgram, + compile_spec: List[CompileSpec], + ) -> bytes: + logger.info("ArmBackend::preprocess") + + # if a debug/test build capture output files from TOSA stage + path = None + debug_output = False + for spec in compile_spec: + if spec.key == "debug_tosa_path": + path = spec.value.decode() + debug_output = True + + # in non debug builds we still pass files to vela + if path is None: + path = tempfile.mkdtemp(prefix="arm_tosa_") + + # Converted output for this subgraph, serializer needs path early as it emits + # const data directly. Path created and data written only in debug builds. + tosa_fb = ts.TosaSerializer(path) + + for node in edge_program.graph.nodes: + if node.op == "call_function": + # Unpack arguments and convert + inputs = [] + for arg in node.args: + inputs.append(tosa_mapping.TosaArg(arg)) + + # Convert output (this node itself) + outp = tosa_mapping.TosaArg(node) + + # All paths have a single output + tosa_fb.currRegion.currBasicBlock.addTensor( + outp.name, outp.shape, outp.dtype + ) + + op = tosa_mapping.op(node.target) + attr = attr_torch_to_tosa(op, node) + + if op: + # a simple 1:1 mapping of operator taking 2 tensor arguments + assert len(inputs) == 2 + assert inputs[0].dtype == outp.dtype + assert inputs[1].dtype == outp.dtype + tosa_fb.addOperator( + op, [inputs[0].name, inputs[1].name], [outp.name], attr + ) + else: + # A more complex mapping of operator + if exir_ops.edge.aten.addmm.default == node.target: + input = inputs[1] + weight = inputs[2] + bias = inputs[0] + + # Reshape input tensor + # TODO: move shape compatibility promotion to function + # Many TOSA ops require a shape including a batch size so we make the implicit + # batch size from the edge graph explicit in TOSA + input_reshape_res = tosa_fb.addIntermediate( + (1,) + input.shape, outp.dtype + ) + attr_input = ts.TosaSerializerAttribute() + attr_input.ReshapeAttribute((1,) + input.shape) + tosa_fb.addOperator( + TosaOp.Op().RESHAPE, + [input.name], + [input_reshape_res.name], + attr_input, + ) + + # Reshape weight tensor + weight_reshape_res = tosa_fb.addIntermediate( + (1,) + weight.shape, outp.dtype + ) + attr_weight = ts.TosaSerializerAttribute() + attr_weight.ReshapeAttribute((1,) + weight.shape) + tosa_fb.addOperator( + TosaOp.Op().RESHAPE, + [weight.name], + [weight_reshape_res.name], + attr_weight, + ) + + # Reshape bias tensor + bias_reshape_res = tosa_fb.addIntermediate( + ( + 1, + 1, + ) + + bias.shape, + outp.dtype, + ) + attr_bias = ts.TosaSerializerAttribute() + attr_bias.ReshapeAttribute( + ( + 1, + 1, + ) + + bias.shape + ) + tosa_fb.addOperator( + TosaOp.Op().RESHAPE, + [bias.name], + [bias_reshape_res.name], + attr_bias, + ) + + # Add dummy batch 1 to mm_shape + mm_shape = (1, input.shape[0], weight.shape[1]) + # Define Intermediate tensor for MatMul res + mm_res = tosa_fb.addIntermediate(mm_shape, outp.dtype) + + # Add MatMulOp + tosa_fb.addOperator( + TosaOp.Op().MATMUL, + [input_reshape_res.name, weight_reshape_res.name], + [mm_res.name], + attr_torch_to_tosa(TosaOp.Op().MATMUL, node), + ) + + # Add AddOp + add_res = tosa_fb.addIntermediate(mm_shape, outp.dtype) + tosa_fb.addOperator( + TosaOp.Op().ADD, + [bias_reshape_res.name, mm_res.name], + [add_res.name], + None, + ) + + # Reshape final result to original shape + attr_out = ts.TosaSerializerAttribute() + attr_out.ReshapeAttribute(outp.shape) + tosa_fb.addOperator( + TosaOp.Op().RESHAPE, [add_res.name], [outp.name], attr_out + ) + elif exir_ops.edge.aten.permute_copy.default == node.target: + attr = ts.TosaSerializerAttribute() + attr.TransposeAttribute(inputs[1].special) + tosa_fb.addOperator( + TosaOp.Op().TRANSPOSE, [inputs[0].name], [outp.name], attr + ) + elif exir_ops.edge.aten.hardtanh.default == node.target: + attr = ts.TosaSerializerAttribute() + attr.ClampAttribute( + tosa_fb.builder, + int(inputs[1].threshold), + int(inputs[2].threshold), + inputs[1].threshold, + inputs[2].threshold, + ) + tosa_fb.addOperator( + TosaOp.Op().CLAMP, [inputs[0].name], [outp.name], attr + ) + elif exir_ops.edge.aten.convolution.default == node.target: + ## RESHAPE input tensor to NHWC_Order = [0, 2, 3, 1] + NHWC_Order = [0, 2, 3, 1] + attr_input_reshape = ts.TosaSerializerAttribute() + input_shape_NHWC = [inputs[0].shape[i] for i in NHWC_Order] + attr_input_reshape.ReshapeAttribute(input_shape_NHWC) + input_reshaped = tosa_fb.addIntermediate( + input_shape_NHWC, outp.dtype + ) + tosa_fb.addOperator( + TosaOp.Op().RESHAPE, + [inputs[0].name], + [input_reshaped.name], + attr_input_reshape, + ) + + ## CONV2DOp + attr = ts.TosaSerializerAttribute() + # PAD + pad_attr = [val for val in inputs[4].special for _ in (0, 1)] + # Stride + stride_attr = inputs[3].special + # Dilation + dilation_attr = inputs[5].special + attr.ConvAttribute(pad_attr, stride_attr, dilation_attr, 0, 0) + + ## TOSA output shape is [NHWO] (num_batch, height, width, num_output) + NHWO_Order = [0, 2, 3, 1] + out_shape_TOSA_CONV2D = [outp.shape[i] for i in NHWO_Order] + conv2d_res = tosa_fb.addIntermediate( + out_shape_TOSA_CONV2D, outp.dtype + ) + tosa_fb.addOperator( + TosaOp.Op().CONV2D, + [input_reshaped.name, inputs[1].name, inputs[2].name], + [conv2d_res.name], + attr, + ) + + ## Torch output shape is [NOHW] + NOHW_Order = [0, 3, 1, 2] + attr_output_transpose = ts.TosaSerializerAttribute() + attr_output_transpose.TransposeAttribute(NOHW_Order) + tosa_fb.addOperator( + TosaOp.Op().TRANSPOSE, + [conv2d_res.name], + [outp.name], + attr_output_transpose, + ) + elif exir_ops.edge.aten.div.Tensor == node.target: + # Div is implemented as x/y = x*1/y + recip = tosa_fb.addIntermediate( + inputs[1].shape, inputs[1].dtype + ) + tosa_fb.addOperator( + TosaOp.Op().RECIPROCAL, [inputs[1].name], [recip.name] + ) + + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(0) + tosa_fb.addOperator( + TosaOp.Op().MUL, + [inputs[0].name, recip.name], + [outp.name], + attr, + ) + elif ( + exir_ops.edge.aten._native_batch_norm_legit_no_training.default + == node.target + ): + # Decompose batch norm into sequence + ( + activations, + _, + _, + running_mean, + running_var, + momentum, + epsilon, + ) = inputs + + input_dtype = activations.dtype + input_shape = activations.shape + + assert ( + 0.1 == momentum.threshold + ), "Expected 0.1 momentum, not currently encoded into TOSA" + + # %op1 = tosa.SUB(%x, %bmean) + # %op2 = tosa.ADD(%variance, %epsilon_const) + # %op3 = tosa.RSQRT(%op2) + # %op4 = tosa.MUL(%op1, %op3) + # %op5 = tosa.MUL(%op4, %weight) + # %output = tosa.ADD(%op5, %bias) + + # Reshape mean to match rank of activations + mean_reshaped_res = tosa_fb.addIntermediate( + (1,) + + running_mean.shape + + ( + 1, + 1, + ), + input_dtype, + ) + attr_mean = ts.TosaSerializerAttribute() + attr_mean.ReshapeAttribute( + (1,) + + running_mean.shape + + ( + 1, + 1, + ) + ) + tosa_fb.addOperator( + TosaOp.Op().RESHAPE, + [running_mean.name], + [mean_reshaped_res.name], + attr_mean, + ) + + # Subtract mean + int1 = tosa_fb.addIntermediate(input_shape, input_dtype) + tosa_fb.addOperator( + TosaOp.Op().SUB, + [activations.name, mean_reshaped_res.name], + [int1.name], + ) + # Adding eplison to variance + epsilon_const = tosa_fb.addConst( + [1], input_dtype, [epsilon.threshold] + ) + int2 = tosa_fb.addIntermediate(running_var.shape, input_dtype) + tosa_fb.addOperator( + TosaOp.Op().ADD, + [running_var.name, epsilon_const.name], + [int2.name], + ) + # Push downward the variance + int3 = tosa_fb.addIntermediate(running_var.shape, input_dtype) + tosa_fb.addOperator(TosaOp.Op().RSQRT, [int2.name], [int3.name]) + + # Reshape variable to match rank of activations + var_reshaped_res = tosa_fb.addIntermediate( + (1,) + + running_var.shape + + ( + 1, + 1, + ), + input_dtype, + ) + attr_var = ts.TosaSerializerAttribute() + attr_var.ReshapeAttribute( + (1,) + + running_var.shape + + ( + 1, + 1, + ) + ) + tosa_fb.addOperator( + TosaOp.Op().RESHAPE, + [int3.name], + [var_reshaped_res.name], + attr_var, + ) + + # Multiple shifted activations with reciprocal variance + # int4 = tosa_fb.addIntermediate( input_shape, input_dtype ) + tosa_fb.addOperator( + TosaOp.Op().MUL, + [int1.name, var_reshaped_res.name], + [outp.name], + attr_torch_to_tosa(TosaOp.Op().MUL, node), + ) + elif exir_ops.edge.aten.avg_pool2d.default == node.target: + input_tensor = inputs[0] + kernel_size_list = inputs[1].special + stride_size_list = inputs[2].special + try: + pad_size_list = inputs[3].special + except IndexError: + pad_size_list = [0, 0, 0, 0] + + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size_list, + stride=stride_size_list, + pad=pad_size_list, + input_zp=0, + output_zp=0, + accum_dtype=8, + ) # FP32 accum type + + # Torch's input is [N,C,H,W], TOSA is [N, H, W, C], + # Transpose to align with TOSA + NHWC_Order = [0, 2, 3, 1] + attr_input_transpose = ts.TosaSerializerAttribute() + attr_input_transpose.TransposeAttribute(NHWC_Order) + + transeposed_input_shape = [ + input_tensor.shape[i] for i in NHWC_Order + ] + input_transposed = tosa_fb.addIntermediate( + transeposed_input_shape, outp.dtype + ) + tosa_fb.addOperator( + TosaOp.Op().TRANSPOSE, + [input_tensor.name], + [input_transposed.name], + attr_input_transpose, + ) + + avg_pool2d_res_shape = [outp.shape[i] for i in NHWC_Order] + avg_pool2d_res = tosa_fb.addIntermediate( + avg_pool2d_res_shape, outp.dtype + ) + tosa_fb.addOperator( + TosaOp.Op().AVG_POOL2D, + [input_transposed.name], + [avg_pool2d_res.name], + attr, + ) + + # TOSA is [N, H, W, C], Transpose back to Torch's [N, C, H, W] + NCHW_Order = [0, 3, 1, 2] + attr_output_transpose = ts.TosaSerializerAttribute() + attr_output_transpose.TransposeAttribute(NCHW_Order) + tosa_fb.addOperator( + TosaOp.Op().TRANSPOSE, + [avg_pool2d_res.name], + [outp.name], + attr_output_transpose, + ) + elif operator.getitem == node.target: + item_name = inputs[0].name + ## Simply add an identityOp + tosa_fb.addOperator( + TosaOp.Op().IDENTITY, [item_name], [outp.name] + ) + else: + raise RuntimeError(f"Unknown operator {node.target}") + + continue + + elif node.op == "placeholder": + assert ( + node.name == node.target + ), "Expect placeholder name and target to match" + assert 0 == len(node.args), "Can't handle default input values" + + # TODO: this may fail on int64 constant input + inputs = [tosa_mapping.TosaArg(node)] + out = node.name + + if out in edge_program.graph_signature.inputs_to_parameters: + parameter_name = edge_program.graph_signature.inputs_to_parameters[ + node.name + ] + p_data = edge_program.state_dict[parameter_name] + + assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor" + weight_values = p_data.detach().numpy() + tosa_fb.addConst( + inputs[0].shape, inputs[0].dtype, weight_values, name=out + ) + elif out in edge_program.graph_signature.inputs_to_buffers: + parameter_name = edge_program.graph_signature.inputs_to_buffers[ + node.name + ] + p_data = edge_program.state_dict[parameter_name] + + assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor" + weight_values = p_data.detach().numpy() + tosa_fb.addConst( + inputs[0].shape, inputs[0].dtype, weight_values, name=out + ) + else: + # Input argument + tensor = ts.TosaSerializerTensor( + inputs[0].name, + inputs[0].shape, + inputs[0].dtype, + data=None, + placeholderFilename=inputs[0].name + ".npy", + ) + tosa_fb.addInputTensor(tensor) + continue + + elif node.op == "output": + for output in node.args[0]: + tosa_fb.addOutputTensor( + tosa_fb.currRegion.currBasicBlock.tensors[output.name] + ) + continue + + else: + # This will only happen if an unpartitioned graph is passed without + # any checking of compatibility. + dbg_fail(node, tosa_fb, path) + + if debug_output is True: + dbg_tosa_dump(tosa_fb, path) + + # Serialize and return the tosa flatbuffer + fb = tosa_fb.serialize() + return PreprocessResult(processed_bytes=bytes(fb)) diff --git a/backends/arm/test/test_models.py b/backends/arm/test/test_models.py new file mode 100644 index 0000000000..3f66a1d076 --- /dev/null +++ b/backends/arm/test/test_models.py @@ -0,0 +1,151 @@ +# Copyright 2023 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Set of simple models for smoke testing TOSA conversion flow +# + +from enum import Enum + +import torch + +TestList = {} + + +def register_test(cls): + TestList[cls.__name__] = cls() + return cls + + +# Which TOSA profile to target with a model/inputs +# See https://www.mlplatform.org/tosa/tosa_spec.html#_profiles +class TosaProfile(Enum): + BI = 0 # Base Inference + MI = 1 # Main Inference + MT = 2 # Main Training + + +class TorchBuilder: + """The member functions build the PyTorch operators into small networks + for our tests""" + + def __init__(self): + pass + + @register_test + class simple_add(torch.nn.Module): + inputs = { + TosaProfile.BI: (torch.ones(5, dtype=torch.int32),), + TosaProfile.MI: (torch.ones(5),), + } + + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + @register_test + class simple_add_broadcast(torch.nn.Module): + inputs = { + TosaProfile.BI: ( + torch.ones(10, 1, dtype=torch.int32), + torch.ones(10, 10, dtype=torch.int32), + ), + TosaProfile.MI: ( + torch.ones(10, 1), + torch.ones(10, 10), + ), + } + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y + + @register_test + class simple_linear(torch.nn.Module): + inputs = { + # TODO: RuntimeError: mat1 and mat2 must have the same dtype, but got Int and Float + # TosaProfile.BI: ( torch.ones(128,20, dtype=torch.int32), ), + TosaProfile.MI: (torch.ones(128, 20),), + } + + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(20, 30) + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + x = self.fc(x) + x = self.relu6(x) + return x + x + + @register_test + class simple_conv2d(torch.nn.Module): + inputs = { + # TODO: fails input char, bias float + # TosaProfile.BI: ( torch.ones(1,3,256,256, dtype=torch.int8), ), + # TODO: this is segfaulting on model.forward in the nightly torch - disabling for now + # TosaProfile.MI: ( torch.ones(1,3,256,256), ), + } + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=3, out_channels=10, kernel_size=3, stride=1 + ) + + def forward(self, x): + x = self.conv2d(x) + return x + + @register_test + class simple_div(torch.nn.Module): + inputs = { + # TODO: BUG: need to codegen for integer div, current float/recip one is not valid BI + # TosaProfile.BI: ( torch.ones(5, dtype=torch.int8), torch.ones(5, dtype=torch.int8), ), + TosaProfile.MI: ( + torch.ones(5), + torch.ones(5), + ), + } + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.div(x, y) + + @register_test + class simple_batch_norm(torch.nn.Module): + inputs = { + # "RuntimeError: "batch_norm" not implemented for 'Char'" + # TosaProfile.BI: ( torch.ones(20,100,35,45, dtype=torch.int8), ), + TosaProfile.MI: (torch.ones(20, 100, 35, 45),), + } + + def __init__(self): + super().__init__() + self.batch_norm_2d = torch.nn.BatchNorm2d(100, affine=False) + self.eval() + + def forward(self, x): + return self.batch_norm_2d(x) + + @register_test + class simple_avg_pool2d(torch.nn.Module): + inputs = { + # TosaProfile.BI: ( torch.ones(20, 16, 50, 32, dtype=torch.int8), ), + TosaProfile.MI: (torch.ones(20, 16, 50, 32),), + } + + def __init__(self): + super().__init__() + self.avg_pool_2d = torch.nn.AvgPool2d(4, stride=2, padding=0) + + def forward(self, x): + return self.avg_pool_2d(x) diff --git a/backends/arm/test/test_tosa.py b/backends/arm/test/test_tosa.py new file mode 100644 index 0000000000..a04118bae4 --- /dev/null +++ b/backends/arm/test/test_tosa.py @@ -0,0 +1,62 @@ +# Copyright 2023 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Test first-stage conversion to TOSA within the Arm backend. +# + +import unittest + +import executorch.exir as exir +from executorch.backends.arm.arm_backend import ArmPartitioner +from executorch.backends.arm.test.test_models import TestList, TosaProfile + +from executorch.exir.backend.backend_api import to_backend + +# Config for Capturing the weights, will be moved in the future +_CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True) +_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig() + + +class TestBasicNN(unittest.TestCase): + def test_minimal_MI(self): + for test_model in TestList: + print(f"Running test {test_model}") + model, inputs, outputs = prepare_model_and_ref(test_model, TosaProfile.MI) + if inputs is None: + print(" Skipping, no inputs for this profile") + continue + model_edge, exec_prog = export_model(model, inputs, []) + # TODO: check there is a tosa delegate blob in the output + + def test_minimal_BI(self): + for test_model in TestList: + print(f"Running test {test_model}") + model, inputs, outputs = prepare_model_and_ref(test_model, TosaProfile.BI) + if inputs is None: + print(" Skipping, no inputs for this profile") + continue + model_edge, exec_prog = export_model(model, inputs, []) + # TODO: check there is a tosa delegate blob in the output + + +def prepare_model_and_ref(test_model, profile=TosaProfile.MI): + model = TestList[test_model] + model_inputs = model.inputs.get(profile) + if model_inputs is not None: + model_outputs = model.forward(*model_inputs) + return model, model_inputs, model_outputs + return model, model_inputs, None + + +def export_model(model, inputs, compile_spec): + model_capture = exir.capture(model, inputs, _CAPTURE_CONFIG) + model_edge = model_capture.to_edge(_EDGE_COMPILE_CONFIG) + ArmPartitioner.compile_spec = compile_spec + model_edge.exported_program = to_backend( + model_edge.exported_program, ArmPartitioner + ) + exec_prog = model_edge.to_executorch() + return model_edge, exec_prog diff --git a/backends/arm/third-party/serialization_lib b/backends/arm/third-party/serialization_lib new file mode 160000 index 0000000000..9601cbda5f --- /dev/null +++ b/backends/arm/third-party/serialization_lib @@ -0,0 +1 @@ +Subproject commit 9601cbda5ff42dc4762e364d90093670931e1261 diff --git a/backends/arm/tosa_mapping.py b/backends/arm/tosa_mapping.py new file mode 100644 index 0000000000..26cb1bb581 --- /dev/null +++ b/backends/arm/tosa_mapping.py @@ -0,0 +1,102 @@ +# Copyright 2023 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# PyTorch to Tosa mapping - simple mapping functions and multi-type extraction +# of key information. These are used by the initial compile stage which captures +# the standardised TOSA representation. +# + +import serializer.tosa_serializer as ts +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from serializer.tosa_serializer import TosaOp + + +def map_dtype(data_type): + unsupported = ( + torch.float64, + torch.double, + torch.complex64, + torch.cfloat, + torch.complex128, + torch.cdouble, + torch.uint8, + torch.int64, + torch.long, + ) + + dmap = { + torch.float32: ts.DType.FP32, + torch.float: ts.DType.FP32, + torch.float16: ts.DType.FP16, + torch.half: ts.DType.FP16, + torch.bfloat16: ts.DType.BF16, + torch.int8: ts.DType.INT8, + torch.int16: ts.DType.INT16, + torch.short: ts.DType.INT16, + torch.int32: ts.DType.INT32, + torch.int: ts.DType.INT32, + torch.bool: ts.DType.BOOL, + } + + assert unsupported.count(data_type) == 0, "Unsupported type" + rtype = dmap.get(data_type) + assert rtype is not None, "Unknown type" + return rtype + + +# Returns the shape and type of a node +# TODO: other types, can be +# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None +def extract_tensor_meta(thing): + if type(thing) is tuple: + # TODO: should use first concrete representation + thing = thing[0] + + assert torch._subclasses.fake_tensor.FakeTensor == type(thing) + + dtype = map_dtype(thing.dtype) + shape = tuple(thing.size()) + return (dtype, shape) + + +def op(op): + ops = {exir_ops.edge.aten.add.Tensor: TosaOp.Op().ADD} + return ops.get(op, None) + + +# Class to capture arguments and turn into tensor references for TOSA OPs +class TosaArg: + def process_node(self, argument): + assert isinstance(argument, torch.fx.node.Node) + assert argument.meta.get("val") is not None + self.name = argument.name + self.dtype, self.shape = extract_tensor_meta(argument.meta["val"]) + + def process_list(self, argument): + self.special = list(argument) + + def process_float(self, argument): + self.threshold = argument + + def __init__(self, argument) -> None: + self.name = None + self.dtype = None + self.shape = None + self.special = None + self.threshold = None + + if isinstance(argument, torch.fx.node.Node): + self.process_node(argument) + return + if issubclass(type(argument), list): + self.process_list(argument) + return + if isinstance(type(argument), type(float)): + self.process_float(argument) + return + + RuntimeError("Unhandled node input argument") diff --git a/examples/README.md b/examples/README.md index e54980d9c3..6bd71fdb39 100644 --- a/examples/README.md +++ b/examples/README.md @@ -15,6 +15,7 @@ examples ├── ios_demo_apps # Contains iOS demo apps ├── models # Contains a set of simple to PyTorch models ├── quantization # Contains examples of quantization workflow +├── arm # Contains examples of the Arm TOSA and Ethos-U NPU flows └── README.md # This file ``` diff --git a/examples/arm/arm_tosa_e2e.py b/examples/arm/arm_tosa_e2e.py new file mode 100644 index 0000000000..a91cf89f5b --- /dev/null +++ b/examples/arm/arm_tosa_e2e.py @@ -0,0 +1,159 @@ +# Copyright 2023 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import subprocess +import tempfile + +import numpy as np +from executorch.backends.arm.test.test_models import TestList, TosaProfile +from executorch.backends.arm.test.test_tosa import export_model, prepare_model_and_ref + +from executorch.exir.backend.compile_spec_schema import CompileSpec + +# Assumes you have these two tools on your path +TOSA_REF_MODEL_PATH = "tosa_reference_model" +VELA_COMPILER_PATH = "vela" + +# Temp directory that any debug output is written to +DEBUG_OUTPUT_PATH = tempfile.mkdtemp(prefix="arm_tosa_") + + +def tosa_ref_dump_inputs(model_edge, inputs, path): + # Emit TOSA test data from the model inputs - assumes whole graph lowered so we just have + # placeholders for the TOSA delegate. Emits data in tosa_ref_model expected layout. + # - Skips placeholders which are encoded as constants (i.e. are already captured weights) + # - Assumes argument order is fixed + argument_names = [] + for node in model_edge.exported_program.graph.nodes: + gs = model_edge.exported_program.graph_signature + if node.op == "placeholder": + print("got placholder", node.target) + if node.name in gs.inputs_to_parameters: + pass + elif node.name in gs.inputs_to_buffers: + pass + else: + argument_names.append(node.name) + else: + break + + for arg in zip(argument_names, inputs): + name = arg[0] + data = arg[1].detach().numpy() + file_path = path + "/" + name + ".npy" + np.save(file_path, data, allow_pickle=False) + + +def tosa_run_test(op, profile=TosaProfile.MI): # noqa: C901 + # + # Minimal sequence to take model through TosaPartitioner and emit + # tosaout/ debug directory containing the flatbuffer - assumes one and will only save last output + # tosaout is generated even for partial/broken subgraph capture to aid in debg + # delegated.pte containing the flatbuffer within the executorch flatbuffer binary + # + print(f"\n\033[96mProcessing:::{op}\033[0m") + print(f"\033[96m Debug output path for intermediates: {DEBUG_OUTPUT_PATH}\033[0m") + + # Debug output for TORCH + TORCH_OUT_PATH = os.path.join(DEBUG_OUTPUT_PATH, op, "torch", "") + os.makedirs(TORCH_OUT_PATH, exist_ok=True) + + # Debug output for TOSA + TOSA_OUT_PATH = os.path.join(DEBUG_OUTPUT_PATH, op, "tosa", "") + os.makedirs(TOSA_OUT_PATH, exist_ok=True) + + # Debug flag for compilers + compile_spec = [CompileSpec("debug_tosa_path", bytes(TOSA_OUT_PATH, "utf8"))] + + model, inputs, torch_output = prepare_model_and_ref(op, profile) + + if inputs is None: + print("\033[96m Skipping, no inputs for TOSA profile \033[0m") + return + + captured_model, exec_prog = export_model(model, inputs, compile_spec) + + # Save ground truth results to file + with open(TORCH_OUT_PATH + "/torch_output.npy", "wb") as f: + np.save(f, torch_output.detach().numpy()) + + tosa_ref_dump_inputs(captured_model, inputs, TOSA_OUT_PATH) + + print(TORCH_OUT_PATH, TOSA_OUT_PATH) + + # this is the .pte binary file + with open(TORCH_OUT_PATH + "/delegated.pte", "wb") as fh: + fh.write(exec_prog.buffer) + + # Convert TOSA Flatbuffer into JSON format for human debugging + cmd_flatc = ( + "flatc" + + " -o " + + TOSA_OUT_PATH + + " --raw-binary -t ./backends/arm/third-party/serialization_lib/schema/tosa.fbs -- " + + TOSA_OUT_PATH + + "/output.tosa" + ) + subprocess.run([cmd_flatc], shell=True, check=True) + + ### Run the TOSA flatbuffer through TOSA Ref_Model and print the results + DESC_FILE_NAME = "/desc.json" + DESC_FILE_PATH = TOSA_OUT_PATH + DESC_FILE_NAME + cmd_ref_model = TOSA_REF_MODEL_PATH + " --test_desc " + DESC_FILE_PATH + subprocess.run([cmd_ref_model], shell=True, check=True) + + ## Load in the JSON File, Read the tosa output + desc_file = open(DESC_FILE_PATH) + desc_json = json.load(desc_file) + tosa_out_filenames = desc_json["ofm_file"] + for tosa_out_fm_file_name in tosa_out_filenames: + f = open(TOSA_OUT_PATH + "/" + tosa_out_fm_file_name, "rb") + tosa_output = np.load(f) + + ## Read the Torch Output + torch_file = open(TORCH_OUT_PATH + "/torch_output.npy", "rb") + torch_output = np.load(torch_file) + + ## Compare Tosa and Torch Results + if np.allclose(tosa_output, torch_output, 1e-1, equal_nan=True): + print( + "\033[92m" + + "Torch and Tosa Reference results are matching for operator: " + + op + + "\033[0m" + ) + else: + print("\033[91m" + "Sorry, Torch and Tosa Reference Results Do not Match!") + print("============================") + print("TOSA Output Shape is: " + str(tosa_output.shape)) + print("TOSA Output is: ") + print(tosa_output) + print("\033[93m") + print("============================") + print("Torch Output Shape is: " + str(torch_output.shape)) + print("Torch Output is: ") + print(torch_output) + print("\033[0m") + + if profile == TosaProfile.BI: + cmd_vela = "cd " + TOSA_OUT_PATH + "; " + VELA_COMPILER_PATH + " ./output.tosa" + try: + subprocess.run([cmd_vela], shell=True, check=True) + print("\033[92m" + "Vela compile worked for: " + op + "\033[0m") + except: + print("\033[91m" + "Vela compile failed for: " + op + "\033[0m") + else: + print("\033[96m" + "Skipping Vela test on non-BI profile." + "\033[0m") + + +# Temp systest mode for running all models against both inference profiles +if __name__ == "__main__": + for op in TestList: + tosa_run_test(op, profile=TosaProfile.MI) + + for op in TestList: + tosa_run_test(op, profile=TosaProfile.BI) diff --git a/pyproject.toml b/pyproject.toml index 9617004c8b..a93312cd1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ version = "0.1.0" # Python dependencies required for development dependencies=[ "expecttest", + "flatbuffers", "hypothesis", "numpy", "packaging", diff --git a/pytest.ini b/pytest.ini index 664fa249aa..df818eba0a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -34,8 +34,8 @@ addopts = # kernels/ kernels/prim_ops/test/prim_ops_test.py kernels/test/test_case_gen.py - # backends/tosa - backends/tosa/test + # backends/arm + backends/arm/test # run the same tests multiple times to determine their # flakiness status. Default to 50 re-runs diff --git a/setup.py b/setup.py index 8eb0e9ab65..ece11e955c 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,8 @@ def run(self): "executorch/schema": "schema", "executorch/extension": "extension", "executorch/bundled_program": "bundled_program", + "tosa": "backends/arm/third-party/serialization_lib/python/tosa", + "serializer": "backends/arm/third-party/serialization_lib/python/serializer", }, cmdclass={ "install": CustomInstallCommand,