-
Notifications
You must be signed in to change notification settings - Fork 10
/
train_odin.py
612 lines (540 loc) · 24.5 KB
/
train_odin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
MaskFormer Training Script.
This script is a simplified version of the training script in detectron2/tools.
"""
import warnings
warnings.filterwarnings('ignore')
import copy
import itertools
import logging
import os
import gc
import weakref
import time
from collections import OrderedDict
from typing import Any, Dict, List, Set
import torch
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from torch.nn.parallel import DistributedDataParallel
from detectron2.config import get_cfg
from detectron2.engine import (
DefaultTrainer,
default_argument_parser,
default_setup,
launch,
AMPTrainer,
SimpleTrainer
)
from detectron2.evaluation import (
DatasetEvaluator,
COCOEvaluator,
inference_on_dataset,
)
from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
from detectron2.solver.build import maybe_add_gradient_clipping
from detectron2.utils.logger import setup_logger
# MaskFormer
from odin.data_video.dataset_mapper_coco import COCOInstanceNewBaselineDatasetMapper
from odin import (
ScannetDatasetMapper,
Scannet3DEvaluator,
ScannetSemantic3DEvaluator,
COCOEvaluatorMemoryEfficient,
add_maskformer2_video_config,
add_maskformer2_config,
build_detection_train_loader,
build_detection_test_loader,
get_detection_dataset_dicts,
build_detection_train_loader_multi_task,
)
from odin.data_video.build import merge_datasets
from odin.global_vars import SCANNET_LIKE_DATASET
from torchinfo import summary
torch.multiprocessing.set_sharing_strategy('file_system')
import ipdb
st = ipdb.set_trace
class OneCycleLr_D2(torch.optim.lr_scheduler.OneCycleLR):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def state_dict(self):
return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch}
def create_ddp_model(model, *, fp16_compression=False, find_unused_parameters=False, **kwargs):
"""
Create a DistributedDataParallel model if there are >1 processes.
Args:
model: a torch.nn.Module
fp16_compression: add fp16 compression hooks to the ddp object.
See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
""" # noqa
if comm.get_world_size() == 1:
return model
if "device_ids" not in kwargs:
kwargs["device_ids"] = [comm.get_local_rank()]
ddp = DistributedDataParallel(model, **kwargs, find_unused_parameters=find_unused_parameters)
if fp16_compression:
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
return ddp
class Trainer(DefaultTrainer):
"""
Extension of the Trainer class adapted to MaskFormer.
"""
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
super(DefaultTrainer, self).__init__()
# super().__init__()
logger = logging.getLogger("detectron2")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
model = create_ddp_model(model, broadcast_buffers=False, find_unused_parameters=cfg.MULTI_TASK_TRAINING or cfg.FIND_UNUSED_PARAMETERS)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
trainer=weakref.proxy(self),
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
@classmethod
def build_evaluator(
cls, cfg, dataset_name,
output_folder=None, use_2d_evaluators_only=False,
use_3d_evaluators_only=False,
):
"""
Create evaluator(s) for a given dataset.
This uses the special metadata "evaluator_type" associated with each builtin dataset.
For your own dataset, you can simply create an evaluator manually in your
script and do not have to worry about the hacky if-else logic here.
"""
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
os.makedirs(output_folder, exist_ok=True)
evaluators = []
if cfg.TEST.EVAL_3D and cfg.MODEL.DECODER_3D and not use_2d_evaluators_only:
if cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
evaluators.append(
ScannetSemantic3DEvaluator(
dataset_name,
output_dir=output_folder,
eval_sparse=cfg.TEST.EVAL_SPARSE,
cfg=cfg
))
if cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
evaluators.append(
Scannet3DEvaluator(
dataset_name,
output_dir=output_folder,
eval_sparse=cfg.TEST.EVAL_SPARSE,
cfg=cfg
))
if (cfg.TEST.EVAL_2D or cfg.EVAL_PER_IMAGE) and not use_3d_evaluators_only:
if cfg.INPUT.ORIGINAL_EVAL:
print("Using original COCO Eval, potentially is RAM hungry")
evaluators.append(COCOEvaluator(dataset_name, output_dir=output_folder, use_fast_impl=False))
else:
evaluators.append(COCOEvaluatorMemoryEfficient(
dataset_name, output_dir=output_folder, use_fast_impl=False,
per_image_eval=cfg.EVAL_PER_IMAGE, evaluate_subset=cfg.EVALUATE_SUBSET,))
return evaluators
@classmethod
def build_train_loader(cls, cfg):
if cfg.MULTI_TASK_TRAINING:
if cfg.TRAIN_3D:
if len(cfg.DATASETS.TRAIN_3D) > 1:
dataset_dicts = [get_detection_dataset_dicts(
cfg.DATASETS.TRAIN_3D[i],
proposal_files=None,
) for i in range(len(cfg.DATASETS.TRAIN_3D))]
mappers = [
ScannetDatasetMapper(cfg, is_train=True, dataset_name=dataset_name, dataset_dict=dataset_dict) for dataset_name, dataset_dict in zip(cfg.DATASETS.TRAIN, dataset_dicts)
]
dataset_dict_3d = merge_datasets(dataset_dicts, mappers, balance=cfg.BALANCE_3D_DATASETS)
mapper_3d = None
else:
dataset_dict_3d = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN_3D,
proposal_files=None,
)
mapper_3d = ScannetDatasetMapper(
cfg, is_train=True,
dataset_name=cfg.DATASETS.TRAIN_3D[0],
dataset_dict=dataset_dict_3d
)
else:
dataset_dict_3d = None
mapper_3d = None
if cfg.TRAIN_2D:
dataset_dict_2d = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN_2D,
proposal_files=None,
)
if 'coco' in cfg.DATASETS.TRAIN_2D[0]:
mapper_2d = COCOInstanceNewBaselineDatasetMapper(cfg, True, dataset_name=cfg.DATASETS.TRAIN_2D[0])
else:
mapper_2d = ScannetDatasetMapper(
cfg, is_train=True,
dataset_name=cfg.DATASETS.TRAIN_2D[0],
dataset_dict=dataset_dict_2d,
force_decoder_2d=cfg.FORCE_DECODER_3D,
frame_left=0,
frame_right=0,
decoder_3d=False
)
else:
dataset_dict_2d = None
mapper_2d = None
return build_detection_train_loader_multi_task(
cfg, mapper_3d=mapper_3d, mapper_2d=mapper_2d,
dataset_3d=dataset_dict_3d, dataset_2d=dataset_dict_2d
)
else:
dataset_name = cfg.DATASETS.TRAIN[0]
scannet_like = False
for scannet_like_dataset in SCANNET_LIKE_DATASET:
if scannet_like_dataset in dataset_name:
scannet_like = True
break
if scannet_like:
dataset_dict = get_detection_dataset_dicts(
dataset_name,
proposal_files=None,
)
mapper = ScannetDatasetMapper(cfg, is_train=True, dataset_name=dataset_name, dataset_dict=dataset_dict)
return build_detection_train_loader(cfg, mapper=mapper, dataset=dataset_dict)
elif 'coco' in dataset_name:
mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True, dataset_name=dataset_name)
return build_detection_train_loader(cfg, mapper=mapper)
else:
raise NotImplementedError
@classmethod
def build_test_loader(cls, cfg, dataset_name):
scannet_like = False
for scannet_like_dataset in SCANNET_LIKE_DATASET:
if scannet_like_dataset in dataset_name:
scannet_like = True
break
if scannet_like:
dataset_dict = get_detection_dataset_dicts(
[dataset_name],
proposal_files=[
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]
]
if cfg.MODEL.LOAD_PROPOSALS
else None,
subsample_data=cfg.TEST.SUBSAMPLE_DATA if dataset_name in cfg.DATASETS.TEST_SUBSAMPLED else None,
)
mapper = ScannetDatasetMapper(
cfg, is_train=False, dataset_name=dataset_name, dataset_dict=dataset_dict,
decoder_3d=False if dataset_name in cfg.DATASETS.TEST_2D_ONLY else cfg.MODEL.DECODER_3D,
)
return build_detection_test_loader(cfg, mapper=mapper, dataset=dataset_dict)
elif 'coco' in dataset_name:
dataset_dict = get_detection_dataset_dicts(
[dataset_name],
proposal_files=[
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]
]
if cfg.MODEL.LOAD_PROPOSALS
else None,
subsample_data=cfg.TEST.SUBSAMPLE_DATA if dataset_name in cfg.DATASETS.TEST_SUBSAMPLED else None,
)
mapper = COCOInstanceNewBaselineDatasetMapper(cfg, is_train=False, dataset_name=dataset_name)
return build_detection_test_loader(cfg, mapper=mapper, dataset=dataset_dict)
else:
raise NotImplementedError
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
if cfg.SOLVER.LR_SCHEDULER_NAME == "onecyclelr":
return OneCycleLr_D2(
optimizer,
max_lr=cfg.SOLVER.BASE_LR,
total_steps=cfg.SOLVER.MAX_ITER,
)
else:
return build_lr_scheduler(cfg, optimizer)
@classmethod
def build_optimizer(cls, cfg, model):
weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
defaults = {}
defaults["lr"] = cfg.SOLVER.BASE_LR
defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
print(summary(model))
panet_resnet_layers = ['cross_view_attn', 'res_to_trans', 'trans_to_res']
panet_swin_layers = ['cross_view_attn', 'cross_layer_norm', 'res_to_trans', 'trans_to_res']
if cfg.MODEL.BACKBONE.NAME == "build_resnet_backbone":
backbone_panet_layers = panet_resnet_layers
elif cfg.MODEL.BACKBONE.NAME == "D2SwinTransformer":
backbone_panet_layers = panet_swin_layers
else:
raise NotImplementedError
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
if "backbone" in module_name :
# panet layers are initialize from scratch so use default lr
panet_found = False
for panet_name in backbone_panet_layers:
if panet_name in module_name:
hyperparams["lr"] = hyperparams["lr"]
panet_found = True
break
if not panet_found:
hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
if (
"relative_position_bias_table" in module_param_name
or "absolute_pos_embed" in module_param_name
):
print(module_param_name)
hyperparams["weight_decay"] = 0.0
if isinstance(module, norm_module_types):
hyperparams["weight_decay"] = weight_decay_norm
if isinstance(module, torch.nn.Embedding):
hyperparams["weight_decay"] = weight_decay_embed
params.append({"params": [value], **hyperparams})
def maybe_add_full_model_gradient_clipping(optim):
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
optimizer_type = cfg.SOLVER.OPTIMIZER
if optimizer_type == "SGD":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
)
elif optimizer_type == "ADAMW":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
return optimizer
@classmethod
def test(cls, cfg, model, evaluators=None):
"""
Evaluate the given model. The given model is expected to already contain
weights to evaluate.
Args:
cfg (CfgNode):
model (nn.Module):
evaluators (list[DatasetEvaluator] or None): if None, will call
:meth:`build_evaluator`. Otherwise, must have the same length as
``cfg.DATASETS.TEST``.
Returns:
dict: a dict of result metrics
"""
from torch.cuda.amp import autocast
logger = logging.getLogger(__name__)
if isinstance(evaluators, DatasetEvaluator):
evaluators = [evaluators]
if evaluators is not None:
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
len(cfg.DATASETS.TEST), len(evaluators)
)
results = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
data_loader = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(
cfg, dataset_name, use_2d_evaluators_only=dataset_name in cfg.DATASETS.TEST_2D_ONLY if cfg.MULTI_TASK_TRAINING else False,
use_3d_evaluators_only=dataset_name in cfg.DATASETS.TEST_3D_ONLY if cfg.MULTI_TASK_TRAINING else False,)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results[dataset_name] = {}
continue
with autocast():
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
gc.collect()
torch.cuda.empty_cache()
if not cfg.MULTI_TASK_TRAINING:
#format for writer
if len(results) == 1:
results_structured = list(results.values())[0]
elif len(results) == 2:
# find a better way than hard-coding here
results_val = results[cfg.DATASETS.TEST[0]].copy()
suffix = '_full' if 'single' in cfg.DATASETS.TEST[0] else ''
suffix += f'_{dataset_name.split("_")[0]}'
results_val = {f'val{suffix}'+k: v for k, v in results_val.items()}
# st()
try:
if cfg.EVAL_PER_IMAGE:
results_val[f'val_{dataset_name.split("_")[0]}segm'] = results_val[f'val{suffix}segm']
del results_val[f'val{suffix}segm']
except:
print("Error in Logging")
print(results_val.keys(), print(f'val{suffix}segm'))
results_train = results[cfg.DATASETS.TEST[1]].copy()
results_train = {f'train{suffix}'+k: v for k, v in results_train.items()}
try:
if cfg.EVAL_PER_IMAGE:
results_train[f'train_{dataset_name.split("_")[0]}segm'] = results_train[f'train{suffix}segm']
del results_train[f'train{suffix}segm']
except:
print(results_train.keys(), print(f'train{suffix}segm'))
results_structured = {}
results_structured.update(results_train)
results_structured.update(results_val)
else:
for dataset_name in cfg.DATASETS.TEST:
results_structured = {}
suffix = 'train_full' if 'train_eval' in dataset_name else 'val_full'
results_val = results[dataset_name].copy()
results_val = {f'{suffix}_{dataset_name.split("_")[0]}'+k: v for k, v in results_val.items()}
results_structured.update(results_val)
else:
results_structured = {}
for dataset_name in cfg.DATASETS.TEST_3D_ONLY:
if dataset_name in results:
suffix = 'train_full' if 'train_eval' in dataset_name else 'val_full'
suffix += f'_{dataset_name.split("_")[0]}'
results_val = results[dataset_name].copy()
results_val = {f'{suffix}'+k: v for k, v in results_val.items()}
results_structured.update(results_val)
for dataset_name in cfg.DATASETS.TEST_2D_ONLY:
if dataset_name in results:
suffix = 'train' if 'train_eval' in dataset_name else 'val'
suffix += f'_{dataset_name.split("_")[0]}'
results_val = results[dataset_name].copy()
results_val = {f'{suffix}'+k: v for k, v in results_val.items()}
results_structured.update(results_val)
return results_structured
def run_step(self):
"""
Implement the AMP training logic.
"""
self._trainer.iter = self.iter
assert self._trainer.model.training, "[AMPTrainer] model was changed to eval mode!"
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
from torch.cuda.amp import autocast
assert self.cfg.SOLVER.AMP.ENABLED
start = time.perf_counter()
data = next(self._trainer._data_loader_iter)
data_time = time.perf_counter() - start
with autocast(dtype=self._trainer.precision):
loss_dict = self._trainer.model(data)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
loss_custom = None
if 'loss_3d' in loss_dict or 'loss_2d' in loss_dict:
loss_name = 'loss_3d' if 'loss_3d' in loss_dict else 'loss_2d'
loss_custom = loss_dict[loss_name]
loss_dict.pop('loss_3d', None)
loss_dict.pop('loss_2d', None)
losses = sum(loss_dict.values())
if loss_custom is not None:
loss_dict[loss_name] = loss_custom
self._trainer.optimizer.zero_grad()
self._trainer.grad_scaler.scale(losses).backward()
self._trainer.after_backward()
self._trainer._write_metrics(loss_dict, data_time)
self._trainer.grad_scaler.step(self.optimizer)
self._trainer.grad_scaler.update()
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
# for poly lr schedule
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
add_maskformer2_video_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="odin")
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = Trainer.test(cfg, model)
if cfg.TEST.AUG.ENABLED:
raise NotImplementedError
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
# this is needed to prevent memory leak in conv2d layers
# see: https://github.com/pytorch/pytorch/issues/98688#issuecomment-1869290827
os.environ['TORCH_CUDNN_V8_API_DISABLED'] = '1'
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)