-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
594e06b
commit 4f9d09b
Showing
7 changed files
with
191 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef KERNEL_SELECT_H | ||
#define KERNEL_SELECT_H | ||
|
||
#include "../collector.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
enum class SelectType { | ||
Max, | ||
Min, | ||
}; | ||
|
||
struct SelectCollector final : public InfoCollector { | ||
SelectType selectType; | ||
|
||
SelectCollector(decltype(_target), SelectType) noexcept; | ||
|
||
std::vector<KernelBox> | ||
filter(TensorRefs inputs, TensorRefs outputs) const final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_SELECT_H |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#include "kernel/collectors/select.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
SelectCollector::SelectCollector(decltype(_target) target, SelectType type) noexcept | ||
: InfoCollector(target), selectType(type) {} | ||
|
||
std::vector<KernelBox> | ||
SelectCollector::filter(TensorRefs inputs, TensorRefs outputs) const { | ||
std::vector<KernelBox> ans; | ||
switch (_target) { | ||
case decltype(_target)::Cpu: | ||
break; | ||
case decltype(_target)::Nvidia: | ||
break; | ||
default: | ||
UNREACHABLEX(void, "Unknown target"); | ||
} | ||
return ans; | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#include "cpu_kernel.hh" | ||
#include <execution> | ||
|
||
namespace refactor::kernel { | ||
using K = SelectCpu; | ||
using DT = DataType; | ||
|
||
K::SelectCpu( | ||
decltype(dataType) dataType_, | ||
decltype(selectType) selectType_, | ||
decltype(broadcaster) broadcaster_, | ||
decltype(inputsNum) inputsNum_) noexcept | ||
: dataType(dataType_), | ||
selectType(selectType_), | ||
broadcaster(broadcaster_), | ||
inputsNum(inputsNum_) {} | ||
|
||
auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox { | ||
auto const &x = inputs_[0].get(); | ||
return x.dataType.isCpuNumberic() | ||
? std::make_unique<K>(x.dataType, selectType_, Broadcaster(inputs_), inputs_.size()) | ||
: nullptr; | ||
} | ||
auto K::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
|
||
auto K::kernelTypeId() const noexcept -> size_t { | ||
return typeId(); | ||
} | ||
auto K::description() const noexcept -> std::string_view { | ||
return "Performing select operation on generic cpu"; | ||
} | ||
|
||
template<class T> | ||
auto lowerTyped(SelectType selectType, Broadcaster broadcaster, size_t inputsNum) noexcept -> RoutineWorkspace { | ||
using namespace runtime; | ||
|
||
T(*op) | ||
(T const a, T const b); | ||
switch (selectType) { | ||
case SelectType::Max: | ||
op = [](T const a, T const b) { return std::max(a, b); }; | ||
break; | ||
case SelectType::Min: | ||
op = [](T const a, T const b) { return std::min(a, b); }; | ||
break; | ||
default: | ||
UNREACHABLE(); | ||
} | ||
|
||
if (broadcaster.needBroadcast()) { | ||
return [n = broadcaster.outputsCount, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { | ||
auto output = reinterpret_cast<T *>(outputs[0]); | ||
for (auto i : range0_(n)) { | ||
for (auto inputIdx : range0_(inputsNum)) { | ||
auto input = reinterpret_cast<const T *>(inputs[inputIdx]); | ||
if (inputIdx == 0) { | ||
output[i] = input[i]; | ||
} else { | ||
output[i] = op(output[i], input[i]); | ||
} | ||
} | ||
}; | ||
}; | ||
} else { | ||
return [broadcaster, inputsNum, op](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { | ||
auto output = reinterpret_cast<T *>(outputs[0]); | ||
for (auto i : range0_(broadcaster.outputsCount)) { | ||
std::vector<dim_t> ans(broadcaster.inputsCount); | ||
broadcaster.locate(i, ans.data()); | ||
for (auto inputIdx : range0_(inputsNum)) { | ||
auto input = reinterpret_cast<const T *>(inputs[inputIdx]); | ||
if (inputIdx == 0) { | ||
output[i] = input[ans[inputIdx]]; | ||
} else { | ||
output[i] = op(output[i], input[ans[inputIdx]]); | ||
} | ||
} | ||
} | ||
}; | ||
} | ||
} | ||
|
||
auto K::lower(Resources &) const noexcept -> RoutineWorkspace { | ||
#define CASE(DT) \ | ||
case DataType::DT: \ | ||
return lowerTyped<primitive<DataType::DT>::type>(selectType, broadcaster, inputsNum) | ||
|
||
switch (dataType) { | ||
CASE(F32); | ||
CASE(U8); | ||
CASE(I8); | ||
CASE(U16); | ||
CASE(I16); | ||
CASE(I32); | ||
CASE(I64); | ||
CASE(F64); | ||
CASE(U32); | ||
CASE(U64); | ||
default: | ||
UNREACHABLE(); | ||
} | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#ifndef KERNEL_SELECT_CPU_KERNEL_HH | ||
#define KERNEL_SELECT_CPU_KERNEL_HH | ||
|
||
#include "kernel/attributes/broadcaster.h" | ||
#include "kernel/collectors/select.h" | ||
#include "kernel/kernel.h" | ||
#include "kernel/tensor.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct SelectCpu final : public Kernel { | ||
DataType dataType; | ||
SelectType selectType; | ||
Broadcaster broadcaster; | ||
size_t inputsNum; | ||
|
||
SelectCpu(decltype(dataType), decltype(selectType), decltype(broadcaster), decltype(inputsNum)) noexcept; | ||
|
||
static KernelBox build(SelectType, TensorRefs) noexcept; | ||
static size_t typeId() noexcept; | ||
|
||
size_t kernelTypeId() const noexcept final; | ||
std::string_view description() const noexcept final; | ||
RoutineWorkspace lower(Resources &) const noexcept final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_Select_CPU_KERNEL_HH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters