Skip to content

Commit 4d54d24

Browse files
committed
Refactors to address @w-k-jones comments
1 parent 8679d44 commit 4d54d24

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

tobac/utils/internal/xarray_utils.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ def find_axis_from_dim_coord(
3737
Returns ValueError if there are more than one matching dimension name or
3838
if the dimension/coordinate isn't found.
3939
"""
40-
41-
dim_axis = find_axis_from_dim(in_da, dim_coord_name)
40+
try:
41+
dim_axis = find_axis_from_dim(in_da, dim_coord_name)
42+
except ValueError:
43+
dim_axis = None
4244

4345
try:
4446
coord_axes = find_axis_from_coord(in_da, dim_coord_name)
@@ -96,7 +98,7 @@ def find_axis_from_dim(in_da: xr.DataArray, dim_name: str) -> Union[int, None]:
9698
"More than one matching dimension. Need to specify which axis number or rename "
9799
"your dimensions."
98100
)
99-
return None
101+
raise ValueError("Dimension not found. ")
100102

101103

102104
def find_axis_from_coord(in_da: xr.DataArray, coord_name: str) -> tuple[int]:
@@ -135,18 +137,18 @@ def find_axis_from_coord(in_da: xr.DataArray, coord_name: str) -> tuple[int]:
135137

136138
if len(all_matching_coords) > 1:
137139
raise ValueError("Too many matching coords")
138-
return tuple()
140+
raise ValueError("No matching coords")
139141

140142

141143
def find_vertical_coord_name(
142-
variable_cube: xr.DataArray,
144+
variable_da: xr.DataArray,
143145
vertical_coord: Union[str, None] = None,
144146
) -> str:
145147
"""Function to find the vertical coordinate in the iris cube
146148
147149
Parameters
148150
----------
149-
variable_cube: xarray.DataArray
151+
variable_da: xarray.DataArray
150152
Input variable cube, containing a vertical coordinate.
151153
vertical_coord: str
152154
Vertical coordinate name. If None, this function tries to auto-detect.
@@ -162,7 +164,7 @@ def find_vertical_coord_name(
162164
Raised if the vertical coordinate isn't found in the cube.
163165
"""
164166

165-
list_coord_names = variable_cube.coords
167+
list_coord_names = variable_da.coords
166168

167169
if vertical_coord is None or vertical_coord == "auto":
168170
# find the intersection
@@ -347,14 +349,20 @@ def add_coordinates_to_features(
347349
hdim2_name_original = variable_da.dims[hdim2_axis]
348350

349351
# generate random names for the new coordinates that are based on i, j, k values
350-
hdim1_name_new = "".join(
351-
random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits)
352-
for _ in range(16)
353-
)
354-
hdim2_name_new = "".join(
355-
random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits)
356-
for _ in range(16)
357-
)
352+
hdim1_name_new = "__temp_hdim1_name"
353+
hdim2_name_new = "__temp_hdim2_name"
354+
vdim_name_new = "__temp_vdim_name"
355+
356+
if (
357+
hdim1_name_new in variable_da.dims
358+
or hdim2_name_new in variable_da.dims
359+
or vdim_name_new in variable_da.dims
360+
):
361+
raise ValueError(
362+
"Cannot have dimensions named {0}, {1}, or {2}".format(
363+
hdim1_name_new, hdim2_name_new, vdim_name_new
364+
)
365+
)
358366

359367
dim_new_names = {
360368
hdim1_name_original: hdim1_name_new,
@@ -367,12 +375,6 @@ def add_coordinates_to_features(
367375

368376
if is_3d:
369377
vdim_name_original = variable_da.dims[vertical_axis]
370-
vdim_name_new = "".join(
371-
random.choice(
372-
string.ascii_uppercase + string.ascii_lowercase + string.digits
373-
)
374-
for _ in range(16)
375-
)
376378
dim_interp_coords[vdim_name_new] = xr.DataArray(
377379
return_feat_df["vdim"].values, dims="features"
378380
)
@@ -383,9 +385,9 @@ def add_coordinates_to_features(
383385
# dataset
384386
renamed_dim_da = variable_da.swap_dims(dim_new_names)
385387
interpolated_df = renamed_dim_da.interp(coords=dim_interp_coords)
386-
interpolated_df = interpolated_df.drop([hdim1_name_new, hdim2_name_new])
387-
if is_3d:
388-
interpolated_df = interpolated_df.drop([vdim_name_new])
388+
interpolated_df = interpolated_df.drop_vars(
389+
[hdim1_name_new, hdim2_name_new, vdim_name_new], errors="ignore"
390+
)
389391
return_feat_df[time_dim_name] = variable_da[time_dim_name].values[
390392
return_feat_df["frame"]
391393
]

0 commit comments

Comments
 (0)