Skip to content

Commit 3d6b37c

Browse files
committed
make All shape work
1 parent 94c61b6 commit 3d6b37c

File tree

3 files changed

+80
-1
lines changed

3 files changed

+80
-1
lines changed

polytope/datacube/datacube_axis.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from abc import ABC, abstractmethod
23
from copy import deepcopy
34
from typing import Any, List
@@ -18,6 +19,10 @@ def update_range():
1819

1920
def to_intervals(range):
2021
update_range()
22+
if range[0] == -math.inf:
23+
range[0] = cls.range[0]
24+
if range[1] == math.inf:
25+
range[1] = cls.range[1]
2126
axis_lower = cls.range[0]
2227
axis_upper = cls.range[1]
2328
axis_range = axis_upper - axis_lower

polytope/shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __repr__(self):
8989
class Span(Shape):
9090
"""1-D range along a single axis"""
9191

92-
def __init__(self, axis, lower=None, upper=None):
92+
def __init__(self, axis, lower=-math.inf, upper=math.inf):
9393
assert not isinstance(lower, list)
9494
assert not isinstance(upper, list)
9595
self.axis = axis

tests/test_shapes.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
import xarray as xr
5+
6+
from polytope.datacube.backends.FDB_datacube import FDBDatacube
7+
from polytope.datacube.backends.xarray import XArrayDatacube
8+
from polytope.engine.hullslicer import HullSlicer
9+
from polytope.polytope import Polytope, Request
10+
from polytope.shapes import All, Select, Span
11+
12+
13+
class TestSlicing3DXarrayDatacube:
14+
def setup_method(self, method):
15+
# Create a dataarray with 3 labelled axes using different index types
16+
array = xr.DataArray(
17+
np.random.randn(3, 6, 129, 360),
18+
dims=("date", "step", "level", "longitude"),
19+
coords={
20+
"date": pd.date_range("2000-01-01", "2000-01-03", 3),
21+
"step": [0, 3, 6, 9, 12, 15],
22+
"level": range(1, 130),
23+
"longitude": range(0, 360),
24+
},
25+
)
26+
self.xarraydatacube = XArrayDatacube(array)
27+
self.options = {"longitude": {"transformation": {"cyclic": [0, 360]}}}
28+
self.slicer = HullSlicer()
29+
self.API = Polytope(datacube=array, engine=self.slicer, axis_options=self.options)
30+
31+
def test_all(self):
32+
request = Request(Select("step", [3]), Select("date", ["2000-01-01"]), All("level"), Select("longitude", [1]))
33+
result = self.API.retrieve(request)
34+
assert len(result.leaves) == 129
35+
36+
def test_all_cyclic(self):
37+
request = Request(Select("step", [3]), Select("date", ["2000-01-01"]), Select("level", [1]), All("longitude"))
38+
result = self.API.retrieve(request)
39+
# result.pprint()
40+
assert len(result.leaves) == 360
41+
42+
@pytest.mark.skip(reason="can't install fdb branch on CI")
43+
def test_all_mapper_cyclic(self):
44+
self.options = {
45+
"values": {
46+
"transformation": {
47+
"mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]}
48+
}
49+
},
50+
"date": {"transformation": {"merge": {"with": "time", "linkers": ["T", "00"]}}},
51+
"step": {"transformation": {"type_change": "int"}},
52+
"longitude": {"transformation": {"cyclic": [0, 360]}},
53+
}
54+
self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "step": 11}
55+
self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options)
56+
self.slicer = HullSlicer()
57+
self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options)
58+
59+
request = Request(
60+
Select("step", [11]),
61+
Select("levtype", ["sfc"]),
62+
Select("date", [pd.Timestamp("20230710T120000")]),
63+
Select("domain", ["g"]),
64+
Select("expver", ["0001"]),
65+
Select("param", ["151130"]),
66+
Select("class", ["od"]),
67+
Select("stream", ["oper"]),
68+
Select("type", ["fc"]),
69+
Span("latitude", 89.9, 90),
70+
All("longitude"),
71+
)
72+
result = self.API.retrieve(request)
73+
# result.pprint()
74+
assert len(result.leaves) == 20

0 commit comments

Comments
 (0)