diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index 19098624d..301787a4f 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -25,6 +25,7 @@ def __init__(self, config=None, axis_options=None): self.fake_axes = [] self.unwanted_path = {} self.nearest_search = {} + self.nearest_search = {} partial_request = config # Find values in the level 3 FDB datacube @@ -51,7 +52,14 @@ def __init__(self, config=None, axis_options=None): logging.info("Polytope created axes for: " + str(self._axes.keys())) - def get(self, requests: IndexTree, leaf_path=None): + def get(self, requests: IndexTree): + fdb_requests = [] + fdb_requests_decoding_info = [] + self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info) + output_values = self.fdb.extract(fdb_requests) + self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info) + + def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_decoding_info=[], leaf_path=None): if leaf_path is None: leaf_path = {} @@ -60,7 +68,7 @@ def get(self, requests: IndexTree, leaf_path=None): logging.info("Looking for data for the tree: " + str([leaf.flatten() for leaf in requests.leaves])) for c in requests.children: - self.get(c) + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info) # If request node has no children, we have a leaf so need to assign fdb values to it else: key_value_path = {requests.axis.name: requests.value} @@ -71,12 +79,23 @@ def get(self, requests: IndexTree, leaf_path=None): leaf_path.update(key_value_path) if len(requests.children[0].children[0].children) == 0: # remap this last key - self.get_2nd_last_values(requests, leaf_path) + # TODO: here, find the fdb_requests and associated nodes to which to add results + + (path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) = self.get_2nd_last_values( + requests, leaf_path + ) + (original_indices, sorted_request_ranges) = self.sort_fdb_request_ranges( + range_lengths, current_start_idxs, lat_length + ) + fdb_requests.append(tuple((path, sorted_request_ranges))) + fdb_requests_decoding_info.append( + tuple((original_indices, fdb_node_ranges, lat_length, range_lengths, current_start_idxs)) + ) # Otherwise remap the path for this key and iterate again over children else: for c in requests.children: - self.get(c, leaf_path) + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info, leaf_path) def get_2nd_last_values(self, requests, leaf_path=None): if leaf_path is None: @@ -87,6 +106,7 @@ def get_2nd_last_values(self, requests, leaf_path=None): if len(self.nearest_search) != 0: first_ax_name = requests.children[0].axis.name second_ax_name = requests.children[0].children[0].axis.name + # TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys() nearest_pts = [ [lat_val, lon_val] for (lat_val, lon_val) in zip( @@ -113,7 +133,10 @@ def get_2nd_last_values(self, requests, leaf_path=None): lat_child.remove_branch() else: possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value] - for lon_child in lat_child.children: + lon_children_values = [child.value for child in lat_child.children] + for j in range(len(lon_children_values)): + lon_child_val = lon_children_values[j] + lon_child = [child for child in lat_child.children if child.value == lon_child_val][0] if lon_child.value not in possible_lons: lon_child.remove_branch() @@ -139,7 +162,10 @@ def get_2nd_last_values(self, requests, leaf_path=None): (range_lengths[i], current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf( lat_child, leaf_path, range_length, current_start_idx, fdb_range_nodes ) - self.give_fdb_val_to_node(leaf_path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) + # TODO: do we need to return all of this? + leaf_path_copy = deepcopy(leaf_path) + leaf_path_copy.pop("values") + return (leaf_path_copy, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, fdb_range_n): i = 0 @@ -171,27 +197,31 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, current_idx[i] = current_start_idx return (range_l, current_idx, fdb_range_n) - def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_range_nodes, lat_length): - (output_values, original_indices) = self.find_fdb_values( - leaf_path, range_lengths, current_start_idx, lat_length - ) - new_fdb_range_nodes = [] - new_range_lengths = [] - for j in range(lat_length): - for i in range(len(range_lengths[j])): - if current_start_idx[j][i] is not None: - new_fdb_range_nodes.append(fdb_range_nodes[j][i]) - new_range_lengths.append(range_lengths[j][i]) - sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] - sorted_range_lengths = [new_range_lengths[i] for i in original_indices] - for i in range(len(sorted_fdb_range_nodes)): - for k in range(sorted_range_lengths[i]): - n = sorted_fdb_range_nodes[i][k] - n.result = output_values[0][0][i][0][k] - - def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): - path.pop("values") - fdb_requests = [] + def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): + for k in range(len(output_values)): + request_output_values = output_values[k] + ( + original_indices, + fdb_node_ranges, + lat_length, + range_lengths, + current_start_idxs, + ) = fdb_requests_decoding_info[k] + new_fdb_range_nodes = [] + new_range_lengths = [] + for j in range(lat_length): + for i in range(len(range_lengths[j])): + if current_start_idxs[j][i] is not None: + new_fdb_range_nodes.append(fdb_node_ranges[j][i]) + new_range_lengths.append(range_lengths[j][i]) + sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] + sorted_range_lengths = [new_range_lengths[i] for i in original_indices] + for i in range(len(sorted_fdb_range_nodes)): + for j in range(sorted_range_lengths[i]): + n = sorted_fdb_range_nodes[i][j] + n.result = request_output_values[0][i][0][j] + + def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length): interm_request_ranges = [] for i in range(lat_length): for j in range(len(range_lengths[i])): @@ -201,12 +231,7 @@ def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): request_ranges_with_idx = list(enumerate(interm_request_ranges)) sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) original_indices, sorted_request_ranges = zip(*sorted_list) - fdb_requests.append(tuple((path, sorted_request_ranges))) - print("REQUEST TO FDB") - print(fdb_requests) - output_values = self.fdb.extract(fdb_requests) - print(output_values) - return (output_values, original_indices) + return (original_indices, sorted_request_ranges) def datacube_natural_indexes(self, axis, subarray): indexes = subarray[axis.name] diff --git a/polytope/datacube/index_tree.py b/polytope/datacube/index_tree.py index 7af7f5422..76ab48ab3 100644 --- a/polytope/datacube/index_tree.py +++ b/polytope/datacube/index_tree.py @@ -1,9 +1,10 @@ import json +import logging from typing import OrderedDict from sortedcontainers import SortedList -from .datacube_axis import IntDatacubeAxis +from .datacube_axis import IntDatacubeAxis, UnsliceableDatacubeAxis class DatacubePath(OrderedDict): @@ -85,7 +86,7 @@ def __eq__(self, other): if other.value == self.value: return True else: - if isinstance(self.value, str): + if isinstance(self.axis, UnsliceableDatacubeAxis): return False else: if other.value - 2 * other.axis.tol <= self.value <= other.value + 2 * other.axis.tol: @@ -94,7 +95,6 @@ def __eq__(self, other): return True else: return False - # return (self.axis.name, self.value) == (other.axis.name, other.value) def __lt__(self, other): return (self.axis.name, self.value) < (other.axis.name, other.value) @@ -169,10 +169,12 @@ def intersect(self, other): def pprint(self, level=0): if self.axis.name == "root": - print("\n") - print("\t" * level + "\u21b3" + str(self)) + logging.debug("\n") + logging.debug("\t" * level + "\u21b3" + str(self)) for child in self.children: child.pprint(level + 1) + if len(self.children) == 0: + logging.debug("\t" * (level + 1) + "\u21b3" + str(self.result)) def remove_branch(self): if not self.is_root(): diff --git a/tests/test_point_nearest.py b/tests/test_point_nearest.py index c1ae7dbd8..081050dad 100644 --- a/tests/test_point_nearest.py +++ b/tests/test_point_nearest.py @@ -16,6 +16,7 @@ def setup_method(self, method): "date": {"merge": {"with": "time", "linkers": ["T", "00"]}}, "step": {"type_change": "int"}, "number": {"type_change": "int"}, + "longitude": {"cyclic": [0, 360]}, } self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"} self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) @@ -76,3 +77,43 @@ def test_fdb_datacube_true_point_2(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 3 + + @pytest.mark.fdb + def test_fdb_datacube_true_point_3(self): + request = Request( + Select("step", [21]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20231102T000000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Point(["latitude", "longitude"], [[0.035149384216, -0.01]], method="nearest"), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 1 + assert result.leaves[0].value == 0 + assert result.leaves[0].axis.name == "longitude" + + @pytest.mark.fdb + def test_fdb_datacube_true_point_4(self): + request = Request( + Select("step", [21]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20231102T000000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Point(["latitude", "longitude"], [[0.035149384216, 359.97]], method="nearest"), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 1 + assert result.leaves[0].value == 359.929906542056 + assert result.leaves[0].axis.name == "longitude" diff --git a/tests/test_slice_date_range_fdb.py b/tests/test_slice_date_range_fdb.py index f7e407ffa..fbbdff7f2 100644 --- a/tests/test_slice_date_range_fdb.py +++ b/tests/test_slice_date_range_fdb.py @@ -35,6 +35,7 @@ def test_fdb_datacube(self): Select("class", ["od"]), Select("stream", ["oper"]), Select("type", ["an"]), + Select("number", [1]), Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), ) result = self.API.retrieve(request)