-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrelu.cu
101 lines (81 loc) · 2.18 KB
/
relu.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include "relu.hh"
__global__ void reluActivationForward(float *Z, float *A,
int Z_x_dim, int Z_y_dim)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx < Z_x_dim * Z_y_dim)
{
A[idx] = fmaxf(Z[idx], 0);
}
}
__global__ void reluActivationBackprop(float* Z, float* dA, float* dZ,
int Z_x_dim, int Z_y_dim)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < Z_x_dim * Z_y_dim)
{
if (Z[idx] > 0)
{
dZ[idx] = dA[idx];
}
else
{
dZ[idx] = 0;
}
}
}
ReLU::ReLU(std::string name)
{
this->name = name;
}
ReLU::~ReLU()
{
// do nothing
}
host_vec& ReLU::forward(host_vec& Z, Shape& Z_shape)
{
this->Z = Z;
Z_device = Z;
this->Z_shape = Z_shape;
if(A.size()==0)
{
host_vec temp_A(Z_shape.x*Z_shape.y, 0);
A = temp_A;
A_device = A;
A_shape = Z_shape;
}
dim3 block_size(256);
dim3 num_of_blocks((Z_shape.x * Z_shape.y + block_size.x - 1)/block_size.x);
float* Z_device_ptr = thrust::raw_pointer_cast(Z_device.data());
float* A_device_ptr = thrust::raw_pointer_cast(A_device.data());
reluActivationForward<<<num_of_blocks, block_size>>>(Z_device_ptr, A_device_ptr,
Z_shape.x, Z_shape.y);
A = A_device;
Z_shape = A_shape;
return A;
}
host_vec& ReLU::backprop(host_vec& dA, float learning_rate, int mb_size)
{
device_vec dA_device = dA;
if(dZ.size()==0)
{
host_vec temp_dZ(Z_shape.x*Z_shape.y, 0);
dZ = temp_dZ;
dZ_device = dZ;
this->dZ_shape = Z_shape;
}
dim3 block_size(256);
dim3 num_of_blocks((Z_shape.y * Z_shape.x + block_size.x - 1) / block_size.x);
float* Z_device_ptr = thrust::raw_pointer_cast(Z_device.data());
float* dA_device_ptr = thrust::raw_pointer_cast(dA_device.data());
float* dZ_device_ptr = thrust::raw_pointer_cast(dZ_device.data());
reluActivationBackprop<<<num_of_blocks, block_size>>>(Z_device_ptr, dA_device_ptr,
dZ_device_ptr,
Z_shape.x, Z_shape.y);
dZ = dZ_device;
return dZ;
}
void ReLU::update_weights_bias(float lr)
{
/*do nothing*/
}