Skip to content
This repository has been archived by the owner on May 2, 2024. It is now read-only.

Drive vcoord_type and origin_z from grib message #120

Merged
merged 5 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/idpi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def handle_vector_fields(ds):
for u_name, v_name in pairs:
click.echo(f"Rotating vector field components {u_name}, {v_name} to geolatlon")
u, v = ds[u_name], ds[v_name]
if u.origin["x"] != 0.0:
if u.origin_x != 0.0:
u = destagger.destagger(u, "x")
if v.origin["y"] != 0.0:
if v.origin_y != 0.0:
v = destagger.destagger(v, "y")
if u.vref == "native" and v.vref == "native":
u_g, v_g = gis.vref_rot2geolatlon(u, v)
Expand Down
53 changes: 9 additions & 44 deletions src/idpi/grib_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import xarray as xr

# Local
from . import data_source, tasking
from . import data_source, metadata, tasking

logger = logging.getLogger(__name__)

Expand All @@ -26,13 +26,6 @@
"step": "time",
}
INV_DIM_MAP = {v: k for k, v in DIM_MAP.items()}
VCOORD_TYPE = {
"generalVertical": ("model_level", -0.5),
"generalVerticalLayer": ("model_level", 0.0),
"hybrid": ("hybrid", 0.0),
"isobaricInPa": ("pressure", 0.0),
"surface": ("surface", 0.0),
}

Request = str | tuple | dict

Expand Down Expand Up @@ -139,23 +132,6 @@ def _load_pv(self, pv_param: Request):
for field in fs:
return field.metadata("pv")

def _construct_metadata(self, field: typing.Any):
metadata: dict[str, typing.Any] = field.metadata(
namespace=["geography", "parameter"]
)
# https://codes.ecmwf.int/grib/format/grib2/ctables/3/3/
[vref_flag] = get_code_flag(field.metadata("resolutionAndComponentFlags"), [5])
level_type: str = field.metadata("typeOfLevel")
vcoord_type, zshift = VCOORD_TYPE.get(level_type, (level_type, 0.0))

metadata |= {
"vref": "native" if vref_flag else "geo",
"vcoord_type": vcoord_type,
"origin": {"z": zshift},
"message": field.message(),
}
return metadata

