-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Doc] Add im2col + gemm 实现 卷积算子 (#32)
- Loading branch information
Showing
9 changed files
with
1,141 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
CC=nvcc | ||
|
||
CXXFLAGS += -DNDEBUG -DUSE_DEFAULT_STDLIB -g | ||
|
||
INCLUDES += -I./include | ||
|
||
LDFLAGS = -gencode arch=compute_75,code=sm_75 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61, -gencode arch=compute_70,code=sm_70 | ||
|
||
# 获取当前目录下的cu文件集,放在变量CUR_SOURCE中 | ||
CUR_SOURCE=${wildcard ./src/*.cu} | ||
|
||
# 将对应的cu文件名转为o文件后放在下面的CUR_OBJS变量中 | ||
CUR_OBJS=${patsubst %.cu, %.o, $(CUR_SOURCE)} | ||
|
||
EXECUTABLE=conv2ddemo | ||
|
||
all: $(EXECUTABLE) | ||
|
||
$(EXECUTABLE): $(CUR_OBJS) | ||
$(CC) $(CUR_OBJS) $(LDFLAGS) -o $(EXECUTABLE) | ||
|
||
%.o: %.cu | ||
$(CC) -c $< $(CXXFLAGS) $(INCLUDES) -o $@ -Xptxas -v -lineinfo --std=c++11 ${LDFLAGS} | ||
|
||
clean: | ||
rm -f $(EXECUTABLE) | ||
rm -f ./src/*.o |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#ifndef __CONV2D_FWD_HEADER__ | ||
#define __CONV2D_FWD_HEADER__ | ||
|
||
#define __in__ | ||
#define __out__ | ||
#define __in_out__ | ||
|
||
typedef struct | ||
{ | ||
float *in; // 输入数据地址 | ||
float *weight; // 权值数据地址 | ||
float *out; // 输出数据地址 | ||
unsigned int n; // batch szie default value 1 | ||
unsigned int c; // channel number default value 32 | ||
unsigned int h; // 数据高 default value 32 | ||
unsigned int w; // 数据宽 default value 32 | ||
unsigned int k; // 卷积核数量 default value 32 | ||
unsigned int r; // 卷积核高 default value 1 | ||
unsigned int s; // 卷积核宽 default value 1 | ||
unsigned int u; // 卷积在高方向上的步长 default value 1 | ||
unsigned int v; // 卷积在宽方向上的步长 default value 1 | ||
unsigned int p; // 卷积在高方向上的补边 default value 0 | ||
unsigned int q; // 卷积在宽方向上的补边 default value 0 | ||
} problem_t; | ||
|
||
typedef struct | ||
{ | ||
unsigned int blockx; // blockx number | ||
unsigned int blocky; // blocky number | ||
unsigned int blockz; // blockz number | ||
unsigned int threadx; // threadx number per block | ||
unsigned int thready; // thready number per block | ||
unsigned int threadz; // threadz number per block | ||
unsigned int dynmicLdsSize; // 动态分配的lds大小,如果不使用动态分配的lds,则该值为0; | ||
void *kernelPtr; // kernel ptr | ||
} kernelInfo_t; | ||
|
||
int getParamsize(__in__ problem_t *problem, __out__ int *paramSize); | ||
int getkernelInfo(__in__ problem_t *problem, __out__ kernelInfo_t *kernelInfo, __in_out__ void *param); | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#ifndef __VERFIY_HEADER__ | ||
#define __VERFIY_HEADER__ | ||
|
||
float getPrecision(float tmp) | ||
{ | ||
int tmpInt = (int)tmp; | ||
float eNum = 1.0e-6; | ||
if (abs(tmpInt) > 0) | ||
{ | ||
while (tmpInt != 0) | ||
{ | ||
tmpInt = (int)(tmpInt / 10); | ||
eNum *= 10; | ||
} | ||
} | ||
else | ||
{ | ||
|
||
if (tmp == 0) | ||
return eNum; | ||
|
||
eNum = 1.0e-5; | ||
|
||
while (tmpInt == 0) | ||
{ | ||
tmp *= 10; | ||
tmpInt = (int)(tmp); | ||
eNum /= 10; | ||
} | ||
} | ||
return eNum; | ||
} | ||
|
||
void conv2dcpu(float *pin, float *pwei, float *pout, int n, int c, int h, int w, int k, int r, int s, int u, int v, int p, int q) | ||
{ | ||
int oh = (h + 2 * p - r) / u + 1; | ||
int ow = (w + 2 * q - s) / v + 1; | ||
|
||
for (int nNum = 0; nNum < n; nNum++) | ||
{ | ||
for (int kNum = 0; kNum < k; kNum++) | ||
{ | ||
for (int i = 0; i < oh; i++) | ||
{ | ||
for (int j = 0; j < ow; j++) | ||
{ | ||
double sum = 0.0; | ||
int posh = i * u - p; | ||
int posw = j * v - q; | ||
|
||
for (int cNum = 0; cNum < c; cNum++) | ||
{ | ||
for (int khNum = 0; khNum < r; khNum++) | ||
{ | ||
for (int kwNum = 0; kwNum < s; kwNum++) | ||
{ | ||
int posh_ori = posh + khNum; | ||
int posw_ori = posw + kwNum; | ||
if (posw_ori >= 0 && posh_ori >= 0 && posw_ori < w && posh_ori < h) | ||
{ | ||
sum += (double)(pin[nNum * c * h * w + cNum * (w * h) + posh_ori * w + posw_ori] * pwei[kNum * r * s * c + cNum * r * s + khNum * s + kwNum]); | ||
} | ||
} | ||
} | ||
} | ||
|
||
pout[nNum * k * oh * ow + kNum * oh * ow + i * ow + j] = (float)sum; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/bin/bash | ||
make clean | ||
make | ||
|
||
./conv2ddemo 128 3 225 225 32 3 3 2 2 0 0 | ||
./conv2ddemo 49 128 35 35 384 3 3 2 2 0 0 | ||
./conv2ddemo 16 128 105 105 256 3 3 2 2 0 0 | ||
./conv2ddemo 128 3 230 230 64 7 7 2 2 0 0 | ||
./conv2ddemo 2 3 838 1350 64 7 7 2 2 0 0 | ||
./conv2ddemo 256 256 28 28 256 2 2 2 2 0 0 | ||
./conv2ddemo 128 3 225 225 32 3 3 1 1 0 0 |
Oops, something went wrong.