diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 95208bae3..c9399e819 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -234,6 +234,7 @@ def make_train (iter_index, training_reuse_stop_batch = 400000 training_reuse_start_lr = jdata.get('training_reuse_start_lr', 1e-4) + training_reuse_decay_steps = jdata.get('training_reuse_decay_steps', None) training_reuse_start_pref_e = jdata.get('training_reuse_start_pref_e', 0.1) training_reuse_start_pref_f = jdata.get('training_reuse_start_pref_f', 100) model_devi_activation_func = jdata.get('model_devi_activation_func', None) @@ -382,6 +383,7 @@ def make_train (iter_index, if jinput['loss'].get('start_pref_f') is not None: jinput['loss']['start_pref_f'] = training_reuse_start_pref_f jinput['learning_rate']['start_lr'] = training_reuse_start_lr + jinput['learning_rate']['decay_steps'] = training_reuse_decay_steps for ii in range(numb_models) :