utils.trainer.TrainLoop will run training loop - compatible with torch.distributed.run
- Complete
config/train.py'sTrainSettings(orYourSettings) class.- this setting class is compatible with argparse and json.
- Complete
data/__init__.py'sload_data_from_argsfunction. - Complete
modelpackage. - Complete
utils/initialization.py'screate_model_from_configfunction. - Complete some method of
utils/trainer.py'sTrainLoopclass.log_loss_dictmethod: logging function of loss values dict.compute_lossesmethod: calculatelossesfrommicro_batchand TrainLoop varsbackward_from_lossesmethod: make singlelossfromlosses, and runloss.backward()__init__method: add your extra values to TrainLoop vars if needed.
- Complete
run/train.pyto make sense with all code signatures you modified. - Modify setting json file, after copying default train settings with command,
python3 -c "from config;train import TrainSettings as T; print(T().json(indent=2))" >> train_config.json
after completion, you can run train script with
python3 -m run.train --distributed --config_json train_config.json@inproceedings{gong2022diffuseq,
author = {Gong, Shansan and Li, Mukai and Feng, Jiangtao and Wu, Zhiyong and Kong, Lingpeng},
booktitle = {International Conference on Learning Representations, ICLR},
title = {{DiffuSeq}: Sequence to Sequence Text Generation with Diffusion Models},
year = 2023
}