Skip to content

Commit

Permalink
fix(kernel): 修改pad/slice的diminfo; 删除部分注释
Browse files Browse the repository at this point in the history
  • Loading branch information
bitzyz authored and YdrMaster committed Jan 31, 2024
1 parent 76fd621 commit 237c494
Show file tree
Hide file tree
Showing 14 changed files with 214 additions and 143 deletions.
4 changes: 2 additions & 2 deletions src/04kernel/cuda/include/kernel/cuda/pad.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

namespace refactor::kernel::cuda {

struct DimInfo {
struct PadDimInfo {
unsigned int strideI, strideO, padS, dimI;
};

void launchPad(
KernelLaunchParameters const &,
uint8_t const *src, uint8_t const *src_const,
DimInfo const *dims, void *output,
PadDimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize);

Expand Down
4 changes: 2 additions & 2 deletions src/04kernel/cuda/include/kernel/cuda/slice.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

namespace refactor::kernel::cuda {

struct DimInfo {
struct SliceDimInfo {
unsigned int strideO, skip;
int strideI;
};

void launchSlice(
KernelLaunchParameters const &,
void const *src, DimInfo const *dims, void *output,
void const *src, SliceDimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize);

Expand Down
4 changes: 2 additions & 2 deletions src/04kernel/cuda/src/pad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace refactor::kernel::cuda {
unsigned long long n,
uint8_t const *__restrict__ src,
uint8_t const *__restrict__ src_const,
DimInfo const *__restrict__ dims,
PadDimInfo const *__restrict__ dims,
uint8_t *__restrict__ dst,
unsigned int rank,
unsigned int blockSize) {
Expand Down Expand Up @@ -42,7 +42,7 @@ namespace refactor::kernel::cuda {
void launchPad(
KernelLaunchParameters const &params,
uint8_t const *src, uint8_t const *src_const,
DimInfo const *dims, void *output,
PadDimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize) {

Expand Down
4 changes: 2 additions & 2 deletions src/04kernel/cuda/src/slice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace refactor::kernel::cuda {
__global__ static void sliceKernel(
unsigned long long n,
uint8_t const *__restrict__ src,
DimInfo const *__restrict__ dims,
SliceDimInfo const *__restrict__ dims,
uint8_t *__restrict__ dst,
unsigned int rank,
unsigned int blockSize) {
Expand All @@ -29,7 +29,7 @@ namespace refactor::kernel::cuda {

void launchSlice(
KernelLaunchParameters const &params,
void const *src, DimInfo const *dims, void *output,
void const *src, SliceDimInfo const *dims, void *output,
unsigned int rank,
unsigned int blockSize) {
sliceKernel<<<
Expand Down
3 changes: 0 additions & 3 deletions src/04kernel/include/kernel/attributes/pad_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ namespace refactor::kernel {
struct PadInfo {
struct Dim {
dim_t strideI, strideO, padS, dimI;

// bool operator==(Dim const &) const noexcept;
// bool operator!=(Dim const &) const noexcept;
};
std::vector<Dim> dims;
dim_t blockCount, blockSize;
Expand Down
18 changes: 2 additions & 16 deletions src/04kernel/src/attributes/pad_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@
namespace refactor::kernel {
using PI = PadInfo;

// bool PI::Dim::operator==(Dim const &rhs) const noexcept {
// return strideI == rhs.strideI &&
// strideO == rhs.strideO &&
// padStride == rhs.padStride &&
// dimt.dimI == rhs.dimI &&;
// }
// bool PI::Dim::operator!=(Dim const &rhs) const noexcept {
// return !operator==(rhs);
// }

PI::PadInfo(decltype(dims) dims_, dim_t blockCount_, dim_t blockSize_) noexcept
: dims(std::move(dims_)), blockCount(blockCount_), blockSize(blockSize_) {}

Expand All @@ -22,23 +12,21 @@ namespace refactor::kernel {
size_t rank = input.rank();
ASSERT(dims_.size() == rank, "Invalid to get PadInfo.");

// std::vector<dim_t> shape;
size_t j = 0;
for (auto i : range0_(rank)) {
if (dims_[i].dimI != dims_[i].dimO || dims_[i].dimI != 1) {
if (j < i) { dims_[j] = dims_[i]; }
//shape.push_back(dims_[i].dimI);
j++;
}
}
dims_.resize(rank = j);

// 合并末尾连续维度
for (auto i : range0_(rank).rev()) {
if (auto d = dims_[i].dimI; d == dims_[i].dimO) {
blockSize *= d;
dims_.pop_back();
} else {
dims.reserve(rank = dims_.size());
auto &dim = dims_[i];
if (auto times = std::gcd(std::gcd(dims_[i].dimI, dims_[i].pads), dims_[i].dimO); times > 1) {
blockSize *= times;
Expand All @@ -49,6 +37,7 @@ namespace refactor::kernel {
break;
}
}
dims.reserve(rank = dims_.size());

dim_t strideI = 1, strideO = 1;
for (auto i : range0_(rank).rev()) {
Expand All @@ -63,9 +52,6 @@ namespace refactor::kernel {
strideO *= dim.dimO;
}
std::reverse(dims.begin(), dims.end());
// for (auto i : range0_(rank)) {
// fmt::println("strideI = {}, strideO = {}, padS = {}, dimI = {}", dims[i].strideI, dims[i].strideO, dims[i].padS, dims[i].dimI);
// }
blockCount = strideO;
}

Expand Down
2 changes: 1 addition & 1 deletion src/04kernel/src/attributes/slice_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ namespace refactor::kernel {
shape.pop_back();
dims_.pop_back();
} else {
dims.resize(rank = shape.size());
if (auto &dim = dims_[i]; dim.step == 1) {
if (auto times = std::gcd(std::gcd(dim.start, dim.length), shape[i]); times > 1) {
blockSize *= times;
Expand All @@ -58,6 +57,7 @@ namespace refactor::kernel {
break;
}
}
dims.resize(rank = shape.size());
dim_t strideI = 1;
for (auto i : range0_(rank).rev()) {
auto const &dim = dims_[i];
Expand Down
9 changes: 0 additions & 9 deletions src/04kernel/src/kernels/pad/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@ namespace refactor::kernel {
return nullptr;
}
size_t value = value_ ? value_->get().dataType.size() : 0;
// std::vector<uint8_t> constValue(info.blockSize, 0);
// if (value_) {
// auto constValueSize = value_->get().dataType.size();
// auto n = constValueSize / info.blockSize;
// for (auto i : range0_(n)) {
// std::memcpy(constValue.data() + i * info.blockSize, (void const *) *value_->get().data, constValueSize);
// }
// }
return std::make_unique<K>(std::move(info), mode, value);
}
auto K::typeId() noexcept -> size_t {
Expand All @@ -42,7 +34,6 @@ namespace refactor::kernel {
auto src = reinterpret_cast<uint8_t const *>(inputs[0]);
auto dst = reinterpret_cast<uint8_t *>(outputs[0]);
std::vector<uint8_t> defaultValue(info.blockSize, 0);
// fmt::println("value = {}, blockSize = {}", value, info.blockSize);
if (value != 0) {
auto constValue = reinterpret_cast<uint8_t const *>(inputs[2]);
for (auto i : range0_(info.blockSize / value)) {
Expand Down
8 changes: 0 additions & 8 deletions src/04kernel/src/kernels/pad/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ namespace refactor::kernel {
}
size_t value = value_ ? value_->get().dataType.size() : 0;
info.reform(16);
// std::vector<uint8_t> constValue(info.blockSize, 0);
// if (value_) {
// auto constValueSize = value_->get().dataType.size();
// auto n = constValueSize / info.blockSize;
// for (auto i : range0_(n)) {
// std::memcpy(constValue.data() + i * info.blockSize, (void const *) *value_->get().data, constValueSize);
// }
// }
return std::make_unique<K>(std::move(info), mode, value);
}

Expand Down
8 changes: 3 additions & 5 deletions src/04kernel/src/kernels/pad/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ namespace refactor::kernel {
using namespace runtime;

auto PadCuda::lower(Resources &) const noexcept -> RoutineWorkspace {
thrust::host_vector<cuda::DimInfo> dims(info.dims.size());
thrust::host_vector<cuda::PadDimInfo> dims(info.dims.size());
std::transform(info.dims.begin(), info.dims.end(),
dims.begin(),
[](auto const &d) {
return cuda::DimInfo{
return cuda::PadDimInfo{
d.strideI,
d.strideO,
d.padS,
d.dimI,
};
});
return [dims = thrust::device_vector<cuda::DimInfo>(dims),
return [dims = thrust::device_vector<cuda::PadDimInfo>(dims),
params = cuda::ThreadsDistributer()(info.blockCount),
blockSize = info.blockSize,
value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
Expand All @@ -27,7 +27,6 @@ namespace refactor::kernel {
if (value != 0) {
auto constValue = reinterpret_cast<uint8_t const *>(inputs[2]);
for (auto i : range0_(blockSize / value)) {
// std::memcpy(defaultValueHost.data() + i * value, constValue, value);
cudaMemcpy(defaultValue.data().get() + i * value, constValue, value, cudaMemcpyDeviceToDevice);
}
}
Expand All @@ -38,4 +37,3 @@ namespace refactor::kernel {
}

}// namespace refactor::kernel

6 changes: 3 additions & 3 deletions src/04kernel/src/kernels/slice/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ namespace refactor::kernel {
using namespace runtime;

auto SliceCuda::lower(Resources &) const noexcept -> RoutineWorkspace {
thrust::host_vector<cuda::DimInfo> dims(info.dims.size());
thrust::host_vector<cuda::SliceDimInfo> dims(info.dims.size());
std::transform(info.dims.begin(), info.dims.end(),
dims.begin(),
[](auto const &d) {
return cuda::DimInfo{
return cuda::SliceDimInfo{
d.strideO,
d.skip,
d.strideI,
};
});
return [dims = thrust::device_vector<cuda::DimInfo>(dims),
return [dims = thrust::device_vector<cuda::SliceDimInfo>(dims),
params = cuda::ThreadsDistributer()(info.blockCount),
blockSize = info.blockSize](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto src = reinterpret_cast<uint8_t const *>(inputs[0]);
Expand Down
Loading

0 comments on commit 237c494

Please sign in to comment.