-
Notifications
You must be signed in to change notification settings - Fork 390
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Accelerate some adjustment for mixed precision (#1009)
* Use accelerator.autocast when computing loss According to the accelerate docs, loss computation should be performed within the accelerator.autocast context manager: https://huggingface.co/docs/accelerate/v0.21.0/en/quicktour#mixed-precision-training I tested if this makes a difference by running the following notebook with fp16 precision: https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_Finetuning.ipynb I found no difference at all: The runtime was practially the same and the losses were identical. Still, I think it's better to have this than not, as it is recommended by the accelerate docs. * Update LR scheduler callback to work w/ accelerate According to the accelerate docs: https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training the LR scheduler step should sometimes be skipped when using mixed precision training because accelerate may skip update steps internally. Therefore, I updated the LR scheduler callback to check if the net has an accelerator and if it does, to check if a step is necessary. This is actually quite hard to test because the necessity of stepping depends on accelerate's internal logic, which we don't want to test, and which might change in the future. Therefore, the added test just runs training with accelerate, mixed precision, and some lr schedulers, verifying that there is no error. When running these tests + the normal lr scheduler tests locally on a machine that supports fp16, I get 100% line coverage of lr_scheduler.py. I think this is good enough. * Non-functional clean ups related to lr schedulers While working on the fixes in this PR, I also cleaned up some lr scheduler code. These clean ups are non-functional. 1. We imported CyclicLR as TorchCyclicLR. I'm not sure why but it is somehow related to very old PyTorch versions we no longer support, so I removed this. 2. Fixed some indentations for conditional checks to improve readability. * Reviewer comment: Simplify conditional code
- Loading branch information
1 parent
07fc260
commit 312daaa
Showing
4 changed files
with
102 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters