From 477a21e33612b1e2868a331b1fca6899f40a6870 Mon Sep 17 00:00:00 2001 From: Sean Freeman Date: Mon, 4 Dec 2023 15:51:10 -0600 Subject: [PATCH] add ability to save in kwarg whether an iris->xarray conversion happened --- tobac/tests/test_convert.py | 14 +- tobac/utils/bulk_statistics.py | 2 +- tobac/utils/decorators.py | 595 +++++++++++++++++---------------- tobac/utils/general.py | 2 +- tobac/utils/internal/basic.py | 2 +- 5 files changed, 321 insertions(+), 294 deletions(-) diff --git a/tobac/tests/test_convert.py b/tobac/tests/test_convert.py index d5b6bbe8..ca2228fd 100644 --- a/tobac/tests/test_convert.py +++ b/tobac/tests/test_convert.py @@ -174,8 +174,9 @@ def test_function_kwarg(test_input, kwarg=None): def test_function_tuple_output(test_input, kwarg=None): return (test_input, test_input) - decorated_function_kwarg = decorator(test_function_kwarg) - decorated_function_tuple = decorator(test_function_tuple_output) + decorator_i = decorator() + decorated_function_kwarg = decorator_i(test_function_kwarg) + decorated_function_tuple = decorator_i(test_function_tuple_output) if input_types[0] == xarray.DataArray: data = xarray.DataArray.from_iris(tobac.testing.make_simple_sample_data_2D()) @@ -227,7 +228,8 @@ def test_xarray_workflow(): data_xarray = xarray.DataArray.from_iris(deepcopy(data)) # Testing the get_spacings utility - get_spacings_xarray = xarray_to_iris(tobac.utils.get_spacings) + xarray_to_iris_i = xarray_to_iris() + get_spacings_xarray = xarray_to_iris_i(tobac.utils.get_spacings) dxy, dt = tobac.utils.get_spacings(data) dxy_xarray, dt_xarray = get_spacings_xarray(data_xarray) @@ -235,7 +237,7 @@ def test_xarray_workflow(): assert dt == dt_xarray # Testing feature detection - feature_detection_xarray = xarray_to_iris( + feature_detection_xarray = xarray_to_iris_i( tobac.feature_detection.feature_detection_multithreshold ) features = tobac.feature_detection.feature_detection_multithreshold( @@ -246,7 +248,7 @@ def test_xarray_workflow(): assert_frame_equal(features, features_xarray) # Testing the segmentation - segmentation_xarray = xarray_to_iris(tobac.segmentation.segmentation) + segmentation_xarray = xarray_to_iris_i(tobac.segmentation.segmentation) mask, features = tobac.segmentation.segmentation(features, data, dxy, threshold=1.0) mask_xarray, features_xarray = segmentation_xarray( features_xarray, data_xarray, dxy_xarray, threshold=1.0 @@ -255,7 +257,7 @@ def test_xarray_workflow(): assert (mask.data == mask_xarray.to_iris().data).all() # testing tracking - tracking_xarray = xarray_to_iris(tobac.tracking.linking_trackpy) + tracking_xarray = xarray_to_iris_i(tobac.tracking.linking_trackpy) track = tobac.tracking.linking_trackpy(features, data, dt, dxy, v_max=100.0) track_xarray = tracking_xarray( features_xarray, data_xarray, dt_xarray, dxy_xarray, v_max=100.0 diff --git a/tobac/utils/bulk_statistics.py b/tobac/utils/bulk_statistics.py index f27f9fd5..55fe17fc 100644 --- a/tobac/utils/bulk_statistics.py +++ b/tobac/utils/bulk_statistics.py @@ -147,7 +147,7 @@ def get_statistics( return features -@decorators.iris_to_xarray +@decorators.iris_to_xarray() def get_statistics_from_mask( features: pd.DataFrame, segmentation_mask: xr.DataArray, diff --git a/tobac/utils/decorators.py b/tobac/utils/decorators.py index 33b012f7..bd9315d2 100644 --- a/tobac/utils/decorators.py +++ b/tobac/utils/decorators.py @@ -5,344 +5,369 @@ import warnings -def iris_to_xarray(func): - """Decorator that converts all input of a function that is in the form of - Iris cubes into xarray DataArrays and converts all outputs with type - xarray DataArrays back into Iris cubes. - - Parameters - ---------- - func : function - Function to be decorated - - Returns - ------- - wrapper : function - Function including decorator - """ - - import iris - import xarray - - @functools.wraps(func) - def wrapper(*args, **kwargs): - # print(kwargs) - if any([type(arg) == iris.cube.Cube for arg in args]) or any( - [type(arg) == iris.cube.Cube for arg in kwargs.values()] - ): - # print("converting iris to xarray and back") - args = tuple( - [ - xarray.DataArray.from_iris(arg) - if type(arg) == iris.cube.Cube - else arg - for arg in args - ] - ) - kwargs_new = dict( - zip( - kwargs.keys(), +def iris_to_xarray(): + def iris_to_xarray_i(func): + """Decorator that converts all input of a function that is in the form of + Iris cubes into xarray DataArrays and converts all outputs with type + xarray DataArrays back into Iris cubes. + + Parameters + ---------- + func : function + Function to be decorated + + Returns + ------- + wrapper : function + Function including decorator + """ + + import iris + import xarray + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # print(kwargs) + if any([type(arg) == iris.cube.Cube for arg in args]) or any( + [type(arg) == iris.cube.Cube for arg in kwargs.values()] + ): + # print("converting iris to xarray and back") + args = tuple( [ xarray.DataArray.from_iris(arg) if type(arg) == iris.cube.Cube else arg - for arg in kwargs.values() - ], - ) - ) - # print(args) - # print(kwargs) - output = func(*args, **kwargs_new) - if type(output) == tuple: - output = tuple( - [ - xarray.DataArray.to_iris(output_item) - if type(output_item) == xarray.DataArray - else output_item - for output_item in output + for arg in args ] ) - elif type(output) == xarray.DataArray: - output = xarray.DataArray.to_iris(output) - # if output is neither tuple nor an xr.DataArray + kwargs_new = dict( + zip( + kwargs.keys(), + [ + xarray.DataArray.from_iris(arg) + if type(arg) == iris.cube.Cube + else arg + for arg in kwargs.values() + ], + ) + ) + # print(args) + # print(kwargs) + output = func(*args, **kwargs_new) + if type(output) == tuple: + output = tuple( + [ + xarray.DataArray.to_iris(output_item) + if type(output_item) == xarray.DataArray + else output_item + for output_item in output + ] + ) + elif type(output) == xarray.DataArray: + output = xarray.DataArray.to_iris(output) + # if output is neither tuple nor an xr.DataArray + else: + output = func(*args, **kwargs) + else: output = func(*args, **kwargs) + return output - else: - output = func(*args, **kwargs) - return output - - return wrapper + return wrapper + return iris_to_xarray_i -def xarray_to_iris(func): - """Decorator that converts all input of a function that is in the form of - xarray DataArrays into Iris cubes and converts all outputs with type - Iris cubes back into xarray DataArrays. - Parameters - ---------- - func : function - Function to be decorated. +def xarray_to_iris(): + def xarray_to_iris_i(func): + """Decorator that converts all input of a function that is in the form of + xarray DataArrays into Iris cubes and converts all outputs with type + Iris cubes back into xarray DataArrays. - Returns - ------- - wrapper : function - Function including decorator. + Parameters + ---------- + func : function + Function to be decorated. - Examples - -------- - >>> segmentation_xarray = xarray_to_iris(segmentation) + Returns + ------- + wrapper : function + Function including decorator. - This line creates a new function that can process xarray fields and - also outputs fields in xarray format, but otherwise works just like - the original function: + Examples + -------- + >>> segmentation_xarray = xarray_to_iris(segmentation) - >>> mask_xarray, features = segmentation_xarray( - features, data_xarray, dxy, threshold - ) - """ + This line creates a new function that can process xarray fields and + also outputs fields in xarray format, but otherwise works just like + the original function: - import iris - import xarray - - @functools.wraps(func) - def wrapper(*args, **kwargs): - # print(args) - # print(kwargs) - if any([type(arg) == xarray.DataArray for arg in args]) or any( - [type(arg) == xarray.DataArray for arg in kwargs.values()] - ): - # print("converting xarray to iris and back") - args = tuple( - [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg - for arg in args - ] + >>> mask_xarray, features = segmentation_xarray( + features, data_xarray, dxy, threshold ) - if kwargs: - kwargs_new = dict( - zip( - kwargs.keys(), - [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg - for arg in kwargs.values() - ], - ) - ) - else: - kwargs_new = kwargs + """ + + import iris + import xarray + + @functools.wraps(func) + def wrapper(*args, **kwargs): # print(args) # print(kwargs) - output = func(*args, **kwargs_new) - if type(output) == tuple: - output = tuple( + if any([type(arg) == xarray.DataArray for arg in args]) or any( + [type(arg) == xarray.DataArray for arg in kwargs.values()] + ): + # print("converting xarray to iris and back") + args = tuple( [ - xarray.DataArray.from_iris(output_item) - if type(output_item) == iris.cube.Cube - else output_item - for output_item in output + xarray.DataArray.to_iris(arg) + if type(arg) == xarray.DataArray + else arg + for arg in args ] ) - else: - if type(output) == iris.cube.Cube: - output = xarray.DataArray.from_iris(output) - - else: - output = func(*args, **kwargs) - # print(output) - return output - - return wrapper - - -def irispandas_to_xarray(func): - """Decorator that converts all input of a function that is in the form of - Iris cubes/pandas Dataframes into xarray DataArrays/xarray Datasets and - converts all outputs with the type xarray DataArray/xarray Dataset - back into Iris cubes/pandas Dataframes. + if kwargs: + kwargs_new = dict( + zip( + kwargs.keys(), + [ + xarray.DataArray.to_iris(arg) + if type(arg) == xarray.DataArray + else arg + for arg in kwargs.values() + ], + ) + ) + else: + kwargs_new = kwargs + # print(args) + # print(kwargs) + output = func(*args, **kwargs_new) + if type(output) == tuple: + output = tuple( + [ + xarray.DataArray.from_iris(output_item) + if type(output_item) == iris.cube.Cube + else output_item + for output_item in output + ] + ) + else: + if type(output) == iris.cube.Cube: + output = xarray.DataArray.from_iris(output) - Parameters - ---------- - func : function - Function to be decorated. + else: + output = func(*args, **kwargs) + # print(output) + return output + + return wrapper + + return xarray_to_iris_i + + +def irispandas_to_xarray(save_iris_info: bool = False): + def irispandas_to_xarray_i(func): + """Decorator that converts all input of a function that is in the form of + Iris cubes/pandas Dataframes into xarray DataArrays/xarray Datasets and + converts all outputs with the type xarray DataArray/xarray Dataset + back into Iris cubes/pandas Dataframes. + + Parameters + ---------- + func : function + Function to be decorated. + + Returns + ------- + wrapper : function + Function including decorator. + """ + import iris + import iris.cube + import xarray + import pandas as pd + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # pass if we did an iris conversion. + if save_iris_info: + if any([(type(arg) == iris.cube.Cube) for arg in args]) or any( + [(type(arg) == iris.cube.Cube) for arg in kwargs.values()] + ): + kwargs["converted_from_iris"] = True + else: + kwargs["converted_from_iris"] = False - Returns - ------- - wrapper : function - Function including decorator. - """ - import iris - import xarray - import pandas as pd - - @functools.wraps(func) - def wrapper(*args, **kwargs): - # print(kwargs) - if any( - [(type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame) for arg in args] - ) or any( - [ - (type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame) - for arg in kwargs.values() - ] - ): - # print("converting iris to xarray and back") - args = tuple( + # print(kwargs) + if any( [ - xarray.DataArray.from_iris(arg) - if type(arg) == iris.cube.Cube - else arg.to_xarray() - if type(arg) == pd.DataFrame - else arg + (type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame) for arg in args ] - ) - kwargs = dict( - zip( - kwargs.keys(), + ) or any( + [ + (type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame) + for arg in kwargs.values() + ] + ): + # print("converting iris to xarray and back") + args = tuple( [ xarray.DataArray.from_iris(arg) if type(arg) == iris.cube.Cube else arg.to_xarray() if type(arg) == pd.DataFrame else arg - for arg in kwargs.values() - ], - ) - ) - - output = func(*args, **kwargs) - if type(output) == tuple: - output = tuple( - [ - xarray.DataArray.to_iris(output_item) - if type(output_item) == xarray.DataArray - else output_item.to_dataframe() - if type(output_item) == xarray.Dataset - else output_item - for output_item in output + for arg in args ] ) - else: - if type(output) == xarray.DataArray: - output = xarray.DataArray.to_iris(output) - elif type(output) == xarray.Dataset: - output = output.to_dataframe() - - else: - output = func(*args, **kwargs) - return output - - return wrapper - - -def xarray_to_irispandas(func): - """Decorator that converts all input of a function that is in the form of - DataArrays/xarray Datasets into xarray Iris cubes/pandas Dataframes and - converts all outputs with the type Iris cubes/pandas Dataframes back into - xarray DataArray/xarray Dataset. - - Parameters - ---------- - func : function - Function to be decorated. - - Returns - ------- - wrapper : function - Function including decorator. - - Examples - -------- - >>> linking_trackpy_xarray = xarray_to_irispandas( - linking_trackpy - ) - - This line creates a new function that can process xarray inputs and - also outputs in xarray formats, but otherwise works just like the - original function: - - >>> track_xarray = linking_trackpy_xarray( - features_xarray, field_xarray, dt, dx - ) - """ - import iris - import xarray - import pandas as pd - - @functools.wraps(func) - def wrapper(*args, **kwargs): - # print(args) - # print(kwargs) - if any( - [ - (type(arg) == xarray.DataArray or type(arg) == xarray.Dataset) - for arg in args - ] - ) or any( - [ - (type(arg) == xarray.DataArray or type(arg) == xarray.Dataset) - for arg in kwargs.values() - ] - ): - # print("converting xarray to iris and back") - args = tuple( - [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg.to_dataframe() - if type(arg) == xarray.Dataset - else arg - for arg in args - ] - ) - if kwargs: - kwargs_new = dict( + kwargs = dict( zip( kwargs.keys(), [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg.to_dataframe() - if type(arg) == xarray.Dataset + xarray.DataArray.from_iris(arg) + if type(arg) == iris.cube.Cube + else arg.to_xarray() + if type(arg) == pd.DataFrame else arg for arg in kwargs.values() ], ) ) + + output = func(*args, **kwargs) + if type(output) == tuple: + output = tuple( + [ + xarray.DataArray.to_iris(output_item) + if type(output_item) == xarray.DataArray + else output_item.to_dataframe() + if type(output_item) == xarray.Dataset + else output_item + for output_item in output + ] + ) + else: + if type(output) == xarray.DataArray: + output = xarray.DataArray.to_iris(output) + elif type(output) == xarray.Dataset: + output = output.to_dataframe() + else: - kwargs_new = kwargs + output = func(*args, **kwargs) + return output + + return wrapper + + return irispandas_to_xarray_i + + +def xarray_to_irispandas(): + def xarray_to_irispandas_i(func): + """Decorator that converts all input of a function that is in the form of + DataArrays/xarray Datasets into xarray Iris cubes/pandas Dataframes and + converts all outputs with the type Iris cubes/pandas Dataframes back into + xarray DataArray/xarray Dataset. + + Parameters + ---------- + func : function + Function to be decorated. + + Returns + ------- + wrapper : function + Function including decorator. + + Examples + -------- + >>> linking_trackpy_xarray = xarray_to_irispandas( + linking_trackpy + ) + + This line creates a new function that can process xarray inputs and + also outputs in xarray formats, but otherwise works just like the + original function: + + >>> track_xarray = linking_trackpy_xarray( + features_xarray, field_xarray, dt, dx + ) + """ + import iris + import xarray + import pandas as pd + + @functools.wraps(func) + def wrapper(*args, **kwargs): # print(args) # print(kwargs) - output = func(*args, **kwargs_new) - if type(output) == tuple: - output = tuple( + if any( + [ + (type(arg) == xarray.DataArray or type(arg) == xarray.Dataset) + for arg in args + ] + ) or any( + [ + (type(arg) == xarray.DataArray or type(arg) == xarray.Dataset) + for arg in kwargs.values() + ] + ): + # print("converting xarray to iris and back") + args = tuple( [ - xarray.DataArray.from_iris(output_item) - if type(output_item) == iris.cube.Cube - else output_item.to_xarray() - if type(output_item) == pd.DataFrame - else output_item - for output_item in output + xarray.DataArray.to_iris(arg) + if type(arg) == xarray.DataArray + else arg.to_dataframe() + if type(arg) == xarray.Dataset + else arg + for arg in args ] ) + if kwargs: + kwargs_new = dict( + zip( + kwargs.keys(), + [ + xarray.DataArray.to_iris(arg) + if type(arg) == xarray.DataArray + else arg.to_dataframe() + if type(arg) == xarray.Dataset + else arg + for arg in kwargs.values() + ], + ) + ) + else: + kwargs_new = kwargs + # print(args) + # print(kwargs) + output = func(*args, **kwargs_new) + if type(output) == tuple: + output = tuple( + [ + xarray.DataArray.from_iris(output_item) + if type(output_item) == iris.cube.Cube + else output_item.to_xarray() + if type(output_item) == pd.DataFrame + else output_item + for output_item in output + ] + ) + else: + if type(output) == iris.cube.Cube: + output = xarray.DataArray.from_iris(output) + elif type(output) == pd.DataFrame: + output = output.to_xarray() + else: - if type(output) == iris.cube.Cube: - output = xarray.DataArray.from_iris(output) - elif type(output) == pd.DataFrame: - output = output.to_xarray() + output = func(*args, **kwargs) + # print(output) + return output - else: - output = func(*args, **kwargs) - # print(output) - return output + return wrapper - return wrapper + return xarray_to_irispandas_i def njit_if_available(func, **kwargs): diff --git a/tobac/utils/general.py b/tobac/utils/general.py index 44f0177a..66a31813 100644 --- a/tobac/utils/general.py +++ b/tobac/utils/general.py @@ -637,7 +637,7 @@ def combine_feature_dataframes( return combined_sorted -@internal_utils.irispandas_to_xarray +@internal_utils.irispandas_to_xarray() def transform_feature_points( features, new_dataset, diff --git a/tobac/utils/internal/basic.py b/tobac/utils/internal/basic.py index 8dc041de..0efd28a5 100644 --- a/tobac/utils/internal/basic.py +++ b/tobac/utils/internal/basic.py @@ -285,7 +285,7 @@ def find_axis_from_coord( raise ValueError("variable_arr must be Iris Cube or Xarray DataArray") -@irispandas_to_xarray +@irispandas_to_xarray() def detect_latlon_coord_name( in_dataset: Union[xr.DataArray, iris.cube.Cube], latitude_name: Union[str, None] = None,