Skip to content

Trainer

iluise edited this page Apr 17, 2024 · 1 revision

Introduction

Screenshot 2024-03-28 at 17 23 26

The training can be launched with:

srun python atmorep/core/train.py

due to memory constraints, the full configuration with 6 fields requires different settings in Juelich. To train the coupled system we advice you to use:

srun python atmorep/core/train_multi.py

See the glossary page for details on each of the options. image

Training strategies

The supported straining strategies are declared in atmorep/training/bert.py:


if not BERT_strategy :
    BERT_strategy = cf.BERT_strategy

  if BERT_strategy == 'BERT' :
    bert_f = prepare_batch_BERT_field
  elif BERT_strategy == 'forecast' :
    bert_f = prepare_batch_BERT_forecast_field
  elif BERT_strategy == 'temporal_interpolation' :
    bert_f = prepare_batch_BERT_temporal_field
  else :
    assert False

an intuitive sketch of the two most used strategies is reported in the Figure above. Each supported strategy is briefly described below.

BERT

This is an adaptation of the BERT style masking training protocol. The tokens are masked randomly within the loaded source cube according to masking rates defined in the cf.fields[field][-1] parameter. the parameters of interest are: [ total masking rate, rate masking, rate noising, rate for multi-res distortion], which control:

  • the total masking rate controls ?????
  • the rate masking parameter controls the fraction of masked tokens over the total. Very high masking ratios, eg. 90% have been shown to produce more robust results.
  • the rate noising parameter controls the fraction of masked tokens that are masked using gaussian noise
  • the rate for multi-res distortion parameter controls the number of ????????

example:

 cf.BERT_strategy = 'BERT'     
 cf.BERT_fields_synced = False   # apply synchronized / identical masking to all fields 
                                 # (fields need to have same BERT params for this to have effect)
 cf.BERT_mr_max = 2              # maximum reduction rate for resolution

This option is the default option used for training the AtmoRep core model.

Forecast

The forecast option aims at optimising the training for the forecasting task. It has been used to fine-tune AtmoRep for forecasting. It is considered as a special case of the BERT strategy, in which the source cube gets all the last time step(s) completely masked for prediction.

Parameters:

  • number of loaded tokens in time: cf.fields[i][3][0] (see glossary for details)
  • number of forecasted tokens: cf.forecast_num_tokens = 2 (default = 2) Important: the first parameter controls the total number of tokens loaded into source. The second parameter controls the number of masked tokens in time, within the loaded cube. No roll-out is implemented at the moment so forecast_num_tokens cannot be larger than cf.fields[i][3][0]!

Temporal Interpolation

The temporal_interpolation option aims at optimising the training for the temporal interpolation task. It has been used to fine tune AtmoRep for temporal interpolation. It is again considered as a special case for the BERT strategy, in which the intermediate tokens in time are masked. Masked tokens:

idx_time_mask = int( np.floor(num_tokens[0] / 2.))  # TODO: masking of multiple time steps

where num_tokens[0] represents the number of loaded tokens along the time dimension (12 by default).
Important: be careful if you use this option. it might be out-dated!

**Parameters: **

  • number of loaded tokens in time: cf.fields[i][3][0]

Supported Losses

The loss computation is modular and defined by the following parameter in the config:

cf.losses = ['mse', 'stats']

The final loss in this case will be the sum of the two terms. The supported losses are defined in trainer.py ( def loss( self, preds, batch_idx = 0):

MSE Loss

  • option: mse
  • Description: MSE loss is based on the usual mean square error.

MSE ensemble Loss

  • option: mse_ensemble
  • Description: it is the MSE loss computed for each ensemble member separately and then averaged.

Statistical Loss

  • option: stats
  • Description: statistical loss. Generalized cross entroy loss for continuous distributions. Refer to paper for a detailed explanation

Statistical Loss - Area

  • option: stats_area
  • Description: based on torch.special.erf (error function), computed as link.

CRPS Loss

  • option: crps
  • Description: Loss based on the continuous ranked probability score. see Eq. A2 in S. Rasp and S. Lerch. Neural networks for postprocessing ensemble weather forecasts. Monthly Weather Review, 146(11):3885 – 3900, 2018.

Output

The output of train.py is:

  1. A new directory inside atmorep/results named with the wandb_id (weights and biases id) of the run. The folder contains a json file with the configuration.
  2. A new directory inside atmorep/models named with the wandb_id (weights and biases id) of the run and containing the model weights.

(JSC specific) Slurm

Depending on the script you use, when you launch the runs on slurm you might have three additional files: 3. An output file output/output_XXXXXX.txt with the log. XXXXXX here is the number associated with the slurm ID of the job (remember to note it down when you launch your job!). The wandb_id is saved within this file, just grep it. 4. Two log files in logs/*XXXXX*.err and logs/*XXXXX*.out again named with the slurm ID of the job containing the errors and the output from the shell scripts.