forked from jbornschein/mpi4py-examples
-
Notifications
You must be signed in to change notification settings - Fork 1
/
08-matrix-matrix-product.py
executable file
·83 lines (60 loc) · 2.41 KB
/
08-matrix-matrix-product.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python
from __future__ import print_function
import numpy as np
from mpi4py import MPI
from time import time
# ============================================================================
my_N = 6000
my_M = 6000
# ============================================================================
NORTH = 0
SOUTH = 1
EAST = 2
WEST = 3
def pprint(string, comm=MPI.COMM_WORLD):
if comm.rank == 0:
print(string)
if __name__ == '__main__':
comm = MPI.COMM_WORLD
mpi_rows = int(np.floor(np.sqrt(comm.size)))
mpi_cols = comm.size // mpi_rows
if mpi_rows * mpi_cols > comm.size:
mpi_cols -= 1
if mpi_rows * mpi_cols > comm.size:
mpi_rows -= 1
pprint('Creating a {:d} x {:d} processor grid...'.format(mpi_rows, mpi_cols))
ccomm = comm.Create_cart((mpi_rows, mpi_cols), periods=(True, True), reorder=True)
my_mpi_row, my_mpi_col = ccomm.Get_coords(ccomm.rank)
neigh = [0, 0, 0, 0]
neigh[NORTH], neigh[SOUTH] = ccomm.Shift(0, 1)
neigh[EAST], neigh[WEST] = ccomm.Shift(1, 1)
# Create matrices
my_A = np.random.normal(size=(my_N, my_M)).astype(np.float32)
my_B = np.random.normal(size=(my_N, my_M)).astype(np.float32)
my_C = np.zeros_like(my_A)
tile_A, tile_B = my_A, my_B
tile_A_, tile_B_ = np.empty_like(my_A), np.empty_like(my_A)
req = [None, None, None, None]
t0 = time()
for r in range(mpi_rows):
req[EAST] = ccomm.Isend(tile_A, neigh[EAST])
req[WEST] = ccomm.Irecv(tile_A_, neigh[WEST])
req[SOUTH] = ccomm.Isend(tile_B, neigh[SOUTH])
req[NORTH] = ccomm.Irecv(tile_B_, neigh[NORTH])
# t0 = time()
my_C += np.dot(tile_A, tile_B)
# t1 = time()
req[0].Waitall(req)
# t2 = time()
# print('Time computing %6.2f %6.2f' % (t1-t0, t2-t1))
comm.barrier()
t_total = time() - t0
t0 = time()
np.dot(tile_A, tile_B)
t_serial = time() - t0
pprint(78 * '=')
pprint('computed (serial) {:d} x {:d} x {:d} in {:6.2f} seconds'.format(my_M, my_M, my_N, t_serial))
pprint('... expecting parallel computation to take {:6.2f} seconds'.format(mpi_rows * mpi_rows * mpi_cols * t_serial / comm.size))
pprint('computed (parallel) {:d} x {:d} x {:d} in {:6.2f} seconds'.format(mpi_rows * my_M, mpi_rows * my_M, mpi_cols * my_N, t_total))
# print '[%d] (%d,%d): %s' % (comm.rank, my_mpi_row, my_mpi_col, neigh)
comm.barrier()