Skip to content

Commit c431542

Browse files
committed
fix netcdf tests
1 parent 316c6bc commit c431542

File tree

8 files changed

+58
-16
lines changed

8 files changed

+58
-16
lines changed

climetlab/readers/grib/index/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs):
4444

4545
@classmethod
4646
def new_mask_index(self, *args, **kwargs):
47-
return MaskFieldSet(*args, **kwargs)
47+
return GribMaskFieldSet(*args, **kwargs)
4848

4949
@property
5050
def availability_path(self):
@@ -53,7 +53,7 @@ def availability_path(self):
5353
@classmethod
5454
def merge(cls, sources):
5555
assert all(isinstance(_, GribFieldSet) for _ in sources)
56-
return MultiFieldSet(sources)
56+
return GribMultiFieldSet(sources)
5757

5858
def available(self, request, as_list_of_dicts=False):
5959
from climetlab.utils.availability import Availability
@@ -152,12 +152,12 @@ def _normalize_kwargs_names(self, **kwargs):
152152
return kwargs
153153

154154

155-
class MaskFieldSet(GribFieldSet, MaskIndex):
155+
class GribMaskFieldSet(GribFieldSet, MaskIndex):
156156
def __init__(self, *args, **kwargs):
157157
MaskIndex.__init__(self, *args, **kwargs)
158158

159159

160-
class MultiFieldSet(GribFieldSet, MultiIndex):
160+
class GribMultiFieldSet(GribFieldSet, MultiIndex):
161161
def __init__(self, *args, **kwargs):
162162
MultiIndex.__init__(self, *args, **kwargs)
163163

climetlab/readers/grib/output.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ def update_metadata(self, handle, metadata, compulsary):
199199
if "number" in metadata:
200200
compulsary += ("numberOfForecastsInEnsemble",)
201201
productDefinitionTemplateNumber = {"tp": 11}
202-
metadata["productDefinitionTemplateNumber"] = (
203-
productDefinitionTemplateNumber.get(handle.get("shortName"), 1)
204-
)
202+
metadata[
203+
"productDefinitionTemplateNumber"
204+
] = productDefinitionTemplateNumber.get(handle.get("shortName"), 1)
205205

206206
if metadata.get("type") in ("pf", "cf"):
207207
metadata.setdefault("typeOfGeneratingProcess", 4)

climetlab/readers/grib/reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111

1212
from climetlab.readers import Reader
13-
from climetlab.readers.grib.index import MultiFieldSet
13+
from climetlab.readers.grib.index import GribMultiFieldSet
1414
from climetlab.readers.grib.index.file import FieldSetInOneFile
1515

1616
LOG = logging.getLogger(__name__)
@@ -31,7 +31,7 @@ def merge(cls, readers):
3131
assert all(isinstance(s, GRIBReader) for s in readers), readers
3232
assert len(readers) > 1
3333

34-
return MultiFieldSet(readers)
34+
return GribMultiFieldSet(readers)
3535

3636
def mutate_source(self):
3737
# A GRIBReader is a source itself

climetlab/readers/netcdf/fieldset.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from functools import cached_property
1111
from itertools import product
1212

13+
from climetlab.core.index import MaskIndex, MultiIndex
1314
from climetlab.indexing.fieldset import FieldSet
1415
from climetlab.utils.bbox import BoundingBox
1516
from climetlab.utils.dates import to_datetime
@@ -24,6 +25,10 @@ def __init__(self, path):
2425
self.path = path
2526
self.opendap = path.startswith("http")
2627

28+
@classmethod
29+
def new_mask_index(self, *args, **kwargs):
30+
return NetCDFMaskFieldSet(*args, **kwargs)
31+
2732
def __repr__(self):
2833
return "NetCDFReader(%s)" % (self.path,)
2934

@@ -40,6 +45,9 @@ def __getitem__(self, n):
4045
def dataset(self):
4146
import xarray as xr
4247

48+
if ".zarr" in self.path:
49+
return xr.open_zarr(self.path)
50+
4351
if self.opendap:
4452
return xr.open_dataset(self.path)
4553
else:
@@ -146,7 +154,7 @@ def _get_fields(self, ds): # noqa C901
146154
def to_xarray(self, **kwargs):
147155
import xarray as xr
148156

