@@ -38,7 +38,8 @@ namespace {
38
38
static TensorDim calcCol2ImOutputDim (const TensorDim &out,
39
39
const TensorDim &kdim) {
40
40
41
- return TensorDim ({kdim.getFeatureLen (), out.width () * out.height ()});
41
+ return TensorDim ({kdim.getFeatureLen (), out.width () * out.height ()},
42
+ out.getTensorType ());
42
43
}
43
44
44
45
/* *
@@ -56,7 +57,10 @@ static void col2im(const Tensor &col_matrix, const TensorDim &kdim,
56
57
const std::array<props::Stride, CONV2D_DIM> &mstride,
57
58
const std::array<props::Dilation, CONV2D_DIM> &dilation,
58
59
Tensor &image) {
59
- auto [pt, pb, pl, pr] = padding;
60
+ auto pt = padding[0 ];
61
+ auto pb = padding[1 ];
62
+ auto pl = padding[2 ];
63
+ auto pr = padding[3 ];
60
64
61
65
unsigned k_height = kdim.height ();
62
66
unsigned k_width = kdim.width ();
@@ -84,32 +88,48 @@ static void col2im(const Tensor &col_matrix, const TensorDim &kdim,
84
88
int h_stride_end = im_eff_height - eff_k_height - pt;
85
89
int w_stride_end = im_eff_width - eff_k_width - pl;
86
90
87
- unsigned col_w = 0 ;
88
- for (int hs = -pt; hs <= h_stride_end; hs += hstride) {
89
- for (int ws = -pl; ws <= w_stride_end; ws += wstride) {
90
- unsigned col_h = 0 ;
91
- int patch_height_end = hs + eff_k_height;
92
- int patch_width_end = ws + eff_k_width;
93
- for (unsigned c = 0 ; c < im_channel; c++) {
94
- for (int h = hs; h < patch_height_end; h += hdilation) {
95
- if (h < 0 || im_height <= h) {
96
- col_h += k_width;
97
- continue ;
98
- }
99
- for (int w = ws; w < patch_width_end; w += wdilation) {
100
- if (w < 0 || im_width <= w) {
101
- col_h++;
91
+ auto apply_data = [&]<typename T>(T *val) {
92
+ unsigned col_w = 0 ;
93
+ for (int hs = -pt; hs <= h_stride_end; hs += hstride) {
94
+ for (int ws = -pl; ws <= w_stride_end; ws += wstride) {
95
+ unsigned col_h = 0 ;
96
+ int patch_height_end = hs + eff_k_height;
97
+ int patch_width_end = ws + eff_k_width;
98
+ for (unsigned c = 0 ; c < im_channel; c++) {
99
+ for (int h = hs; h < patch_height_end; h += hdilation) {
100
+ if (h < 0 || im_height <= h) {
101
+ col_h += k_width;
102
102
continue ;
103
103
}
104
-
105
- float *val = image.getAddress <float >(0 , c, h, w);
106
- *val += col_matrix.getValue <float >(0 , 0 , col_h, col_w);
107
- col_h++;
104
+ for (int w = ws; w < patch_width_end; w += wdilation) {
105
+ if (w < 0 || im_width <= w) {
106
+ col_h++;
107
+ continue ;
108
+ }
109
+
110
+ val = image.getAddress <T>(0 , c, h, w);
111
+ *val += col_matrix.getValue <T>(0 , 0 , col_h, col_w);
112
+ col_h++;
113
+ }
108
114
}
109
115
}
116
+ col_w++;
110
117
}
111
- col_w++;
112
118
}
119
+ };
120
+
121
+ if (image.getDataType () == nntrainer::Tdatatype::FP32) {
122
+ float val;
123
+ apply_data (&val);
124
+ }
125
+ #ifdef ENABLE_FP16
126
+ else if (image.getDataType () == nntrainer::Tdatatype::FP16) {
127
+ _FP16 val;
128
+ apply_data (&val);
129
+ }
130
+ #endif
131
+ else {
132
+ throw std::runtime_error (" Not supported datatype" );
113
133
}
114
134
}
115
135
@@ -179,7 +199,10 @@ static void im2col(const Tensor &in, const TensorDim &kdim,
179
199
// }
180
200
*/
181
201
182
- auto [pt, pb, pl, pr] = padding;
202
+ auto pt = padding[0 ];
203
+ auto pb = padding[1 ];
204
+ auto pl = padding[2 ];
205
+ auto pr = padding[3 ];
183
206
184
207
unsigned int channel = in.channel ();
185
208
int in_height = in.height ();
@@ -198,46 +221,62 @@ static void im2col(const Tensor &in, const TensorDim &kdim,
198
221
unsigned int out_width = (width - eff_k_width) / mstride[1 ] + 1 ;
199
222
200
223
out.reshape (
201
- TensorDim ({out_height * out_width, in.channel () * k_height * k_width}));
202
- float *out_data = out.getData ();
203
-
204
- int h_stride_end = height - eff_k_height - pt;
205
- int w_stride_end = width - eff_k_width - pl;
206
-
207
- // / get a patch, size of kernel
208
- // / hs is height_strided, ws is width_strided
209
- unsigned int owidth = out.width ();
210
- unsigned int base_im_w = 0 ;
211
- for (int hs = -pt; hs <= h_stride_end; hs += mstride[0 ]) {
212
- unsigned int base_im_h = 0 ;
213
- int patch_height_end = eff_k_height + hs;
214
- // / map the patch to a single line looping through channel
215
- for (unsigned int c = 0 ; c < channel; ++c) {
216
- for (int h = hs; h < patch_height_end; h += dilation[0 ]) {
217
- if (h < 0 || in_height <= h) {
218
- base_im_h += k_width;
219
- continue ;
220
- }
221
-
222
- unsigned int im_w = base_im_w;
223
- for (int ws = -pl; ws <= w_stride_end; ws += mstride[1 ]) {
224
- unsigned int im_h = base_im_h;
225
- int patch_width_end = eff_k_width + ws;
224
+ TensorDim ({out_height * out_width, in.channel () * k_height * k_width},
225
+ in.getTensorType ()));
226
+
227
+ auto apply_data = [&]<typename T>(T *out_data) {
228
+ int h_stride_end = height - eff_k_height - pt;
229
+ int w_stride_end = width - eff_k_width - pl;
230
+
231
+ // / get a patch, size of kernel
232
+ // / hs is height_strided, ws is width_strided
233
+ unsigned int owidth = out.width ();
234
+ unsigned int base_im_w = 0 ;
235
+ for (int hs = -pt; hs <= h_stride_end; hs += mstride[0 ]) {
236
+ unsigned int base_im_h = 0 ;
237
+ int patch_height_end = eff_k_height + hs;
238
+ // / map the patch to a single line looping through channel
239
+ for (unsigned int c = 0 ; c < channel; ++c) {
240
+ for (int h = hs; h < patch_height_end; h += dilation[0 ]) {
241
+ if (h < 0 || in_height <= h) {
242
+ base_im_h += k_width;
243
+ continue ;
244
+ }
226
245
227
- for (int w = ws; w < patch_width_end; w += dilation[1 ]) {
228
- if (w < 0 || in_width <= w) {
246
+ unsigned int im_w = base_im_w;
247
+ for (int ws = -pl; ws <= w_stride_end; ws += mstride[1 ]) {
248
+ unsigned int im_h = base_im_h;
249
+ int patch_width_end = eff_k_width + ws;
250
+
251
+ for (int w = ws; w < patch_width_end; w += dilation[1 ]) {
252
+ if (w < 0 || in_width <= w) {
253
+ im_h++;
254
+ continue ;
255
+ }
256
+ out_data[im_w * owidth + im_h] = in.getValue <T>(0 , c, h, w);
229
257
im_h++;
230
- continue ;
231
258
}
232
- out_data[im_w * owidth + im_h] = in.getValue <float >(0 , c, h, w);
233
- im_h++;
259
+ im_w++;
234
260
}
235
- im_w++ ;
261
+ base_im_h += k_width ;
236
262
}
237
- base_im_h += k_width;
238
263
}
264
+ base_im_w += out_width;
239
265
}
240
- base_im_w += out_width;
266
+ };
267
+
268
+ if (out.getDataType () == nntrainer::Tdatatype::FP32) {
269
+ float *out_data = out.getData <float >();
270
+ apply_data (out_data);
271
+ }
272
+ #ifdef ENABLE_FP16
273
+ else if (out.getDataType () == nntrainer::Tdatatype::FP16) {
274
+ _FP16 *out_data = out.getData <_FP16>();
275
+ apply_data (out_data);
276
+ }
277
+ #endif
278
+ else {
279
+ throw std::runtime_error (" Not supported datatype" );
241
280
}
242
281
}
243
282
@@ -279,9 +318,11 @@ void Conv2DLayer::finalize(InitLayerContext &context) {
279
318
auto &dilation =
280
319
std::get<std::array<props::Dilation, CONV2D_DIM>>(conv_props);
281
320
282
- TensorDim kernel_dim =
283
- TensorDim (filter_size, in_dim.channel (), kernel_size[0 ], kernel_size[1 ]);
284
- TensorDim bias_dim = TensorDim (1 , filter_size, 1 , 1 );
321
+ auto in_t_type = in_dim.getTensorType ();
322
+ in_t_type.data_type = context.getWeightDataType ();
323
+ TensorDim kernel_dim = TensorDim (filter_size, in_dim.channel (),
324
+ kernel_size[0 ], kernel_size[1 ], in_t_type);
325
+ TensorDim bias_dim = TensorDim (1 , filter_size, 1 , 1 , in_t_type);
285
326
286
327
padding = std::get<props::Padding2D>(conv_props)
287
328
.compute (in_dim, kernel_dim, {stride[0 ], stride[1 ]},
@@ -309,6 +350,7 @@ void Conv2DLayer::finalize(InitLayerContext &context) {
309
350
out_dim.channel (filter_size);
310
351
out_dim.height ((eff_in_height - eff_k_height) / stride[0 ] + 1 );
311
352
out_dim.width ((eff_in_width - eff_k_width) / stride[1 ] + 1 );
353
+ out_dim.setTensorType (in_dim.getTensorType ());
312
354
context.setOutputDimensions ({out_dim});
313
355
314
356
NNTR_THROW_IF (eff_in_height < kernel_size[0 ] || eff_in_width < kernel_size[1 ],
0 commit comments