diff --git a/src/04kernel/src/collectors/global_pool.cc b/src/04kernel/src/collectors/global_pool.cc index 1ae1d7fc..e6a278c1 100644 --- a/src/04kernel/src/collectors/global_pool.cc +++ b/src/04kernel/src/collectors/global_pool.cc @@ -1,5 +1,6 @@ #include "kernel/collectors/global_pool.h" #include "../kernels/pool/cudnn_kernel.hh" +#include "../kernels/pool/cnnl_kernel.hh" namespace refactor::kernel { @@ -28,6 +29,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = PoolCnnl::build(type, false, kernelShape, attributes, x, y); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/kernels/gather/cnnl_kernel.cc b/src/04kernel/src/kernels/gather/cnnl_kernel.cc index d5ddcace..b4d5aa15 100644 --- a/src/04kernel/src/kernels/gather/cnnl_kernel.cc +++ b/src/04kernel/src/kernels/gather/cnnl_kernel.cc @@ -4,6 +4,7 @@ #include "../../utilities/bang/cnnl_context.hh" #include "../../utilities/bang/cnnl_functions.h" #endif +#include namespace refactor::kernel { using K = GatherCnnl; @@ -15,11 +16,11 @@ namespace refactor::kernel { #ifndef USE_BANG return nullptr; #endif - + return std::make_unique(decltype(info){ input.dataType, DataType::I32, - axis, + axis ? axis : 0, std::vector(input.shape.begin(), input.shape.end()), std::vector(index.shape.begin(), index.shape.end()), std::vector(output.shape.begin(), output.shape.end()), @@ -70,15 +71,16 @@ namespace refactor::kernel { res.fetchOrStore(); auto routine = [d = std::move(d), - shape = info.inDim.data(), workspaceSize, + shape = std::vector(info.inDim.begin(), info.inDim.end()), + workspaceSize, dim = info.axis](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { - BANG_ASSERT(cnrtMemcpy(workspace, (void*) shape, workspaceSize, CNRT_MEM_TRANS_DIR_HOST2DEV)); + res.fetchOrStore()->copyFromCPU(workspace, shape.data(), workspaceSize); CNNL_ASSERT(cnnlGatherV2(res.fetchOrStore()->handle, dim, d->inDesc, inputs[0], reinterpret_cast(workspace), - d->indexDesc, reinterpret_cast(inputs[1]), + d->indexDesc, reinterpret_cast(inputs[1]), d->outDesc, outputs[0])); BANG_ASSERT(cnrtQueueSync(res.fetchOrStore()->queue)); - }; + }; return {std::move(routine), workspaceSize}; } diff --git a/src/04kernel/src/kernels/reduce/cnnl_kernel.cc b/src/04kernel/src/kernels/reduce/cnnl_kernel.cc index 6b22d793..4ea6fd82 100644 --- a/src/04kernel/src/kernels/reduce/cnnl_kernel.cc +++ b/src/04kernel/src/kernels/reduce/cnnl_kernel.cc @@ -71,14 +71,15 @@ namespace refactor::kernel { std::vector dimsI(shape.begin(), shape.end()), - dimsO(shape.begin(), shape.end()); + dimsO(shape.begin(), shape.end()), + indices(axes.begin(), axes.end()); for (auto axis : axes) { dimsO[axis] = 1; } // setCnnlTensor(d->x, dataType, slice(dimsI.data(), dimsI.size())); // setCnnlTensor(d->y, dataType, slice(dimsO.data(), dimsO.size())); - CNNL_ASSERT(cnnlSetTensorDescriptor(d->x, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(dataType), dimsI.size(), dimsI.data())); - CNNL_ASSERT(cnnlSetTensorDescriptor(d->y, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(dataType), dimsO.size(), dimsO.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->x, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), dimsI.size(), dimsI.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->y, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), dimsO.size(), dimsO.data())); // clang-format off auto reduceOp = reduceType == ReduceType::Mean ? CNNL_REDUCE_AVG @@ -91,12 +92,12 @@ namespace refactor::kernel { : UNREACHABLEX(cnnlReduceOp_t, ""); // clang-format on CNNL_ASSERT(cnnlSetReduceDescriptor_v2( - d->reduce, (int *) (axes.data()), axes.size(), reduceOp, + d->reduce, indices.data(), indices.size(), reduceOp, cnnlDataTypeConvert(d->f32 ? DataType::F32 : DataType::F64), CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0)); auto handler = res.fetchOrStore()->handle; - size_t idxWorkspaceSize = axes.size() * sizeof(int); + size_t idxWorkspaceSize = indices.size() * sizeof(int); // idxWorkspaceSize = hardware::alignBytes(idxWorkspaceSize, 256); size_t workspaceSize; // get workspace diff --git a/src/04kernel/src/kernels/softmax/cnnl_kernel.cc b/src/04kernel/src/kernels/softmax/cnnl_kernel.cc index 865e452e..0633195d 100644 --- a/src/04kernel/src/kernels/softmax/cnnl_kernel.cc +++ b/src/04kernel/src/kernels/softmax/cnnl_kernel.cc @@ -59,9 +59,11 @@ namespace refactor::kernel { static_cast(algo), dataType != DataType::F64); int dims[]{pre, mid, post}; - cnnlSoftmaxMode_t mode = (post == 1) ? CNNL_SOFTMAX_MODE_HIGH_DIMENSION - : (pre == 1) ? CNNL_SOFTMAX_MODE_LOW_DIMENSION - : CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; + // cnnlSoftmaxMode_t mode = (pre == 1) ? CNNL_SOFTMAX_MODE_HIGH_DIMENSION + // : (post == 1) ? CNNL_SOFTMAX_MODE_LOW_DIMENSION + // : CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; + // FIXME(bolun): CNNL Softmax mode + cnnlSoftmaxMode_t mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; // cnnlSoftmaxForward_v2 is applied to a 3D input tensor only CNNL_ASSERT(cnnlSetTensorDescriptor(d->t, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), 3, dims)); @@ -78,6 +80,7 @@ namespace refactor::kernel { CNNL_COMPUTATION_ULTRAHIGH_PRECISION, &a, d->t, inputs[0], &b, d->t, outputs[0])); + res.fetchOrStore()->queueSync(); }; } diff --git a/src/04kernel/src/kernels/where/cnnl_kernel.cc b/src/04kernel/src/kernels/where/cnnl_kernel.cc index 774c5513..50b7c9d8 100644 --- a/src/04kernel/src/kernels/where/cnnl_kernel.cc +++ b/src/04kernel/src/kernels/where/cnnl_kernel.cc @@ -16,13 +16,24 @@ namespace refactor::kernel { #ifndef USE_BANG return nullptr; #endif - return std::make_unique(decltype(info) { - inputs[1].get().dataType, - inputs[0].get().shape, - inputs[1].get().shape, - inputs[2].get().shape, - outputs[0].get().shape, - }); + std::vector cDim(inputs[0].get().shape.begin(), inputs[0].get().shape.end()), + xDim(inputs[1].get().shape.begin(), inputs[1].get().shape.end()), + yDim(inputs[2].get().shape.begin(), inputs[2].get().shape.end()), + ansDim(outputs[0].get().shape.begin(), outputs[0].get().shape.end()); + if (ansDim.size() == 0) { + ansDim.push_back(1); + } + if (xDim.size() == 0) { + xDim.push_back(1); + } + if (yDim.size() == 0) { + yDim.push_back(1); + } + if (cDim.size() == 0) { + cDim.push_back(1); + } + return std::make_unique(decltype(info){ + inputs[1].get().dataType, cDim, xDim, yDim, ansDim}); } auto K::typeId() noexcept -> size_t { static uint8_t ID = 1; @@ -44,11 +55,10 @@ namespace refactor::kernel { struct Descriptors { cnnlTensorDescriptor_t cond, x, y, ans; - bool f32; - explicit Descriptors(decltype(f32) f32_) + explicit Descriptors() : cond(nullptr), x(nullptr), y(nullptr), - ans(nullptr), f32(f32_) { + ans(nullptr) { CNNL_ASSERT(cnnlCreateTensorDescriptor(&cond)); CNNL_ASSERT(cnnlCreateTensorDescriptor(&x)); CNNL_ASSERT(cnnlCreateTensorDescriptor(&y)); @@ -64,29 +74,20 @@ namespace refactor::kernel { Descriptors(const Descriptors &) = delete; Descriptors(Descriptors &&) = delete; }; - auto d = std::make_shared(info.dataType != DT::F64); - - std::vector cDim(info.condDim.begin(), info.condDim.end()), - xDim(info.thenDim.begin(), info.thenDim.end()), - yDim(info.elseDim.begin(), info.elseDim.end()), - ansDim(info.outputDim.begin(), info.outputDim.end()); - - auto rightAlign = [](std::vector &dim, uint32_t targetLength) { - if (dim.size() < targetLength) { - dim.insert(dim.begin(), targetLength - dim.size(), 1); - } - }; - if (ansDim.size() == 0) { - ansDim.push_back(1); - } - rightAlign(cDim, ansDim.size()); - rightAlign(xDim, ansDim.size()); - rightAlign(yDim, ansDim.size()); - - CNNL_ASSERT(cnnlSetTensorDescriptor(d->cond, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(DT::Bool), cDim.size(), cDim.data())); - CNNL_ASSERT(cnnlSetTensorDescriptor(d->x, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), xDim.size(), xDim.data())); - CNNL_ASSERT(cnnlSetTensorDescriptor(d->y, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), yDim.size(), yDim.data())); - CNNL_ASSERT(cnnlSetTensorDescriptor(d->ans, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), ansDim.size(), ansDim.data())); + auto d = std::make_shared(); + + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->cond, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(DT::Bool), + info.condDim.size(), info.condDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->x, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.thenDim.size(), info.thenDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->y, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.elseDim.size(), info.elseDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->ans, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.outputDim.size(), info.outputDim.data())); auto handle = res.fetchOrStore()->handle; size_t workspaceSize; @@ -94,19 +95,14 @@ namespace refactor::kernel { res.fetchOrStore(); auto routine = [d = std::move(d), workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { - // fetch cnnl handle from resources - auto handle = res.fetchOrStore()->handle; - auto cond = inputs[0], - x = inputs[1], - y = inputs[2]; - auto ans = outputs[0]; CNNL_ASSERT(cnnlSelectV2( - handle, d->cond, cond, d->x, x, - d->y, y, workspace, workspaceSize, - d->ans, ans)); + res.fetchOrStore()->handle, + d->cond, inputs[0], d->x, inputs[1], + d->y, inputs[2], workspace, workspaceSize, + d->ans, outputs[0])); - cnrtQueueSync(res.fetchOrStore()->queue); + res.fetchOrStore()->queueSync(); }; return {std::move(routine), workspaceSize}; diff --git a/src/04kernel/src/kernels/where/cnnl_kernel.hh b/src/04kernel/src/kernels/where/cnnl_kernel.hh index 6139b183..ffe39a87 100644 --- a/src/04kernel/src/kernels/where/cnnl_kernel.hh +++ b/src/04kernel/src/kernels/where/cnnl_kernel.hh @@ -7,12 +7,10 @@ namespace refactor::kernel { - using Shape = absl::InlinedVector; - struct WhereCnnl final : public Kernel { struct { DataType dataType; - Shape condDim, thenDim, elseDim, outputDim; + std::vector condDim, thenDim, elseDim, outputDim; } info; WhereCnnl(decltype(info)) noexcept; diff --git a/src/04kernel/src/utilities/bang/cnnl_context.cc b/src/04kernel/src/utilities/bang/cnnl_context.cc index 15cc1382..f2ad33ab 100644 --- a/src/04kernel/src/utilities/bang/cnnl_context.cc +++ b/src/04kernel/src/utilities/bang/cnnl_context.cc @@ -30,6 +30,15 @@ namespace refactor::kernel::cnnl { return "CnnlContext"; } + void CnnlContext::copyFromCPU(void *dst, const void *src, size_t size) { + BANG_ASSERT(cnrtMemcpy(dst, const_cast(src), size, + CNRT_MEM_TRANS_DIR_HOST2DEV)); + } + + void CnnlContext::queueSync() { + BANG_ASSERT(cnrtQueueSync(queue)); + } + }// namespace refactor::kernel::cnnl #endif diff --git a/src/04kernel/src/utilities/bang/cnnl_context.hh b/src/04kernel/src/utilities/bang/cnnl_context.hh index 7db40d3d..4743a0e4 100644 --- a/src/04kernel/src/utilities/bang/cnnl_context.hh +++ b/src/04kernel/src/utilities/bang/cnnl_context.hh @@ -22,6 +22,8 @@ namespace refactor::kernel::cnnl { size_t resourceTypeId() const noexcept final; std::string_view description() const noexcept final; + void copyFromCPU(void *dst, const void *src, size_t size); + void queueSync(); }; }// namespace refactor::kernel::cnnl diff --git a/src/04kernel/src/utilities/bang/cnrt_functions.cc b/src/04kernel/src/utilities/bang/cnrt_functions.cc new file mode 100644 index 00000000..2ea66194 --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnrt_functions.cc @@ -0,0 +1,27 @@ +#ifdef USE_BANG +#include "cnrt_functions.h" +#include "cnnl_functions.h" +#include +#include + +namespace refactor::kernel::cnnl { + + int currentDevice() { + int device; + BANG_ASSERT(cnrtGetDevice(&device)); + return device; + } + + void sync() { + BANG_ASSERT(cnrtSyncDevice()); + } + + void copyOut(void *dst, const void *src, size_t size) { + sync(); + BANG_ASSERT(cnrtMemcpy(dst, const_cast(src), size, + CNRT_MEM_TRANS_DIR_DEV2HOST)); + } + +}// namespace refactor::kernel::cnnl + +#endif diff --git a/src/04kernel/src/utilities/bang/cnrt_functions.h b/src/04kernel/src/utilities/bang/cnrt_functions.h new file mode 100644 index 00000000..ef119819 --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnrt_functions.h @@ -0,0 +1,16 @@ +#ifndef KERNEL_CNRT_FUNCTIONS_H +#define KERNEL_CNRT_FUNCTIONS_H + +#include "common.h" + +namespace refactor::kernel::cnnl { + + int currentDevice(); + + void sync(); + + void copyOut(void *dst, const void *src, size_t size); + +}// namespace refactor::kernel::cnnl + +#endif// KERNEL_CNRT_FUNCTIONS_H diff --git a/src/04kernel/test/kernels/gather/test_gather_cnnl.cpp b/src/04kernel/test/kernels/gather/test_gather_cnnl.cpp index a3f739cd..020b5f91 100644 --- a/src/04kernel/test/kernels/gather/test_gather_cnnl.cpp +++ b/src/04kernel/test/kernels/gather/test_gather_cnnl.cpp @@ -94,6 +94,51 @@ TEST(kernel, GatherCnnl) { EXPECT_FLOAT_EQ(c[i], result[i]); } } + + // Case axis = 1, indexType= int32 + { + // Create Tensor and build kernels + auto data = Tensor::share(DataType::F32, Shape{32, 16}, LayoutType::NCHW); + auto indices = Tensor::share(DataType::I64, Shape{1, 4}, LayoutType::NCHW); + auto output = Tensor::share(DataType::F32, Shape{1, 4, 16}, LayoutType::NCHW); + GatherInfo info(0, *data, *indices); + auto cnnlKernel = GatherCnnl::build(0, *data, *indices, *output); + auto cpuKernel = GatherCpu::build(info); + ASSERT_TRUE(cnnlKernel && cpuKernel); + auto res = runtime::Resources(); + auto [cnnlRoutine, workspaceSize] = cnnlKernel->lower(res); + auto cpuRoutine = cpuKernel->lower(res).routine; + // Init inputs and outputs + std::vector a; + for (auto i = 0; i < data->elementsSize(); i++) { + a.push_back(i + 0.1f); + } + std::vector b(indices->elementsSize(), 0); + std::vector c(output->elementsSize()); + auto workspace = dev.malloc(workspaceSize), + aMLU = dev.malloc(data->bytesSize()), + bMLU = dev.malloc(indices->bytesSize()), + cMLU = dev.malloc(output->bytesSize()); + aMLU->copyFromHost(a.data(), data->bytesSize()); + bMLU->copyFromHost(b.data(), indices->bytesSize()); + // Compute + { + void const *inputs[]{*aMLU, *bMLU}; + void *outputs[]{*cMLU}; + cnnlRoutine(res, *workspace, inputs, outputs); + } + { + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + } + // Compare + std::vector result(output->elementsSize()); + cMLU->copyToHost(result.data(), output->bytesSize()); + for (auto i : range0_(c.size())) { + EXPECT_FLOAT_EQ(c[i], result[i]); + } + } } #endif diff --git a/src/09python_ffi/CMakeLists.txt b/src/09python_ffi/CMakeLists.txt index ccce34d3..09567c9d 100644 --- a/src/09python_ffi/CMakeLists.txt +++ b/src/09python_ffi/CMakeLists.txt @@ -10,6 +10,10 @@ pybind11_add_module(python_ffi SHARED ${PYFFI_SRC}) target_link_libraries(python_ffi PRIVATE onnx llm communication) target_include_directories(python_ffi PRIVATE include) +if(USE_BANG) + target_include_directories(python_ffi PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../04kernel/src/utilities/bang) +endif() + # EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a # define (VERSION_INFO) here. # target_compile_definitions(python_ffi diff --git a/src/09python_ffi/src/executor.cc b/src/09python_ffi/src/executor.cc index c6a20cb9..947410cc 100644 --- a/src/09python_ffi/src/executor.cc +++ b/src/09python_ffi/src/executor.cc @@ -7,6 +7,10 @@ #include "kernel/cuda/functions.cuh" #endif// USE_CUDA +#ifdef USE_BANG +#include "cnrt_functions.h" +#endif// USE_BANG + namespace refactor::python_ffi { Executor::Executor(computation::Graph graph, runtime::Stream stream) @@ -70,9 +74,13 @@ namespace refactor::python_ffi { void Executor::bench(bool sync) { #ifdef USE_CUDA auto ans = _stream.bench(sync ? kernel::cuda::sync : nullptr); +#else + #ifdef USE_BANG + auto ans = _stream.bench(sync ? kernel::cnnl::sync : nullptr); #else auto ans = _stream.bench(nullptr); -#endif// USE_CUDA + #endif +#endif auto const &nodes = _graph.internal().contiguous().nodes; for (auto i : range0_(nodes.size())) { fmt::println("{} {} {}", @@ -213,6 +221,9 @@ namespace refactor::python_ffi { #ifdef USE_CUDA kernel::cuda::copyOut(buffer.data(), addresses[idx], size); #endif +#ifdef USE_BANG + kernel::cnnl::copyOut(buffer.data(), addresses[idx], size); +#endif auto file = path / fmt::format("data{:06}.{}", dataIdx++, format); fs::remove(file);