From 6ea15a051130cebfa882381860ba0ab24e0c1c74 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 31 Jan 2024 14:33:00 +0800 Subject: [PATCH] omp and loop fusion --- CMakeLists.txt | 8 ++++++++ .../kernels/mat_mul_common/cpu_template.hpp | 18 +++++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5521ed55..5561321e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,13 @@ set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_POSITION_INDEPENDENT_CODE ON) +find_package(OpenMP) +if(OpenMP_C_FOUND) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") +endif() +if(OpenMP_CXX_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +endif() if(USE_CUDA) add_compile_definitions(USE_CUDA) @@ -45,6 +52,7 @@ add_compile_options(-march=native) # this will cause error in some machine add_compile_options(-mtune=native) add_compile_options(-Wall) + add_subdirectory(3rd-party/backward-cpp) include_directories(3rd-party/fmt/include) diff --git a/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp b/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp index 766279b3..cac35ca8 100644 --- a/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp +++ b/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp @@ -16,16 +16,16 @@ namespace refactor::kernel { * Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias. */ void matrixMultiply(TI const *a, TI const *b, TO *y) const noexcept { - // #pragma omp parallel for - for (size_t i = 0; i < M; i++) { - for (size_t j = 0; j < N; j++) { - TO sum = 0; - // #pragma omp simd reduction(+ : sum) - for (size_t k = 0; k < K; k++) { - sum += static_cast(a[i * strideA0 + k * strideA1] * b[k * strideB0 + j * strideB1]); - } - y[i * N + j] = beta * y[i * N + j] + alpha * sum; +#pragma omp parallel for simd + for (size_t ind = 0; ind < M * N; ind++) { + size_t i = ind / N; + size_t j = ind % N; + TO sum = 0; + + for (size_t k = 0; k < K; k++) { + sum += static_cast(a[i * strideA0 + k * strideA1] * b[k * strideB0 + j * strideB1]); } + y[ind] = beta * y[ind] + alpha * sum; } } };