-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
92 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,91 @@ | ||
# 卷积算子优化思路介绍 | ||
# 卷积算子优化思路介绍 | ||
|
||
上一篇文章中,我们介绍了卷积算子的简易实现,它是直接模拟卷积操作的过程,这种实现方式的缺点是计算量大,效率低。在本文中,我们将介绍卷积算子的优化思路。 | ||
|
||
卷积算子的主要优化思路就是将卷积运算转换为矩阵乘法运算。进而卷积算子优化问题就转化为了矩阵乘法优化问题。这篇文章中我们主要介绍一下如何将卷积运算转换为矩阵乘法运算。 | ||
|
||
## 1. 卷积算法映射为矩阵乘法 | ||
|
||
首先我们先来回顾一下卷积算法的定义,假设输入的特征图为 $X$,卷积核为 $K$,输出特征图为 $Y$,$X$ 的大小为 $N \times C \times H \times W$,$K$ 的大小为 $M \times C \times K_h \times K_w$,$Y$ 的大小为 $N \times M \times H \times W$。那么卷积算法的定义如下: | ||
|
||
$$ | ||
Y[n,oc,oh,ow] = \sum_{ic}\sum_{fh}\sum_{fw}X[n,ic,ih,iw] \times K[oc,ic,fh,fw] | ||
$$ | ||
|
||
其中,$n$ 表示 batch 的索引,$m$ 表示输出特征图的索引,$i$ 和 $j$ 分别表示输出特征图的高和宽的索引。其中 ih, iw 等坐标计算如下: | ||
|
||
``` | ||
ih = oh * stride_h + fh - padding_h | ||
iw = ow * stride_w + fw - padding_w | ||
``` | ||
|
||
其中,$stride_h$ 和 $stride_w$ 分别表示卷积核的高和宽的步长,$padding_h$ 和 $padding_w$ 分别表示卷积核的高和宽的填充。想要把卷积算法映射为矩阵乘法算法,我们需要使用 `im2col` 算法将输入特征图转换为矩阵,然后使用 `gemm` 算法进行矩阵乘法运算。 | ||
|
||
## 2. im2col 算法 | ||
|
||
im2col 就是 image to column 的缩写,它是一种将输入特征图转换为矩阵的算法。im2col 算法的主要思想是将卷积核在输入特征图上滑动,每次滑动的步长为卷积核的步长,然后将卷积核覆盖的区域拉成一个列向量,最后将所有的列向量拼接在一起,就得到了一个矩阵。 | ||
|
||
下面我们通过一个简单的例子来说明 im2col 算法的原理。假设输入特征图的大小为 $4 \times 4$,卷积核的大小为 $3 \times 3$,步长为 1,填充为 0。那么 im2col 算法的过程如下图所示: | ||
|
||
![im2col](./images/im2col.jpg) | ||
|
||
从上图中可以看出,im2col 算法的过程就是将卷积核在输入特征图上滑动,每次滑动的步长为卷积核的步长,然后将卷积核覆盖的区域拉成一个列向量,最后将所有的列向量拼接在一起,就得到了一个矩阵。 | ||
|
||
如果我们把这个矩阵记为 $X_{col}$,那么卷积算法就可以表示为: | ||
|
||
$$ | ||
Y = K \times X_{col} | ||
$$ | ||
|
||
其中,$K$ 表示卷积核,$X_{col}$ 表示输入特征图转换得到的矩阵。 | ||
|
||
## 3. 隐式 gemm 算法 | ||
|
||
im2col 算法会把输入特征图转换为一个矩阵,然后保存在内存中。其实我们可以直接在计算的时候不用保存这个矩阵,而是直接计算坐标的偏移量,然后直接从输入特征图中读取数据。这种算法就是隐式 gemm 算法。 | ||
|
||
根据上面的讨论,我们可以把卷积的运算过程,写成一个隐式矩阵乘法 (Implicit GEMM) 的形式: | ||
|
||
``` | ||
GEMM_M = OC | ||
GEMM_N = N * OH * OW | ||
GEMM_K = IC * FH * FW | ||
For i=0 to GEMM_M | ||
oc = i | ||
For j=0 to GEMM_N | ||
accumulator = 0 | ||
n = j / (OH * OW) | ||
j_res = j % (OH * OW) | ||
oh = j_res / OW | ||
ow = j_res % OW | ||
For k=0 to GEMM_K | ||
ic = k / (FH * FW) | ||
k_res = k % (FH * FW) | ||
fh = k_res / FW | ||
fw = k_res % FW | ||
ih = oh * stride_h - pad_h + fh | ||
iw = ow * stride_w - pad_w + fw | ||
accumulator = accumulator + x(n, ic, ih, iw) * w(oc, ic, fh, fw) | ||
y(n, oc, oh, ow) = accumulator | ||
``` | ||
|
||
上面的代码中,`GEMM_M` 表示输出特征图的通道数,`GEMM_N` 表示输出特征图的大小,`GEMM_K` 表示输入特征图的大小。在这个算法中,我们直接计算坐标的偏移量,然后直接从输入特征图中读取数据,然后进行计算。 | ||
|
||
## 4. 总结 | ||
|
||
总的来说,卷积算子的优化思路就是将卷积运算转换为矩阵乘法运算。这样做的好处是可以利用矩阵乘法的高效实现来提高卷积算子的计算效率。在下一篇文章中,我们将动手实现基于 im2col 算法的卷积算子优化版本。 | ||
|
||
## References | ||
|
||
1. https://zhuanlan.zhihu.com/p/372973726 | ||
2. https://blog.csdn.net/m0_45388819/article/details/120757424 | ||
3. https://blog.csdn.net/dwyane12138/article/details/78449898 | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
1 change: 1 addition & 0 deletions
1
docs/12_convolution/02_intro_conv_optimize/images/im2col.drawio
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
<mxfile host="Electron" modified="2024-03-18T09:05:05.891Z" agent="5.0 (Macintosh; Intel Mac OS X 10_16_0) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/13.7.9 Chrome/85.0.4183.121 Electron/10.1.3 Safari/537.36" etag="BGvOS3bmMKx1kujGUsmh" version="13.7.9" type="device"><diagram id="rteP-x4H_hP2QaLaT4T0" name="第 1 页">7Z3bcuJGEIafhsut0gGdLm1vsklltyopX+RaETKoViAi5DWbp4+EDsAMBGc9rX/kHm/VGkYwyPNJdM8/Pd0z92G9/1TG29WXYpHmM8da7Gfux5njRG5Y/980fG8bwiBqG5Zltmib7GPDY/ZP2jVaXetztkh3Zy+siiKvsu15Y1JsNmlSnbXFZVm8nL/sqcjPP3UbL1Op4TGJc7n1z2xRrdpWz7Ks44Ff0my5qsQj67h/ddewW8WL4uWkyf1p5j6URVG1j9b7hzRvBq8fmPZ9P185OpxZmW6q17yh+vy5iH/d//bHdrO/293dP3xaJR+6Xr7F+XP3F88cP6/7u38q6m7rs66+d2Ph//1c9Ac+7A6k7uoX2OF2fzxYP1o2v62+m/p82p7a9m4ohk6dl1VWpY/bOGmev9TXUP2iVbXO62d2/TDebVusT9k+XTSfnuX5Q5EX5eHt7lOYpElSt++qsvianhz5K/TmntUc+ZpWyarrb4BgDyfzLS2rdH91UO0BVX2Np8U6rcrv9Uv6N/QXand92/Pu+cvxaumbVifXSd8Wd9fncuj5SLB+0EH8H0AdIqA2F6CeZkBdIqAOF6CRZkDnREDdsYB6zb9LQP3DT3e+J+3tDzVox9UMtE1168653LqicQ3RRKnuXY8LUdG6wol6RER9LkRF8won6hMRDYx91Yx0QEQ65HLvCvZ1kCpgREMiohEXop5uRCMiojYbhSnSDGn/+eqRjqYxTcTC4lFTycP2aOqTpqgl0+ujUZMJx9x1Kckmw1FT6VL2aMKUrqhFWw1HTSVY2aMpVpqilmw1HDWVksXF0XbDc6IeGiiVkMVlKVcEGqGBUulVoznTmgG1XTRRKr2KywqfRDRAE6XSq7is8IlEh0g+WPwMlVzFZYVPIop2jFwqVYrLuo9EFO0ZuVTiE5d1H5Goi/aMXDKNiet01EW7Ru4lZ1cY83SzuGui7utnSR7vdllyPuhnI9a+O11IEfg3h+jlNBRfHoK+rUzzuMq+nXd/aVy6T/i9yA6XZi/xXFt567vYFc9lknbvOo6u1JHr3+ioistlWkkdHTANf/YbyFE5taPdi7dFwOv34jV5UME96usW/T+ncna5aEa+buH/cypnl4to5GsX/0/l7I620AomGugW6D+ffKD/VOwrOuR0bjYAKLavcKJmA4Bi+wonOvkNALrZVzjRyQf6T8S+wgNO52YDgFr7iidqNgAoNrBwpP0lNuENAJpZWDxSE+hPdffqFujvUclP7AP9JdsLR20C/ceyyXDUJtB/LFsNR02lWHHxviKBKDqezaMSrEbzvTQDig5n86j0Ki4rfCJQeKC/R6VLcVn3kYiio9k8Kl2Ky7qPSBQe6N/7ZWbdRxVRtGPkU6lSXFYDJKJoz8gnE5+4rAaISOGR/j6ZyMR1PgqP9PcviUnCmL+rSP9AVaR/CI70900SiLfdi6FwJcBdIJMEQi1QuAdkkkCoBQrXhnyTBEIxUbj/Y5JAqCUK14b6K8qofaqIoh2jwCSBUEwU7RkFVNoQF7VPJAqXhgKTBEI1UrRrZFsmlwDNXgwxI7SHjjMaTmi62oOuqD3tUE8+y4CuqEU9EY+ays3SKPwbglrMCK0B6snnJUDXXhFssm/BkZoEBGptrwZITQYCtTZWA6STT0EwERurAerJ5yZA372ijUVnc7JNFULVNhaP1GQhUG1k4UxZ1CHUw8riWU8/PwH6/hXNLDqpk+1QCVHjJSJAMxXtLJ7p9DMOoJmKdhbP1KQWIGIt2Vk8a5Nb4G1I5wLSOXzBnqyMIBfPSUTqoYMZbbJCglwcJwkpOprRJqskyGVdR0KKDme0yUoJclnXEZH66HhGm6yWIJcUAxJSuHtEVkyQi9gvIg3g7hFZNUE2Yr/EFO4f0dUT5DozDeAOksstzYBYUNAXx/ZHCwpKHRGnGbBdk2fgjdU9hUsBrxO5JtOAWqR4ncg1uQYUI8X7QSbZgGKkeDfIZBtQixSvE/U9G+lPFVK4e0RWh5dLnK+IFK8TkRXi5SL9SUjh7hFZJV420p/EFO4fkdXc5eLyDtqCNloDWdFdLi6viBSvNZBV3eXi8kpI8cb0uny028YbNS5v2xMTpHhbara2qUWqgdZgtrapZgr3j+gq7DpMmeLVBroSu2yidkWmcA+JrpYu17kpXm4gK5o7mjnVdFexWAkbnx+Lrmgud9TaJTgkK5o7mkOlK2rtEhySldPlnuBQrIStAWqqAKfR3C9NUYu2Gp9ny1TgVWyTNUBqNsiptb14pKYGr2IbqwHSyQc+TcXGwrNskVXn5bI6JNlYPFITEKXayOKZkslW3HUryfriWZMlZGIvXGmXEJGu1i+XFULJ/uKZUilUbFYIJfuLZ0olRY2XEFG3WS6eKVkUFROkkXaJD8lq/o7mOWmGFB9kTlb0l4uDJCGFh1CRVf3lYkslpPAIKrKyv1zWdUSk+CDzgEpy4rKuIyHFu0dUyhIXsV9Eig8xD8gEJC5iv8QU7x+RCUhcZ6b4EPPgkoAkDPq7SnwYqEp8GKITHwZUOhGXb9hQu8SHvfZopD9FSPE6UUilE3GR/iSkcD8opNKJuKyjSUjhblBIpRNxkf5EpHidKKTSibhIfxJSvHtEpRNxSXwoIsXrRCGVTsRF+pOQ4t0jMpmI68QULxOFJtf3G5U/7RIfhibXt1qkeK2h/5owLq8qpHBjGl2Xj0ziwx9CCrelEZV8xMXlFZHitYbIbG1TzRTuH0VkW9scpkzxakNEtoWNTdSudokPI7JII65zU7zcEFFJSFzmprYlbCnF6w0RlYTEZXIqMdVAcKDSkLjEYstM0fbU6bP8KFccTqen71lxkJmi7WlNhMpH4jI/laDCRQfHotKR2IRjy1DRXpJjkSlJXGaoElS47OBYZFISlymqDBXvJ5FpSWwnqXDhwbFuiklP7WgfsT0U6yypDzzGm13968vjKdFXXAPzS9fAY5VuX22FawKVQD7Plpv6cVKTSGum9w2nLInzu+7AOlssmrffl2l9KvFfh64a4Ntmu8phWL37mfex6eu5KtrTVcZd2DljX/iCvrQDyKHDflNvGg+7816x+33hg//A7o+L/aYkNR72V7tnU8M+VBDrZ8MXJk4jY7+pWo2H/dUO3NSwi3c7Ifb6aVk0Izsc+1T7QqsvxSJtXvEv</diagram></mxfile> |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.