Skip to content

Commit

Permalink
feat: 优化Pad算子
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz authored and YdrMaster committed Jan 31, 2024
1 parent 48a195e commit 76fd621
Show file tree
Hide file tree
Showing 16 changed files with 426 additions and 196 deletions.
22 changes: 22 additions & 0 deletions src/04kernel/cuda/include/kernel/cuda/pad.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef KERNEL_CUDA_PAD_CUH
#define KERNEL_CUDA_PAD_CUH

#include "threads_distributer.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {

struct DimInfo {
unsigned int strideI, strideO, padS, dimI;
};

void launchPad(
KernelLaunchParameters const &,
uint8_t const *src, uint8_t const *src_const,
DimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize);

}// namespace refactor::kernel::cuda

#endif// KERNEL_CUDA_PAD_CUH
64 changes: 64 additions & 0 deletions src/04kernel/cuda/src/pad.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "kernel/cuda/pad.cuh"
#include "macro.cuh"
#include <cstdint>

namespace refactor::kernel::cuda {

__global__ static void padKernel(
unsigned long long n,
uint8_t const *__restrict__ src,
uint8_t const *__restrict__ src_const,
DimInfo const *__restrict__ dims,
uint8_t *__restrict__ dst,
unsigned int rank,
unsigned int blockSize) {
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step) {
long rem = tid, j = 0;
bool flag = false;
for (auto i = 0; i < rank; ++i) {
auto strideO = __ldg(&(dims[i].strideO));
auto strideI = __ldg(&(dims[i].strideI));
auto padS = __ldg(&(dims[i].padS));
auto dimI = __ldg(&(dims[i].dimI));
auto pos = rem / strideO - padS;
if (pos < 0 || pos >= dimI) {
flag = true;
break;
}
j += pos * strideI;
rem %= strideO;
}
if (flag) {
optimizedMemcpy(dst + tid * blockSize, src_const, blockSize);
} else {
optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize);
}
}
}

void launchPad(
KernelLaunchParameters const &params,
uint8_t const *src, uint8_t const *src_const,
DimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize) {


padKernel<<<
params.gridSize,
params.blockSize,
0,
reinterpret_cast<cudaStream_t>(params.stream)>>>(
params.n,
src,
src_const,
dims,
reinterpret_cast<uint8_t *>(output),
rank,
blockSize);
}

}// namespace refactor::kernel::cuda
31 changes: 18 additions & 13 deletions src/04kernel/include/kernel/attributes/pad_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,28 @@ namespace refactor::kernel {
}
};

using PadsShape = absl::InlinedVector<int64_t, 4>;
namespace pad {
struct Dim {
int64_t dimI, dimO, pads;
};
}// namespace pad

using PadDimension = std::vector<pad::Dim>;

struct PadInfo {
int rank;
PadType mode;
PadsShape pads;
PadsShape wholeNDim;
PadsShape partNDim;
PadsShape partStride;
DataType type;
bool have_value;
size_t size;

explicit PadInfo(PadsShape, PadType, Tensor const &, Tensor const &, bool) noexcept;
};
struct Dim {
dim_t strideI, strideO, padS, dimI;

// bool operator==(Dim const &) const noexcept;
// bool operator!=(Dim const &) const noexcept;
};
std::vector<Dim> dims;
dim_t blockCount, blockSize;

PadInfo(decltype(dims), dim_t, dim_t) noexcept;
PadInfo(PadDimension, Tensor const &);
void reform(dim_t) noexcept;
};

}// namespace refactor::kernel