149-
if self.opendap:
157+
if self.path.startswith("http"):
150158
return xr.open_dataset(self.path, **kwargs)
151159
return type(self).to_xarray_multi_from_paths([self.path], **kwargs)
152160

@@ -185,3 +193,31 @@ def to_datetime_list(self):
185193

186194
def to_bounding_box(self):
187195
return BoundingBox.multi_merge([s.to_bounding_box() for s in self.fields])
196+
197+
@classmethod
198+
def merge(cls, sources):
199+
assert len(sources) > 1
200+
assert all(isinstance(_, NetCDFFieldSet) for _ in sources)
201+
return NetCDFMultiFieldSet(sources)
202+
203+
204+
class NetCDFMaskFieldSet(NetCDFFieldSet, MaskIndex):
205+
def __init__(self, *args, **kwargs):
206+
MaskIndex.__init__(self, *args, **kwargs)
207+
208+
209+
class NetCDFMultiFieldSet(NetCDFFieldSet, MultiIndex):
210+
def __init__(self, *args, **kwargs):
211+
MultiIndex.__init__(self, *args, **kwargs)
212+
self.paths = [s.path for s in args[0]]
213+
214+
def to_xarray(self, **kwargs):
215+
import xarray as xr
216+
if not kwargs:
217+
kwargs = dict(combine="by_coords")
218+
return xr.open_mfdataset(self.paths, **kwargs)
219+
220+
221+
@cached_property
222+
def dataset(self):
223+
return self.to_xarray(combine="by_coords")

climetlab/sources/indexed_urls.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import warnings
1111

1212
from climetlab.indexing import PerUrlIndex
13-
from climetlab.readers.grib.index import MultiFieldSet
13+
from climetlab.readers.grib.index import GribMultiFieldSet
1414
from climetlab.readers.grib.index.sql import FieldsetInFilesWithSqlIndex
1515
from climetlab.sources.indexed import IndexedSource
1616
from climetlab.utils.patterns import Pattern
@@ -60,7 +60,7 @@ def __init__(
6060
# This is to avoid keeping them on the request
6161
request.pop(used)
6262

63-
index = MultiFieldSet(
63+
index = GribMultiFieldSet(
6464
FieldsetInFilesWithSqlIndex.from_url(
6565
get_index_url(url, substitute_extension, index_extension),
6666
selection=request,

climetlab/utils/patterns.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,14 @@ def substitute(self, value, name):
8888
return self.format % value
8989

9090

91-
TYPES = {"": Any, "int": Int, "float": Float, "date": Datetime, "strftime": Datetime, "enum": Enum}
91+
TYPES = {
92+
"": Any,
93+
"int": Int,
94+
"float": Float,
95+
"date": Datetime,
96+
"strftime": Datetime,
97+
"enum": Enum,
98+
}
9299

93100

94101
class Constant:

tests/readers/test_netcdf_reader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def test_dummy_netcdf_4():
9797
@pytest.mark.long_test
9898
@pytest.mark.download
9999
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
100-
@pytest.mark.skipif(True, reason="Merging of netcdf files does not work yet")
101100
def test_multi():
102101
s1 = load_source(
103102
"cds",

tests/sources/test_merge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_nc_merge_custom(custom_merger):
7272
target2 = xr.open_mfdataset([s1.path, s2.path])
7373
assert target2.identical(merged)
7474

75-
@pytest.mark.skipif(True, reason="Merging of netcdf files does not work yet")
75+
7676
def test_nc_merge_var():
7777
s1 = load_source(
7878
"climetlab-testing",
@@ -125,7 +125,7 @@ def _merge_var_different_coords(kind1, kind2):
125125

126126
assert target.identical(merged)
127127

128-
@pytest.mark.skipif(True, reason="Merging of netcdf files does not work yet")
128+
129129
def test_nc_merge_var_different_coords():
130130
_merge_var_different_coords("netcdf", "netcdf")
131131

0 commit comments

Comments
 (0)