Skip to content

Commit

Permalink
feat(kernel): add test for max/min cpu kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
kilinchange committed Jan 25, 2024
1 parent 4f9d09b commit 5e13361
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/04kernel/src/kernels/select/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,34 @@ namespace refactor::kernel {
}

if (broadcaster.needBroadcast()) {
return [n = broadcaster.outputsCount, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto output = reinterpret_cast<T *>(outputs[0]);
for (auto i : range0_(n)) {
for (auto i : range0_(broadcaster.outputsCount)) {
std::vector<dim_t> ans(broadcaster.inputsCount);
broadcaster.locate(i, ans.data());
for (auto inputIdx : range0_(inputsNum)) {
auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
if (inputIdx == 0) {
output[i] = input[i];
output[i] = input[ans[inputIdx]];
} else {
output[i] = op(output[i], input[i]);
output[i] = op(output[i], input[ans[inputIdx]]);
}
}
};
}
};
} else {
return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
return [n = broadcaster.outputsCount, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto output = reinterpret_cast<T *>(outputs[0]);
for (auto i : range0_(broadcaster.outputsCount)) {
std::vector<dim_t> ans(broadcaster.inputsCount);
broadcaster.locate(i, ans.data());
for (auto i : range0_(n)) {
for (auto inputIdx : range0_(inputsNum)) {
auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
if (inputIdx == 0) {
output[i] = input[ans[inputIdx]];
output[i] = input[i];
} else {
output[i] = op(output[i], input[ans[inputIdx]]);
output[i] = op(output[i], input[i]);
}
}
}
};
};
}
}
Expand Down
72 changes: 72 additions & 0 deletions src/04kernel/test/kernels/select/test_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "../../../src/kernels/select/cpu_kernel.hh"
#include <functional>
#include <gtest/gtest.h>
#include <numeric>

using namespace refactor;
using namespace kernel;

static void testSelect(const SelectType selectType, const std::vector<Shape> &shapes, const std::vector<std::vector<float>> &data,
const std::vector<float> expectData) {
// build routine
TensorRefs dataTensors;
std::vector<Tensor> tensorsVec;
for (size_t i = 0; i < shapes.size(); ++i) {
tensorsVec.push_back(Tensor(DataType::F32, shapes[i], LayoutType::Others, nullptr));
}
for (size_t i = 0; i < shapes.size(); ++i) {
dataTensors.push_back(std::cref(tensorsVec[i]));
}
auto kernel = SelectCpu::build(selectType, dataTensors);
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input output data
void const *inputs[data.size()];
for (size_t i = 0; i < data.size(); ++i) {
inputs[i] = data[i].data();
}
std::vector<float>
out(expectData.size());
void *outputs[]{out.data()};
// inference
routine(res, nullptr, inputs, outputs);
// check
for (auto i : range0_(expectData.size())) {
EXPECT_FLOAT_EQ(expectData[i], out[i]);
}
}

TEST(kernel, SelectCpu) {
// no need broadcast
testSelect(SelectType::Max,
{{1, 3}, {1, 3}, {1, 3}},
{{3, 2, 1}, {1, 4, 4}, {2, 5, 3}},
{3, 5, 4});

testSelect(SelectType::Min,
{{1, 3}, {1, 3}, {1, 3}},
{{3, 2, 1}, {1, 4, 4}, {2, 5, 3}},
{1, 2, 1});

// need broadcast
testSelect(SelectType::Max,
{{3}, {1, 3}, {1, 3}},
{{3, 3, 3}, {1, 4, 4}, {2, 5, 3}},
{3, 5, 4});

testSelect(SelectType::Min,
{{3}, {1, 3}, {1, 3}},
{{3, 3, 3}, {1, 4, 4}, {2, 5, 3}},
{1, 3, 3});

testSelect(SelectType::Max,
{{1}, {1, 3}, {1, 3}},
{{3}, {1, 4, 4}, {2, 5, 3}},
{3, 5, 4});

testSelect(SelectType::Min,
{{1}, {1, 3}, {1, 3}},
{{3}, {1, 4, 4}, {2, 5, 3}},
{1, 3, 3});
}

0 comments on commit 5e13361

Please sign in to comment.