Expand Down
6 changes: 3 additions & 3 deletions src/04kernel/include/kernel/collectors/pad.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
namespace refactor::kernel {

struct PadCollector final : public InfoCollector {
PadsShape pads;
PadDimension dims;
PadType mode;

explicit PadCollector(decltype(_target) target, PadsShape const &pads_, PadType mode_) noexcept
: InfoCollector(target), pads(std::move(pads_)), mode(mode_) {}
explicit PadCollector(decltype(_target) target, PadDimension const &dims_, PadType mode_) noexcept
: InfoCollector(target), dims(std::move(dims_)), mode(mode_) {}

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
Expand Down
95 changes: 79 additions & 16 deletions src/04kernel/src/attributes/pad_info.cc
Original file line number Diff line number Diff line change
@@ -1,25 +1,88 @@
#include "kernel/attributes/pad_info.h"
#include <iostream>
#include <numeric>

namespace refactor::kernel {
using PI = PadInfo;

PadInfo::PadInfo(
PadsShape pads_,
PadType mode_,
Tensor const &x,
Tensor const &y,
bool have_value_) noexcept : rank(x.rank()), mode(mode_), pads(std::move(pads_)), wholeNDim(rank, 0),
partNDim(rank, 0), partStride(rank, 1), type(x.dataType), have_value(have_value_),
size(0) {
int64_t p = 1;
for (auto i = rank - 1; i >= 0; --i) {
wholeNDim[i] = y.shape[i];
partNDim[i] = x.shape[i];
partStride[i] = p;
p = p * partNDim[i];
// bool PI::Dim::operator==(Dim const &rhs) const noexcept {
// return strideI == rhs.strideI &&
// strideO == rhs.strideO &&
// padStride == rhs.padStride &&
// dimt.dimI == rhs.dimI &&;
// }
// bool PI::Dim::operator!=(Dim const &rhs) const noexcept {
// return !operator==(rhs);
// }

PI::PadInfo(decltype(dims) dims_, dim_t blockCount_, dim_t blockSize_) noexcept
: dims(std::move(dims_)), blockCount(blockCount_), blockSize(blockSize_) {}

PI::PadInfo(PadDimension dims_, Tensor const &input) : dims{}, blockCount(1),
blockSize(input.dataType.size()) {
size_t rank = input.rank();
ASSERT(dims_.size() == rank, "Invalid to get PadInfo.");

// std::vector<dim_t> shape;
size_t j = 0;
for (auto i : range0_(rank)) {
if (dims_[i].dimI != dims_[i].dimO || dims_[i].dimI != 1) {
if (j < i) { dims_[j] = dims_[i]; }
//shape.push_back(dims_[i].dimI);
j++;
}
}
dims_.resize(rank = j);
// 合并末尾连续维度
for (auto i : range0_(rank).rev()) {
if (auto d = dims_[i].dimI; d == dims_[i].dimO) {
blockSize *= d;
dims_.pop_back();
} else {
dims.reserve(rank = dims_.size());
auto &dim = dims_[i];
if (auto times = std::gcd(std::gcd(dims_[i].dimI, dims_[i].pads), dims_[i].dimO); times > 1) {
blockSize *= times;
dim.dimI /= times;
dim.dimO /= times;
dim.pads /= times;
}
break;
}
}

dim_t strideI = 1, strideO = 1;
for (auto i : range0_(rank).rev()) {
auto const &dim = dims_[i];
dims.push_back({
strideI,
strideO,
static_cast<dim_t>(dim.pads),
static_cast<dim_t>(dim.dimI),
});
strideI *= dim.dimI;
strideO *= dim.dimO;
}
std::reverse(dims.begin(), dims.end());
// for (auto i : range0_(rank)) {
// fmt::println("strideI = {}, strideO = {}, padS = {}, dimI = {}", dims[i].strideI, dims[i].strideO, dims[i].padS, dims[i].dimI);
// }
blockCount = strideO;
}

void PI::reform(dim_t maxblockSize) noexcept {
auto blockSize_ = std::gcd(blockSize, maxblockSize);
if (blockSize_ == blockSize) { return; }
auto t = blockSize / blockSize_;
blockCount *= t;
blockSize = blockSize_;
for (auto &d : dims) {
d.strideI *= t;
d.strideO *= t;
d.padS *= t;
d.dimI *= t;
}
size = std::accumulate(wholeNDim.begin(), wholeNDim.end(), 1, std::multiplies<>());
dims.resize(dims.size() + 1);
dims.back() = {1, 1, 0, t};
}

}// namespace refactor::kernel
24 changes: 12 additions & 12 deletions src/04kernel/src/collectors/pad.cc
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
#include "../kernels/pad/cpu_kernel.hh"
// #include "../kernels/pad/cuda_kernel.hh"
#include "kernel/collectors/pad.h"
#include "../kernels/pad/cpu_kernel.hh"
#include "../kernels/pad/cuda_kernel.hh"

namespace refactor::kernel {

std::vector<KernelBox>
PadCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &input = inputs[0];
auto const &output = outputs[0];
bool have_value = inputs.size() >= 3 ? true : false;
PadInfo info(pads, mode, input, output, have_value);
PadInfo info(dims, input);
auto const_value = inputs.size() >= 3 ? std::make_optional(inputs[2]) : std::nullopt;

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = PadCpu::build(std::move(info)); ptr) {
if (auto ptr = PadCpu::build(std::move(info), mode, const_value); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = PadCuda::build(std::move(info), mode, const_value); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
// case decltype(_target)::Nvidia:
// if (auto ptr = PadCuda::build(); ptr) {
// ans.emplace_back(std::move(ptr));
// }
// break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
}// namespace refactor::kernel

Loading

0 comments on commit 76fd621

Please sign in to comment.