Skip to content

Commit 26c516a

Browse files
Merge pull request #23 from tplr-ai/feat/validator_sync
Feat/validator sync
2 parents fe1bc1a + f4115cf commit 26c516a

File tree

11 files changed

+878
-359
lines changed

11 files changed

+878
-359
lines changed

docker/compose.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ services:
3535
image: containrrr/watchtower
3636
volumes:
3737
- /var/run/docker.sock:/var/run/docker.sock
38-
- ${HOME}/.docker/config.json:/config.json:ro
3938
command: --interval 30 --cleanup --label-enable
4039
restart: unless-stopped
4140
environment:

docker/docker-compose-test.yml

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
networks:
2+
test:
3+
driver: bridge
4+
15
services:
26
miner1:
37
build:
48
context: ..
59
dockerfile: docker/Dockerfile
610
container_name: templar-miner-M111
11+
networks:
12+
- test
713
volumes:
814
- ~/.bittensor/wallets:/root/.bittensor/wallets
915
- ../logs:/app/logs
@@ -27,22 +33,24 @@ services:
2733
reservations:
2834
devices:
2935
- driver: nvidia
30-
device_ids: [ '0', '1', '2' ]
36+
device_ids: [ '3' ]
3137
capabilities: [ gpu ]
3238

3339
miner2:
3440
build:
3541
context: ..
3642
dockerfile: docker/Dockerfile
3743
container_name: templar-miner-M222
44+
networks:
45+
- test
3846
volumes:
3947
- ~/.bittensor/wallets:/root/.bittensor/wallets
4048
- ../logs:/app/logs
4149
environment:
4250
NODE_TYPE: miner
4351
WALLET_NAME: Bistro
4452
WALLET_HOTKEY: M222
45-
CUDA_DEVICE: cuda:1
53+
CUDA_DEVICE: cuda:0
4654
NETWORK: test
4755
DEBUG: 'true'
4856
WANDB_API_KEY: ${WANDB_API_KEY}
@@ -58,36 +66,39 @@ services:
5866
reservations:
5967
devices:
6068
- driver: nvidia
61-
device_ids: [ '0', '1', '2' ]
69+
device_ids: [ '1' ]
6270
capabilities: [ gpu ]
6371

6472
validator:
6573
build:
6674
context: ..
6775
dockerfile: docker/Dockerfile
6876
container_name: templar-validator-V11
77+
networks:
78+
- test
6979
volumes:
7080
- ~/.bittensor/wallets:/root/.bittensor/wallets
7181
- ../logs:/app/logs
7282
environment:
7383
NODE_TYPE: validator
7484
WALLET_NAME: Bistro
7585
WALLET_HOTKEY: V11
76-
CUDA_DEVICE: cuda:2
86+
CUDA_DEVICE: cuda:0
7787
NETWORK: test
7888
DEBUG: 'true'
7989
WANDB_API_KEY: ${WANDB_API_KEY}
8090
NETUID: 268
81-
HOST_CUDA_VERSION : 12.6
91+
HOST_CUDA_VERSION: 12.6
8292
R2_ACCOUNT_ID: ${R2_ACCOUNT_ID}
8393
R2_READ_ACCESS_KEY_ID: ${R2_READ_ACCESS_KEY_ID}
8494
R2_READ_SECRET_ACCESS_KEY: ${R2_READ_SECRET_ACCESS_KEY}
85-
R2_WRITE_ACCESS_KEY_ID : ${R2_WRITE_ACCESS_KEY_ID}
86-
R2_WRITE_SECRET_ACCESS_KEY : ${R2_WRITE_SECRET_ACCESS_KEY}
95+
R2_WRITE_ACCESS_KEY_ID: ${R2_WRITE_ACCESS_KEY_ID}
96+
R2_WRITE_SECRET_ACCESS_KEY: ${R2_WRITE_SECRET_ACCESS_KEY}
97+
restart: always
8798
deploy:
8899
resources:
89100
reservations:
90101
devices:
91102
- driver: nvidia
92-
device_ids: [ '0', '1', '2' ]
103+
device_ids: [ '2' ]
93104
capabilities: [ gpu ]

hparams.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,8 @@
2424
"warmup_steps": 250,
2525
"alpha_f": 0.1,
2626
"t_max": 20000,
27-
"validator_offset": 4
27+
"validator_offset": 4,
28+
"checkpoint_frequency": 50,
29+
"topk_peers": 20,
30+
"minimum_peers": 5
2831
}

neurons/miner.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self):
109109
)
110110
cosine_scheduler = CosineAnnealingWarmRestarts(
111111
self.optimizer,
112-
T_0=1000,
112+
T_0=10000,
113113
T_mult=2,
114114
eta_min=self.hparams.learning_rate * 0.1,
115115
)
@@ -137,6 +137,11 @@ def __init__(self):
137137
hparams=self.hparams,
138138
)
139139

