diff --git a/CHANGELOG.md b/CHANGELOG.md index 61d8852a..56a0ea89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,11 +15,15 @@ Keep it human-readable, your future self will thank you! - Fix metadata serialization handling of numpy.integer (#140) - Fix negative variance for constant variables (#148) - Fix cutout slicing of grid dimension (#145) +- Use cKDTree instead of KDTree +- Implement 'complement' feature +- Add ability to patch xarrays (#160) ### Added - Call filters from anemoi-transform -- make test optional when adls is not installed Pull request #110 +- Make test optional when adls is not installed Pull request #110 +- Add wz_to_w, orog_to_z, and sum filters (#149) ## [0.5.8](https://github.com/ecmwf/anemoi-datasets/compare/0.5.7...0.5.8) - 2024-10-26 diff --git a/docs/building/filters.rst b/docs/building/filters.rst index f26467e6..84f4809e 100644 --- a/docs/building/filters.rst +++ b/docs/building/filters.rst @@ -15,8 +15,11 @@ Filters are used to modify the data or metadata in a dataset. :maxdepth: 1 filters/select + filters/orog_to_z filters/rename filters/rotate_winds + filters/sum filters/unrotate_winds + filters/wz_to_w filters/noop filters/empty diff --git a/docs/building/filters/orog_to_z.rst b/docs/building/filters/orog_to_z.rst new file mode 100644 index 00000000..8c30cd39 --- /dev/null +++ b/docs/building/filters/orog_to_z.rst @@ -0,0 +1,17 @@ +########### + orog_to_z +########### + +The ``orog_to_z`` filter converts orography (in meters) to surface +geopotential height (m^2/s^2) using the equation: + +.. math:: + + z &= g \cdot \textrm{orog}\\ + g &= 9.80665\ m \cdot s^{-1} + +This filter needs to follow a source that provides orography, which is +replaced by surface geopotential height. + +.. literalinclude:: yaml/orog_to_z.yaml + :language: yaml diff --git a/docs/building/filters/sum.rst b/docs/building/filters/sum.rst new file mode 100644 index 00000000..34901296 --- /dev/null +++ b/docs/building/filters/sum.rst @@ -0,0 +1,13 @@ +##### + sum +##### + +The ``sum`` filter computes the sum over multiple variables. This can be +useful for computing total precipitation from its components (snow, +rain) or summing the components of total column integrated water. This +filter needs to follow a source that provides the list of variables to +be summed up. These variables are removed by the filter and replaced by +a single summed variable. + +.. literalinclude:: yaml/sum.yaml + :language: yaml diff --git a/docs/building/filters/wz_to_w.rst b/docs/building/filters/wz_to_w.rst new file mode 100644 index 00000000..846995bb --- /dev/null +++ b/docs/building/filters/wz_to_w.rst @@ -0,0 +1,12 @@ +######### + wz_to_w +######### + +The ``wz_to_w`` filter converts geometric vertical velocity (provided in +m/s) to vertical velocity in pressure coordinates (Pa/s). This filter +needs to follow a source that provides geometric vertical velocity. +Geometric vertical velocity is removed by the filter and pressure +vertical velocity is added. + +.. literalinclude:: yaml/wz_to_w.yaml + :language: yaml diff --git a/docs/building/filters/yaml/orog_to_z.yaml b/docs/building/filters/yaml/orog_to_z.yaml new file mode 100644 index 00000000..ceeed636 --- /dev/null +++ b/docs/building/filters/yaml/orog_to_z.yaml @@ -0,0 +1,10 @@ +input: + pipe: + - source: # mars, grib, netcdf, etc. + # source attributes here + # ... + # Must load an orography variable + + - orog_to_z: + orog: orog # Name of orography (input) variable + z: z # Name of z (output) variable diff --git a/docs/building/filters/yaml/sum.yaml b/docs/building/filters/yaml/sum.yaml new file mode 100644 index 00000000..ef1e1465 --- /dev/null +++ b/docs/building/filters/yaml/sum.yaml @@ -0,0 +1,13 @@ +input: + pipe: + - source: # mars, grib, netcdf, etc. + # source attributes here + # ... + # Must load the variables to be summed + + - sum: + params: # List of input variables + variable1 + variable2 + variable3 + output: variable_total # Name of output variable diff --git a/docs/building/filters/yaml/wz_to_w.yaml b/docs/building/filters/yaml/wz_to_w.yaml new file mode 100644 index 00000000..6ae9e2d9 --- /dev/null +++ b/docs/building/filters/yaml/wz_to_w.yaml @@ -0,0 +1,10 @@ +input: + pipe: + - source: # mars, grib, netcdf, etc. + # source attributes here + # ... + # Must load geometric vertical velocity + + - wz_to_w: + wz: wz # Name of geometric vertical velocity (input) variable + x: z # Name of pressure vertical velocity (output) variable diff --git a/docs/building/introduction.rst b/docs/building/introduction.rst index c2f7c8bd..65ef921f 100644 --- a/docs/building/introduction.rst +++ b/docs/building/introduction.rst @@ -10,7 +10,7 @@ file, which is a YAML file that describes sources of meteorological fields as well as the operations to perform on them, before they are written to a zarr file. The input of the process is a range of dates and some options to control the layout of the output. Statistics will be -computed as the dataset is build, and stored in the metadata, with other +computed as the dataset is built, and stored in the metadata, with other information such as the the locations of the grid points, the list of variables, etc. @@ -24,35 +24,35 @@ variables, etc. date Throughout this document, the term `date` refers to a date and time, - not just a date. A training dataset is covers a continuous range of + not just a date. A training dataset covers a continuous range of dates with a given frequency. Missing dates are still part of the - dataset, but the data are missing and marked as such using NaNs. - Dates are always in UTC, and refer to date at which the data is - valid. For accumulations and fluxes, that would be the end of the - accumulation period. + dataset, but missing data are marked as such using NaNs. Dates are + always in UTC, and refer to date at which the data is valid. For + accumulations and fluxes, that would be the end of the accumulation + period. variable - A `variable` is meteorological parameter, such as temperature, wind, - etc. Multilevel parameters are treated as separate variables, one for - each level. For example, temperature at 850 hPa and temperature at - 500 hPa will be treated as two separate variables (`t_850` and - `t_500`). + A `variable` is a meteorological parameter, such as temperature, + wind, etc. Multilevel parameters are treated as separate variables, + one for each level. For example, temperature at 850 hPa and + temperature at 500 hPa will be treated as two separate variables + (`t_850` and `t_500`). field - A `field` is a variable at a given date. It is represented by a array - of values at each grid point. + A `field` is a variable at a given date. It is represented by an + array of values at each grid point. source - The `source` is a software component that given a list of dates and - variables will return the corresponding fields. A example of source + The `source` is a software component that, given a list of dates and + variables will return the corresponding fields. An example of source is ECMWF's MARS archive, a collection of GRIB or NetCDF files, a database, etc. See :ref:`sources` for more information. filter A `filter` is a software component that takes as input the output of - a source or the output of another filter can modify the fields and/or - their metadata. For example, typical filters are interpolations, - renaming of variables, etc. See :ref:`filters` for more information. + a source or another filter and can modify the fields and/or their + metadata. For example, typical filters are interpolations, renaming + of variables, etc. See :ref:`filters` for more information. ************ Operations @@ -62,19 +62,20 @@ In order to build a training dataset, sources and filters are combined using the following operations: join - The join is the process of combining several sources data. Each - source is expected to provide different variables at the same dates. + The join is the process of combining several sources of data. Each + source is expected to provide different variables for the same of + dates. pipe The pipe is the process of transforming fields using filters. The - first step of a pipe is typically a source, a join or another pipe. - The following steps are filters. + first step of a pipe is typically a source, a join, or another pipe. + This can subsequently followed by more filters. concat The concatenation is the process of combining different sets of - operation that handle different dates. This is typically used to - build a dataset that spans several years, when the several sources - are involved, each providing a different period. + operations that handle different dates. This is typically used to + build a dataset that spans several years, when several sources are + involved, each providing data for different period. Each operation is considered as a :ref:`source `, therefore operations can be combined to build complex datasets. @@ -87,7 +88,7 @@ First recipe ============ The simplest `recipe` file must contain a ``dates`` section and an -``input`` section. The latter must contain a `source` In that case, the +``input`` section. The latter must contain a `source`. In that case, the source is ``mars`` .. literalinclude:: yaml/building1.yaml @@ -132,15 +133,15 @@ This will build the following dataset: Adding some forcing variables ============================= -When training a data-driven models, some forcing variables may be +When training a data-driven model, some forcing variables may be required such as the solar radiation, the time of day, the day in the year, etc. -These are provided by the ``forcings`` source. In that example, we add a -few of them. The `template` option is used to point to another source, -in that case the first instance of ``mars``. This source is used to get -information about the grid points, as some of the forcing variables are -grid dependent. +These are provided by the ``forcings`` source. Let us add a few of them +to the above example. The `template` option is used to point to another +source, in that case the first instance of ``mars``. This source is used +to get information about the grid points, as some of the forcing +variables are grid dependent. .. literalinclude:: yaml/building3.yaml :language: yaml diff --git a/docs/building/sources/yaml/accumulations1.yaml b/docs/building/sources/yaml/accumulations1.yaml index 1735bff2..163f46e1 100644 --- a/docs/building/sources/yaml/accumulations1.yaml +++ b/docs/building/sources/yaml/accumulations1.yaml @@ -1,6 +1,6 @@ input: accumulations: - accumulations_period: 6 + accumulation_period: 6 class: ea param: [tp, cp, sf] levtype: sfc diff --git a/docs/building/sources/yaml/accumulations2.yaml b/docs/building/sources/yaml/accumulations2.yaml index 16250606..81f699a4 100644 --- a/docs/building/sources/yaml/accumulations2.yaml +++ b/docs/building/sources/yaml/accumulations2.yaml @@ -1,6 +1,6 @@ input: accumulations: - accumulations_period: [6, 12] + accumulation_period: [6, 12] class: od param: [tp, cp, sf] levtype: sfc diff --git a/docs/index.rst b/docs/index.rst index 270509a9..616cf39f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -45,6 +45,7 @@ datasets `. - :doc:`using/subsetting` - :doc:`using/combining` - :doc:`using/selecting` +- :doc:`using/ensembles` - :doc:`using/grids` - :doc:`using/zip` - :doc:`using/statistics` @@ -65,6 +66,7 @@ datasets `. using/subsetting using/combining using/selecting + using/ensembles using/grids using/zip using/statistics diff --git a/docs/using/code/complement1_.py b/docs/using/code/complement1_.py new file mode 100644 index 00000000..ad77234d --- /dev/null +++ b/docs/using/code/complement1_.py @@ -0,0 +1,6 @@ +open_dataset( + complement=dataset1, + source=dataset2, + what="variables", + interpolate="nearest", +) diff --git a/docs/using/code/complement2_.py b/docs/using/code/complement2_.py new file mode 100644 index 00000000..22ad7804 --- /dev/null +++ b/docs/using/code/complement2_.py @@ -0,0 +1,12 @@ +open_dataset( + cutout=[ + { + "complement": lam_dataset, + "source": global_dataset, + "interpolate": "nearest", + }, + { + "dataset": global_dataset, + }, + ] +) diff --git a/docs/using/code/complement3_.py b/docs/using/code/complement3_.py new file mode 100644 index 00000000..e20b30a3 --- /dev/null +++ b/docs/using/code/complement3_.py @@ -0,0 +1,4 @@ +open_dataset( + complement=dataset1, + source=dataset2, +) diff --git a/docs/using/code/number1_.py b/docs/using/code/number1_.py new file mode 100644 index 00000000..089e8b57 --- /dev/null +++ b/docs/using/code/number1_.py @@ -0,0 +1,4 @@ +ds = open_dataset( + dataset, + number=1, +) diff --git a/docs/using/code/number2_.py b/docs/using/code/number2_.py new file mode 100644 index 00000000..3fd8d808 --- /dev/null +++ b/docs/using/code/number2_.py @@ -0,0 +1,4 @@ +ds = open_dataset( + dataset, + number=[1, 3, 5], +) diff --git a/docs/using/combining.rst b/docs/using/combining.rst index 86f1082f..1d6acd1d 100644 --- a/docs/using/combining.rst +++ b/docs/using/combining.rst @@ -182,3 +182,32 @@ The difference can be seen at the boundary between the two grids: To debug the combination, you can pass `plot=True` to the `cutout` function (when running from a Notebook), of use `plot="prefix"` to save the plots to series of PNG files in the current directory. + +.. _complement: + +************ + complement +************ + +That feature will interpolate the variables of `dataset2` that are not +in `dataset1` to the grid of `dataset1` , add them to the list of +variable of `dataset1` and return the result. + +.. literalinclude:: code/complement1_.py + +Currently ``what`` can only be ``variables`` and can be omitted. + +The value for ``interpolate`` can be one of ``none`` (default) or +``nearest``. In the case of ``none``, the grids of the two datasets must +match. + +This feature was originally designed to be used in conjunction with +``cutout``, where `dataset1` is the lam, and `dataset2` is the global +dataset. + +.. literalinclude:: code/complement2_.py + +Another use case is to simply bring all non-overlapping variables of a +dataset into an other: + +.. literalinclude:: code/complement3_.py diff --git a/docs/using/ensembles.rst b/docs/using/ensembles.rst new file mode 100644 index 00000000..ac109538 --- /dev/null +++ b/docs/using/ensembles.rst @@ -0,0 +1,27 @@ +.. _selecting-members: + +################### + Selecting members +################### + +This section describes how to subset data that are part of an ensemble. +To combine ensembles, see :ref:`ensembles` in the +:ref:`combining-datasets` section. + +.. _number: + +If a dataset is an ensemble, you can select one or more specific members +using the `number` option. You can also use ``numbers`` (which is an +alias for ``number``), and ``member`` (or ``members``). The difference +between the two is that ``number`` is **1-based**, while ``member`` is +**0-based**. + +Select a single element: + +.. literalinclude:: code/number1_.py + :language: python + +... or a list: + +.. literalinclude:: code/number2_.py + :language: python diff --git a/docs/using/selecting.rst b/docs/using/selecting.rst index ef21d9a3..cf3cf6d1 100644 --- a/docs/using/selecting.rst +++ b/docs/using/selecting.rst @@ -67,6 +67,28 @@ You can also rename variables: This will be useful when you join datasets and do not want variables from one dataset to override the ones from the other. +******** + number +******** + +If a dataset is an ensemble, you can select one or more specific members +using the `number` option. You can also use ``numbers`` (which is an +alias for ``number``), and ``member`` (or ``members``). The difference +between the two is that ``number`` is **1-based**, while ``member`` is +**0-based**. + +Select a single element: + +.. literalinclude:: code/number1_.py + :language: python + +... or a list: + +.. literalinclude:: code/number2_.py + :language: python + +.. _rescale: + ********* rescale ********* @@ -87,7 +109,9 @@ rescale the data. .. warning:: When providing units, the library assumes that the mapping between - them is a linear transformation. No check is does to ensure this is + them is a linear transformation. No check is done to ensure this is the case. .. _cfunits: https://github.com/NCAS-CMS/cfunits + +.. _number: diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index adf5b79f..d623ade2 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -622,10 +622,14 @@ def check_shape(cube, dates, dates_in_data): check_shape(cube, dates, dates_in_data) - def check_dates_in_data(lst, lst2): - lst2 = [np.datetime64(_) for _ in lst2] - lst = [np.datetime64(_) for _ in lst] - assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2) + def check_dates_in_data(dates_in_data, requested_dates): + requested_dates = [np.datetime64(_) for _ in requested_dates] + dates_in_data = [np.datetime64(_) for _ in dates_in_data] + assert dates_in_data == requested_dates, ( + "Dates in data are not the requested ones:", + dates_in_data, + requested_dates, + ) check_dates_in_data(dates_in_data, dates) diff --git a/src/anemoi/datasets/create/functions/filters/orog_to_z.py b/src/anemoi/datasets/create/functions/filters/orog_to_z.py new file mode 100644 index 00000000..ddfc3cc2 --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/orog_to_z.py @@ -0,0 +1,58 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from collections import defaultdict + +from earthkit.data.indexing.fieldlist import FieldArray + + +class NewDataField: + def __init__(self, field, data, new_name): + self.field = field + self.data = data + self.new_name = new_name + + def to_numpy(self, *args, **kwargs): + return self.data + + def metadata(self, key=None, **kwargs): + if key is None: + return self.field.metadata(**kwargs) + + value = self.field.metadata(key, **kwargs) + if key == "param": + return self.new_name + return value + + def __getattr__(self, name): + return getattr(self.field, name) + + +def execute(context, input, orog, z="z"): + """Convert orography [m] to z (geopotential height)""" + result = FieldArray() + + processed_fields = defaultdict(dict) + + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + if param == orog: + key = tuple(key.items()) + + if param in processed_fields[key]: + raise ValueError(f"Duplicate field {param} for {key}") + + output = f.to_numpy(flatten=True) * 9.80665 + result.append(NewDataField(f, output, z)) + else: + result.append(f) + + return result diff --git a/src/anemoi/datasets/create/functions/filters/sum.py b/src/anemoi/datasets/create/functions/filters/sum.py new file mode 100644 index 00000000..083c9967 --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/sum.py @@ -0,0 +1,71 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from collections import defaultdict + +from earthkit.data.indexing.fieldlist import FieldArray + + +class NewDataField: + def __init__(self, field, data, new_name): + self.field = field + self.data = data + self.new_name = new_name + + def to_numpy(self, *args, **kwargs): + return self.data + + def metadata(self, key=None, **kwargs): + if key is None: + return self.field.metadata(**kwargs) + + value = self.field.metadata(key, **kwargs) + if key == "param": + return self.new_name + return value + + def __getattr__(self, name): + return getattr(self.field, name) + + +def execute(context, input, params, output): + """Computes the sum over a set of variables""" + result = FieldArray() + + needed_fields = defaultdict(dict) + + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + if param in params: + key = tuple(key.items()) + + if param in needed_fields[key]: + raise ValueError(f"Duplicate field {param} for {key}") + + needed_fields[key][param] = f + else: + result.append(f) + + for keys, values in needed_fields.items(): + + if len(values) != len(params): + raise ValueError("Missing fields") + + s = None + for k, v in values.items(): + c = v.to_numpy(flatten=True) + if s is None: + s = c + else: + s += c + result.append(NewDataField(values[list(values.keys())[0]], s, output)) + + return result diff --git a/src/anemoi/datasets/create/functions/filters/wz_to_w.py b/src/anemoi/datasets/create/functions/filters/wz_to_w.py new file mode 100644 index 00000000..b108035a --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/wz_to_w.py @@ -0,0 +1,79 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from collections import defaultdict + +from earthkit.data.indexing.fieldlist import FieldArray + + +class NewDataField: + def __init__(self, field, data, new_name): + self.field = field + self.data = data + self.new_name = new_name + + def to_numpy(self, *args, **kwargs): + return self.data + + def metadata(self, key=None, **kwargs): + if key is None: + return self.field.metadata(**kwargs) + + value = self.field.metadata(key, **kwargs) + if key == "param": + return self.new_name + return value + + def __getattr__(self, name): + return getattr(self.field, name) + + +def execute(context, input, wz, t, w="w"): + """Convert geometric vertical velocity (m/s) to vertical velocity (Pa / s)""" + result = FieldArray() + + params = (wz, t) + pairs = defaultdict(dict) + + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + if param in params: + key = tuple(key.items()) + + if param in pairs[key]: + raise ValueError(f"Duplicate field {param} for {key}") + + pairs[key][param] = f + if param == t: + result.append(f) + else: + result.append(f) + + for keys, values in pairs.items(): + + if len(values) != 2: + raise ValueError("Missing fields") + + wz_pl = values[wz].to_numpy(flatten=True) + t_pl = values[t].to_numpy(flatten=True) + pressure = keys[4][1] * 100 # TODO: REMOVE HARDCODED INDICES + + w_pl = wz_to_w(wz_pl, t_pl, pressure) + result.append(NewDataField(values[wz], w_pl, w)) + + return result + + +def wz_to_w(wz, t, pressure): + g = 9.81 + Rd = 287.058 + + return -wz * g * pressure / (t * Rd) diff --git a/src/anemoi/datasets/create/functions/sources/accumulations.py b/src/anemoi/datasets/create/functions/sources/accumulations.py index 1fc22fea..7b583078 100644 --- a/src/anemoi/datasets/create/functions/sources/accumulations.py +++ b/src/anemoi/datasets/create/functions/sources/accumulations.py @@ -379,6 +379,7 @@ def accumulations(context, dates, **request): KWARGS = { ("od", "oper"): dict(patch=_scda), ("od", "elda"): dict(base_times=(6, 18)), + ("od", "enfo"): dict(base_times=(0, 6, 12, 18)), ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)), ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)), ("rr", "oper"): dict(base_times=(0, 3, 6, 9, 12, 15, 18, 21)), diff --git a/src/anemoi/datasets/create/functions/sources/xarray/__init__.py b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py index fda14c1f..4a3229a0 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/__init__.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py @@ -29,7 +29,7 @@ def check(what, ds, paths, **kwargs): raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})") -def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs): +def load_one(emoji, context, dates, dataset, *, options={}, flavour=None, patch=None, **kwargs): import xarray as xr """ @@ -54,10 +54,10 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs) else: data = xr.open_dataset(dataset, **options) - fs = XarrayFieldList.from_xarray(data, flavour) + fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch) if len(dates) == 0: - return fs.sel(**kwargs) + result = fs.sel(**kwargs) else: result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates]) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/field.py b/src/anemoi/datasets/create/functions/sources/xarray/field.py index 257e2932..05cfed72 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/field.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/field.py @@ -92,6 +92,10 @@ def _metadata(self): def grid_points(self): return self.owner.grid_points() + def to_latlon(self, flatten=True): + assert flatten + return dict(lat=self.latitudes, lon=self.longitudes) + @property def resolution(self): return None @@ -120,6 +124,6 @@ def forecast_reference_time(self): def __repr__(self): return repr(self._metadata) - def _values(self): + def _values(self, dtype=None): # we don't use .values as this will download the data return self.selection diff --git a/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py b/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py index 716f7b6b..305d01a4 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py @@ -16,6 +16,7 @@ from .field import EmptyFieldList from .flavour import CoordinateGuesser +from .patch import patch_dataset from .time import Time from .variable import FilteredVariable from .variable import Variable @@ -49,7 +50,11 @@ def __getitem__(self, i): raise IndexError(k) @classmethod - def from_xarray(cls, ds, flavour=None): + def from_xarray(cls, ds, *, flavour=None, patch=None): + + if patch is not None: + ds = patch_dataset(ds, patch) + variables = [] if isinstance(flavour, str): @@ -83,6 +88,8 @@ def _skip_attr(v, attr_name): _skip_attr(variable, "bounds") _skip_attr(variable, "grid_mapping") + LOG.debug("Xarray data_vars: %s", ds.data_vars) + # Select only geographical variables for name in ds.data_vars: @@ -97,6 +104,7 @@ def _skip_attr(v, attr_name): c = guess.guess(ds[coord], coord) assert c, f"Could not guess coordinate for {coord}" if coord not in variable.dims: + LOG.debug("%s: coord=%s (not a dimension): dims=%s", variable, coord, variable.dims) c.is_dim = False coordinates.append(c) @@ -104,6 +112,7 @@ def _skip_attr(v, attr_name): assert grid_coords <= 2 if grid_coords < 2: + LOG.debug("Skipping %s (not 2D): %s", variable, [(c, c.is_grid, c.is_dim) for c in coordinates]) continue v = Variable( diff --git a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py index 6744ace9..ca574001 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py @@ -24,6 +24,7 @@ class _MDMapping: def __init__(self, variable): self.variable = variable self.time = variable.time + # Aliases self.mapping = dict(param="variable") for c in variable.coordinates: for v in c.mars_names: @@ -34,7 +35,6 @@ def _from_user(self, key): return self.mapping.get(key, key) def from_user(self, kwargs): - print("from_user", kwargs, self) return {self._from_user(k): v for k, v in kwargs.items()} def __repr__(self): @@ -81,22 +81,16 @@ def _base_datetime(self): def _valid_datetime(self): return self._get("valid_datetime") - def _get(self, key, **kwargs): + def get(self, key, astype=None, **kwargs): if key in self._d: + if astype is not None: + return astype(self._d[key]) return self._d[key] - if key.startswith("mars."): - key = key[5:] - if key not in self.MARS_KEYS: - if kwargs.get("raise_on_missing", False): - raise KeyError(f"Invalid key '{key}' in namespace='mars'") - else: - return kwargs.get("default", None) - key = self._mapping._from_user(key) - return super()._get(key, **kwargs) + return super().get(key, astype=astype, **kwargs) class XArrayFieldGeography(Geography): diff --git a/src/anemoi/datasets/create/functions/sources/xarray/patch.py b/src/anemoi/datasets/create/functions/sources/xarray/patch.py new file mode 100644 index 00000000..dbe2b59c --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/patch.py @@ -0,0 +1,44 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +LOG = logging.getLogger(__name__) + + +def patch_attributes(ds, attributes): + for name, value in attributes.items(): + variable = ds[name] + variable.attrs.update(value) + + return ds + + +def patch_coordinates(ds, coordinates): + for name in coordinates: + ds = ds.assign_coords({name: ds[name]}) + + return ds + + +PATCHES = { + "attributes": patch_attributes, + "coordinates": patch_coordinates, +} + + +def patch_dataset(ds, patch): + for what, values in patch.items(): + if what not in PATCHES: + raise ValueError(f"Unknown patch type {what!r}") + + ds = PATCHES[what](ds, values) + + return ds diff --git a/src/anemoi/datasets/create/functions/sources/xarray/time.py b/src/anemoi/datasets/create/functions/sources/xarray/time.py index dcee1d3f..533408ad 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/time.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/time.py @@ -62,12 +62,18 @@ def from_coordinates(cls, coordinates): raise NotImplementedError(f"{len(date_coordinate)=} {len(time_coordinate)=} {len(step_coordinate)=}") + def select_valid_datetime(self, variable): + raise NotImplementedError(f"{self.__class__.__name__}.select_valid_datetime()") + class Constant(Time): def fill_time_metadata(self, coords_values, metadata): return None + def select_valid_datetime(self, variable): + return None + class Analysis(Time): @@ -83,6 +89,9 @@ def fill_time_metadata(self, coords_values, metadata): return valid_datetime + def select_valid_datetime(self, variable): + return self.time_coordinate_name + class ForecastFromValidTimeAndStep(Time): @@ -116,6 +125,9 @@ def fill_time_metadata(self, coords_values, metadata): return valid_datetime + def select_valid_datetime(self, variable): + return self.time_coordinate_name + class ForecastFromValidTimeAndBaseTime(Time): @@ -138,6 +150,9 @@ def fill_time_metadata(self, coords_values, metadata): return valid_datetime + def select_valid_datetime(self, variable): + return self.time_coordinate_name + class ForecastFromBaseTimeAndDate(Time): diff --git a/src/anemoi/datasets/create/functions/sources/xarray/variable.py b/src/anemoi/datasets/create/functions/sources/xarray/variable.py index 7765c61f..e8086c5e 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/variable.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/variable.py @@ -37,7 +37,7 @@ def __init__( self.coordinates = coordinates self._metadata = metadata.copy() - self._metadata.update({"variable": variable.name}) + self._metadata.update({"variable": variable.name, "param": variable.name}) self.time = time @@ -45,6 +45,9 @@ def __init__( self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid} self.by_name = {c.variable.name: c for c in coordinates} + # We need that alias for the time dimension + self._aliases = dict(valid_datetime="time") + self.length = math.prod(self.shape) @property @@ -96,15 +99,28 @@ def sel(self, missing, **kwargs): k, v = kwargs.popitem() + user_provided_k = k + + if k == "valid_datetime": + # Ask the Time object to select the valid datetime + k = self.time.select_valid_datetime(self) + if k is None: + return None + c = self.by_name.get(k) + # assert c is not None, f"Could not find coordinate {k} in {self.variable.name} {self.coordinates} {list(self.by_name)}" + if c is None: missing[k] = v return self.sel(missing, **kwargs) i = c.index(v) if i is None: - LOG.warning(f"Could not find {k}={v} in {c}") + if k != user_provided_k: + LOG.warning(f"Could not find {user_provided_k}={v} in {c} (alias of {k})") + else: + LOG.warning(f"Could not find {k}={v} in {c}") return None coordinates = [x.reduced(i) if c is x else x for x in self.coordinates] diff --git a/src/anemoi/datasets/create/utils.py b/src/anemoi/datasets/create/utils.py index 578e31a1..ce6df19e 100644 --- a/src/anemoi/datasets/create/utils.py +++ b/src/anemoi/datasets/create/utils.py @@ -54,6 +54,10 @@ def to_datetime(*args, **kwargs): def make_list_int(value): + # Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers. + # Moved to anemoi.utils.humanize + # replace with from anemoi.utils.humanize import make_list_int + # when anemoi-utils is released and pyproject.toml is updated if isinstance(value, str): if "/" not in value: return [value] diff --git a/src/anemoi/datasets/data/complement.py b/src/anemoi/datasets/data/complement.py new file mode 100644 index 00000000..ee324dd0 --- /dev/null +++ b/src/anemoi/datasets/data/complement.py @@ -0,0 +1,164 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from functools import cached_property + +from ..grids import nearest_grid_points +from .debug import Node +from .forwards import Combined +from .indexing import apply_index_to_slices_changes +from .indexing import index_to_slices +from .indexing import update_tuple +from .misc import _auto_adjust +from .misc import _open + +LOG = logging.getLogger(__name__) + + +class Complement(Combined): + + def __init__(self, target, source, what="variables", interpolation="nearest"): + super().__init__([target, source]) + + # We had the variables of dataset[1] to dataset[0] + # interpoated on the grid of dataset[0] + + self.target = target + self.source = source + + self._variables = [] + + # Keep the same order as the original dataset + for v in self.source.variables: + if v not in self.target.variables: + self._variables.append(v) + + if not self._variables: + raise ValueError("Augment: no missing variables") + + @property + def variables(self): + return self._variables + + @property + def name_to_index(self): + return {v: i for i, v in enumerate(self.variables)} + + @property + def shape(self): + shape = self.target.shape + return (shape[0], len(self._variables)) + shape[2:] + + @property + def variables_metadata(self): + return {k: v for k, v in self.source.variables_metadata.items() if k in self._variables} + + def check_same_variables(self, d1, d2): + pass + + @cached_property + def missing(self): + missing = self.source.missing.copy() + missing = missing | self.target.missing + return set(missing) + + def tree(self): + """Generates a hierarchical tree structure for the `Cutout` instance and + its associated datasets. + + Returns: + Node: A `Node` object representing the `Cutout` instance as the root + node, with each dataset in `self.datasets` represented as a child + node. + """ + return Node(self, [d.tree() for d in (self.target, self.source)]) + + def __getitem__(self, index): + if isinstance(index, (int, slice)): + index = (index, slice(None), slice(None), slice(None)) + return self._get_tuple(index) + + +class ComplementNone(Complement): + + def __init__(self, target, source): + super().__init__(target, source) + + def _get_tuple(self, index): + index, changes = index_to_slices(index, self.shape) + result = self.source[index] + return apply_index_to_slices_changes(result, changes) + + +class ComplementNearest(Complement): + + def __init__(self, target, source): + super().__init__(target, source) + + self._nearest_grid_points = nearest_grid_points( + self.source.latitudes, + self.source.longitudes, + self.target.latitudes, + self.target.longitudes, + ) + + def check_compatibility(self, d1, d2): + pass + + def _get_tuple(self, index): + variable_index = 1 + index, changes = index_to_slices(index, self.shape) + index, previous = update_tuple(index, variable_index, slice(None)) + source_index = [self.source.name_to_index[x] for x in self.variables[previous]] + source_data = self.source[index[0], source_index, index[2], ...] + target_data = source_data[..., self._nearest_grid_points] + + result = target_data[..., index[3]] + + return apply_index_to_slices_changes(result, changes) + + +def complement_factory(args, kwargs): + from .select import Select + + assert len(args) == 0, args + + source = kwargs.pop("source") + target = kwargs.pop("complement") + what = kwargs.pop("what", "variables") + interpolation = kwargs.pop("interpolation", "none") + + if what != "variables": + raise NotImplementedError(f"Complement what={what} not implemented") + + if interpolation not in ("none", "nearest"): + raise NotImplementedError(f"Complement method={interpolation} not implemented") + + source = _open(source) + target = _open(target) + # `select` is the same as `variables` + (source, target), kwargs = _auto_adjust([source, target], kwargs, exclude=["select"]) + + Class = { + None: ComplementNone, + "none": ComplementNone, + "nearest": ComplementNearest, + }[interpolation] + + complement = Class(target=target, source=source)._subset(**kwargs) + + # Will join the datasets along the variables axis + reorder = source.variables + complemented = _open([target, complement]) + ordered = ( + Select(complemented, complemented._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate() + ) + return ordered diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 0cb42959..6036e630 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -168,6 +168,16 @@ def __subset(self, **kwargs): bbox = kwargs.pop("area") return Cropping(self, bbox)._subset(**kwargs).mutate() + if "number" in kwargs or "numbers" or "member" in kwargs or "members" in kwargs: + from .ensemble import Number + + members = {} + for key in ["number", "numbers", "member", "members"]: + if key in kwargs: + members[key] = kwargs.pop(key) + + return Number(self, **members)._subset(**kwargs).mutate() + if "set_missing_dates" in kwargs: from .missing import MissingDates @@ -251,13 +261,19 @@ def _drop_to_columns(self, vars): return sorted([v for k, v in self.name_to_index.items() if k not in vars]) def _reorder_to_columns(self, vars): + if isinstance(vars, str) and vars == "sort": + # Sorting the variables alphabetically. + # This is cruical for pre-training then transfer learning in combination with + # cutout and adjust = 'all' + + indices = [self.name_to_index[k] for k, v in sorted(self.name_to_index.items(), key=lambda x: x[0])] + assert set(indices) == set(range(len(self.name_to_index))) + return indices + if isinstance(vars, (list, tuple)): vars = {k: i for i, k in enumerate(vars)} - indices = [] - - for k, v in sorted(vars.items(), key=lambda x: x[1]): - indices.append(self.name_to_index[k]) + indices = [self.name_to_index[k] for k, v in sorted(vars.items(), key=lambda x: x[1])] # Make sure we don't forget any variables assert set(indices) == set(range(len(self.name_to_index))) @@ -502,3 +518,50 @@ def _compute_constant_fields_from_statistics(self): result.append(v) return result + + def plot(self, date, variable, member=0, **kwargs): + """For debugging purposes, plot a field. + + Parameters + ---------- + date : int or datetime.datetime or numpy.datetime64 or str + The date to plot. + variable : int or str + The variable to plot. + member : int, optional + The ensemble member to plot. + + **kwargs: + Additional arguments to pass to matplotlib.pyplot.tricontourf + + + Returns + ------- + matplotlib.pyplot.Axes + """ + + from anemoi.utils.devtools import plot_values + from earthkit.data.utils.dates import to_datetime + + if not isinstance(date, int): + date = np.datetime64(to_datetime(date)).astype(self.dates[0].dtype) + index = np.where(self.dates == date)[0] + if len(index) == 0: + raise ValueError( + f"Date {date} not found in the dataset {self.dates[0]} to {self.dates[-1]} by {self.frequency}" + ) + date_index = index[0] + else: + date_index = date + + if isinstance(variable, int): + variable_index = variable + else: + if variable not in self.variables: + raise ValueError(f"Unknown variable {variable} (available: {self.variables})") + + variable_index = self.name_to_index[variable] + + values = self[date_index, variable_index, member] + + return plot_values(values, self.latitudes, self.longitudes, **kwargs) diff --git a/src/anemoi/datasets/data/ensemble.py b/src/anemoi/datasets/data/ensemble.py index 45cb07ca..460923db 100644 --- a/src/anemoi/datasets/data/ensemble.py +++ b/src/anemoi/datasets/data/ensemble.py @@ -10,13 +10,68 @@ import logging +import numpy as np + from .debug import Node +from .forwards import Forwards from .forwards import GivenAxis +from .indexing import apply_index_to_slices_changes +from .indexing import index_to_slices +from .indexing import update_tuple from .misc import _auto_adjust from .misc import _open LOG = logging.getLogger(__name__) +OFFSETS = dict(number=1, numbers=1, member=0, members=0) + + +class Number(Forwards): + def __init__(self, forward, **kwargs): + super().__init__(forward) + + self.members = [] + for key, values in kwargs.items(): + if not isinstance(values, (list, tuple)): + values = [values] + self.members.extend([int(v) - OFFSETS[key] for v in values]) + + self.members = sorted(set(self.members)) + for n in self.members: + if not (0 <= n < forward.shape[2]): + raise ValueError(f"Member {n} is out of range. `number(s)` is one-based, `member(s)` is zero-based.") + + self.mask = np.array([n in self.members for n in range(forward.shape[2])], dtype=bool) + self._shape, _ = update_tuple(forward.shape, 2, len(self.members)) + + @property + def shape(self): + return self._shape + + def __getitem__(self, index): + if isinstance(index, int): + result = self.forward[index] + result = result[:, self.mask, :] + return result + + if isinstance(index, slice): + result = self.forward[index] + result = result[:, :, self.mask, :] + return result + + index, changes = index_to_slices(index, self.shape) + result = self.forward[index] + result = result[:, :, self.mask, :] + return apply_index_to_slices_changes(result, changes) + + def tree(self): + return Node(self, [self.forward.tree()], numbers=[n + 1 for n in self.members]) + + def metadata_specific(self): + return { + "numbers": [n + 1 for n in self.members], + } + class Ensemble(GivenAxis): def tree(self): diff --git a/src/anemoi/datasets/data/join.py b/src/anemoi/datasets/data/join.py index 6b7de3e6..13c64e19 100644 --- a/src/anemoi/datasets/data/join.py +++ b/src/anemoi/datasets/data/join.py @@ -118,6 +118,7 @@ def variables(self): def variables_metadata(self): result = {} variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))] + for d in self.datasets: md = d.variables_metadata for v in variables: @@ -130,8 +131,6 @@ def variables_metadata(self): if v not in result: LOG.error("Missing metadata for %r.", v) - raise ValueError("Some variables are missing metadata.") - return result @cached_property diff --git a/src/anemoi/datasets/data/merge.py b/src/anemoi/datasets/data/merge.py index 6921c2be..f9d4dbc3 100644 --- a/src/anemoi/datasets/data/merge.py +++ b/src/anemoi/datasets/data/merge.py @@ -134,6 +134,9 @@ def check_compatibility(self, d1, d2): def tree(self): return Node(self, [d.tree() for d in self.datasets], allow_gaps_in_dates=self.allow_gaps_in_dates) + def metadata_specific(self): + return {"allow_gaps_in_dates": self.allow_gaps_in_dates} + @debug_indexing def __getitem__(self, n): if isinstance(n, tuple): diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index aad751f0..2d2493a3 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -194,7 +194,7 @@ def _open(a): raise NotImplementedError(f"Unsupported argument: {type(a)}") -def _auto_adjust(datasets, kwargs): +def _auto_adjust(datasets, kwargs, exclude=None): if "adjust" not in kwargs: return datasets, kwargs @@ -214,6 +214,9 @@ def _auto_adjust(datasets, kwargs): for a in adjust_list: adjust_set.update(ALIASES.get(a, [a])) + if exclude is not None: + adjust_set -= set(exclude) + extra = set(adjust_set) - set(ALIASES["all"]) if extra: raise ValueError(f"Invalid adjust keys: {extra}") @@ -335,6 +338,12 @@ def _open_dataset(*args, **kwargs): assert not sets, sets return cutout_factory(args, kwargs).mutate() + if "complement" in kwargs: + from .complement import complement_factory + + assert not sets, sets + return complement_factory(args, kwargs).mutate() + for name in ("datasets", "dataset"): if name in kwargs: datasets = kwargs.pop(name) diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py index 9e3a10b6..e5d58f0c 100644 --- a/src/anemoi/datasets/grids.py +++ b/src/anemoi/datasets/grids.py @@ -152,7 +152,7 @@ def cutout_mask( plot=None, ): """Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]""" - from scipy.spatial import KDTree + from scipy.spatial import cKDTree # TODO: transform min_distance from lat/lon to xyz @@ -195,13 +195,13 @@ def cutout_mask( min_distance = min_distance_km / 6371.0 else: points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km] - distances, _ = KDTree(points).query(points, k=2) + distances, _ = cKDTree(points).query(points, k=2) min_distance = np.min(distances[:, 1]) LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km") - # Use a KDTree to find the nearest points - distances, indices = KDTree(lam_points).query(global_points, k=neighbours) + # Use a cKDTree to find the nearest points + distances, indices = cKDTree(lam_points).query(global_points, k=neighbours) # Centre of the Earth zero = np.array([0.0, 0.0, 0.0]) @@ -255,7 +255,7 @@ def thinning_mask( cropping_distance=2.0, ): """Return the list of points in [lats, lons] closest to [global_lats, global_lons]""" - from scipy.spatial import KDTree + from scipy.spatial import cKDTree assert global_lats.ndim == 1 assert global_lons.ndim == 1 @@ -291,20 +291,20 @@ def thinning_mask( xyx = latlon_to_xyz(lats, lons) points = np.array(xyx).transpose() - # Use a KDTree to find the nearest points - _, indices = KDTree(points).query(global_points, k=1) + # Use a cKDTree to find the nearest points + _, indices = cKDTree(points).query(global_points, k=1) return np.array([i for i in indices]) def outline(lats, lons, neighbours=5): - from scipy.spatial import KDTree + from scipy.spatial import cKDTree xyx = latlon_to_xyz(lats, lons) grid_points = np.array(xyx).transpose() - # Use a KDTree to find the nearest points - _, indices = KDTree(grid_points).query(grid_points, k=neighbours) + # Use a cKDTree to find the nearest points + _, indices = cKDTree(grid_points).query(grid_points, k=neighbours) # Centre of the Earth zero = np.array([0.0, 0.0, 0.0]) @@ -379,6 +379,19 @@ def serialise_mask(mask): return result +def nearest_grid_points(source_latitudes, source_longitudes, target_latitudes, target_longitudes): + from scipy.spatial import cKDTree + + source_xyz = latlon_to_xyz(source_latitudes, source_longitudes) + source_points = np.array(source_xyz).transpose() + + target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) + target_points = np.array(target_xyz).transpose() + + _, indices = cKDTree(source_points).query(target_points, k=1) + return indices + + if __name__ == "__main__": global_lats, global_lons = np.meshgrid( np.linspace(90, -90, 90), diff --git a/tests/xarray/test_samples.py b/tests/xarray/test_samples.py index 8eb3542e..b4786b2c 100644 --- a/tests/xarray/test_samples.py +++ b/tests/xarray/test_samples.py @@ -17,7 +17,7 @@ from anemoi.datasets.create.functions.sources.xarray import XarrayFieldList from anemoi.datasets.testing import assert_field_list -URL = "https://object-store.os-api.cci1.ecmwf.int/ml-tests/test-data/samples/" +URL = "https://object-store.os-api.cci1.ecmwf.int/ml-tests/test-data/samples/anemoi-datasets/" SAMPLES = list(range(23)) SKIP = [0, 1, 2, 3, 4, 22] diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index 5847063c..e3988542 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -73,7 +73,7 @@ def test_weatherbench(): "levtype": "pl", } - fs = XarrayFieldList.from_xarray(ds, flavour) + fs = XarrayFieldList.from_xarray(ds, flavour=flavour) assert_field_list( fs, @@ -116,7 +116,7 @@ def test_noaa_replay(): "levtype": "pl", } - fs = XarrayFieldList.from_xarray(ds, flavour) + fs = XarrayFieldList.from_xarray(ds, flavour=flavour) assert_field_list( fs, @@ -141,7 +141,7 @@ def test_planetary_computer_conus404(): }, } - fs = XarrayFieldList.from_xarray(ds, flavour) + fs = XarrayFieldList.from_xarray(ds, flavour=flavour) assert_field_list( fs,