Skip to content

Commit

Permalink
skip whatever does not work
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Sep 16, 2024
1 parent f3db8d0 commit f103998
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
4 changes: 2 additions & 2 deletions onnx_extended/ortops/tutorial/cuda/custom_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,14 @@ void CustomGemmKernel::SetParams(const std::vector<int64_t> &shape_A,
}
}

void check_device(const Ort::ConstValue &input, const char *name) {
static void check_device(const Ort::ConstValue &input, const char *name) {
EXT_ENFORCE(input.HasValue(), "Input '", name, "' is not empty.");
auto mem = input.GetTensorMemoryInfo();
EXT_ENFORCE(mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"Input '", name, "' is not on CUDA");
}

void check_device(const Ort::UnownedValue &output, const char *name) {
static void check_device(const Ort::UnownedValue &output, const char *name) {
auto mem = output.GetTensorMemoryInfo();
EXT_ENFORCE(mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"Output '", name, "' is not on CUDA");
Expand Down
36 changes: 26 additions & 10 deletions onnx_extended/ortops/tutorial/cuda/matx_matmul.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
#include "cuda/common_kernels_cuda.h"
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
#include "matx.h"
#endif
#include "matx_matmul.h"
#include <cublasLt.h>
#include <cublas_v2.h>

#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
using namespace matx;
#endif

namespace ortops {

Expand Down Expand Up @@ -46,14 +50,14 @@ MatXMatMulOp::GetOutputCharacteristic(std::size_t index) const {

MatXMatMulKernel::MatXMatMulKernel(const OrtApi &api, const OrtKernelInfo *info) {}

void check_device(const Ort::ConstValue &input, const char *name) {
static void check_device(const Ort::ConstValue &input, const char *name) {
EXT_ENFORCE(input.HasValue(), "Input '", name, "' is not empty.");
auto mem = input.GetTensorMemoryInfo();
EXT_ENFORCE(mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"Input '", name, "' is not on CUDA");
}

void check_device(const Ort::UnownedValue &output, const char *name) {
static void check_device(const Ort::UnownedValue &output, const char *name) {
auto mem = output.GetTensorMemoryInfo();
EXT_ENFORCE(mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"Output '", name, "' is not on CUDA");
Expand All @@ -74,11 +78,24 @@ ONNXTensorElementDataType GetTypeAndShape(const TValue &input, std::vector<int64
template <typename T>
void ComputeMatMul(const std::vector<int64_t> &shape_A, const T *ptr_A,
const std::vector<int64_t> &shape_B, const T *ptr_B,
const std::vector<int64_t> &shape_D, const T *ptr_D, cudaExecutor &exec) {
auto matx_ta = make_tensor(ptr_A, shape_A);
auto matx_tb = make_tensor(ptr_B, shape_B);
//auto matx_td = make_tensor(ptr_D, shape_D);
auto matx_td = matmul(matx_ta, matx_tb).run(exec);
const std::vector<int64_t> &shape_D, const T *ptr_D, cudaStream_t &stream) {
// MatX only supports tensors with a known ranks.
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
if (shape_A.size() == 2 && shape_B.size() == 2) {
auto matx_ta = make_tensor(ptr_A, {shape_A[0], shape_A[1]});
auto matx_tb = make_tensor(ptr_B, {shape_B[0], shape_B[1]});
auto matx_td = make_tensor<T>({shape_D[0], shape_D[1]});
(matx_td = matmul(matx_ta, matx_tb)).run(stream);
CUDA_THROW_IF_ERROR(cudaMemcpyAsync((void *)ptr_D, (void *)matx_td.data(),
sizeof(T) * shape_D[0] * shape_D[1],
cudaMemcpyDeviceToDevice));
} else {
EXT_THROW("ComputeMatMul not implemented when ranks are ", shape_A.size(), " and ",
shape_B.size(), ".");
}
#else
EXT_THROW("ComputeMatMul not implemented with CUDA_VERSION=", CUDA_VERSION, ".");
#endif
}

void MatXMatMulKernel::Compute(OrtKernelContext *context) {
Expand Down Expand Up @@ -106,14 +123,13 @@ void MatXMatMulKernel::Compute(OrtKernelContext *context) {
shape_D[i] = shape_A[i];
shape_D[shape_D.size() - 1] = shape_B[shape_B.size() - 1];
Ort::UnownedValue output = ctx.GetOutput(0, shape_D);

cudaExecutor exec{stream};
check_device(output, "Y");

switch (dtype_A) {
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
ComputeMatMul(shape_A, input_A.GetTensorData<float>(), shape_B,
input_B.GetTensorData<float>(), shape_D, output.GetTensorMutableData<float>(),
exec);
stream);
break;
default:
EXT_THROW("Not implemented for type: ", (uint64_t)dtype_A, ".");
Expand Down

0 comments on commit f103998

Please sign in to comment.