Skip to content

Commit eb7cf07

Browse files
committed
[Mixed Precision] Fix mixed precsion to use Tensor V2
This PR includes fixes to use TensorV2 Resolves: **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
1 parent a4a3750 commit eb7cf07

31 files changed

+139
-148
lines changed

api/ccapi/include/model.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,14 @@ class Model {
188188
* @details This function accepts vector of properties in the format -
189189
* { std::string property_name, void * property_val, ...}
190190
*/
191-
virtual int train(const std::vector<std::string> &values = {},
192-
std::function<bool(void *)> stop_cb =
193-
[](void *stop_user_data) { return false; },
194-
void *stop_user_data = nullptr,
195-
std::function<void(void *)> epoch_complete_cb =
196-
[](void *epoch_user_data) { return false; },
197-
void *epoch_user_data = nullptr) = 0;
191+
virtual int train(
192+
const std::vector<std::string> &values = {},
193+
std::function<bool(void *)> stop_cb =
194+
[](void *stop_user_data) { return false; },
195+
void *stop_user_data = nullptr,
196+
std::function<void(void *)> epoch_complete_cb =
197+
[](void *epoch_user_data) { return false; },
198+
void *epoch_user_data = nullptr) = 0;
198199

199200
/**
200201
* @brief Run Model train with callback function by user

meson.build

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,18 @@ warning_c_flags = [
7070

7171
arch = host_machine.cpu_family()
7272

73+
target = target_machine.cpu_family()
74+
7375
if get_option('enable-avx')
74-
extra_defines += '-DUSE_AVX=1'
75-
if get_option('platform') == 'tizen'
76-
add_project_arguments(['-mavx2'], language: ['c','cpp'])
77-
else
78-
add_project_arguments(['-march=native'], language: ['c','cpp'])
79-
endif
80-
message('-march=native added for AVX hardware acceleration.')
76+
if get_option('platform') != 'android'
77+
if target == 'x86_64' or target == 'x86'
78+
extra_defines += '-DUSE_AVX=1'
79+
add_project_arguments(['-march=native'], language: ['c','cpp'])
80+
add_project_arguments(['-mavx2'], language: ['c','cpp'])
81+
message('-march=native added for AVX hardware acceleration.')
82+
endif
83+
message('This arch does not support avx2')
84+
endif
8185
endif
8286

8387
if get_option('enable-fp16')

nntrainer/app_context.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ AppContext::registerPluggableFromDirectory(const std::string &base_path) {
559559
struct dirent *entry;
560560

561561
std::vector<int> keys;
562+
562563
while ((entry = readdir(dir)) != NULL) {
563564
if (endswith(entry->d_name, solib_suffix)) {
564565
if (endswith(entry->d_name, layerlib_suffix)) {
@@ -581,7 +582,8 @@ AppContext::registerPluggableFromDirectory(const std::string &base_path) {
581582
}
582583
}
583584

584-
closedir(dir);
585+
if (dir != NULL)
586+
closedir(dir);
585587

586588
return keys;
587589
}

nntrainer/graph/graph_core.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ void GraphCore::topologicalSort() {
9898
if (Sorted.size() != node_list.size())
9999
throw std::runtime_error("Internal error in topologicalSort");
100100
unsigned int idx = 0;
101-
for (auto n : Sorted) {
101+
for (auto &n : Sorted) {
102102
sorted_node_map[n->getName()] = idx;
103103
idx++;
104104
}

nntrainer/graph/network_graph.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ void NetworkGraph::applyGradients(
341341
/**
342342
* @note the weights whose gradient are to be clipped by global norm will
343343
* be clipped at once at the end of iteration and applied then.
344+
* For those weights where mixed precision is uesed, their gradient
345+
* updates might be delayed until they confirm whether their loss scales
346+
* are appropeiate.
344347
*/
345348
continue;
346349
}
@@ -438,7 +441,7 @@ bool NetworkGraph::backwarding(
438441
*/
439442
float scale = (*iter_)->getRunContext().getLossScale();
440443

441-
NNTR_THROW_IF(scale == 1.0f, std::invalid_argument)
444+
NNTR_THROW_IF(scale - 1.0f < 10e-6, std::invalid_argument)
442445
<< "Loss Scale Factor is 1.0f";
443446

444447
float s = scale > 1.5f ? scale * 0.5f : 1.0f;
@@ -487,18 +490,12 @@ bool NetworkGraph::backwarding(
487490
}
488491
}
489492
/** apply the gradient with the above global norm */
490-
std::cout << "======================================= update gradient "
491-
<< std::endl;
492493
for (auto w : lazy_weights) {
493-
std::cout << w->getName() << " : ";
494494
lazy_apply_grad_op(*w, iteration);
495495
}
496496
nan_count++;
497497

