Skip to content

Commit

Permalink
MOBT-689 (Pt 1): Move functionality out of the spot-extract CLI into …
Browse files Browse the repository at this point in the history
…a plugin (#1996)

* Move neighbour_finding_method_name function out of the NeighbourSelection plugin into a new spot data utilities file. This is reused in the spot-extraction CLI, soon to be in the plugin, and should be made more common.

* Move check_grid_match into metadata utilities along with the hash generation code which is more logical. This removes imports from the spot_extract file for other plugins that use this functionality.

* Move all functionality from the spot-extract CLI to a wrapper class that can be invoked by the CLI. Still needs doc-strings, tidying up, any potential rationalisation, and unit tests for the new class.

* Add unit tests for the SpotManipulation plugin which has taken on logic and calling of other plugins that was previously done in the CLI layer.

* Fix up dz_rescaling code and apply_height_adjustment CLI which use the neighbour_finding_method_name method which was previously part of the NeighbourFinding plugin but which is now a standalone utility. Update the plugin which the spot_extraction CLI calls.

* Resolves doc-string error highlighted by sphinx.

* Change test paramerisation to help with testing spot forecast subsetting in subsequent PR.

* Additional tests for realization collapse and for percentile extraction directly from realization data. Modifies the test in spot_manipulation for selecting the percentile method to check if data is actually masked, rather than just of masked type. This will enable the fast method to be used in more cases.

* Rename neighbour_finding_method_name to get_neighbour_finding_method_name.

* Review changes.
  • Loading branch information
bayliffe authored May 28, 2024
1 parent 689cded commit dee4f9b
Show file tree
Hide file tree
Showing 14 changed files with 793 additions and 316 deletions.
7 changes: 3 additions & 4 deletions improver/calibration/dz_rescaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from improver.calibration.utilities import filter_non_matching_cubes
from improver.constants import SECONDS_IN_HOUR
from improver.metadata.constants.time_types import TIME_COORDS
from improver.spotdata.neighbour_finding import NeighbourSelection
from improver.spotdata.utilities import get_neighbour_finding_method_name


class EstimateDzRescaling(PostProcessingPlugin):
Expand Down Expand Up @@ -74,10 +74,9 @@ def __init__(
# Please see numpy.polynomial.polynomial.Polynomial.fit for further information.
self.polyfit_deg = 1

self.neighbour_selection_method = NeighbourSelection(
self.neighbour_selection_method = get_neighbour_finding_method_name(
land_constraint=land_constraint, minimum_dz=similar_altitude
).neighbour_finding_method_name()

)
self.site_id_coord = site_id_coord

def _fit_polynomial(self, forecasts: Cube, truths: Cube, dz: Cube) -> float:
Expand Down
6 changes: 3 additions & 3 deletions improver/cli/apply_height_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def process(
height
"""
from improver.spotdata.height_adjustment import SpotHeightAdjustment
from improver.spotdata.neighbour_finding import NeighbourSelection
from improver.spotdata.utilities import get_neighbour_finding_method_name

neighbour_selection_method = NeighbourSelection(
neighbour_selection_method = get_neighbour_finding_method_name(
land_constraint=land_constraint, minimum_dz=similar_altitude
).neighbour_finding_method_name()
)

result = SpotHeightAdjustment(neighbour_selection_method)(spot_cube, neighbour)
return result
130 changes: 21 additions & 109 deletions improver/cli/spot_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def process(
And the neighbour cube is a cube of spot-data neighbours and
the spot site information.
apply_lapse_rate_correction (bool):
Use to apply a lapse-rate correction to screen temperature data so
that the data are a better match the altitude of the spot site for
which they have been extracted. This lapse rate will be applied for
a fixed orographic difference between the site and gridpoint
altitude. Differences in orography in excess of this fixed limit
will use the Environmental Lapse Rate (also known as the Standard
Atmosphere Lapse Rate).
Use to apply a lapse-rate correction to screen temperature
forecasts so that they better represent the altitude of the
spot site for which they have been extracted. This lapse rate
will be applied for a fixed orographic difference between the
site and grid point altitude. Differences in orography in
excess of this fixed limit will use the Environmental Lapse
Rate (also known as the Standard Atmosphere Lapse Rate).
fixed_lapse_rate (float):
If provided, use this fixed value as a lapse-rate for adjusting
the forecast values if apply_lapse_rate_correction is True. This
Expand Down Expand Up @@ -105,106 +105,18 @@ def process(
Returns:
iris.cube.Cube:
Cube of spot data.
Warns:
warning:
If diagnostic cube is not a known probabilistic type.
warning:
If a lapse rate cube was not provided, but the option to apply
the lapse rate correction was enabled.
"""

import warnings

import iris
import numpy as np
from iris.exceptions import CoordinateNotFoundError

from improver.ensemble_copula_coupling.ensemble_copula_coupling import (
ConvertProbabilitiesToPercentiles,
ResamplePercentiles,
)
from improver.metadata.probabilistic import find_percentile_coordinate
from improver.percentile import PercentileConverter
from improver.spotdata.apply_lapse_rate import SpotLapseRateAdjust
from improver.spotdata.neighbour_finding import NeighbourSelection
from improver.spotdata.spot_extraction import SpotExtraction
from improver.utilities.cube_extraction import extract_subcube
from improver.utilities.cube_manipulation import collapse_realizations

neighbour_cube = cubes[-1]
cube = cubes[0]

if realization_collapse:
cube = collapse_realizations(cube)
neighbour_selection_method = NeighbourSelection(
land_constraint=land_constraint, minimum_dz=similar_altitude
).neighbour_finding_method_name()
result = SpotExtraction(neighbour_selection_method=neighbour_selection_method)(
neighbour_cube, cube, new_title=new_title
)

# If a probability or percentile diagnostic cube is provided, extract
# the given percentile if available. This is done after the spot-extraction
# to minimise processing time; usually there are far fewer spot sites than
# grid points.
if extract_percentiles:
extract_percentiles = [np.float32(x) for x in extract_percentiles]
try:
perc_coordinate = find_percentile_coordinate(result)
except CoordinateNotFoundError:
if "probability_of_" in result.name():
result = ConvertProbabilitiesToPercentiles(
ecc_bounds_warning=ignore_ecc_bounds_exceedance,
skip_ecc_bounds=skip_ecc_bounds,
)(result, percentiles=extract_percentiles)
result = iris.util.squeeze(result)
elif result.coords("realization", dim_coords=True):
fast_percentile_method = not np.ma.isMaskedArray(result.data)
result = PercentileConverter(
"realization",
percentiles=extract_percentiles,
fast_percentile_method=fast_percentile_method,
)(result)
else:
msg = (
"Diagnostic cube is not a known probabilistic type. "
"The {} percentile could not be extracted. Extracting "
"data from the cube including any leading "
"dimensions.".format(extract_percentiles)
)
if not suppress_warnings:
warnings.warn(msg)
else:
if set(extract_percentiles).issubset(perc_coordinate.points):
constraint = [
"{}={}".format(perc_coordinate.name(), extract_percentiles)
]
result = extract_subcube(result, constraint)
else:
result = ResamplePercentiles()(result, percentiles=extract_percentiles)

# Check whether a lapse rate cube has been provided
if apply_lapse_rate_correction:
if len(cubes) == 3:
plugin = SpotLapseRateAdjust(
neighbour_selection_method=neighbour_selection_method
)
result = plugin(result, neighbour_cube, cubes[-2])
elif fixed_lapse_rate is not None:
plugin = SpotLapseRateAdjust(
neighbour_selection_method=neighbour_selection_method,
fixed_lapse_rate=fixed_lapse_rate,
)
result = plugin(result, neighbour_cube)
elif not suppress_warnings:
warnings.warn(
"A lapse rate cube or fixed lapse rate was not provided, but the "
"option to apply the lapse rate correction was enabled. No lapse rate "
"correction could be applied."
)

# Remove the internal model_grid_hash attribute if present.
result.attributes.pop("model_grid_hash", None)
return result
from improver.spotdata.spot_manipulation import SpotManipulation

return SpotManipulation(
apply_lapse_rate_correction,
fixed_lapse_rate,
land_constraint,
similar_altitude,
extract_percentiles,
ignore_ecc_bounds_exceedance,
skip_ecc_bounds,
new_title,
suppress_warnings,
realization_collapse,
)(cubes)
39 changes: 38 additions & 1 deletion improver/metadata/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import iris
import numpy as np
from cf_units import Unit
from iris.cube import Cube
from iris.cube import Cube, CubeList
from numpy import ndarray
from numpy.ma.core import MaskedArray

Expand Down Expand Up @@ -177,6 +177,43 @@ def create_coordinate_hash(cube: Cube) -> str:
return generate_hash(hashable_data)


def check_grid_match(cubes: Union[List[Cube], CubeList]) -> None:
"""
Checks that cubes are on, or originate from, compatible coordinate grids.
Each cube is first checked for an existing 'model_grid_hash' which can be
used to encode coordinate information on cubes that do not themselves
contain a coordinate grid (e.g. spotdata cubes). If this is not found a new
hash is generated to enable comparison. If the cubes are not compatible, an
exception is raised to prevent the use of unmatched cubes.
Args:
cubes:
A list of cubes to check for grid compatibility.
Raises:
ValueError: Raised if the cubes are not on matching grids as
identified by the model_grid_hash.
"""

def _get_grid_hash(cube):
try:
cube_hash = cube.attributes["model_grid_hash"]
except KeyError:
cube_hash = create_coordinate_hash(cube)
return cube_hash

cubes = iter(cubes)
reference_hash = _get_grid_hash(next(cubes))

for cube in cubes:
cube_hash = _get_grid_hash(cube)
if cube_hash != reference_hash:
raise ValueError(
"Cubes do not share or originate from the same "
"grid, so cannot be used together."
)


def get_model_id_attr(cubes: List[Cube], model_id_attr: str) -> str:
"""
Gets the specified model ID attribute from a list of input cubes, checking
Expand Down
3 changes: 2 additions & 1 deletion improver/spotdata/apply_lapse_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from improver import PostProcessingPlugin
from improver.lapse_rate import compute_lapse_rate_adjustment
from improver.metadata.probabilistic import is_probability
from improver.spotdata.spot_extraction import SpotExtraction, check_grid_match
from improver.metadata.utilities import check_grid_match
from improver.spotdata.spot_extraction import SpotExtraction


class SpotLapseRateAdjust(PostProcessingPlugin):
Expand Down
22 changes: 5 additions & 17 deletions improver/spotdata/neighbour_finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from improver.spotdata.build_spotdata_cube import build_spotdata_cube
from improver.utilities.cube_manipulation import enforce_coordinate_ordering

from .utilities import get_neighbour_finding_method_name


class NeighbourSelection(BasePlugin):
"""
Expand Down Expand Up @@ -114,22 +116,6 @@ def __repr__(self) -> str:
self.node_limit,
)

def neighbour_finding_method_name(self) -> str:
"""
Create a name to describe the neighbour method based on the constraints
provided.
Returns:
A string that describes the neighbour finding method employed.
This is essentially a concatenation of the options.
"""
method_name = "{}{}{}".format(
"nearest",
"_land" if self.land_constraint else "",
"_minimum_dz" if self.minimum_dz else "",
)
return method_name

def _transform_sites_coordinate_system(
self, x_points: ndarray, y_points: ndarray, target_crs: CRS
) -> ndarray:
Expand Down Expand Up @@ -610,7 +596,9 @@ def process(
)

# Construct a name to describe the neighbour finding method employed
method_name = self.neighbour_finding_method_name()
method_name = get_neighbour_finding_method_name(
self.land_constraint, self.minimum_dz
)

# Create an array of indices and displacements to return
data = np.stack(
Expand Down
41 changes: 2 additions & 39 deletions improver/spotdata/spot_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import iris
import numpy as np
from iris.coords import AuxCoord, DimCoord
from iris.cube import Cube, CubeList
from iris.cube import Cube
from numpy import ndarray

from improver import BasePlugin
from improver.metadata.constants.attributes import MANDATORY_ATTRIBUTE_DEFAULTS
from improver.metadata.constants.mo_attributes import MOSG_GRID_ATTRIBUTES
from improver.metadata.utilities import create_coordinate_hash
from improver.metadata.utilities import check_grid_match
from improver.spotdata.build_spotdata_cube import build_spotdata_cube
from improver.utilities.cube_manipulation import enforce_coordinate_ordering

Expand Down Expand Up @@ -341,40 +341,3 @@ def process(
spotdata_cube.cell_methods = diagnostic_cube.cell_methods

return spotdata_cube


def check_grid_match(cubes: Union[List[Cube], CubeList]) -> None:
"""
Checks that cubes are on, or originate from, compatible coordinate grids.
Each cube is first checked for an existing 'model_grid_hash' which can be
used to encode coordinate information on cubes that do not themselves
contain a coordinate grid (e.g. spotdata cubes). If this is not found a new
hash is generated to enable comparison. If the cubes are not compatible, an
exception is raised to prevent the use of unmatched cubes.
Args:
cubes:
A list of cubes to check for grid compatibility.
Raises:
ValueError: Raised if the cubes are not on matching grids as
identified by the model_grid_hash.
"""

def _get_grid_hash(cube):
try:
cube_hash = cube.attributes["model_grid_hash"]
except KeyError:
cube_hash = create_coordinate_hash(cube)
return cube_hash

cubes = iter(cubes)
reference_hash = _get_grid_hash(next(cubes))

for cube in cubes:
cube_hash = _get_grid_hash(cube)
if cube_hash != reference_hash:
raise ValueError(
"Cubes do not share or originate from the same "
"grid, so cannot be used together."
)
Loading

0 comments on commit dee4f9b

Please sign in to comment.