Skip to content

Commit

Permalink
Feature/to xarray from stream (#497)
Browse files Browse the repository at this point in the history
* Enable converting GRIB streams to xarray
  • Loading branch information
sandorkertesz authored Oct 25, 2024
1 parent 3781045 commit 5843ad1
Show file tree
Hide file tree
Showing 20 changed files with 593 additions and 300 deletions.
2 changes: 1 addition & 1 deletion docs/examples/xarray_engine_field_dims.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"id": "2f23c9c6-09b3-477f-8d2f-534312bc835f",
"metadata": {
"editable": true,
"raw_mimetype": "",
"raw_mimetype": "text/restructuredtext",
"slideshow": {
"slide_type": ""
},
Expand Down
7 changes: 3 additions & 4 deletions src/earthkit/data/core/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,8 @@ def from_fields(fields):
Parameters
----------
fields: list
List of :obj:`Field` objects.
fields: iterable
Iterable of :obj:`Field` objects.
Returns
-------
Expand All @@ -839,7 +839,7 @@ def from_fields(fields):
"""
from earthkit.data.indexing.fieldlist import SimpleFieldList

return SimpleFieldList(fields)
return SimpleFieldList([f for f in fields])

@staticmethod
def from_numpy(array, metadata):
Expand Down Expand Up @@ -966,7 +966,6 @@ def _vals(f):
r[0] = vals
for i, f in enumerate(it, start=1):
r[i] = _vals(f)

return r

def to_numpy(self, **kwargs):
Expand Down
11 changes: 11 additions & 0 deletions src/earthkit/data/indexing/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,17 @@ def from_fieldlist(

return cls(source, user_coords, field_coords, field_dims, flatten_values)

def clear(self):
self.source = None
self._user_coords = None
self._user_shape = None
self._user_dims = None
self._field_coords = None
self._field_shape = None
self._field_dims = None
self._full_shape = None
self.flatten_values = None

@flatten_arg
def to_numpy(self, index=None, **kwargs):
if index is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/earthkit/data/readers/grib/codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __repr__(self):
)

def write(self, f, **kwargs):
r"""Writes the message to a file object.
r"""Write the message to a file object.
Parameters
----------
Expand All @@ -317,7 +317,7 @@ def write(self, f, **kwargs):
write(f, self, **kwargs)

def message(self):
r"""Returns a buffer containing the encoded message.
r"""Return a buffer containing the encoded message.
Returns
-------
Expand Down
3 changes: 3 additions & 0 deletions src/earthkit/data/readers/grib/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def from_buffer(buf):
GribCodesHandle(handle, None, None), use_metadata_cache=get_use_grib_metadata_cache()
)

def _release(self):
self._handle = None

def copy(self, **kwargs):
return NewGribFieldInMemory(self, **kwargs)

Expand Down
24 changes: 21 additions & 3 deletions src/earthkit/data/readers/grib/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def to_xarray(self, engine="earthkit", xarray_open_dataset_kwargs=None, **kwargs
"level" and "level_type" roles.
- "level_and_type": Use a single dimension for combined level and type of level.
* squeeze: bool, None
Remove dimensions which has only one valid values. Not applies to dimension in
Remove dimensions which have only one valid value. Not applies to dimensions in
``ensure_dims``. Its default value (None) expands
to True unless the ``profile`` overwrites it.
* add_valid_time_coord: bool, None
Expand Down Expand Up @@ -245,6 +245,22 @@ def to_xarray(self, engine="earthkit", xarray_open_dataset_kwargs=None, **kwargs
A dictionary of attribute to rename. Default is None.
* remapping: dict, None
Define new metadata keys for indexing. Default is None.
* lazy_load: bool, None
If True, the resulting DataSet will load data lazily from the
underlying data source. If False, a DataSet holding all the data in memory
and decoupled from the backend source will be created.
Using ``lazy_load=False`` with ``release_source=True`` can provide optimised
memory usage in certain cases. The default value of ``lazy_load`` (None)
expands to True unless the ``profile`` overwrites it.
* release_source: bool, None
Only used when ``lazy_load=False``. If True, memory held in the input fields are
released as soon as their values are copied into the resulting DataSet. This is
done per field to avoid memory spikes. The release operation is currently
only supported for GRIB fields stored entirely in memory, e.g. when read from a
:ref:`stream <streams>`. When a field does not support the release operation, this
option is ignored. Having run :obj:`to_xarray` the input data becomes unusable,
so use this option carefully. The default value of ``release_source`` (None) expands
to False unless the ``profile`` overwrites it.
* strict: bool, None
If True, perform stricter checks on hypercube consistency. Its default value (None) expands
to False unless the ``profile`` overwrites it.
Expand Down Expand Up @@ -287,9 +303,11 @@ def to_xarray(self, engine="earthkit", xarray_open_dataset_kwargs=None, **kwargs
--------
>>> import earthkit.data
>>> fs = earthkit.data.from_source("file", "test6.grib")
>>> ds = fs.to_xarray(time_dim_mode="forecast")
>>> # also possible to use the xarray_open_dataset_kwargs
>>> ds = fs.to_xarray(
... xarray_open_dataset_kwargs={
... "backend_kwargs": {"ignore_keys": ["number"]}
... "backend_kwargs": {"time_dim_mode": "forecast"}
... }
... )
Expand Down Expand Up @@ -349,7 +367,7 @@ def to_xarray_earthkit(self, user_kwargs):
# print(f"{kwargs=}")
# print(f"{xarray_open_dataset_kwargs=}")

from earthkit.data.utils.xarray.engine import from_earthkit
from earthkit.data.utils.xarray.builder import from_earthkit

return from_earthkit(self, **xarray_open_dataset_kwargs)

Expand Down
4 changes: 4 additions & 0 deletions src/earthkit/data/sources/array_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def _metadata(self):
def handle(self):
return self._metadata._handle

def _release(self):
self._array = None
self.__metadata = None

def __getstate__(self) -> dict:
ret = {}
ret["_array"] = self._array
Expand Down
44 changes: 19 additions & 25 deletions src/earthkit/data/sources/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import itertools
import logging
from functools import cached_property

from earthkit.data.core.fieldlist import FieldList
from earthkit.data.readers import stream_reader
Expand All @@ -31,15 +32,13 @@ def __init__(self, stream, **kwargs):
if not isinstance(stream, Stream):
raise ValueError(f"Invalid stream={stream}")
self._stream = stream
self._reader_ = None

@property
@cached_property
def _reader(self):
if self._reader_ is None:
self._reader_ = stream_reader(self, self._stream.stream, True, **self._kwargs)
if self._reader_ is None:
raise TypeError(f"could not create reader for stream={self._stream}")
return self._reader_
reader = stream_reader(self, self._stream.stream, True, **self._kwargs)
if reader is None:
raise TypeError(f"could not create reader for stream={self._stream}")
return reader

def mutate(self):
source = self._reader.mutate_source()
Expand All @@ -52,9 +51,10 @@ def mutate(self):
class StreamSource(Source):
def __init__(self, stream, *, read_all=False, **kwargs):
super().__init__()
self._reader_ = None
self._stream = self._wrap_stream(stream)
self.memory = read_all

# TODO: remove this check in a future release
for k in ["group_by", "batch_size"]:
if k in kwargs:
raise ValueError(f"Invalid argument '{k}' for StreamSource. Deprecated since 0.8.0.")
Expand All @@ -74,13 +74,12 @@ def mutate(self):
return StreamFieldList(self._reader, **self._kwargs)
return self

@property
@cached_property
def _reader(self):
if self._reader_ is None:
self._reader_ = stream_reader(self, self._stream.stream, False, **self._kwargs)
if self._reader_ is None:
raise TypeError(f"could not create reader for stream={self._stream.stream}")
return self._reader_
reader = stream_reader(self, self._stream.stream, False, **self._kwargs)
if reader is None:
raise TypeError(f"could not create reader for stream={self._stream.stream}")
return reader

def batched(self, n):
"""Iterate through the stream in batches of ``n``.
Expand Down Expand Up @@ -133,13 +132,6 @@ def _wrap_stream(self, stream):

return stream

def _status(self):
"""For testing purposes."""
return {
"reader": self._reader_ is not None,
"stream": self._stream._stream is not None,
}


class MultiStreamSource(Source):
def __init__(self, sources, read_all=False, **kwargs):
Expand Down Expand Up @@ -183,10 +175,6 @@ def _from_sources(self, sources):
raise TypeError(f"Invalid source={s}")
return r

def _status(self):
"""For testing purposes."""
return [s._status() for s in self.sources]


class StreamFieldList(FieldList, Source):
def __init__(self, source, **kwargs):
Expand All @@ -208,6 +196,12 @@ def group_by(self, *keys, **kwargs):
def __getstate__(self):
raise NotImplementedError("StreamFieldList cannot be pickled")

def to_xarray(self, **kwargs):
from earthkit.data.core.fieldlist import FieldList

fields = [f for f in self]
return FieldList.from_fields(fields).to_xarray(**kwargs)


class Stream:
def __init__(self, stream=None, maker=None, **kwargs):
Expand Down
67 changes: 34 additions & 33 deletions src/earthkit/data/utils/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,29 @@
from collections import defaultdict


class TimeDiag:
class DiagCore:
def __init__(self, name=""):
self.name = name

def _build_label(self, label):
if label:
label = f"[{label:10}] "

if self.name:
name = f"[{self.name:10}] "
if label:
label = f"{name}{label}"
else:
label = f"{name}"

return label


class TimeDiag(DiagCore):
def __init__(self, name="", **kwargs):
self.start = time.time()
self.prev = self.start
super().__init__(name=name, **kwargs)

def elapsed(self):
return time.time() - self.start
Expand All @@ -24,14 +42,7 @@ def __call__(self, label="", as_str=False):
curr = time.time()
delta = curr - self.prev
self.prev = curr
if label:
label = f"[{label:10}] "
if self.name:
if label:
label = f"[{self.name}] {label}"
else:
label = f"[{self.name}]"

label = self._build_label(label)
s = f"{label}elapsed={self.elapsed():.3f}s delta={delta:.3f}s"

if as_str:
Expand All @@ -40,14 +51,13 @@ def __call__(self, label="", as_str=False):
print(s)


class MemoryDiag:
def __init__(self, name="", peak=False):
class MemoryDiag(DiagCore):
def __init__(self, name="", peak=False, **kwargs):
import os
import platform

import psutil

self.name = name
self.proc = psutil.Process()
self.prev = 0
self.scale = 1
Expand All @@ -59,6 +69,8 @@ def __init__(self, name="", peak=False):
except Exception:
pass

super().__init__(name=name, **kwargs)

def scale_to_mbytes(self, v):
return v * self.scale

Expand All @@ -85,14 +97,7 @@ def __call__(self, label="", delta=True, as_str=False):
m = self.current()
_delta = m - self.prev
self.prev = m
if label:
label = f"[{label:10}] "

if self.name:
if label:
label = f"[{self.name}] {label}"
else:
label = f"[{self.name}]"
label = self._build_label(label)

s = ""
if delta:
Expand All @@ -109,15 +114,17 @@ def __call__(self, label="", delta=True, as_str=False):
print(s)


class Diag:
def __init__(self, name="", peak=False):
self.time = TimeDiag(name)
self.memory = MemoryDiag(name, peak=peak)
class Diag(DiagCore):
def __init__(self, name="", peak=False, **kwargs):
self.time = TimeDiag("")
self.memory = MemoryDiag("", peak=peak)
super().__init__(name=name, **kwargs)

def __call__(self, label=""):
if label:
label = f"[{label:10}] "
return f"{label}{self.time(as_str=True)} {self.memory(as_str=True)}"
label = str(label)
label = self._build_label(label)
s = f"{label}{self.time(as_str=True)} {self.memory(as_str=True)}"
print(s)

def peak(self):
return self.memory.peak()
Expand All @@ -127,12 +134,6 @@ def metadata_cache_diag(fieldlist):
r = defaultdict(int)
for f in fieldlist:
collect_field_metadata_cache_diag(f, r)
# try:
# md_cache = f._diag()
# for k in ["metadata_cache_hits", "metadata_cache_misses", "metadata_cache_size"]:
# r[k] += md_cache[k]
# except Exception:
# pass
return r


Expand Down
Loading

0 comments on commit 5843ad1

Please sign in to comment.