-
Notifications
You must be signed in to change notification settings - Fork 8
/
patch_main.diff
237 lines (204 loc) · 10.2 KB
/
patch_main.diff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
--- original/main_pretrain.py 2022-05-18 15:39:18.838208217 +0900
+++ main_pretrain.py 2022-05-17 23:13:22.034457577 +0900
@@ -15,44 +15,57 @@
import os
import time
from pathlib import Path
+import re
+import subprocess
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
-import torchvision.transforms as transforms
-import torchvision.datasets as datasets
+from util import datasets
import timm
-assert timm.__version__ == "0.3.2" # version check
+# assert timm.__version__ == "0.3.2" # version check
import timm.optim.optim_factory as optim_factory
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
-import models_mae
+from msm_mae import models_mae
-from engine_pretrain import train_one_epoch
+from msm_mae.engine_pretrain import train_one_epoch
def get_args_parser():
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
+ parser.add_argument('output_dir', type=str,
+ help='path where to save, empty for no saving')
parser.add_argument('--batch_size', default=64, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
- parser.add_argument('--epochs', default=400, type=int)
+ parser.add_argument('--epochs', default=100, type=int) # original default=800
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
+ parser.add_argument('--save_freq', default=20, type=int)
+ parser.add_argument('--eval_freq', default=20, type=int)
# Model parameters
- parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
+ parser.add_argument('--model', default='mae_vit_base_patch16x16', type=str, metavar='MODEL',
help='Name of model to train')
- parser.add_argument('--input_size', default=224, type=int,
+ parser.add_argument('--input_size', default='80x208', type=str,
help='images input size')
parser.add_argument('--mask_ratio', default=0.75, type=float,
help='Masking ratio (percentage of removed patches).')
+ parser.add_argument('--no_cls_token', action='store_true',
+ help='Do not use [CLS] token if set.')
+ parser.set_defaults(no_cls_token=False)
+
+ parser.add_argument('--dec_pos_2d', action='store_true',
+ help='Use a 2-D positional embeddings on decoder. It is a 1-D by default.')
+ parser.set_defaults(dec_pos_2d=False)
+
parser.add_argument('--norm_pix_loss', action='store_true',
help='Use (per-patch) normalized pixels as targets for computing loss')
parser.set_defaults(norm_pix_loss=False)
@@ -63,21 +76,23 @@
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
- parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
- help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
+ parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 128')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
- parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
- help='epochs to warmup LR')
+ parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
+ help='epochs to warmup LR') # original default=40
# Dataset parameters
- parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
+ parser.add_argument('--data_path', default='audio_lms', type=str,
help='dataset path')
+ parser.add_argument('--dataset', default='trainingfiles', type=str,
+ help='dataset definition')
+ parser.add_argument('--norm_stats', default='None', type=str,
+ help='dataset normalization stats')
- parser.add_argument('--output_dir', default='./output_dir',
- help='path where to save, empty for no saving')
- parser.add_argument('--log_dir', default='./output_dir',
+ parser.add_argument('--log_dir', default='',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
@@ -104,6 +119,32 @@
return parser
+import matplotlib.pyplot as plt
+
+
+def visualize_reconstruction(args, device, model, save_path):
+ ds, files = datasets.build_viz_dataset(args)
+ if len(ds) == 0:
+ print(f'(Skipped visualization which require samples in {args.data_path}/vis_samples folder.)')
+ return
+ batch = torch.stack([ds[i] for i in range(len(ds))])
+ model.eval()
+ with torch.no_grad():
+ _, recons, _, masks = model.forward_viz(batch.to(device))
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ for i, file in enumerate(files):
+ # as .npy
+ np.save(f"{save_path}/recon_{Path(file).name}", recons[i].cpu().numpy())
+ # as .png
+ fig = plt.figure(figsize=[12, 8 if batch[0].shape[-1] < 310 else 6])
+ for j, img in enumerate([batch[i][0], recons[i][0], masks[i]]):
+ ax = fig.add_subplot(3, 1, j + 1)
+ ax.imshow(img.cpu().numpy(), origin='lower')
+ plt.margins(x=0, y=0)
+ fig.savefig(f'{save_path}/recon_{Path(file).stem}.png', bbox_inches = 'tight')
+
+
def main(args):
misc.init_distributed_mode(args)
@@ -119,13 +160,7 @@
cudnn.benchmark = True
- # simple augmentation
- transform_train = transforms.Compose([
- transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
- dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
+ dataset_train = datasets.build_dataset(args)
print(dataset_train)
if True: # args.distributed:
@@ -138,6 +173,8 @@
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
+ eff_batch_size = args.batch_size * args.accum_iter * torch.cuda.device_count()
+
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
@@ -146,34 +183,31 @@
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
- batch_size=args.batch_size,
+ batch_size=eff_batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
# define the model
- model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
+ model = models_mae.__dict__[args.model](img_size=args.input_size, norm_pix_loss=args.norm_pix_loss,
+ use_cls_token=(not args.no_cls_token), use_2d_dec_pos_embd=args.dec_pos_2d)
model.to(device)
model_without_ddp = model
print("Model = %s" % str(model_without_ddp))
- eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
-
if args.lr is None: # only base_lr is specified
- args.lr = args.blr * eff_batch_size / 256
+ args.lr = args.blr * eff_batch_size / 128
- print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
+ print("base lr: %.2e" % (args.lr * 128 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
- if args.distributed:
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
- model_without_ddp = model.module
+ model = torch.nn.DataParallel(model).to(device)
# following timm: set wd as 0 for bias and norm layers
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
@@ -186,6 +220,7 @@
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
+ epoch1 = epoch + 1
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
@@ -194,10 +229,17 @@
log_writer=log_writer,
args=args
)
- if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs):
+ if args.output_dir and (epoch1 % args.save_freq == 0 or epoch1 == args.epochs):
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
- loss_scaler=loss_scaler, epoch=epoch)
+ loss_scaler=loss_scaler, epoch=epoch1)
+ # visualize reconstructions
+ out_dir = Path(args.output_dir)/str(epoch)
+ visualize_reconstruction(args, device, model_without_ddp, out_dir)
+ # run the external evaluator
+ if epoch1 % args.eval_freq == 0 or epoch1 == args.epochs:
+ abspath = Path(f'{args.output_dir}/checkpoint-{epoch1}.pth').absolute()
+ subprocess.Popen(['/bin/bash', './quick_eval.sh', abspath])
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,}
@@ -218,4 +260,10 @@
args = args.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ if not args.log_dir:
+ args.log_dir = args.output_dir
+ args.input_size = [int(x) for x in args.input_size.split('x')]
+ args.patch_size = [int(x) for x in re.match(r'.+\_patch([0-9]+x[0-9]+)$', str(args.model)).group(1).split('x')]
+ args.norm_stats = eval(args.norm_stats) if args.norm_stats else None
+ print(args)
main(args)