From 4f9d09b830596fd3884a2c32f96e44b44bbd56cf Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 24 Jan 2024 18:02:49 +0800 Subject: [PATCH] fead(kernel): add max cpu kernel --- .../include/kernel/collectors/select.h | 24 ++++ src/04kernel/include/kernel/selector.h | 16 --- src/04kernel/src/collectors/select.cc | 22 ++++ src/04kernel/src/kernels/select/cpu_kernel.cc | 107 ++++++++++++++++++ src/04kernel/src/kernels/select/cpu_kernel.hh | 29 +++++ .../include/computation/operators/select.h | 8 +- src/05computation/src/operators/select.cc | 6 + 7 files changed, 191 insertions(+), 21 deletions(-) create mode 100644 src/04kernel/include/kernel/collectors/select.h delete mode 100644 src/04kernel/include/kernel/selector.h create mode 100644 src/04kernel/src/collectors/select.cc create mode 100644 src/04kernel/src/kernels/select/cpu_kernel.cc create mode 100644 src/04kernel/src/kernels/select/cpu_kernel.hh diff --git a/src/04kernel/include/kernel/collectors/select.h b/src/04kernel/include/kernel/collectors/select.h new file mode 100644 index 000000000..97f7061cd --- /dev/null +++ b/src/04kernel/include/kernel/collectors/select.h @@ -0,0 +1,24 @@ +#ifndef KERNEL_SELECT_H +#define KERNEL_SELECT_H + +#include "../collector.h" + +namespace refactor::kernel { + + enum class SelectType { + Max, + Min, + }; + + struct SelectCollector final : public InfoCollector { + SelectType selectType; + + SelectCollector(decltype(_target), SelectType) noexcept; + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SELECT_H diff --git a/src/04kernel/include/kernel/selector.h b/src/04kernel/include/kernel/selector.h deleted file mode 100644 index c93f2c913..000000000 --- a/src/04kernel/include/kernel/selector.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef KERNEL_SELECTOR_H -#define KERNEL_SELECTOR_H - -#include "kernel.h" - -namespace refactor::kernel { - - class Selector { - public: - virtual ~Selector() = default; - virtual KernelBox select(std::vector) const = 0; - }; - -}// namespace refactor::kernel - -#endif// KERNEL_SELECTOR_H diff --git a/src/04kernel/src/collectors/select.cc b/src/04kernel/src/collectors/select.cc new file mode 100644 index 000000000..45da6483c --- /dev/null +++ b/src/04kernel/src/collectors/select.cc @@ -0,0 +1,22 @@ +#include "kernel/collectors/select.h" + +namespace refactor::kernel { + + SelectCollector::SelectCollector(decltype(_target) target, SelectType type) noexcept + : InfoCollector(target), selectType(type) {} + + std::vector + SelectCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + break; + case decltype(_target)::Nvidia: + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/select/cpu_kernel.cc b/src/04kernel/src/kernels/select/cpu_kernel.cc new file mode 100644 index 000000000..c3abe61c5 --- /dev/null +++ b/src/04kernel/src/kernels/select/cpu_kernel.cc @@ -0,0 +1,107 @@ +#include "cpu_kernel.hh" +#include + +namespace refactor::kernel { + using K = SelectCpu; + using DT = DataType; + + K::SelectCpu( + decltype(dataType) dataType_, + decltype(selectType) selectType_, + decltype(broadcaster) broadcaster_, + decltype(inputsNum) inputsNum_) noexcept + : dataType(dataType_), + selectType(selectType_), + broadcaster(broadcaster_), + inputsNum(inputsNum_) {} + + auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox { + auto const &x = inputs_[0].get(); + return x.dataType.isCpuNumberic() + ? std::make_unique(x.dataType, selectType_, Broadcaster(inputs_), inputs_.size()) + : nullptr; + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing select operation on generic cpu"; + } + + template + auto lowerTyped(SelectType selectType, Broadcaster broadcaster, size_t inputsNum) noexcept -> RoutineWorkspace { + using namespace runtime; + + T(*op) + (T const a, T const b); + switch (selectType) { + case SelectType::Max: + op = [](T const a, T const b) { return std::max(a, b); }; + break; + case SelectType::Min: + op = [](T const a, T const b) { return std::min(a, b); }; + break; + default: + UNREACHABLE(); + } + + if (broadcaster.needBroadcast()) { + return [n = broadcaster.outputsCount, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto output = reinterpret_cast(outputs[0]); + for (auto i : range0_(n)) { + for (auto inputIdx : range0_(inputsNum)) { + auto input = reinterpret_cast(inputs[inputIdx]); + if (inputIdx == 0) { + output[i] = input[i]; + } else { + output[i] = op(output[i], input[i]); + } + } + }; + }; + } else { + return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto output = reinterpret_cast(outputs[0]); + for (auto i : range0_(broadcaster.outputsCount)) { + std::vector ans(broadcaster.inputsCount); + broadcaster.locate(i, ans.data()); + for (auto inputIdx : range0_(inputsNum)) { + auto input = reinterpret_cast(inputs[inputIdx]); + if (inputIdx == 0) { + output[i] = input[ans[inputIdx]]; + } else { + output[i] = op(output[i], input[ans[inputIdx]]); + } + } + } + }; + } + } + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { +#define CASE(DT) \ + case DataType::DT: \ + return lowerTyped::type>(selectType, broadcaster, inputsNum) + + switch (dataType) { + CASE(F32); + CASE(U8); + CASE(I8); + CASE(U16); + CASE(I16); + CASE(I32); + CASE(I64); + CASE(F64); + CASE(U32); + CASE(U64); + default: + UNREACHABLE(); + } + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/select/cpu_kernel.hh b/src/04kernel/src/kernels/select/cpu_kernel.hh new file mode 100644 index 000000000..9b8d9a331 --- /dev/null +++ b/src/04kernel/src/kernels/select/cpu_kernel.hh @@ -0,0 +1,29 @@ +#ifndef KERNEL_SELECT_CPU_KERNEL_HH +#define KERNEL_SELECT_CPU_KERNEL_HH + +#include "kernel/attributes/broadcaster.h" +#include "kernel/collectors/select.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct SelectCpu final : public Kernel { + DataType dataType; + SelectType selectType; + Broadcaster broadcaster; + size_t inputsNum; + + SelectCpu(decltype(dataType), decltype(selectType), decltype(broadcaster), decltype(inputsNum)) noexcept; + + static KernelBox build(SelectType, TensorRefs) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + RoutineWorkspace lower(Resources &) const noexcept final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_Select_CPU_KERNEL_HH diff --git a/src/05computation/include/computation/operators/select.h b/src/05computation/include/computation/operators/select.h index 9d6b3f6ca..6d3a64588 100644 --- a/src/05computation/include/computation/operators/select.h +++ b/src/05computation/include/computation/operators/select.h @@ -2,13 +2,10 @@ #define COMPUTATION_SELECT_H #include "../operator.h" +#include "kernel/collectors/select.h" namespace refactor::computation { - - enum class SelectType { - Max, - Min, - }; + using kernel::SelectType; struct Select final : public Operator { SelectType type; @@ -19,6 +16,7 @@ namespace refactor::computation { static size_t typeId(SelectType) noexcept; size_t opTypeId() const noexcept final; std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const noexcept final; }; }// namespace refactor::computation diff --git a/src/05computation/src/operators/select.cc b/src/05computation/src/operators/select.cc index 10452a0f6..90f175657 100644 --- a/src/05computation/src/operators/select.cc +++ b/src/05computation/src/operators/select.cc @@ -1,4 +1,5 @@ #include "computation/operators/select.h" +#include "kernel/collectors/select.h" namespace refactor::computation { @@ -30,4 +31,9 @@ namespace refactor::computation { } } + auto Select::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { + using Collector_ = kernel::SelectCollector; + return std::make_unique(target, type); + } + }// namespace refactor::computation