Skip to content

Commit 4a2265a

Browse files
committed
Remove MPI
1 parent 8a8fea8 commit 4a2265a

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

lya_2pt/scripts/run_mpi.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import argparse
44
import time
5-
from mpi4py import MPI
65
from configparser import ConfigParser
76

87
from lya_2pt.interface import Interface
@@ -16,42 +15,39 @@ def main():
1615
parser.add_argument('-i', '--config', type=str, default=None,
1716
help=('Path to config file'))
1817

18+
parser.add_argument('-r', '--rank', type=int, default=0, help=('Rank of the job'))
19+
parser.add_argument('-s', '--size', type=int, default=1, help=('Number of jobs'))
20+
1921
args = parser.parse_args()
2022

23+
if args.rank >= args.size:
24+
raise MPIError(f"Rank {args.rank} is greater than the number of MPI processes {args.size}")
25+
2126
config = ConfigParser()
2227
config.read(args.config)
2328

24-
# Initilize MPI objects
25-
mpi_comm = MPI.COMM_WORLD
26-
mpi_rank = mpi_comm.Get_rank()
27-
mpi_size = mpi_comm.Get_size()
28-
2929
lya2pt = Interface(config)
3030

31-
if len(lya2pt.files) < mpi_size:
31+
if len(lya2pt.files) < args.size:
3232
raise MPIError(f"Less files than MPI processes. "
3333
f"Found {len(lya2pt.files)} healpix files and running "
34-
f"{mpi_size} MPI processes. This is wasteful. "
34+
f"{args.size} MPI processes. This is wasteful. "
3535
"Please lower the numper of MPI processes.")
3636

37-
if mpi_rank == 0:
38-
total_t1 = time.time()
39-
print('Reading tracers...', flush=True)
40-
lya2pt.read_tracers(mpi_rank=mpi_rank)
37+
total_t1 = time.time()
38+
39+
print(f'Rank {args.rank}: Reading tracers...', flush=True)
40+
lya2pt.read_tracers(mpi_rank=args.rank)
4141

42-
if mpi_rank == 0:
43-
print('Starting computation...', flush=True)
44-
lya2pt.run(mpi_size=mpi_size, mpi_rank=mpi_rank)
42+
print(f'Rank {args.rank}: Starting computation...', flush=True)
43+
lya2pt.run(mpi_size=args.size, mpi_rank=args.rank)
4544

46-
if mpi_rank == 0:
47-
print('Writing results...', flush=True)
48-
lya2pt.write_results(mpi_rank=mpi_rank)
45+
print(f'Rank {args.rank}: Writing results...', flush=True)
46+
lya2pt.write_results(mpi_rank=args.rank)
4947

50-
mpi_comm.Barrier()
51-
if mpi_rank == 0:
52-
total_t2 = time.time()
53-
print(f'Total time: {(total_t2-total_t1):.3f} sec', flush=True)
54-
print('Done', flush=True)
48+
total_t2 = time.time()
49+
print(f'Rank {args.rank}: Total time: {(total_t2-total_t1):.3f} sec', flush=True)
50+
print(f'Rank {args.rank}: Done', flush=True)
5551

5652

5753
if __name__ == '__main__':

0 commit comments

Comments
 (0)