@@ -109,7 +109,7 @@ def __init__(self):
109
109
)
110
110
cosine_scheduler = CosineAnnealingWarmRestarts (
111
111
self .optimizer ,
112
- T_0 = 1000 ,
112
+ T_0 = 10000 ,
113
113
T_mult = 2 ,
114
114
eta_min = self .hparams .learning_rate * 0.1 ,
115
115
)
@@ -137,6 +137,11 @@ def __init__(self):
137
137
hparams = self .hparams ,
138
138
)
139
139
140
+ self .bucket = self .comms .get_own_bucket ()
141
+ self .comms .try_commit (self .wallet , self .bucket )
142
+ self .comms .fetch_commitments ()
143
+
144
+
140
145
# Init peers
141
146
if not self .config .peers :
142
147
self .peers = self .comms .peers
@@ -175,23 +180,49 @@ async def run(self):
175
180
validator_uid , stake = self .comms .get_highest_stake_validator ()
176
181
if stake > 0 :
177
182
try :
178
- state_dict = await self .comms .get (
179
- uid = str (validator_uid ),
180
- window = self .current_window ,
181
- key = 'checkpoint' ,
182
- timeout = 240 ,
183
- local = False ,
184
- stale_retention = 10
185
- )
186
- if state_dict is not None :
187
- self .model .load_state_dict (state_dict )
188
- tplr .logger .info (f"Loaded checkpoint from validator { validator_uid } at window { self .current_window } " )
183
+ # Calculate the most recent window that should have a checkpoint
184
+ expected_checkpoint_window = (self .current_window // self .hparams .checkpoint_frequency ) * self .hparams .checkpoint_frequency
185
+
186
+ # Try last few windows in case of missed checkpoints
187
+ for window in range (expected_checkpoint_window , max (0 , expected_checkpoint_window - 3 * self .hparams .checkpoint_frequency ), - self .hparams .checkpoint_frequency ):
188
+ result = await self .comms .get (
189
+ uid = str (validator_uid ),
190
+ window = window ,
191
+ key = 'checkpoint' ,
192
+ timeout = 240 ,
193
+ local = False ,
194
+ stale_retention = 10
195
+ )
196
+ if result is None :
197
+ tplr .logger .debug (f"No checkpoint found for window { window } " )
198
+ continue
199
+
200
+ checkpoint_data , global_step = result
201
+ try :
202
+ # Load state dicts from dictionary
203
+ self .model .load_state_dict (checkpoint_data ['model_state_dict' ])
204
+ self .optimizer .load_state_dict (checkpoint_data ['optimizer_state_dict' ])
205
+ self .scheduler .load_state_dict (checkpoint_data ['scheduler_state_dict' ])
206
+ self .momentum = checkpoint_data ['momentum' ]
207
+ self .global_step = checkpoint_data ['global_step' ]
208
+
209
+ # Update optimizer and scheduler steps to match
210
+ self .optimizer ._step_count = self .global_step
211
+ self .scheduler .last_epoch = self .global_step
212
+
213
+ tplr .logger .info (f"Loaded checkpoint from validator { validator_uid } at window { window } , global_step={ self .global_step } " )
214
+ break # Successfully loaded checkpoint, exit loop
215
+ except KeyError as e :
216
+ tplr .logger .error (f"Invalid checkpoint format: missing key { e } " )
217
+ except Exception as e :
218
+ tplr .logger .error (f"Failed to load checkpoint: { e } " )
189
219
else :
190
- tplr .logger .info ("No checkpoint found, starting from scratch " )
220
+ tplr .logger .info ("No valid checkpoints found in recent windows " )
191
221
except Exception as e :
192
222
tplr .logger .warning (f"Failed to load checkpoint: { e } " )
193
223
else :
194
224
tplr .logger .info ("No active validators found, starting from scratch" )
225
+ self .global_step = 0
195
226
196
227
# Start background block listener
197
228
self .loop = asyncio .get_running_loop ()
@@ -200,6 +231,7 @@ async def run(self):
200
231
args = (self .loop ,),
201
232
daemon = True ,
202
233
).start ()
234
+ self .comms .start_commitment_fetcher ()
203
235
204
236
while True :
205
237
step_window = self .current_window
@@ -317,20 +349,37 @@ async def run(self):
317
349
xshapes [n ] = xshape
318
350
totalks [n ] = totalk
319
351
320
- # All-gather share state from peers
352
+ # Gather gradients from peers
321
353
tplr .logger .info (f"Start gather: { self .peers } " )
322
354
gather_result = await self .comms .gather (
323
355
state_dict = gradient ,
324
356
my_uid = self .uid ,
325
357
uids = self .peers ,
326
358
window = step_window ,
327
359
key = 'gradient' ,
328
- timeout = 5 ,
360
+ timeout = 20 ,
329
361
device = self .config .device ,
330
362
local = False ,
331
- stale_retention = 10
363
+ stale_retention = 10 ,
364
+ global_step = self .global_step ,
332
365
)
333
-
366
+
367
+ if gather_result is None :
368
+ tplr .logger .error ("Failed to gather gradients from peers. Waiting for next window." )
369
+ # Wait for next window
370
+ while self .current_window == step_window :
371
+ await asyncio .sleep (0.1 )
372
+ continue # Proceed to the next window
373
+
374
+ # Update self.global_step based on the maximum global_step received
375
+ max_global_step = max (gather_result .global_steps + [self .global_step ])
376
+ if max_global_step > self .global_step :
377
+ tplr .logger .info (f"Updating global_step from { self .global_step } to { max_global_step } " )
378
+ self .global_step = max_global_step
379
+ # Update optimizer and scheduler steps
380
+ self .optimizer ._step_count = self .global_step
381
+ self .scheduler .last_epoch = self .global_step
382
+
334
383
# Decompress state and apply to grad.
335
384
for n , p in self .model .named_parameters ():
336
385
idxs_key = n + 'idxs'
0 commit comments