Skip to content

Commit

Permalink
Merge pull request #399 from emfdavid/grib_tree
Browse files Browse the repository at this point in the history
Add grib_tree method
  • Loading branch information
martindurant authored Dec 4, 2023
2 parents 7047d14 + 229c3da commit 37d7526
Show file tree
Hide file tree
Showing 9 changed files with 1,845 additions and 6 deletions.
1 change: 1 addition & 0 deletions ci/environment-py310.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- dask
- zarr
- xarray
- xarray-datatree
- h5netcdf
- h5py<3.9
- pandas
Expand Down
1 change: 1 addition & 0 deletions ci/environment-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- dask
- zarr
- xarray
- xarray-datatree
- h5netcdf
- h5py<3.9
- pandas
Expand Down
1 change: 1 addition & 0 deletions ci/environment-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- dask
- zarr
- xarray
- xarray-datatree
- h5netcdf
- h5py<3.9
- pandas
Expand Down
9 changes: 7 additions & 2 deletions kerchunk/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,13 @@ def store_coords(self):
for _ in v
]
).ravel()
if "fill_value" not in kw and data.dtype.kind == "i":
kw["fill_value"] = None
if "fill_value" not in kw:
if data.dtype.kind == "i":
kw["fill_value"] = None
elif k in z:
# Fall back to existing fill value
kw["fill_value"] = z[k].fill_value

