Skip to content

Commit a987435

Browse files
authored
Merge pull request #452 from tobac-project/min_distance_3D_fix
Fix minimum distance filtering for varying vertical coordinates
2 parents e8ccdcc + bea540c commit a987435

File tree

2 files changed

+125
-92
lines changed

2 files changed

+125
-92
lines changed

tobac/feature_detection.py

Lines changed: 87 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,39 +1121,34 @@ def feature_detection_multithreshold_timestep(
11211121
elif i_threshold == 0:
11221122
regions_old = regions_i
11231123

1124-
if statistic:
1125-
# reconstruct the labeled regions based on the regions dict
1126-
labels = np.zeros(track_data.shape)
1127-
labels = labels.astype(int)
1128-
for key in regions_old.keys():
1129-
labels.ravel()[regions_old[key]] = key
1130-
# apply function to get statistics based on labeled regions and functions provided by the user
1131-
# the feature dataframe is updated by appending a column for each metric
1132-
if statistics_unsmoothed:
1133-
features_thresholds = get_statistics(
1134-
features_thresholds,
1135-
labels,
1136-
data_i.core_data(),
1137-
statistic=statistic,
1138-
index=np.unique(labels[labels > 0]),
1139-
id_column="idx",
1140-
)
1141-
else:
1142-
features_thresholds = get_statistics(
1143-
features_thresholds,
1144-
labels,
1145-
track_data,
1146-
statistic=statistic,
1147-
index=np.unique(labels[labels > 0]),
1148-
id_column="idx",
1149-
)
1150-
11511124
logging.debug(
11521125
"Finished feature detection for threshold "
11531126
+ str(i_threshold)
11541127
+ " : "
11551128
+ str(threshold_i)
11561129
)
1130+
1131+
if statistic:
1132+
# reconstruct the labeled regions based on the regions dict
1133+
labels = np.zeros(track_data.shape)
1134+
labels = labels.astype(int)
1135+
for key in regions_old.keys():
1136+
labels.ravel()[regions_old[key]] = key
1137+
# apply function to get statistics based on labeled regions and functions provided by the user
1138+
# the feature dataframe is updated by appending a column for each metric
1139+
1140+
# select which data to use according to statistics_unsmoothed option
1141+
stats_data = data_i.core_data() if statistics_unsmoothed else track_data
1142+
1143+
features_thresholds = get_statistics(
1144+
features_thresholds,
1145+
labels,
1146+
stats_data,
1147+
statistic=statistic,
1148+
index=np.unique(labels[labels > 0]),
1149+
id_column="idx",
1150+
)
1151+
11571152
return features_thresholds
11581153

11591154

@@ -1296,6 +1291,9 @@ def feature_detection_multithreshold(
12961291
if detect_subset is not None and ndim_time in detect_subset:
12971292
raise NotImplementedError("Cannot subset on time")
12981293

1294+
# Remember if dz is set and not vertical coord for min distance filtering
1295+
use_dz_for_filtering = dz is not None
1296+
12991297
if is_3D:
13001298
# We need to determine the time axis so that we can determine the
13011299
# vertical axis in each timestep if vertical_axis is not none.
@@ -1395,30 +1393,7 @@ def feature_detection_multithreshold(
13951393
statistic=statistic,
13961394
statistics_unsmoothed=statistics_unsmoothed,
13971395
)
1398-
# check if list of features is not empty, then merge features from different threshold
1399-
# values into one DataFrame and append to list for individual timesteps:
1400-
if not features_thresholds.empty:
1401-
hdim1_ax, hdim2_ax = internal_utils.find_hdim_axes_3D(
1402-
field_in, vertical_coord=vertical_coord
1403-
)
1404-
hdim1_max = field_in.shape[hdim1_ax] - 1
1405-
hdim2_max = field_in.shape[hdim2_ax] - 1
1406-
# Loop over DataFrame to remove features that are closer than distance_min to each
1407-
# other:
1408-
if min_distance > 0:
1409-
features_thresholds = filter_min_distance(
1410-
features_thresholds,
1411-
dxy=dxy,
1412-
dz=dz,
1413-
min_distance=min_distance,
1414-
z_coordinate_name=vertical_coord,
1415-
target=target,
1416-
PBC_flag=PBC_flag,
1417-
min_h1=0,
1418-
max_h1=hdim1_max,
1419-
min_h2=0,
1420-
max_h2=hdim2_max,
1421-
)
1396+
14221397
list_features_timesteps.append(features_thresholds)
14231398

14241399
logging.debug(
@@ -1440,9 +1415,41 @@ def feature_detection_multithreshold(
14401415
)
14411416
else:
14421417
features = add_coordinates(features, field_in)
1418+
1419+
# Loop over DataFrame to remove features that are closer than distance_min to each
1420+
# other:
1421+
filtered_features = []
1422+
if min_distance > 0:
1423+
hdim1_ax, hdim2_ax = internal_utils.find_hdim_axes_3D(
1424+
field_in, vertical_coord=vertical_coord
1425+
)
1426+
hdim1_max = field_in.shape[hdim1_ax] - 1
1427+
hdim2_max = field_in.shape[hdim2_ax] - 1
1428+
1429+
for _, features_frame in features.groupby("frame"):
1430+
filtered_features.append(
1431+
filter_min_distance(
1432+
features_frame,
1433+
dxy=dxy,
1434+
dz=dz if use_dz_for_filtering else None,
1435+
min_distance=min_distance,
1436+
z_coordinate_name=(
1437+
None if use_dz_for_filtering else vertical_coord
1438+
),
1439+
target=target,
1440+
PBC_flag=PBC_flag,
1441+
min_h1=0,
1442+
max_h1=hdim1_max,
1443+
min_h2=0,
1444+
max_h2=hdim2_max,
1445+
)
1446+
)
1447+
features = pd.concat(filtered_features, ignore_index=True)
1448+
14431449
else:
14441450
features = None
14451451
logging.debug("No features detected")
1452+
14461453
logging.debug("feature detection completed")
14471454
return features
14481455

@@ -1512,47 +1519,47 @@ def filter_min_distance(
15121519
pandas DataFrame
15131520
features after filtering
15141521
"""
1522+
# Optional coordinate names are not yet implemented, set to defaults here:
15151523
if dxy is None:
15161524
raise NotImplementedError("dxy currently must be set.")
15171525

1518-
# if PBC_flag != "none":
1519-
# raise NotImplementedError("We haven't yet implemented PBCs into this.")
1520-
1521-
# if we are 3D, the vertical dimension is in features. if we are 2D, there
1522-
# is no vertical dimension in features.
1523-
is_3D = "vdim" in features
1524-
1525-
if is_3D and dz is None:
1526-
z_coordinate_name = internal_utils.find_dataframe_vertical_coord(
1527-
features, z_coordinate_name
1528-
)
1529-
15301526
# Check if both dxy and their coordinate names are specified.
15311527
# If they are, warn that we will use dxy.
1532-
if dxy is not None and (
1533-
x_coordinate_name in features and y_coordinate_name in features
1534-
):
1528+
elif x_coordinate_name in features and y_coordinate_name in features:
15351529
warnings.warn(
15361530
"Both " + x_coordinate_name + "/" + y_coordinate_name + " and dxy "
15371531
"set. Using constant dxy. Set dxy to None if you want to use the "
15381532
"interpolated coordinates, or set `x_coordinate_name` and "
15391533
"`y_coordinate_name` to None to use a constant dxy."
15401534
)
1535+
y_coordinate_name = "hdim_1"
1536+
x_coordinate_name = "hdim_2"
1537+
# If dxy only, use hdim_1, hdim_1 as default horizontal dimensions
1538+
else:
1539+
y_coordinate_name = "hdim_1"
1540+
x_coordinate_name = "hdim_2"
15411541

1542-
# Check and if both dz is specified and altitude is available, warn that we will use dz.
1543-
if is_3D and (dz is not None and z_coordinate_name in features):
1544-
warnings.warn(
1545-
"Both "
1546-
+ z_coordinate_name
1547-
+ " and dz available to filter_min_distance; using constant dz. "
1548-
"Set dz to none if you want to use altitude or set `z_coordinate_name` to None to use "
1549-
"constant dz."
1550-
)
1551-
1552-
# As optional coordinate names are not yet implemented, set to defaults here:
1553-
z_coordinate_name = "vdim"
1554-
y_coordinate_name = "hdim_1"
1555-
x_coordinate_name = "hdim_2"
1542+
# if we are 3D, the vertical dimension is in features
1543+
is_3D = "vdim" in features
1544+
if is_3D:
1545+
if dz is None:
1546+
# Find vertical coord name and set dz to 1
1547+
z_coordinate_name = internal_utils.find_dataframe_vertical_coord(
1548+
variable_dataframe=features, vertical_coord=z_coordinate_name
1549+
)
1550+
dz = 1
1551+
else:
1552+
# Use dz, warn if both are set
1553+
if z_coordinate_name is not None:
1554+
warnings.warn(
1555+
"Both "
1556+
+ z_coordinate_name
1557+
+ " and dz available to filter_min_distance; using constant dz. "
1558+
"Set dz to none if you want to use altitude or set `z_coordinate_name` to None to use "
1559+
"constant dz.",
1560+
UserWarning,
1561+
)
1562+
z_coordinate_name = "vdim"
15561563

15571564
if target not in ["minimum", "maximum"]:
15581565
raise ValueError(

tobac/tests/test_feature_detection.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,34 @@ def test_feature_detection_position(position_threshold):
319319
True,
320320
False,
321321
),
322+
( # Test using z coord name
323+
(0, 0, 0, 4, 1),
324+
(1, 1, 1, 4, 1),
325+
1000,
326+
None,
327+
1,
328+
"maximum",
329+
False,
330+
False,
331+
True,
332+
"none",
333+
True,
334+
True,
335+
),
336+
( # Test using z coord name
337+
(0, 0, 0, 5, 1),
338+
(1, 1, 1, 4, 1),
339+
1,
340+
None,
341+
101,
342+
"maximum",
343+
False,
344+
False,
345+
True,
346+
"none",
347+
True,
348+
False,
349+
),
322350
],
323351
)
324352
def test_filter_min_distance(
@@ -415,18 +443,6 @@ def test_filter_min_distance(
415443

416444
feat_combined = pd.concat([feat_1_interp, feat_2_interp], ignore_index=True)
417445

418-
filter_dist_opts = dict()
419-
420-
if add_x_coords:
421-
feat_combined[x_coord_name] = feat_combined["hdim_2"] * assumed_dxy
422-
filter_dist_opts["x_coordinate_name"] = x_coord_name
423-
if add_y_coords:
424-
feat_combined[y_coord_name] = feat_combined["hdim_1"] * assumed_dxy
425-
filter_dist_opts["y_coordinate_name"] = y_coord_name
426-
if add_z_coords and is_3D:
427-
feat_combined[z_coord_name] = feat_combined["vdim"] * assumed_dz
428-
filter_dist_opts["z_coordinate_name"] = z_coord_name
429-
430446
filter_dist_opts = {
431447
"features": feat_combined,
432448
"dxy": dxy,
@@ -439,6 +455,16 @@ def test_filter_min_distance(
439455
"min_h2": 0,
440456
"max_h2": 100,
441457
}
458+
if add_x_coords:
459+
feat_combined[x_coord_name] = feat_combined["hdim_2"] * assumed_dxy
460+
filter_dist_opts["x_coordinate_name"] = x_coord_name
461+
if add_y_coords:
462+
feat_combined[y_coord_name] = feat_combined["hdim_1"] * assumed_dxy
463+
filter_dist_opts["y_coordinate_name"] = y_coord_name
464+
if add_z_coords and is_3D:
465+
feat_combined[z_coord_name] = feat_combined["vdim"] * assumed_dz
466+
filter_dist_opts["z_coordinate_name"] = z_coord_name
467+
442468
if target not in ["maximum", "minimum"]:
443469
with pytest.raises(ValueError):
444470
out_feats = feat_detect.filter_min_distance(**filter_dist_opts)

0 commit comments

Comments
 (0)