diff --git a/src/idpi/cli.py b/src/idpi/cli.py index 7ac23850..04b58c1b 100644 --- a/src/idpi/cli.py +++ b/src/idpi/cli.py @@ -69,9 +69,9 @@ def handle_vector_fields(ds): for u_name, v_name in pairs: click.echo(f"Rotating vector field components {u_name}, {v_name} to geolatlon") u, v = ds[u_name], ds[v_name] - if u.origin["x"] != 0.0: + if u.origin_x != 0.0: u = destagger.destagger(u, "x") - if v.origin["y"] != 0.0: + if v.origin_y != 0.0: v = destagger.destagger(v, "y") if u.vref == "native" and v.vref == "native": u_g, v_g = gis.vref_rot2geolatlon(u, v) diff --git a/src/idpi/grib_decoder.py b/src/idpi/grib_decoder.py index 4bf49f49..f77f3e8b 100644 --- a/src/idpi/grib_decoder.py +++ b/src/idpi/grib_decoder.py @@ -16,7 +16,7 @@ import xarray as xr # Local -from . import data_source, tasking +from . import data_source, metadata, tasking logger = logging.getLogger(__name__) @@ -26,13 +26,6 @@ "step": "time", } INV_DIM_MAP = {v: k for k, v in DIM_MAP.items()} -VCOORD_TYPE = { - "generalVertical": ("model_level", -0.5), - "generalVerticalLayer": ("model_level", 0.0), - "hybrid": ("hybrid", 0.0), - "isobaricInPa": ("pressure", 0.0), - "surface": ("surface", 0.0), -} Request = str | tuple | dict @@ -139,23 +132,6 @@ def _load_pv(self, pv_param: Request): for field in fs: return field.metadata("pv") - def _construct_metadata(self, field: typing.Any): - metadata: dict[str, typing.Any] = field.metadata( - namespace=["geography", "parameter"] - ) - # https://codes.ecmwf.int/grib/format/grib2/ctables/3/3/ - [vref_flag] = get_code_flag(field.metadata("resolutionAndComponentFlags"), [5]) - level_type: str = field.metadata("typeOfLevel") - vcoord_type, zshift = VCOORD_TYPE.get(level_type, (level_type, 0.0)) - - metadata |= { - "vref": "native" if vref_flag else "geo", - "vcoord_type": vcoord_type, - "origin": {"z": zshift}, - "message": field.message(), - } - return metadata - def _load_param( self, req: Request, @@ -164,7 +140,7 @@ def _load_param( fs = self.data_source.retrieve(req) hcoords: dict[str, xr.DataArray] = {} - metadata: dict[str, typing.Any] = {} + metadata_values: dict[str, typing.Any] = {} time_meta: dict[int, dict] = {} dims: tuple[str, ...] | None = None field_map: dict[tuple[int, ...], np.ndarray] = {} @@ -186,8 +162,11 @@ def _load_param( if not dims: dims = tuple(DIM_MAP[d] for d in dim_keys) + ("y", "x") - if not metadata: - metadata = self._construct_metadata(field) + if not metadata_values: + metadata_values = { + "message": field.message(), + **metadata.extract(field.metadata()), + } if not hcoords: hcoords = { @@ -205,7 +184,7 @@ def _load_param( np.array([field_map.pop(key) for key in sorted(field_map)]).reshape(shape), coords=coords | hcoords | tcoords, dims=dims, - attrs=metadata, + attrs=metadata_values, ) return ( @@ -262,19 +241,6 @@ def load_fieldnames( return self.load(reqs, extract_pv) -def _get_type_of_level(field): - if field.vcoord_type == "model_level": - if field.origin["z"] == 0.0: - return "generalVerticalLayer" - elif field.origin["z"] == -0.5: - return "generalVertical" - else: - raise ValueError(f"Unsupported field origin in z: {field.origin['z']}") - else: - mapping = {vc: name for name, (vc, _) in VCOORD_TYPE.items()} - return mapping.get(field.vcoord_type, field.vcoord_type) - - def save(field: xr.DataArray, file_handle: io.BufferedWriter, bits_per_value: int = 16): """Write field to file in GRIB format. @@ -307,8 +273,7 @@ def save(field: xr.DataArray, file_handle: io.BufferedWriter, bits_per_value: in } def to_grib(loc: dict[str, xr.DataArray]): - result = {INV_DIM_MAP[key]: value.item() for key, value in loc.items()} - return result | {"typeOfLevel": _get_type_of_level(field)} + return {INV_DIM_MAP[key]: value.item() for key, value in loc.items()} for idx_slice in product(*idx.values()): loc = {dim: value for dim, value in zip(idx.keys(), idx_slice)} diff --git a/src/idpi/metadata.py b/src/idpi/metadata.py index 333f4d87..76ef34ce 100644 --- a/src/idpi/metadata.py +++ b/src/idpi/metadata.py @@ -11,12 +11,38 @@ import xarray as xr from earthkit.data.writers import write # type: ignore +# Local +from . import grib_decoder + +VCOORD_TYPE = { + "generalVertical": ("model_level", -0.5), + "generalVerticalLayer": ("model_level", 0.0), + "isobaricInPa": ("pressure", 0.0), +} + + +def extract(metadata): + [vref_flag] = grib_decoder.get_code_flag( + metadata.get("resolutionAndComponentFlags"), [5] + ) + level_type = metadata.get("typeOfLevel") + vcoord_type, zshift = VCOORD_TYPE.get(level_type, (level_type, 0.0)) + + return { + "parameter": metadata.as_namespace("parameter"), + "geography": metadata.as_namespace("geography"), + "vref": "native" if vref_flag else "geo", + "vcoord_type": vcoord_type, + "origin_z": zshift, + } + def override(message: bytes, **kwargs: typing.Any) -> dict[str, typing.Any]: """Override GRIB metadata contained in message. Note that no special consideration is made for maintaining consistency when overriding template definition keys such as productDefinitionTemplateNumber. + Note that the origin components in x and y will be unset. Parameters ---------- @@ -40,8 +66,7 @@ def override(message: bytes, **kwargs: typing.Any) -> dict[str, typing.Any]: return { "message": out.getvalue(), - "geography": md.as_namespace("geography"), - "parameter": md.as_namespace("parameter"), + **extract(md), } @@ -112,8 +137,8 @@ def compute_origin(ref_grid: Grid, field: xr.DataArray) -> dict[str, float]: y0_key = "latitudeOfFirstGridPointInDegrees" return { - "x": np.round((geo[x0_key] % 360 - x0) / dx, 1), - "y": np.round((geo[y0_key] - y0) / dy, 1), + "origin_x": np.round((geo[x0_key] % 360 - x0) / dx, 1), + "origin_y": np.round((geo[y0_key] - y0) / dy, 1), } @@ -138,4 +163,4 @@ def set_origin_xy(ds: dict[str, xr.DataArray], ref_param: str) -> None: ref_grid = load_grid_reference(ds[ref_param].message) for field in ds.values(): - field.attrs["origin"] |= compute_origin(ref_grid, field) + field.attrs |= compute_origin(ref_grid, field) diff --git a/src/idpi/operators/destagger.py b/src/idpi/operators/destagger.py index e95931b9..27746272 100644 --- a/src/idpi/operators/destagger.py +++ b/src/idpi/operators/destagger.py @@ -96,6 +96,15 @@ def _update_grid(field: xr.DataArray, dim: Literal["x", "y"]) -> dict[str, Any]: ) +def _update_vertical(field) -> dict[str, Any]: + if field.vcoord_type != "model_level": + raise ValueError("Field.vcoord_type must model_level") + return metadata.override( + field.message, + typeOfLevel="generalVerticalLayer", + ) + + def destagger( field: xr.DataArray, dim: Literal["x", "y", "z"], @@ -128,7 +137,7 @@ def destagger( """ dims = list(field.sizes.keys()) if dim == "x" or dim == "y": - if field.origin[dim] != 0.5: + if field.attrs[f"origin_{dim}"] != 0.5: raise ValueError return ( xr.apply_ufunc( @@ -140,10 +149,10 @@ def destagger( keep_attrs=True, ) .transpose(*dims) - .assign_attrs(origin=field.origin | {dim: 0.0}, **_update_grid(field, dim)) + .assign_attrs({f"origin_{dim}": 0.0}, **_update_grid(field, dim)) ) elif dim == "z": - if field.origin[dim] != -0.5: + if field.origin_z != -0.5: raise ValueError return ( xr.apply_ufunc( @@ -155,7 +164,7 @@ def destagger( keep_attrs=True, ) .transpose(*dims) - .assign_attrs(origin=field.origin | {dim: 0.0}) + .assign_attrs({f"origin_{dim}": 0.0}, **_update_vertical(field)) ) raise ValueError("Unknown dimension", dim) diff --git a/src/idpi/operators/gis.py b/src/idpi/operators/gis.py index e9f03a92..7de1c079 100644 --- a/src/idpi/operators/gis.py +++ b/src/idpi/operators/gis.py @@ -211,8 +211,7 @@ def vref_rot2geolatlon( x and y components of the vector field w.r.t. the geo lat lon coords. """ - valid_origin = {d: 0.0 for d in tuple("xyz")} - if u.origin != valid_origin or v.origin != valid_origin: + if u.origin_x != 0.0 or v.origin_y != 0.0: raise ValueError("The vector fields must be destaggered.") grid = get_grid(u.geography) diff --git a/src/idpi/operators/vertical_interpolation.py b/src/idpi/operators/vertical_interpolation.py index de7403bc..18709b4b 100644 --- a/src/idpi/operators/vertical_interpolation.py +++ b/src/idpi/operators/vertical_interpolation.py @@ -102,7 +102,7 @@ def interpolate_k2p( ) # Check that dimensions are the same for field and p_field if ( - field.origin != p_field.origin + field.origin_z != p_field.origin_z or field.dims != p_field.dims or field.size != p_field.size ): @@ -253,16 +253,16 @@ def interpolate_k2theta( ) # Check that typeOfLevel is supported and equal for field, th_field, and h_field - if field.vcoord_type != "model_level" or field.origin["z"] != 0.0: + if field.vcoord_type != "model_level" or field.origin_z != 0.0: raise RuntimeError( "interpolate_k2theta: field to interpolate must " "be defined on model_level layers" ) - if th_field.vcoord_type != "model_level" or th_field.origin["z"] != 0.0: + if th_field.vcoord_type != "model_level" or th_field.origin_z != 0.0: raise RuntimeError( "interpolate_k2theta: theta field must be defined on model_level layers" ) - if h_field.vcoord_type != "model_level" or h_field.origin["z"] != 0.0: + if h_field.vcoord_type != "model_level" or h_field.origin_z != 0.0: raise RuntimeError( "interpolate_k2theta: height field must be defined on model_level layers" ) @@ -355,7 +355,7 @@ def interpolate_k2any( raise ValueError(f"Unsupported mode: {mode}") for f in (field, tc_field, h_field): - if f.vcoord_type != "model_level" or f.origin["z"] != 0.0: + if f.vcoord_type != "model_level" or f.origin_z != 0.0: raise ValueError("Input fields must be defined on full model levels") # ... tc values outside range of meaningful values of height, diff --git a/src/idpi/operators/vertical_reduction.py b/src/idpi/operators/vertical_reduction.py index 842374e4..a57e0896 100644 --- a/src/idpi/operators/vertical_reduction.py +++ b/src/idpi/operators/vertical_reduction.py @@ -95,10 +95,7 @@ def minmax_k(field, operator, mode, height, h_bounds, hsurf=None): # levels included in the height interval, and at the interval boundaries # after linear interpolation wrt height; f and auxiliary height fields # must either both be defined on full levels or half levels - if ( - field.vcoord_type != height.vcoord_type - or field.origin["z"] != height.origin["z"] - ): + if field.vcoord_type != height.vcoord_type or field.origin_z != height.origin_z: raise RuntimeError( "minmax_k: height is not defined for the same level type as field." ) @@ -256,7 +253,7 @@ def integrate_k(field, operator, mode, height, h_bounds, hsurf=None): "integrate_k: field must be defined for level type " "generalVertical or generalVerticalLayer" ) - if field.origin["z"] != 0.0: + if field.origin_z != 0.0: field_on_fl = destagger(field, "z") else: field_on_fl = field diff --git a/src/idpi/operators/wind.py b/src/idpi/operators/wind.py index 2d71142b..b03ee747 100644 --- a/src/idpi/operators/wind.py +++ b/src/idpi/operators/wind.py @@ -35,9 +35,7 @@ def speed(u: xr.DataArray, v: xr.DataArray) -> xr.DataArray: the horizontal wind speed [m/s]. """ - centered = {dim: 0.0 for dim in tuple("xyz")} - - if u.origin != centered or v.origin != centered: + if u.origin_x != 0.0 or v.origin_y != 0.0: raise ValueError("The wind components should not be staggered.") name = {"U": "SP", "U_10M": "SP_10M"}[u.parameter["shortName"]] diff --git a/tests/test_idpi/test_destagger.py b/tests/test_idpi/test_destagger.py index 13131460..2ab91a02 100644 --- a/tests/test_idpi/test_destagger.py +++ b/tests/test_idpi/test_destagger.py @@ -27,6 +27,6 @@ def test_destagger(data_dir, fieldextra): assert_allclose(fs_ds["V"], v, rtol=1e-12, atol=1e-9) assert_allclose(fs_ds["HFL"], hfl, rtol=1e-12, atol=1e-9) - assert u.origin["x"] == 0.0 - assert v.origin["y"] == 0.0 - assert hfl.origin["z"] == 0.0 + assert u.origin_x == 0.0 + assert v.origin_y == 0.0 + assert hfl.origin_z == 0.0