Skip to content

Commit

Permalink
Fix precision configuration in Trainer class
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed Feb 2, 2024
1 parent 22a9bfd commit dc35006
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ def __init__(self, cfg: CfgNode):
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
precision = precision_converter.get(cfg.AMP_TRAIN.PRECISION, None)
precision = precision_converter.get(cfg.MODEL.AMP_TRAIN.PRECISION, None)
if precision is None:
raise ValueError(f"Unrecognized precision: {cfg.AMP_TRAIN.PRECISION}")
raise ValueError(f"Unrecognized precision: {cfg.MODEL.AMP_TRAIN.PRECISION}")
self._trainer.precision = precision

self.scheduler = self.build_lr_scheduler(cfg, optimizer)
Expand Down

0 comments on commit dc35006

Please sign in to comment.