Skip to content
This repository has been archived by the owner on May 2, 2024. It is now read-only.

Commit

Permalink
Drive vcoord_type and origin_z from grib message (#120)
Browse files Browse the repository at this point in the history
## Purpose

Use the `typeOfLevel` key as the source of truth for the `vcoord_type` attribute. An `extract` function is added to the metadata module to enable this workflow. The `override` function now updates all attributes previously extracted by the `grib_decoder` module. This does not include the `origin_x` and `origin_y` attributes as they need to be derived from a reference field.

## Code changes

- Added `metadata.extract`
- `metadata.override` will call the `extract` function to update attributes in the return value
  • Loading branch information
cfkanesan authored Mar 21, 2024
1 parent 20ba8c3 commit 2328d98
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 73 deletions.
4 changes: 2 additions & 2 deletions src/idpi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def handle_vector_fields(ds):
for u_name, v_name in pairs:
click.echo(f"Rotating vector field components {u_name}, {v_name} to geolatlon")
u, v = ds[u_name], ds[v_name]
if u.origin["x"] != 0.0:
if u.origin_x != 0.0:
u = destagger.destagger(u, "x")
if v.origin["y"] != 0.0:
if v.origin_y != 0.0:
v = destagger.destagger(v, "y")
if u.vref == "native" and v.vref == "native":
u_g, v_g = gis.vref_rot2geolatlon(u, v)
Expand Down
53 changes: 9 additions & 44 deletions src/idpi/grib_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import xarray as xr

# Local
from . import data_source, tasking
from . import data_source, metadata, tasking

logger = logging.getLogger(__name__)

Expand All @@ -26,13 +26,6 @@
"step": "time",
}
INV_DIM_MAP = {v: k for k, v in DIM_MAP.items()}
VCOORD_TYPE = {
"generalVertical": ("model_level", -0.5),
"generalVerticalLayer": ("model_level", 0.0),
"hybrid": ("hybrid", 0.0),
"isobaricInPa": ("pressure", 0.0),
"surface": ("surface", 0.0),
}

Request = str | tuple | dict

Expand Down Expand Up @@ -139,23 +132,6 @@ def _load_pv(self, pv_param: Request):
for field in fs:
return field.metadata("pv")

def _construct_metadata(self, field: typing.Any):
metadata: dict[str, typing.Any] = field.metadata(
namespace=["geography", "parameter"]
)
# https://codes.ecmwf.int/grib/format/grib2/ctables/3/3/
[vref_flag] = get_code_flag(field.metadata("resolutionAndComponentFlags"), [5])
level_type: str = field.metadata("typeOfLevel")
vcoord_type, zshift = VCOORD_TYPE.get(level_type, (level_type, 0.0))

metadata |= {
"vref": "native" if vref_flag else "geo",
"vcoord_type": vcoord_type,
"origin": {"z": zshift},
"message": field.message(),
}
return metadata

