Skip to content

Commit cc99d2a

Browse files
Merge pull request #32 from tplr-ai/feat/vali_fix
Feat/vali fix
2 parents 0c53fce + 036ea74 commit cc99d2a

File tree

7 files changed

+498
-351
lines changed

7 files changed

+498
-351
lines changed

hparams.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
"spec_version": 5,
33
"project": "dough",
44
"sequence_length": 2048,
5-
"pages_per_window": 5,
5+
"pages_per_window": 10,
66
"batch_size": 6,
77
"learning_rate": 4e-4,
8-
"blocks_per_window": 4,
8+
"blocks_per_window": 6,
99
"windows_per_sync": 100,
1010
"windows_per_weights": 100,
1111
"momentum_decay": 0.999,
@@ -24,7 +24,7 @@
2424
"warmup_steps": 250,
2525
"alpha_f": 0.1,
2626
"t_max": 20000,
27-
"validator_offset": 2,
27+
"validator_offset": 4,
2828
"checkpoint_frequency": 50,
2929
"topk_peers": 20,
3030
"minimum_peers": 5,

neurons/miner.py

Lines changed: 46 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -168,50 +168,6 @@ def __init__(self):
168168

169169
# Main training loop.
170170
async def run(self):
171-
# Try to load latest checkpoint
172-
result = await self.comms.get_latest_checkpoint()
173-
if result:
174-
checkpoint_data, window = result
175-
try:
176-
# Load state dicts from checkpoint data
177-
self.model.load_state_dict({k: v.to(self.config.device) for k,v in checkpoint_data['model_state_dict'].items()})
178-
self.model.to(self.config.device)
179-
180-
# Load optimizer state
181-
for state in self.optimizer.state.values():
182-
for k, v in state.items():
183-
if torch.is_tensor(v):
184-
state[k] = v.to(self.config.device)
185-
self.optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
186-
187-
# Load scheduler state
188-
self.scheduler.load_state_dict(checkpoint_data['scheduler_state_dict'])
189-
190-
# Load momentum and global_step
191-
self.momentum = checkpoint_data['momentum']
192-
self.global_step = checkpoint_data['global_step']
193-
194-
# Adjust scheduler to catch up with current window
195-
checkpoint_window = checkpoint_data.get('checkpoint_window', None)
196-
if checkpoint_window is not None:
197-
window_difference = self.current_window - checkpoint_window
198-
if window_difference > 0:
199-
for _ in range(window_difference):
200-
self.scheduler.step()
201-
tplr.logger.info(f"Stepped scheduler {window_difference} times to catch up with current window {self.current_window}")
202-
else:
203-
tplr.logger.warning("Checkpoint does not contain 'checkpoint_window'; cannot adjust scheduler")
204-
205-
tplr.logger.info(f"Loaded checkpoint from window {window}, global_step={self.global_step}")
206-
except KeyError as e:
207-
tplr.logger.error(f"Invalid checkpoint format: missing key {e}")
208-
except Exception as e:
209-
tplr.logger.error(f"Failed to load checkpoint: {e}")
210-
else:
211-
tplr.logger.info("No valid checkpoints found, starting from scratch")
212-
self.global_step = 0
213-
self.model.to(self.config.device)
214-
215171
# Load Peers
216172
if not self.config.peers:
217173
self.peers = self.comms.peers
@@ -221,6 +177,31 @@ async def run(self):
221177
if self.uid not in self.peers:
222178
self.peers.append(self.uid)
223179

