forked from jchelly/SOAP
-
Notifications
You must be signed in to change notification settings - Fork 3
/
io_test.py
98 lines (75 loc) · 2.58 KB
/
io_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/bin/env python
import time
import matplotlib.pyplot as plt
import numpy as np
from mpi4py import MPI
comm = MPI.COMM_WORLD
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()
import swift_cells
import shared_mesh
import pytest
@pytest.mark.mpi
def test_io():
comm.barrier()
t0 = time.time()
# Open the snapshot
fname = "/cosma8/data/dp004/flamingo/Runs/L1000N0900/HYDRO_FIDUCIAL/snapshots/flamingo_0037/flamingo_0037.{file_nr}.hdf5"
try:
cellgrid = swift_cells.SWIFTCellGrid(fname)
except FileNotFoundError:
if comm_rank == 0:
print("File not found for running io_test")
return
# Quantities to read
property_names = {
"PartType0": ("Coordinates", "Velocities", "Masses"),
"PartType1": ("Coordinates", "Velocities", "Masses"),
}
# Specify region to read
pos_min = np.asarray((0.0, 0.0, 0.0)) * cellgrid.get_unit("snap_length")
pos_max = np.asarray((50.0, 50.0, 50.0)) * cellgrid.get_unit("snap_length")
# Read in the region
mask = cellgrid.empty_mask()
cellgrid.mask_region(mask, pos_min, pos_max)
data = cellgrid.read_masked_cells_to_shared_memory(property_names, mask, comm, 8)
comm.barrier()
t1 = time.time()
# Find read rate
nbytes = 0
for ptype in property_names:
for dataset in property_names[ptype]:
arr = data[ptype][dataset].full
nbytes += arr.data.nbytes
elapsed = t1 - t0
if comm_rank == 0:
rate = nbytes / elapsed / (1024 ** 3)
print("Read at %.2f GB/sec on %d ranks" % (rate, comm_size))
# Build the shared mesh
mesh = shared_mesh.SharedMesh(
comm, pos=data["PartType1"]["Coordinates"], resolution=256
)
if comm_rank == 0:
print("Built mesh")
comm.barrier()
if comm_rank == 0:
# Plot all particles
pos = data["PartType1"]["Coordinates"]
plt.plot(pos.full[:, 0], pos.full[:, 1], "k,", alpha=0.05)
plt.gca().set_aspect("equal")
# Try selecting a sphere
centre = np.asarray((30, 30, 30)) * cellgrid.get_unit("snap_length")
radius = 10 * cellgrid.get_unit("snap_length")
idx = mesh.query_radius_periodic(centre, radius, pos, cellgrid.boxsize)
plt.plot(pos.full[idx, 0], pos.full[idx, 1], "g,")
plt.xlim(0, 150)
plt.ylim(0, 150)
plt.savefig(f"io_test.png", dpi=300)
plt.close()
# Free the shared particle data
for ptype in data:
for name in data[ptype]:
data[ptype][name].free()
mesh.free()
if __name__ == "__main__":
test_io()