From aeaf74d20d67ed648a7d6d16e661068d63629f3a Mon Sep 17 00:00:00 2001
From: Tim Moon <tmoon@nvidia.com>
Date: Thu, 2 Mar 2023 02:14:25 -0800
Subject: [PATCH] Hack to load old distopt checkpoints

Handles checkpoints generated before https://github.com/NVIDIA/apex/pull/1551.
---
 .../optimizers/distributed_fused_adam.py       | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py
index 8b25e9bee..6edc1ef96 100644
--- a/apex/contrib/optimizers/distributed_fused_adam.py
+++ b/apex/contrib/optimizers/distributed_fused_adam.py
@@ -843,6 +843,24 @@ def init_params(self, params=None):
         elif isinstance(params, torch.Tensor):
             params = [params]
 
+        # Hack to load old checkpoint files
+        # Note: Handles updates to optimizer state in
+        # https://github.com/NVIDIA/apex/pull/1551
+        if 'buckets' in self.state:
+            offset = 0
+            for bucket in self.state['buckets']:
+                def maybe_setattr(name, val):
+                    if not hasattr(bucket, name):
+                        setattr(bucket, name, val)
+                maybe_setattr('bucket_size', self.default_shard_size * self.distributed_size)
+                maybe_setattr('shard_size', self.default_shard_size)
+                maybe_setattr('filled_size', bucket.bucket_size)
+                maybe_setattr('contiguous_buffer_offset', offset)
+                offset = max(
+                    offset,
+                    bucket.contiguous_buffer_offset + bucket.bucket_size,
+                )
+
         # Ignore parameters that have already been initialized
         params = [
             param