From e4d311eeb0bd52accce03054a67680304b9ab1cb Mon Sep 17 00:00:00 2001 From: MaxMax2016 <525942103@qq.com> Date: Mon, 12 Jun 2023 21:51:17 +0800 Subject: [PATCH] steps for save --- README.md | 4 ++-- configs/maxgan.yaml | 4 ++-- utils/train.py | 26 +++++++++++++------------- utils/writer.py | 4 ++-- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 1590029..29c9918 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Necessary pre-processing: - 1 accompaniment separation - 2 band extension - 3 sound quality improvement -- 4 cut audio, less than 30 seconds for whisper💗 +- 4 cut audio, less than 30 seconds for whisper then put the dataset into the dataset_raw directory according to the following file structure ```shell @@ -88,7 +88,7 @@ dataset_raw > python prepare/preprocess_train.py - 8, training file debugging - > python prepare/preprocess_zzz.py + > python prepare/preprocess_zzz.py -c configs/maxgan.yaml ```shell data_svc/ diff --git a/configs/maxgan.yaml b/configs/maxgan.yaml index 9694d3e..0702345 100644 --- a/configs/maxgan.yaml +++ b/configs/maxgan.yaml @@ -52,8 +52,8 @@ dist_config: ############################# log: info_interval: 100 - eval_interval: 1 - save_interval: 5 + eval_interval: 1000 + save_interval: 1000 num_audio: 6 pth_dir: 'chkpt' log_dir: 'logs' diff --git a/utils/train.py b/utils/train.py index ca6f4af..78a0814 100644 --- a/utils/train.py +++ b/utils/train.py @@ -204,16 +204,16 @@ def train(rank, args, chkpt_path, hp, hp_str): with torch.no_grad(): validate(hp, args, model_g, model_d, valloader, stft, writer, step, device) - if rank == 0 and epoch % hp.log.save_interval == 0: - save_path = os.path.join(pt_dir, '%s_%04d.pt' - % (args.name, epoch)) - torch.save({ - 'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(), - 'model_d': (model_d.module if args.num_gpus > 1 else model_d).state_dict(), - 'optim_g': optim_g.state_dict(), - 'optim_d': optim_d.state_dict(), - 'step': step, - 'epoch': epoch, - 'hp_str': hp_str, - }, save_path) - logger.info("Saved checkpoint to: %s" % save_path) + if rank == 0 and step % hp.log.save_interval == 0: + save_path = os.path.join(pth_dir, '%s_%08d.pt' + % (args.name, step)) + torch.save({ + 'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(), + 'model_d': (model_d.module if args.num_gpus > 1 else model_d).state_dict(), + 'optim_g': optim_g.state_dict(), + 'optim_d': optim_d.state_dict(), + 'step': step, + 'epoch': epoch, + 'hp_str': hp_str, + }, save_path) + logger.info("Saved checkpoint to: %s" % save_path) diff --git a/utils/writer.py b/utils/writer.py index e9766f7..b2697f0 100644 --- a/utils/writer.py +++ b/utils/writer.py @@ -21,8 +21,8 @@ def log_training(self, g_loss, d_loss, mel_loss, stft_loss, score_loss, step): def log_validation(self, mel_loss, generator, discriminator, step): self.add_scalar('validation/mel_loss', mel_loss, step) - self.log_histogram(generator, step) - self.log_histogram(discriminator, step) + # self.log_histogram(generator, step) + # self.log_histogram(discriminator, step) if self.is_first: self.is_first = False