From 57504b1ca3cd469c2f97d4e0bb01122f610e0c94 Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Fri, 22 Mar 2024 10:50:20 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E6=94=AF=E6=8C=81layernorm=E8=9E=8D?= =?UTF-8?q?=E5=90=88=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/onnx/make_serialize.py | 18 ++- scripts/onnx/to_onnx.py | 53 ++++++- .../include/computation/operators/layernorm.h | 24 ++++ .../include/computation/pass/layernorm_fuse.h | 134 ++++++++++++++++++ .../include/computation/pass_register.h | 2 + src/05computation/src/graph.cc | 1 + src/05computation/src/operators/layernorm.cc | 22 +++ .../test/test_pass/test_layernorm_fuse.cpp | 117 +++++++++++++++ 8 files changed, 363 insertions(+), 8 deletions(-) create mode 100644 src/05computation/include/computation/operators/layernorm.h create mode 100644 src/05computation/include/computation/pass/layernorm_fuse.h create mode 100644 src/05computation/src/operators/layernorm.cc create mode 100644 src/05computation/test/test_pass/test_layernorm_fuse.cpp diff --git a/scripts/onnx/make_serialize.py b/scripts/onnx/make_serialize.py index ccd22b8d..ee2d5fc6 100644 --- a/scripts/onnx/make_serialize.py +++ b/scripts/onnx/make_serialize.py @@ -1,6 +1,8 @@ from refactor_graph.onnx import make_compiler from onnx import load import argparse +from onnx.external_data_helper import load_external_data_for_model + def parse_args(): parser = argparse.ArgumentParser( @@ -9,17 +11,27 @@ def parse_args(): parser.add_argument( "--model", type=str, required=True, help="Path to the model file file." ) - parser.add_argument("--output", type=str, default="./", help="Path to save the output file.") + parser.add_argument( + "--output", type=str, default="./", help="Path to save the output file." + ) args = parser.parse_args() return ( args.model, args.output, ) + def main(): model_path, output_path = parse_args() - compiler = make_compiler(load(model_path)) + model = load(model_path) + # model = load(model_path, load_external_data=False) + # load_external_data_for_model( + # model, + # "/home/zhangyunze/workspace/RefactorGraph/scripts/onnx/bert_bs1.pb", + # ) + compiler = make_compiler(model) compiler.serialize(output_path) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/onnx/to_onnx.py b/scripts/onnx/to_onnx.py index d48b4c22..115b069a 100644 --- a/scripts/onnx/to_onnx.py +++ b/scripts/onnx/to_onnx.py @@ -121,7 +121,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: ), [], ) - if self.type == "Relu": + if self.type in ["Relu", "Tanh"]: return ( make_node( self.type, @@ -166,6 +166,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: "Log", "Neg", "Sigmoid", + "Where", ]: return ( make_node( @@ -235,14 +236,14 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: ), [shape], ) - if self.type in ["Gather", "Concat", "Softmax"]: + if self.type in ["Gather", "Concat", "Softmax", "Split"]: meta = self.meta.split(b"/") axis = int(meta[0]) return ( make_node( self.type, [tensors[i].name for i in self.topo.inputs], - [tensors[self.topo.outputs[0]].name], + [tensors[i].name for i in self.topo.outputs], self.name, domain=DEFAULT_DOMAIN, axis=axis, @@ -251,7 +252,7 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: ) if self.type == "ReduceMean": meta = self.meta.split(b",") - keepDims = meta[2] == b"true" + keepDims = meta[2] == b" true" axes = [int(x) for x in split_array(meta[0])] return ( make_node( @@ -311,7 +312,35 @@ def to_node(self, tensors: list[Tensor]) -> tuple[NodeProto, list[TensorProto]]: [tensors[i].name for i in self.topo.outputs], self.name, domain="refactor", - epsilon=1e-5, + epsilon=float(self.meta.split(b"=")[0]), + ), + [], + ) + if self.type == "LayerNormalization": + meta = self.meta.split(b",") + epsilon = float(meta[0].split(b"=")[0].strip()) + axis = int(meta[1]) + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + domain="refactor", + epsilon=epsilon, + axis=axis, + ), + [], + ) + if self.type == "RotaryPositionEmbedding": + return ( + make_node( + self.type, + [tensors[i].name for i in self.topo.inputs], + [tensors[i].name for i in self.topo.outputs], + self.name, + domain="refactor", + theta=float(self.meta.split(b"=")[0]), ), [], ) @@ -364,7 +393,14 @@ def main(): with open(data_path, "r") as f: with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as m: nodes = [] + # for t in tensors: + # if t.size != 0: + # print(f"tensor_name is {t.name}") initializer = [ + # ( + # , + # print(f"tensor_name is {t.name}"), + # ) make_tensor( t.name, t.dt, @@ -391,6 +427,13 @@ def main(): for t in (tensors[i] for i in graph.outputs) ], initializer, + value_info=[ + make_tensor_value_info(t.name, t.dt, t.shape) + for t in tensors + if t.size == 0 + and t.name not in graph.inputs + and t.name not in graph.outputs + ], ) # model = make_model( # graph, opset_imports=[make_opsetid(domain="", version=13)] diff --git a/src/05computation/include/computation/operators/layernorm.h b/src/05computation/include/computation/operators/layernorm.h new file mode 100644 index 00000000..4bc2f6eb --- /dev/null +++ b/src/05computation/include/computation/operators/layernorm.h @@ -0,0 +1,24 @@ +#ifndef COMPUTATION_LAYER_NORMALIZATION_H +#define COMPUTATION_LAYER_NORMALIZATION_H + +#include "../operator.h" + +namespace refactor::computation { + + struct LayerNormalization final : public Operator { + float epsilon; + int axis; + + constexpr explicit LayerNormalization(float epsilon_, int axis_) noexcept + : Operator(), epsilon(epsilon_), axis(axis_) {} + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + // kernel::CollectorBox candidateKernels(Target) const final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_LAYER_NORMALIZATION_H diff --git a/src/05computation/include/computation/pass/layernorm_fuse.h b/src/05computation/include/computation/pass/layernorm_fuse.h new file mode 100644 index 00000000..2a907205 --- /dev/null +++ b/src/05computation/include/computation/pass/layernorm_fuse.h @@ -0,0 +1,134 @@ +#ifndef COMPUTATION_LAYERNORM_FUSE_H +#define COMPUTATION_LAYERNORM_FUSE_H + +#include "../graph.h" +#include "computation/operators/layernorm.h" +#include "computation/operators/reduce.h" +#include "computation/operators/simple_binary.h" +#include "computation/operators/simple_unary.h" +#include "computation/pass/converter.h" + +namespace refactor::computation { + + class LayernormFuse : public Converter { + public: + virtual bool execute(const std::shared_ptr &g) const override { + auto nodesList = g->internal().nodes(); + size_t count = 0; + for (auto opMatch : nodesList) { + if (opMatch->info().op == nullptr) { + continue; + } + size_t optype = opMatch->info().op->opTypeId(); + if (optype != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) { + continue; + } + if (opMatch->successors().size() < 2) { + continue; + } + auto input = opMatch->inputs()[0]->info().tensor; + auto targets = opMatch->outputs()[0]->targets(); + auto ReduceMeanOp = *targets.begin(); + auto SubOp1 = *(std::next(targets.begin())); + if (ReduceMeanOp == nullptr || SubOp1 == nullptr || + ReduceMeanOp->info().op->opTypeId() != Reduce::typeId(refactor::kernel::ReduceType::Mean) || + SubOp1->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Sub)) { + continue; + } + auto reduceOp = dynamic_cast(ReduceMeanOp->info().op.get()); + auto axes = reduceOp->axes; + if (axes.size() != 1) { + continue; + } + auto keepDims = reduceOp->keepDims; + if (ReduceMeanOp->successors().size() != 1 || *(ReduceMeanOp->outputs()[0]->targets().begin()) != SubOp1) { + continue; + } + if (SubOp1->successors().size() != 2) { + continue; + } + auto targets1 = SubOp1->outputs()[0]->targets(); + auto PowOp = *targets1.begin(); + auto DivOp = *(std::next(targets1.begin())); + if (PowOp == nullptr || DivOp == nullptr || + PowOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Pow) || + DivOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Div)) { + continue; + } + if (PowOp->successors().size() != 1 || DivOp->successors().size() != 1) { + continue; + } + auto ReduceMeanOp1 = *(PowOp->outputs()[0]->targets().begin()); + auto MulOp = *(DivOp->outputs()[0]->targets().begin()); + if (ReduceMeanOp1 == nullptr || MulOp == nullptr || + ReduceMeanOp1->info().op->opTypeId() != Reduce::typeId(refactor::kernel::ReduceType::Mean) || + MulOp->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Mul)) { + continue; + } + auto reduce1Op = dynamic_cast(ReduceMeanOp1->info().op.get()); + auto axes1 = reduce1Op->axes; + if (axes != axes1) { + continue; + } + if (auto keepDims1 = reduce1Op->keepDims; keepDims != keepDims1) { + continue; + } + if (MulOp->successors().size() != 1 || ReduceMeanOp1->successors().size() != 1) { + continue; + } + auto AddOrSqrtOp = *(ReduceMeanOp1->outputs()[0]->targets().begin()); + auto AddOp2 = *(MulOp->outputs()[0]->targets().begin()); + if (AddOrSqrtOp == nullptr || AddOp2 == nullptr || + AddOp2->info().op->opTypeId() != SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) { + continue; + } + if (AddOrSqrtOp->successors().size() != 1) { + continue; + } + float epsilon = 0.0; + if (auto AddOp = AddOrSqrtOp; AddOp->info().op->opTypeId() == SimpleBinary::typeId(refactor::kernel::SimpleBinaryType::Add)) { + auto SqrtOp = *(AddOp->outputs()[0]->targets().begin()); + if (SqrtOp == nullptr || SqrtOp->info().op->opTypeId() != SimpleUnary::typeId(refactor::kernel::SimpleUnaryType::Sqrt)) { + continue; + } + if (SqrtOp->successors().size() != 1 || *(SqrtOp->outputs()[0]->targets().begin()) != DivOp) { + continue; + } + // start replace with LayernormOp + if (auto t = AddOp->inputs()[1]->info().tensor->data; t) { + epsilon = *t->get(); + } + } else if (auto SqrtOp = AddOrSqrtOp; SqrtOp->info().op->opTypeId() == SimpleUnary::typeId(refactor::kernel::SimpleUnaryType::Sqrt)) { + if (*(SqrtOp->outputs()[0]->targets().begin()) != DivOp) { + continue; + } + } else { + continue; + } + + int axis = axes[0]; + auto layernormOp = g->internal().pushNode( + {std::make_unique(epsilon, axis), fmt::format("Layernorm", count)}, + {g->internal().shareEdge({Tensor::share(input->dataType, input->shape), fmt::format("Layernorm_{}_out", count)})}); + layernormOp->connect(0, opMatch->outputs()[0]); + layernormOp->connect(1, MulOp->inputs()[1]); + layernormOp->connect(2, AddOp2->inputs()[1]); + if (AddOp2->outputs()[0]->targets().size() == 0) {//global output + g->internal().replaceOutput(AddOp2->outputs()[0], layernormOp->outputs()[0]); + } else { + for (auto node : AddOp2->outputs()[0]->targets()) { + auto it = std::find(node->inputs().begin(), node->inputs().end(), AddOp2->outputs()[0]); + node->reconnect(node->inputs()[std::distance(node->inputs().begin(), it)], layernormOp->outputs()[0]); + } + } + count++; + g->internal().cleanup(); + } + return true; + }; + }; + + +}// namespace refactor::computation + +#endif// COMPUTATION_LAYERNORM_FUSE_H diff --git a/src/05computation/include/computation/pass_register.h b/src/05computation/include/computation/pass_register.h index 4bdd5e8c..6f883023 100644 --- a/src/05computation/include/computation/pass_register.h +++ b/src/05computation/include/computation/pass_register.h @@ -2,6 +2,7 @@ #define COMPUTATION_PASS_REGISTER_H #include "pass/conv_to_matmul.h" #include "pass/converter.h" +#include "pass/layernorm_fuse.h" #include "pass/matmul_transpose.h" namespace refactor::computation { @@ -10,6 +11,7 @@ namespace refactor::computation { #define REGISTER(PASS, NAME) static ConverterRegister NAME("" #NAME); REGISTER(MatMulTransposeFuse, MatMulTransposeFuse) REGISTER(ConvToMatmul, ConvToMatmul) + REGISTER(LayernormFuse, LayernormFuse) }; diff --git a/src/05computation/src/graph.cc b/src/05computation/src/graph.cc index 28295719..caf4761b 100644 --- a/src/05computation/src/graph.cc +++ b/src/05computation/src/graph.cc @@ -220,6 +220,7 @@ namespace refactor::computation { void Graph::optimize() { auto graphMutant = GraphMutant(*this); std::vector passes = { + "LayernormFuse", // "MatMulTransposeFuse", // "ConvToMatmul", }; diff --git a/src/05computation/src/operators/layernorm.cc b/src/05computation/src/operators/layernorm.cc new file mode 100644 index 00000000..b34d0e91 --- /dev/null +++ b/src/05computation/src/operators/layernorm.cc @@ -0,0 +1,22 @@ +#include "computation/operators/layernorm.h" + +namespace refactor::computation { + using Op = LayerNormalization; + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "LayerNormalization"; } + auto Op::serialize() const noexcept -> std::string { + union code { + float f; + int32_t i; + }; + return fmt::format(("{}({:e}={:#010x},{})"), + name(), epsilon, + code{epsilon}.i, axis); + } + +}// namespace refactor::computation diff --git a/src/05computation/test/test_pass/test_layernorm_fuse.cpp b/src/05computation/test/test_pass/test_layernorm_fuse.cpp new file mode 100644 index 00000000..13225432 --- /dev/null +++ b/src/05computation/test/test_pass/test_layernorm_fuse.cpp @@ -0,0 +1,117 @@ +#include "computation/graph.h" +#include "computation/operators/reduce.h" +#include "computation/operators/simple_binary.h" +#include "computation/operators/simple_unary.h" +#include +#include + +namespace refactor::computation { + refactor::graph_topo::Builder TestLayernormFuseGraphBuild() { + auto nodes = std::unordered_map{}; + absl::InlinedVector axes = {2}; + uint32_t rank = 3; + bool keepDims = true; + nodes[0] = Node{std::make_unique(refactor::kernel::SimpleBinaryType::Add), "add1"}; + nodes[1] = Node{std::make_unique(refactor::kernel::ReduceType::Mean, axes, rank, keepDims), "reducemean1"}; + nodes[2] = Node{std::make_unique(refactor::kernel::SimpleBinaryType::Sub), "sub"}; + nodes[3] = Node{std::make_unique(refactor::kernel::SimpleBinaryType::Pow), "pow"}; + nodes[4] = Node{std::make_unique(refactor::kernel::ReduceType::Mean, axes, rank, keepDims), "reducemean1"}; + nodes[5] = Node{std::make_unique(refactor::kernel::SimpleBinaryType::Add), "add2"}; + nodes[6] = Node{std::make_unique(refactor::kernel::SimpleUnaryType::Sqrt), "sqrt"}; + nodes[7] = Node{std::make_unique(refactor::kernel::SimpleBinaryType::Div), "div"}; + nodes[8] = Node{std::make_unique(refactor::kernel::SimpleBinaryType::Mul), "mul"}; + nodes[9] = Node{std::make_unique(refactor::kernel::SimpleBinaryType::Add), "add3"}; + nodes[10] = Node{std::make_unique(refactor::kernel::SimpleBinaryType::Add), "add4"}; + + + auto tensor0 = Tensor::share(DataType::F32, {64, 101, 768}, LayoutType::Others); + auto tensor1 = Tensor::share(DataType::F32, {1, 101, 768}, LayoutType::Others); + auto tensor2 = Tensor::share(DataType::F32, {64, 101, 768}, LayoutType::Others); + auto tensor3 = Tensor::share(DataType::F32, {64, 101, 1}, LayoutType::Others); + auto tensor4 = Tensor::share(DataType::F32, {64, 101, 768}, LayoutType::Others); + auto tensor5 = Tensor::share(DataType::F32, {64, 101, 768}, LayoutType::Others); + auto tensor6 = Tensor::share(DataType::F32, {64, 101, 1}, LayoutType::Others); + auto tensor7 = Tensor::share(DataType::F32, {}, LayoutType::Others); + auto tensor8 = Tensor::share(DataType::F32, {64, 101, 1}, LayoutType::Others); + auto tensor9 = Tensor::share(DataType::F32, {64, 101, 1}, LayoutType::Others); + auto tensor10 = Tensor::share(DataType::F32, {64, 101, 768}, LayoutType::Others); + auto tensor11 = Tensor::share(DataType::F32, {768}, LayoutType::Others); + auto tensor12 = Tensor::share(DataType::F32, {64, 101, 768}, LayoutType::Others); + auto tensor13 = Tensor::share(DataType::F32, {768}, LayoutType::Others); + auto tensor14 = Tensor::share(DataType::F32, {64, 101, 768}, LayoutType::Others); + auto tensor15 = Tensor::share(DataType::F32, {}, LayoutType::Others); + auto tensor16 = Tensor::share(DataType::F32, {64, 101, 768}, LayoutType::Others); + + auto scale = reinterpret_cast(tensor11->malloc()); + std::iota(scale, scale + tensor11->elementsSize(), 1.0); + auto bias = reinterpret_cast(tensor13->malloc()); + std::iota(bias, bias + tensor13->elementsSize(), 0.0); + float epsilon_ = 0.000009999999747378752; + std::memcpy(tensor7->malloc(), &epsilon_, tensor7->bytesSize()); + float pow = 2.0; + std::memcpy(tensor15->malloc(), &pow, tensor15->bytesSize()); + + return { + { + {0, {{0, 1}, {2}}}, //add + {1, {{2}, {3}}}, //reducemean + {2, {{3, 2}, {4}}}, //sub + {3, {{4, 15}, {5}}}, //pow + {4, {{5}, {6}}}, //reducemean + {5, {{6, 7}, {8}}}, //add + {6, {{8}, {9}}}, //sqrt + {7, {{9, 4}, {10}}}, //div + {8, {{10, 11}, {12}}},//mul + {9, {{12, 13}, {14}}},//add + {10, {{14, 2}, {16}}}, + }, + { + 0, + 1, + 7, + 11, + 13, + 15, + }, // global inputs + {16},// global outputs + std::move(nodes), + { + {0, {tensor0, "input0"}}, + {1, {tensor1, "input1"}}, + {2, {tensor2, "add1_output"}}, + {3, {tensor3, "reducemean1_output"}}, + {4, {tensor4, "sub_output"}}, + {5, {tensor5, "pow_output"}}, + {6, {tensor6, "reducemean2_output"}}, + {7, {tensor7, "add2_input"}}, + {8, {tensor8, "add2_output"}}, + {9, {tensor9, "sqrt_output"}}, + {10, {tensor10, "div_output"}}, + {11, {tensor11, "mul_input"}}, + {12, {tensor12, "mul_output"}}, + {13, {tensor13, "add3_input"}}, + {14, {tensor14, "output"}}, + {15, {tensor15, "pow_input"}}, + {16, {tensor15, "add4_output"}}, + }, + }; + } + + TEST(Graph, LayerNormFuse) { + auto graphTopo = TestLayernormFuseGraphBuild().build(); + fmt::println("{}", graphTopo.topology.toString()); + Graph g(std::move(graphTopo)); + g.optimize(); + auto const &g_ = g.internal().contiguous(); + fmt::println("{}", g_.topology.toString()); + fmt::println("Nodes info :"); + for (size_t i = 0; i < g_.nodes.size(); ++i) { + fmt::println("{}. \"{}\"", i, g_.nodes[i].name); + } + fmt::println("\n Edges info :"); + for (size_t i = 0; i < g_.edges.size(); ++i) { + fmt::println("{}. \"{}\" Shape is {}, Layout is {}", i, g_.edges[i].name, + vec2str(g_.edges[i].tensor->shape), g_.edges[i].tensor->layout.name()); + } + } +}// namespace refactor::computation \ No newline at end of file