diff --git a/.gitignore b/.gitignore index 7b8e19f62..ca7e60eb6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,8 @@ polytope.egg-info .pytest_cache *.prof *.idx +*.grib +*.xml +site +.coverage *.grib \ No newline at end of file diff --git a/performance/fdb_performance.py b/performance/fdb_performance.py new file mode 100644 index 000000000..78819d462 --- /dev/null +++ b/performance/fdb_performance.py @@ -0,0 +1,47 @@ +import time + +import pandas as pd + +from polytope.datacube.backends.fdb import FDBDatacube +from polytope.engine.hullslicer import HullSlicer +from polytope.polytope import Polytope, Request +from polytope.shapes import Box, Select + + +class TestSlicingFDBDatacube: + def setup_method(self, method): + # Create a dataarray with 3 labelled axes using different index types + self.options = { + "values": { + "transformation": { + "mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]} + } + }, + "date": {"transformation": {"merge": {"with": "time", "linkers": [" ", "00"]}}}, + "step": {"transformation": {"type_change": "int"}}, + } + self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "step": 0} + self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) + self.slicer = HullSlicer() + self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) + + # Testing different shapes + # @pytest.mark.skip(reason="can't install fdb branch on CI") + def test_fdb_datacube(self): + request = Request( + Select("step", [0]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20230625T120000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["an"]), + Box(["latitude", "longitude"], [0, 0], [10, 10]), + ) + time1 = time.time() + result = self.API.retrieve(request) + print("ENTIRE TIME") + print(time.time() - time1) + print(len(result.leaves)) diff --git a/performance/fdb_performance_3D.py b/performance/fdb_performance_3D.py new file mode 100644 index 000000000..547d865b0 --- /dev/null +++ b/performance/fdb_performance_3D.py @@ -0,0 +1,48 @@ +import time + +import pandas as pd + +from polytope.datacube.backends.fdb import FDBDatacube +from polytope.engine.hullslicer import HullSlicer +from polytope.polytope import Polytope, Request +from polytope.shapes import Box, Select, Span + + +class TestSlicingFDBDatacube: + def setup_method(self, method): + # Create a dataarray with 3 labelled axes using different index types + self.options = { + "values": { + "transformation": { + "mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]} + } + }, + "date": {"transformation": {"merge": {"with": "time", "linkers": [" ", "00"]}}}, + "step": {"transformation": {"type_change": "int"}}, + "levelist": {"transformation": {"type_change": "int"}}, + } + self.config = {"class": "od", "expver": "0001", "levtype": "sfc"} + self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) + self.slicer = HullSlicer() + self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) + + # Testing different shapes + # @pytest.mark.skip(reason="can't install fdb branch on CI") + def test_fdb_datacube(self): + request = Request( + Span("step", 1, 15), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20231102T000000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Box(["latitude", "longitude"], [0, 0], [3, 5]), + ) + time1 = time.time() + result = self.API.retrieve(request) + print("ENTIRE TIME") + print(time.time() - time1) + print(len(result.leaves)) diff --git a/performance/fdb_scalability_plot.py b/performance/fdb_scalability_plot.py new file mode 100644 index 000000000..7230fd47b --- /dev/null +++ b/performance/fdb_scalability_plot.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt + +fdb_time = [ + 7.6377081871032715 - 7.558288812637329, + 73.57192325592041 - 72.99611115455627, + 733.2706120014191 - 727.7059993743896, + 4808.3157522678375 - 4770.814565420151, +] +num_extracted_points = [1986, 19226, 191543, 1267134] + +# for the 1.3M points, we used 100 latitudes too...., maybe that's why it's not as linear... + +plt.plot(num_extracted_points, fdb_time, marker="o") +plt.xlabel("Number of extracted points") +plt.ylabel("Polytope extraction time (in s)") +plt.show() diff --git a/polytope/datacube/backends/FDB_datacube.py b/polytope/datacube/backends/FDB_datacube.py deleted file mode 100644 index 846ac840e..000000000 --- a/polytope/datacube/backends/FDB_datacube.py +++ /dev/null @@ -1,81 +0,0 @@ -from copy import deepcopy - -import pyfdb - -from .datacube import Datacube, IndexTree - - -def glue(path, unmap_path): - return {"t": 0} - - -def update_fdb_dataarray(fdb_dataarray): - fdb_dataarray["values"] = [0.0] - return fdb_dataarray - - -class FDBDatacube(Datacube): - def __init__(self, config={}, axis_options={}): - self.axis_options = axis_options - self.grid_mapper = None - self.axis_counter = 0 - self._axes = None - treated_axes = [] - self.non_complete_axes = [] - self.complete_axes = [] - self.blocked_axes = [] - self.transformation = None - self.fake_axes = [] - - partial_request = config - # Find values in the level 3 FDB datacube - # Will be in the form of a dictionary? {axis_name:values_available, ...} - fdb = pyfdb.FDB() - fdb_dataarray = fdb.axes(partial_request).as_dict() - dataarray = update_fdb_dataarray(fdb_dataarray) - self.dataarray = dataarray - - for name, values in dataarray.items(): - values.sort() - options = axis_options.get(name, {}) - self._check_and_add_axes(options, name, values) - treated_axes.append(name) - self.complete_axes.append(name) - - # add other options to axis which were just created above like "lat" for the mapper transformations for eg - for name in self._axes: - if name not in treated_axes: - options = axis_options.get(name, {}) - val = self._axes[name].type - self._check_and_add_axes(options, name, val) - - def get(self, requests: IndexTree): - for r in requests.leaves: - path = r.flatten() - path = self.remap_path(path) - if len(path.items()) == self.axis_counter: - # first, find the grid mapper transform - unmapped_path = {} - path_copy = deepcopy(path) - for key in path_copy: - axis = self._axes[key] - (path, unmapped_path) = axis.unmap_total_path_to_datacube(path, unmapped_path) - path = self.fit_path(path) - subxarray = glue(path, unmapped_path) - key = list(subxarray.keys())[0] - value = subxarray[key] - r.result = (key, value) - else: - r.remove_branch() - - def datacube_natural_indexes(self, axis, subarray): - indexes = subarray[axis.name] - return indexes - - def select(self, path, unmapped_path): - return self.dataarray - - def ax_vals(self, name): - for _name, values in self.dataarray.items(): - if _name == name: - return values diff --git a/polytope/datacube/backends/__init__.py b/polytope/datacube/backends/__init__.py new file mode 100644 index 000000000..63902115a --- /dev/null +++ b/polytope/datacube/backends/__init__.py @@ -0,0 +1 @@ +from ..backends.datacube import * diff --git a/polytope/datacube/backends/datacube.py b/polytope/datacube/backends/datacube.py index f6ffb041f..10b5a6613 100644 --- a/polytope/datacube/backends/datacube.py +++ b/polytope/datacube/backends/datacube.py @@ -37,8 +37,11 @@ def _create_axes(self, name, values, transformation_type_key, transformation_opt ) for blocked_axis in transformation.blocked_axes(): self.blocked_axes.append(blocked_axis) + for unwanted_axis in transformation.unwanted_axes(): + self.unwanted_axes.append(unwanted_axis) for axis_name in final_axis_names: self.complete_axes.append(axis_name) + self.fake_axes.append(axis_name) # if axis does not yet exist, create it # first need to change the values so that we have right type diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py new file mode 100644 index 000000000..0679dbdd0 --- /dev/null +++ b/polytope/datacube/backends/fdb.py @@ -0,0 +1,167 @@ +from copy import deepcopy + +import pyfdb + +from .datacube import Datacube, IndexTree + + +class FDBDatacube(Datacube): + def __init__(self, config={}, axis_options={}): + self.axis_options = axis_options + self.axis_counter = 0 + self._axes = None + treated_axes = [] + self.non_complete_axes = [] + self.complete_axes = [] + self.blocked_axes = [] + self.unwanted_axes = [] + self.fake_axes = [] + self.unwanted_path = {} + + partial_request = config + # Find values in the level 3 FDB datacube + # Will be in the form of a dictionary? {axis_name:values_available, ...} + self.fdb = pyfdb.FDB() + self.fdb_coordinates = self.fdb.axes(partial_request).as_dict() + self.fdb_coordinates["values"] = [] + for name, values in self.fdb_coordinates.items(): + values.sort() + options = axis_options.get(name, {}) + self._check_and_add_axes(options, name, values) + treated_axes.append(name) + self.complete_axes.append(name) + + # add other options to axis which were just created above like "lat" for the mapper transformations for eg + for name in self._axes: + if name not in treated_axes: + options = axis_options.get(name, {}) + val = self._axes[name].type + self._check_and_add_axes(options, name, val) + + def remove_unwanted_axes(self, leaf_path): + for axis in self.unwanted_axes: + leaf_path.pop(axis) + return leaf_path + + def get(self, requests: IndexTree, leaf_path={}): + # First when request node is root, go to its children + if requests.axis.name == "root": + for c in requests.children: + self.get(c) + # If request node has no children, we have a leaf so need to assign fdb values to it + else: + key_value_path = {requests.axis.name: requests.value} + ax = requests.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path |= key_value_path + if len(requests.children[0].children[0].children) == 0: + # remap this last key + self.get_2nd_last_values(requests, leaf_path) + + # Otherwise remap the path for this key and iterate again over children + else: + for c in requests.children: + self.get(c, leaf_path) + + def get_2nd_last_values(self, requests, leaf_path={}): + # In this function, we recursively loop over the last two layers of the tree and store the indices of the + # request ranges in those layers + lat_length = len(requests.children) + range_lengths = [False] * lat_length + current_start_idxs = [False] * lat_length + fdb_node_ranges = [False] * lat_length + for i in range(len(requests.children)): + lat_child = requests.children[i] + lon_length = len(lat_child.children) + range_lengths[i] = [1] * lon_length + current_start_idxs[i] = [None] * lon_length + fdb_node_ranges[i] = [[IndexTree.root] * lon_length] * lon_length + range_length = deepcopy(range_lengths[i]) + current_start_idx = deepcopy(current_start_idxs[i]) + fdb_range_nodes = deepcopy(fdb_node_ranges[i]) + key_value_path = {lat_child.axis.name: lat_child.value} + ax = lat_child.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path |= key_value_path + (range_lengths[i], current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf( + lat_child, leaf_path, range_length, current_start_idx, fdb_range_nodes + ) + self.give_fdb_val_to_node(leaf_path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) + + def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, fdb_range_n): + i = 0 + for c in requests.children: + # now c are the leaves of the initial tree + key_value_path = {c.axis.name: c.value} + ax = c.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path |= key_value_path + last_idx = key_value_path["values"] + if current_idx[i] is None: + current_idx[i] = last_idx + fdb_range_n[i][range_l[i] - 1] = c + else: + if last_idx == current_idx[i] + range_l[i]: + range_l[i] += 1 + fdb_range_n[i][range_l[i] - 1] = c + else: + key_value_path = {c.axis.name: c.value} + ax = c.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path |= key_value_path + i += 1 + current_start_idx = key_value_path["values"] + current_idx[i] = current_start_idx + return (range_l, current_idx, fdb_range_n) + + def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_range_nodes, lat_length): + (output_values, original_indices) = self.find_fdb_values( + leaf_path, range_lengths, current_start_idx, lat_length + ) + new_fdb_range_nodes = [] + new_range_lengths = [] + for j in range(lat_length): + for i in range(len(range_lengths[j])): + if current_start_idx[j][i] is not None: + new_fdb_range_nodes.append(fdb_range_nodes[j][i]) + new_range_lengths.append(range_lengths[j][i]) + sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] + sorted_range_lengths = [new_range_lengths[i] for i in original_indices] + for i in range(len(sorted_fdb_range_nodes)): + for k in range(sorted_range_lengths[i]): + n = sorted_fdb_range_nodes[i][k] + n.result = output_values[0][0][0][i][k] + + def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): + path.pop("values") + fdb_requests = [] + interm_request_ranges = [] + for i in range(lat_length): + for j in range(len(range_lengths[i])): + if current_start_idx[i][j] is not None: + current_request_ranges = (current_start_idx[i][j], current_start_idx[i][j] + range_lengths[i][j]) + interm_request_ranges.append(current_request_ranges) + request_ranges_with_idx = list(enumerate(interm_request_ranges)) + sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) + original_indices, sorted_request_ranges = zip(*sorted_list) + fdb_requests.append(tuple((path, sorted_request_ranges))) + output_values = self.fdb.extract(fdb_requests) + return (output_values, original_indices) + + def datacube_natural_indexes(self, axis, subarray): + indexes = subarray[axis.name] + return indexes + + def select(self, path, unmapped_path): + return self.fdb_coordinates + + def ax_vals(self, name): + return self.fdb_coordinates.get(name, None) diff --git a/polytope/datacube/backends/xarray.py b/polytope/datacube/backends/xarray.py index 321f0f644..650038e05 100644 --- a/polytope/datacube/backends/xarray.py +++ b/polytope/datacube/backends/xarray.py @@ -10,7 +10,6 @@ class XArrayDatacube(Datacube): def __init__(self, dataarray: xr.DataArray, axis_options={}): self.axis_options = axis_options - self.grid_mapper = None self.axis_counter = 0 self._axes = None self.dataarray = dataarray @@ -18,8 +17,8 @@ def __init__(self, dataarray: xr.DataArray, axis_options={}): self.non_complete_axes = [] self.complete_axes = [] self.blocked_axes = [] - self.transformation = None self.fake_axes = [] + self.unwanted_axes = [] for name, values in dataarray.coords.variables.items(): if name in dataarray.dims: options = axis_options.get(name, {}) @@ -48,14 +47,13 @@ def __init__(self, dataarray: xr.DataArray, axis_options={}): def get(self, requests: IndexTree): for r in requests.leaves: path = r.flatten() - path = self.remap_path(path) if len(path.items()) == self.axis_counter: # first, find the grid mapper transform unmapped_path = {} path_copy = deepcopy(path) for key in path_copy: axis = self._axes[key] - (path, unmapped_path) = axis.unmap_total_path_to_datacube(path, unmapped_path) + (path, unmapped_path) = axis.unmap_to_datacube(path, unmapped_path) path = self.fit_path(path) subxarray = self.dataarray.sel(path, method="nearest") subxarray = subxarray.sel(unmapped_path) diff --git a/polytope/datacube/datacube_axis.py b/polytope/datacube/datacube_axis.py index 625051f7c..a84bfd303 100644 --- a/polytope/datacube/datacube_axis.py +++ b/polytope/datacube/datacube_axis.py @@ -1,3 +1,4 @@ +import bisect import math from abc import ABC, abstractmethod from copy import deepcopy @@ -6,6 +7,8 @@ import numpy as np import pandas as pd +from ..utility.list_tools import bisect_left_cmp, bisect_right_cmp + def cyclic(cls): if cls.is_cyclic: @@ -120,19 +123,17 @@ def remap(range: List): def find_indexes(path, datacube): return old_find_indexes(path, datacube) - old_unmap_total_path_to_datacube = cls.unmap_total_path_to_datacube + old_unmap_path_key = cls.unmap_path_key - def unmap_total_path_to_datacube(path, unmapped_path): + def unmap_path_key(key_value_path, leaf_path, unwanted_path): + value = key_value_path[cls.name] for transform in cls.transformations: if isinstance(transform, DatacubeAxisCyclic): - transformation = transform - if cls.name == transformation.name: - old_val = path.get(cls.name, None) - path.pop(cls.name, None) - new_val = _remap_val_to_axis_range(old_val) - path[cls.name] = new_val - (path, unmapped_path) = old_unmap_total_path_to_datacube(path, unmapped_path) - return (path, unmapped_path) + if cls.name == transform.name: + new_val = _remap_val_to_axis_range(value) + key_value_path[cls.name] = new_val + key_value_path, leaf_path, unwanted_path = old_unmap_path_key(key_value_path, leaf_path, unwanted_path) + return (key_value_path, leaf_path, unwanted_path) old_unmap_to_datacube = cls.unmap_to_datacube @@ -185,8 +186,8 @@ def offset(range): cls.offset = offset cls.find_indexes = find_indexes cls.unmap_to_datacube = unmap_to_datacube - cls.unmap_total_path_to_datacube = unmap_total_path_to_datacube cls.find_indices_between = find_indices_between + cls.unmap_path_key = unmap_path_key return cls @@ -213,64 +214,46 @@ def unmap_to_datacube(path, unmapped_path): (path, unmapped_path) = old_unmap_to_datacube(path, unmapped_path) for transform in cls.transformations: if isinstance(transform, DatacubeMapper): - transformation = transform - if cls.name == transformation._mapped_axes()[0]: + if cls.name == transform._mapped_axes()[0]: # if we are on the first axis, then need to add the first val to unmapped_path first_val = path.get(cls.name, None) path.pop(cls.name, None) if cls.name not in unmapped_path: # if for some reason, the unmapped_path already has the first axis val, then don't update unmapped_path[cls.name] = first_val - if cls.name == transformation._mapped_axes()[1]: + if cls.name == transform._mapped_axes()[1]: # if we are on the second axis, then the val of the first axis is stored # inside unmapped_path so can get it from there second_val = path.get(cls.name, None) path.pop(cls.name, None) - first_val = unmapped_path.get(transformation._mapped_axes()[0], None) - unmapped_path.pop(transformation._mapped_axes()[0], None) + first_val = unmapped_path.get(transform._mapped_axes()[0], None) + unmapped_path.pop(transform._mapped_axes()[0], None) # if the first_val was not in the unmapped_path, then it's still in path if first_val is None: - first_val = path.get(transformation._mapped_axes()[0], None) - path.pop(transformation._mapped_axes()[0], None) + first_val = path.get(transform._mapped_axes()[0], None) + path.pop(transform._mapped_axes()[0], None) if first_val is not None and second_val is not None: - unmapped_idx = transformation.unmap(first_val, second_val) - unmapped_path[transformation.old_axis] = unmapped_idx + unmapped_idx = transform.unmap(first_val, second_val) + unmapped_path[transform.old_axis] = unmapped_idx return (path, unmapped_path) - old_unmap_total_path_to_datacube = cls.unmap_total_path_to_datacube + old_unmap_path_key = cls.unmap_path_key - def unmap_total_path_to_datacube(path, unmapped_path): - (path, unmapped_path) = old_unmap_total_path_to_datacube(path, unmapped_path) + def unmap_path_key(key_value_path, leaf_path, unwanted_path): + key_value_path, leaf_path, unwanted_path = old_unmap_path_key(key_value_path, leaf_path, unwanted_path) + value = key_value_path[cls.name] for transform in cls.transformations: if isinstance(transform, DatacubeMapper): - transformation = transform - if cls.name == transformation._mapped_axes()[0]: - # if we are on the first axis, then need to add the first val to unmapped_path - first_val = path.get(cls.name, None) - path.pop(cls.name, None) - if unmapped_path is None: - unmapped_path[cls.name] = first_val - elif cls.name not in unmapped_path: - # if for some reason, the unmapped_path already has the first axis val, then don't update - unmapped_path[cls.name] = first_val - if cls.name == transformation._mapped_axes()[1]: - # if we are on the second axis, then the val of the first axis is stored - # inside unmapped_path so can get it from there - second_val = path.get(cls.name, None) - path.pop(cls.name, None) - first_val = unmapped_path.get(transformation._mapped_axes()[0], None) - unmapped_path.pop(transformation._mapped_axes()[0], None) - # if the first_val was not in the unmapped_path, then it's still in path - if first_val is None: - first_val = path.get(transformation._mapped_axes()[0], None) - path.pop(transformation._mapped_axes()[0], None) - if first_val is not None and second_val is not None: - unmapped_idx = transformation.unmap(first_val, second_val) - unmapped_path[transformation.old_axis] = unmapped_idx - return (path, unmapped_path) - - def remap_to_requested(path, unmapped_path): - return (path, unmapped_path) + if cls.name == transform._mapped_axes()[0]: + unwanted_val = key_value_path[transform._mapped_axes()[0]] + unwanted_path[cls.name] = unwanted_val + if cls.name == transform._mapped_axes()[1]: + first_val = unwanted_path[transform._mapped_axes()[0]] + unmapped_idx = transform.unmap(first_val, value) + leaf_path.pop(transform._mapped_axes()[0], None) + key_value_path.pop(cls.name) + key_value_path[transform.old_axis] = unmapped_idx + return (key_value_path, leaf_path, unwanted_path) def find_indices_between(index_ranges, low, up, datacube, method=None): # TODO: add method for snappping @@ -288,21 +271,23 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): indexes_between = idxs[start:end] indexes_between_ranges.append(indexes_between) else: - indexes_between = [i for i in idxs if low <= i <= up] + axis_reversed = transform._axis_reversed[cls.name] + if not axis_reversed: + lower_idx = bisect.bisect_left(idxs, low) + upper_idx = bisect.bisect_right(idxs, up) + indexes_between = idxs[lower_idx:upper_idx] + else: + # TODO: do the custom bisect + end_idx = bisect_left_cmp(idxs, low, cmp=lambda x, y: x > y) + 1 + start_idx = bisect_right_cmp(idxs, up, cmp=lambda x, y: x > y) + indexes_between = idxs[start_idx:end_idx] indexes_between_ranges.append(indexes_between) return indexes_between_ranges - old_remap = cls.remap - - def remap(range): - return old_remap(range) - - cls.remap = remap cls.find_indexes = find_indexes cls.unmap_to_datacube = unmap_to_datacube - cls.remap_to_requested = remap_to_requested cls.find_indices_between = find_indices_between - cls.unmap_total_path_to_datacube = unmap_total_path_to_datacube + cls.unmap_path_key = unmap_path_key return cls @@ -320,20 +305,19 @@ def find_indexes(path, datacube): if cls.name == transformation._first_axis: return transformation.merged_values(datacube) - old_unmap_total_path_to_datacube = cls.unmap_total_path_to_datacube + old_unmap_path_key = cls.unmap_path_key - def unmap_total_path_to_datacube(path, unmapped_path): - (path, unmapped_path) = old_unmap_total_path_to_datacube(path, unmapped_path) + def unmap_path_key(key_value_path, leaf_path, unwanted_path): + key_value_path, leaf_path, unwanted_path = old_unmap_path_key(key_value_path, leaf_path, unwanted_path) + new_key_value_path = {} + value = key_value_path[cls.name] for transform in cls.transformations: if isinstance(transform, DatacubeAxisMerger): - transformation = transform - if cls.name == transformation._first_axis: - old_val = path.get(cls.name, None) - (first_val, second_val) = transformation.unmerge(old_val) - path.pop(cls.name, None) - path[transformation._first_axis] = first_val - path[transformation._second_axis] = second_val - return (path, unmapped_path) + if cls.name == transform._first_axis: + (first_val, second_val) = transform.unmerge(value) + new_key_value_path[transform._first_axis] = first_val + new_key_value_path[transform._second_axis] = second_val + return (new_key_value_path, leaf_path, unwanted_path) old_unmap_to_datacube = cls.unmap_to_datacube @@ -350,9 +334,6 @@ def unmap_to_datacube(path, unmapped_path): path[transformation._second_axis] = second_val return (path, unmapped_path) - def remap_to_requested(path, unmapped_path): - return (path, unmapped_path) - def find_indices_between(index_ranges, low, up, datacube, method=None): # TODO: add method for snappping indexes_between_ranges = [] @@ -369,7 +350,9 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): indexes_between = indexes[start:end] indexes_between_ranges.append(indexes_between) else: - indexes_between = [i for i in indexes if low <= i <= up] + lower_idx = bisect.bisect_left(indexes, low) + upper_idx = bisect.bisect_right(indexes, up) + indexes_between = indexes[lower_idx:upper_idx] indexes_between_ranges.append(indexes_between) return indexes_between_ranges @@ -379,9 +362,8 @@ def remap(range): cls.remap = remap cls.find_indexes = find_indexes cls.unmap_to_datacube = unmap_to_datacube - cls.remap_to_requested = remap_to_requested cls.find_indices_between = find_indices_between - cls.unmap_total_path_to_datacube = unmap_total_path_to_datacube + cls.unmap_path_key = unmap_path_key return cls @@ -401,12 +383,6 @@ def find_indexes(path, datacube): ordered_indices = unordered_indices return ordered_indices - def unmap_to_datacube(path, unmapped_path): - return (path, unmapped_path) - - def remap_to_requested(path, unmapped_path): - return (path, unmapped_path) - def find_indices_between(index_ranges, low, up, datacube, method=None): # TODO: add method for snappping indexes_between_ranges = [] @@ -441,7 +417,9 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): indexes_between = indexes[start:end] indexes_between_ranges.append(indexes_between) else: - indexes_between = [i for i in indexes if low <= i <= up] + lower_idx = bisect.bisect_left(indexes, low) + upper_idx = bisect.bisect_right(indexes, up) + indexes_between = indexes[lower_idx:upper_idx] indexes_between_ranges.append(indexes_between) return indexes_between_ranges @@ -450,8 +428,6 @@ def remap(range): cls.remap = remap cls.find_indexes = find_indexes - cls.unmap_to_datacube = unmap_to_datacube - cls.remap_to_requested = remap_to_requested cls.find_indices_between = find_indices_between return cls @@ -471,20 +447,17 @@ def find_indexes(path, datacube): original_vals = old_find_indexes(path, datacube) return transformation.change_val_type(cls.name, original_vals) - old_unmap_total_path_to_datacube = cls.unmap_total_path_to_datacube + old_unmap_path_key = cls.unmap_path_key - def unmap_total_path_to_datacube(path, unmapped_path): - (path, unmapped_path) = old_unmap_total_path_to_datacube(path, unmapped_path) + def unmap_path_key(key_value_path, leaf_path, unwanted_path): + key_value_path, leaf_path, unwanted_path = old_unmap_path_key(key_value_path, leaf_path, unwanted_path) + value = key_value_path[cls.name] for transform in cls.transformations: if isinstance(transform, DatacubeAxisTypeChange): - transformation = transform - if cls.name == transformation.name: - changed_val = path.get(cls.name, None) - unchanged_val = transformation.make_str(changed_val) - if cls.name in path: - path.pop(cls.name, None) - unmapped_path[cls.name] = unchanged_val - return (path, unmapped_path) + if cls.name == transform.name: + unchanged_val = transform.make_str(value) + key_value_path[cls.name] = unchanged_val + return (key_value_path, leaf_path, unwanted_path) def unmap_to_datacube(path, unmapped_path): for transform in cls.transformations: @@ -498,9 +471,6 @@ def unmap_to_datacube(path, unmapped_path): unmapped_path[cls.name] = unchanged_val return (path, unmapped_path) - def remap_to_requested(path, unmapped_path): - return (path, unmapped_path) - def find_indices_between(index_ranges, low, up, datacube, method=None): # TODO: add method for snappping indexes_between_ranges = [] @@ -517,7 +487,9 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): indexes_between = indexes[start:end] indexes_between_ranges.append(indexes_between) else: - indexes_between = [i for i in indexes if low <= i <= up] + lower_idx = bisect.bisect_left(indexes, low) + upper_idx = bisect.bisect_right(indexes, up) + indexes_between = indexes[lower_idx:upper_idx] indexes_between_ranges.append(indexes_between) return indexes_between_ranges @@ -527,9 +499,32 @@ def remap(range): cls.remap = remap cls.find_indexes = find_indexes cls.unmap_to_datacube = unmap_to_datacube - cls.remap_to_requested = remap_to_requested cls.find_indices_between = find_indices_between - cls.unmap_total_path_to_datacube = unmap_total_path_to_datacube + cls.unmap_path_key = unmap_path_key + + return cls + + +def null(cls): + if cls.type_change: + old_find_indexes = cls.find_indexes + + def find_indexes(path, datacube): + return old_find_indexes(path, datacube) + + def find_indices_between(index_ranges, low, up, datacube, method=None): + indexes_between_ranges = [] + for indexes in index_ranges: + indexes_between = [i for i in indexes if low <= i <= up] + indexes_between_ranges.append(indexes_between) + return indexes_between_ranges + + def remap(range): + return [range] + + cls.remap = remap + cls.find_indexes = find_indexes + cls.find_indices_between = find_indices_between return cls @@ -585,11 +580,8 @@ def find_indexes(self, path, datacube): def offset(self, value): return 0 - def unmap_total_path_to_datacube(self, path, unmapped_path): - return (path, unmapped_path) - - def remap_to_requeest(path, unmapped_path): - return (path, unmapped_path) + def unmap_path_key(self, key_value_path, leaf_path, unwanted_path): + return (key_value_path, leaf_path, unwanted_path) def find_indices_between(self, index_ranges, low, up, datacube, method=None): # TODO: add method for snappping diff --git a/polytope/datacube/index_tree.py b/polytope/datacube/index_tree.py index b7a37aff5..9054cd0c5 100644 --- a/polytope/datacube/index_tree.py +++ b/polytope/datacube/index_tree.py @@ -30,6 +30,7 @@ def __init__(self, axis=root, value=None): self._parent = None self.result = None self.axis = axis + self.ancestors = [] @property def leaves(self): @@ -37,10 +38,30 @@ def leaves(self): self._collect_leaf_nodes(leaves) return leaves + @property + def leaves_with_ancestors(self): + # TODO: could store ancestors directly in leaves? Change here + leaves = [] + self._collect_leaf_nodes(leaves) + return leaves + + def _collect_leaf_nodes_old(self, leaves): + if len(self.children) == 0: + leaves.append(self) + for n in self.children: + n._collect_leaf_nodes(leaves) + def _collect_leaf_nodes(self, leaves): + # NOTE: leaves_and_ancestors is going to be a list of tuples, where first entry is leaf and second entry is a + # list of its ancestors if len(self.children) == 0: leaves.append(self) + self.ancestors.append(self) for n in self.children: + for ancestor in self.ancestors: + n.ancestors.append(ancestor) + if self.axis != IndexTree.root: + n.ancestors.append(self) n._collect_leaf_nodes(leaves) def __setitem__(self, key, value): @@ -164,6 +185,13 @@ def flatten(self): path[ancestor.axis.name] = ancestor.value return path + def flatten_with_ancestors(self): + path = DatacubePath() + ancestors = self.ancestors + for ancestor in ancestors: + path[ancestor.axis.name] = ancestor.value + return path + def get_ancestors(self): ancestors = [] current_node = self diff --git a/polytope/datacube/transformations/datacube_cyclic.py b/polytope/datacube/transformations/datacube_cyclic.py index a48dedf3d..802285c32 100644 --- a/polytope/datacube/transformations/datacube_cyclic.py +++ b/polytope/datacube/transformations/datacube_cyclic.py @@ -20,3 +20,6 @@ def change_val_type(self, axis_name, values): def blocked_axes(self): return [] + + def unwanted_axes(self): + return [] diff --git a/polytope/datacube/transformations/datacube_mappers.py b/polytope/datacube/transformations/datacube_mappers.py index bac6bcf92..f759a00bf 100644 --- a/polytope/datacube/transformations/datacube_mappers.py +++ b/polytope/datacube/transformations/datacube_mappers.py @@ -1,7 +1,9 @@ +import bisect import math from copy import deepcopy from importlib import import_module +from ...utility.list_tools import bisect_left_cmp, bisect_right_cmp from .datacube_transformations import DatacubeAxisTransformation @@ -14,6 +16,9 @@ def __init__(self, name, mapper_options): self.grid_resolution = mapper_options["resolution"] self.grid_axes = mapper_options["axes"] self.old_axis = name + self._final_transformation = self.generate_final_transformation() + self._final_mapped_axes = self._final_transformation._mapped_axes + self._axis_reversed = self._final_transformation._axis_reversed def generate_final_transformation(self): map_type = _type_to_datacube_mapper_lookup[self.grid_type] @@ -25,9 +30,11 @@ def generate_final_transformation(self): def blocked_axes(self): return [] + def unwanted_axes(self): + return [self._final_mapped_axes[0]] + def transformation_axes_final(self): - final_transformation = self.generate_final_transformation() - final_axes = final_transformation._mapped_axes + final_axes = self._final_mapped_axes return final_axes # Needs to also implement its own methods @@ -38,8 +45,7 @@ def change_val_type(self, axis_name, values): def _mapped_axes(self): # NOTE: Each of the mapper method needs to call it's sub mapper method - final_transformation = self.generate_final_transformation() - final_axes = final_transformation._mapped_axes + final_axes = self._final_mapped_axes return final_axes def _base_axis(self): @@ -49,24 +55,78 @@ def _resolution(self): pass def first_axis_vals(self): - final_transformation = self.generate_final_transformation() - return final_transformation.first_axis_vals() + return self._final_transformation.first_axis_vals() def second_axis_vals(self, first_val): - final_transformation = self.generate_final_transformation() - return final_transformation.second_axis_vals(first_val) + return self._final_transformation.second_axis_vals(first_val) def map_first_axis(self, lower, upper): - final_transformation = self.generate_final_transformation() - return final_transformation.map_first_axis(lower, upper) + return self._final_transformation.map_first_axis(lower, upper) def map_second_axis(self, first_val, lower, upper): - final_transformation = self.generate_final_transformation() - return final_transformation.map_second_axis(first_val, lower, upper) + return self._final_transformation.map_second_axis(first_val, lower, upper) + + def find_second_idx(self, first_val, second_val): + return self._final_transformation.find_second_idx(first_val, second_val) + + def unmap_first_val_to_start_line_idx(self, first_val): + return self._final_transformation.unmap_first_val_to_start_line_idx(first_val) def unmap(self, first_val, second_val): - final_transformation = self.generate_final_transformation() - return final_transformation.unmap(first_val, second_val) + return self._final_transformation.unmap(first_val, second_val) + + +class RegularGridMapper(DatacubeMapper): + def __init__(self, base_axis, mapped_axes, resolution): + self._mapped_axes = mapped_axes + self._base_axis = base_axis + self._resolution = resolution + self.deg_increment = 90 / self._resolution + self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False} + self._first_axis_vals = self.first_axis_vals() + + def first_axis_vals(self): + first_ax_vals = [90 - i * self.deg_increment for i in range(2 * self._resolution)] + return first_ax_vals + + def map_first_axis(self, lower, upper): + axis_lines = self._first_axis_vals + return_vals = [val for val in axis_lines if lower <= val <= upper] + return return_vals + + def second_axis_vals(self, first_val): + second_ax_vals = [i * self.deg_increment for i in range(4 * self._resolution)] + return second_ax_vals + + def map_second_axis(self, first_val, lower, upper): + axis_lines = self.second_axis_vals(first_val) + return_vals = [val for val in axis_lines if lower <= val <= upper] + return return_vals + + def axes_idx_to_regular_idx(self, first_idx, second_idx): + final_idx = first_idx * 4 * self._resolution + second_idx + return final_idx + + def find_second_idx(self, first_val, second_val): + tol = 1e-10 + second_axis_vals = self.second_axis_vals(first_val) + second_idx = bisect.bisect_left(second_axis_vals, second_val - tol) + return second_idx + + def unmap_first_val_to_start_line_idx(self, first_val): + tol = 1e-8 + first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] + first_idx = self._first_axis_vals.index(first_val) + return first_idx * 4 * self._resolution + + def unmap(self, first_val, second_val): + tol = 1e-8 + first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] + first_idx = self._first_axis_vals.index(first_val) + second_val = [i for i in self.second_axis_vals(first_val) if second_val - tol <= i <= second_val + tol][0] + second_idx = self.second_axis_vals(first_val).index(second_val) + final_index = self.axes_idx_to_regular_idx(first_idx, second_idx) + return final_index class HealpixGridMapper(DatacubeMapper): @@ -74,6 +134,8 @@ def __init__(self, base_axis, mapped_axes, resolution): self._mapped_axes = mapped_axes self._base_axis = base_axis self._resolution = resolution + self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False} + self._first_axis_vals = self.first_axis_vals() def first_axis_vals(self): rad2deg = 180 / math.pi @@ -91,18 +153,17 @@ def first_axis_vals(self): vals[4 * self._resolution - 1 - i] = -val # Equator vals[2 * self._resolution - 1] = 0 - return vals def map_first_axis(self, lower, upper): - axis_lines = self.first_axis_vals() + axis_lines = self._first_axis_vals return_vals = [val for val in axis_lines if lower <= val <= upper] return return_vals def second_axis_vals(self, first_val): tol = 1e-8 - first_val = [i for i in self.first_axis_vals() if first_val - tol <= i <= first_val + tol][0] - idx = self.first_axis_vals().index(first_val) + first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] + idx = self._first_axis_vals.index(first_val) # Polar caps if idx < self._resolution - 1 or 3 * self._resolution - 1 < idx <= 4 * self._resolution - 2: @@ -114,6 +175,8 @@ def second_axis_vals(self, first_val): if self._resolution - 1 <= idx < 2 * self._resolution - 1 or 2 * self._resolution <= idx < 3 * self._resolution: r_start = start * (2 - (((idx + 1) - self._resolution + 1) % 2)) vals = [r_start + i * (360 / (4 * self._resolution)) for i in range(4 * self._resolution)] + if vals[-1] == 360: + vals[-1] = 0 return vals # Equator temp_val = 1 if self._resolution % 2 else 0 @@ -134,22 +197,51 @@ def axes_idx_to_healpix_idx(self, first_idx, second_idx): idx += 4 * (i + 1) else: idx += second_idx + return idx for i in range(self._resolution - 1, 3 * self._resolution): if i != first_idx: idx += 4 * self._resolution else: idx += second_idx + return idx for i in range(3 * self._resolution, 4 * self._resolution - 1): if i != first_idx: idx += 4 * (4 * self._resolution - 1 - i + 1) else: idx += second_idx - return idx + return idx + + def find_second_idx(self, first_val, second_val): + tol = 1e-10 + second_axis_vals = self.second_axis_vals(first_val) + second_idx = bisect.bisect_left(second_axis_vals, second_val - tol) + return second_idx + + def unmap_first_val_to_start_line_idx(self, first_val): + tol = 1e-8 + first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] + first_idx = self._first_axis_vals.index(first_val) + idx = 0 + for i in range(self._resolution - 1): + if i != first_idx: + idx += 4 * (i + 1) + else: + return idx + for i in range(self._resolution - 1, 3 * self._resolution): + if i != first_idx: + idx += 4 * self._resolution + else: + return idx + for i in range(3 * self._resolution, 4 * self._resolution - 1): + if i != first_idx: + idx += 4 * (4 * self._resolution - 1 - i + 1) + else: + return idx def unmap(self, first_val, second_val): tol = 1e-8 - first_val = [i for i in self.first_axis_vals() if first_val - tol <= i <= first_val + tol][0] - first_idx = self.first_axis_vals().index(first_val) + first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] + first_idx = self._first_axis_vals.index(first_val) second_val = [i for i in self.second_axis_vals(first_val) if second_val - tol <= i <= second_val + tol][0] second_idx = self.second_axis_vals(first_val).index(second_val) healpix_index = self.axes_idx_to_healpix_idx(first_idx, second_idx) @@ -161,6 +253,10 @@ def __init__(self, base_axis, mapped_axes, resolution): self._mapped_axes = mapped_axes self._base_axis = base_axis self._resolution = resolution + self._first_axis_vals = self.first_axis_vals() + self._first_idx_map = self.create_first_idx_map() + self._second_axis_spacing = {} + self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False} def gauss_first_guess(self): i = 0 @@ -228,6 +324,8 @@ def gauss_first_guess(self): def get_precomputed_values_N1280(self): lats = [0] * 2560 + # lats = SortedList() + # lats = {} lats[0] = 89.946187715665616 lats[1] = 89.876478353332288 lats[2] = 89.806357319542244 @@ -2819,57 +2917,87 @@ def first_axis_vals(self): return new_vals def map_first_axis(self, lower, upper): - axis_lines = self.first_axis_vals() - return_vals = [val for val in axis_lines if lower <= val <= upper] + axis_lines = self._first_axis_vals + end_idx = bisect_left_cmp(axis_lines, lower, cmp=lambda x, y: x > y) + 1 + start_idx = bisect_right_cmp(axis_lines, upper, cmp=lambda x, y: x > y) + return_vals = axis_lines[start_idx:end_idx] return return_vals def second_axis_vals(self, first_val): - first_axis_vals = self.first_axis_vals() + first_axis_vals = self._first_axis_vals tol = 1e-10 - first_val = [val for val in first_axis_vals if first_val - tol < val < first_val + tol][0] - first_idx = first_axis_vals.index(first_val) + first_idx = bisect_left_cmp(first_axis_vals, first_val - tol, cmp=lambda x, y: x > y) if first_idx >= self._resolution: first_idx = (2 * self._resolution) - 1 - first_idx first_idx = first_idx + 1 npoints = 4 * first_idx + 16 second_axis_spacing = 360 / npoints - second_axis_start = 0 - second_axis_vals = [second_axis_start + i * second_axis_spacing for i in range(int(npoints))] + second_axis_vals = [i * second_axis_spacing for i in range(npoints)] return second_axis_vals + def second_axis_spacing(self, first_val): + first_axis_vals = self._first_axis_vals + tol = 1e-10 + _first_idx = bisect_left_cmp(first_axis_vals, first_val - tol, cmp=lambda x, y: x > y) + first_idx = _first_idx + if first_idx >= self._resolution: + first_idx = (2 * self._resolution) - 1 - first_idx + first_idx = first_idx + 1 + npoints = 4 * first_idx + 16 + second_axis_spacing = 360 / npoints + return (second_axis_spacing, _first_idx + 1) + def map_second_axis(self, first_val, lower, upper): - second_axis_vals = self.second_axis_vals(first_val) - return_vals = [val for val in second_axis_vals if lower <= val <= upper] + second_axis_spacing, first_idx = self.second_axis_spacing(first_val) + start_idx = int(lower / second_axis_spacing) + end_idx = int(upper / second_axis_spacing) + 1 + return_vals = [i * second_axis_spacing for i in range(start_idx, end_idx)] return return_vals def axes_idx_to_octahedral_idx(self, first_idx, second_idx): - octa_idx = 0 - if first_idx == 1: - octa_idx = second_idx - else: - for i in range(first_idx - 1): - if i <= self._resolution - 1: - octa_idx += 20 + 4 * i - else: - i = i - self._resolution + 1 - if i == 1: - octa_idx += 16 + 4 * self._resolution - else: - i = i - 1 - octa_idx += 16 + 4 * (self._resolution - i) - octa_idx += second_idx + # NOTE: for now this takes ~2e-4s per point, so taking significant time -> for 20k points, takes 4s + # Would it be better to store a dictionary of first_idx with cumulative number of points on that idx? + # Because this is what we are doing here, but we are calculating for each point... + # But then this would only work for special grid resolutions, so need to do like a O1280 version of this + + # NOTE: OR somehow cache this for a given first_idx and then only modify the axis idx for second_idx when the + # first_idx changes + octa_idx = self._first_idx_map[first_idx - 1] + second_idx return octa_idx + def create_first_idx_map(self): + first_idx_list = {} + idx = 0 + for i in range(2 * self._resolution): + first_idx_list[i] = idx + if i <= self._resolution - 1: + idx += 20 + 4 * i + else: + i = i - self._resolution + 1 + if i == 1: + idx += 16 + 4 * self._resolution + else: + i = i - 1 + idx += 16 + 4 * (self._resolution - i) + return first_idx_list + + def find_second_axis_idx(self, first_val, second_val): + (second_axis_spacing, first_idx) = self.second_axis_spacing(first_val) + tol = 1e-8 + if second_val / second_axis_spacing > int(second_val / second_axis_spacing) + 1 - tol: + second_idx = int(second_val / second_axis_spacing) + 1 + else: + second_idx = int(second_val / second_axis_spacing) + return (first_idx, second_idx) + def unmap(self, first_val, second_val): - first_axis_vals = self.first_axis_vals() - tol = 1e-10 - first_val = [val for val in first_axis_vals if first_val - tol < val < first_val + tol][0] - first_idx = first_axis_vals.index(first_val) + 1 - second_axis_vals = self.second_axis_vals(first_val) - second_val = [val for val in second_axis_vals if second_val - tol < val < second_val + tol][0] - second_idx = second_axis_vals.index(second_val) + (first_idx, second_idx) = self.find_second_axis_idx(first_val, second_val) octahedral_index = self.axes_idx_to_octahedral_idx(first_idx, second_idx) return octahedral_index -_type_to_datacube_mapper_lookup = {"octahedral": "OctahedralGridMapper", "healpix": "HealpixGridMapper"} +_type_to_datacube_mapper_lookup = { + "octahedral": "OctahedralGridMapper", + "healpix": "HealpixGridMapper", + "regular": "RegularGridMapper", +} diff --git a/polytope/datacube/transformations/datacube_merger.py b/polytope/datacube/transformations/datacube_merger.py index eb404498d..d60278671 100644 --- a/polytope/datacube/transformations/datacube_merger.py +++ b/polytope/datacube/transformations/datacube_merger.py @@ -15,16 +15,24 @@ def __init__(self, name, merge_options): def blocked_axes(self): return [self._second_axis] + def unwanted_axes(self): + return [] + + def _mapped_axes(self): + return self._first_axis + def merged_values(self, datacube): first_ax_vals = datacube.ax_vals(self.name) second_ax_name = self._second_axis second_ax_vals = datacube.ax_vals(second_ax_name) linkers = self._linkers merged_values = [] - for first_val in first_ax_vals: - for second_val in second_ax_vals: + for i in range(len(first_ax_vals)): + first_val = first_ax_vals[i] + for j in range(len(second_ax_vals)): + second_val = second_ax_vals[j] # TODO: check that the first and second val are strings - val_to_add = pd.to_datetime(first_val + linkers[0] + second_val + linkers[1]) + val_to_add = pd.to_datetime("".join([first_val, linkers[0], second_val, linkers[1]])) val_to_add = val_to_add.to_numpy() val_to_add = val_to_add.astype("datetime64[s]") merged_values.append(val_to_add) @@ -44,7 +52,12 @@ def unmerge(self, merged_val): first_linker_size = len(self._linkers[0]) second_linked_size = len(self._linkers[1]) second_val = merged_val[first_idx + first_linker_size : -second_linked_size] + + # TODO: maybe replacing like this is too specific to time/dates? + first_val = str(first_val).replace("-", "") + second_val = second_val.replace(":", "") return (first_val, second_val) def change_val_type(self, axis_name, values): - return values + new_values = pd.to_datetime(values) + return new_values diff --git a/polytope/datacube/transformations/datacube_null_transformation.py b/polytope/datacube/transformations/datacube_null_transformation.py new file mode 100644 index 000000000..55c94277e --- /dev/null +++ b/polytope/datacube/transformations/datacube_null_transformation.py @@ -0,0 +1,22 @@ +from .datacube_transformations import DatacubeAxisTransformation + + +class DatacubeNullTransformation(DatacubeAxisTransformation): + def __init__(self, name, mapper_options): + self.name = name + self.transformation_options = mapper_options + + def generate_final_transformation(self): + return self + + def transformation_axes_final(self): + return [self.name] + + def change_val_type(self, axis_name, values): + return values + + def blocked_axes(self): + return [] + + def unwanted_axes(self): + return [] diff --git a/polytope/datacube/transformations/datacube_reverse.py b/polytope/datacube/transformations/datacube_reverse.py index bda03c4a0..6a556907a 100644 --- a/polytope/datacube/transformations/datacube_reverse.py +++ b/polytope/datacube/transformations/datacube_reverse.py @@ -17,3 +17,6 @@ def change_val_type(self, axis_name, values): def blocked_axes(self): return [] + + def unwanted_axes(self): + return [] diff --git a/polytope/datacube/transformations/datacube_transformations.py b/polytope/datacube/transformations/datacube_transformations.py index f4ce357a9..900ad16b6 100644 --- a/polytope/datacube/transformations/datacube_transformations.py +++ b/polytope/datacube/transformations/datacube_transformations.py @@ -50,6 +50,7 @@ def change_val_type(self, axis_name, values): "merge": "DatacubeAxisMerger", "reverse": "DatacubeAxisReverse", "type_change": "DatacubeAxisTypeChange", + "null": "DatacubeNullTransformation", } _type_to_transformation_file_lookup = { @@ -58,6 +59,7 @@ def change_val_type(self, axis_name, values): "merge": "merger", "reverse": "reverse", "type_change": "type_change", + "null": "null_transformation", } has_transform = { @@ -66,4 +68,5 @@ def change_val_type(self, axis_name, values): "merge": "has_merger", "reverse": "reorder", "type_change": "type_change", + "null": "null", } diff --git a/polytope/datacube/transformations/datacube_type_change.py b/polytope/datacube/transformations/datacube_type_change.py index e3e8a682c..cdc046b76 100644 --- a/polytope/datacube/transformations/datacube_type_change.py +++ b/polytope/datacube/transformations/datacube_type_change.py @@ -11,6 +11,7 @@ def __init__(self, name, type_options): self.name = name self.transformation_options = type_options self.new_type = type_options + self._final_transformation = self.generate_final_transformation() def generate_final_transformation(self): map_type = _type_to_datacube_type_change_lookup[self.new_type] @@ -20,20 +21,22 @@ def generate_final_transformation(self): return transformation def transformation_axes_final(self): - final_transformation = self.generate_final_transformation() - return [final_transformation.axis_name] + return [self._final_transformation.axis_name] def change_val_type(self, axis_name, values): - transformation = self.generate_final_transformation() - return [transformation.transform_type(val) for val in values] + return_idx = [self._final_transformation.transform_type(val) for val in values] + return_idx.sort() + return return_idx def make_str(self, value): - transformation = self.generate_final_transformation() - return transformation.make_str(value) + return self._final_transformation.make_str(value) def blocked_axes(self): return [] + def unwanted_axes(self): + return [] + class TypeChangeStrToInt(DatacubeAxisTypeChange): def __init__(self, axis_name, new_type): diff --git a/polytope/engine/hullslicer.py b/polytope/engine/hullslicer.py index f39f8b0c0..6b0306ed2 100644 --- a/polytope/engine/hullslicer.py +++ b/polytope/engine/hullslicer.py @@ -8,7 +8,7 @@ from ..datacube.backends.datacube import Datacube, IndexTree from ..datacube.datacube_axis import UnsliceableDatacubeAxis from ..shapes import ConvexPolytope -from ..utility.combinatorics import argmax, argmin, group, product, unique +from ..utility.combinatorics import argmax, argmin, group, tensor_product, unique from ..utility.exceptions import UnsliceableShapeError from ..utility.geometry import lerp from .engine import Engine @@ -19,7 +19,7 @@ def __init__(self): pass def _unique_continuous_points(self, p: ConvexPolytope, datacube: Datacube): - for i, ax in enumerate(p.axes()): + for i, ax in enumerate(p._axes): mapper = datacube.get_mapper(ax) if isinstance(mapper, UnsliceableDatacubeAxis): break @@ -29,8 +29,8 @@ def _unique_continuous_points(self, p: ConvexPolytope, datacube: Datacube): # Remove duplicate points unique(p.points) - def _build_unsliceable_child(self, polytope, ax, node, datacube, lower, next_nodes): - if polytope.axes() != [ax.name]: + def _build_unsliceable_child(self, polytope, ax, node, datacube, lower, next_nodes, slice_axis_idx): + if polytope._axes != [ax.name]: raise UnsliceableShapeError(ax) path = node.flatten() if datacube.has_index(path, ax, lower): @@ -42,20 +42,26 @@ def _build_unsliceable_child(self, polytope, ax, node, datacube, lower, next_nod # raise a value not found error raise ValueError() - def _build_sliceable_child(self, polytope, ax, node, datacube, lower, upper, next_nodes): + def _build_sliceable_child(self, polytope, ax, node, datacube, lower, upper, next_nodes, slice_axis_idx): tol = ax.tol lower = ax.from_float(lower - tol) upper = ax.from_float(upper + tol) flattened = node.flatten() method = polytope.method - for value in datacube.get_indices(flattened, ax, lower, upper, method): + values = datacube.get_indices(flattened, ax, lower, upper, method) + + if len(values) == 0: + node.remove_branch() + + for value in values: # convert to float for slicing fvalue = ax.to_float(value) - new_polytope = slice(polytope, ax.name, fvalue) + new_polytope = slice(polytope, ax.name, fvalue, slice_axis_idx) # store the native type remapped_val = value if ax.is_cyclic: - remapped_val = (ax.remap([value, value])[0][0] + ax.remap([value, value])[0][1]) / 2 + remapped_val_interm = ax.remap([value, value])[0] + remapped_val = (remapped_val_interm[0] + remapped_val_interm[1]) / 2 remapped_val = round(remapped_val, int(-math.log10(ax.tol))) child = node.create_child(ax, remapped_val) child["unsliced_polytopes"] = copy(node["unsliced_polytopes"]) @@ -66,13 +72,13 @@ def _build_sliceable_child(self, polytope, ax, node, datacube, lower, upper, nex def _build_branch(self, ax, node, datacube, next_nodes): for polytope in node["unsliced_polytopes"]: - if ax.name in polytope.axes(): - lower, upper = polytope.extents(ax.name) + if ax.name in polytope._axes: + lower, upper, slice_axis_idx = polytope.extents(ax.name) # here, first check if the axis is an unsliceable axis and directly build node if it is if isinstance(ax, UnsliceableDatacubeAxis): - self._build_unsliceable_child(polytope, ax, node, datacube, lower, next_nodes) + self._build_unsliceable_child(polytope, ax, node, datacube, lower, next_nodes, slice_axis_idx) else: - self._build_sliceable_child(polytope, ax, node, datacube, lower, upper, next_nodes) + self._build_sliceable_child(polytope, ax, node, datacube, lower, upper, next_nodes, slice_axis_idx) del node["unsliced_polytopes"] def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]): @@ -83,7 +89,7 @@ def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]): groups, input_axes = group(polytopes) datacube.validate(input_axes) request = IndexTree() - combinations = product(groups) + combinations = tensor_product(groups) for c in combinations: r = IndexTree() @@ -127,9 +133,7 @@ def _reduce_dimension(intersects, slice_axis_idx): return temp_intersects -def slice(polytope: ConvexPolytope, axis, value): - slice_axis_idx = polytope._axes.index(axis) - +def slice(polytope: ConvexPolytope, axis, value, slice_axis_idx): if len(polytope.points[0]) == 1: # Note that in this case, we do not need to do linear interpolation so we can save time if value in chain(*polytope.points): @@ -145,7 +149,7 @@ def slice(polytope: ConvexPolytope, axis, value): # Reduce dimension of intersection points, removing slice axis intersects = _reduce_dimension(intersects, slice_axis_idx) - axes = [ax for ax in polytope.axes() if ax != axis] + axes = [ax for ax in polytope._axes if ax != axis] if len(intersects) < len(intersects[0]) + 1: return ConvexPolytope(axes, intersects) @@ -162,7 +166,7 @@ def slice(polytope: ConvexPolytope, axis, value): vertices = hull.vertices except scipy.spatial.qhull.QhullError as e: - if "input is less than" or "simplex is flat" in str(e): + if "less than" or "flat" in str(e): return ConvexPolytope(axes, intersects) # Sliced result is simply the convex hull return ConvexPolytope(axes, [intersects[i] for i in vertices]) diff --git a/polytope/shapes.py b/polytope/shapes.py index cc77273a7..5653f521e 100644 --- a/polytope/shapes.py +++ b/polytope/shapes.py @@ -33,7 +33,7 @@ def extents(self, axis): axis_values = [point[slice_axis_idx] for point in self.points] lower = min(axis_values) upper = max(axis_values) - return (lower, upper) + return (lower, upper, slice_axis_idx) def __str__(self): return f"Polytope in {self.axes} with points {self.points}" @@ -143,7 +143,6 @@ def __init__(self, axes, lower_corner=None, upper_corner=None): if i >> d & 1: vertex[d] = upper_corner[d] self.vertices.append(vertex) - assert lower_corner in self.vertices assert upper_corner in self.vertices assert len(self.vertices) == 2**dimension diff --git a/polytope/utility/combinatorics.py b/polytope/utility/combinatorics.py index 1a4a24a5b..9dc641084 100644 --- a/polytope/utility/combinatorics.py +++ b/polytope/utility/combinatorics.py @@ -18,7 +18,7 @@ def group(polytopes: List[ConvexPolytope]): return groups, concatenation -def product(groups): +def tensor_product(groups): # Compute the tensor product of polytope groups return list(itertools.product(*groups.values())) diff --git a/polytope/utility/list_tools.py b/polytope/utility/list_tools.py new file mode 100644 index 000000000..2d18917c1 --- /dev/null +++ b/polytope/utility/list_tools.py @@ -0,0 +1,22 @@ +def bisect_left_cmp(arr, val, cmp): + left = -1 + r = len(arr) + while r - left > 1: + e = (left + r) >> 1 + if cmp(arr[e], val): + left = e + else: + r = e + return left + + +def bisect_right_cmp(arr, val, cmp): + left = -1 + r = len(arr) + while r - left > 1: + e = (left + r) >> 1 + if cmp(arr[e], val): + left = e + else: + r = e + return r diff --git a/requirements_example.txt b/requirements_example.txt new file mode 100644 index 000000000..19717dacf --- /dev/null +++ b/requirements_example.txt @@ -0,0 +1 @@ +cfgrib==0.9.10.3 \ No newline at end of file diff --git a/tests/.DS_Store b/tests/.DS_Store index a3976a993..28aa9c59f 100644 Binary files a/tests/.DS_Store and b/tests/.DS_Store differ diff --git a/tests/test_combinatorics.py b/tests/test_combinatorics.py index fab2fb2ca..980ad8e2c 100644 --- a/tests/test_combinatorics.py +++ b/tests/test_combinatorics.py @@ -2,7 +2,7 @@ import pytest from polytope import ConvexPolytope -from polytope.utility.combinatorics import group, product, validate_axes +from polytope.utility.combinatorics import group, tensor_product, validate_axes from polytope.utility.exceptions import ( AxisNotFoundError, AxisOverdefinedError, @@ -28,7 +28,7 @@ def test_group_and_product(self): assert len(groups[("a", "b")]) == 4 assert len(all_axes) == 4 - combinations = product(groups) + combinations = tensor_product(groups) assert len(combinations) == 4 for c in combinations: diff --git a/tests/test_datacube_axes_init.py b/tests/test_datacube_axes_init.py index 9089de0c7..6f3d59bbf 100644 --- a/tests/test_datacube_axes_init.py +++ b/tests/test_datacube_axes_init.py @@ -23,7 +23,8 @@ def setup_method(self, method): "transformation": { "mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]} } - } + }, + # "latitude": {"transformation": {"reverse": {True}}}, } self.slicer = HullSlicer() self.API = Polytope(datacube=latlon_array, engine=self.slicer, axis_options=self.options) @@ -35,7 +36,7 @@ def test_created_axes(self): assert self.datacube._axes["longitude"].has_mapper assert isinstance(self.datacube._axes["longitude"], FloatDatacubeAxis) assert not ("values" in self.datacube._axes.keys()) - assert self.datacube._axes["latitude"].find_indexes({}, self.datacube)[:5] == [ + assert list(self.datacube._axes["latitude"].find_indexes({}, self.datacube)[:5]) == [ 89.94618771566562, 89.87647835333229, 89.80635731954224, diff --git a/tests/test_fdb_datacube.py b/tests/test_fdb_datacube.py index 7d364864e..ff56f6709 100644 --- a/tests/test_fdb_datacube.py +++ b/tests/test_fdb_datacube.py @@ -1,11 +1,14 @@ import pandas as pd import pytest -from polytope.datacube.backends.FDB_datacube import FDBDatacube +from polytope.datacube.backends.fdb import FDBDatacube from polytope.engine.hullslicer import HullSlicer from polytope.polytope import Polytope, Request from polytope.shapes import Box, Select +# import geopandas as gpd +# import matplotlib.pyplot as plt + class TestSlicingFDBDatacube: def setup_method(self, method): @@ -19,7 +22,7 @@ def setup_method(self, method): "date": {"transformation": {"merge": {"with": "time", "linkers": ["T", "00"]}}}, "step": {"transformation": {"type_change": "int"}}, } - self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "step": 11} + self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "step": 0} self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) self.slicer = HullSlicer() self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) @@ -28,17 +31,35 @@ def setup_method(self, method): @pytest.mark.skip(reason="can't install fdb branch on CI") def test_fdb_datacube(self): request = Request( - Select("step", [11]), + Select("step", [0]), Select("levtype", ["sfc"]), - Select("date", [pd.Timestamp("20230710T120000")]), + Select("date", [pd.Timestamp("20230625T120000")]), Select("domain", ["g"]), Select("expver", ["0001"]), - Select("param", ["151130"]), + Select("param", ["167"]), Select("class", ["od"]), Select("stream", ["oper"]), - Select("type", ["fc"]), + Select("type", ["an"]), Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), ) result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 9 + + # lats = [] + # lons = [] + # tol = 1e-8 + # for i in range(len(result.leaves)): + # cubepath = result.leaves[i].flatten() + # lat = cubepath["latitude"] + # lon = cubepath["longitude"] + # lats.append(lat) + # lons.append(lon) + + # worldmap = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres")) + # fig, ax = plt.subplots(figsize=(12, 6)) + # worldmap.plot(color="darkgrey", ax=ax) + + # plt.scatter(lons, lats, s=16, c="red", cmap="YlOrRd") + # plt.colorbar(label="Temperature") + # plt.show() diff --git a/tests/test_healpix_mapper.py b/tests/test_healpix_mapper.py index c80d51042..9014e5e3a 100644 --- a/tests/test_healpix_mapper.py +++ b/tests/test_healpix_mapper.py @@ -20,7 +20,8 @@ def setup_method(self, method): self.options = { "values": { "transformation": {"mapper": {"type": "healpix", "resolution": 32, "axes": ["latitude", "longitude"]}} - } + }, + "longitude": {"transformation": {"cyclic": [0, 360]}}, } self.slicer = HullSlicer() self.API = Polytope(datacube=self.latlon_array, engine=self.slicer, axis_options=self.options) @@ -49,7 +50,7 @@ def find_nearest_latlon(self, grib_file, target_lat, target_lon): return nearest_points @pytest.mark.internet - def test_octahedral_grid(self): + def test_healpix_grid(self): request = Request( Box(["latitude", "longitude"], [-2, -2], [10, 10]), Select("time", ["2022-12-14T12:00:00"]), @@ -58,7 +59,8 @@ def test_octahedral_grid(self): Select("valid_time", ["2022-12-14T13:00:00"]), ) result = self.API.retrieve(request) - assert len(result.leaves) == 35 + result.pprint() + assert len(result.leaves) == 40 lats = [] lons = [] @@ -78,4 +80,4 @@ def test_octahedral_grid(self): assert lat <= eccodes_lat + tol assert eccodes_lon - tol <= lon assert lon <= eccodes_lon + tol - assert len(eccodes_lats) == 35 + assert len(eccodes_lats) == 40 diff --git a/tests/test_hull_slicer.py b/tests/test_hull_slicer.py index 4e9404e23..e733612f7 100644 --- a/tests/test_hull_slicer.py +++ b/tests/test_hull_slicer.py @@ -20,38 +20,40 @@ def construct_nd_cube(self, dimension, lower=-1, upper=1): def test_3D(self): p3 = self.construct_nd_cube(3) print(p3) - p2 = polytope.engine.hullslicer.slice(p3, "c", 0.5) + p2 = polytope.engine.hullslicer.slice(p3, "c", 0.5, 2) print(p2) - p1 = polytope.engine.hullslicer.slice(p2, "b", 0.5) + p1 = polytope.engine.hullslicer.slice(p2, "b", 0.5, 1) print(p1) + @pytest.mark.skip(reason="This is too slow.") def test_4D(self): p = self.construct_nd_cube(4) print(p) while len(p.axes()) > 1: - p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5) + p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5, -1) print(p) + @pytest.mark.skip(reason="This is too slow.") def test_ND(self): with benchmark("4D"): p = self.construct_nd_cube(4) while len(p.axes()) > 1: - p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5) + p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5, -1) with benchmark("5D"): p = self.construct_nd_cube(5) while len(p.axes()) > 1: - p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5) + p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5, -1) with benchmark("6D"): p = self.construct_nd_cube(6) while len(p.axes()) > 1: - p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5) + p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5, -1) with benchmark("7D"): p = self.construct_nd_cube(7) while len(p.axes()) > 1: - p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5) + p = polytope.engine.hullslicer.slice(p, p._axes[-1], 0.5, -1) # QHull is not performant above 7D as per its documentation # with benchmark("8D"): diff --git a/tests/test_incomplete_tree_fdb.py b/tests/test_incomplete_tree_fdb.py new file mode 100644 index 000000000..87d55aef1 --- /dev/null +++ b/tests/test_incomplete_tree_fdb.py @@ -0,0 +1,97 @@ +import pandas as pd +import pytest +from eccodes import codes_grib_find_nearest, codes_grib_new_from_file +from helper_functions import download_test_data + +from polytope.datacube.backends.fdb import FDBDatacube +from polytope.engine.hullslicer import HullSlicer +from polytope.polytope import Polytope, Request +from polytope.shapes import Select + + +class TestRegularGrid: + def setup_method(self, method): + nexus_url = "https://get.ecmwf.int/test-data/polytope/test-data/era5-levels-members.grib" + download_test_data(nexus_url, "era5-levels-members.grib") + self.options = { + "values": { + "transformation": {"mapper": {"type": "regular", "resolution": 30, "axes": ["latitude", "longitude"]}} + }, + "date": {"transformation": {"merge": {"with": "time", "linkers": [" ", "00"]}}}, + "step": {"transformation": {"type_change": "int"}}, + "number": {"transformation": {"type_change": "int"}}, + "longitude": {"transformation": {"cyclic": [0, 360]}}, + } + self.config = {"class": "ea", "expver": "0001", "levtype": "pl", "step": 0} + self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) + self.slicer = HullSlicer() + self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) + + def find_nearest_latlon(self, grib_file, target_lat, target_lon): + # Open the GRIB file + f = open(grib_file) + + # Load the GRIB messages from the file + messages = [] + while True: + message = codes_grib_new_from_file(f) + if message is None: + break + messages.append(message) + + # Find the nearest grid points + nearest_points = [] + for message in messages: + nearest_index = codes_grib_find_nearest(message, target_lat, target_lon) + nearest_points.append(nearest_index) + + # Close the GRIB file + f.close() + + return nearest_points + + @pytest.mark.internet + @pytest.mark.skip(reason="can't install fdb branch on CI") + def test_incomplete_fdb_branch(self): + request = Request( + Select("step", [0]), + Select("levtype", ["pl"]), + Select("date", [pd.Timestamp("20170102T120000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["129"]), + Select("class", ["ea"]), + Select("stream", ["enda"]), + Select("type", ["an"]), + Select("latitude", [0]), + Select("longitude", [1]), + Select("levelist", ["500"]), + Select("number", ["0"]), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 1 + assert result.is_root() + + @pytest.mark.internet + @pytest.mark.skip(reason="can't install fdb branch on CI") + def test_incomplete_fdb_branch_2(self): + request = Request( + Select("step", [0]), + Select("levtype", ["pl"]), + Select("date", [pd.Timestamp("20170102T120000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["129"]), + Select("class", ["ea"]), + Select("stream", ["enda"]), + Select("type", ["an"]), + Select("latitude", [1]), + Select("longitude", [0]), + Select("levelist", ["500"]), + Select("number", ["0"]), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 1 + assert result.is_root() diff --git a/tests/test_merge_cyclic_octahedral.py b/tests/test_merge_cyclic_octahedral.py index ac5d5abc5..6fcf16323 100644 --- a/tests/test_merge_cyclic_octahedral.py +++ b/tests/test_merge_cyclic_octahedral.py @@ -15,14 +15,14 @@ def setup_method(self, method): np.random.randn(1, 1, 4289589, 3), dims=("date", "time", "values", "step"), coords={ - "date": ["2000-01-01"], - "time": ["06:00"], + "date": ["20000101"], + "time": ["0600"], "values": list(range(4289589)), "step": [0, 1, 2], }, ) self.options = { - "date": {"transformation": {"merge": {"with": "time", "linkers": ["T", ":00"]}}}, + "date": {"transformation": {"merge": {"with": "time", "linkers": ["T", "00"]}}}, "values": { "transformation": { "mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]} @@ -34,7 +34,6 @@ def setup_method(self, method): self.slicer = HullSlicer() self.API = Polytope(datacube=self.array, engine=self.slicer, axis_options=self.options) - # @pytest.mark.skip(reason="Need date time to not be strings") def test_merge_axis(self): # NOTE: does not work because the date is a string in the merge option... date = np.datetime64("2000-01-01T06:00:00") diff --git a/tests/test_merge_transformation.py b/tests/test_merge_transformation.py index af3b0d79c..ac3abf5bb 100644 --- a/tests/test_merge_transformation.py +++ b/tests/test_merge_transformation.py @@ -15,17 +15,16 @@ def setup_method(self, method): np.random.randn(1, 1), dims=("date", "time"), coords={ - "date": ["2000-01-01"], - "time": ["06:00"], + "date": ["20000101"], + "time": ["0600"], }, ) - self.options = {"date": {"transformation": {"merge": {"with": "time", "linkers": [" ", ":00"]}}}} + self.options = {"date": {"transformation": {"merge": {"with": "time", "linkers": ["T", "00"]}}}} self.xarraydatacube = XArrayDatacube(self.array) self.slicer = HullSlicer() self.API = Polytope(datacube=self.array, engine=self.slicer, axis_options=self.options) def test_merge_axis(self): - request = Request(Select("date", [pd.Timestamp("2000-01-01T06:00:00")])) + request = Request(Select("date", [pd.Timestamp("20000101T060000")])) result = self.API.retrieve(request) - # assert result.leaves[0].flatten()["date"] == np.datetime64("2000-01-01T06:00:00") assert result.leaves[0].flatten()["date"] == pd.Timestamp("2000-01-01T06:00:00") diff --git a/tests/test_regular_grid.py b/tests/test_regular_grid.py new file mode 100644 index 000000000..1de8a8705 --- /dev/null +++ b/tests/test_regular_grid.py @@ -0,0 +1,104 @@ +import pandas as pd +import pytest +from eccodes import codes_grib_find_nearest, codes_grib_new_from_file +from helper_functions import download_test_data + +from polytope.datacube.backends.fdb import FDBDatacube +from polytope.engine.hullslicer import HullSlicer +from polytope.polytope import Polytope, Request +from polytope.shapes import Disk, Select + +# import geopandas as gpd +# import matplotlib.pyplot as plt + + +class TestRegularGrid: + def setup_method(self, method): + nexus_url = "https://get.ecmwf.int/test-data/polytope/test-data/era5-levels-members.grib" + download_test_data(nexus_url, "era5-levels-members.grib") + self.options = { + "values": { + "transformation": {"mapper": {"type": "regular", "resolution": 30, "axes": ["latitude", "longitude"]}} + }, + "date": {"transformation": {"merge": {"with": "time", "linkers": ["T", "00"]}}}, + "step": {"transformation": {"type_change": "int"}}, + "number": {"transformation": {"type_change": "int"}}, + "longitude": {"transformation": {"cyclic": [0, 360]}}, + } + self.config = {"class": "ea", "expver": "0001", "levtype": "pl", "step": 0} + self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) + self.slicer = HullSlicer() + self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) + + def find_nearest_latlon(self, grib_file, target_lat, target_lon): + # Open the GRIB file + f = open(grib_file) + + # Load the GRIB messages from the file + messages = [] + while True: + message = codes_grib_new_from_file(f) + if message is None: + break + messages.append(message) + + # Find the nearest grid points + nearest_points = [] + for message in messages: + nearest_index = codes_grib_find_nearest(message, target_lat, target_lon) + nearest_points.append(nearest_index) + + # Close the GRIB file + f.close() + + return nearest_points + + @pytest.mark.internet + @pytest.mark.skip(reason="can't install fdb branch on CI") + def test_regular_grid(self): + request = Request( + Select("step", [0]), + Select("levtype", ["pl"]), + Select("date", [pd.Timestamp("20170102T120000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["129"]), + Select("class", ["ea"]), + Select("stream", ["enda"]), + Select("type", ["an"]), + Disk(["latitude", "longitude"], [0, 0], [3, 3]), + Select("levelist", ["500"]), + Select("number", ["0", "1"]), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 10 + + lats = [] + lons = [] + eccodes_lats = [] + tol = 1e-8 + for i in range(len(result.leaves)): + cubepath = result.leaves[i].flatten() + lat = cubepath["latitude"] + lon = cubepath["longitude"] + lats.append(lat) + lons.append(lon) + nearest_points = self.find_nearest_latlon("./tests/data/era5-levels-members.grib", lat, lon) + eccodes_lat = nearest_points[0][0]["lat"] + eccodes_lon = nearest_points[0][0]["lon"] + eccodes_lats.append(eccodes_lat) + assert eccodes_lat - tol <= lat + assert lat <= eccodes_lat + tol + assert eccodes_lon - tol <= lon + assert lon <= eccodes_lon + tol + + # worldmap = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres")) + # fig, ax = plt.subplots(figsize=(12, 6)) + # worldmap.plot(color="darkgrey", ax=ax) + + # plt.scatter(lons, lats, s=16, c="red", cmap="YlOrRd") + # plt.colorbar(label="Temperature") + # plt.show() + + assert len(eccodes_lats) == 10 diff --git a/tests/test_shapes.py b/tests/test_shapes.py index ba355dd4c..ebc55b527 100644 --- a/tests/test_shapes.py +++ b/tests/test_shapes.py @@ -3,7 +3,7 @@ import pytest import xarray as xr -from polytope.datacube.backends.FDB_datacube import FDBDatacube +from polytope.datacube.backends.fdb import FDBDatacube from polytope.datacube.backends.xarray import XArrayDatacube from polytope.engine.hullslicer import HullSlicer from polytope.polytope import Polytope, Request diff --git a/tests/test_slice_date_range_fdb.py b/tests/test_slice_date_range_fdb.py new file mode 100644 index 000000000..567230687 --- /dev/null +++ b/tests/test_slice_date_range_fdb.py @@ -0,0 +1,44 @@ +import pandas as pd +import pytest + +from polytope.datacube.backends.fdb import FDBDatacube +from polytope.engine.hullslicer import HullSlicer +from polytope.polytope import Polytope, Request +from polytope.shapes import Box, Select, Span + + +class TestSlicingFDBDatacube: + def setup_method(self, method): + # Create a dataarray with 3 labelled axes using different index types + self.options = { + "values": { + "transformation": { + "mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]} + } + }, + "date": {"transformation": {"merge": {"with": "time", "linkers": ["T", "00"]}}}, + "step": {"transformation": {"type_change": "int"}}, + } + self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "step": 0} + self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) + self.slicer = HullSlicer() + self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) + + # Testing different shapes + @pytest.mark.skip(reason="can't install fdb branch on CI") + def test_fdb_datacube(self): + request = Request( + Select("step", [0]), + Select("levtype", ["sfc"]), + Span("date", pd.Timestamp("20230625T120000"), pd.Timestamp("20230626T120000")), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["an"]), + Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 9