498-
std::cout << "====================================== update gradient finished"
499-
<< std::endl;
500498
/** @todo : handle as property : growth_interval : default --> 2000 */
501-
502499
if (nan_count > 2000) {
503500
float scale = (*iter_)->getRunContext().getLossScale();
504501
/** @todo growth_factor : default --> 2.0 */
@@ -1647,7 +1644,7 @@ void NetworkGraph::requestOptimizerVariable(
16471644
w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables(
16481645
dims, w->getName(), ":opt", TensorLifespan::MAX_LIFESPAN,
16491646
w->isGradientClipByGlobalNorm(), w->isMixedPrecision(),
1650-
Tensor::Initializer::ZEROS));
1647+
Initializer::ZEROS));
16511648
}
16521649
}
16531650
}

nntrainer/graph/network_graph.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ class NetworkGraph {
5858
/**
5959
* @brief Constructor of NeuralNetwork Graph Class
6060
* @param[in] enable_swap enable memory swap for tensor
61+
* @param[in] mode execution mode (default ExecutionMode::TRAIN)
6162
* @param[in] swap_path memory swap file path when the swap is enabled
63+
* @param[in] tensor_format define tensor format. One of NCHW and NHWC
64+
* (default NCHW)
65+
* @param[in] tensor_type It says weight type and activation type (default
66+
* FP32-FP32)
6267
*/
6368
NetworkGraph(bool enable_swap, ExecutionMode mode = ExecutionMode::TRAIN,
6469
const std::string &swap_path = "", unsigned int lookahead = 0,
@@ -207,8 +212,12 @@ class NetworkGraph {
207212
/**
208213
* @brief backwarding the network graph
209214
* @param[in] iteration current iteration number
215+
* @param[in] forwarding_op operation for the forwarding
210216
* @param[in] backwarding_op operation for the backwarding
211-
* @param[in] apply_grad_clip_op operation for applying the clip gradients
217+
* @param[in] lazy_apply_grad_op operation for applying the lazy gradients
218+
* @retval ret it is false then the gradient has NaN valude in mixed precision
219+
* training. If it is, then we need to control the loss scale factor and
220+
* compute again the derivatives.
212221
*/
213222
bool backwarding(
214223
int iteration,
@@ -496,7 +505,8 @@ class NetworkGraph {
496505
std::unordered_map<std::string, int>
497506
profile_keys; /**< profile keys based on the layer type */
498507
std::vector<Weight *>
499-
lazy_weights; /**< weights with global norm based clipping enabled */
508+
lazy_weights; /**< weights with delayed grad update, e.g., gradient
509+
clipping, loss scaling */
500510
bool is_clip_grad;
501511

502512
unsigned int nan_count;

nntrainer/layers/bn_layer.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
118118
1.0f, bias_decay, "beta", true);
119119

120120
wt_idx[BNParams::mu_b] =
121-
context.requestTensor(dim, "moviing_mean_backup", Tensor::Initializer::NONE,
122-
false, TensorLifespan::ITERATION_LIFESPAN);
121+
context.requestTensor(dim, "moviing_mean_backup", Initializer::NONE, false,
122+
TensorLifespan::ITERATION_LIFESPAN);
123123

124-
wt_idx[BNParams::var_b] = context.requestTensor(
125-
dim, "moviing_variance_backup", Tensor::Initializer::NONE, false,
126-
TensorLifespan::ITERATION_LIFESPAN);
124+
wt_idx[BNParams::var_b] =
125+
context.requestTensor(dim, "moviing_variance_backup", Initializer::NONE,
126+
false, TensorLifespan::ITERATION_LIFESPAN);
127127

128128
/**
129129
* caches the deviation -> input - avg(input)
@@ -137,8 +137,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
137137
}
138138

139139
wt_idx[BNParams::deviation] =
140-
context.requestTensor(in_dim_, "deviation", Tensor::Initializer::NONE,
141-
false, TensorLifespan::ITERATION_LIFESPAN);
140+
context.requestTensor(in_dim_, "deviation", Initializer::NONE, false,
141+
TensorLifespan::ITERATION_LIFESPAN);
142142
/** caches the inverse standard deviation */
143143
wt_idx[BNParams::invstd] =
144144
context.requestTensor(dim, "invstd", Initializer::NONE, false,
@@ -150,8 +150,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
150150
* as the output of this layer need not be stored all the time.
151151
*/
152152
wt_idx[BNParams::t_full] =
153-
context.requestTensor(in_dim_, "tensor_full", Tensor::Initializer::NONE,
154-
false, TensorLifespan::CALC_DERIV_LIFESPAN);
153+
context.requestTensor(in_dim_, "tensor_full", Initializer::NONE, false,
154+
TensorLifespan::CALC_DERIV_LIFESPAN);
155155
/**
156156
* caches variance + epsilon as well.
157157
*/

nntrainer/layers/conv2d_layer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ static void im2col(const Tensor &in, const TensorDim &kdim,
242242
unsigned int base_im_h = 0;
243243
int patch_height_end = eff_k_height + hs;
244244
/// map the patch to a single line looping through channel
245+
// We need to optimize this padding & copy. May be use multi threads, or
246+
// SIMD
245247
for (unsigned int c = 0; c < channel; ++c) {
246248
for (int h = hs; h < patch_height_end; h += dilation[0]) {
247249
if (h < 0 || in_height <= h) {

nntrainer/layers/layer_context.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class InitLayerContext {
5050
* @param name name
5151
* @param prefix_ prefix
5252
* @param max_norm max norm
53+
* @param tensor_type array including tensor format and weight, activation
54+
* type.
55+
* @param loss_scale loss scale value for mixed precision training
56+
* @param mode execution mode.
5357
*/
5458
InitLayerContext(
5559
const std::vector<TensorDim> &dim,
@@ -220,7 +224,7 @@ class InitLayerContext {
220224
* start from 0 and will always be incremental.
221225
*/
222226
unsigned int requestWeight(const TensorDim &dim, const TensorDim &dim_g,
223-
const Tensor::Initializer init,
227+
const Initializer init,
224228
const WeightRegularizer reg, const float reg_const,
225229
const float decay, const std::string &name,
226230
bool trainable = true, unsigned int out_axis = 3) {

nntrainer/layers/lstm.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -512,16 +512,16 @@ void LSTMLayer::finalize(InitLayerContext &context) {
512512
const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit,
513513
activation_tensor_type);
514514

515-
wt_idx[LSTMParams::hidden_state] = context.requestTensor(
516-
hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
517-
TensorLifespan::ITERATION_LIFESPAN);
515+
wt_idx[LSTMParams::hidden_state] =
516+
context.requestTensor(hidden_state_dim, "hidden_state", Initializer::NONE,
517+
true, TensorLifespan::ITERATION_LIFESPAN);
518518
// cell_state_dim : [ batch_size, 1, max_timestep, unit ]
519519
const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit,
520520
activation_tensor_type);
521521

522-
wt_idx[LSTMParams::cell_state] = context.requestTensor(
523-
cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
524-
TensorLifespan::ITERATION_LIFESPAN);
522+
wt_idx[LSTMParams::cell_state] =
523+
context.requestTensor(cell_state_dim, "cell_state", Initializer::NONE, true,
524+
TensorLifespan::ITERATION_LIFESPAN);
525525

526526
// ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
527527
const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit,
@@ -594,18 +594,18 @@ void LSTMLayer::finalize(InitLayerContext &context) {
594594
// reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
595595
const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep,
596596
NUM_GATE * unit, activation_tensor_type);
597-
wt_idx[LSTMParams::reverse_ifgo] = context.requestTensor(
598-
reverse_ifgo_dim, "reverse_ifgo", Tensor::Initializer::NONE, true,
599-
TensorLifespan::ITERATION_LIFESPAN);
597+
wt_idx[LSTMParams::reverse_ifgo] =
598+
context.requestTensor(reverse_ifgo_dim, "reverse_ifgo", Initializer::NONE,
599+
true, TensorLifespan::ITERATION_LIFESPAN);
600600
}
601601

602602
if (dropout_rate > epsilon) {
603603
// dropout_mask_dim = [ batch, 1, time_iteration, unit ]
604604
const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit,
605605
activation_tensor_type);
606-
wt_idx[LSTMParams::dropout_mask] = context.requestTensor(
607-
dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false,
608-
TensorLifespan::ITERATION_LIFESPAN);
606+
wt_idx[LSTMParams::dropout_mask] =
607+
context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
608+
false, TensorLifespan::ITERATION_LIFESPAN);
609609
}
610610

611611
if (context.getActivationDataType() == TensorDim::DataType::FP32) {

nntrainer/layers/pooling2d_layer.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,15 @@ void Pooling2DLayer::finalize(InitLayerContext &context) {
126126
auto helper_dim = in_dim;
127127
helper_dim.setDataType(ml::train::TensorDim::DataType::FP32);
128128
pool_helper_idx =
129-
context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE,
130-
false, TensorLifespan::ITERATION_LIFESPAN);
129+
context.requestTensor(helper_dim, "helper_idx", Initializer::NONE, false,
130+
TensorLifespan::ITERATION_LIFESPAN);
131131
pool_helper_size.resize(helper_dim.batch() * helper_dim.channel());
132132
} else {
133133
auto helper_dim = out_dim;
134134
helper_dim.setDataType(ml::train::TensorDim::DataType::FP32);
135135
pool_helper_idx =
136-
context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE,
137-
false, TensorLifespan::ITERATION_LIFESPAN);
136+
context.requestTensor(helper_dim, "helper_idx", Initializer::NONE, false,
137+
TensorLifespan::ITERATION_LIFESPAN);
138138
}
139139
}
140140

nntrainer/models/neuralnet.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,6 @@ int NeuralNetwork::train_run(
11601160
auto epochs = getEpochs();
11611161
ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.",
11621162
epoch_idx + 1, getEpochs());
1163-
epoch_idx = 0;
11641163
for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
11651164
if (stop_cb(stop_user_data)) {
11661165
--epoch_idx;

nntrainer/tensor/blas_interface.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,8 +874,7 @@ void scopy(const unsigned int N, const float *X, const int incX, float *Y,
874874
#ifdef BLAS_NUM_THREADS
875875
openblas_set_num_threads(BLAS_NUM_THREADS);
876876
#endif
877-
// cblas_scopy(N, (float*)(X), incX, (float*)(Y), incY);
878-
// replace cblas scopy with raw temporary.
877+
// cblas_scopy(N, X, incX, Y, incY);
879878
for (unsigned int i = 0; i < N; ++i)
880879
Y[i * incY] = X[i * incX];
881880
#else

nntrainer/tensor/char_tensor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ class CharTensor : public TensorBase {
231231
* @return std::string of tensor data type (QINT8)
232232
*/
233233
std::string getStringDataType() const override { return "QINT8"; }
234+
235+
/**
236+
* @copydoc Tensor::isValid()
237+
*/
238+
bool isValid() const override { return true; }; // NYI
234239
};
235240

236241
} // namespace nntrainer

nntrainer/tensor/float_tensor.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ void FloatTensor::setZero() {
150150
// sscal(size(), 0, getData<float>(), 1);
151151
/// @note we cannot use sscal, when we set zero. if the data is inf or
152152
/// NaN, then the inf or NaN still remain.
153-
memset(getData<float>(), 0, sizeof(float) * size());
153+
memset((float *)getData(), 0, sizeof(float) * size());
154154
} else {
155155
/// @todo implement apply_i
156156
// apply_i<float>([](float val) -> float { return 0; });
@@ -1210,8 +1210,8 @@ void FloatTensor::apply_broadcast(
12101210
return apply_broadcast_util(m, v_func, output, this->computeBroadcastInfo(m));
12111211
}
12121212

1213-
bool Tensor::isValid() const {
1214-
return is_valid(dim.getDataLen(), Tdatatype::FP32, getData<float>());
1213+
bool FloatTensor::isValid() const {
1214+
return is_valid(dim.getDataLen(), Tdatatype::FP32, (float *)getData());
12151215
}
12161216

12171217
} // namespace nntrainer

nntrainer/tensor/float_tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ class FloatTensor : public TensorBase {
511511
/**
512512
* @copydoc Tensor::isValid()
513513
*/
514-
bool Tensor::isValid() const;
514+
bool isValid() const override;
515515
};
516516

517517
} // namespace nntrainer

nntrainer/tensor/half_tensor.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ void HalfTensor::setZero() {
149149
// sscal(size(), 0, (_FP16 *)getData(), 1);
150150
/// @note we cannot use sscal, when we set zero. if the data is inf or
151151
/// NaN, then the inf or NaN still remain.
152-
memset(getData<_FP16>(), 0, sizeof(_FP16) * size());
152+
memset((_FP16 *)getData(), 0, sizeof(_FP16) * size());
153153
} else {
154154
/// @todo implement apply_i
155155
// apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; });
@@ -1176,8 +1176,8 @@ void HalfTensor::apply_broadcast_util(
11761176
}
11771177
}
11781178

1179-
bool Tensor::isValid() const {
1180-
return is_valid(dim.getDataLen(), Tdatatype::FP16, getData<_FP16>());
1179+
bool HalfTensor::isValid() const {
1180+
return is_valid(dim.getDataLen(), Tdatatype::FP16, (_FP16 *)getData());
11811181
}
11821182

11831183
} // namespace nntrainer

nntrainer/tensor/half_tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ class HalfTensor : public TensorBase {
502502
/**
503503
* @copydoc Tensor::isValid()
504504
*/
505-
bool Tensor::isValid() const;
505+
bool isValid() const override;
506506
};
507507

508508
} // namespace nntrainer

0 commit comments

Comments
 (0)