Skip to content

Commit

Permalink
fead(kernel): add max cpu kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
kilinchange committed Jan 25, 2024
1 parent 594e06b commit 4f9d09b
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 21 deletions.
24 changes: 24 additions & 0 deletions src/04kernel/include/kernel/collectors/select.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef KERNEL_SELECT_H
#define KERNEL_SELECT_H

#include "../collector.h"

namespace refactor::kernel {

enum class SelectType {
Max,
Min,
};

struct SelectCollector final : public InfoCollector {
SelectType selectType;

SelectCollector(decltype(_target), SelectType) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_SELECT_H
16 changes: 0 additions & 16 deletions src/04kernel/include/kernel/selector.h

This file was deleted.

22 changes: 22 additions & 0 deletions src/04kernel/src/collectors/select.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "kernel/collectors/select.h"

namespace refactor::kernel {

SelectCollector::SelectCollector(decltype(_target) target, SelectType type) noexcept
: InfoCollector(target), selectType(type) {}

std::vector<KernelBox>
SelectCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
break;
case decltype(_target)::Nvidia:
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
107 changes: 107 additions & 0 deletions src/04kernel/src/kernels/select/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#include "cpu_kernel.hh"
#include <execution>

namespace refactor::kernel {
using K = SelectCpu;
using DT = DataType;

K::SelectCpu(
decltype(dataType) dataType_,
decltype(selectType) selectType_,
decltype(broadcaster) broadcaster_,
decltype(inputsNum) inputsNum_) noexcept
: dataType(dataType_),
selectType(selectType_),
broadcaster(broadcaster_),
inputsNum(inputsNum_) {}

auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox {
auto const &x = inputs_[0].get();
return x.dataType.isCpuNumberic()
? std::make_unique<K>(x.dataType, selectType_, Broadcaster(inputs_), inputs_.size())
: nullptr;
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing select operation on generic cpu";
}

template<class T>
auto lowerTyped(SelectType selectType, Broadcaster broadcaster, size_t inputsNum) noexcept -> RoutineWorkspace {
using namespace runtime;

T(*op)
(T const a, T const b);
switch (selectType) {
case SelectType::Max:
op = [](T const a, T const b) { return std::max(a, b); };
break;
case SelectType::Min:
op = [](T const a, T const b) { return std::min(a, b); };
break;
default:
UNREACHABLE();
}

if (broadcaster.needBroadcast()) {
return [n = broadcaster.outputsCount, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto output = reinterpret_cast<T *>(outputs[0]);
for (auto i : range0_(n)) {
for (auto inputIdx : range0_(inputsNum)) {
auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
if (inputIdx == 0) {
output[i] = input[i];
} else {
output[i] = op(output[i], input[i]);
}
}
};
};
} else {
return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto output = reinterpret_cast<T *>(outputs[0]);
for (auto i : range0_(broadcaster.outputsCount)) {
std::vector<dim_t> ans(broadcaster.inputsCount);
broadcaster.locate(i, ans.data());
for (auto inputIdx : range0_(inputsNum)) {
auto input = reinterpret_cast<const T *>(inputs[inputIdx]);
if (inputIdx == 0) {
output[i] = input[ans[inputIdx]];
} else {
output[i] = op(output[i], input[ans[inputIdx]]);
}
}
}
};
}
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
#define CASE(DT) \
case DataType::DT: \
return lowerTyped<primitive<DataType::DT>::type>(selectType, broadcaster, inputsNum)

switch (dataType) {
CASE(F32);
CASE(U8);
CASE(I8);
CASE(U16);
CASE(I16);
CASE(I32);
CASE(I64);
CASE(F64);
CASE(U32);
CASE(U64);
default:
UNREACHABLE();
}
}

}// namespace refactor::kernel
29 changes: 29 additions & 0 deletions src/04kernel/src/kernels/select/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef KERNEL_SELECT_CPU_KERNEL_HH
#define KERNEL_SELECT_CPU_KERNEL_HH

#include "kernel/attributes/broadcaster.h"
#include "kernel/collectors/select.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct SelectCpu final : public Kernel {
DataType dataType;
SelectType selectType;
Broadcaster broadcaster;
size_t inputsNum;

SelectCpu(decltype(dataType), decltype(selectType), decltype(broadcaster), decltype(inputsNum)) noexcept;

static KernelBox build(SelectType, TensorRefs) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
RoutineWorkspace lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel

#endif// KERNEL_Select_CPU_KERNEL_HH
8 changes: 3 additions & 5 deletions src/05computation/include/computation/operators/select.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
#define COMPUTATION_SELECT_H

#include "../operator.h"
#include "kernel/collectors/select.h"

namespace refactor::computation {

enum class SelectType {
Max,
Min,
};
using kernel::SelectType;

struct Select final : public Operator {
SelectType type;
Expand All @@ -19,6 +16,7 @@ namespace refactor::computation {
static size_t typeId(SelectType) noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
kernel::CollectorBox candidateKernels(Target) const noexcept final;
};

}// namespace refactor::computation
Expand Down
6 changes: 6 additions & 0 deletions src/05computation/src/operators/select.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "computation/operators/select.h"
#include "kernel/collectors/select.h"

namespace refactor::computation {

Expand Down Expand Up @@ -30,4 +31,9 @@ namespace refactor::computation {
}
}

auto Select::candidateKernels(Target target) const noexcept -> kernel::CollectorBox {
using Collector_ = kernel::SelectCollector;
return std::make_unique<Collector_>(target, type);
}

}// namespace refactor::computation

0 comments on commit 4f9d09b

Please sign in to comment.