From 4f9d09b830596fd3884a2c32f96e44b44bbd56cf Mon Sep 17 00:00:00 2001
From: kilinchange <kilinchange@163.com>
Date: Wed, 24 Jan 2024 18:02:49 +0800
Subject: [PATCH] fead(kernel): add max cpu kernel

---
 .../include/kernel/collectors/select.h        |  24 ++++
 src/04kernel/include/kernel/selector.h        |  16 ---
 src/04kernel/src/collectors/select.cc         |  22 ++++
 src/04kernel/src/kernels/select/cpu_kernel.cc | 107 ++++++++++++++++++
 src/04kernel/src/kernels/select/cpu_kernel.hh |  29 +++++
 .../include/computation/operators/select.h    |   8 +-
 src/05computation/src/operators/select.cc     |   6 +
 7 files changed, 191 insertions(+), 21 deletions(-)
 create mode 100644 src/04kernel/include/kernel/collectors/select.h
 delete mode 100644 src/04kernel/include/kernel/selector.h
 create mode 100644 src/04kernel/src/collectors/select.cc
 create mode 100644 src/04kernel/src/kernels/select/cpu_kernel.cc
 create mode 100644 src/04kernel/src/kernels/select/cpu_kernel.hh

diff --git a/src/04kernel/include/kernel/collectors/select.h b/src/04kernel/include/kernel/collectors/select.h
new file mode 100644
index 000000000..97f7061cd
--- /dev/null
+++ b/src/04kernel/include/kernel/collectors/select.h
@@ -0,0 +1,24 @@
+#ifndef KERNEL_SELECT_H
+#define KERNEL_SELECT_H
+
+#include "../collector.h"
+
+namespace refactor::kernel {
+
+    enum class SelectType {
+        Max,
+        Min,
+    };
+
+    struct SelectCollector final : public InfoCollector {
+        SelectType selectType;
+
+        SelectCollector(decltype(_target), SelectType) noexcept;
+
+        std::vector<KernelBox>
+        filter(TensorRefs inputs, TensorRefs outputs) const final;
+    };
+
+}// namespace refactor::kernel
+
+#endif// KERNEL_SELECT_H
diff --git a/src/04kernel/include/kernel/selector.h b/src/04kernel/include/kernel/selector.h
deleted file mode 100644
index c93f2c913..000000000
--- a/src/04kernel/include/kernel/selector.h
+++ /dev/null
@@ -1,16 +0,0 @@
-#ifndef KERNEL_SELECTOR_H
-#define KERNEL_SELECTOR_H
-
-#include "kernel.h"
-
-namespace refactor::kernel {
-
-    class Selector {
-    public:
-        virtual ~Selector() = default;
-        virtual KernelBox select(std::vector<KernelBox>) const = 0;
-    };
-
-}// namespace refactor::kernel
-
-#endif// KERNEL_SELECTOR_H
diff --git a/src/04kernel/src/collectors/select.cc b/src/04kernel/src/collectors/select.cc
new file mode 100644
index 000000000..45da6483c
--- /dev/null
+++ b/src/04kernel/src/collectors/select.cc
@@ -0,0 +1,22 @@
+#include "kernel/collectors/select.h"
+
+namespace refactor::kernel {
+
+    SelectCollector::SelectCollector(decltype(_target) target, SelectType type) noexcept
+        : InfoCollector(target), selectType(type) {}
+
+    std::vector<KernelBox>
+    SelectCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
+        std::vector<KernelBox> ans;
+        switch (_target) {
+            case decltype(_target)::Cpu:
+                break;
+            case decltype(_target)::Nvidia:
+                break;
+            default:
+                UNREACHABLEX(void, "Unknown target");
+        }
+        return ans;
+    }
+
+}// namespace refactor::kernel
diff --git a/src/04kernel/src/kernels/select/cpu_kernel.cc b/src/04kernel/src/kernels/select/cpu_kernel.cc
new file mode 100644
index 000000000..c3abe61c5
--- /dev/null
+++ b/src/04kernel/src/kernels/select/cpu_kernel.cc
@@ -0,0 +1,107 @@
+#include "cpu_kernel.hh"
+#include <execution>
+
+namespace refactor::kernel {
+    using K = SelectCpu;
+    using DT = DataType;
+
+    K::SelectCpu(
+        decltype(dataType) dataType_,
+        decltype(selectType) selectType_,
+        decltype(broadcaster) broadcaster_,
+        decltype(inputsNum) inputsNum_) noexcept
+        : dataType(dataType_),
+          selectType(selectType_),
+          broadcaster(broadcaster_),
+          inputsNum(inputsNum_) {}
+
+    auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox {
+        auto const &x = inputs_[0].get();
+        return x.dataType.isCpuNumberic()
+                   ? std::make_unique<K>(x.dataType, selectType_, Broadcaster(inputs_), inputs_.size())
+                   : nullptr;
+    }
+    auto K::typeId() noexcept -> size_t {
+        static uint8_t ID = 1;
+        return reinterpret_cast<size_t>(&ID);
+    }
+
+    auto K::kernelTypeId() const noexcept -> size_t {
+        return typeId();
+    }
+    auto K::description() const noexcept -> std::string_view {
+        return "Performing select operation on generic cpu";
+    }
+
+    template<class T>
+    auto lowerTyped(SelectType selectType, Broadcaster broadcaster, size_t inputsNum) noexcept -> RoutineWorkspace {
+        using namespace runtime;
+
+        T(*op)
+        (T const a, T const b);
+        switch (selectType) {
+            case SelectType::Max:
+                op = [](T const a, T const b) { return std::max(a, b); };
+                break;
+            case SelectType::Min:
+                op = [](T const a, T const b) { return std::min(a, b); };
+                break;
+            default:
+                UNREACHABLE();
+        }
+
+        if (broadcaster.needBroadcast()) {
+            return [n = broadcaster.outputsCount, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
+                auto output = reinterpret_cast<T *>(outputs[0]);
+                for (auto i : range0_(n)) {
+                    for (auto inputIdx : range0_(inputsNum)) {
+                        auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
+                        if (inputIdx == 0) {
+                            output[i] = input[i];
+                        } else {
+                            output[i] = op(output[i], input[i]);
+                        }
+                    }
+                };
+            };
+        } else {
+            return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
+                auto output = reinterpret_cast<T *>(outputs[0]);
+                for (auto i : range0_(broadcaster.outputsCount)) {
+                    std::vector<dim_t> ans(broadcaster.inputsCount);
+                    broadcaster.locate(i, ans.data());
+                    for (auto inputIdx : range0_(inputsNum)) {
+                        auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
+                        if (inputIdx == 0) {
+                            output[i] = input[ans[inputIdx]];
+                        } else {
+                            output[i] = op(output[i], input[ans[inputIdx]]);
+                        }
+                    }
+                }
+            };
+        }
+    }
+
+    auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
+#define CASE(DT)       \
+    case DataType::DT: \
+        return lowerTyped<primitive<DataType::DT>::type>(selectType, broadcaster, inputsNum)
+
+        switch (dataType) {
+            CASE(F32);
+            CASE(U8);
+            CASE(I8);
+            CASE(U16);
+            CASE(I16);
+            CASE(I32);
+            CASE(I64);
+            CASE(F64);
+            CASE(U32);
+            CASE(U64);
+            default:
+                UNREACHABLE();
+        }
+    }
+
+}// namespace refactor::kernel
diff --git a/src/04kernel/src/kernels/select/cpu_kernel.hh b/src/04kernel/src/kernels/select/cpu_kernel.hh
new file mode 100644
index 000000000..9b8d9a331
--- /dev/null
+++ b/src/04kernel/src/kernels/select/cpu_kernel.hh
@@ -0,0 +1,29 @@
+#ifndef KERNEL_SELECT_CPU_KERNEL_HH
+#define KERNEL_SELECT_CPU_KERNEL_HH
+
+#include "kernel/attributes/broadcaster.h"
+#include "kernel/collectors/select.h"
+#include "kernel/kernel.h"
+#include "kernel/tensor.h"
+
+namespace refactor::kernel {
+
+    struct SelectCpu final : public Kernel {
+        DataType dataType;
+        SelectType selectType;
+        Broadcaster broadcaster;
+        size_t inputsNum;
+
+        SelectCpu(decltype(dataType), decltype(selectType), decltype(broadcaster), decltype(inputsNum)) noexcept;
+
+        static KernelBox build(SelectType, TensorRefs) noexcept;
+        static size_t typeId() noexcept;
+
+        size_t kernelTypeId() const noexcept final;
+        std::string_view description() const noexcept final;
+        RoutineWorkspace lower(Resources &) const noexcept final;
+    };
+
+}// namespace refactor::kernel
+
+#endif// KERNEL_Select_CPU_KERNEL_HH
diff --git a/src/05computation/include/computation/operators/select.h b/src/05computation/include/computation/operators/select.h
index 9d6b3f6ca..6d3a64588 100644
--- a/src/05computation/include/computation/operators/select.h
+++ b/src/05computation/include/computation/operators/select.h
@@ -2,13 +2,10 @@
 #define COMPUTATION_SELECT_H
 
 #include "../operator.h"