def _load_param(
self,
req: Request,
Expand All @@ -164,7 +140,7 @@ def _load_param(
fs = self.data_source.retrieve(req)

hcoords: dict[str, xr.DataArray] = {}
metadata: dict[str, typing.Any] = {}
metadata_values: dict[str, typing.Any] = {}
time_meta: dict[int, dict] = {}
dims: tuple[str, ...] | None = None
field_map: dict[tuple[int, ...], np.ndarray] = {}
Expand All @@ -186,8 +162,11 @@ def _load_param(
if not dims:
dims = tuple(DIM_MAP[d] for d in dim_keys) + ("y", "x")

if not metadata:
metadata = self._construct_metadata(field)
if not metadata_values:
metadata_values = {
"message": field.message(),
**metadata.extract(field.metadata()),
}

if not hcoords:
hcoords = {
Expand All @@ -205,7 +184,7 @@ def _load_param(
np.array([field_map.pop(key) for key in sorted(field_map)]).reshape(shape),
coords=coords | hcoords | tcoords,
dims=dims,
attrs=metadata,
attrs=metadata_values,
)

return (
Expand Down Expand Up @@ -262,19 +241,6 @@ def load_fieldnames(
return self.load(reqs, extract_pv)


def _get_type_of_level(field):
if field.vcoord_type == "model_level":
if field.origin["z"] == 0.0:
return "generalVerticalLayer"
elif field.origin["z"] == -0.5:
return "generalVertical"
else:
raise ValueError(f"Unsupported field origin in z: {field.origin['z']}")
else:
mapping = {vc: name for name, (vc, _) in VCOORD_TYPE.items()}
return mapping.get(field.vcoord_type, field.vcoord_type)


def save(field: xr.DataArray, file_handle: io.BufferedWriter, bits_per_value: int = 16):
"""Write field to file in GRIB format.
Expand Down Expand Up @@ -307,8 +273,7 @@ def save(field: xr.DataArray, file_handle: io.BufferedWriter, bits_per_value: in
}

def to_grib(loc: dict[str, xr.DataArray]):
result = {INV_DIM_MAP[key]: value.item() for key, value in loc.items()}
return result | {"typeOfLevel": _get_type_of_level(field)}
return {INV_DIM_MAP[key]: value.item() for key, value in loc.items()}

for idx_slice in product(*idx.values()):
loc = {dim: value for dim, value in zip(idx.keys(), idx_slice)}
Expand Down
35 changes: 30 additions & 5 deletions src/idpi/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,38 @@
import xarray as xr
from earthkit.data.writers import write # type: ignore

# Local
from . import grib_decoder

VCOORD_TYPE = {
"generalVertical": ("model_level", -0.5),
"generalVerticalLayer": ("model_level", 0.0),
"isobaricInPa": ("pressure", 0.0),
}


def extract(metadata):
[vref_flag] = grib_decoder.get_code_flag(
metadata.get("resolutionAndComponentFlags"), [5]
)
level_type = metadata.get("typeOfLevel")
vcoord_type, zshift = VCOORD_TYPE.get(level_type, (level_type, 0.0))

return {
"parameter": metadata.as_namespace("parameter"),
"geography": metadata.as_namespace("geography"),
"vref": "native" if vref_flag else "geo",
"vcoord_type": vcoord_type,
"origin_z": zshift,
}


def override(message: bytes, **kwargs: typing.Any) -> dict[str, typing.Any]:
"""Override GRIB metadata contained in message.
Note that no special consideration is made for maintaining consistency when
overriding template definition keys such as productDefinitionTemplateNumber.
Note that the origin components in x and y will be unset.
Parameters
----------
Expand All @@ -40,8 +66,7 @@ def override(message: bytes, **kwargs: typing.Any) -> dict[str, typing.Any]:

return {
"message": out.getvalue(),
"geography": md.as_namespace("geography"),
"parameter": md.as_namespace("parameter"),
**extract(md),
}


Expand Down Expand Up @@ -112,8 +137,8 @@ def compute_origin(ref_grid: Grid, field: xr.DataArray) -> dict[str, float]:
y0_key = "latitudeOfFirstGridPointInDegrees"

return {
"x": np.round((geo[x0_key] % 360 - x0) / dx, 1),
"y": np.round((geo[y0_key] - y0) / dy, 1),
"origin_x": np.round((geo[x0_key] % 360 - x0) / dx, 1),
"origin_y": np.round((geo[y0_key] - y0) / dy, 1),
}


Expand All @@ -138,4 +163,4 @@ def set_origin_xy(ds: dict[str, xr.DataArray], ref_param: str) -> None:

ref_grid = load_grid_reference(ds[ref_param].message)
for field in ds.values():
field.attrs["origin"] |= compute_origin(ref_grid, field)
field.attrs |= compute_origin(ref_grid, field)
17 changes: 13 additions & 4 deletions src/idpi/operators/destagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ def _update_grid(field: xr.DataArray, dim: Literal["x", "y"]) -> dict[str, Any]:
)


def _update_vertical(field) -> dict[str, Any]:
if field.vcoord_type != "model_level":
raise ValueError("Field.vcoord_type must model_level")
return metadata.override(
field.message,
typeOfLevel="generalVerticalLayer",
)


def destagger(
field: xr.DataArray,
dim: Literal["x", "y", "z"],
Expand Down Expand Up @@ -128,7 +137,7 @@ def destagger(
"""
dims = list(field.sizes.keys())
if dim == "x" or dim == "y":
if field.origin[dim] != 0.5:
if field.attrs[f"origin_{dim}"] != 0.5:
raise ValueError
return (
xr.apply_ufunc(
Expand All @@ -140,10 +149,10 @@ def destagger(
keep_attrs=True,
)
.transpose(*dims)
.assign_attrs(origin=field.origin | {dim: 0.0}, **_update_grid(field, dim))
.assign_attrs({f"origin_{dim}": 0.0}, **_update_grid(field, dim))
)
elif dim == "z":
if field.origin[dim] != -0.5:
if field.origin_z != -0.5:
raise ValueError
return (
xr.apply_ufunc(
Expand All @@ -155,7 +164,7 @@ def destagger(
keep_attrs=True,
)
.transpose(*dims)
.assign_attrs(origin=field.origin | {dim: 0.0})
.assign_attrs({f"origin_{dim}": 0.0}, **_update_vertical(field))
)

raise ValueError("Unknown dimension", dim)
3 changes: 1 addition & 2 deletions src/idpi/operators/gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ def vref_rot2geolatlon(
x and y components of the vector field w.r.t. the geo lat lon coords.
"""
valid_origin = {d: 0.0 for d in tuple("xyz")}
if u.origin != valid_origin or v.origin != valid_origin:
if u.origin_x != 0.0 or v.origin_y != 0.0:
raise ValueError("The vector fields must be destaggered.")

grid = get_grid(u.geography)
Expand Down
10 changes: 5 additions & 5 deletions src/idpi/operators/vertical_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def interpolate_k2p(
)
# Check that dimensions are the same for field and p_field
if (
field.origin != p_field.origin
field.origin_z != p_field.origin_z
or field.dims != p_field.dims
or field.size != p_field.size
):
Expand Down Expand Up @@ -253,16 +253,16 @@ def interpolate_k2theta(
)

# Check that typeOfLevel is supported and equal for field, th_field, and h_field
if field.vcoord_type != "model_level" or field.origin["z"] != 0.0:
if field.vcoord_type != "model_level" or field.origin_z != 0.0:
raise RuntimeError(
"interpolate_k2theta: field to interpolate must "
"be defined on model_level layers"
)
if th_field.vcoord_type != "model_level" or th_field.origin["z"] != 0.0:
if th_field.vcoord_type != "model_level" or th_field.origin_z != 0.0:
raise RuntimeError(
"interpolate_k2theta: theta field must be defined on model_level layers"
)
if h_field.vcoord_type != "model_level" or h_field.origin["z"] != 0.0:
if h_field.vcoord_type != "model_level" or h_field.origin_z != 0.0:
raise RuntimeError(
"interpolate_k2theta: height field must be defined on model_level layers"
)
Expand Down Expand Up @@ -355,7 +355,7 @@ def interpolate_k2any(
raise ValueError(f"Unsupported mode: {mode}")

for f in (field, tc_field, h_field):
if f.vcoord_type != "model_level" or f.origin["z"] != 0.0:
if f.vcoord_type != "model_level" or f.origin_z != 0.0:
raise ValueError("Input fields must be defined on full model levels")

# ... tc values outside range of meaningful values of height,
Expand Down
7 changes: 2 additions & 5 deletions src/idpi/operators/vertical_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,7 @@ def minmax_k(field, operator, mode, height, h_bounds, hsurf=None):
# levels included in the height interval, and at the interval boundaries
# after linear interpolation wrt height; f and auxiliary height fields
# must either both be defined on full levels or half levels
if (
field.vcoord_type != height.vcoord_type
or field.origin["z"] != height.origin["z"]
):
if field.vcoord_type != height.vcoord_type or field.origin_z != height.origin_z:
raise RuntimeError(
"minmax_k: height is not defined for the same level type as field."
)
Expand Down Expand Up @@ -256,7 +253,7 @@ def integrate_k(field, operator, mode, height, h_bounds, hsurf=None):
"integrate_k: field must be defined for level type "
"generalVertical or generalVerticalLayer"
)
if field.origin["z"] != 0.0:
if field.origin_z != 0.0:
field_on_fl = destagger(field, "z")
else:
field_on_fl = field
Expand Down
4 changes: 1 addition & 3 deletions src/idpi/operators/wind.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def speed(u: xr.DataArray, v: xr.DataArray) -> xr.DataArray:
the horizontal wind speed [m/s].
"""
centered = {dim: 0.0 for dim in tuple("xyz")}

if u.origin != centered or v.origin != centered:
if u.origin_x != 0.0 or v.origin_y != 0.0:
raise ValueError("The wind components should not be staggered.")

name = {"U": "SP", "U_10M": "SP_10M"}[u.parameter["shortName"]]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_idpi/test_destagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def test_destagger(data_dir, fieldextra):
assert_allclose(fs_ds["V"], v, rtol=1e-12, atol=1e-9)
assert_allclose(fs_ds["HFL"], hfl, rtol=1e-12, atol=1e-9)

assert u.origin["x"] == 0.0
assert v.origin["y"] == 0.0
assert hfl.origin["z"] == 0.0
assert u.origin_x == 0.0
assert v.origin_y == 0.0
assert hfl.origin_z == 0.0

0 comments on commit 2328d98

Please sign in to comment.