2
2
3
3
import argparse
4
4
import time
5
- from mpi4py import MPI
6
5
from configparser import ConfigParser
7
6
8
7
from lya_2pt .interface import Interface
@@ -16,42 +15,39 @@ def main():
16
15
parser .add_argument ('-i' , '--config' , type = str , default = None ,
17
16
help = ('Path to config file' ))
18
17
18
+ parser .add_argument ('-r' , '--rank' , type = int , default = 0 , help = ('Rank of the job' ))
19
+ parser .add_argument ('-s' , '--size' , type = int , default = 1 , help = ('Number of jobs' ))
20
+
19
21
args = parser .parse_args ()
20
22
23
+ if args .rank >= args .size :
24
+ raise MPIError (f"Rank { args .rank } is greater than the number of MPI processes { args .size } " )
25
+
21
26
config = ConfigParser ()
22
27
config .read (args .config )
23
28
24
- # Initilize MPI objects
25
- mpi_comm = MPI .COMM_WORLD
26
- mpi_rank = mpi_comm .Get_rank ()
27
- mpi_size = mpi_comm .Get_size ()
28
-
29
29
lya2pt = Interface (config )
30
30
31
- if len (lya2pt .files ) < mpi_size :
31
+ if len (lya2pt .files ) < args . size :
32
32
raise MPIError (f"Less files than MPI processes. "
33
33
f"Found { len (lya2pt .files )} healpix files and running "
34
- f"{ mpi_size } MPI processes. This is wasteful. "
34
+ f"{ args . size } MPI processes. This is wasteful. "
35
35
"Please lower the numper of MPI processes." )
36
36
37
- if mpi_rank == 0 :
38
- total_t1 = time . time ()
39
- print (' Reading tracers...' , flush = True )
40
- lya2pt .read_tracers (mpi_rank = mpi_rank )
37
+ total_t1 = time . time ()
38
+
39
+ print (f'Rank { args . rank } : Reading tracers...' , flush = True )
40
+ lya2pt .read_tracers (mpi_rank = args . rank )
41
41
42
- if mpi_rank == 0 :
43
- print ('Starting computation...' , flush = True )
44
- lya2pt .run (mpi_size = mpi_size , mpi_rank = mpi_rank )
42
+ print (f'Rank { args .rank } : Starting computation...' , flush = True )
43
+ lya2pt .run (mpi_size = args .size , mpi_rank = args .rank )
45
44
46
- if mpi_rank == 0 :
47
- print ('Writing results...' , flush = True )
48
- lya2pt .write_results (mpi_rank = mpi_rank )
45
+ print (f'Rank { args .rank } : Writing results...' , flush = True )
46
+ lya2pt .write_results (mpi_rank = args .rank )
49
47
50
- mpi_comm .Barrier ()
51
- if mpi_rank == 0 :
52
- total_t2 = time .time ()
53
- print (f'Total time: { (total_t2 - total_t1 ):.3f} sec' , flush = True )
54
- print ('Done' , flush = True )
48
+ total_t2 = time .time ()
49
+ print (f'Rank { args .rank } : Total time: { (total_t2 - total_t1 ):.3f} sec' , flush = True )
50
+ print (f'Rank { args .rank } : Done' , flush = True )
55
51
56
52
57
53
if __name__ == '__main__' :
0 commit comments