Skip to content

Commit

Permalink
feat(vtk): include all arrays on pathline input (#2161)
Browse files Browse the repository at this point in the history
Instead of hardcoding which arrays to include in vtk pathline export, include all numeric arrays present in the pathline data
  • Loading branch information
wpbonelli authored Apr 19, 2024
1 parent ff82488 commit 43cbe47
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions flopy/export/vtk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,8 +1098,8 @@ def add_pathline_points(self, pathlines, timeseries=False):
animation or as a single static vtk file. Default is false.
"""

mpx_keys = ["particleid", "time", "k"]
prt_keys = ["imdl", "iprp", "irpt", "trelease", "ilay"]
mpx_fields = ["particleid", "time", "k"]
prt_fields = ["imdl", "iprp", "irpt", "trelease", "ilay"]

if isinstance(pathlines, list):
if len(pathlines) == 0:
Expand All @@ -1112,23 +1112,32 @@ def add_pathline_points(self, pathlines, timeseries=False):
)
for pl in pathlines
]
if all(k in pathlines[0].dtype.names for k in mpx_keys):
keys = mpx_keys
elif all(k in pathlines[0].dtype.names for k in prt_keys):
keys = prt_keys
else:
fields = pathlines[0].dtype.names
arr_fields = {
n: pathlines[0].dtype[n]
for n in fields
if np.issubdtype(pathlines[0].dtype[n], np.number)
}
if not (
all(k in fields for k in mpx_fields)
or all(k in fields for k in prt_fields)
):
raise ValueError("Unrecognized pathline dtype")
elif isinstance(pathlines, (np.recarray, np.ndarray, pd.DataFrame)):
if isinstance(pathlines, pd.DataFrame):
pathlines = pathlines.to_records(index=False)
if all(k in pathlines.dtype.names for k in mpx_keys):
keys = mpx_keys
fields = pathlines.dtype.names
arr_fields = {
n: pathlines.dtype[n]
for n in fields
if np.issubdtype(pathlines.dtype[n], np.number)
}
if all(k in pathlines.dtype.names for k in mpx_fields):
pids = np.unique(pathlines.particleid)
pathlines = [
pathlines[pathlines.particleid == pid] for pid in pids
]
elif all(k in pathlines.dtype.names for k in prt_keys):
keys = prt_keys
elif all(k in pathlines.dtype.names for k in prt_fields):
pls = []
for imdl in np.unique(pathlines.imdl):
for iprp in np.unique(pathlines.iprp):
Expand All @@ -1150,7 +1159,7 @@ def add_pathline_points(self, pathlines, timeseries=False):
)

if not timeseries:
arrays = {key: [] for key in keys}
arrays = {f: [] for f in arr_fields}
points = []
lines = []
for recarray in pathlines:
Expand All @@ -1160,8 +1169,8 @@ def add_pathline_points(self, pathlines, timeseries=False):
t = tuple(rec[["x", "y", "z"]])
line.append(t)
points.append(t)
for key in keys:
arrays[key].append(rec[key])
for f in arr_fields:
arrays[f].append(rec[f])
lines.append(line)

self._set_particle_track_data(points, lines, arrays)
Expand All @@ -1176,14 +1185,12 @@ def add_pathline_points(self, pathlines, timeseries=False):
time = rec["time"]
if time not in points:
points[time] = [tuple(rec[["x", "y", "z"]])]
t = {key: [] for key in keys}
timeseries_data[time] = t

timeseries_data[time] = {f: [] for f in arr_fields}
else:
points[time].append(tuple(rec[["x", "y", "z"]]))

for key in keys:
timeseries_data[time][key].append(rec[key])
for f in arr_fields:
timeseries_data[time][f].append(rec[f])

self.__pathline_transient_data = timeseries_data
self._pathline_points = points
Expand Down

0 comments on commit 43cbe47

Please sign in to comment.