arr = group.create_dataset(
name=k,
data=data,
Expand Down
230 changes: 228 additions & 2 deletions kerchunk/grib2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import base64
import copy
import logging
from collections import defaultdict
from typing import Iterable, List, Dict, Set

import ujson

try:
import cfgrib
Expand All @@ -13,10 +18,12 @@

import fsspec
import zarr
import xarray
import numpy as np

from kerchunk.utils import class_factory, _encode_for_JSON
from kerchunk.codecs import GRIBCodec
from kerchunk.combine import MultiZarrToZarr, drop


# cfgrib copies over certain GRIB attributes
Expand Down Expand Up @@ -328,8 +335,6 @@ def example_combine(
... "consolidated": False,
... "storage_options": {"fo": tot, "remote_options": {"anon": True}}})
"""
from kerchunk.combine import MultiZarrToZarr, drop

files = [
"s3://noaa-hrrr-bdp-pds/hrrr.20190101/conus/hrrr.t22z.wrfsfcf01.grib2",
"s3://noaa-hrrr-bdp-pds/hrrr.20190101/conus/hrrr.t23z.wrfsfcf01.grib2",
Expand All @@ -354,3 +359,224 @@ def example_combine(
identical_dims=["heightAboveGround", "latitude", "longitude"],
)
return mzz.translate()


def grib_tree(
message_groups: Iterable[Dict],
remote_options=None,
) -> Dict:
"""
Build a hierarchical data model from a set of scanned grib messages.
The iterable input groups should be a collection of results from scan_grib. Multiple grib files can
be processed together to produce an FMRC like collection.
The time (reference_time) and step coordinates will be used as concat_dims in the MultiZarrToZarr
aggregation. Each variable name will become a group with nested subgroups representing the grib
step type and grib level. The resulting hierarchy can be opened as a zarr_group or a xarray datatree.
Grib message variable names that decode as "unknown" are dropped
Grib typeOfLevel attributes that decode as unknown are treated as a single group
Grib steps that are missing due to WrongStepUnitError are patched with NaT
The input message_groups should not be modified by this method
Parameters
----------
message_groups: iterable[dict]
a collection of zarr store like dictionaries as produced by scan_grib
remote_options: dict
remote options to pass to ZarrToMultiZarr
Returns
-------
list(dict): A new zarr store like dictionary for use as a reference filesystem mapper with zarr
or xarray datatree
"""
# Hard code the filters in the correct order for the group hierarchy
filters = ["stepType", "typeOfLevel"]

# TODO allow passing a LazyReferenceMapper as output?
zarr_store = {}
zroot = zarr.open_group(store=zarr_store)

aggregations: Dict[str, List] = defaultdict(list)
aggregation_dims: Dict[str, Set] = defaultdict(set)

unknown_counter = 0
for msg_ind, group in enumerate(message_groups):
assert group["version"] == 1

gattrs = ujson.loads(group["refs"][".zattrs"])
coordinates = gattrs["coordinates"].split(" ")

# Find the data variable
vname = None
for key, entry in group["refs"].items():
name = key.split("/")[0]
if name not in [".zattrs", ".zgroup"] and name not in coordinates:
vname = name
break

if vname is None:
raise RuntimeError(
f"Can not find a data var for msg# {msg_ind} in {group['refs'].keys()}"
)

if vname == "unknown":
# To resolve unknown variables add custom grib tables.
# https://confluence.ecmwf.int/display/UDOC/Creating+your+own+local+definitions+-+ecCodes+GRIB+FAQ
# If you process the groups from a single file in order, you can use the msg# to compare with the
# IDX file. The idx files message index is 1 based where the grib_tree message count is zero based
logger.warning(
"Dropping unknown variable in msg# %d. Compare with the grib idx file to help identify it"
" and build an ecCodes local grib definitions file to fix it.",
msg_ind,
)
unknown_counter += 1
continue

logger.debug("Processing vname: %s", vname)
dattrs = ujson.loads(group["refs"][f"{vname}/.zattrs"])
# filter order matters - it determines the hierarchy
gfilters = {}
for key in filters:
attr_val = dattrs.get(f"GRIB_{key}")
if attr_val is None:
continue
if attr_val == "unknown":
logger.warning(
"Found 'unknown' attribute value for key %s in var %s of msg# %s",
key,
vname,
msg_ind,
)
# Use unknown as a group or drop it?

gfilters[key] = attr_val

zgroup = zroot.require_group(vname)
if "name" not in zgroup.attrs:
zgroup.attrs["name"] = dattrs.get("GRIB_name")

for key, value in gfilters.items():
if value: # Ignore empty string and None
# name the group after the attribute values: surface, instant, etc
zgroup = zgroup.require_group(value)
# Add an attribute to give context
zgroup.attrs[key] = value

# Set the coordinates attribute for the group
zgroup.attrs["coordinates"] = " ".join(coordinates)
# add to the list of groups to multi-zarr
aggregations[zgroup.path].append(group)

# keep track of the level coordinate variables and their values
for key, entry in group["refs"].items():
name = key.split("/")[0]
if name == gfilters.get("typeOfLevel") and key.endswith("0"):
if isinstance(entry, list):
entry = tuple(entry)
aggregation_dims[zgroup.path].add(entry)

concat_dims = ["time", "step"]
identical_dims = ["longitude", "latitude"]
for path in aggregations.keys():
# Parallelize this step!
catdims = concat_dims.copy()
idims = identical_dims.copy()

level_dimension_value_count = len(aggregation_dims.get(path, ()))
level_group_name = path.split("/")[-1]
if level_dimension_value_count == 0:
logger.debug(
"Path % has no value coordinate value associated with the level name %s",
path,
level_group_name,
)
elif level_dimension_value_count == 1:
idims.append(level_group_name)
elif level_dimension_value_count > 1:
# The level name should be the last element in the path
catdims.insert(3, level_group_name)

logger.info(
"%s calling MultiZarrToZarr with idims %s and catdims %s",
path,
idims,
catdims,
)

mzz = MultiZarrToZarr(
aggregations[path],
remote_options=remote_options,
concat_dims=catdims,
identical_dims=idims,
)
group = mzz.translate()

for key, value in group["refs"].items():
if key not in [".zattrs", ".zgroup"]:
zarr_store[f"{path}/{key}"] = value

# Force all stored values to decode as string, not bytes. String should be correct.
# ujson will reject bytes values by default.
# Using 'reject_bytes=False' one write would fail an equality check on read.
zarr_store = {
key: (val.decode() if isinstance(val, bytes) else val)
for key, val in zarr_store.items()
}
# TODO handle other kerchunk reference spec versions?
result = dict(refs=zarr_store, version=1)

return result


def correct_hrrr_subhf_step(group: Dict) -> Dict:
"""
Overrides the definition of the "step" variable.
Sets the value equal to the `valid_time - time`
in hours as a floating point value. This fixes issues with the HRRR SubHF grib2 step as read by
cfgrib via scan_grib.
The result is a deep copy, the original data is unmodified.
Parameters
----------
group: dict
the zarr group store for a single grib message
Returns
-------
dict: A new zarr store like dictionary for use as a reference filesystem mapper with zarr
or xarray datatree
"""
group = copy.deepcopy(group)
group["refs"]["step/.zarray"] = (
'{"chunks":[],"compressor":null,"dtype":"<f8","fill_value":"NaN","filters":null,"order":"C",'
'"shape":[],"zarr_format":2}'
)
group["refs"]["step/.zattrs"] = (
'{"_ARRAY_DIMENSIONS":[],"long_name":"time since forecast_reference_time",'
'"standard_name":"forecast_period","units":"hours"}'
)

# add step to coords
attrs = ujson.loads(group["refs"][".zattrs"])
if "step" not in attrs["coordinates"]:
attrs["coordinates"] += " step"
group["refs"][".zattrs"] = ujson.dumps(attrs)

fo = fsspec.filesystem("reference", fo=group, mode="r")
xd = xarray.open_dataset(fo.get_mapper(), engine="zarr", consolidated=False)

correct_step = xd.valid_time.values - xd.time.values

assert correct_step.shape == ()
step_float = correct_step.astype("timedelta64[s]").astype("float") / 3600.0
step_bytes = step_float.tobytes()
try:
enocded_val = step_bytes.decode("ascii")
except UnicodeDecodeError:
enocded_val = (b"base64:" + base64.b64encode(step_bytes)).decode("ascii")

group["refs"]["step/0"] = enocded_val

return group
Loading

0 comments on commit 37d7526

Please sign in to comment.