Skip to content

Commit

Permalink
Merge pull request #354 from freemansw1/sfreeman_xarray_work_fd_julia…
Browse files Browse the repository at this point in the history
…_avgs

Switch feature detection to use xarray internally
  • Loading branch information
freemansw1 authored Jul 25, 2024
2 parents 5f82097 + bbb6b67 commit 0ebee82
Show file tree
Hide file tree
Showing 11 changed files with 1,245 additions and 297 deletions.
4 changes: 2 additions & 2 deletions tobac/analysis/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from iris.analysis.cartography import area_weights

from tobac.utils.bulk_statistics import get_statistics_from_mask
from tobac.utils.internal.basic import find_vertical_axis_from_coord
from tobac.utils.internal.basic import find_vertical_coord_name
from tobac.utils import decorators

__all__ = (
Expand Down Expand Up @@ -381,7 +381,7 @@ def calculate_area(features, mask, method_area=None, vertical_coord=None):
mask_slice = next(mask.slices_over("time"))
is_3d = len(mask_slice.core_data().shape) == 3
if is_3d:
vertical_coord_name = find_vertical_axis_from_coord(mask_slice, vertical_coord)
vertical_coord_name = find_vertical_coord_name(mask_slice, vertical_coord)
# Need to get var_name as xarray uses this to label dims
collapse_dim = mask_slice.coords(vertical_coord_name)[0].var_name
else:
Expand Down
60 changes: 36 additions & 24 deletions tobac/feature_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
from tobac.utils import get_statistics
import warnings

# from typing_extensions import Literal
import iris
import iris.cube


def feature_position(
hdim1_indices: list[int],
Expand Down Expand Up @@ -886,8 +890,9 @@ def feature_detection_threshold(
return features_threshold, regions


@internal_utils.irispandas_to_xarray()
def feature_detection_multithreshold_timestep(
data_i: np.array,
data_i: xr.DataArray,
i_time: int,
threshold: list[float] = None,
min_num: int = 0,
Expand Down Expand Up @@ -916,7 +921,7 @@ def feature_detection_multithreshold_timestep(
Parameters
----------
data_i : iris.cube.Cube
data_i : iris.cube.Cube or xarray.DataArray
3D field to perform the feature detection (single timestep) on.
i_time : int
Expand Down Expand Up @@ -995,7 +1000,7 @@ def feature_detection_multithreshold_timestep(
)

# get actual numpy array and make a copy so as not to change the data in the iris cube
track_data = data_i.core_data().copy()
track_data = data_i.values.copy()

track_data = gaussian_filter(
track_data, sigma=sigma_threshold
Expand Down Expand Up @@ -1127,9 +1132,9 @@ def feature_detection_multithreshold_timestep(
return features_thresholds


@decorators.xarray_to_iris()
@decorators.irispandas_to_xarray(save_iris_info=True)
def feature_detection_multithreshold(
field_in: iris.cube.Cube,
field_in: xr.DataArray,
dxy: float = None,
threshold: list[float] = None,
min_num: int = 0,
Expand All @@ -1150,15 +1155,17 @@ def feature_detection_multithreshold(
dz: Union[float, None] = None,
strict_thresholding: bool = False,
statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None,
preserve_iris_datetime_types: bool = True,
**kwargs,
) -> pd.DataFrame:
"""Perform feature detection based on contiguous regions.
The regions are above/below a threshold.
Parameters
----------
field_in : iris.cube.Cube
2D field to perform the tracking on (needs to have coordinate
field_in : iris.cube.Cube or xarray.DataArray
2D or 3D field to perform the tracking on (needs to have coordinate
'time' along one of its dimensions),
dxy : float
Expand Down Expand Up @@ -1230,6 +1237,11 @@ def feature_detection_multithreshold(
If True, a feature can only be detected if all previous thresholds have been met.
Default is False.
preserve_iris_datetime_types: bool, optional, default: True
If True, for iris input, preserve the original datetime type (typically
`cftime.DatetimeGregorian`) where possible. For xarray input, this parameter has no
effect.
Returns
-------
features : pandas.DataFrame
Expand All @@ -1238,12 +1250,10 @@ def feature_detection_multithreshold(
"""
from .utils import add_coordinates, add_coordinates_3D

time_var_name: str = "time"
logging.debug("start feature detection based on thresholds")

if "time" not in [coord.name() for coord in field_in.coords()]:
raise ValueError(
"input to feature detection step must include a dimension named 'time'"
)
ndim_time = internal_utils.find_axis_from_coord(field_in, time_var_name)

# Check whether we need to run 2D or 3D feature detection
if field_in.ndim == 3:
Expand All @@ -1255,8 +1265,6 @@ def feature_detection_multithreshold(
else:
raise ValueError("Feature detection only works with 2D or 3D data")

ndim_time = field_in.coord_dims("time")[0]

if detect_subset is not None:
raise NotImplementedError("Subsetting feature detection not yet supported.")

Expand All @@ -1275,7 +1283,7 @@ def feature_detection_multithreshold(
if vertical_axis is None:
# We need to determine vertical axis.
# first, find the name of the vertical axis
vertical_axis_name = internal_utils.find_vertical_axis_from_coord(
vertical_axis_name = internal_utils.find_vertical_coord_name(
field_in, vertical_coord=vertical_coord
)
# then find our axis number.
Expand All @@ -1298,9 +1306,6 @@ def feature_detection_multithreshold(
# create empty list to store features for all timesteps
list_features_timesteps = []

# loop over timesteps for feature identification:
data_time = field_in.slices_over("time")

# if single threshold is put in as a single value, turn it into a list
if type(threshold) in [int, float]:
threshold = [threshold]
Expand Down Expand Up @@ -1339,8 +1344,8 @@ def feature_detection_multithreshold(
"given in meter."
)

for i_time, data_i in enumerate(data_time):
time_i = data_i.coord("time").units.num2date(data_i.coord("time").points[0])
for i_time, time_i in enumerate(field_in.coords[time_var_name]):
data_i = field_in.isel({time_var_name: i_time})

features_thresholds = feature_detection_multithreshold_timestep(
data_i,
Expand Down Expand Up @@ -1387,9 +1392,7 @@ def feature_detection_multithreshold(
)
list_features_timesteps.append(features_thresholds)

logging.debug(
"Finished feature detection for " + time_i.strftime("%Y-%m-%d_%H:%M:%S")
)
logging.debug("Finished feature detection for %s", time_i)

logging.debug("feature detection: merging DataFrames")
# Check if features are detected and then concatenate features from different timesteps into
Expand All @@ -1402,10 +1405,19 @@ def feature_detection_multithreshold(
# features_filtered.drop(columns=['idx','num','threshold_value'],inplace=True)
if "vdim" in features:
features = add_coordinates_3D(
features, field_in, vertical_coord=vertical_coord
features,
field_in,
vertical_coord=vertical_coord,
preserve_iris_datetime_types=kwargs["converted_from_iris"]
& preserve_iris_datetime_types,
)
else:
features = add_coordinates(features, field_in)
features = add_coordinates(
features,
field_in,
preserve_iris_datetime_types=kwargs["converted_from_iris"]
& preserve_iris_datetime_types,
)
else:
features = None
logging.debug("No features detected")
Expand Down
2 changes: 1 addition & 1 deletion tobac/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def segmentation_timestep(
hdim_1_axis = 0
hdim_2_axis = 1
elif field_in.ndim == 3:
vertical_axis = internal_utils.find_vertical_axis_from_coord(
vertical_axis = internal_utils.find_vertical_coord_name(
field_in, vertical_coord=vertical_coord
)
ndim_vertical = field_in.coord_dims(vertical_axis)
Expand Down
28 changes: 23 additions & 5 deletions tobac/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def make_dataset_from_arr(
import xarray as xr
import iris

time_dim_name = "time"
has_time = time_dim_num is not None

is_3D = z_dim_num is not None
Expand All @@ -530,10 +531,29 @@ def make_dataset_from_arr(
if has_time:
time_min = datetime.datetime(2022, 1, 1)
time_num = in_arr.shape[time_dim_num]
time_vals = pd.date_range(start=time_min, periods=time_num).values.astype(
"datetime64[s]"
)

if data_type == "xarray":
# add dimension and coordinates
if is_3D:
output_arr = output_arr.rename(
new_name_or_name_dict={"dim_" + str(z_dim_num): z_dim_name}
)
output_arr = output_arr.assign_coords(
{z_dim_name: (z_dim_name, np.arange(0, z_max))}
)
# add dimension and coordinates
if has_time:
output_arr = output_arr.rename(
new_name_or_name_dict={"dim_" + str(time_dim_num): time_dim_name}
)
output_arr = output_arr.assign_coords(
{time_dim_name: (time_dim_name, time_vals)}
)
return output_arr
elif data_type == "iris":
if data_type == "iris":
out_arr_iris = output_arr.to_iris()

if is_3D:
Expand All @@ -544,10 +564,8 @@ def make_dataset_from_arr(
if has_time:
out_arr_iris.add_dim_coord(
iris.coords.DimCoord(
pd.date_range(start=time_min, periods=time_num)
.values.astype("datetime64[s]")
.astype(int),
standard_name="time",
time_vals.astype(int),
standard_name=time_dim_name,
units="seconds since epoch",
),
time_dim_num,
Expand Down
Loading

0 comments on commit 0ebee82

Please sign in to comment.