Skip to content

Commit

Permalink
add ability to save in kwarg whether an iris->xarray conversion happened
Browse files Browse the repository at this point in the history
  • Loading branch information
freemansw1 committed Dec 4, 2023
1 parent 7ded724 commit 477a21e
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 294 deletions.
14 changes: 8 additions & 6 deletions tobac/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ def test_function_kwarg(test_input, kwarg=None):
def test_function_tuple_output(test_input, kwarg=None):
return (test_input, test_input)

decorated_function_kwarg = decorator(test_function_kwarg)
decorated_function_tuple = decorator(test_function_tuple_output)
decorator_i = decorator()
decorated_function_kwarg = decorator_i(test_function_kwarg)
decorated_function_tuple = decorator_i(test_function_tuple_output)

if input_types[0] == xarray.DataArray:
data = xarray.DataArray.from_iris(tobac.testing.make_simple_sample_data_2D())
Expand Down Expand Up @@ -227,15 +228,16 @@ def test_xarray_workflow():
data_xarray = xarray.DataArray.from_iris(deepcopy(data))

# Testing the get_spacings utility
get_spacings_xarray = xarray_to_iris(tobac.utils.get_spacings)
xarray_to_iris_i = xarray_to_iris()
get_spacings_xarray = xarray_to_iris_i(tobac.utils.get_spacings)
dxy, dt = tobac.utils.get_spacings(data)
dxy_xarray, dt_xarray = get_spacings_xarray(data_xarray)

assert dxy == dxy_xarray
assert dt == dt_xarray

# Testing feature detection
feature_detection_xarray = xarray_to_iris(
feature_detection_xarray = xarray_to_iris_i(
tobac.feature_detection.feature_detection_multithreshold
)
features = tobac.feature_detection.feature_detection_multithreshold(
Expand All @@ -246,7 +248,7 @@ def test_xarray_workflow():
assert_frame_equal(features, features_xarray)

# Testing the segmentation
segmentation_xarray = xarray_to_iris(tobac.segmentation.segmentation)
segmentation_xarray = xarray_to_iris_i(tobac.segmentation.segmentation)
mask, features = tobac.segmentation.segmentation(features, data, dxy, threshold=1.0)
mask_xarray, features_xarray = segmentation_xarray(
features_xarray, data_xarray, dxy_xarray, threshold=1.0
Expand All @@ -255,7 +257,7 @@ def test_xarray_workflow():
assert (mask.data == mask_xarray.to_iris().data).all()

# testing tracking
tracking_xarray = xarray_to_iris(tobac.tracking.linking_trackpy)
tracking_xarray = xarray_to_iris_i(tobac.tracking.linking_trackpy)
track = tobac.tracking.linking_trackpy(features, data, dt, dxy, v_max=100.0)
track_xarray = tracking_xarray(
features_xarray, data_xarray, dt_xarray, dxy_xarray, v_max=100.0
Expand Down
2 changes: 1 addition & 1 deletion tobac/utils/bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_statistics(
return features


@decorators.iris_to_xarray
@decorators.iris_to_xarray()
def get_statistics_from_mask(
features: pd.DataFrame,
segmentation_mask: xr.DataArray,
Expand Down
Loading

0 comments on commit 477a21e

Please sign in to comment.