Skip to content

Commit

Permalink
更新
Browse files Browse the repository at this point in the history
  • Loading branch information
xinetzone committed Aug 2, 2023
1 parent 35c7797 commit 09207a0
Show file tree
Hide file tree
Showing 16 changed files with 576 additions and 13 deletions.
4 changes: 2 additions & 2 deletions doc/tutorials/basic/start/te-mm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. TE (Tensor Expression) 实现矩阵乘法\n",
"# TE (Tensor Expression) 实现矩阵乘法\n",
"\n",
"## 用 TE 实现原始程序\n",
"\n",
Expand Down Expand Up @@ -282,7 +282,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/basic/start/tensor-program-abstraction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/datasets/imagenet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# TensorFlow 下的 ImageNet\n",
"# ImageNet 接口\n",
"\n",
"## ImageNet 标签信息\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
1 change: 1 addition & 0 deletions doc/tutorials/quantize/auto-quantize/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
```{toctree}
intro
AutomatedQuantization
parse
```
2 changes: 1 addition & 1 deletion doc/tutorials/quantize/auto-quantize/intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/quantize/canonicalizations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/quantize/fake-quantization-to-integer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
416 changes: 416 additions & 0 deletions doc/tutorials/quantize/parse.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion doc/tutorials/quantize/resnet18.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/relay/frontend/from-tensorflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/relay/frontend/from-tf_slim/tf2-keras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/relay/frontend/from-tf_slim/tf2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorials/relay/frontend/pb2onnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
},
"orig_nbformat": 4
},
Expand Down
File renamed without changes.
146 changes: 146 additions & 0 deletions src/tvm_book/tvm_utils/split_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import tvm
from tvm import relay
from tvm.relay import transform, build_module
from tvm.relay.testing import run_opt_pass
# from tvm.contrib import graph_executor
# from tvm._ffi import get_global_func
# from tvm.contrib import cc as _cc


def graph_split(expr, split_conf, params=None):
"""Splitting the graph into a list of subgraphs
e.g.:split_conf = [{"op_name": "nn.relu", "op_index": 0}]
"""

def get_dep_var(sub_var_dep):
return [var for var in sub_var_dep[len(sub_var_dep) - 1]["ref_nodes"]]

def parse_dependency(value, snode_dep, new_input_idx):
new_args = []
need_update = False
for var in value.args:
is_free_var = False
for dep in snode_dep[:-1]:
if var in dep["nodes"]:
# Mark the previous subgraph node as a dependency.
dep["nodes"][var] += 1
dep["ref_nodes"][var] = dep["nodes"][var]
# The var of this call is a free_var
is_free_var = True
# if the var of this call is a free_var, recreate it and give it a fixed input name.
if is_free_var:
need_update = True
new_args.append(relay.var(f"data_n_{new_input_idx}", var.checked_type))
new_input_idx += 1
else:
new_args.append(var)
# if the 'tvm.relay.expr.Call' has a free_var, recreate it with new name as 'data_n_*'.
if need_update:
value = tvm.relay.expr.Call(
value.op, new_args, value.attrs, value.type_args, value.span
)
return value, snode_dep, new_input_idx

def merge_constant_expr(constant_expr, expr):
# merge constant express with a express
if not isinstance(constant_expr.body, tvm.relay.expr.Let):
return tvm.relay.expr.Let(constant_expr.var, constant_expr.value, expr)

return tvm.relay.expr.Let(
constant_expr.var, constant_expr.value, merge_constant_expr(constant_expr.body, expr)
)

def _recursion(anf, pipeline_mods, split_conf, constant_expr):
"""列举计算图中的所有算子,然后将计算图分成一组子图。"""
nonlocal operator_index_map
nonlocal new_input_idx
nonlocal snode_dep
cur_node_dep = snode_dep[len(snode_dep) - 1]
if isinstance(anf, tvm.relay.Function):
return tvm.relay.Function(
anf.params,
_recursion(anf.body, pipeline_mods, split_conf, constant_expr),
anf.ret_type,
anf.type_params,
anf.attrs,
)
elif isinstance(anf, tvm.relay.expr.Let):
value = anf.value
# 记录常量表达式,以确保所有子图都能找到正确的常量。
if isinstance(value, tvm.relay.expr.Constant):
if not constant_expr:
constant_expr = tvm.relay.expr.Let(anf.var, value, anf.var)
else:
constant_expr = tvm.relay.expr.Let(anf.var, value, constant_expr)
if isinstance(value, tvm.relay.expr.Call):
new_args = []
# 构建当前变量列表
cur_node_dep["nodes"][anf.var] = 0
# 获得节点的依赖信息。
value, snode_dep, new_input_idx = parse_dependency(value, snode_dep, new_input_idx)
if isinstance(value.op, tvm.ir.Op):
if value.op.name in operator_index_map:
operator_index_map[value.op.name] += 1
else:
operator_index_map[value.op.name] = 0
split_operator_name = split_conf[0]["op_name"] if split_conf else ""
split_operator_index = split_conf[0]["op_index"] if split_conf else ""
# 如果网络中的算子名称和重复计数与“分割配置”的值匹配,则应该在这里执行图分割。
if (
split_conf
and split_operator_name in operator_index_map
and operator_index_map[split_operator_name] >= split_operator_index
):
# 执行图分割
split_conf.pop(0)
snode_dep.append({"nodes": {}, "ref_nodes": {}})
ann = _recursion(
anf.body,
pipeline_mods,
split_conf,
constant_expr,
)
snode_dep.pop()
dep_vars = get_dep_var(snode_dep)
# 当前子图的节点是另一个子图的依赖节点时,需要将它们设置为当前子图的输出。
body = relay.Tuple(dep_vars) if len(dep_vars) > 1 else anf.var
# 当当前子图的算子使用先前子图的常量作为 ``relay.expr.call`` 的参数时,如果该常量不在当前子图中,则可能会成为自由变量。为了避免这个问题,可以将先前的常量与当前子图合并。
if constant_expr:
ann = merge_constant_expr(constant_expr, ann)
ann = run_opt_pass(ann, transform.ToGraphNormalForm())
mod = tvm.IRModule.from_expr(ann)
pipeline_mods.insert(0, mod)
# 返回当前子图的最后一个节点。
return tvm.relay.expr.Let(anf.var, value, body)
return tvm.relay.expr.Let(
anf.var,
value,
_recursion(anf.body, pipeline_mods, split_conf, constant_expr),
)
else:
return anf

snode_dep = [{"nodes": {}, "ref_nodes": {}}]
pipeline_mods = []
operator_index_map = {}
# Used to tracking new input which caused by graph splitting.
new_input_idx = 0
constant_expr = None
subgraph_split_conf = split_conf.copy()
# Binding the parameters.
if params:
expr = build_module.bind_params_by_name(expr, params)
anf = run_opt_pass(expr, transform.ToANormalForm())
anf = run_opt_pass(anf, transform.InferType())
ann = _recursion(
anf,
pipeline_mods,
subgraph_split_conf,
constant_expr,
)
ann = run_opt_pass(ann.body, transform.ToGraphNormalForm())
mod = tvm.IRModule.from_expr(ann)
pipeline_mods.insert(0, mod)
return pipeline_mods

0 comments on commit 09207a0

Please sign in to comment.