Skip to content

Commit 60ccce5

Browse files
committed
update
1 parent 1649475 commit 60ccce5

File tree

2 files changed

+76
-61
lines changed

2 files changed

+76
-61
lines changed

.github/workflows/amd-health.yml

Lines changed: 1 addition & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ on:
66
- cron: '0 2 * * *'
77
workflow_dispatch:
88
push:
9-
branches: [main]
109

1110
jobs:
1211
health-check:
@@ -28,64 +27,5 @@ jobs:
2827

2928
- name: Distributed Health Check
3029
run: |
31-
# Check how many GPUs are available
3230
python -c "import torch; print(f'Available GPUs: {torch.cuda.device_count()}')"
33-
34-
# Test process group initialization in a loop to debug hanging issues
35-
python -c "
36-
import torch
37-
import torch.distributed as dist
38-
import os
39-
import time
40-
import signal
41-
42-
def timeout_handler(signum, frame):
43-
print('✗ Process group initialization timed out after 30 seconds')
44-
exit(1)
45-
46-
# Set timeout for process group initialization
47-
signal.signal(signal.SIGALRM, timeout_handler)
48-
49-
num_gpus = torch.cuda.device_count()
50-
print(f'Testing process group initialization on {num_gpus} GPUs')
51-
52-
for attempt in range(3): # Try 3 times
53-
try:
54-
print(f'Attempt {attempt + 1}: Initializing process group...')
55-
56-
# Set environment variables
57-
os.environ['MASTER_ADDR'] = '127.0.0.1'
58-
os.environ['MASTER_PORT'] = str(12345 + attempt)
59-
os.environ['WORLD_SIZE'] = '1'
60-
os.environ['RANK'] = '0'
61-
62-
# Set 30 second timeout
63-
signal.alarm(30)
64-
65-
# Test single-process initialization first
66-
dist.init_process_group('nccl', rank=0, world_size=1)
67-
68-
# Cancel timeout
69-
signal.alarm(0)
70-
71-
print(f'✓ Attempt {attempt + 1}: Process group initialized successfully')
72-
73-
# Test basic tensor operations
74-
device = torch.device('cuda:0')
75-
tensor = torch.ones(10, device=device)
76-
print(f'✓ Tensor operations work: {tensor.sum().item()}')
77-
78-
dist.destroy_process_group()
79-
print(f'✓ Attempt {attempt + 1}: Process group destroyed successfully')
80-
break
81-
82-
except Exception as e:
83-
signal.alarm(0) # Cancel timeout
84-
print(f'✗ Attempt {attempt + 1} failed: {type(e).__name__}: {e}')
85-
if attempt == 2: # Last attempt
86-
print('✗ All initialization attempts failed')
87-
exit(1)
88-
time.sleep(2) # Wait before retry
89-
90-
print('✓ Distributed health check passed')
91-
"
31+
python scripts/test_distributed.py

scripts/test_distributed.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import torch.distributed as dist
3+
import torch.multiprocessing as mp
4+
import os
5+
import signal
6+
import sys
7+
8+
def timeout_handler(signum, frame):
9+
print('✗ TIMEOUT: Process hung')
10+
sys.exit(1)
11+
12+
def test_worker(rank, world_size, master_port):
13+
try:
14+
os.environ['MASTER_ADDR'] = '127.0.0.1'
15+
os.environ['MASTER_PORT'] = str(master_port)
16+
os.environ['RANK'] = str(rank)
17+
os.environ['WORLD_SIZE'] = str(world_size)
18+
19+
signal.signal(signal.SIGALRM, timeout_handler)
20+
signal.alarm(30)
21+
22+
print(f'Rank {rank}: Init NCCL...')
23+
dist.init_process_group('nccl', rank=rank, world_size=world_size)
24+
signal.alarm(0)
25+
26+
device = torch.device(f'cuda:{rank}')
27+
tensor = torch.ones(100, device=device) * rank
28+
29+
signal.alarm(15)
30+
dist.all_reduce(tensor)
31+
signal.alarm(0)
32+
33+
print(f'✓ Rank {rank}: sum = {tensor[0].item()}')
34+
dist.destroy_process_group()
35+
36+
except Exception as e:
37+
signal.alarm(0)
38+
print(f'✗ Rank {rank}: {e}')
39+
sys.exit(1)
40+
41+
def main():
42+
num_gpus = torch.cuda.device_count()
43+
print(f'Testing {num_gpus} GPUs - 4 rounds')
44+
45+
for round_num in range(4):
46+
print(f'=== ROUND {round_num + 1} ===')
47+
master_port = 29500 + round_num
48+
49+
mp.set_start_method('spawn', force=True)
50+
processes = []
51+
52+
for rank in range(num_gpus):
53+
p = mp.Process(target=test_worker, args=(rank, num_gpus, master_port))
54+
p.start()
55+
processes.append(p)
56+
57+
for i, p in enumerate(processes):
58+
p.join(timeout=60)
59+
if p.exitcode != 0:
60+
print(f'✗ ROUND {round_num + 1} FAILED')
61+
for rp in processes:
62+
if rp.is_alive():
63+
rp.terminate()
64+
sys.exit(1)
65+
elif p.is_alive():
66+
print(f'✗ ROUND {round_num + 1} HUNG')
67+
p.terminate()
68+
sys.exit(1)
69+
70+
print(f'✓ ROUND {round_num + 1} PASSED')
71+
72+
print('✓ ALL ROUNDS PASSED')
73+
74+
if __name__ == '__main__':
75+
main()

0 commit comments

Comments
 (0)