180+
self.comms.commitments = self.comms.get_commitments_sync()
181+
self.comms.update_peers_with_buckets()
182+
tplr.logger.info(f"Loaded commitments: {self.comms.commitments.keys()}")
183+
184+
success, loaded_momentum, loaded_global_step = await self.comms.load_checkpoint(
185+
model=self.model,
186+
optimizer=self.optimizer,
187+
scheduler=self.scheduler,
188+
transformer=self.transformer,
189+
compressor=self.compressor,
190+
current_window=self.current_window,
191+
device=self.config.device,
192+
peers=self.peers,
193+
uid=self.uid
194+
)
195+
if success:
196+
self.momentum = loaded_momentum
197+
self.global_step = loaded_global_step
198+
tplr.logger.info(f"Loaded checkpoint with global_step={self.global_step}")
199+
else:
200+
tplr.logger.info("Starting from scratch")
201+
self.global_step = 0
202+
self.momentum = {n: torch.zeros_like(p) for n, p in self.model.named_parameters()}
203+
self.model.to(self.config.device)
204+
224205
# Start background block listener
225206
self.loop = asyncio.get_running_loop()
226207
self.listener = threading.Thread(
@@ -233,13 +214,13 @@ async def run(self):
233214
self.comms.start_background_tasks()
234215

235216
while True:
217+
# 1. Initialize window and update peers
236218
step_window = self.current_window
237219
tplr.logger.info(f"\n{'-' * 40} Window: {step_window} {'-' * 40}")
238-
# self.comms.update_peers_with_buckets()
239-
# Update local references
220+
self.comms.update_peers_with_buckets()
240221
self.peers = self.comms.peers
241222

242-
# Get the pages for this window.
223+
# 2. Load training data for this window
243224
pages = await tplr.dataset.DatasetLoader.next_pages(
244225
offset = step_window,
245226
n_pages = self.hparams.pages_per_window,
@@ -253,7 +234,7 @@ async def run(self):
253234
)
254235
tplr.logger.info(f"Pages: {[p[1] for p in pages]} for Window: {step_window}")
255236

256-
# Accumulate gradient
237+
# 3. Accumulate gradients over batches
257238
start_time = time.time()
258239
tplr.logger.info("Start accumulating...")
259240
self.optimizer.zero_grad()
@@ -272,26 +253,27 @@ async def run(self):
272253
total_loss += outputs.loss.item()
273254
outputs.loss.backward()
274255

275-
# Track tokens
276256
batch_tokens += (labels != -100).sum().item()
277-
257+
# TODO: INCREASE LENGHT OF THE WINDOW
278258
tplr.logger.info(f'loss: {outputs.loss.item()}')
279259
if self.current_window != step_window:
280260
tplr.logger.info('<Exhausted window>')
281261
break
262+
263+
# 4. Wait for next window
264+
tplr.logger.info("Wait for next window...")
265+
while self.current_window == step_window:
266+
await asyncio.sleep(0.1)
282267
tplr.logger.info(f"Stopped accumulating: {i+1} batches with {(i+1) * self.hparams.batch_size * self.hparams.sequence_length} tokens")
283268

284-
# Calculate processing metrics
269+
# 5. Calculate and log metrics
285270
duration = time.time() - start_time
286271
self.batch_times.append(duration)
287272
self.total_tokens_processed += batch_tokens
288273

289-
# Log gradient metrics
290274
grad_norms = [p.grad.norm().item() for p in self.model.parameters() if p.grad is not None]
291275
weight_norms = [p.norm().item() for p in self.model.parameters()]
292276
momentum_norms = [m.norm().item() for m in self.momentum.values()]
293-
294-
# Enhanced wandb logging with all metrics
295277
self.wandb.log({
296278
# Training metrics
297279
"miner/loss": total_loss/(i+1),
@@ -321,7 +303,7 @@ async def run(self):
321303
"miner/mean_momentum_norm": sum(momentum_norms) / len(momentum_norms),
322304
}, step=self.global_step)
323305

324-
# Reduce gradient using DeMo.
306+
# 6. Prepare gradients for sharing using DeMo compression
325307
gradient = {}
326308
xshapes = {}
327309
totalks = {}
@@ -351,35 +333,35 @@ async def run(self):
351333
xshapes[n] = xshape
352334
totalks[n] = totalk
353335

354-
# Gather gradients from peers
336+
# 7. Gather and process peer gradients
355337
tplr.logger.info(f"Start gather: {self.peers}")
356338
gather_result = await self.comms.gather(
357339
state_dict=gradient,
358340
my_uid=self.uid,
359341
uids=self.peers,
360342
window=step_window,
361343
key='gradient',
362-
timeout=5,
344+
timeout=30,
363345
device=self.config.device,
364346
local=False,
365-
stale_retention=10,
347+
stale_retention=100,
366348
global_step=self.global_step,
367349
)
368350

369351
if gather_result is None:
370352
tplr.logger.error("Failed to gather gradients from peers. Waiting for next window.")
371-
# Wait for next window
372353
while self.current_window == step_window:
373354
await asyncio.sleep(0.1)
374-
continue # Proceed to the next window
355+
continue
375356

376-
# Update self.global_step based on the maximum global_step received
357+
# 8. Update global step based on peer information
377358
max_global_step = max(gather_result.global_steps + [self.global_step])
359+
tplr.logger.info(f"Gather global steps : {gather_result.global_steps}")
378360
if max_global_step > self.global_step:
379361
tplr.logger.info(f"Updating global_step from {self.global_step} to {max_global_step}")
380362
self.global_step = max_global_step
381363

382-
# Decompress state and apply to grad.
364+
# 9. Apply gathered gradients
383365
for n, p in self.model.named_parameters():
384366
idxs_key = n + 'idxs'
385367
vals_key = n + 'vals'
@@ -406,27 +388,19 @@ async def run(self):
406388
p.grad = new_grad
407389
else:
408390
p.grad.copy_(new_grad)
409-
# Sign-SGD
391+
# Sign-SGD
410392
p.grad.sign_()
411393
else:
412394
tplr.logger.info(f"Gradient data missing for parameter {n}, skipping.")
413395

414-
415-
416-
# Apply optimizer step
396+
# 10. Optimization step
417397
tplr.logger.info("Finish and step.")
418398
self.optimizer.step()
419399
self.scheduler.step()
420400
self.global_step += 1
421401
self.window_step += 1
422402
tplr.logger.info(f"Total optimization steps: {self.global_step}")
423403

424-
# Wait for next window
425-
tplr.logger.info("Wait for next window...")
426-
while self.current_window == step_window:
427-
await asyncio.sleep(0.1)
428-
self.window_step = 0
429-
430404
# Listens for new blocks and sets self.current_block and self.current_window
431405
def block_listener(self, loop):
432406
def handler(event, _u, _s):

0 commit comments

Comments
 (0)