+#include "kernel/collectors/select.h"
 
 namespace refactor::computation {
-
-    enum class SelectType {
-        Max,
-        Min,
-    };
+    using kernel::SelectType;
 
     struct Select final : public Operator {
         SelectType type;
@@ -19,6 +16,7 @@ namespace refactor::computation {
         static size_t typeId(SelectType) noexcept;
         size_t opTypeId() const noexcept final;
         std::string_view name() const noexcept final;
+        kernel::CollectorBox candidateKernels(Target) const noexcept final;
     };
 
 }// namespace refactor::computation
diff --git a/src/05computation/src/operators/select.cc b/src/05computation/src/operators/select.cc
index 10452a0f6..90f175657 100644
--- a/src/05computation/src/operators/select.cc
+++ b/src/05computation/src/operators/select.cc
@@ -1,4 +1,5 @@
 #include "computation/operators/select.h"
+#include "kernel/collectors/select.h"
 
 namespace refactor::computation {
 
@@ -30,4 +31,9 @@ namespace refactor::computation {
         }
     }
 
+    auto Select::candidateKernels(Target target) const noexcept -> kernel::CollectorBox {
+        using Collector_ = kernel::SelectCollector;
+        return std::make_unique<Collector_>(target, type);
+    }
+
 }// namespace refactor::computation