Skip to content

Commit

Permalink
steps for save
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxMax2016 committed Jun 12, 2023
1 parent 7f70a58 commit e4d311e
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Necessary pre-processing:
- 1 accompaniment separation
- 2 band extension
- 3 sound quality improvement
- 4 cut audio, less than 30 seconds for whisper💗
- 4 cut audio, less than 30 seconds for whisper

then put the dataset into the dataset_raw directory according to the following file structure
```shell
Expand Down Expand Up @@ -88,7 +88,7 @@ dataset_raw
> python prepare/preprocess_train.py
- 8, training file debugging
> python prepare/preprocess_zzz.py
> python prepare/preprocess_zzz.py -c configs/maxgan.yaml
```shell
data_svc/
Expand Down
4 changes: 2 additions & 2 deletions configs/maxgan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ dist_config:
#############################
log:
info_interval: 100
eval_interval: 1
save_interval: 5
eval_interval: 1000
save_interval: 1000
num_audio: 6
pth_dir: 'chkpt'
log_dir: 'logs'
26 changes: 13 additions & 13 deletions utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,16 @@ def train(rank, args, chkpt_path, hp, hp_str):
with torch.no_grad():
validate(hp, args, model_g, model_d, valloader, stft, writer, step, device)

if rank == 0 and epoch % hp.log.save_interval == 0:
save_path = os.path.join(pt_dir, '%s_%04d.pt'
% (args.name, epoch))
torch.save({
'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(),
'model_d': (model_d.module if args.num_gpus > 1 else model_d).state_dict(),
'optim_g': optim_g.state_dict(),
'optim_d': optim_d.state_dict(),
'step': step,
'epoch': epoch,
'hp_str': hp_str,
}, save_path)
logger.info("Saved checkpoint to: %s" % save_path)
if rank == 0 and step % hp.log.save_interval == 0:
save_path = os.path.join(pth_dir, '%s_%08d.pt'
% (args.name, step))
torch.save({
'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(),
'model_d': (model_d.module if args.num_gpus > 1 else model_d).state_dict(),
'optim_g': optim_g.state_dict(),
'optim_d': optim_d.state_dict(),
'step': step,
'epoch': epoch,
'hp_str': hp_str,
}, save_path)
logger.info("Saved checkpoint to: %s" % save_path)
4 changes: 2 additions & 2 deletions utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def log_training(self, g_loss, d_loss, mel_loss, stft_loss, score_loss, step):
def log_validation(self, mel_loss, generator, discriminator, step):
self.add_scalar('validation/mel_loss', mel_loss, step)

self.log_histogram(generator, step)
self.log_histogram(discriminator, step)
# self.log_histogram(generator, step)
# self.log_histogram(discriminator, step)
if self.is_first:
self.is_first = False

Expand Down

0 comments on commit e4d311e

Please sign in to comment.