From 09207a027927dabcca01904d8b0ae412cece7cc6 Mon Sep 17 00:00:00 2001 From: xinetzone Date: Wed, 2 Aug 2023 13:47:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- doc/tutorials/basic/start/te-mm.ipynb | 4 +- .../start/tensor-program-abstraction.ipynb | 2 +- doc/tutorials/datasets/imagenet.ipynb | 2 +- .../auto-quantize/AutomatedQuantization.ipynb | 2 +- doc/tutorials/quantize/auto-quantize/index.md | 1 + .../quantize/auto-quantize/intro.ipynb | 2 +- .../quantize/canonicalizations.ipynb | 2 +- .../fake-quantization-to-integer.ipynb | 2 +- doc/tutorials/quantize/parse.ipynb | 416 ++++++++++++++++++ doc/tutorials/quantize/resnet18.ipynb | 2 +- .../relay/frontend/from-tensorflow.ipynb | 2 +- .../frontend/from-tf_slim/tf2-keras.ipynb | 2 +- .../relay/frontend/from-tf_slim/tf2.ipynb | 2 +- doc/tutorials/relay/frontend/pb2onnx.ipynb | 2 +- .../image/{process.py => processing.py} | 0 src/tvm_book/tvm_utils/split_graph.py | 146 ++++++ 16 files changed, 576 insertions(+), 13 deletions(-) create mode 100644 doc/tutorials/quantize/parse.ipynb rename src/tvm_book/image/{process.py => processing.py} (100%) create mode 100644 src/tvm_book/tvm_utils/split_graph.py diff --git a/doc/tutorials/basic/start/te-mm.ipynb b/doc/tutorials/basic/start/te-mm.ipynb index af0974ae..361a070c 100644 --- a/doc/tutorials/basic/start/te-mm.ipynb +++ b/doc/tutorials/basic/start/te-mm.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# 2. TE (Tensor Expression) 实现矩阵乘法\n", + "# TE (Tensor Expression) 实现矩阵乘法\n", "\n", "## 用 TE 实现原始程序\n", "\n", @@ -282,7 +282,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/doc/tutorials/basic/start/tensor-program-abstraction.ipynb b/doc/tutorials/basic/start/tensor-program-abstraction.ipynb index e82751f0..d3527bdb 100644 --- a/doc/tutorials/basic/start/tensor-program-abstraction.ipynb +++ b/doc/tutorials/basic/start/tensor-program-abstraction.ipynb @@ -641,7 +641,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/doc/tutorials/datasets/imagenet.ipynb b/doc/tutorials/datasets/imagenet.ipynb index eb7c0ec8..39c8a391 100644 --- a/doc/tutorials/datasets/imagenet.ipynb +++ b/doc/tutorials/datasets/imagenet.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# TensorFlow 下的 ImageNet\n", + "# ImageNet 接口\n", "\n", "## ImageNet 标签信息\n", "\n", diff --git a/doc/tutorials/quantize/auto-quantize/AutomatedQuantization.ipynb b/doc/tutorials/quantize/auto-quantize/AutomatedQuantization.ipynb index ce4cd309..038401a7 100644 --- a/doc/tutorials/quantize/auto-quantize/AutomatedQuantization.ipynb +++ b/doc/tutorials/quantize/auto-quantize/AutomatedQuantization.ipynb @@ -603,7 +603,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/doc/tutorials/quantize/auto-quantize/index.md b/doc/tutorials/quantize/auto-quantize/index.md index 6885ec02..8de092ee 100644 --- a/doc/tutorials/quantize/auto-quantize/index.md +++ b/doc/tutorials/quantize/auto-quantize/index.md @@ -3,4 +3,5 @@ ```{toctree} intro AutomatedQuantization +parse ``` \ No newline at end of file diff --git a/doc/tutorials/quantize/auto-quantize/intro.ipynb b/doc/tutorials/quantize/auto-quantize/intro.ipynb index 222f00e0..46c90b9c 100644 --- a/doc/tutorials/quantize/auto-quantize/intro.ipynb +++ b/doc/tutorials/quantize/auto-quantize/intro.ipynb @@ -840,7 +840,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/doc/tutorials/quantize/canonicalizations.ipynb b/doc/tutorials/quantize/canonicalizations.ipynb index 8c12ddc4..8ca903b0 100644 --- a/doc/tutorials/quantize/canonicalizations.ipynb +++ b/doc/tutorials/quantize/canonicalizations.ipynb @@ -237,7 +237,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/doc/tutorials/quantize/fake-quantization-to-integer.ipynb b/doc/tutorials/quantize/fake-quantization-to-integer.ipynb index 7c6eb166..94f93718 100644 --- a/doc/tutorials/quantize/fake-quantization-to-integer.ipynb +++ b/doc/tutorials/quantize/fake-quantization-to-integer.ipynb @@ -1211,7 +1211,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/doc/tutorials/quantize/parse.ipynb b/doc/tutorials/quantize/parse.ipynb new file mode 100644 index 00000000..5d2f56d8 --- /dev/null +++ b/doc/tutorials/quantize/parse.ipynb @@ -0,0 +1,416 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TVM 自动量化过程剖析\n", + "\n", + "以 PyTorch 的 resnet18 模型为例剖析 TVM 自动量化过程。\n", + "\n", + "## PyTorch 模型翻译为 relay 模型\n", + "\n", + "加载 PyTorch 模型:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from torchvision.models import resnet18, ResNet18_Weights\n", + "import torch\n", + "\n", + "def load_model(input_shape):\n", + " model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).eval()\n", + " data = torch.randn(*input_shape)\n", + " model = torch.jit.trace(model, data)\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "PyTorch 模型翻译为 relay 模型:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from tvm import relay\n", + "\n", + "\n", + "input_shape = 1, 3, 224, 224\n", + "input_name = \"data\"\n", + "traced_model = load_model(input_shape)\n", + "mod, params = relay.frontend.from_pytorch(\n", + " traced_model, \n", + " [(input_name, input_shape)], \n", + " # use_parser_friendly_name=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 加载数据" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import tvm\n", + "import tvm.testing\n", + "from tvm import relay\n", + "from tvm.relay import transform, build_module\n", + "from tvm.relay.testing import run_opt_pass" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import dataclass\n", + "from pathlib import Path\n", + "from PIL import Image\n", + "import numpy as np\n", + "from tvm_book.data.classification import ImageFolderDataset\n", + "\n", + "def preprocess_image(\n", + " image: np.ndarray,\n", + " size: tuple[int] = (224, 224),\n", + " mean: tuple[float] = (0.485, 0.456, 0.406),\n", + " std: tuple[float] = (0.229, 0.224, 0.225)\n", + "):\n", + " im = Image.fromarray(image)\n", + " im = im.resize((256, 256), Image.Resampling.BILINEAR)\n", + " ori_H, ori_W = im.size\n", + " H, W = size\n", + " space_W, space_H = (ori_W - W)//2, (ori_H - H)//2\n", + " im = im.crop((space_W, space_H, ori_W-space_W, ori_H-space_H))\n", + " image = np.array(im)\n", + " im.close()\n", + " image = image/256\n", + " image -= mean\n", + " image /= std\n", + " return image.astype(np.float32)\n", + "\n", + "\n", + "@dataclass\n", + "class ImageNet:\n", + " root: str\n", + " size: tuple[int] = (224, 224)\n", + " mean: tuple[float] = (0.485, 0.456, 0.406)\n", + " std: tuple[float] = (0.229, 0.224, 0.225)\n", + "\n", + " def __post_init__(self):\n", + " self.root = Path(self.root) # 数据根目录\n", + " self.valset = ImageFolderDataset(f\"{self.root}/val\")\n", + " self.trainset = ImageFolderDataset(f\"{self.root}/train\")\n", + "\n", + " def calibrateset(self, calibrate_num: int = 200):\n", + " \"\"\"用于 TVM 量化的校准数据集\n", + " \"\"\"\n", + " for k, (data, label) in enumerate(self.trainset):\n", + " if k >= calibrate_num:\n", + " break\n", + " image = preprocess_image(data, self.size, self.mean, self.std)\n", + " images = np.expand_dims(image, 0)\n", + " images = images.transpose((0, 3, 1, 2))\n", + " yield {\"data\": images}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = ImageNet(\"/media/pc/data/lxw/home/data/datasets/ILSVRC/\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 子图分割" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import tvm\n", + "# 参考:https://tvm.apache.org/docs/how_to/work_with_relay/using_pipeline_executor.html?highlight=graph_split\n", + "from tvm_book.tvm_utils.split_graph import graph_split" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "split_conf = [{\"op_name\": \"nn.max_pool2d\", \"op_index\": 0}]\n", + "pipeline_mods = graph_split(mod[\"main\"], split_conf, params)\n", + "run_mod = pipeline_mods[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'\n", + "To print formatted TVM script, please install the formatter 'Black':\n", + "/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install \"black==22.3.0\" --upgrade --user\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) {\n",
+       "  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
+       "  %1 = nn.batch_norm(%0, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */, meta[relay.Constant][4] /* ty=Tensor[(64), float32] */) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) */;\n",
+       "  %2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
+       "  %3 = nn.relu(%2) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n",
+       "  nn.max_pool2d(%3, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "run_mod.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 量化模型" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# 简化\n", + "run_mod = relay.quantize.prerequisite_optimize(run_mod, params)\n", + "# 分区\n", + "partition_mod = relay.quantize.partition()(run_mod)\n", + "# 校准\n", + "calibrate_pass = tvm.transform.module_pass(\n", + " relay.quantize.calibrate(dataset.calibrateset(calibrate_num=200)), opt_level=1, name=\"QuantizeCalibrate\"\n", + ")\n", + "calibrate_mod = tvm.transform.Sequential([relay.quantize.annotate(), calibrate_pass,])(partition_mod)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fn (%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {\n", + " %0 = relay.op.annotation.simulated_quantize(%data, 0.0625f, -127f, 127f, kind=1) /* ty=Tensor[(1, 3, 224, 224), float32] */;\n", + " %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, 0.015625f, -127f, 127f, kind=2) /* ty=Tensor[(64, 3, 7, 7), float32] */;\n", + " %2 = nn.conv2d(%0, %1, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n", + " %3 = relay.op.annotation.simulated_quantize(%2, 0.0625f, -127f, 127f, kind=1) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n", + " %4 = relay.op.annotation.simulated_quantize(meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), float32] */, 0.00390625f, -127f, 127f, kind=2) /* ty=Tensor[(64, 1, 1), float32] */;\n", + " %5 = multiply(%3, %4) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n", + " %6 = relay.op.annotation.simulated_quantize(meta[relay.Constant][2] /* ty=Tensor[(64, 1, 1), float32] */, 0.0078125f, -127f, 127f, kind=2) /* ty=Tensor[(64, 1, 1), float32] */;\n", + " %7 = add(%5, %6) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n", + " %8 = nn.relu(%7) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n", + " %9 = relay.op.annotation.simulated_quantize(%8, 0.0625f, -127f, 127f, kind=1) /* ty=Tensor[(1, 64, 112, 112), float32] */;\n", + " %10 = nn.max_pool2d(%9, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float32] */;\n", + " %11 = annotation.cast_hint(%10, dtype=\"int8\") /* ty=Tensor[(1, 64, 56, 56), float32] */;\n", + " annotation.stop_fusion(%11) /* ty=Tensor[(1, 64, 56, 56), float32] */\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "print(calibrate_mod[\"main\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# from tvm.ir import IRModule, structural_equal\n", + "import tvm\n", + "\n", + "with tvm.transform.PassContext(opt_level=3):\n", + " with relay.quantize.qconfig(\n", + " skip_conv_layers=[],\n", + " calibrate_mode=\"kl_divergence\", \n", + " weight_scale=\"max\",\n", + " round_for_shift=True,\n", + " # rounding=\"TONEAREST\", # \"UPWARD\" or \"TONEAREST\"\n", + " calibrate_skip_layers=[],\n", + " skip_dense_layer=False,\n", + " ):\n", + " qmod = relay.quantize.quantize(run_mod, params, dataset.calibrateset(calibrate_num=200))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 64, 56, 56), float32] {\n", + " %0 = multiply(%data, 48.8058f /* ty=float32 */) /* ty=Tensor[(1, 3, 224, 224), float32] */;\n", + " %1 = round(%0) /* ty=Tensor[(1, 3, 224, 224), float32] */;\n", + " %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 224, 224), float32] */;\n", + " %3 = cast(%2, dtype=\"int8\") /* ty=Tensor[(1, 3, 224, 224), int8] */;\n", + " %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), int8] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7], out_dtype=\"int32\") /* ty=Tensor[(1, 64, 112, 112), int32] */;\n", + " %5 = add(%4, meta[relay.Constant][1] /* ty=Tensor[(64, 1, 1), int32] */) /* ty=Tensor[(1, 64, 112, 112), int32] */;\n", + " %6 = nn.relu(%5) /* ty=Tensor[(1, 64, 112, 112), int32] */;\n", + " %7 = cast(%6, dtype=\"int64\") /* ty=Tensor[(1, 64, 112, 112), int64] */;\n", + " %8 = fixed_point_multiply(%7, multiplier=1562234368, shift=-8) /* ty=Tensor[(1, 64, 112, 112), int64] */;\n", + " %9 = clip(%8, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 112, 112), int64] */;\n", + " %10 = cast(%9, dtype=\"int32\") /* ty=Tensor[(1, 64, 112, 112), int32] */;\n", + " %11 = cast(%10, dtype=\"int8\") /* ty=Tensor[(1, 64, 112, 112), int8] */;\n", + " %12 = nn.max_pool2d(%11, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), int8] */;\n", + " %13 = cast(%12, dtype=\"int8\") /* ty=Tensor[(1, 64, 56, 56), int8] */;\n", + " %14 = annotation.stop_fusion(%13) /* ty=Tensor[(1, 64, 56, 56), int8] */;\n", + " %15 = cast(%14, dtype=\"float32\") /* ty=Tensor[(1, 64, 56, 56), float32] */;\n", + " multiply(%15, 0.0221871f /* ty=float32 */) /* ty=Tensor[(1, 64, 56, 56), float32] */\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(qmod)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n", + "0it [00:00, ?it/s]\n" + ] + }, + { + "ename": "TypeError", + "evalue": "exceptions must derive from BaseException", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[17], line 25\u001b[0m\n\u001b[1;32m 23\u001b[0m label \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39marray([label])\n\u001b[1;32m 24\u001b[0m \u001b[39m# 精度度量\u001b[39;00m\n\u001b[0;32m---> 25\u001b[0m metric_top1\u001b[39m.\u001b[39;49mupdate(preds \u001b[39m=\u001b[39;49m output, labels \u001b[39m=\u001b[39;49m label)\n\u001b[1;32m 26\u001b[0m metric_top5\u001b[39m.\u001b[39mupdate(preds \u001b[39m=\u001b[39m output, labels \u001b[39m=\u001b[39m label)\n\u001b[1;32m 27\u001b[0m qmetric_top1\u001b[39m.\u001b[39mupdate(preds \u001b[39m=\u001b[39m quant_output, labels \u001b[39m=\u001b[39m label)\n", + "File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/tvm_book/metric/classification.py:122\u001b[0m, in \u001b[0;36mAccuracy.update\u001b[0;34m(self, labels, preds)\u001b[0m\n\u001b[1;32m 120\u001b[0m pred_labels \u001b[39m=\u001b[39m preds\n\u001b[1;32m 121\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 122\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m暂时未支持预测标签维度为 \u001b[39m\u001b[39m{\u001b[39;00mpreds\u001b[39m.\u001b[39mndim\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 123\u001b[0m pred_labels \u001b[39m=\u001b[39m pred_labels\u001b[39m.\u001b[39mastype(\u001b[39m'\u001b[39m\u001b[39mint32\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m 124\u001b[0m labels \u001b[39m=\u001b[39m labels\u001b[39m.\u001b[39mastype(\u001b[39m'\u001b[39m\u001b[39mint32\u001b[39m\u001b[39m'\u001b[39m)\n", + "\u001b[0;31mTypeError\u001b[0m: exceptions must derive from BaseException" + ] + } + ], + "source": [ + "from tqdm import tqdm\n", + "from tvm.runtime.vm import VirtualMachine\n", + "from tvm_book.metric.classification import Accuracy, TopKAccuracy\n", + "\n", + "with tvm.transform.PassContext(opt_level=3):\n", + " vm_exec = relay.vm.compile(run_mod, target=\"llvm\", params=params)\n", + "vm = VirtualMachine(vm_exec, tvm.cpu())\n", + "with tvm.transform.PassContext(opt_level=3):\n", + " qvm_exec = relay.vm.compile(qmod, target=\"llvm\", params=params)\n", + "qvm = VirtualMachine(qvm_exec, tvm.cpu())\n", + "\n", + "metric_top1 = Accuracy(\"浮点\")\n", + "metric_top5 = TopKAccuracy(top_k=5)\n", + "qmetric_top1 = Accuracy(\"量化\")\n", + "qmetric_top5 = TopKAccuracy(top_k=5)\n", + "for k, (data, label) in tqdm(enumerate(dataset.valset)):\n", + " image = preprocess_image(data, dataset.size, dataset.mean, dataset.std)\n", + " images = np.expand_dims(image, 0)\n", + " images = images.transpose((0, 3, 1, 2))\n", + " input_dict = {\"data\": images}\n", + " output = vm.run(**input_dict).asnumpy()\n", + " quant_output = qvm.run(**input_dict).asnumpy()\n", + " label = np.array([label])\n", + " # 精度度量\n", + " metric_top1.update(preds = output, labels = label)\n", + " metric_top5.update(preds = output, labels = label)\n", + " qmetric_top1.update(preds = quant_output, labels = label)\n", + " qmetric_top5.update(preds = quant_output, labels = label)\n", + " if k % 1000 == 0:\n", + " print(f\"浮点: {metric_top1.get(), metric_top5.get()}||量化: {qmetric_top1.get(), qmetric_top5.get()}\")\n", + " # break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tvmz", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/tutorials/quantize/resnet18.ipynb b/doc/tutorials/quantize/resnet18.ipynb index 405c91f7..d59b2f09 100644 --- a/doc/tutorials/quantize/resnet18.ipynb +++ b/doc/tutorials/quantize/resnet18.ipynb @@ -763,7 +763,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/doc/tutorials/relay/frontend/from-tensorflow.ipynb b/doc/tutorials/relay/frontend/from-tensorflow.ipynb index 3fed5ee1..61939a2d 100644 --- a/doc/tutorials/relay/frontend/from-tensorflow.ipynb +++ b/doc/tutorials/relay/frontend/from-tensorflow.ipynb @@ -175,7 +175,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/doc/tutorials/relay/frontend/from-tf_slim/tf2-keras.ipynb b/doc/tutorials/relay/frontend/from-tf_slim/tf2-keras.ipynb index ad421661..86aa89b8 100644 --- a/doc/tutorials/relay/frontend/from-tf_slim/tf2-keras.ipynb +++ b/doc/tutorials/relay/frontend/from-tf_slim/tf2-keras.ipynb @@ -655,7 +655,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/doc/tutorials/relay/frontend/from-tf_slim/tf2.ipynb b/doc/tutorials/relay/frontend/from-tf_slim/tf2.ipynb index 857b6e36..6be5b22b 100644 --- a/doc/tutorials/relay/frontend/from-tf_slim/tf2.ipynb +++ b/doc/tutorials/relay/frontend/from-tf_slim/tf2.ipynb @@ -312,7 +312,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/doc/tutorials/relay/frontend/pb2onnx.ipynb b/doc/tutorials/relay/frontend/pb2onnx.ipynb index 1094554b..6ea4f8be 100644 --- a/doc/tutorials/relay/frontend/pb2onnx.ipynb +++ b/doc/tutorials/relay/frontend/pb2onnx.ipynb @@ -214,7 +214,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/src/tvm_book/image/process.py b/src/tvm_book/image/processing.py similarity index 100% rename from src/tvm_book/image/process.py rename to src/tvm_book/image/processing.py diff --git a/src/tvm_book/tvm_utils/split_graph.py b/src/tvm_book/tvm_utils/split_graph.py new file mode 100644 index 00000000..70034e26 --- /dev/null +++ b/src/tvm_book/tvm_utils/split_graph.py @@ -0,0 +1,146 @@ +import tvm +from tvm import relay +from tvm.relay import transform, build_module +from tvm.relay.testing import run_opt_pass +# from tvm.contrib import graph_executor +# from tvm._ffi import get_global_func +# from tvm.contrib import cc as _cc + + +def graph_split(expr, split_conf, params=None): + """Splitting the graph into a list of subgraphs + + e.g.:split_conf = [{"op_name": "nn.relu", "op_index": 0}] + """ + + def get_dep_var(sub_var_dep): + return [var for var in sub_var_dep[len(sub_var_dep) - 1]["ref_nodes"]] + + def parse_dependency(value, snode_dep, new_input_idx): + new_args = [] + need_update = False + for var in value.args: + is_free_var = False + for dep in snode_dep[:-1]: + if var in dep["nodes"]: + # Mark the previous subgraph node as a dependency. + dep["nodes"][var] += 1 + dep["ref_nodes"][var] = dep["nodes"][var] + # The var of this call is a free_var + is_free_var = True + # if the var of this call is a free_var, recreate it and give it a fixed input name. + if is_free_var: + need_update = True + new_args.append(relay.var(f"data_n_{new_input_idx}", var.checked_type)) + new_input_idx += 1 + else: + new_args.append(var) + # if the 'tvm.relay.expr.Call' has a free_var, recreate it with new name as 'data_n_*'. + if need_update: + value = tvm.relay.expr.Call( + value.op, new_args, value.attrs, value.type_args, value.span + ) + return value, snode_dep, new_input_idx + + def merge_constant_expr(constant_expr, expr): + # merge constant express with a express + if not isinstance(constant_expr.body, tvm.relay.expr.Let): + return tvm.relay.expr.Let(constant_expr.var, constant_expr.value, expr) + + return tvm.relay.expr.Let( + constant_expr.var, constant_expr.value, merge_constant_expr(constant_expr.body, expr) + ) + + def _recursion(anf, pipeline_mods, split_conf, constant_expr): + """列举计算图中的所有算子,然后将计算图分成一组子图。""" + nonlocal operator_index_map + nonlocal new_input_idx + nonlocal snode_dep + cur_node_dep = snode_dep[len(snode_dep) - 1] + if isinstance(anf, tvm.relay.Function): + return tvm.relay.Function( + anf.params, + _recursion(anf.body, pipeline_mods, split_conf, constant_expr), + anf.ret_type, + anf.type_params, + anf.attrs, + ) + elif isinstance(anf, tvm.relay.expr.Let): + value = anf.value + # 记录常量表达式,以确保所有子图都能找到正确的常量。 + if isinstance(value, tvm.relay.expr.Constant): + if not constant_expr: + constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var) + else: + constant_expr = tvm.relay.expr.Let(anf.var, value, constant_expr) + if isinstance(value, tvm.relay.expr.Call): + new_args = [] + # 构建当前变量列表 + cur_node_dep["nodes"][anf.var] = 0 + # 获得节点的依赖信息。 + value, snode_dep, new_input_idx = parse_dependency(value, snode_dep, new_input_idx) + if isinstance(value.op, tvm.ir.Op): + if value.op.name in operator_index_map: + operator_index_map[value.op.name] += 1 + else: + operator_index_map[value.op.name] = 0 + split_operator_name = split_conf[0]["op_name"] if split_conf else "" + split_operator_index = split_conf[0]["op_index"] if split_conf else "" + # 如果网络中的算子名称和重复计数与“分割配置”的值匹配,则应该在这里执行图分割。 + if ( + split_conf + and split_operator_name in operator_index_map + and operator_index_map[split_operator_name] >= split_operator_index + ): + # 执行图分割 + split_conf.pop(0) + snode_dep.append({"nodes": {}, "ref_nodes": {}}) + ann = _recursion( + anf.body, + pipeline_mods, + split_conf, + constant_expr, + ) + snode_dep.pop() + dep_vars = get_dep_var(snode_dep) + # 当前子图的节点是另一个子图的依赖节点时,需要将它们设置为当前子图的输出。 + body = relay.Tuple(dep_vars) if len(dep_vars) > 1 else anf.var + # 当当前子图的算子使用先前子图的常量作为 ``relay.expr.call`` 的参数时,如果该常量不在当前子图中,则可能会成为自由变量。为了避免这个问题,可以将先前的常量与当前子图合并。 + if constant_expr: + ann = merge_constant_expr(constant_expr, ann) + ann = run_opt_pass(ann, transform.ToGraphNormalForm()) + mod = tvm.IRModule.from_expr(ann) + pipeline_mods.insert(0, mod) + # 返回当前子图的最后一个节点。 + return tvm.relay.expr.Let(anf.var, value, body) + return tvm.relay.expr.Let( + anf.var, + value, + _recursion(anf.body, pipeline_mods, split_conf, constant_expr), + ) + else: + return anf + + snode_dep = [{"nodes": {}, "ref_nodes": {}}] + pipeline_mods = [] + operator_index_map = {} + # Used to tracking new input which caused by graph splitting. + new_input_idx = 0 + constant_expr = None + subgraph_split_conf = split_conf.copy() + # Binding the parameters. + if params: + expr = build_module.bind_params_by_name(expr, params) + anf = run_opt_pass(expr, transform.ToANormalForm()) + anf = run_opt_pass(anf, transform.InferType()) + ann = _recursion( + anf, + pipeline_mods, + subgraph_split_conf, + constant_expr, + ) + ann = run_opt_pass(ann.body, transform.ToGraphNormalForm()) + mod = tvm.IRModule.from_expr(ann) + pipeline_mods.insert(0, mod) + return pipeline_mods +