Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update train_net.py #574

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@
this file as an example of how to use the library.
You may want to write your own script with your datasets and other customizations.
"""

import logging
import os
from collections import OrderedDict
import torch
from torch.nn.parallel import DistributedDataParallel

import detectron2.utils.comm as comm
from detectron2.data import MetadataCatalog, build_detection_train_loader
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
from detectron2.engine import AMPTrainer, SimpleTrainer, TrainerBase
from detectron2.utils.events import EventStorage
from detectron2.evaluation import (
COCOEvaluator,
Expand All @@ -36,8 +35,9 @@
)
from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.utils.logger import setup_logger

from adet.data.dataset_mapper import DatasetMapperWithBasis
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.env import TORCH_VERSION
from adet.data.my_dataset_mapper import DatasetMapperWithBasis
from adet.data.fcpose_dataset_mapper import FCPoseDatasetMapper
from adet.config import get_cfg
from adet.checkpoint import AdetCheckpointer
Expand All @@ -49,6 +49,41 @@ class Trainer(DefaultTrainer):
This is the same Trainer except that we rewrite the
`build_train_loader`/`resume_or_load` method.
"""
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
Use the custom checkpointer, which loads other backbone models
with matching heuristics.
"""
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)

if comm.get_world_size() > 1:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True
)

TrainerBase.__init__(self)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
) # init trainer

self.scheduler = self.build_lr_scheduler(cfg, optimizer) # init lr_scheduler
self.checkpointer = DetectionCheckpointer(
model,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg

self.register_hooks(self.build_hooks()) # how to use the hook? and how to register?

def build_hooks(self):
"""
Replace `DetectionCheckpointer` with `AdetCheckpointer`.
Expand All @@ -67,11 +102,17 @@ def build_hooks(self):
)
ret[i] = hooks.PeriodicCheckpointer(self.checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD)
return ret

def resume_or_load(self, resume=True):
checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
if resume and self.checkpointer.has_checkpoint():
self.start_iter = checkpoint.get("iteration", -1) + 1
if isinstance(self.model, DistributedDataParallel):
# broadcast loaded data/model from the first rank, because other
# machines may not have access to the checkpoint file
if TORCH_VERSION >= (1, 7):
self.model._sync_params_and_buffers()
self.start_iter = comm.all_gather(self.start_iter)[0]

def train_loop(self, start_iter: int, max_iter: int):
"""
Expand Down Expand Up @@ -100,6 +141,7 @@ def train(self):
OrderedDict of results, if evaluation is enabled. Otherwise None.
"""
self.train_loop(self.start_iter, self.max_iter)

if hasattr(self, "_last_eval_results") and comm.is_main_process():
verify_results(self.cfg, self._last_eval_results)
return self._last_eval_results
Expand Down Expand Up @@ -203,7 +245,7 @@ def main(args):
AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = Trainer.test(cfg, model) # d2 defaults.py
res = Trainer.test(cfg, model) # d2 defaults.py
if comm.is_main_process():
verify_results(cfg, res)
if cfg.TEST.AUG.ENABLED:
Expand All @@ -220,6 +262,7 @@ def main(args):
trainer.register_hooks(
[hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
)
# trainer.checkpointer()
return trainer.train()


Expand Down