-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathmpi_pytorch.py
More file actions
36 lines (33 loc) · 1.27 KB
/
mpi_pytorch.py
File metadata and controls
36 lines (33 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import multiprocessing
import numpy as np
import os
import torch
from mpi4py import MPI
from mpi_tools import broadcast, mpi_avg, num_procs, proc_id
def setup_pytorch_for_mpi():
"""
Avoid slowdowns caused by each separate process's PyTorch using
more than its fair share of CPU resources.
"""
#print('Proc %d: Reporting original number of Torch threads as %d.'%(proc_id(), torch.get_num_threads()), flush=True)
if torch.get_num_threads()==1:
return
fair_num_threads = max(int(torch.get_num_threads() / num_procs()), 1)
torch.set_num_threads(fair_num_threads)
#print('Proc %d: Reporting new number of Torch threads as %d.'%(proc_id(), torch.get_num_threads()), flush=True)
def mpi_avg_grads(module):
""" Average contents of gradient buffers across MPI processes. """
if num_procs()==1:
return
for p in module.parameters():
if p.grad is not None:
p_grad_numpy = p.grad.numpy() # numpy view of tensor data
avg_p_grad = mpi_avg(p.grad)
p_grad_numpy[:] = avg_p_grad[:]
def sync_params(module):
""" Sync all parameters of module across all MPI processes. """
if num_procs()==1:
return
for p in module.parameters():
p_numpy = p.data.numpy()
broadcast(p_numpy)