diff --git a/docs/12_convolution/01_naive_conv/README.md b/docs/12_convolution/01_naive_conv/README.md index a935c2c..98d193a 100644 --- a/docs/12_convolution/01_naive_conv/README.md +++ b/docs/12_convolution/01_naive_conv/README.md @@ -54,7 +54,7 @@ int w; // 数据宽 int k; // 卷积核数量 int r; // 卷积核高 int s; // 卷积核宽 -int u; // 卷积在高方向上的步长 +int u; // 卷积在高方向上的步长 int v; // 卷积在宽方向上的步长 int p; // 卷积在高方向上的补边 int q; // 卷积在宽方向上的补边 diff --git a/docs/12_convolution/02_intro_conv_optimize/README.md b/docs/12_convolution/02_intro_conv_optimize/README.md index d51242b..15ebee7 100644 --- a/docs/12_convolution/02_intro_conv_optimize/README.md +++ b/docs/12_convolution/02_intro_conv_optimize/README.md @@ -2,7 +2,7 @@ 上一篇文章中,我们介绍了卷积算子的简易实现,它是直接模拟卷积操作的过程,这种实现方式的缺点是计算量大,效率低。在本文中,我们将介绍卷积算子的优化思路。 -卷积算子的主要优化思路就是将卷积运算转换为矩阵乘法运算。进而卷积算子优化问题就转化为了矩阵乘法优化问题。这篇文章中我们主要介绍一下如何将卷积运算转换为矩阵乘法运算。 +卷积算子的主要优化思路就是将卷积运算转换为矩阵乘法运算。进而卷积算子优化问题就转化为了矩阵乘法优化问题卷积算子的主要优化思路就是将卷积运算转换为矩阵乘法运算。进而卷积算子优化问题就转化为了矩阵乘法优化问题。。这篇文章中我们主要介绍一下如何将卷积运算转换为矩阵乘法运算。 ## 1. 卷积算法映射为矩阵乘法 diff --git a/docs/12_convolution/03_im2col_conv/README.md b/docs/12_convolution/03_im2col_conv/README.md index 9ae5e5a..a60c8b7 100644 --- a/docs/12_convolution/03_im2col_conv/README.md +++ b/docs/12_convolution/03_im2col_conv/README.md @@ -1 +1,394 @@ -# im2col + gemm 实现卷积 \ No newline at end of file +# im2col + gemm 实现卷积 + +本文让我们使用 im2col 和 gemm 来实现卷积操作。 + +## 1. im2col 算子实现 + +首先让我们来实现 im2col 算子,这个算子的作用是将输入的图像转换为矩阵,这样我们就可以使用矩阵乘法来实现卷积操作。这个算子本身并没有太多需要能够优化的地方,我们需要做的就是按照上一篇文章中给出的 im2col 的定义来实现这个算子。 + +先让我们简单回顾一下 im2col 操作,假设我们有一个大小为 [$B$, $C_{in}$, $H$, $W$] 的输入张量,其中 $B$ 是批大小,$C_{in}$ 是通道数,$H$ 和 $W$ 是图像的高度和宽度,我们需要将这个输入张量转换为一个大小为 [$B$, $C$, $K$, $K$, $H'$, $W'$] 的输出张量,其中 $C$ 是通道数,$K$ 是卷积核的大小,$H'$ 和 $W'$ 是输出图像的高度和宽度。下图是 im2col 的示意图: + +![im2col](./images/im2col.jpg) + +简单来说就是将输入图像中每个卷积核对应的位置的像素值放到一个矩阵中。输入图像一个通道的大小是 `H * W`, im2col 对应的输出矩阵的大小是 `KH * KW * H' * W'`。我们这个 Kernel 的思路就是让每个线程负责一次将卷积核对应的像素值放到输出矩阵中,也就是输出一条矩阵的这个过程。在了解了这个过程之后,我们就可以开始实现这个算子了。 + +首先我们先来定义这个算子的接口,该算子的输入是一个大小为 [$B$, $C_{in}$, $H$, $W$] 的输入张量,输出是一个大小为 [$B$, $H'$, $W'$, $C_{in}$, $KH$, $KW$] 的输出张量。我们可以使用下面的代码来定义这个算子的接口: + +:::tip + +注意,这里为什么输出的 shape 不是 [$B$, $C_{in}$, $KH$, $KW$, $H'$, $W'$] 而是 [$B$, $H'$, $W'$, $C_{in}$, $KH$, $KW$] ?我们可以保留这个疑问,后面我们会解释这个问题。 + +::: + +```cpp +template +__global__ void im2col_kernel(const int n, + T *data_x, + T *data_y, + const int batches, // 批大小 + const int inner_size_x, // 每个样本(或批次)中单个通道的输入数据的大小 + const int inner_size_y, // 每个样本(或批次)中单个通道的输出数据的大小 + const int x_height, // 输入图像的高度 + const int x_width, // 输入图像的宽度 + const int kernel_height, // 卷积核的高度 + const int kernel_width, // 卷积核的宽度 + const int pad_height, // 填充的高度 + const int pad_width, // 填充的宽度 + const int stride_height, // 步长的高度 + const int stride_width, // 步长的宽度 + const int dilation_height, // 膨胀的高度 + const int dilation_width, // 膨胀的宽度 + const int y_height, // 输出图像的高度 + const int y_width, // 输出图像的宽度 + const int inner_size_c // 每个批次中通道数乘以输出张量的大小 + ); +``` + +接下来我们来实现这个算子的代码,首先为了方便后面的计算,我们先定义一些变量: + +```cpp +int batch = index / inner_size_c, idx = index % inner_size_c; +int w_out = idx % y_width, id = idx / y_width; +int h_out = id % y_height, channel_in = id / y_height; +int channel_out = channel_in * kernel_height * kernel_width; +``` + +接下来计算输入图像中对应输出位置的起始位置,并设置输入和输出数据指针。 + +```cpp +for (int i = 0; i < kernel_height; ++i) +{ + for (int j = 0; j < kernel_width; ++j) + { + int h = h_in + i * dilation_height; // 计算输入图像的高度 + int w = w_in + j * dilation_width; // 计算输入图像的宽度 + // 计算输入图像的索引, 如果索引超出了输入图像的大小,则设置为 0 + *out = (h >= 0 && w >= 0 && h < x_height && w < x_width) + ? in[i * dilation_height * x_width + j * dilation_width] + : static_cast(0); + // 更新输入和输出数据指针 + out += y_height * y_width * batches; + } +} +``` + +为了更加方便的使用 im2col_kernel 函数,我们还需要定义一个包装函数: + +```cpp +template +void cuda_im2col(const int batches, // 批大小 + const int x_channel, // 输入图像的通道数 + const int x_height, // 输入图像的高度 + const int x_width, // 输入图像的宽度 + const int y_out_plane, // 输出图像的通道数 + const int y_height, // 输出图像的高度 + const int y_width, // 输出图像的宽度 + const int kernel_height, // 卷积核的高度 + const int kernel_width, // 卷积核的宽度 + const int stride_height, // 步长的高度 + const int stride_width, // 步长的宽度 + const int dilation_height, // 膨胀的高度 + const int dilation_width, // 膨胀的宽度 + const int pad_height, // 填充的高度 + const int pad_width, // 填充的宽度 + T *x, // 输入数据 + T *y // 输出数据) +{ + // 计算输入和输出数据的大小 + const int inner_size_y = y_out_plane * y_height * y_width; + const int inner_size_x = x_channel * x_height * x_width; + const int inner_size_c = x_channel * y_height * y_width; + // 计算总的卷积核的数量 + const int num_kernels = batches * inner_size_c; + + // 设置线程块的大小和数量 + const int blockSize = std::max(std::min(MaxBlockSize, num_kernels), static_cast(1)); + // 计算线程块的数量 + const int gridSize = (num_kernels + blockSize - 1) / blockSize; + + // 调用 im2col_kernel 函数 + im2col_kernel<<>>(num_kernels, + x, + y, + batches, + inner_size_x, + inner_size_y, + x_height, + x_width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + y_height, + y_width, + inner_size_c); +} +``` + + +## 2. gemm 算子实现 + +接下来我们来实现 gemm 算子,这个算子的作用是将两个矩阵相乘,我们这里使用之前文章中介绍的[二维 Thread Tile 并行优化](https://cuda.keter.top/gemm_optimize/tiled2d/) 矩阵乘法优化方法来实现这个算子。 + +在用这个算子之前,我们需要先考虑一个问题:矩阵乘的输入矩阵和结果矩阵的大小是多少? 输入 im2col 的矩阵大小是 [$B$, $C_{in}$, $H$, $W$] ,输出 im2col 的矩阵大小是 [$B$, $H'$, $W'$, $C_{in}$, $KH$, $KW$],卷积核的大小是 [$C_{out}$, $C_in$, $KH$, $KW$]。 + +:::tip + +这里也就解释了为什么我们的 im2col 的输出矩阵的大小是 [$B$, $H'$, $W'$, $C_{in}$, $KH$, $KW$] 而不是 [$B$, $C_{in}$, $KH$, $KW$, $H'$, $W'$],因为我们的卷积核是 [$C_{out}$, $C_in$, $KH$, $KW$],如果我们的输出矩阵的大小是 [$B$, $C_{in}$, $KH$, $KW$, $H'$, $W'$],那么我们在计算卷积的时候就需要将卷积核转置,这样会增加计算的复杂度,所以我们选择了 [$B$, $H'$, $W'$, $C_{in}$, $KH$, $KW$] 这种方式。 + +::: + +我们可以发现卷积核和 im2col 输出的矩阵相乘后的大小是 [$C_{out}$, $B$, $H'$, $W'$],这个大小和我们平时使用的卷积操作的输出大小不太一样,所以我们需要对这个结果进行转置,转置后的大小就是 [$B$, $C_{out}$, $H'$, $W'$],这个大小和我们平时使用的卷积操作的输出大小是一样的。 + +已经忘了矩阵乘法的实现方法的同学可以参考之前的文章 [二维 Thread Tile 并行优化](https://cuda.keter.top/gemm_optimize/tiled2d/),它的主要优化思路是将输入矩阵和输出矩阵分块,然后使用线程块中的线程来计算这些块的结果,这样可以减少全局内存的访问次数,提高计算效率。 + +但是这个 Kernel 并不能直接用到我们的卷积操作中,我们需要做如下的修改: + +1. 添加判断条件,防止越界访问; +2. 将输出矩阵在保存的时候进行转置。 + +下面我们一起来实现这个代码,首先我们需要定义一些常量方便后续使用: + +```cpp + +// 我们在这个线程块中要计算的输出块 +const uint c_row = blockIdx.y; +const uint c_col = blockIdx.x; + +// 一个线程块中的线程负责计算 TM*TN 个元素 +const uint num_threads_block_tile = (BM * BN) / (TM * TN); + +// 计算线程的位置 +const uint thread_row = threadIdx.x / (BN / TN); +const uint thread_col = threadIdx.x % (BN / TN); + +// 用于避免越界访问 +int global_c_index = c_row * BM * N + c_col * BN; // 计算C矩阵的全局索引位置 +int global_m_pos = c_row * BM * K; // A矩阵全局位置 +int global_n_pos = c_col * BN; // B矩阵全局位置 +const uint m_size = M * K; // A矩阵的大小 +const uint n_size = N * K; // B矩阵的大小 + +assert((BM * BN) / (TM * TN) == blockDim.x); + +// 计算输出矩阵的位置 +const uint A_inner_row = threadIdx.x / BK; // A矩阵内部行索引 +const uint A_inner_col = threadIdx.x % BK; // A矩阵内部列索引 +const uint stride_a = num_threads_block_tile / BK; // A矩阵的跨步 +const uint B_inner_row = threadIdx.x / BN; // B矩阵内部行索引 +const uint B_inner_col = threadIdx.x % BN; // B矩阵内部列索引 +const uint stride_b = num_threads_block_tile / BN; // B矩阵的跨步 +``` + +这些代码除了添加了用于避免越界访问的代码外,其他的代码和之前的实现是一样的。这里不再多做解释。 + +接下来我们需要定义一些共享内存,线程块的结果和寄存器变量: + +```cpp +// 申请共享内存 +__shared__ float A_shared[BM * BK]; +__shared__ float B_shared[BK * BN]; + +// 用于保存线程块的结果 +float thread_results[TM * TN] = {0.0}; +float reg_m[TM] = {0.0}; +float reg_n[TN] = {0.0}; + +// 外层循环 +for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) +{ + ... // 每个线程的具体逻辑 +} +``` + +然后我们需要在内核的外层循环中,将矩阵 A 和矩阵 B 的数据加载到共享内存中: + +```cpp +// 加载数据到共享内存 +for (uint load_offset = 0; load_offset < BM; load_offset += stride_a) +{ + A_shared[(A_inner_row + load_offset) * BK + A_inner_col] = + (global_m_pos + (A_inner_row + load_offset) * K + A_inner_col < m_size) ? A[(A_inner_row + load_offset) * K + A_inner_col] : 0.0f; +} +for (uint load_offset = 0; load_offset < BK; load_offset += stride_b) +{ + B_shared[(B_inner_row + load_offset) * BN + B_inner_col] = + (global_n_pos + (B_inner_row + load_offset) * N + B_inner_col < n_size) ? B[(B_inner_row + load_offset) * N + B_inner_col] : 0.0f; +} + +__syncthreads(); + +// 移动数据指针 +A += BK; +B += BK * N; +global_m_pos += BK; +global_n_pos += BK * N; +``` + +请注意,这里我们在加载数据的时候,我们需要判断是否越界,如果越界则设置为 0。因为卷积运算经常会遇到越界的情况,所以我们需要在这里进行处理。以前我们主要是学习矩阵乘法,测试的例子都是 512 * 512 或者 1024 * 1024 这种大小的矩阵,这种情况下越界的情况比较少,所以我们没有处理这种情况。 + +下一步我们需要计算矩阵乘法的结果, 我们需要计算 BM * BN 个结果, 并将这一步的结果累加到 thread_results 中。 + +```cpp +// 计算矩阵乘法的结果 +for (uint dot_idx = 0; dot_idx < BK; dot_idx++) +{ + // 将数据加载到寄存器中 + for (uint i = 0; i < TM; i++) + { + reg_m[i] = A_shared[(thread_row * TM + i) * BK + dot_idx]; + } + for (uint i = 0; i < TN; i++) + { + reg_n[i] = B_shared[dot_idx * BN + thread_col * TN + i]; + } + + // 计算矩阵乘法的结果 + for (uint res_idx_m = 0; res_idx_m < TM; res_idx_m++) + { + for (uint res_idx_n = 0; res_idx_n < TN; res_idx_n++) + { + thread_results[res_idx_m * TN + res_idx_n] += reg_m[res_idx_m] * reg_n[res_idx_n]; + } + } +} + +// 移动数据指针 +__syncthreads(); +``` + +最后我们需要将 thread_results 中的结果保存到全局内存中,在保存的时候我们需要将结果转置,这里的转置坐标计算看起来难以理解,建议读者画图来理解这个过程(因为我当时就是对着图写出来的,过几天让我自己去看我不对着图也看不懂(x))。 + +```cpp +int inner_y_size = y_height * y_width; // 计算Y矩阵的内部尺寸 +int res_inner_index, g_index, batch_id, channel_id, inner_offset; + +int conv_idx; + +if (global_c_index >= M * N) // 如果全局索引超出范围,直接返回 +{ + return; +} + +for (uint res_idx_m = 0; res_idx_m < TM; res_idx_m++) +{ + for (uint res_idx_n = 0; res_idx_n < TN; res_idx_n++) + { + if (c_row * BM + thread_row * TM + res_idx_m < M && c_col * BN + thread_col * TN + res_idx_n < N) + { + // 计算结果在C矩阵中的内部索引 + res_inner_index = (thread_row * TM + res_idx_m) * N + thread_col * TN + res_idx_n; + // 计算全局索引 + g_index = global_c_index + res_inner_index; + // 计算内部偏移 + inner_offset = g_index % inner_y_size; + // 计算 batch ID + batch_id = (g_index % (inner_y_size * batch)) / inner_y_size; + // 计算 channel ID + channel_id = g_index / (inner_y_size * batch); + // 根据batch ID、channel ID和内部偏移计算卷积索引 + conv_idx = batch_id * (kernel_num * y_height * y_width) + channel_id * (y_height * y_width) + inner_offset; + // 将计算结果写入C矩阵 + C[conv_idx] = thread_results[res_idx_m * TN + res_idx_n]; + } + } +} +``` + +同样为了方便使用,我们还需要定义一个包装函数: + +```cpp +void cuda_gemm(float *A, + float *B, + float *C, + int M, + int N, + int K, + int batch, + int kernel_num, + int y_height, + int y_width) +{ + const uint BK = 8; + const uint TM = 8; + const uint TN = 8; + if (M >= 128 && N >= 128) + { + const uint BM = 128; + const uint BN = 128; + dim3 grid_size(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 block_size((BM * BN) / (TM * TN)); + sgemm_blocktiling_2d_kernel + <<>>(A, B, C, M, N, K, batch, kernel_num, y_height, y_width); + } + else + { + const uint BM = 64; + const uint BN = 64; + dim3 grid_size(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 block_size((BM * BN) / (TM * TN)); + sgemm_blocktiling_2d_kernel + <<>>(A, B, C, M, N, K, batch, kernel_num, y_height, y_width); + } +} +``` + +这样我们就实现了 gemm 算子的代码,接下来我们就可以使用这两个算子来实现卷积操作了。 + +## 3. 卷积操作实现 + +首先我们需要调用 im2col 算子将输入图像转换为矩阵,然后调用 gemm 算子将转换后的矩阵和卷积核进行矩阵乘法,最后将结果保存到输出张量中。 + +```cpp +cuda_im2col(batch_size, + x_channel, + x_height, + x_width, + y_out_plane, + y_height, + y_width, + kernel_height, + kernel_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + pad_height, + pad_width, + pIn_device, + pInCol_device); +KernelErrChk(); + +cudaDeviceSynchronize(); + +cuda_gemm(pWeight_device, + pInCol_device, + pOut_device, + kernel_numbers, + batch_size * y_height * y_width, + x_channel * kernel_height * kernel_width, + batch_size, + kernel_numbers, + y_height, y_width); +``` + +## 4. 编译和运行 + +这个代码我们提供了 Makefile 文件,可以直接使用 make 命令来编译代码,编译完成后运行 `bash job.sh` 可以自动输入样例数据并运行代码。 + +## 5. 总结 + +本文我们实现了 im2col 和 gemm 算子,然后使用这两个算子来实现卷积操作。下一篇文章我们会介绍如何把 im2col 给优化掉,让我们的卷积操作更加高效。 + +## References + +1. https://siboehm.com/articles/22/CUDA-MMM +2. https://space.keter.top/docs/high_performance/GEMM%E4%BC%98%E5%8C%96%E4%B8%93%E9%A2%98/%E5%85%B1%E4%BA%AB%E5%86%85%E5%AD%98%E7%BC%93%E5%AD%98%E5%9D%97 +3. https://space.keter.top/docs/high_performance/GEMM%E4%BC%98%E5%8C%96%E4%B8%93%E9%A2%98/%E4%B8%80%E7%BB%B4Thread%20Tile%E5%B9%B6%E8%A1%8C%E4%BC%98%E5%8C%96 +4. https://github.com/AndSonder/UNIVERSAL_SGEMM_CUDA + + diff --git a/docs/12_convolution/03_im2col_conv/codes/Makefile b/docs/12_convolution/03_im2col_conv/codes/Makefile new file mode 100644 index 0000000..32eaf6d --- /dev/null +++ b/docs/12_convolution/03_im2col_conv/codes/Makefile @@ -0,0 +1,27 @@ +CC=nvcc + +CXXFLAGS += -DNDEBUG -DUSE_DEFAULT_STDLIB -g + +INCLUDES += -I./include + +LDFLAGS = -gencode arch=compute_75,code=sm_75 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61, -gencode arch=compute_70,code=sm_70 + +# 获取当前目录下的cu文件集,放在变量CUR_SOURCE中 +CUR_SOURCE=${wildcard ./src/*.cu} + +# 将对应的cu文件名转为o文件后放在下面的CUR_OBJS变量中 +CUR_OBJS=${patsubst %.cu, %.o, $(CUR_SOURCE)} + +EXECUTABLE=conv2ddemo + +all: $(EXECUTABLE) + +$(EXECUTABLE): $(CUR_OBJS) + $(CC) $(CUR_OBJS) $(LDFLAGS) -o $(EXECUTABLE) + +%.o: %.cu + $(CC) -c $< $(CXXFLAGS) $(INCLUDES) -o $@ -Xptxas -v -lineinfo --std=c++11 ${LDFLAGS} + +clean: + rm -f $(EXECUTABLE) + rm -f ./src/*.o \ No newline at end of file diff --git a/docs/12_convolution/03_im2col_conv/codes/include/conv2d.h b/docs/12_convolution/03_im2col_conv/codes/include/conv2d.h new file mode 100644 index 0000000..0b1f4c0 --- /dev/null +++ b/docs/12_convolution/03_im2col_conv/codes/include/conv2d.h @@ -0,0 +1,41 @@ +#ifndef __CONV2D_FWD_HEADER__ +#define __CONV2D_FWD_HEADER__ + +#define __in__ +#define __out__ +#define __in_out__ + +typedef struct +{ + float *in; // 输入数据地址 + float *weight; // 权值数据地址 + float *out; // 输出数据地址 + unsigned int n; // batch szie default value 1 + unsigned int c; // channel number default value 32 + unsigned int h; // 数据高 default value 32 + unsigned int w; // 数据宽 default value 32 + unsigned int k; // 卷积核数量 default value 32 + unsigned int r; // 卷积核高 default value 1 + unsigned int s; // 卷积核宽 default value 1 + unsigned int u; // 卷积在高方向上的步长 default value 1 + unsigned int v; // 卷积在宽方向上的步长 default value 1 + unsigned int p; // 卷积在高方向上的补边 default value 0 + unsigned int q; // 卷积在宽方向上的补边 default value 0 +} problem_t; + +typedef struct +{ + unsigned int blockx; // blockx number + unsigned int blocky; // blocky number + unsigned int blockz; // blockz number + unsigned int threadx; // threadx number per block + unsigned int thready; // thready number per block + unsigned int threadz; // threadz number per block + unsigned int dynmicLdsSize; // 动态分配的lds大小,如果不使用动态分配的lds,则该值为0; + void *kernelPtr; // kernel ptr +} kernelInfo_t; + +int getParamsize(__in__ problem_t *problem, __out__ int *paramSize); +int getkernelInfo(__in__ problem_t *problem, __out__ kernelInfo_t *kernelInfo, __in_out__ void *param); + +#endif \ No newline at end of file diff --git a/docs/12_convolution/03_im2col_conv/codes/include/verfiy.h b/docs/12_convolution/03_im2col_conv/codes/include/verfiy.h new file mode 100644 index 0000000..04cd4c2 --- /dev/null +++ b/docs/12_convolution/03_im2col_conv/codes/include/verfiy.h @@ -0,0 +1,73 @@ +#ifndef __VERFIY_HEADER__ +#define __VERFIY_HEADER__ + +float getPrecision(float tmp) +{ + int tmpInt = (int)tmp; + float eNum = 1.0e-6; + if (abs(tmpInt) > 0) + { + while (tmpInt != 0) + { + tmpInt = (int)(tmpInt / 10); + eNum *= 10; + } + } + else + { + + if (tmp == 0) + return eNum; + + eNum = 1.0e-5; + + while (tmpInt == 0) + { + tmp *= 10; + tmpInt = (int)(tmp); + eNum /= 10; + } + } + return eNum; +} + +void conv2dcpu(float *pin, float *pwei, float *pout, int n, int c, int h, int w, int k, int r, int s, int u, int v, int p, int q) +{ + int oh = (h + 2 * p - r) / u + 1; + int ow = (w + 2 * q - s) / v + 1; + + for (int nNum = 0; nNum < n; nNum++) + { + for (int kNum = 0; kNum < k; kNum++) + { + for (int i = 0; i < oh; i++) + { + for (int j = 0; j < ow; j++) + { + double sum = 0.0; + int posh = i * u - p; + int posw = j * v - q; + + for (int cNum = 0; cNum < c; cNum++) + { + for (int khNum = 0; khNum < r; khNum++) + { + for (int kwNum = 0; kwNum < s; kwNum++) + { + int posh_ori = posh + khNum; + int posw_ori = posw + kwNum; + if (posw_ori >= 0 && posh_ori >= 0 && posw_ori < w && posh_ori < h) + { + sum += (double)(pin[nNum * c * h * w + cNum * (w * h) + posh_ori * w + posw_ori] * pwei[kNum * r * s * c + cNum * r * s + khNum * s + kwNum]); + } + } + } + } + + pout[nNum * k * oh * ow + kNum * oh * ow + i * ow + j] = (float)sum; + } + } + } + } +} +#endif \ No newline at end of file diff --git a/docs/12_convolution/03_im2col_conv/codes/job.sh b/docs/12_convolution/03_im2col_conv/codes/job.sh new file mode 100644 index 0000000..77fd12a --- /dev/null +++ b/docs/12_convolution/03_im2col_conv/codes/job.sh @@ -0,0 +1,11 @@ +#!/bin/bash +make clean +make + +./conv2ddemo 128 3 225 225 32 3 3 2 2 0 0 +./conv2ddemo 49 128 35 35 384 3 3 2 2 0 0 +./conv2ddemo 16 128 105 105 256 3 3 2 2 0 0 +./conv2ddemo 128 3 230 230 64 7 7 2 2 0 0 +./conv2ddemo 2 3 838 1350 64 7 7 2 2 0 0 +./conv2ddemo 256 256 28 28 256 2 2 2 2 0 0 +./conv2ddemo 128 3 225 225 32 3 3 1 1 0 0 diff --git a/docs/12_convolution/03_im2col_conv/codes/src/conv2d.cu b/docs/12_convolution/03_im2col_conv/codes/src/conv2d.cu new file mode 100644 index 0000000..7103def --- /dev/null +++ b/docs/12_convolution/03_im2col_conv/codes/src/conv2d.cu @@ -0,0 +1,593 @@ +#include +#include +#include +#include "conv2d.h" +#include "verfiy.h" // 包含用于验证的自定义头文件 + +#define KernelErrChk() \ + { \ + cudaError_t errSync = cudaGetLastError(); \ + cudaError_t errAsync = cudaDeviceSynchronize(); \ + if (errSync != cudaSuccess) \ + { \ + printf("Sync kernel error: %s\n", cudaGetErrorString(errSync)); \ + exit(EXIT_FAILURE); \ + } \ + if (errAsync != cudaSuccess) \ + { \ + printf("Async kernel error: %s\n", cudaGetErrorString(errAsync)); \ + exit(EXIT_FAILURE); \ + } \ + } + +template +__global__ void sgemm_kernel(int M, int N, int K, float alpha, float *A, float *B, float beta, float *C, int kernel_num, int y_height, int y_width) +{ + int bx = blockIdx.x; + int by = blockIdx.y; + + int block_row_thread = BN / TN; + int block_col_thread = BM / TM; + int thread_num = block_row_thread * block_col_thread; // 一个线程负责计算block中TM*TN个元素 + + int tx = (threadIdx.x % block_row_thread) * TN; + int ty = (threadIdx.x / block_row_thread) * TM; + + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // 移动到当前block + A = &A[by * BM * K]; + B = &B[bx * BN]; + C = &C[by * BM * N + bx * BN]; + + /* + 当前线程负责搬运全局内存中第a_tile_row行,第a_tile_col列元素至共享内存第a_tile_row行,第a_tile_col列 + a_tile_stride表示block中线程可搬运a_tile_stride行至共享内存; + + 若BM=64,BK=8,thread_num=512,则a_tile_stride=64,a_tile_stride=BM,表示每个线程搬运一轮即可完成所需元素的搬运; + 若BM=128,BK=8,thread_num=512,则a_tile_stride=64,表示每个线程搬运两轮即可完成所需元素的搬运; + */ + int a_tile_row = threadIdx.x / BK; + int a_tile_col = threadIdx.x % BK; + int a_tile_stride = thread_num / BK; + + int b_tile_row = threadIdx.x / BN; + int b_tile_col = threadIdx.x % BN; + int b_tile_stride = thread_num / BN; + + float tmp[TM][TN] = {0.}; // 每个线程负责TM*TN个元素,则需要申请TM*TN个寄存器保存累加值,额外的一个寄存器用于缓存; + float a_frag[TM] = {0.}; + float b_frag[TN] = {0.}; + +#pragma unroll + for (int k = 0; k < K; k += BK) + { +#pragma unroll + for (int i = 0; i < BM; i += a_tile_stride) + { + As[(a_tile_row + i) * BK + a_tile_col] = A[(a_tile_row + i) * K + a_tile_col]; + } +#pragma unroll + for (int i = 0; i < BK; i += b_tile_stride) + { + Bs[(b_tile_row + i) * BN + b_tile_col] = B[(b_tile_row + i) * N + b_tile_col]; + } + __syncthreads(); + A += BK; + B += BK * N; +#pragma unroll + for (int i = 0; i < BK; i++) + { +#pragma unroll + for (int j = 0; j < TM; j++) + { + a_frag[j] = As[(ty + j) * BK + i]; + } +#pragma unroll + for (int l = 0; l < TN; l++) + { + b_frag[l] = Bs[tx + l + i * BN]; + } +#pragma unroll + for (int j = 0; j < TM; j++) + { +#pragma unroll + for (int l = 0; l < TN; l++) + tmp[j][l] += a_frag[j] * b_frag[l]; + } + } + __syncthreads(); + } +#pragma unroll + for (int j = 0; j < TM; j++) + { + for (int l = 0; l < TN; l++) + { + int batch = l / (y_height * y_width); + int conv_idx = batch * (kernel_num * y_height * y_width) + ((ty + j) * N + tx) / (batch * y_height * y_width) * (y_height * y_width) + l % (y_height * y_width); + if ((ty + j) * N + tx + l < M * N) + { + C[conv_idx] = alpha * tmp[j][l] + beta * C[conv_idx]; + } + } + } +} + +template +__global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) + sgemm_blocktiling_2d_kernel(float *A, + float *B, + float *C, + int M, + int N, + int K, + int batch, + int kernel_num, + int y_height, + int y_width) +{ + // the output block that we want to compute in this threadblock + const uint c_row = blockIdx.y; + const uint c_col = blockIdx.x; + + // // A thread is responsible for calculating TM*TN elements in the blocktile + const uint num_threads_block_tile = (BM * BN) / (TM * TN); + + // allocate shared memory for the input and output submatrices + __shared__ float A_shared[BM * BK]; + __shared__ float B_shared[BK * BN]; + + // the inner row & col that we're accessing in this thread + const uint thread_row = threadIdx.x / (BN / TN); + const uint thread_col = threadIdx.x % (BN / TN); + + // advance pointers to the starting positions + A += c_row * BM * K; + B += c_col * BN; + int global_c_index = c_row * BM * N + c_col * BN; + + // use to avoid out-of-bounds accesses + int global_m_pos = c_row * BM * K; + int global_n_pos = c_col * BN; + const uint m_size = M * K; + const uint n_size = N * K; + + assert((BM * BN) / (TM * TN) == blockDim.x); + + const uint A_inner_row = threadIdx.x / BK; // warp-level GMEM coalescing + const uint A_inner_col = threadIdx.x % BK; + const uint stride_a = num_threads_block_tile / BK; + const uint B_inner_row = threadIdx.x / BN; // warp-level GMEM coalescing + const uint B_inner_col = threadIdx.x % BN; + const uint stride_b = num_threads_block_tile / BN; + + // allocate thread-local cache for results in registerfile + float thread_results[TM * TN] = {0.0}; + float reg_m[TM] = {0.0}; + float reg_n[TN] = {0.0}; + + // outer loop over block tiles + for (uint bk_idx = 0; bk_idx < K; bk_idx += BK) + { + // load the next block of the input matrices into shared memory + for (uint load_offset = 0; load_offset < BM; load_offset += stride_a) + { + A_shared[(A_inner_row + load_offset) * BK + A_inner_col] = + (global_m_pos + (A_inner_row + load_offset) * K + A_inner_col < m_size) ? A[(A_inner_row + load_offset) * K + A_inner_col] : 0.0f; + } + for (uint load_offset = 0; load_offset < BK; load_offset += stride_b) + { + B_shared[(B_inner_row + load_offset) * BN + B_inner_col] = + (global_n_pos + (B_inner_row + load_offset) * N + B_inner_col < n_size) ? B[(B_inner_row + load_offset) * N + B_inner_col] : 0.0f; + } + + // wait for all threads to finish loading + __syncthreads(); + + // advance the pointers + A += BK; + B += BK * N; + global_m_pos += BK; + global_n_pos += BK * N; + + // compute the partial sum + for (uint dot_idx = 0; dot_idx < BK; dot_idx++) + { + // load relevant As & Bs entries into registers + for (uint i = 0; i < TM; i++) + { + reg_m[i] = A_shared[(thread_row * TM + i) * BK + dot_idx]; + } + for (uint i = 0; i < TN; i++) + { + reg_n[i] = B_shared[dot_idx * BN + thread_col * TN + i]; + } + + // perform outer product on register cache, accumulate + // into threadResults + for (uint res_idx_m = 0; res_idx_m < TM; res_idx_m++) + { + for (uint res_idx_n = 0; res_idx_n < TN; res_idx_n++) + { + thread_results[res_idx_m * TN + res_idx_n] += reg_m[res_idx_m] * reg_n[res_idx_n]; + } + } + } + + // wait for all threads to finish computing + __syncthreads(); + } + + int inner_y_size = y_height * y_width; + int res_inner_index, g_index, batch_id, channel_id, inner_offset; + + int conv_idx; + + if (global_c_index >= M * N) + { + return; + } + + for (uint res_idx_m = 0; res_idx_m < TM; res_idx_m++) + { + for (uint res_idx_n = 0; res_idx_n < TN; res_idx_n++) + { + if (c_row * BM + thread_row * TM + res_idx_m < M && c_col * BN + thread_col * TN + res_idx_n < N) + { + res_inner_index = (thread_row * TM + res_idx_m) * N + thread_col * TN + res_idx_n; + g_index = global_c_index + res_inner_index; + inner_offset = g_index % inner_y_size; + batch_id = (g_index % (inner_y_size * batch)) / inner_y_size; + channel_id = g_index / (inner_y_size * batch); + conv_idx = batch_id * (kernel_num * y_height * y_width) + channel_id * (y_height * y_width) + inner_offset; + C[conv_idx] = thread_results[res_idx_m * TN + res_idx_n]; + } + } + } +} + +#define CEIL_DIV(M, N) ((M) + (N)-1) / (N) + +void cuda_gemm(float *A, + float *B, + float *C, + int M, + int N, + int K, + int batch, + int kernel_num, + int y_height, + int y_width) +{ + const uint BK = 8; + const uint TM = 8; + const uint TN = 8; + if (M >= 128 && N >= 128) + { + const uint BM = 128; + const uint BN = 128; + dim3 grid_size(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 block_size((BM * BN) / (TM * TN)); + sgemm_blocktiling_2d_kernel + <<>>(A, B, C, M, N, K, batch, kernel_num, y_height, y_width); + } + else + { + const uint BM = 64; + const uint BN = 64; + dim3 grid_size(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 block_size((BM * BN) / (TM * TN)); + sgemm_blocktiling_2d_kernel + <<>>(A, B, C, M, N, K, batch, kernel_num, y_height, y_width); + } +} + +#define MAX_THREADS 1024 +template +static int FetchMaxBlokcSize(T cuda_kernel, const int share_memory_size = 0) +{ + int minGridSize{0}; + int blockSize{0}; + cudaOccupancyMaxPotentialBlockSize(&minGridSize, &blockSize, cuda_kernel, share_memory_size, MAX_THREADS); + return blockSize; +} + +#define MaxBlockSize 512 + +template +__global__ void im2col_kernel(const int n, + T *data_x, + T *data_y, + const int batches, + const int inner_size_x, + const int inner_size_y, + const int x_height, + const int x_width, + const int kernel_height, + const int kernel_width, + const int pad_height, + const int pad_width, + const int stride_height, + const int stride_width, + const int dilation_height, + const int dilation_width, + const int y_height, + const int y_width, + const int inner_size_c) +{ + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < n; index += blockDim.x * gridDim.x) + { + int batch = index / inner_size_c, idx = index % inner_size_c; + int w_out = idx % y_width, id = idx / y_width; + int h_out = id % y_height, channel_in = id / y_height; + int channel_out = channel_in * kernel_height * kernel_width; + int h_in = h_out * stride_height - pad_height; + int w_in = w_out * stride_width - pad_width; + T *out = data_y + batch * (y_height * y_width) + (channel_out * y_height * batches + h_out) * y_width + w_out; + T *in = data_x + batch * inner_size_x + (channel_in * x_height + h_in) * x_width + w_in; + for (int i = 0; i < kernel_height; ++i) + { + for (int j = 0; j < kernel_width; ++j) + { + int h = h_in + i * dilation_height; + int w = w_in + j * dilation_width; + *out = (h >= 0 && w >= 0 && h < x_height && w < x_width) + ? in[i * dilation_height * x_width + j * dilation_width] + : static_cast(0); + out += y_height * y_width * batches; + } + } + } +} + +template +void cuda_im2col(const int batches, + const int x_channel, + const int x_height, + const int x_width, + const int y_out_plane, + const int y_height, + const int y_width, + const int kernel_height, + const int kernel_width, + const int stride_height, + const int stride_width, + const int dilation_height, + const int dilation_width, + const int pad_height, + const int pad_width, + T *x, + T *y) +{ + const int inner_size_y = y_out_plane * y_height * y_width; + const int inner_size_x = x_channel * x_height * x_width; + const int inner_size_c = x_channel * y_height * y_width; + const int num_kernels = batches * inner_size_c; + + const int blockSize = std::max(std::min(MaxBlockSize, num_kernels), static_cast(1)); + const int gridSize = (num_kernels + blockSize - 1) / blockSize; + + im2col_kernel<<>>(num_kernels, + x, + y, + batches, + inner_size_x, + inner_size_y, + x_height, + x_width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + y_height, + y_width, + inner_size_c); +} + +int main(int argc, char **argv) +{ + // 从命令行参数中获取输入参数 + int n = atoi(argv[1]); // 批大小 + int c = atoi(argv[2]); // 输入通道数 + int h = atoi(argv[3]); // 输入高度 + int w = atoi(argv[4]); // 输入宽度 + int k = atoi(argv[5]); // 卷积核数 + int r = atoi(argv[6]); // 卷积核高度 + int s = atoi(argv[7]); // 卷积核宽度 + int u = atoi(argv[8]); // 垂直步幅 + int v = atoi(argv[9]); // 水平步幅 + int p = atoi(argv[10]); // 垂直填充 + int q = atoi(argv[11]); // 水平填充 + + // 计算输出特征图的高度和宽度 + int outh = (h - r + 2 * p) / u + 1; + int outw = (w - s + 2 * q) / v + 1; + + // 分配并初始化输入、权重、输出和主机端输出数据的内存 + float *pIn = (float *)malloc(n * c * h * w * sizeof(float)); + float *pInCol = (float *)malloc(n * c * r * s * outh * outw * sizeof(float)); + float *pWeight = (float *)malloc(k * c * r * s * sizeof(float)); + float *pOut = (float *)malloc(n * k * outh * outw * sizeof(float)); + float *pOut_host = (float *)malloc(n * k * outh * outw * sizeof(float)); + + float *pIn_device, *pInCol_device, *pWeight_device, *pOut_device; + cudaMalloc(&pIn_device, n * c * h * w * sizeof(float)); + cudaMalloc(&pInCol_device, n * c * r * s * outh * outw * sizeof(float)); + cudaMalloc(&pWeight_device, k * c * r * s * sizeof(float)); + cudaMalloc(&pOut_device, n * k * outh * outw * sizeof(float)); + + // 随机初始化输入和权重数据 + for (int i = 0; i < n * c * h * w; i++) + { + pIn[i] = (rand() % 255) / 255.0; + } + + for (int i = 0; i < n * c * r * s * outh * outw; i++) + { + pInCol[i] = 0.0; + } + + for (int i = 0; i < k * c * r * s; i++) + { + pWeight[i] = (rand() % 255) / 255.0; + } + + for (int i = 0; i < n * k * outh * outw; i++) + { + pOut[i] = 0.0; + pOut_host[i] = 0.0; + } + + // 将输入、权重和输出数据从主机内存复制到设备内存 + cudaMemcpy(pIn_device, pIn, n * c * h * w * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(pInCol_device, pInCol, n * c * r * s * outh * outw * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(pWeight_device, pWeight, k * c * r * s * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy((void **)pOut_device, pOut, n * k * outh * outw * sizeof(float), cudaMemcpyHostToDevice); + + /*******************************warm up and get result************************************/ + // 计算 + int batch_size = n; + int x_channel = c; + int x_height = h; + int x_width = w; + int y_height = outh; + int y_width = outw; + int kernel_numbers = k; + int kernel_height = r; + int kernel_width = s; + int stride_height = u; + int stride_width = v; + int dilation_height = 1; + int dilation_width = 1; + int pad_height = p; + int pad_width = q; + int y_out_plane = x_channel * kernel_height * kernel_width; + + cuda_im2col(batch_size, + x_channel, + x_height, + x_width, + y_out_plane, + y_height, + y_width, + kernel_height, + kernel_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + pad_height, + pad_width, + pIn_device, + pInCol_device); + KernelErrChk(); + + cudaDeviceSynchronize(); + + std::cout << kernel_numbers << " " << batch_size * y_height * y_width << " " << x_channel * kernel_height * kernel_width << std::endl; + cuda_gemm(pWeight_device, + pInCol_device, + pOut_device, + kernel_numbers, + batch_size * y_height * y_width, + x_channel * kernel_height * kernel_width, + batch_size, + kernel_numbers, + y_height, y_width); + + KernelErrChk(); + + cudaDeviceSynchronize(); + + cudaMemcpy(pOut_host, pOut_device, n * k * outh * outw * sizeof(float), cudaMemcpyDeviceToHost); + + /*******************************cost time test************************************/ + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + cudaEventRecord(start, 0); + + float time_elapsed = 0.0; + + int iternum = 100; + for (int i = 0; i < iternum; i++) + { + cuda_im2col(batch_size, + x_channel, + x_height, + x_width, + y_out_plane, + y_height, + y_width, + kernel_height, + kernel_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + pad_height, + pad_width, + pIn_device, + pInCol_device); + KernelErrChk(); + + cudaDeviceSynchronize(); + + cuda_gemm(pWeight_device, + pInCol_device, + pOut_device, + kernel_numbers, + batch_size * y_height * y_width, + x_channel * kernel_height * kernel_width, + batch_size, + kernel_numbers, + y_height, y_width); + cudaDeviceSynchronize(); + } + + cudaEventRecord(stop, 0); + + cudaEventSynchronize(stop); + cudaEventElapsedTime(&time_elapsed, start, stop); + + printf("time: %f us\n", time_elapsed * 1000 / iternum); + cudaEventDestroy(start); + cudaEventDestroy(stop); + + printf("===================start verify===================\n"); + // 调用CPU上的卷积函数以验证GPU计算结果 + conv2dcpu(pIn, pWeight, pOut, n, c, h, w, k, r, s, u, v, p, q); + + int error = 0; + for (int i = 0; i < n * k * outh * outw; i++) + { + if (abs(pOut_host[i] - pOut[i]) > getPrecision(pOut[i])) + { + printf("error, position:%d, gpuvalue:%f, cpuvalue:%f\n", i, pOut_host[i], pOut[i]); + error++; + break; + } + } + printf("================finish,error:%d=========================\n", error); + + // 释放设备和主机内存 + cudaFree(pIn_device); + cudaFree(pWeight_device); + cudaFree(pOut_device); + + free(pIn); + free(pWeight); + free(pOut); + free(pOut_host); + + return 0; +} \ No newline at end of file diff --git a/docs/12_convolution/03_im2col_conv/images/im2col.jpg b/docs/12_convolution/03_im2col_conv/images/im2col.jpg new file mode 100644 index 0000000..bebdd77 Binary files /dev/null and b/docs/12_convolution/03_im2col_conv/images/im2col.jpg differ