@@ -168,50 +168,6 @@ def __init__(self):
168
168
169
169
# Main training loop.
170
170
async def run (self ):
171
- # Try to load latest checkpoint
172
- result = await self .comms .get_latest_checkpoint ()
173
- if result :
174
- checkpoint_data , window = result
175
- try :
176
- # Load state dicts from checkpoint data
177
- self .model .load_state_dict ({k : v .to (self .config .device ) for k ,v in checkpoint_data ['model_state_dict' ].items ()})
178
- self .model .to (self .config .device )
179
-
180
- # Load optimizer state
181
- for state in self .optimizer .state .values ():
182
- for k , v in state .items ():
183
- if torch .is_tensor (v ):
184
- state [k ] = v .to (self .config .device )
185
- self .optimizer .load_state_dict (checkpoint_data ['optimizer_state_dict' ])
186
-
187
- # Load scheduler state
188
- self .scheduler .load_state_dict (checkpoint_data ['scheduler_state_dict' ])
189
-
190
- # Load momentum and global_step
191
- self .momentum = checkpoint_data ['momentum' ]
192
- self .global_step = checkpoint_data ['global_step' ]
193
-
194
- # Adjust scheduler to catch up with current window
195
- checkpoint_window = checkpoint_data .get ('checkpoint_window' , None )
196
- if checkpoint_window is not None :
197
- window_difference = self .current_window - checkpoint_window
198
- if window_difference > 0 :
199
- for _ in range (window_difference ):
200
- self .scheduler .step ()
201
- tplr .logger .info (f"Stepped scheduler { window_difference } times to catch up with current window { self .current_window } " )
202
- else :
203
- tplr .logger .warning ("Checkpoint does not contain 'checkpoint_window'; cannot adjust scheduler" )
204
-
205
- tplr .logger .info (f"Loaded checkpoint from window { window } , global_step={ self .global_step } " )
206
- except KeyError as e :
207
- tplr .logger .error (f"Invalid checkpoint format: missing key { e } " )
208
- except Exception as e :
209
- tplr .logger .error (f"Failed to load checkpoint: { e } " )
210
- else :
211
- tplr .logger .info ("No valid checkpoints found, starting from scratch" )
212
- self .global_step = 0
213
- self .model .to (self .config .device )
214
-
215
171
# Load Peers
216
172
if not self .config .peers :
217
173
self .peers = self .comms .peers
@@ -221,6 +177,31 @@ async def run(self):
221
177
if self .uid not in self .peers :
222
178
self .peers .append (self .uid )
223
179
180
+ self .comms .commitments = self .comms .get_commitments_sync ()
181
+ self .comms .update_peers_with_buckets ()
182
+ tplr .logger .info (f"Loaded commitments: { self .comms .commitments .keys ()} " )
183
+
184
+ success , loaded_momentum , loaded_global_step = await self .comms .load_checkpoint (
185
+ model = self .model ,
186
+ optimizer = self .optimizer ,
187
+ scheduler = self .scheduler ,
188
+ transformer = self .transformer ,
189
+ compressor = self .compressor ,
190
+ current_window = self .current_window ,
191
+ device = self .config .device ,
192
+ peers = self .peers ,
193
+ uid = self .uid
194
+ )
195
+ if success :
196
+ self .momentum = loaded_momentum
197
+ self .global_step = loaded_global_step
198
+ tplr .logger .info (f"Loaded checkpoint with global_step={ self .global_step } " )
199
+ else :
200
+ tplr .logger .info ("Starting from scratch" )
201
+ self .global_step = 0
202
+ self .momentum = {n : torch .zeros_like (p ) for n , p in self .model .named_parameters ()}
203
+ self .model .to (self .config .device )
204
+
224
205
# Start background block listener
225
206
self .loop = asyncio .get_running_loop ()
226
207
self .listener = threading .Thread (
@@ -233,13 +214,13 @@ async def run(self):
233
214
self .comms .start_background_tasks ()
234
215
235
216
while True :
217
+ # 1. Initialize window and update peers
236
218
step_window = self .current_window
237
219
tplr .logger .info (f"\n { '-' * 40 } Window: { step_window } { '-' * 40 } " )
238
- # self.comms.update_peers_with_buckets()
239
- # Update local references
220
+ self .comms .update_peers_with_buckets ()
240
221
self .peers = self .comms .peers
241
222
242
- # Get the pages for this window.
223
+ # 2. Load training data for this window
243
224
pages = await tplr .dataset .DatasetLoader .next_pages (
244
225
offset = step_window ,
245
226
n_pages = self .hparams .pages_per_window ,
@@ -253,7 +234,7 @@ async def run(self):
253
234
)
254
235
tplr .logger .info (f"Pages: { [p [1 ] for p in pages ]} for Window: { step_window } " )
255
236
256
- # Accumulate gradient
237
+ # 3. Accumulate gradients over batches
257
238
start_time = time .time ()
258
239
tplr .logger .info ("Start accumulating..." )
259
240
self .optimizer .zero_grad ()
@@ -272,26 +253,27 @@ async def run(self):
272
253
total_loss += outputs .loss .item ()
273
254
outputs .loss .backward ()
274
255
275
- # Track tokens
276
256
batch_tokens += (labels != - 100 ).sum ().item ()
277
-
257
+ # TODO: INCREASE LENGHT OF THE WINDOW
278
258
tplr .logger .info (f'loss: { outputs .loss .item ()} ' )
279
259
if self .current_window != step_window :
280
260
tplr .logger .info ('<Exhausted window>' )
281
261
break
262
+
263
+ # 4. Wait for next window
264
+ tplr .logger .info ("Wait for next window..." )
265
+ while self .current_window == step_window :
266
+ await asyncio .sleep (0.1 )
282
267
tplr .logger .info (f"Stopped accumulating: { i + 1 } batches with { (i + 1 ) * self .hparams .batch_size * self .hparams .sequence_length } tokens" )
283
268
284
- # Calculate processing metrics
269
+ # 5. Calculate and log metrics
285
270
duration = time .time () - start_time
286
271
self .batch_times .append (duration )
287
272
self .total_tokens_processed += batch_tokens
288
273
289
- # Log gradient metrics
290
274
grad_norms = [p .grad .norm ().item () for p in self .model .parameters () if p .grad is not None ]
291
275
weight_norms = [p .norm ().item () for p in self .model .parameters ()]
292
276
momentum_norms = [m .norm ().item () for m in self .momentum .values ()]
293
-
294
- # Enhanced wandb logging with all metrics
295
277
self .wandb .log ({
296
278
# Training metrics
297
279
"miner/loss" : total_loss / (i + 1 ),
@@ -321,7 +303,7 @@ async def run(self):
321
303
"miner/mean_momentum_norm" : sum (momentum_norms ) / len (momentum_norms ),
322
304
}, step = self .global_step )
323
305
324
- # Reduce gradient using DeMo.
306
+ # 6. Prepare gradients for sharing using DeMo compression
325
307
gradient = {}
326
308
xshapes = {}
327
309
totalks = {}
@@ -351,35 +333,35 @@ async def run(self):
351
333
xshapes [n ] = xshape
352
334
totalks [n ] = totalk
353
335
354
- # Gather gradients from peers
336
+ # 7. Gather and process peer gradients
355
337
tplr .logger .info (f"Start gather: { self .peers } " )
356
338
gather_result = await self .comms .gather (
357
339
state_dict = gradient ,
358
340
my_uid = self .uid ,
359
341
uids = self .peers ,
360
342
window = step_window ,
361
343
key = 'gradient' ,
362
- timeout = 5 ,
344
+ timeout = 30 ,
363
345
device = self .config .device ,
364
346
local = False ,
365
- stale_retention = 10 ,
347
+ stale_retention = 100 ,
366
348
global_step = self .global_step ,
367
349
)
368
350
369
351
if gather_result is None :
370
352
tplr .logger .error ("Failed to gather gradients from peers. Waiting for next window." )
371
- # Wait for next window
372
353
while self .current_window == step_window :
373
354
await asyncio .sleep (0.1 )
374
- continue # Proceed to the next window
355
+ continue
375
356
376
- # Update self.global_step based on the maximum global_step received
357
+ # 8. Update global step based on peer information
377
358
max_global_step = max (gather_result .global_steps + [self .global_step ])
359
+ tplr .logger .info (f"Gather global steps : { gather_result .global_steps } " )
378
360
if max_global_step > self .global_step :
379
361
tplr .logger .info (f"Updating global_step from { self .global_step } to { max_global_step } " )
380
362
self .global_step = max_global_step
381
363
382
- # Decompress state and apply to grad.
364
+ # 9. Apply gathered gradients
383
365
for n , p in self .model .named_parameters ():
384
366
idxs_key = n + 'idxs'
385
367
vals_key = n + 'vals'
@@ -406,27 +388,19 @@ async def run(self):
406
388
p .grad = new_grad
407
389
else :
408
390
p .grad .copy_ (new_grad )
409
- # Sign-SGD
391
+ # Sign-SGD
410
392
p .grad .sign_ ()
411
393
else :
412
394
tplr .logger .info (f"Gradient data missing for parameter { n } , skipping." )
413
395
414
-
415
-
416
- # Apply optimizer step
396
+ # 10. Optimization step
417
397
tplr .logger .info ("Finish and step." )
418
398
self .optimizer .step ()
419
399
self .scheduler .step ()
420
400
self .global_step += 1
421
401
self .window_step += 1
422
402
tplr .logger .info (f"Total optimization steps: { self .global_step } " )
423
403
424
- # Wait for next window
425
- tplr .logger .info ("Wait for next window..." )
426
- while self .current_window == step_window :
427
- await asyncio .sleep (0.1 )
428
- self .window_step = 0
429
-
430
404
# Listens for new blocks and sets self.current_block and self.current_window
431
405
def block_listener (self , loop ):
432
406
def handler (event , _u , _s ):
0 commit comments