140+
self.bucket = self.comms.get_own_bucket()
141+
self.comms.try_commit(self.wallet, self.bucket)
142+
self.comms.fetch_commitments()
143+
144+
140145
# Init peers
141146
if not self.config.peers:
142147
self.peers = self.comms.peers
@@ -175,23 +180,49 @@ async def run(self):
175180
validator_uid, stake = self.comms.get_highest_stake_validator()
176181
if stake > 0:
177182
try:
178-
state_dict = await self.comms.get(
179-
uid=str(validator_uid),
180-
window=self.current_window,
181-
key='checkpoint',
182-
timeout=240,
183-
local=False,
184-
stale_retention=10
185-
)
186-
if state_dict is not None:
187-
self.model.load_state_dict(state_dict)
188-
tplr.logger.info(f"Loaded checkpoint from validator {validator_uid} at window {self.current_window}")
183+
# Calculate the most recent window that should have a checkpoint
184+
expected_checkpoint_window = (self.current_window // self.hparams.checkpoint_frequency) * self.hparams.checkpoint_frequency
185+
186+
# Try last few windows in case of missed checkpoints
187+
for window in range(expected_checkpoint_window, max(0, expected_checkpoint_window - 3 * self.hparams.checkpoint_frequency), -self.hparams.checkpoint_frequency):
188+
result = await self.comms.get(
189+
uid=str(validator_uid),
190+
window=window,
191+
key='checkpoint',
192+
timeout=240,
193+
local=False,
194+
stale_retention=10
195+
)
196+
if result is None:
197+
tplr.logger.debug(f"No checkpoint found for window {window}")
198+
continue
199+
200+
checkpoint_data, global_step = result
201+
try:
202+
# Load state dicts from dictionary
203+
self.model.load_state_dict(checkpoint_data['model_state_dict'])
204+
self.optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
205+
self.scheduler.load_state_dict(checkpoint_data['scheduler_state_dict'])
206+
self.momentum = checkpoint_data['momentum']
207+
self.global_step = checkpoint_data['global_step']
208+
209+
# Update optimizer and scheduler steps to match
210+
self.optimizer._step_count = self.global_step
211+
self.scheduler.last_epoch = self.global_step
212+
213+
tplr.logger.info(f"Loaded checkpoint from validator {validator_uid} at window {window}, global_step={self.global_step}")
214+
break # Successfully loaded checkpoint, exit loop
215+
except KeyError as e:
216+
tplr.logger.error(f"Invalid checkpoint format: missing key {e}")
217+
except Exception as e:
218+
tplr.logger.error(f"Failed to load checkpoint: {e}")
189219
else:
190-
tplr.logger.info("No checkpoint found, starting from scratch")
220+
tplr.logger.info("No valid checkpoints found in recent windows")
191221
except Exception as e:
192222
tplr.logger.warning(f"Failed to load checkpoint: {e}")
193223
else:
194224
tplr.logger.info("No active validators found, starting from scratch")
225+
self.global_step = 0
195226

196227
# Start background block listener
197228
self.loop = asyncio.get_running_loop()
@@ -200,6 +231,7 @@ async def run(self):
200231
args=(self.loop,),
201232
daemon=True,
202233
).start()
234+
self.comms.start_commitment_fetcher()
203235

204236
while True:
205237
step_window = self.current_window
@@ -317,20 +349,37 @@ async def run(self):
317349
xshapes[n] = xshape
318350
totalks[n] = totalk
319351

320-
# All-gather share state from peers
352+
# Gather gradients from peers
321353
tplr.logger.info(f"Start gather: {self.peers}")
322354
gather_result = await self.comms.gather(
323355
state_dict=gradient,
324356
my_uid=self.uid,
325357
uids=self.peers,
326358
window=step_window,
327359
key='gradient',
328-
timeout=5,
360+
timeout=20,
329361
device=self.config.device,
330362
local=False,
331-
stale_retention=10
363+
stale_retention=10,
364+
global_step=self.global_step,
332365
)
333-
366+
367+
if gather_result is None:
368+
tplr.logger.error("Failed to gather gradients from peers. Waiting for next window.")
369+
# Wait for next window
370+
while self.current_window == step_window:
371+
await asyncio.sleep(0.1)
372+
continue # Proceed to the next window
373+
374+
# Update self.global_step based on the maximum global_step received
375+
max_global_step = max(gather_result.global_steps + [self.global_step])
376+
if max_global_step > self.global_step:
377+
tplr.logger.info(f"Updating global_step from {self.global_step} to {max_global_step}")
378+
self.global_step = max_global_step
379+
# Update optimizer and scheduler steps
380+
self.optimizer._step_count = self.global_step
381+
self.scheduler.last_epoch = self.global_step
382+
334383
# Decompress state and apply to grad.
335384
for n, p in self.model.named_parameters():
336385
idxs_key = n + 'idxs'

0 commit comments

Comments
 (0)