From 5e13361af052517ef6dede30882f6366b10a17e5 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Thu, 25 Jan 2024 14:56:36 +0800 Subject: [PATCH] feat(kernel): add test for max/min cpu kernel --- src/04kernel/src/kernels/select/cpu_kernel.cc | 24 +++---- src/04kernel/test/kernels/select/test_cpu.cpp | 72 +++++++++++++++++++ 2 files changed, 84 insertions(+), 12 deletions(-) create mode 100644 src/04kernel/test/kernels/select/test_cpu.cpp diff --git a/src/04kernel/src/kernels/select/cpu_kernel.cc b/src/04kernel/src/kernels/select/cpu_kernel.cc index c3abe61c..bf5a2805 100644 --- a/src/04kernel/src/kernels/select/cpu_kernel.cc +++ b/src/04kernel/src/kernels/select/cpu_kernel.cc @@ -51,34 +51,34 @@ namespace refactor::kernel { } if (broadcaster.needBroadcast()) { - return [n = broadcaster.outputsCount, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { auto output = reinterpret_cast(outputs[0]); - for (auto i : range0_(n)) { + for (auto i : range0_(broadcaster.outputsCount)) { + std::vector ans(broadcaster.inputsCount); + broadcaster.locate(i, ans.data()); for (auto inputIdx : range0_(inputsNum)) { auto input = reinterpret_cast(inputs[inputIdx]); if (inputIdx == 0) { - output[i] = input[i]; + output[i] = input[ans[inputIdx]]; } else { - output[i] = op(output[i], input[i]); + output[i] = op(output[i], input[ans[inputIdx]]); } } - }; + } }; } else { - return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + return [n = broadcaster.outputsCount, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { auto output = reinterpret_cast(outputs[0]); - for (auto i : range0_(broadcaster.outputsCount)) { - std::vector ans(broadcaster.inputsCount); - broadcaster.locate(i, ans.data()); + for (auto i : range0_(n)) { for (auto inputIdx : range0_(inputsNum)) { auto input = reinterpret_cast(inputs[inputIdx]); if (inputIdx == 0) { - output[i] = input[ans[inputIdx]]; + output[i] = input[i]; } else { - output[i] = op(output[i], input[ans[inputIdx]]); + output[i] = op(output[i], input[i]); } } - } + }; }; } } diff --git a/src/04kernel/test/kernels/select/test_cpu.cpp b/src/04kernel/test/kernels/select/test_cpu.cpp new file mode 100644 index 00000000..2db98c17 --- /dev/null +++ b/src/04kernel/test/kernels/select/test_cpu.cpp @@ -0,0 +1,72 @@ +#include "../../../src/kernels/select/cpu_kernel.hh" +#include +#include +#include + +using namespace refactor; +using namespace kernel; + +static void testSelect(const SelectType selectType, const std::vector &shapes, const std::vector> &data, + const std::vector expectData) { + // build routine + TensorRefs dataTensors; + std::vector tensorsVec; + for (size_t i = 0; i < shapes.size(); ++i) { + tensorsVec.push_back(Tensor(DataType::F32, shapes[i], LayoutType::Others, nullptr)); + } + for (size_t i = 0; i < shapes.size(); ++i) { + dataTensors.push_back(std::cref(tensorsVec[i])); + } + auto kernel = SelectCpu::build(selectType, dataTensors); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input output data + void const *inputs[data.size()]; + for (size_t i = 0; i < data.size(); ++i) { + inputs[i] = data[i].data(); + } + std::vector + out(expectData.size()); + void *outputs[]{out.data()}; + // inference + routine(res, nullptr, inputs, outputs); + // check + for (auto i : range0_(expectData.size())) { + EXPECT_FLOAT_EQ(expectData[i], out[i]); + } +} + +TEST(kernel, SelectCpu) { + // no need broadcast + testSelect(SelectType::Max, + {{1, 3}, {1, 3}, {1, 3}}, + {{3, 2, 1}, {1, 4, 4}, {2, 5, 3}}, + {3, 5, 4}); + + testSelect(SelectType::Min, + {{1, 3}, {1, 3}, {1, 3}}, + {{3, 2, 1}, {1, 4, 4}, {2, 5, 3}}, + {1, 2, 1}); + + // need broadcast + testSelect(SelectType::Max, + {{3}, {1, 3}, {1, 3}}, + {{3, 3, 3}, {1, 4, 4}, {2, 5, 3}}, + {3, 5, 4}); + + testSelect(SelectType::Min, + {{3}, {1, 3}, {1, 3}}, + {{3, 3, 3}, {1, 4, 4}, {2, 5, 3}}, + {1, 3, 3}); + + testSelect(SelectType::Max, + {{1}, {1, 3}, {1, 3}}, + {{3}, {1, 4, 4}, {2, 5, 3}}, + {3, 5, 4}); + + testSelect(SelectType::Min, + {{1}, {1, 3}, {1, 3}}, + {{3}, {1, 4, 4}, {2, 5, 3}}, + {1, 3, 3}); +}