diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 8b25e9bee..6edc1ef96 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -843,6 +843,24 @@ def init_params(self, params=None): elif isinstance(params, torch.Tensor): params = [params] + # Hack to load old checkpoint files + # Note: Handles updates to optimizer state in + # https://github.com/NVIDIA/apex/pull/1551 + if 'buckets' in self.state: + offset = 0 + for bucket in self.state['buckets']: + def maybe_setattr(name, val): + if not hasattr(bucket, name): + setattr(bucket, name, val) + maybe_setattr('bucket_size', self.default_shard_size * self.distributed_size) + maybe_setattr('shard_size', self.default_shard_size) + maybe_setattr('filled_size', bucket.bucket_size) + maybe_setattr('contiguous_buffer_offset', offset) + offset = max( + offset, + bucket.contiguous_buffer_offset + bucket.bucket_size, + ) + # Ignore parameters that have already been initialized params = [ param