diff --git a/src/04kernel/cuda/include/kernel/cuda/pad.cuh b/src/04kernel/cuda/include/kernel/cuda/pad.cuh new file mode 100644 index 00000000..79d36cdd --- /dev/null +++ b/src/04kernel/cuda/include/kernel/cuda/pad.cuh @@ -0,0 +1,22 @@ +#ifndef KERNEL_CUDA_PAD_CUH +#define KERNEL_CUDA_PAD_CUH + +#include "threads_distributer.cuh" +#include + +namespace refactor::kernel::cuda { + + struct DimInfo { + unsigned int strideI, strideO, padS, dimI; + }; + + void launchPad( + KernelLaunchParameters const &, + uint8_t const *src, uint8_t const *src_const, + DimInfo const *dims, void *output, + unsigned int rank, + unsigned int blockSize); + +}// namespace refactor::kernel::cuda + +#endif// KERNEL_CUDA_PAD_CUH diff --git a/src/04kernel/cuda/src/pad.cu b/src/04kernel/cuda/src/pad.cu new file mode 100644 index 00000000..f66d1479 --- /dev/null +++ b/src/04kernel/cuda/src/pad.cu @@ -0,0 +1,64 @@ +#include "kernel/cuda/pad.cuh" +#include "macro.cuh" +#include + +namespace refactor::kernel::cuda { + + __global__ static void padKernel( + unsigned long long n, + uint8_t const *__restrict__ src, + uint8_t const *__restrict__ src_const, + DimInfo const *__restrict__ dims, + uint8_t *__restrict__ dst, + unsigned int rank, + unsigned int blockSize) { + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) { + long rem = tid, j = 0; + bool flag = false; + for (auto i = 0; i < rank; ++i) { + auto strideO = __ldg(&(dims[i].strideO)); + auto strideI = __ldg(&(dims[i].strideI)); + auto padS = __ldg(&(dims[i].padS)); + auto dimI = __ldg(&(dims[i].dimI)); + auto pos = rem / strideO - padS; + if (pos < 0 || pos >= dimI) { + flag = true; + break; + } + j += pos * strideI; + rem %= strideO; + } + if (flag) { + optimizedMemcpy(dst + tid * blockSize, src_const, blockSize); + } else { + optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize); + } + } + } + + void launchPad( + KernelLaunchParameters const ¶ms, + uint8_t const *src, uint8_t const *src_const, + DimInfo const *dims, void *output, + unsigned int rank, + unsigned int blockSize) { + + + padKernel<<< + params.gridSize, + params.blockSize, + 0, + reinterpret_cast(params.stream)>>>( + params.n, + src, + src_const, + dims, + reinterpret_cast(output), + rank, + blockSize); + } + +}// namespace refactor::kernel::cuda diff --git a/src/04kernel/include/kernel/attributes/pad_info.h b/src/04kernel/include/kernel/attributes/pad_info.h index 40165724..9bdc4611 100644 --- a/src/04kernel/include/kernel/attributes/pad_info.h +++ b/src/04kernel/include/kernel/attributes/pad_info.h @@ -37,23 +37,28 @@ namespace refactor::kernel { } }; - using PadsShape = absl::InlinedVector; + namespace pad { + struct Dim { + int64_t dimI, dimO, pads; + }; + }// namespace pad + using PadDimension = std::vector; struct PadInfo { - int rank; - PadType mode; - PadsShape pads; - PadsShape wholeNDim; - PadsShape partNDim; - PadsShape partStride; - DataType type; - bool have_value; - size_t size; - - explicit PadInfo(PadsShape, PadType, Tensor const &, Tensor const &, bool) noexcept; - }; + struct Dim { + dim_t strideI, strideO, padS, dimI; + + // bool operator==(Dim const &) const noexcept; + // bool operator!=(Dim const &) const noexcept; + }; + std::vector dims; + dim_t blockCount, blockSize; + PadInfo(decltype(dims), dim_t, dim_t) noexcept; + PadInfo(PadDimension, Tensor const &); + void reform(dim_t) noexcept; + }; }// namespace refactor::kernel diff --git a/src/04kernel/include/kernel/collectors/pad.h b/src/04kernel/include/kernel/collectors/pad.h index 53073827..fd7d9744 100644 --- a/src/04kernel/include/kernel/collectors/pad.h +++ b/src/04kernel/include/kernel/collectors/pad.h @@ -7,11 +7,11 @@ namespace refactor::kernel { struct PadCollector final : public InfoCollector { - PadsShape pads; + PadDimension dims; PadType mode; - explicit PadCollector(decltype(_target) target, PadsShape const &pads_, PadType mode_) noexcept - : InfoCollector(target), pads(std::move(pads_)), mode(mode_) {} + explicit PadCollector(decltype(_target) target, PadDimension const &dims_, PadType mode_) noexcept + : InfoCollector(target), dims(std::move(dims_)), mode(mode_) {} std::vector filter(TensorRefs inputs, TensorRefs outputs) const final; diff --git a/src/04kernel/src/attributes/pad_info.cc b/src/04kernel/src/attributes/pad_info.cc index 62ad3721..a0ffe0c1 100644 --- a/src/04kernel/src/attributes/pad_info.cc +++ b/src/04kernel/src/attributes/pad_info.cc @@ -1,25 +1,88 @@ #include "kernel/attributes/pad_info.h" -#include #include namespace refactor::kernel { + using PI = PadInfo; - PadInfo::PadInfo( - PadsShape pads_, - PadType mode_, - Tensor const &x, - Tensor const &y, - bool have_value_) noexcept : rank(x.rank()), mode(mode_), pads(std::move(pads_)), wholeNDim(rank, 0), - partNDim(rank, 0), partStride(rank, 1), type(x.dataType), have_value(have_value_), - size(0) { - int64_t p = 1; - for (auto i = rank - 1; i >= 0; --i) { - wholeNDim[i] = y.shape[i]; - partNDim[i] = x.shape[i]; - partStride[i] = p; - p = p * partNDim[i]; + // bool PI::Dim::operator==(Dim const &rhs) const noexcept { + // return strideI == rhs.strideI && + // strideO == rhs.strideO && + // padStride == rhs.padStride && + // dimt.dimI == rhs.dimI &&; + // } + // bool PI::Dim::operator!=(Dim const &rhs) const noexcept { + // return !operator==(rhs); + // } + + PI::PadInfo(decltype(dims) dims_, dim_t blockCount_, dim_t blockSize_) noexcept + : dims(std::move(dims_)), blockCount(blockCount_), blockSize(blockSize_) {} + + PI::PadInfo(PadDimension dims_, Tensor const &input) : dims{}, blockCount(1), + blockSize(input.dataType.size()) { + size_t rank = input.rank(); + ASSERT(dims_.size() == rank, "Invalid to get PadInfo."); + + // std::vector shape; + size_t j = 0; + for (auto i : range0_(rank)) { + if (dims_[i].dimI != dims_[i].dimO || dims_[i].dimI != 1) { + if (j < i) { dims_[j] = dims_[i]; } + //shape.push_back(dims_[i].dimI); + j++; + } + } + dims_.resize(rank = j); + // 合并末尾连续维度 + for (auto i : range0_(rank).rev()) { + if (auto d = dims_[i].dimI; d == dims_[i].dimO) { + blockSize *= d; + dims_.pop_back(); + } else { + dims.reserve(rank = dims_.size()); + auto &dim = dims_[i]; + if (auto times = std::gcd(std::gcd(dims_[i].dimI, dims_[i].pads), dims_[i].dimO); times > 1) { + blockSize *= times; + dim.dimI /= times; + dim.dimO /= times; + dim.pads /= times; + } + break; + } + } + + dim_t strideI = 1, strideO = 1; + for (auto i : range0_(rank).rev()) { + auto const &dim = dims_[i]; + dims.push_back({ + strideI, + strideO, + static_cast(dim.pads), + static_cast(dim.dimI), + }); + strideI *= dim.dimI; + strideO *= dim.dimO; + } + std::reverse(dims.begin(), dims.end()); + // for (auto i : range0_(rank)) { + // fmt::println("strideI = {}, strideO = {}, padS = {}, dimI = {}", dims[i].strideI, dims[i].strideO, dims[i].padS, dims[i].dimI); + // } + blockCount = strideO; + } + + void PI::reform(dim_t maxblockSize) noexcept { + auto blockSize_ = std::gcd(blockSize, maxblockSize); + if (blockSize_ == blockSize) { return; } + auto t = blockSize / blockSize_; + blockCount *= t; + blockSize = blockSize_; + for (auto &d : dims) { + d.strideI *= t; + d.strideO *= t; + d.padS *= t; + d.dimI *= t; } - size = std::accumulate(wholeNDim.begin(), wholeNDim.end(), 1, std::multiplies<>()); + dims.resize(dims.size() + 1); + dims.back() = {1, 1, 0, t}; } }// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/pad.cc b/src/04kernel/src/collectors/pad.cc index c00429cc..f4c995e0 100644 --- a/src/04kernel/src/collectors/pad.cc +++ b/src/04kernel/src/collectors/pad.cc @@ -1,32 +1,32 @@ -#include "../kernels/pad/cpu_kernel.hh" -// #include "../kernels/pad/cuda_kernel.hh" #include "kernel/collectors/pad.h" +#include "../kernels/pad/cpu_kernel.hh" +#include "../kernels/pad/cuda_kernel.hh" namespace refactor::kernel { std::vector PadCollector::filter(TensorRefs inputs, TensorRefs outputs) const { auto const &input = inputs[0]; - auto const &output = outputs[0]; - bool have_value = inputs.size() >= 3 ? true : false; - PadInfo info(pads, mode, input, output, have_value); + PadInfo info(dims, input); + auto const_value = inputs.size() >= 3 ? std::make_optional(inputs[2]) : std::nullopt; std::vector ans; switch (_target) { case decltype(_target)::Cpu: - if (auto ptr = PadCpu::build(std::move(info)); ptr) { + if (auto ptr = PadCpu::build(std::move(info), mode, const_value); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + case decltype(_target)::Nvidia: + if (auto ptr = PadCuda::build(std::move(info), mode, const_value); ptr) { ans.emplace_back(std::move(ptr)); } break; - // case decltype(_target)::Nvidia: - // if (auto ptr = PadCuda::build(); ptr) { - // ans.emplace_back(std::move(ptr)); - // } - // break; default: UNREACHABLEX(void, "Unknown target"); } return ans; } -}// namespace refactor::kernel \ No newline at end of file +}// namespace refactor::kernel + diff --git a/src/04kernel/src/kernels/pad/cpu_kernel.cc b/src/04kernel/src/kernels/pad/cpu_kernel.cc index 76d23b36..ab58c704 100644 --- a/src/04kernel/src/kernels/pad/cpu_kernel.cc +++ b/src/04kernel/src/kernels/pad/cpu_kernel.cc @@ -4,14 +4,23 @@ namespace refactor::kernel { using K = PadCpu; - K::PadCpu(PadInfo info_) noexcept - : Kernel(), info(std::move(info_)) {} + K::PadCpu(PadInfo info_, PadType mode_, size_t value_) noexcept + : Kernel(), info(std::move(info_)), mode(mode_), valueLength(value_) {} - auto K::build(PadInfo info) noexcept -> KernelBox { - if (info.mode != PadType::Constant) { + auto K::build(PadInfo info, PadType mode, std::optional> value_) noexcept -> KernelBox { + if (mode != PadType::Constant) { return nullptr; } - return std::make_unique(std::move(info)); + size_t value = value_ ? value_->get().dataType.size() : 0; + // std::vector constValue(info.blockSize, 0); + // if (value_) { + // auto constValueSize = value_->get().dataType.size(); + // auto n = constValueSize / info.blockSize; + // for (auto i : range0_(n)) { + // std::memcpy(constValue.data() + i * info.blockSize, (void const *) *value_->get().data, constValueSize); + // } + // } + return std::make_unique(std::move(info), mode, value); } auto K::typeId() noexcept -> size_t { static uint8_t ID = 1; @@ -25,50 +34,42 @@ namespace refactor::kernel { return "Performing pad operation on generic cpu"; } - template - static Routine lowerTyped(PadInfo info) { + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { using namespace runtime; - return [info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { - auto x = reinterpret_cast(inputs[0]); - auto const_value = info.have_value ? reinterpret_cast(inputs[2])[0] : static_cast(0); - auto y = reinterpret_cast(outputs[0]); - auto getValue = [&](auto tid) { - int offset = 0; - for (int i = info.rank - 1; i >= 0; --i) { - auto wholePos = tid % info.wholeNDim[i]; - auto pos = wholePos - info.pads[i]; - // if pos belongs to pad range, then return -1 - if (pos < 0 || pos >= info.partNDim[i]) { return -1; } - tid = tid / info.wholeNDim[i]; - offset += pos * info.partStride[i]; + + return [info = this->info, value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto src = reinterpret_cast(inputs[0]); + auto dst = reinterpret_cast(outputs[0]); + std::vector defaultValue(info.blockSize, 0); + // fmt::println("value = {}, blockSize = {}", value, info.blockSize); + if (value != 0) { + auto constValue = reinterpret_cast(inputs[2]); + for (auto i : range0_(info.blockSize / value)) { + std::memcpy(defaultValue.data() + i * value, constValue, value); } - return offset; - }; - std::for_each_n(std::execution::par_unseq, natural_t(0), info.size, [&](auto i) { - auto axis = getValue(i); - y[i] = axis < 0 ? const_value : x[axis]; - }); + } + std::for_each_n(std::execution::par_unseq, + natural_t(0), info.blockCount, + [=, &info](auto i) { + long rem = i, j = 0; + bool flag = false; + for (auto const &dim : info.dims) { + auto pos = rem / dim.strideO - dim.padS; + if (pos < 0 || pos >= dim.dimI) { + flag = true; + break; + } + j += pos * dim.strideI; + rem %= dim.strideO; + } + if (flag) { + std::memcpy(dst + i * info.blockSize, defaultValue.data(), info.blockSize); + } else { + std::memcpy(dst + i * info.blockSize, src + j * info.blockSize, info.blockSize); + } + }); }; } - auto K::lower(Resources &) const noexcept -> RoutineWorkspace { -#define CASE_DT(T) \ - case DataType::T: \ - return lowerTyped::type>(std::move(info)); - switch (info.type) { - CASE_DT(U8) - CASE_DT(I8) - CASE_DT(U16) - CASE_DT(I16) - CASE_DT(U32) - CASE_DT(I32) - CASE_DT(U64) - CASE_DT(I64) - CASE_DT(F32) - CASE_DT(F64) - default: - UNREACHABLE(); - } - } - }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cpu_kernel.hh b/src/04kernel/src/kernels/pad/cpu_kernel.hh index 24ea6a9d..d314c520 100644 --- a/src/04kernel/src/kernels/pad/cpu_kernel.hh +++ b/src/04kernel/src/kernels/pad/cpu_kernel.hh @@ -8,10 +8,12 @@ namespace refactor::kernel { struct PadCpu final : public Kernel { PadInfo info; + PadType mode; + size_t valueLength; - explicit PadCpu(PadInfo) noexcept; + explicit PadCpu(PadInfo, PadType, size_t) noexcept; - static KernelBox build(PadInfo) noexcept; + static KernelBox build(PadInfo, PadType, std::optional>) noexcept; static size_t typeId() noexcept; size_t kernelTypeId() const noexcept final; @@ -21,4 +23,5 @@ namespace refactor::kernel { }// namespace refactor::kernel -#endif// KERNEL_PAD_CPU_KERNEL_HH \ No newline at end of file +#endif// KERNEL_PAD_CPU_KERNEL_HH + diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cc b/src/04kernel/src/kernels/pad/cuda_kernel.cc index a281f15b..495f20e0 100644 --- a/src/04kernel/src/kernels/pad/cuda_kernel.cc +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cc @@ -1,25 +1,29 @@ #include "cuda_kernel.hh" -#ifdef USE_CUDA -#include "../../generator/nvrtc_repo.h" -#include "kernel/cuda/threads_distributer.cuh" -#include -#endif - namespace refactor::kernel { using K = PadCuda; - K::PadCuda(PadInfo info_) noexcept - : Kernel(), info(std::move(info_)) {} + K::PadCuda(PadInfo info_, PadType mode_, size_t value_) noexcept + : Kernel(), info(std::move(info_)), mode(mode_), valueLength(value_) {} - auto K::build(PadInfo info) noexcept -> KernelBox { + auto K::build(PadInfo info, PadType mode, std::optional> value_) noexcept -> KernelBox { #ifndef USE_CUDA return nullptr; #endif - if (info.mode != PadType::Constant) { + if (mode != PadType::Constant) { return nullptr; } - return std::make_unique(std::move(info)); + size_t value = value_ ? value_->get().dataType.size() : 0; + info.reform(16); + // std::vector constValue(info.blockSize, 0); + // if (value_) { + // auto constValueSize = value_->get().dataType.size(); + // auto n = constValueSize / info.blockSize; + // for (auto i : range0_(n)) { + // std::memcpy(constValue.data() + i * info.blockSize, (void const *) *value_->get().data, constValueSize); + // } + // } + return std::make_unique(std::move(info), mode, value); } auto K::typeId() noexcept -> size_t { @@ -32,61 +36,4 @@ namespace refactor::kernel { return "Performing Pad using CUDA"; } -#ifdef USE_CUDA - constexpr static const char *TEMPLATE = R"~( -#include "kernel/attributes/pad_info.h" - -__device__ int WholeTensorOffset2PartTensorOffset(int tid, - PadInfo info) {{ - int offset = 0; - for (int i = nDims - 1; i >= 0; --i) {{ - auto wholePos = tid % info.wholeNDim[i]; - auto pos = wholePos - info.begNum[i]; - // if pos belongs to pad range, then return -1 - if (pos < 0 || pos >= info.partNDim[i]) - return -1; - tid = tid / info.wholeNDim[i]; - - offset += pos * info.partStride[i]; - }} - - return offset; -}} -extern "C" __global__ void kernel( - {0:} *__restrict__ y, - {0:} const *__restrict__ x, - {0:} const *__restrict__ value, - PadInfo info, - size_t n -) {{ - auto const_value = info.have_value ? value[0] : static_cast<{0:}>(0); - for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, - step = blockDim.x * gridDim.x; - tid < n; - tid += step){{ - auto axis = WholeTensorOffset2PartTensorOffset(tid, info); - y[tid] = axis < 0 ? const_value : x[tid]; - }} -}} - )~"; - auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { - using namespace runtime; - - auto name = fmt::format("Pad_{}", info.type.name()); - auto code = fmt::format(TEMPLATE, nvrtc::dataType(info.type)); - return [info = this->info, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), - params = cuda::ThreadsDistributer()(info.size)]( - Resources &, void *, void const *const *inputs, void *const *outputs) { - auto y = outputs[0]; - auto x = inputs[0]; - auto const_value = info.have_value ? inputs[2] : nullptr; - auto n = params.n; - void *args[]{&y, &x, &const_value, const_cast(&info), &n}; - h->launch(params.gridSize, 1, 1, - params.blockSize, 1, 1, - 0, args); - }; - } -#endif - }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cu b/src/04kernel/src/kernels/pad/cuda_kernel.cu new file mode 100644 index 00000000..89d3ba9e --- /dev/null +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cu @@ -0,0 +1,41 @@ +#include "cuda_kernel.hh" +#include "kernel/cuda/pad.cuh" +#include +#include + +namespace refactor::kernel { + using namespace runtime; + + auto PadCuda::lower(Resources &) const noexcept -> RoutineWorkspace { + thrust::host_vector dims(info.dims.size()); + std::transform(info.dims.begin(), info.dims.end(), + dims.begin(), + [](auto const &d) { + return cuda::DimInfo{ + d.strideI, + d.strideO, + d.padS, + d.dimI, + }; + }); + return [dims = thrust::device_vector(dims), + params = cuda::ThreadsDistributer()(info.blockCount), + blockSize = info.blockSize, + value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto src = reinterpret_cast(inputs[0]); + thrust::device_vector defaultValue(blockSize, 0); + if (value != 0) { + auto constValue = reinterpret_cast(inputs[2]); + for (auto i : range0_(blockSize / value)) { + // std::memcpy(defaultValueHost.data() + i * value, constValue, value); + cudaMemcpy(defaultValue.data().get() + i * value, constValue, value, cudaMemcpyDeviceToDevice); + } + } + cuda::launchPad(params, src, defaultValue.data().get(), dims.data().get(), outputs[0], + dims.size(), + blockSize); + }; + } + +}// namespace refactor::kernel + diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.hh b/src/04kernel/src/kernels/pad/cuda_kernel.hh index fe6526c9..b0f915a5 100644 --- a/src/04kernel/src/kernels/pad/cuda_kernel.hh +++ b/src/04kernel/src/kernels/pad/cuda_kernel.hh @@ -8,9 +8,11 @@ namespace refactor::kernel { struct PadCuda final : public Kernel { PadInfo info; + PadType mode; + size_t valueLength; - PadCuda(PadInfo) noexcept; - static KernelBox build(PadInfo) noexcept; + PadCuda(PadInfo, PadType, size_t) noexcept; + static KernelBox build(PadInfo, PadType, std::optional>) noexcept; static size_t typeId() noexcept; size_t kernelTypeId() const noexcept final; diff --git a/src/04kernel/test/kernels/pad/test_cpu.cpp b/src/04kernel/test/kernels/pad/test_cpu.cpp index d834bc04..48b1bbb0 100644 --- a/src/04kernel/test/kernels/pad/test_cpu.cpp +++ b/src/04kernel/test/kernels/pad/test_cpu.cpp @@ -9,13 +9,15 @@ using namespace kernel; TEST(kernel, PadCpu) { // no constant_value { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + }; // build routine auto xTensor = Tensor::share(DataType::F32, Shape{2, 3}); auto yTensor = Tensor::share(DataType::F32, Shape{4, 5}); - PadsShape pads = {1, 1, 1, 1}; - PadType type = PadType::Constant; - PadInfo info = PadInfo(pads, type, *xTensor, *yTensor, false); - auto kernel = PadCpu::build(std::move(info)); + PadType mode = PadType::Constant; + auto kernel = PadCpu::build(PadInfo(dims, *xTensor), mode, std::nullopt); ASSERT_TRUE(kernel); auto res = runtime::Resources(); auto routine = kernel->lower(res).routine; @@ -37,15 +39,17 @@ TEST(kernel, PadCpu) { } // have constant_value { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + }; // build routine auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3}); auto t2Tensor = Tensor::share(DataType::I64, Shape{4}); auto t3Tensor = Tensor::share(DataType::F32, Shape{}); auto yTensor = Tensor::share(DataType::F32, Shape{4, 5}); - PadsShape pads = {1, 1, 1, 1}; PadType type = PadType::Constant; - PadInfo info = PadInfo(pads, type, *t1Tensor, *yTensor, true); - auto kernel = PadCpu::build(std::move(info)); + auto kernel = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); ASSERT_TRUE(kernel); auto res = runtime::Resources(); auto routine = kernel->lower(res).routine; @@ -67,4 +71,56 @@ TEST(kernel, PadCpu) { EXPECT_FLOAT_EQ(output[i], result[i]); } } + { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + {1, 1, 0}, + {4, 8, 2}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); + PadType type = PadType::Constant; + auto kernel = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // set input data + std::vector + data(t1Tensor->elementsSize(), 1), + result(yTensor->elementsSize()); + std::vector constant_value(1, 1.2); + std::vector pads_value{1, 1, 0, 2, 1, 1, 0, 2}; + // inference + { + void const *inputs[]{data.data(), pads_value.data(), constant_value.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + std::vector output = {1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, 1.0000, 1.0000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, 1.0000, 1.0000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, 1.0000, 1.0000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.0000, 1.0000, 1.0000, 1.0000, 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, + 1.0000, 1.0000, 1.0000, 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, + 1.0000, 1.0000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000}; + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(output[i], result[i]); + } + } } diff --git a/src/04kernel/test/kernels/pad/test_cuda.cpp b/src/04kernel/test/kernels/pad/test_cuda.cpp index 38e1196f..0a9ef9fe 100644 --- a/src/04kernel/test/kernels/pad/test_cuda.cpp +++ b/src/04kernel/test/kernels/pad/test_cuda.cpp @@ -10,36 +10,50 @@ using namespace kernel; using namespace hardware; TEST(kernel, PadCuda) { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + {1, 1, 0}, + {4, 8, 2}, + }; // build routine - auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 5}); - auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 5}); - PadsShape pads = {1, 1, 0, 1, 1, 0}; + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); PadType type = PadType::Constant; - auto kernel = PadCuda::build(PadInfo(pads, type, *xTensor, *yTensor, false)); - auto kCpu = PadCpu::build(PadInfo(pads, type, *xTensor, *yTensor, false)); + auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + auto kernel = PadCuda::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); ASSERT_TRUE(kernel && kCpu); auto res = runtime::Resources(); auto routine = kernel->lower(res).routine, rCpu = kCpu->lower(res).routine; // malloc auto &dev = *device::init(Device::Type::Nvidia, 0, ""); - auto gpuIn = dev.malloc(xTensor->bytesSize()), + auto gpuIn = dev.malloc(t1Tensor->bytesSize()), + gpuIn2 = dev.malloc(t2Tensor->bytesSize()), + gpuIn3 = dev.malloc(t3Tensor->bytesSize()), gpuOut = dev.malloc(yTensor->bytesSize()); // put input data - std::vector data(xTensor->elementsSize()), + std::vector data(t1Tensor->elementsSize(), 1.f), + constvalue(1, 1.2f), cpuOut(yTensor->elementsSize()); + std::vector pads{1, 1, 0, 2, 1, 1, 0, 2}; for (auto i : range0_(data.size())) { data[i] = i; } - gpuIn->copyFromHost(data.data(), xTensor->bytesSize()); + gpuIn->copyFromHost(data.data(), t1Tensor->bytesSize()); + gpuIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); + gpuIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); + // inference { - void const *inputs[]{*gpuIn}; + void const *inputs[]{*gpuIn, *gpuIn2, *gpuIn3}; void *outputs[]{*gpuOut}; routine(res, nullptr, inputs, outputs); } { - void const *inputs[]{data.data()}; + void const *inputs[]{data.data(), pads.data(), constvalue.data()}; void *outputs[]{cpuOut.data()}; rCpu(res, nullptr, inputs, outputs); } @@ -47,7 +61,8 @@ TEST(kernel, PadCuda) { std::vector result(yTensor->elementsSize()); gpuOut->copyToHost(result.data(), yTensor->bytesSize()); // check - for (auto i : range0_(data.size())) { + for (auto i : range0_(cpuOut.size())) { + // fmt::println("i = {}, cpuout = {}, gpuout = {}", i, cpuOut[i], result[i]); EXPECT_FLOAT_EQ(cpuOut[i], result[i]); } } diff --git a/src/05computation/include/computation/operators/pad.h b/src/05computation/include/computation/operators/pad.h index 49e6ead5..173fcae7 100644 --- a/src/05computation/include/computation/operators/pad.h +++ b/src/05computation/include/computation/operators/pad.h @@ -5,14 +5,14 @@ #include "kernel/collectors/pad.h" namespace refactor::computation { - using kernel::PadsShape; using kernel::PadType; + using Dimensions = kernel::PadDimension; struct Pad final : public LayoutDependentOperator { - PadsShape pads; + Dimensions dims; PadType mode; - Pad(decltype(pads), PadType) noexcept; + Pad(decltype(dims), PadType) noexcept; static size_t typeId() noexcept; size_t opTypeId() const noexcept final; diff --git a/src/05computation/src/operators/pad.cc b/src/05computation/src/operators/pad.cc index 1e2e23f0..243f8536 100644 --- a/src/05computation/src/operators/pad.cc +++ b/src/05computation/src/operators/pad.cc @@ -4,8 +4,8 @@ namespace refactor::computation { using Op = Pad; - Op::Pad(decltype(pads) pads_, - PadType mode_) noexcept : LayoutDependentOperator(), pads(std::move(pads_)), mode(mode_) {} + Op::Pad(decltype(dims) dims_, + PadType mode_) noexcept : LayoutDependentOperator(), dims(std::move(dims_)), mode(mode_) {} auto Op::typeId() noexcept -> size_t { static uint8_t ID = 1; @@ -15,13 +15,16 @@ namespace refactor::computation { auto Op::name() const noexcept -> std::string_view { return "Pad"; } auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { using Collector_ = kernel::PadCollector; - return std::make_unique(target, std::move(pads), mode); + return std::make_unique(target, std::move(dims), mode); } auto Op::serialize() const noexcept -> std::string { - return fmt::format("{}({}, {})", - name(), - vec2str(pads), - mode.toString()); + std::stringstream ss; + ss << name() << "(["; + for (auto const &d : dims) { + ss << "input = " << d.dimI << ", output = " << d.dimO << ", pads = " << d.pads; + } + ss << "mode = " << mode.toString() << " ])"; + return ss.str(); } }// namespace refactor::computation diff --git a/src/07onnx/src/operators/pad.cc b/src/07onnx/src/operators/pad.cc index f18d7888..817deabe 100644 --- a/src/07onnx/src/operators/pad.cc +++ b/src/07onnx/src/operators/pad.cc @@ -10,7 +10,8 @@ namespace refactor::onnx { Op::Pad(Pm mode_) : Operator(), mode(mode_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - auto mode = defaultOr(attributes, "mode", {"constant"}).string(); + //auto mode = defaultOr(attributes, "mode", {"constant"}).string(); + auto mode = attributes.getOrInsert("mode", {"constant"}).string(); Pm pm; if (mode == "constant") { pm = Pm::Constant; @@ -104,11 +105,12 @@ namespace refactor::onnx { auto Op::lower(TensorRefs inputs) const -> computation::OpBox { using Ty_ = computation::PadType; using Op_ = computation::Pad; - using Shape_ = computation::PadsShape; + using Dimension = computation::Dimensions; auto rank = inputs[0].rank(); int64_t const *pads_ = inputs[1].data->get(); - Shape_ pads_info(2 * rank, 0); + std::vector pads_info(2 * rank, 0); + Dimension dims(rank); if (inputs.size() != 4) { for (auto i : range0_(inputs[1].shape[0].value())) { pads_info[i] = pads_[i]; } } else { @@ -123,6 +125,11 @@ namespace refactor::onnx { pads_info[axis + rank] = pads_[i + axes_len]; } } + for (auto i : range0_(rank)) { + auto dimI = inputs[0].shape[i].value(); + dims[i] = { + dimI, dimI + pads_info[i] + pads_info[i + rank], pads_info[i]}; + } Ty_ mode_; switch (mode) { case Pm::Constant: @@ -136,7 +143,8 @@ namespace refactor::onnx { default: UNREACHABLE(); } - return std::make_unique(std::move(pads_info), mode_); + return std::make_unique(std::move(dims), mode_); } -}// namespace refactor::onnx \ No newline at end of file +}// namespace refactor::onnx +