Skip to content

Commit

Permalink
omp and loop fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
xgqdut2016 authored and bitzyz committed Feb 5, 2024
1 parent d076c20 commit 8db2499
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
find_package(OpenMP)
if(OpenMP_C_FOUND)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
endif()
if(OpenMP_CXX_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()

if(USE_CUDA)
add_compile_definitions(USE_CUDA)
Expand Down Expand Up @@ -78,6 +85,7 @@ add_compile_options(-march=native) # this will cause error in some machine
add_compile_options(-mtune=native)
add_compile_options(-Wall)


add_subdirectory(3rd-party/backward-cpp)

include_directories(3rd-party/fmt/include)
Expand Down
18 changes: 9 additions & 9 deletions src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ namespace refactor::kernel {
* Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias.
*/
void matrixMultiply(TI const *a, TI const *b, TO *y) const noexcept {
// #pragma omp parallel for
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < N; j++) {
TO sum = 0;
// #pragma omp simd reduction(+ : sum)
for (size_t k = 0; k < K; k++) {
sum += static_cast<TO>(a[i * strideA0 + k * strideA1] * b[k * strideB0 + j * strideB1]);
}
y[i * N + j] = beta * y[i * N + j] + alpha * sum;
#pragma omp parallel for simd
for (size_t ind = 0; ind < M * N; ind++) {
size_t i = ind / N;
size_t j = ind % N;
TO sum = 0;

for (size_t k = 0; k < K; k++) {
sum += static_cast<TO>(a[i * strideA0 + k * strideA1] * b[k * strideB0 + j * strideB1]);
}
y[ind] = beta * y[ind] + alpha * sum;
}
}
};
Expand Down

0 comments on commit 8db2499

Please sign in to comment.