diff --git a/scripts/onnx/make_serialize.py b/scripts/onnx/make_serialize.py new file mode 100644 index 000000000..ccd22b8db --- /dev/null +++ b/scripts/onnx/make_serialize.py @@ -0,0 +1,25 @@ +from refactor_graph.onnx import make_compiler +from onnx import load +import argparse + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run Refactor compiler, export model serialize." + ) + parser.add_argument( + "--model", type=str, required=True, help="Path to the model file file." + ) + parser.add_argument("--output", type=str, default="./", help="Path to save the output file.") + args = parser.parse_args() + return ( + args.model, + args.output, + ) + +def main(): + model_path, output_path = parse_args() + compiler = make_compiler(load(model_path)) + compiler.serialize(output_path) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/onnx/run_convert_to_onnx.sh b/scripts/onnx/run_convert_to_onnx.sh new file mode 100644 index 000000000..75dad715b --- /dev/null +++ b/scripts/onnx/run_convert_to_onnx.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +while getopts ":i:o:" opt; do + case $opt in + i) + model_path=$OPTARG + ;; + o) + output_path=$OPTARG + ;; + \?) + echo "Invalid option: -$OPTARG" + exit 1 + ;; + esac +done +if [ -z "$model_path" ] || [ -z "$output_path" ]; then + echo "Model path and output path are required." + exit 1 +fi + +# 确保输出目录存在 +mkdir -p "$output_path" + +# 运行第一个Python文件并保存输出到文件 +python3 make_serialize.py --model "$model_path" --output "$output_path" + +# 运行第二个Python文件并保存输出到文件 +python3 to_onnx.py --input "$output_path" + +# 输出完成信息 +echo "Models have been run successfully. Outputs are saved in $output_path." \ No newline at end of file diff --git a/scripts/onnx/to_onnx.py b/scripts/onnx/to_onnx.py new file mode 100644 index 000000000..b4eb26e34 --- /dev/null +++ b/scripts/onnx/to_onnx.py @@ -0,0 +1,277 @@ +import mmap +import argparse +from onnx import TensorProto, NodeProto, save_model +from onnx.helper import ( + make_model, + make_node, + make_graph, + make_tensor_value_info, + make_tensor, + make_opsetid, +) +from onnx.checker import check_model +class Topo: + def __init__(self, bytes: bytes): + list = bytes.strip().split(b"<-") + self.inputs = [int(s.strip(b"%")) for s in list[1].split()] + self.outputs = [int(s.strip(b"%")) for s in list[0].split()] + def __str__(self) -> str: + return f"{self.inputs} <- {self.outputs}" + +class Tensor: + def __init__(self, bytes_: bytes): + list = bytes_.split(b"\t") + self.name = str(list[1].strip(), "utf-8") + def map_dt(dt: bytes) -> TensorProto.DataType: + match dt: + case b"F32": + return TensorProto.FLOAT + case b"U8": + return TensorProto.UINT8 + case b"I8": + return TensorProto.INT8 + case b"U16": + return TensorProto.UINT16 + case b"I16": + return TensorProto.INT16 + case b"I32": + return TensorProto.INT32 + case b"I64": + return TensorProto.INT64 + case b"String": + return TensorProto.STRING + case b"Bool": + return TensorProto.BOOL + case b"FP16": + return TensorProto.FLOAT16 + case b"F64": + return TensorProto.DOUBLE + case b"U32": + return TensorProto.UINT32 + case b"U64": + return TensorProto.UINT64 + case b"Complex64": + return TensorProto.COMPLEX64 + case b"Complex128": + return TensorProto.COMPLEX128 + case b"BF16": + return TensorProto.BFLOAT16 + case _: + return TensorProto.UNDEFINED + self.dt = map_dt(list[2].strip()) + layout = list[3].strip() + if layout != b"NCHW" and layout != b"ELSE": + raise ValueError("Unsupported layout") + range = list[4].strip().split() + self.offset = int(range[0], 0) + self.size = int(range[1], 0) + self.shape = [int(s) for s in split_array(list[5])] + def __str__(self) -> str: + return f"{self.name} (dt = {self.dt}) {self.shape} {self.offset}..{self.offset + self.size}" + +class Operator: + def __init__(self, bytes: bytes): + list = bytes.split(b"\t") + self.name = str(list[1].strip(), "utf-8") + list = list[2].split(b"(", 1) + self.type = str(list[0].strip(), "utf-8") + list = list[1].rsplit(b")", 1) + self.meta = list[0].strip() + self.topo = Topo(list[1]) + def __str__(self) -> str: + return f"{self.type}: {self.name}, meta = {self.meta}, topo = {self.topo}" + def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: + if self.type == "BatchNormalization": + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + epsilon=float(self.meta.split(b"=")[0]), + ), + [], + ) + if self.type == "Conv": + meta = [int(x) for x in split_array(self.meta)] + rank = int(len(meta) / 4) + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + dilations=meta[0:rank], + strides=meta[rank : 2 * rank], + pads=meta[2 * rank : 4 * rank], + ), + [], + ) + if self.type == "Relu": + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + ), + [], + ) + if self.type == "MaxPool": + meta = self.meta.split(b",") + ceil_mode = ( + 1 if meta[0] == b"true" else (0 if meta[0] == b"false" else None) + ) + kernel_shape = [int(x) for x in split_array(meta[1])] + meta = [int(x) for x in split_array(meta[2])] + rank = int(len(meta) / 4) + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + ceil_mode=ceil_mode, + kernel_shape=kernel_shape, + dilations=meta[0:rank], + strides=meta[rank : 2 * rank], + pads=meta[2 * rank : 4 * rank], + ), + [], + ) + if self.type == "Add": + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + ), + [], + ) + if self.type == "GlobalAveragePool": + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + ), + [], + ) + if self.type == "MatMul": + meta = self.meta.split(b",") + alpha = float(meta[0].split(b"=")[0].strip()) + beta = float(meta[1].split(b"=")[0].strip()) + transA = 1 if meta[2].strip() == b"AT" else 0 + transB = 1 if meta[3].strip() == b"BT" else 0 + if alpha != 1 or beta != 0 or transA == 1 or transB == 1: + return ( + make_node( + "Gemm", + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + alpha=alpha, + beta=beta, + transA=transA, + transB=transB, + ), + [], + ) + else: + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + ), + [], + ) + if self.type == "Reshape" or self.type == "Identity": + output = tensors[self.topo.outputs[0]] + shape_name = f"{output.name}_shape" + shape = output.shape + shape = make_tensor(shape_name, TensorProto.INT64, [len(shape)], shape) + return ( + make_node( + "Reshape", + [tensors[self.topo.inputs[0]].name, shape_name], + [tensors[i].name for i in self.topo.outputs], + self.name, + ), + [shape], + ) + raise ValueError(f"Unsupported operator {self.type}") + +def parse_args(): + parser = argparse.ArgumentParser(description="Analysis serialize file.") + parser.add_argument( + "--input", + type=str, + default="./", + help="Path to save the serialize output files.", + ) + args = parser.parse_args() + return ( + args.input + ) + +def split_array(arr: bytes): + return (x for x in arr.strip().strip(b"[").strip(b"]").split()) + +def main(): + path = parse_args() + info_path = path + "/graph.info" + data_path = path + "/graph.data" + outputfile = path + "/model_refactor.onnx" + with open(info_path, "r") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as m: + operators = [] + for line in iter(m.readline, b""): + if line == b"\n": + break + operators.append(Operator(line)) + graph = Topo(m.readline().strip().strip(b"graph. ")) + _ = m.readline() + tensors = [Tensor(line) for line in iter(m.readline, b"")] + + with open(data_path, "r") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as m: + nodes = [] + initializer = [ + make_tensor( + t.name, + t.dt, + t.shape, + vals=m[t.offset : t.offset + t.size], + raw=True, + ) + for t in tensors + if t.size != 0 + ] + for o in operators: + node, init = o.to_node(tensors) + nodes.append(node) + initializer.extend(init) + graph = make_graph( + nodes, + "graph", + [ + make_tensor_value_info(t.name, t.dt, t.shape) + for t in (tensors[i] for i in graph.inputs) + ], + [ + make_tensor_value_info(t.name, t.dt, t.shape) + for t in (tensors[i] for i in graph.outputs) + ], + initializer, + ) + model = make_model(graph, opset_imports=[make_opsetid( + domain="", version=13)]) + check_model(model) + save_model(model, outputfile) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/05computation/include/computation/graph.h b/src/05computation/include/computation/graph.h index 9286172ce..c171f92b5 100644 --- a/src/05computation/include/computation/graph.h +++ b/src/05computation/include/computation/graph.h @@ -26,6 +26,7 @@ namespace refactor::computation { Graph(graph_topo::GraphTopo, std::vector, std::vector) noexcept; void layoutPermute(); + void optimize(); kernel::Graph lower(Target) const; auto internal() const -> decltype(_internal) const &; @@ -34,6 +35,17 @@ namespace refactor::computation { -> std::pair>; }; + //GraphMutant used to graph optimize + class GraphMutant { + graph_topo::LinkedGraph _internal; + + public: + explicit GraphMutant(Graph const &) noexcept; + auto internal() const -> decltype(_internal) const &; + auto internal() -> decltype(_internal) &; + }; + + }// namespace refactor::computation #endif// COMPUTATION_GRAPH_H diff --git a/src/05computation/include/computation/pass/conv_to_matmul.h b/src/05computation/include/computation/pass/conv_to_matmul.h new file mode 100644 index 000000000..a61dd624b --- /dev/null +++ b/src/05computation/include/computation/pass/conv_to_matmul.h @@ -0,0 +1,145 @@ +#ifndef COMPUTATION_CONV_TO_MATMUL_H +#define COMPUTATION_CONV_TO_MATMUL_H + +#include "../graph.h" +#include "computation/operators/conv.h" +#include "computation/operators/mat_mul.h" +#include "computation/operators/reshape.h" +#include "computation/operators/transpose.h" +#include "computation/pass/converter.h" + +namespace refactor::computation { + class ConvToMatmul : public Converter { + + public: + /* + * input weight + * | | + * | | + * transpose transpose + * | | + * | | + * reshape reshape + * \ / + * \ / + * matmul + * | + * reshape + * | + * transpose + * | + * output + */ + virtual bool execute(const std::shared_ptr &g) const override { + auto nodesList = g->internal().nodes(); + size_t count = 0; + for (auto opMatch : nodesList) { + if (opMatch->info().op == nullptr) { + continue; + } + size_t optype = opMatch->info().op->opTypeId(); + if (optype != Conv::typeId()) { + continue; + } + auto convOp = dynamic_cast(opMatch->info().op.get()); + auto input = opMatch->inputs()[0]->info().tensor; + auto weight = opMatch->inputs()[1]->info().tensor; + auto shape = weight->shape; + // judge conv is 1x1 convolution + if (shape.size() != 4 || shape[2] != 1 || shape[3] != 1) { + continue; + } + auto attr = convOp->attributes; + auto poolAttrRank = attr.rank(); + auto poolAttrDilation = attr.dilations(); + auto poolAttrStride = attr.strides(); + auto poolAttrPad = attr.pads(); + bool flag = false; + for (auto i : range0_(poolAttrRank)) { + if (poolAttrDilation[i] != 1 || poolAttrStride[i] != 1) { + flag = true; + break; + } + if (poolAttrPad[i] != 0 || poolAttrPad[i + poolAttrRank] != 0) { + flag = true; + break; + } + } + if (flag) { continue; } + // create transpose op + absl::InlinedVector + perm1 = {0, 2, 3, 1}; + Shape shape1 = {input->shape[0], input->shape[2], input->shape[3], input->shape[1]}; + auto newTransposeOp1 = g->internal().pushNode( + {std::make_unique(perm1), fmt::format("ConvToMatmul_transpose1_{}", count)}, + {g->internal().shareEdge({Tensor::share(input->dataType, shape1), fmt::format("ConvToMatmul_transpose1_{}_out", count)})}); + newTransposeOp1->connect(0, opMatch->inputs()[0]); + absl::InlinedVector perm2 = {1, 0, 2, 3}; + Shape shape2 = {weight->shape[1], weight->shape[0], weight->shape[2], weight->shape[3]}; + auto newTransposeOp2 = g->internal().pushNode( + {std::make_unique(perm2), fmt::format("ConvToMatmul_transpose2_{}", count)}, + {g->internal().shareEdge({Tensor::share(weight->dataType, shape2), fmt::format("ConvToMatmul_transpose2_{}_out", count)})}); + newTransposeOp2->connect(0, opMatch->inputs()[1]); + // create reshape op + Shape shape3 = {input->shape[0] * input->shape[2] * input->shape[3], input->shape[1]}; + Shape shape4 = {weight->shape[1], weight->shape[0]}; + int64_t data1[2] = {input->shape[0] * input->shape[2] * input->shape[3], input->shape[1]}; + int64_t data2[2] = {weight->shape[1], weight->shape[0]}; + auto [data1_, ptr1] = refactor::kernel::Blob::share(sizeof(int64_t) * 2); + auto [data2_, ptr2] = refactor::kernel::Blob::share(sizeof(int64_t) * 2); + ptr1 = &data1[0]; + ptr2 = &data2[0]; + auto newReshapeEdge1 = g->internal().shareEdge({Tensor::share(DataType::I64, {2}, LayoutType::Others, data1_), fmt::format("ConvToMatmul_reshape1_shape_{}", count)}); + auto newReshapeEdge2 = g->internal().shareEdge({Tensor::share(DataType::I64, {2}, LayoutType::Others, data2_), fmt::format("ConvToMatmul_reshape2_shape_{}", count)}); + auto newReshapeOp1 = g->internal().pushNode( + {std::make_unique(), fmt::format("ConvToMatmul_reshape1_{}", count)}, + {g->internal().shareEdge({Tensor::share(input->dataType, shape3), fmt::format("ConvToMatmul_reshape1_{}_out", count)})}); + auto newReshapeOp2 = g->internal().pushNode( + {std::make_unique(), fmt::format("ConvToMatmul_reshape2_{}", count)}, + {g->internal().shareEdge({Tensor::share(weight->dataType, shape4), fmt::format("ConvToMatmul_reshape2_{}_out", count)})}); + newReshapeOp1->connect(0, newTransposeOp1->outputs()[0]); + newReshapeOp1->connect(1, newReshapeEdge1); + newReshapeOp2->connect(0, newTransposeOp2->outputs()[0]); + newReshapeOp2->connect(1, newReshapeEdge2); + // create matmul op + Shape shape5 = {input->shape[0] * input->shape[2] * input->shape[3], weight->shape[0]}; + auto newMatMulOp = g->internal().pushNode( + {std::make_unique(1.0, 1.0, false, false), fmt::format("ConvToMatmul_matmul_{}", count)}, + {g->internal().shareEdge({Tensor::share(input->dataType, shape5), fmt::format("ConvToMatmul_matmul_{}_out", count)})}); + newMatMulOp->connect(0, newReshapeOp1->outputs()[0]); + newMatMulOp->connect(1, newReshapeOp2->outputs()[0]); + // create reshape op + Shape shape6 = {input->shape[0], input->shape[2], input->shape[3], weight->shape[0]}; + int64_t data3[4] = {input->shape[0], input->shape[2], input->shape[3], weight->shape[0]}; + auto [data3_, ptr3] = refactor::kernel::Blob::share(sizeof(int64_t) * 4); + ptr3 = &data3[0]; + auto newReshapeEdge3 = g->internal().shareEdge({Tensor::share(DataType::I64, {4}, LayoutType::Others, data3_), fmt::format("ConvToMatmul_reshape3_shape_{}", count)}); + auto newReshapeOp3 = g->internal().pushNode( + {std::make_unique(), fmt::format("ConvToMatmul_reshape3_{}", count)}, + {g->internal().shareEdge({Tensor::share(input->dataType, shape6), fmt::format("ConvToMatmul_reshape3_{}_out", count)})}); + newReshapeOp3->connect(0, newMatMulOp->outputs()[0]); + newReshapeOp3->connect(1, newReshapeEdge3); + // create transpose op + absl::InlinedVector perm3 = {0, 3, 1, 2}; + Shape shape7 = {input->shape[0], weight->shape[0], input->shape[2], input->shape[3]}; + auto newTransposeOp3 = g->internal().pushNode( + {std::make_unique(perm3), fmt::format("ConvToMatmul_transpose3_{}", count)}, + {g->internal().shareEdge({Tensor::share(input->dataType, shape7), fmt::format("ConvToMatmul_transpose3_{}_out", count)})}); + newTransposeOp3->connect(0, newReshapeOp3->outputs()[0]); + if (opMatch->outputs()[0]->targets().size() == 0) {// global output + g->internal().replaceOutput(opMatch->outputs()[0], newTransposeOp3->outputs()[0]); + } else { + for (auto node : opMatch->outputs()[0]->targets()) { + auto it = std::find(node->inputs().begin(), node->inputs().end(), opMatch->outputs()[0]); + node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], newTransposeOp3->outputs()[0]); + } + } + g->internal().eraseNode(opMatch); + count++; + } + return true; + }; + }; + +}// namespace refactor::computation +#endif// COMPUTATION_CONV_TO_MATMUL_H \ No newline at end of file diff --git a/src/05computation/include/computation/pass/converter.h b/src/05computation/include/computation/pass/converter.h new file mode 100644 index 000000000..174b4d819 --- /dev/null +++ b/src/05computation/include/computation/pass/converter.h @@ -0,0 +1,40 @@ +#ifndef COMPUTATION_CONVERTER_H +#define COMPUTATION_CONVERTER_H + +#include "../graph.h" + +namespace refactor::computation { + + class Converter { + public: + Converter() = default; + virtual ~Converter() = default; + virtual bool execute(const std::shared_ptr &) const = 0; + static Converter *get(std::string_view key) { + //fmt::println("{}", storage().size()); + if (storage().find(key) != storage().end()) { + return storage().at(key).get(); + } + return nullptr; + }; + static void add(std::shared_ptr converter, std::string_view key) { + storage().insert(std::make_pair(key, converter)); + }; + static std::unordered_map> &storage() { + static std::unordered_map> passStorage; + return passStorage; + } + }; + + template + class ConverterRegister { + public: + ConverterRegister(const char *claim) { + T *instance = new T; + Converter::add(std::shared_ptr(instance), claim); + } + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_CONVERTER_H \ No newline at end of file diff --git a/src/05computation/include/computation/pass/matmul_transpose.h b/src/05computation/include/computation/pass/matmul_transpose.h new file mode 100644 index 000000000..a84ead727 --- /dev/null +++ b/src/05computation/include/computation/pass/matmul_transpose.h @@ -0,0 +1,94 @@ +#ifndef COMPUTATION_MATMUL_TRANSPOSE_H +#define COMPUTATION_MATMUL_TRANSPOSE_H + +#include "../graph.h" +#include "computation/operators/mat_mul.h" +#include "computation/operators/transpose.h" +#include "computation/pass/converter.h" + +namespace refactor::computation { + class MatMulTransposeFuse : public Converter { + public: + virtual bool execute(const std::shared_ptr &g) const override { + auto nodesList = g->internal().nodes(); + for (auto opMatch : nodesList) { + if (opMatch->info().op == nullptr) { + continue; + } + size_t optype = opMatch->info().op->opTypeId(); + if (optype != MatMul::typeId()) { + continue; + } + auto matmulOp = dynamic_cast(opMatch->info().op.get()); + if (opMatch->predecessors().size() != 0) { + for (size_t i = 0; i < opMatch->inputs().size(); ++i) { + if (auto preOp = opMatch->inputs()[i]->source(); + preOp != nullptr && preOp->info().op->opTypeId() == Transpose::typeId()) { + auto transposeOp = dynamic_cast(preOp->info().op.get()); + auto axis = transposeOp->perm; + bool flag = false; + if (axis[axis.size() - 1] == axis.size() - 2 && axis[axis.size() - 2] == axis.size() - 1) { + flag = true; + } + for (size_t index = 0; index < axis.size() - 2; ++index) { + if (index == axis[index]) { + continue; + } + flag = false; + break; + } + if (flag) { + if (i == 0) { + matmulOp->transA = !matmulOp->transA; + } else { + matmulOp->transB = !matmulOp->transB; + } + opMatch->reconnect(opMatch->inputs()[i], preOp->inputs()[0]); + g->internal().eraseNode(preOp); + } + } + } + } + if (opMatch->successors().size() == 1) { + if (auto postOp = *(opMatch->outputs()[0]->targets().begin()); + postOp != nullptr && postOp->info().op->opTypeId() == Transpose::typeId()) { + auto transposeOp = dynamic_cast(postOp->info().op.get()); + auto axis = transposeOp->perm; + bool flag = false; + if (axis[axis.size() - 1] == axis.size() - 2 && axis[axis.size() - 2] == axis.size() - 1) { + flag = true; + } + for (size_t index = 0; index < axis.size() - 2; ++index) { + if (index == axis[index]) { + continue; + } + flag = false; + break; + } + if (flag) { + matmulOp->transA = !matmulOp->transA; + matmulOp->transB = !matmulOp->transB; + auto inputsA = opMatch->inputs()[0]; + auto inputsB = opMatch->inputs()[1]; + opMatch->connect(0, inputsB); + opMatch->connect(1, inputsA); + opMatch->outputs()[0]->info().tensor->shape = postOp->outputs()[0]->info().tensor->shape; + if (postOp->outputs()[0]->targets().size() == 0) {// global output + g->internal().replaceOutput(postOp->outputs()[0], opMatch->outputs()[0]); + } else { + for (auto node : postOp->outputs()[0]->targets()) { + auto it = std::find(node->inputs().begin(), node->inputs().end(), postOp->outputs()[0]); + node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], opMatch->outputs()[0]); + } + } + g->internal().eraseNode(postOp); + } + } + } + } + return true; + }; + }; + +}// namespace refactor::computation +#endif// COMPUTATION_MATMUL_TRANSPOSE_H diff --git a/src/05computation/include/computation/pass_register.h b/src/05computation/include/computation/pass_register.h new file mode 100644 index 000000000..4bdd5e8ce --- /dev/null +++ b/src/05computation/include/computation/pass_register.h @@ -0,0 +1,18 @@ +#ifndef COMPUTATION_PASS_REGISTER_H +#define COMPUTATION_PASS_REGISTER_H +#include "pass/conv_to_matmul.h" +#include "pass/converter.h" +#include "pass/matmul_transpose.h" + +namespace refactor::computation { + + void register_() { +#define REGISTER(PASS, NAME) static ConverterRegister NAME("" #NAME); + REGISTER(MatMulTransposeFuse, MatMulTransposeFuse) + REGISTER(ConvToMatmul, ConvToMatmul) + }; + + +}// namespace refactor::computation + +#endif \ No newline at end of file diff --git a/src/05computation/src/graph.cc b/src/05computation/src/graph.cc index eb0b9566c..92ec95b73 100644 --- a/src/05computation/src/graph.cc +++ b/src/05computation/src/graph.cc @@ -1,4 +1,6 @@ #include "computation/graph.h" +#include "computation/pass/converter.h" +#include "computation/pass_register.h" #include namespace refactor::computation { @@ -133,79 +135,104 @@ namespace refactor::computation { std::pair> Graph::serialize(bool withData) const { auto const &graph = _internal.contiguous(); - graph_topo::LinkedGraph cleaner; - { - std::unordered_map identities; - for (auto [nodeIdx, inputs, outputs] : graph.topology) { - if (auto const &op = graph.nodes[nodeIdx].op; op && op->isIdentity()) { - identities.emplace(outputs[0], inputs[0]); - } - } - - auto modifier = graph_topo::InplaceModifier(graph.topology); - modifier.reconnect(identities); - - std::vector - nodes(graph.nodes.size()), - edges(graph.edges.size()); - std::iota(nodes.begin(), nodes.end(), 0); - std::iota(edges.begin(), edges.end(), 0); - - cleaner = graph_topo::LinkedGraph(graph_topo::Graph{ - modifier.take(), - std::move(nodes), - std::move(edges), - }); - cleaner.cleanup(); - - for (auto const &n : cleaner.nodes()) { - auto const &inputs = n->inputs(); - for (auto i : range0_(inputs.size()).rev()) { - if (!graph.edges[inputs[i]->info()].tensor) { - n->disconnect(i); - } else { - break; - } - } - } - } - + // graph_topo::LinkedGraph cleaner; + // { + // std::unordered_map identities; + // for (auto [nodeIdx, inputs, outputs] : graph.topology) { + // if (auto const &op = graph.nodes[nodeIdx].op; op && op->isIdentity()) { + // identities.emplace(outputs[0], inputs[0]); + // } + // } + // auto modifier = graph_topo::InplaceModifier(graph.topology); + // modifier.reconnect(identities); + // std::vector + // nodes(graph.nodes.size()), + // edges(graph.edges.size()); + // std::iota(nodes.begin(), nodes.end(), 0); + // std::iota(edges.begin(), edges.end(), 0); + // cleaner = graph_topo::LinkedGraph(graph_topo::Graph{ + // modifier.take(), + // std::move(nodes), + // std::move(edges), + // }); + // cleaner.cleanup(); + // for (auto const &n : cleaner.nodes()) { + // auto const &inputs = n->inputs(); + // for (auto i : range0_(inputs.size()).rev()) { + // if (!graph.edges[inputs[i]->info()].tensor) { + // n->disconnect(i); + // } else { + // break; + // } + // } + // } + // } EdgeRecorder edges(withData); - for (auto const &edge : cleaner.inputs()) { - edges += graph.edges[edge->info()]; + for (auto const &edge : graph.topology.globalInputs()) { + edges += graph.edges[edge]; } std::stringstream ss; - for (auto n : cleaner.nodes()) { - auto const &[op, name] = graph.nodes[n->info()]; + for (auto const &[nodeIdx, inputs, outputs] : graph.topology) { + auto const &[op, name] = graph.nodes[nodeIdx]; if (op) { - ss << fmt::format("{:>5}.\t{:<32}\t{}", n->info(), name, op->serialize()); + ss << fmt::format("{:>5}.\t{:<32}\t{}", nodeIdx, name, op->serialize()); } else { continue; } - - for (auto const &e : n->outputs()) { - ss << " %" << (edges += graph.edges[e->info()]); + for (auto const &e : outputs) { + ss << " %" << (edges += graph.edges[e]); } ss << " <-"; - for (auto const &e : n->inputs()) { - ss << " %" << (edges += graph.edges[e->info()]); + for (auto const &e : inputs) { + ss << " %" << (edges += graph.edges[e]); } ss << std::endl; } ss << std::endl << "graph."; - for (auto const &e : cleaner.outputs()) { - ss << " %" << edges[graph.edges[e->info()]]; + for (auto const &e : graph.topology.globalOutputs()) { + ss << " %" << edges[graph.edges[e]]; } ss << " <-"; - for (auto const &e : cleaner.inputs()) { - ss << " %" << edges[graph.edges[e->info()]]; + for (auto const &e : graph.topology.globalInputs()) { + ss << " %" << edges[graph.edges[e]]; } ss << std::endl << std::endl << edges; - return {ss.str(), edges.takeData()}; } + void RunOptimizePass(std::vector passes, const std::shared_ptr &g) { + for (auto pass : passes) { + auto convert = Converter::get(pass); + if (nullptr == convert) { + fmt::println("Can't find pass of {}.", pass); + continue; + } + bool valid = convert->execute(g); + if (!valid) { + fmt::println("Run {} Error", pass); + } + } + } + + void Graph::optimize() { + auto graphMutant = GraphMutant(*this); + std::vector passes = { + "MatMulTransposeFuse", + "ConvToMatmul", + }; + register_();//all pass insert + auto g = std::make_shared(graphMutant); + RunOptimizePass(passes, g); + _internal = g->internal(); + } + + GraphMutant::GraphMutant(Graph const &g) noexcept { + _internal = g.internal().linked(); + } + auto GraphMutant::internal() const -> decltype(_internal) const & { return _internal; } + auto GraphMutant::internal() -> decltype(_internal) & { return _internal; } + }// namespace refactor::computation diff --git a/src/05computation/test/test_pass/test_cont_to_matmul.cpp b/src/05computation/test/test_pass/test_cont_to_matmul.cpp new file mode 100644 index 000000000..c59c925d8 --- /dev/null +++ b/src/05computation/test/test_pass/test_cont_to_matmul.cpp @@ -0,0 +1,114 @@ +#include "computation/graph.h" +#include "computation/operators/conv.h" +#include "computation/operators/simple_unary.h" +#include + +namespace refactor::computation { + + refactor::graph_topo::Builder TestConvToMatMulGraphBuild1() { + auto nodes = std::unordered_map{}; + int64_t dilations[2] = {1, 1}; + int64_t strides[2] = {1, 1}; + int64_t pads[4] = {0, 0, 0, 0}; + nodes[0] = Node{std::make_unique(PoolAttributes(2, &dilations[0], &pads[0], &strides[0])), "conv"}; + nodes[1] = Node{std::make_unique(refactor::kernel::SimpleUnaryType::Relu), "relu"}; + + auto tensor0 = Tensor::share(DataType::F32, {1, 3, 5, 5}, LayoutType::Others); + auto tensor1 = Tensor::share(DataType::F32, {2, 3, 1, 1}, LayoutType::Others); + auto tensor2 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); + auto tensor3 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); + + return { + { + {0, {{0, 1}, {2}}}, + {1, {{2}, {3}}}, + }, + {0, 1},// global inputs + {3}, // global outputs + std::move(nodes), + { + {0, {tensor0, "input"}}, + {1, {tensor1, "weight"}}, + {2, {tensor2, "conv_output"}}, + {3, {tensor3, "output"}}, + }, + }; + } + + TEST(Graph, ConvToMatMul1) { + auto graphTopo = TestConvToMatMulGraphBuild1().build(); + fmt::println("{}", graphTopo.topology.toString()); + Graph g(std::move(graphTopo)); + g.optimize(); + auto const &g_ = g.internal().contiguous(); + fmt::println("{}", g_.topology.toString()); + fmt::println("Nodes info :"); + for (size_t i = 0; i < g_.nodes.size(); ++i) { + fmt::println("{}. \"{}\"", i, g_.nodes[i].name); + } + fmt::println("\n Edges info :"); + for (size_t i = 0; i < g_.edges.size(); ++i) { + fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, + vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); + } + ASSERT_EQ(g_.nodes.size(), 8); + ASSERT_EQ(g_.edges.size(), 13); + } + + refactor::graph_topo::Builder TestConvToMatMulGraphBuild2() { + auto nodes = std::unordered_map{}; + nodes[0] = Node{std::make_unique(PoolAttributes(2, nullptr, nullptr, nullptr)), "conv0"}; + nodes[1] = Node{std::make_unique(refactor::kernel::SimpleUnaryType::Relu), "relu0"}; + nodes[2] = Node{std::make_unique(PoolAttributes(2, nullptr, nullptr, nullptr)), "conv1"}; + nodes[3] = Node{std::make_unique(refactor::kernel::SimpleUnaryType::Relu), "relu1"}; + + auto tensor0 = Tensor::share(DataType::F32, {1, 3, 5, 5}, LayoutType::Others); + auto tensor1 = Tensor::share(DataType::F32, {2, 3, 1, 1}, LayoutType::Others); + auto tensor2 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); + auto tensor3 = Tensor::share(DataType::F32, {1, 2, 5, 5}, LayoutType::Others); + auto tensor4 = Tensor::share(DataType::F32, {4, 3, 1, 1}, LayoutType::Others); + auto tensor5 = Tensor::share(DataType::F32, {1, 4, 5, 5}, LayoutType::Others); + auto tensor6 = Tensor::share(DataType::F32, {1, 4, 5, 5}, LayoutType::Others); + + return { + { + {0, {{0, 1}, {2}}}, + {1, {{2}, {3}}}, + {2, {{3, 4}, {5}}}, + {3, {{5}, {6}}}, + }, + {0, 1, 4},// global inputs + {6}, // global outputs + std::move(nodes), + { + {0, {tensor0, "input0"}}, + {1, {tensor1, "weight0"}}, + {2, {tensor2, "conv0_output"}}, + {3, {tensor3, "relu0_output"}}, + {4, {tensor4, "weight1"}}, + {5, {tensor5, "conv1_output"}}, + {6, {tensor6, "output"}}, + }, + }; + } + + TEST(Graph, ConvToMatMul2) { + auto graphTopo = TestConvToMatMulGraphBuild2().build(); + fmt::println("{}", graphTopo.topology.toString()); + Graph g(std::move(graphTopo)); + g.optimize(); + auto const &g_ = g.internal().contiguous(); + fmt::println("{}", g_.topology.toString()); + fmt::println("Nodes info :"); + for (size_t i = 0; i < g_.nodes.size(); ++i) { + fmt::println("{}. \"{}\"", i, g_.nodes[i].name); + } + fmt::println("\n Edges info :"); + for (size_t i = 0; i < g_.edges.size(); ++i) { + fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, + vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); + } + ASSERT_EQ(g_.nodes.size(), 16); + ASSERT_EQ(g_.edges.size(), 25); + } +}// namespace refactor::computation diff --git a/src/05computation/test/test_pass/test_matmul_transpose_fuse.cpp b/src/05computation/test/test_pass/test_matmul_transpose_fuse.cpp new file mode 100644 index 000000000..6ae0ea75b --- /dev/null +++ b/src/05computation/test/test_pass/test_matmul_transpose_fuse.cpp @@ -0,0 +1,163 @@ +#include "computation/graph.h" +#include "computation/operators/mat_mul.h" +#include "computation/operators/simple_unary.h" +#include "computation/operators/transpose.h" +#include + +namespace refactor::computation { + + refactor::graph_topo::Builder TestMatMulTransposeGraphBuild1() { + absl::InlinedVector perm = {0, 1, 3, 2}; + auto nodes = std::unordered_map{}; + nodes[0] = Node{std::make_unique(perm), "transpose0"}; + nodes[1] = Node{std::make_unique(perm), "transpose1"}; + nodes[2] = Node{std::make_unique(1.0, 1.0, false, false), "matmul"}; + + auto tensor0 = Tensor::share(DataType::F32, {1, 3, 3, 5}, LayoutType::Others); + auto tensor1 = Tensor::share(DataType::F32, {2, 3, 5, 3}, LayoutType::Others); + auto tensor2 = Tensor::share(DataType::F32, {1, 3, 5, 3}, LayoutType::Others); + auto tensor3 = Tensor::share(DataType::F32, {2, 3, 3, 5}, LayoutType::Others); + auto tensor4 = Tensor::share(DataType::F32, {2, 3, 5, 5}, LayoutType::Others); + + return { + { + {0, {{0}, {2}}}, + {1, {{1}, {3}}}, + {2, {{2, 3}, {4}}}, + }, + {0, 1},// global inputs + {4}, // global outputs + std::move(nodes), + { + {0, {tensor0, "input0"}}, + {1, {tensor1, "input1"}}, + {2, {tensor2, "input0_transpose"}}, + {3, {tensor3, "input1_transpose"}}, + {4, {tensor4, "output"}}, + }, + }; + } + + refactor::graph_topo::Builder TestMatMulTransposeGraphBuild2() { + absl::InlinedVector perm = {0, 1, 3, 2}; + auto nodes = std::unordered_map{}; + nodes[0] = Node{std::make_unique(1.0, 1.0, false, false), "matmul"}; + nodes[1] = Node{std::make_unique(perm), "transpose1"}; + + auto tensor0 = Tensor::share(DataType::F32, {1, 3, 3, 5}, LayoutType::Others); + auto tensor1 = Tensor::share(DataType::F32, {2, 3, 5, 4}, LayoutType::Others); + auto tensor2 = Tensor::share(DataType::F32, {2, 3, 3, 4}, LayoutType::Others); + auto tensor3 = Tensor::share(DataType::F32, {2, 3, 4, 3}, LayoutType::Others); + + return { + { + {0, {{0, 1}, {2}}}, + {1, {{2}, {3}}}, + }, + {0, 1},// global inputs + {3}, // global outputs + std::move(nodes), + { + {0, {tensor0, "input0"}}, + {1, {tensor1, "input1"}}, + {2, {tensor2, "matmul_output"}}, + {3, {tensor3, "output"}}, + }, + }; + } + + refactor::graph_topo::Builder TestMatMulTransposeGraphBuild3() { + absl::InlinedVector perm = {0, 1, 3, 2}; + auto nodes = std::unordered_map{}; + nodes[0] = Node{std::make_unique(perm), "transpose0"}; + nodes[1] = Node{std::make_unique(perm), "transpose1"}; + nodes[2] = Node{std::make_unique(1.0, 1.0, false, false), "matmul"}; + nodes[3] = Node{std::make_unique(perm), "transpose3"}; + nodes[4] = Node{std::make_unique(refactor::kernel::SimpleUnaryType::Relu), "relu"}; + + + auto tensor0 = Tensor::share(DataType::F32, {1, 3, 3, 4}, LayoutType::Others); + auto tensor1 = Tensor::share(DataType::F32, {2, 3, 5, 3}, LayoutType::Others); + auto tensor2 = Tensor::share(DataType::F32, {1, 3, 4, 3}, LayoutType::Others); + auto tensor3 = Tensor::share(DataType::F32, {2, 3, 3, 5}, LayoutType::Others); + auto tensor4 = Tensor::share(DataType::F32, {2, 3, 4, 5}, LayoutType::Others); + auto tensor5 = Tensor::share(DataType::F32, {2, 3, 5, 4}, LayoutType::Others); + auto tensor6 = Tensor::share(DataType::F32, {2, 3, 5, 4}, LayoutType::Others); + + return { + { + {0, {{0}, {2}}}, + {1, {{1}, {3}}}, + {2, {{2, 3}, {4}}}, + {3, {{4}, {5}}}, + {4, {{5}, {6}}}, + }, + {0, 1},// global inputs + {6}, // global outputs + std::move(nodes), + { + {0, {tensor0, "input0"}}, + {1, {tensor1, "input1"}}, + {2, {tensor2, "input0_transpose"}}, + {3, {tensor3, "input1_transpose"}}, + {4, {tensor4, "matmul_output"}}, + {5, {tensor5, "transpose_output"}}, + {6, {tensor6, "output"}}, + }, + }; + } + + TEST(Graph, MatMulTranspose1) { + auto graphTopo = TestMatMulTransposeGraphBuild1().build(); + fmt::println("{}", graphTopo.topology.toString()); + Graph g(std::move(graphTopo)); + g.optimize(); + auto const &g_ = g.internal().contiguous(); + fmt::println("{}", g_.topology.toString()); + fmt::println("Nodes info :"); + for (size_t i = 0; i < g_.nodes.size(); ++i) { + fmt::println("{}. \"{}\"", i, g_.nodes[i].name); + } + fmt::println("\n Edges info :"); + for (size_t i = 0; i < g_.edges.size(); ++i) { + fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, + vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); + } + } + + TEST(Graph, MatMulTranspose2) { + auto graphTopo = TestMatMulTransposeGraphBuild2().build(); + fmt::println("{}", graphTopo.topology.toString()); + Graph g(std::move(graphTopo)); + g.optimize(); + auto const &g_ = g.internal().contiguous(); + fmt::println("{}", g_.topology.toString()); + fmt::println("Nodes info :"); + for (size_t i = 0; i < g_.nodes.size(); ++i) { + fmt::println("{}. \"{}\"", i, g_.nodes[i].name); + } + fmt::println("\n Edges info :"); + for (size_t i = 0; i < g_.edges.size(); ++i) { + fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, + vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); + } + } + + TEST(Graph, MatMulTranspose3) { + auto graphTopo = TestMatMulTransposeGraphBuild3().build(); + fmt::println("{}", graphTopo.topology.toString()); + Graph g(std::move(graphTopo)); + g.optimize(); + auto const &g_ = g.internal().contiguous(); + fmt::println("{}", g_.topology.toString()); + fmt::println("Nodes info :"); + for (size_t i = 0; i < g_.nodes.size(); ++i) { + fmt::println("{}. \"{}\"", i, g_.nodes[i].name); + } + fmt::println("\n Edges info :"); + for (size_t i = 0; i < g_.edges.size(); ++i) { + fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, + vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); + } + } +}// namespace refactor::computation diff --git a/src/06frontend/src/graph.cc b/src/06frontend/src/graph.cc index 7fd2eb57e..0a4ce7b5e 100644 --- a/src/06frontend/src/graph.cc +++ b/src/06frontend/src/graph.cc @@ -198,7 +198,9 @@ namespace refactor::frontend { auto const endTime = high_resolution_clock::now(); logi("lowering cost time: {} μs", duration_cast(endTime - startTime).count()); - return {_internal.topology, std::move(nodes), std::move(edges)}; + computation::Graph graph(_internal.topology, std::move(nodes), std::move(edges)); + graph.optimize(); + return graph; } void Graph::logGraph() const {