def _load_param(
self,
req: Request,
Expand All @@ -164,7 +140,7 @@ def _load_param(
fs = self.data_source.retrieve(req)

hcoords: dict[str, xr.DataArray] = {}
metadata: dict[str, typing.Any] = {}
metadata_values: dict[str, typing.Any] = {}
time_meta: dict[int, dict] = {}
dims: tuple[str, ...] | None = None
field_map: dict[tuple[int, ...], np.ndarray] = {}
Expand All @@ -186,8 +162,11 @@ def _load_param(
if not dims:
dims = tuple(DIM_MAP[d] for d in dim_keys) + ("y", "x")

if not metadata:
metadata = self._construct_metadata(field)
if not metadata_values:
metadata_values = {
"message": field.message(),
**metadata.extract(field.metadata()),
}

if not hcoords:
hcoords = {
Expand All @@ -205,7 +184,7 @@ def _load_param(
np.array([field_map.pop(key) for key in sorted(field_map)]).reshape(shape),
coords=coords | hcoords | tcoords,
dims=dims,
attrs=metadata,
attrs=metadata_values,
)

return (
Expand Down Expand Up @@ -262,19 +241,6 @@ def load_fieldnames(
return self.load(reqs, extract_pv)


def _get_type_of_level(field):
if field.vcoord_type == "model_level":
if field.origin["z"] == 0.0:
return "generalVerticalLayer"
elif field.origin["z"] == -0.5:
return "generalVertical"
else:
raise ValueError(f"Unsupported field origin in z: {field.origin['z']}")
else:
mapping = {vc: name for name, (vc, _) in VCOORD_TYPE.items()}
return mapping.get(field.vcoord_type, field.vcoord_type)


def save(field: xr.DataArray, file_handle: io.BufferedWriter, bits_per_value: int = 16):
"""Write field to file in GRIB format.

Expand Down Expand Up @@ -307,8 +273,7 @@ def save(field: xr.DataArray, file_handle: io.BufferedWriter, bits_per_value: in
}

def to_grib(loc: dict[str, xr.DataArray]):
result = {INV_DIM_MAP[key]: value.item() for key, value in loc.items()}
return result | {"typeOfLevel": _get_type_of_level(field)}
return {INV_DIM_MAP[key]: value.item() for key, value in loc.items()}

for idx_slice in product(*idx.values()):
loc = {dim: value for dim, value in zip(idx.keys(), idx_slice)}
Expand Down
35 changes: 30 additions & 5 deletions src/idpi/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,38 @@
import xarray as xr
from earthkit.data.writers import write # type: ignore

# Local
from . import grib_decoder

VCOORD_TYPE = {
"generalVertical": ("model_level", -0.5),
"generalVerticalLayer": ("model_level", 0.0),
"isobaricInPa": ("pressure", 0.0),
}
Comment on lines +17 to +21
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come surface and hybrid keys are removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typeOfLevel is the same as the vcoord_type for those so they're actually handled by the pass through

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok thanks



def extract(metadata):
[vref_flag] = grib_decoder.get_code_flag(
metadata.get("resolutionAndComponentFlags"), [5]
)
level_type = metadata.get("typeOfLevel")
vcoord_type, zshift = VCOORD_TYPE.get(level_type, (level_type, 0.0))

return {
"parameter": metadata.as_namespace("parameter"),
"geography": metadata.as_namespace("geography"),
"vref": "native" if vref_flag else "geo",
"vcoord_type": vcoord_type,
"origin_z": zshift,
}


def override(message: bytes, **kwargs: typing.Any) -> dict[str, typing.Any]:
"""Override GRIB metadata contained in message.

Note that no special consideration is made for maintaining consistency when
overriding template definition keys such as productDefinitionTemplateNumber.
Note that the origin components in x and y will be unset.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cant see where in override or extract origin_x or origin_y are unset. Does this happen in earthkit-data's GribMetadata override method?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this note is a bit outdated, in the current implementation, the origin_xy values just remain untouched. I'll update the note

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes thanks, since untouched and unset have different meanings it would be good to update


Parameters
----------
Expand All @@ -40,8 +66,7 @@ def override(message: bytes, **kwargs: typing.Any) -> dict[str, typing.Any]:

return {
"message": out.getvalue(),
"geography": md.as_namespace("geography"),
"parameter": md.as_namespace("parameter"),
**extract(md),
}


Expand Down Expand Up @@ -112,8 +137,8 @@ def compute_origin(ref_grid: Grid, field: xr.DataArray) -> dict[str, float]:
y0_key = "latitudeOfFirstGridPointInDegrees"

return {
"x": np.round((geo[x0_key] % 360 - x0) / dx, 1),
"y": np.round((geo[y0_key] - y0) / dy, 1),
"origin_x": np.round((geo[x0_key] % 360 - x0) / dx, 1),
"origin_y": np.round((geo[y0_key] - y0) / dy, 1),
}


Expand All @@ -138,4 +163,4 @@ def set_origin_xy(ds: dict[str, xr.DataArray], ref_param: str) -> None:

ref_grid = load_grid_reference(ds[ref_param].message)
for field in ds.values():
field.attrs["origin"] |= compute_origin(ref_grid, field)
field.attrs |= compute_origin(ref_grid, field)
17 changes: 13 additions & 4 deletions src/idpi/operators/destagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ def _update_grid(field: xr.DataArray, dim: Literal["x", "y"]) -> dict[str, Any]:
)


def _update_vertical(field) -> dict[str, Any]:
if field.vcoord_type != "model_level":
raise ValueError("Field.vcoord_type must model_level")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("Field.vcoord_type must model_level")
raise ValueError("Vertical coordinate type must equal model_level")

return metadata.override(
field.message,
typeOfLevel="generalVerticalLayer",
)


def destagger(
field: xr.DataArray,
dim: Literal["x", "y", "z"],
Expand Down Expand Up @@ -128,7 +137,7 @@ def destagger(
"""
dims = list(field.sizes.keys())
if dim == "x" or dim == "y":
if field.origin[dim] != 0.5:
if field.attrs[f"origin_{dim}"] != 0.5:
raise ValueError
return (
xr.apply_ufunc(
Expand All @@ -140,10 +149,10 @@ def destagger(
keep_attrs=True,
)
.transpose(*dims)
.assign_attrs(origin=field.origin | {dim: 0.0}, **_update_grid(field, dim))
.assign_attrs({f"origin_{dim}": 0.0}, **_update_grid(field, dim))
)
elif dim == "z":
if field.origin[dim] != -0.5:
if field.origin_z != -0.5:
raise ValueError
return (
xr.apply_ufunc(
Expand All @@ -155,7 +164,7 @@ def destagger(
keep_attrs=True,
)
.transpose(*dims)
.assign_attrs(origin=field.origin | {dim: 0.0})
.assign_attrs({f"origin_{dim}": 0.0}, **_update_vertical(field))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does destaggering z enforce typeOfLevel="generalVerticalLayer" ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All staggered fields should be defined on the generalVertical levels as far as I understand

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why then does it change from generalVertical to generalVerticalLayer during destaggering?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that's just the way that those level types are defined.

)

raise ValueError("Unknown dimension", dim)
3 changes: 1 addition & 2 deletions src/idpi/operators/gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ def vref_rot2geolatlon(
x and y components of the vector field w.r.t. the geo lat lon coords.

"""
valid_origin = {d: 0.0 for d in tuple("xyz")}
if u.origin != valid_origin or v.origin != valid_origin:
if u.origin_x != 0.0 or v.origin_y != 0.0:
raise ValueError("The vector fields must be destaggered.")

grid = get_grid(u.geography)
Expand Down
10 changes: 5 additions & 5 deletions src/idpi/operators/vertical_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def interpolate_k2p(
)
# Check that dimensions are the same for field and p_field
if (
field.origin != p_field.origin
field.origin_z != p_field.origin_z
or field.dims != p_field.dims
or field.size != p_field.size
):
Expand Down Expand Up @@ -253,16 +253,16 @@ def interpolate_k2theta(
)

# Check that typeOfLevel is supported and equal for field, th_field, and h_field
if field.vcoord_type != "model_level" or field.origin["z"] != 0.0:
if field.vcoord_type != "model_level" or field.origin_z != 0.0:
raise RuntimeError(
"interpolate_k2theta: field to interpolate must "
"be defined on model_level layers"
)
if th_field.vcoord_type != "model_level" or th_field.origin["z"] != 0.0:
if th_field.vcoord_type != "model_level" or th_field.origin_z != 0.0:
raise RuntimeError(
"interpolate_k2theta: theta field must be defined on model_level layers"
)
if h_field.vcoord_type != "model_level" or h_field.origin["z"] != 0.0:
if h_field.vcoord_type != "model_level" or h_field.origin_z != 0.0:
raise RuntimeError(
"interpolate_k2theta: height field must be defined on model_level layers"
)
Expand Down Expand Up @@ -355,7 +355,7 @@ def interpolate_k2any(
raise ValueError(f"Unsupported mode: {mode}")

for f in (field, tc_field, h_field):
if f.vcoord_type != "model_level" or f.origin["z"] != 0.0:
if f.vcoord_type != "model_level" or f.origin_z != 0.0:
raise ValueError("Input fields must be defined on full model levels")

# ... tc values outside range of meaningful values of height,
Expand Down
7 changes: 2 additions & 5 deletions src/idpi/operators/vertical_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,7 @@ def minmax_k(field, operator, mode, height, h_bounds, hsurf=None):
# levels included in the height interval, and at the interval boundaries
# after linear interpolation wrt height; f and auxiliary height fields
# must either both be defined on full levels or half levels
if (
field.vcoord_type != height.vcoord_type
or field.origin["z"] != height.origin["z"]
):
if field.vcoord_type != height.vcoord_type or field.origin_z != height.origin_z:
raise RuntimeError(
"minmax_k: height is not defined for the same level type as field."
)
Expand Down Expand Up @@ -256,7 +253,7 @@ def integrate_k(field, operator, mode, height, h_bounds, hsurf=None):
"integrate_k: field must be defined for level type "
"generalVertical or generalVerticalLayer"
)
if field.origin["z"] != 0.0:
if field.origin_z != 0.0:
field_on_fl = destagger(field, "z")
else:
field_on_fl = field
Expand Down
4 changes: 1 addition & 3 deletions src/idpi/operators/wind.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def speed(u: xr.DataArray, v: xr.DataArray) -> xr.DataArray:
the horizontal wind speed [m/s].

"""
centered = {dim: 0.0 for dim in tuple("xyz")}

if u.origin != centered or v.origin != centered:
if u.origin_x != 0.0 or v.origin_y != 0.0:
raise ValueError("The wind components should not be staggered.")

name = {"U": "SP", "U_10M": "SP_10M"}[u.parameter["shortName"]]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_idpi/test_destagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def test_destagger(data_dir, fieldextra):
assert_allclose(fs_ds["V"], v, rtol=1e-12, atol=1e-9)
assert_allclose(fs_ds["HFL"], hfl, rtol=1e-12, atol=1e-9)

assert u.origin["x"] == 0.0
assert v.origin["y"] == 0.0
assert hfl.origin["z"] == 0.0
assert u.origin_x == 0.0
assert v.origin_y == 0.0
assert hfl.origin_z == 0.0
Loading