From aeaf74d20d67ed648a7d6d16e661068d63629f3a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 2 Mar 2023 02:14:25 -0800 Subject: [PATCH] Hack to load old distopt checkpoints Handles checkpoints generated before https://github.com/NVIDIA/apex/pull/1551. --- .../optimizers/distributed_fused_adam.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 8b25e9bee..6edc1ef96 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -843,6 +843,24 @@ def init_params(self, params=None): elif isinstance(params, torch.Tensor): params = [params] + # Hack to load old checkpoint files + # Note: Handles updates to optimizer state in + # https://github.com/NVIDIA/apex/pull/1551 + if 'buckets' in self.state: + offset = 0 + for bucket in self.state['buckets']: + def maybe_setattr(name, val): + if not hasattr(bucket, name): + setattr(bucket, name, val) + maybe_setattr('bucket_size', self.default_shard_size * self.distributed_size) + maybe_setattr('shard_size', self.default_shard_size) + maybe_setattr('filled_size', bucket.bucket_size) + maybe_setattr('contiguous_buffer_offset', offset) + offset = max( + offset, + bucket.contiguous_buffer_offset + bucket.bucket_size, + ) + # Ignore parameters that have already been initialized params = [ param