From f02c03edd5f5de1dd24bfbb12baf699ff9c37365 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Wed, 17 Jan 2024 17:31:42 +0100 Subject: [PATCH 01/17] start nearest method by finding indices between like for surrounding method --- polytope/datacube/datacube_axis.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/polytope/datacube/datacube_axis.py b/polytope/datacube/datacube_axis.py index 5c78c22f6..1119be943 100644 --- a/polytope/datacube/datacube_axis.py +++ b/polytope/datacube/datacube_axis.py @@ -147,7 +147,7 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): update_range() indexes_between_ranges = [] - if method != "surrounding": + if method != "surrounding" and method != "nearest": return old_find_indices_between(index_ranges, low, up, datacube, method) else: for indexes in index_ranges: @@ -264,7 +264,7 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): transformation = transform if cls.name in transformation._mapped_axes(): for idxs in index_ranges: - if method == "surrounding": + if method == "surrounding" or method == "nearest": axis_reversed = transform._axis_reversed[cls.name] if not axis_reversed: start = bisect.bisect_left(idxs, low) @@ -349,7 +349,7 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): transformation = transform if cls.name in transformation._mapped_axes(): for indexes in index_ranges: - if method == "surrounding": + if method == "surrounding" or method == "nearest": start = indexes.index(low) end = indexes.index(up) start = max(start - 1, 0) @@ -403,7 +403,7 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): # https://pandas.pydata.org/docs/reference/api/pandas.Index.searchsorted.html # Assumes the indexes are already sorted (could sort to be sure) and monotonically # increasing - if method == "surrounding": + if method == "surrounding" or method == "nearest": start = indexes.searchsorted(low, "left") end = indexes.searchsorted(up, "right") start = max(start - 1, 0) @@ -416,7 +416,7 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): indexes_between = indexes[start:end].to_list() indexes_between_ranges.append(indexes_between) else: - if method == "surrounding": + if method == "surrounding" or method == "nearest": start = indexes.index(low) end = indexes.index(up) start = max(start - 1, 0) @@ -486,7 +486,7 @@ def find_indices_between(index_ranges, low, up, datacube, method=None): transformation = transform if cls.name == transformation.name: for indexes in index_ranges: - if method == "surrounding": + if method == "surrounding" or method == "nearest": start = indexes.index(low) end = indexes.index(up) start = max(start - 1, 0) @@ -598,7 +598,7 @@ def find_indices_between(self, index_ranges, low, up, datacube, method=None): # Find the range of indexes between lower and upper # https://pandas.pydata.org/docs/reference/api/pandas.Index.searchsorted.html # Assumes the indexes are already sorted (could sort to be sure) and monotonically increasing - if method == "surrounding": + if method == "surrounding" or method == "nearest": start = indexes.searchsorted(low, "left") end = indexes.searchsorted(up, "right") start = max(start - 1, 0) @@ -611,7 +611,7 @@ def find_indices_between(self, index_ranges, low, up, datacube, method=None): indexes_between = indexes[start:end].to_list() indexes_between_ranges.append(indexes_between) else: - if method == "surrounding": + if method == "surrounding" or method == "nearest": start = indexes.index(low) end = indexes.index(up) start = max(start - 1, 0) From cf4a1997d2ea1d1d5e323263d41f836d80fb4db5 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Thu, 18 Jan 2024 11:19:26 +0100 Subject: [PATCH 02/17] make nearest method return only nearest point for end lat/lon grid --- polytope/datacube/backends/fdb.py | 34 ++++++++++++++++ polytope/datacube/backends/xarray.py | 2 + polytope/engine/hullslicer.py | 2 + polytope/shapes.py | 12 ++++-- polytope/utility/geometry.py | 18 +++++++++ tests/test_fdb_datacube.py | 3 +- tests/test_point_nearest.py | 60 ++++++++++++++++++++++++++++ tests/test_point_shape.py | 7 ++++ 8 files changed, 133 insertions(+), 5 deletions(-) create mode 100644 tests/test_point_nearest.py diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index cbc76893f..9e9bc1a4a 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -2,6 +2,7 @@ import pygribjump as pygj +from ...utility.geometry import nearest_pt from .datacube import Datacube, IndexTree @@ -15,6 +16,7 @@ def __init__(self, config={}, axis_options={}): self.blocked_axes = [] self.fake_axes = [] self.unwanted_path = {} + self.nearest_search = {} partial_request = config # Find values in the level 3 FDB datacube @@ -61,6 +63,38 @@ def get(self, requests: IndexTree, leaf_path={}): def get_2nd_last_values(self, requests, leaf_path={}): # In this function, we recursively loop over the last two layers of the tree and store the indices of the # request ranges in those layers + first_ax_name = requests.children[0].axis.name + second_ax_name = requests.children[0].children[0].axis.name + nearest_pts = [ + [lat_val, lon_val] + for (lat_val, lon_val) in zip(self.nearest_search[first_ax_name][0], self.nearest_search[second_ax_name][0]) + ] + # TODO: here find nearest point first before retrieving etc + if len(self.nearest_search) != 0: + # first collect the lat lon points found + found_latlon_pts = [] + for lat_child in requests.children: + for lon_child in lat_child.children: + found_latlon_pts.append([lat_child.value, lon_child.value]) + # now find the nearest lat lon to the points requested + nearest_latlons = [] + for pt in nearest_pts: + nearest_latlon = nearest_pt(found_latlon_pts, pt) + nearest_latlons.append(nearest_latlon) + # TODO: now combine with the rest of the function.... + # TODO: need to remove the branches that do not fit + copy_requests = deepcopy(requests) + for i in range(len(copy_requests.children)): + lat_child = copy_requests.children[i] + lat_child = [child for child in requests.children if child.value == lat_child.value][0] + if lat_child.value not in [latlon[0] for latlon in nearest_latlons]: + lat_child.remove_branch() + else: + possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value] + for lon_child in lat_child.children: + if lon_child.value not in possible_lons: + lon_child.remove_branch() + lat_length = len(requests.children) range_lengths = [False] * lat_length current_start_idxs = [False] * lat_length diff --git a/polytope/datacube/backends/xarray.py b/polytope/datacube/backends/xarray.py index f8ca1c2e2..34cd8066f 100644 --- a/polytope/datacube/backends/xarray.py +++ b/polytope/datacube/backends/xarray.py @@ -17,6 +17,7 @@ def __init__(self, dataarray: xr.DataArray, axis_options={}): self.complete_axes = [] self.blocked_axes = [] self.fake_axes = [] + self.nearest_search = None for name, values in dataarray.coords.variables.items(): if name in dataarray.dims: @@ -52,6 +53,7 @@ def get(self, requests: IndexTree): for key in path_copy: axis = self._axes[key] (path, unmapped_path) = axis.unmap_to_datacube(path, unmapped_path) + # TODO: here do nearest point search path = self.fit_path(path) subxarray = self.dataarray.sel(path, method="nearest") subxarray = subxarray.sel(unmapped_path) diff --git a/polytope/engine/hullslicer.py b/polytope/engine/hullslicer.py index e6c8e3eb0..f95f71e2e 100644 --- a/polytope/engine/hullslicer.py +++ b/polytope/engine/hullslicer.py @@ -50,6 +50,8 @@ def _build_sliceable_child(self, polytope, ax, node, datacube, lower, upper, nex upper = ax.from_float(upper + tol) flattened = node.flatten() method = polytope.method + if method == "nearest": + datacube.nearest_search[ax.name] = polytope.points # TODO: this hashing doesn't work because we need to know the latitude val for finding longitude values # TODO: Maybe create a coupled_axes list inside of datacube and add to it during axis formation, then here diff --git a/polytope/shapes.py b/polytope/shapes.py index 6af698ee2..c0378af06 100644 --- a/polytope/shapes.py +++ b/polytope/shapes.py @@ -80,9 +80,15 @@ def __init__(self, axes, values, method=None): self.values = values self.method = method self.polytopes = [] - for i in range(len(axes)): - polytope_points = [v[i] for v in self.values] - self.polytopes.append(ConvexPolytope([axes[i]], [polytope_points], self.method)) + # if method != "nearest": + if True: + # if the method is surrounding, need to treat as 1D polytopes + for i in range(len(axes)): + polytope_points = [v[i] for v in self.values] + # TODO: IS THIS RIGHT? FOR MULTIPLE POINTS, DOES IT CREATE A LINE SEGMENT INSTEAD? + self.polytopes.extend([ConvexPolytope([axes[i]], [[point]], self.method) for point in polytope_points]) + # if method == "nearest": + # self.polytopes.extend([ConvexPolytope(axes, [v], self.method) for v in self.values]) def axes(self): return self._axes diff --git a/polytope/utility/geometry.py b/polytope/utility/geometry.py index bbbb75152..2c88d9655 100644 --- a/polytope/utility/geometry.py +++ b/polytope/utility/geometry.py @@ -1,4 +1,22 @@ +import math + + def lerp(a, b, value): direction = [a - b for a, b in zip(a, b)] intersect = [b + value * d for b, d in zip(b, direction)] return intersect + + +def nearest_pt(pts_list, pt): + nearest_pt = pts_list[0] + distance = l2_norm(pts_list[0], pt) + for new_pt in pts_list[1:]: + new_distance = l2_norm(new_pt, pt) + if new_distance < distance: + distance = new_distance + nearest_pt = new_pt + return nearest_pt + + +def l2_norm(pt1, pt2): + return math.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[1] - pt2[1]) * (pt1[1] - pt2[1])) diff --git a/tests/test_fdb_datacube.py b/tests/test_fdb_datacube.py index a7c54816a..bf4452a35 100644 --- a/tests/test_fdb_datacube.py +++ b/tests/test_fdb_datacube.py @@ -20,7 +20,7 @@ def setup_method(self, method): "step": {"type_change": "int"}, "number": {"type_change": "int"}, } - self.config = {"class": "od", "expver": "0001", "levtype": "sfc"} + self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"} self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) self.slicer = HullSlicer() self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) @@ -38,7 +38,6 @@ def test_fdb_datacube(self): Select("class", ["od"]), Select("stream", ["oper"]), Select("type", ["an"]), - Select("number", [1]), Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), ) result = self.API.retrieve(request) diff --git a/tests/test_point_nearest.py b/tests/test_point_nearest.py new file mode 100644 index 000000000..9166fd794 --- /dev/null +++ b/tests/test_point_nearest.py @@ -0,0 +1,60 @@ +import pandas as pd +import pytest + +from polytope.engine.hullslicer import HullSlicer +from polytope.polytope import Polytope, Request +from polytope.shapes import Point, Select + + +class TestSlicingFDBDatacube: + def setup_method(self, method): + from polytope.datacube.backends.fdb import FDBDatacube + + # Create a dataarray with 3 labelled axes using different index types + self.options = { + "values": {"mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]}}, + "date": {"merge": {"with": "time", "linkers": ["T", "00"]}}, + "step": {"type_change": "int"}, + "number": {"type_change": "int"}, + } + self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"} + self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) + self.slicer = HullSlicer() + self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) + + # Testing different shapes + @pytest.mark.fdb + def test_fdb_datacube(self): + request = Request( + Select("step", [0]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20230625T120000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["an"]), + Point(["latitude", "longitude"], [[0.16, 0.176]], method="nearest"), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 1 + + @pytest.mark.fdb + def test_fdb_datacube_true_point(self): + request = Request( + Select("step", [0]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20230625T120000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["an"]), + Point(["latitude", "longitude"], [[0.175746921078, 0.210608424337]], method="nearest"), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 1 diff --git a/tests/test_point_shape.py b/tests/test_point_shape.py index 0bc203d61..7dbbd9b52 100644 --- a/tests/test_point_shape.py +++ b/tests/test_point_shape.py @@ -30,6 +30,13 @@ def test_point(self): assert len(result.leaves) == 1 assert result.leaves[0].axis.name == "level" + def test_multiple_points(self): + request = Request(Point(["step", "level"], [[3, 10], [3, 12]]), Select("date", ["2000-01-01"])) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 2 + assert result.leaves[0].axis.name == "level" + def test_point_surrounding_step(self): request = Request(Point(["step", "level"], [[2, 10]], method="surrounding"), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) From 0a522dbe67f52adbd2ad4a1b6260e716e80affb5 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Thu, 18 Jan 2024 13:45:14 +0100 Subject: [PATCH 03/17] fix small bug --- polytope/datacube/backends/fdb.py | 14 ++++++++------ tests/test_ecmwf_oper_data_fdb.py | 2 +- tests/test_slice_date_range_fdb.py | 1 - tests/test_slice_date_range_fdb_v2.py | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index 9e9bc1a4a..f682b671f 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -63,14 +63,16 @@ def get(self, requests: IndexTree, leaf_path={}): def get_2nd_last_values(self, requests, leaf_path={}): # In this function, we recursively loop over the last two layers of the tree and store the indices of the # request ranges in those layers - first_ax_name = requests.children[0].axis.name - second_ax_name = requests.children[0].children[0].axis.name - nearest_pts = [ - [lat_val, lon_val] - for (lat_val, lon_val) in zip(self.nearest_search[first_ax_name][0], self.nearest_search[second_ax_name][0]) - ] # TODO: here find nearest point first before retrieving etc if len(self.nearest_search) != 0: + first_ax_name = requests.children[0].axis.name + second_ax_name = requests.children[0].children[0].axis.name + nearest_pts = [ + [lat_val, lon_val] + for (lat_val, lon_val) in zip( + self.nearest_search[first_ax_name][0], self.nearest_search[second_ax_name][0] + ) + ] # first collect the lat lon points found found_latlon_pts = [] for lat_child in requests.children: diff --git a/tests/test_ecmwf_oper_data_fdb.py b/tests/test_ecmwf_oper_data_fdb.py index 7ec4a4659..553407eab 100644 --- a/tests/test_ecmwf_oper_data_fdb.py +++ b/tests/test_ecmwf_oper_data_fdb.py @@ -16,7 +16,7 @@ def setup_method(self, method): "date": {"merge": {"with": "time", "linkers": ["T", "00"]}}, "step": {"type_change": "int"}, } - self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "step": 0, "type": "fc"} + self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper", "type": "fc"} self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) self.slicer = HullSlicer() self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) diff --git a/tests/test_slice_date_range_fdb.py b/tests/test_slice_date_range_fdb.py index 624a77fe2..ec14ba48c 100644 --- a/tests/test_slice_date_range_fdb.py +++ b/tests/test_slice_date_range_fdb.py @@ -36,7 +36,6 @@ def test_fdb_datacube(self): Select("class", ["od"]), Select("stream", ["oper"]), Select("type", ["an"]), - Select("number", [1]), Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), ) result = self.API.retrieve(request) diff --git a/tests/test_slice_date_range_fdb_v2.py b/tests/test_slice_date_range_fdb_v2.py index 24ae1a9a5..63ce5e678 100644 --- a/tests/test_slice_date_range_fdb_v2.py +++ b/tests/test_slice_date_range_fdb_v2.py @@ -16,7 +16,7 @@ def setup_method(self, method): "date": {"merge": {"with": "time", "linkers": ["T", "00"]}}, "step": {"type_change": "int"}, } - self.config = {"class": "ea", "expver": "0001", "levtype": "pl", "step": 0} + self.config = {"class": "ea", "expver": "0001", "levtype": "pl", "stream": "enda"} self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) self.slicer = HullSlicer() self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) From ee500212c1803a6eceab81b676be0c6515bc5a9a Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Thu, 18 Jan 2024 13:49:07 +0100 Subject: [PATCH 04/17] make sure nearest is only called on one point --- polytope/shapes.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/polytope/shapes.py b/polytope/shapes.py index c0378af06..0c170b3fc 100644 --- a/polytope/shapes.py +++ b/polytope/shapes.py @@ -80,15 +80,11 @@ def __init__(self, axes, values, method=None): self.values = values self.method = method self.polytopes = [] - # if method != "nearest": - if True: - # if the method is surrounding, need to treat as 1D polytopes - for i in range(len(axes)): - polytope_points = [v[i] for v in self.values] - # TODO: IS THIS RIGHT? FOR MULTIPLE POINTS, DOES IT CREATE A LINE SEGMENT INSTEAD? - self.polytopes.extend([ConvexPolytope([axes[i]], [[point]], self.method) for point in polytope_points]) - # if method == "nearest": - # self.polytopes.extend([ConvexPolytope(axes, [v], self.method) for v in self.values]) + if method == "nearest": + assert len(self.values) == 1 + for i in range(len(axes)): + polytope_points = [v[i] for v in self.values] + self.polytopes.extend([ConvexPolytope([axes[i]], [[point]], self.method) for point in polytope_points]) def axes(self): return self._axes From ad5c029a327557090c9d14e0e9a65c62352ad324 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Thu, 18 Jan 2024 14:55:34 +0100 Subject: [PATCH 05/17] fix bug --- polytope/datacube/backends/fdb.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index f682b671f..4d7b7aa57 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -85,10 +85,10 @@ def get_2nd_last_values(self, requests, leaf_path={}): nearest_latlons.append(nearest_latlon) # TODO: now combine with the rest of the function.... # TODO: need to remove the branches that do not fit - copy_requests = deepcopy(requests) - for i in range(len(copy_requests.children)): - lat_child = copy_requests.children[i] - lat_child = [child for child in requests.children if child.value == lat_child.value][0] + lat_children_values = [child.value for child in requests.children] + for i in range(len(lat_children_values)): + lat_child_val = lat_children_values[i] + lat_child = [child for child in requests.children if child.value == lat_child_val][0] if lat_child.value not in [latlon[0] for latlon in nearest_latlons]: lat_child.remove_branch() else: @@ -182,7 +182,10 @@ def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) original_indices, sorted_request_ranges = zip(*sorted_list) fdb_requests.append(tuple((path, sorted_request_ranges))) + print("REQUEST TO FDB") + print(fdb_requests) output_values = self.fdb.extract(fdb_requests) + print(output_values) return (output_values, original_indices) def datacube_natural_indexes(self, axis, subarray): From f42f37d89d3780aaaa24c8a7c696991620f3263f Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Fri, 19 Jan 2024 10:07:13 +0100 Subject: [PATCH 06/17] make tests work --- examples/healpix_grid_box_example.py | 6 +----- tests/test_point_nearest.py | 22 ++++++++++++++++++++-- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/examples/healpix_grid_box_example.py b/examples/healpix_grid_box_example.py index 8e676fa3b..6cfbb5e78 100644 --- a/examples/healpix_grid_box_example.py +++ b/examples/healpix_grid_box_example.py @@ -14,11 +14,7 @@ def setup_method(self, method): ds = data.from_source("file", "./tests/data/healpix.grib") self.latlon_array = ds.to_xarray().isel(step=0).isel(time=0).isel(isobaricInhPa=0).z self.xarraydatacube = XArrayDatacube(self.latlon_array) - self.options = { - "values": { - "transformation": {"mapper": {"type": "healpix", "resolution": 32, "axes": ["latitude", "longitude"]}} - } - } + self.options = {"values": {"mapper": {"type": "healpix", "resolution": 32, "axes": ["latitude", "longitude"]}}} self.slicer = HullSlicer() self.API = Polytope(datacube=self.latlon_array, engine=self.slicer, axis_options=self.options) diff --git a/tests/test_point_nearest.py b/tests/test_point_nearest.py index 9166fd794..c1ae7dbd8 100644 --- a/tests/test_point_nearest.py +++ b/tests/test_point_nearest.py @@ -38,7 +38,7 @@ def test_fdb_datacube(self): Point(["latitude", "longitude"], [[0.16, 0.176]], method="nearest"), ) result = self.API.retrieve(request) - result.pprint() + # result.pprint() assert len(result.leaves) == 1 @pytest.mark.fdb @@ -56,5 +56,23 @@ def test_fdb_datacube_true_point(self): Point(["latitude", "longitude"], [[0.175746921078, 0.210608424337]], method="nearest"), ) result = self.API.retrieve(request) - result.pprint() + # result.pprint() assert len(result.leaves) == 1 + + @pytest.mark.fdb + def test_fdb_datacube_true_point_2(self): + request = Request( + Select("step", [21, 22, 23]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20231102T000000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Point(["latitude", "longitude"], [[0.035149384216, 0.0]], method="nearest"), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 3 From e2161ca801751eac5baf48003b4f57b458f55e38 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Fri, 19 Jan 2024 10:37:01 +0100 Subject: [PATCH 07/17] fix tree merging for multiple params --- polytope/datacube/index_tree.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/polytope/datacube/index_tree.py b/polytope/datacube/index_tree.py index 9054cd0c5..f5a3106d1 100644 --- a/polytope/datacube/index_tree.py +++ b/polytope/datacube/index_tree.py @@ -84,13 +84,16 @@ def __eq__(self, other): else: if other.value == self.value: return True - if other.value - 2 * other.axis.tol <= self.value <= other.value + 2 * other.axis.tol: - return True - elif self.value - 2 * self.axis.tol <= other.value <= self.value + 2 * self.axis.tol: - return True else: - return False - # return (self.axis.name, self.value) == (other.axis.name, other.value) + if isinstance(self.value, str): + return False + else: + if other.value - 2 * other.axis.tol <= self.value <= other.value + 2 * other.axis.tol: + return True + elif self.value - 2 * self.axis.tol <= other.value <= self.value + 2 * self.axis.tol: + return True + else: + return False def __lt__(self, other): return (self.axis.name, self.value) < (other.axis.name, other.value) From 721c53adbefd5061e7769141cee17cf8990a0fa5 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Fri, 19 Jan 2024 10:44:23 +0100 Subject: [PATCH 08/17] nicer fix --- polytope/datacube/index_tree.py | 4 ++-- tests/test_multiple_param_fdb.py | 41 ++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 tests/test_multiple_param_fdb.py diff --git a/polytope/datacube/index_tree.py b/polytope/datacube/index_tree.py index f5a3106d1..303bd4d84 100644 --- a/polytope/datacube/index_tree.py +++ b/polytope/datacube/index_tree.py @@ -3,7 +3,7 @@ from sortedcontainers import SortedList -from .datacube_axis import IntDatacubeAxis +from .datacube_axis import IntDatacubeAxis, UnsliceableDatacubeAxis class DatacubePath(OrderedDict): @@ -85,7 +85,7 @@ def __eq__(self, other): if other.value == self.value: return True else: - if isinstance(self.value, str): + if isinstance(self.axis, UnsliceableDatacubeAxis): return False else: if other.value - 2 * other.axis.tol <= self.value <= other.value + 2 * other.axis.tol: diff --git a/tests/test_multiple_param_fdb.py b/tests/test_multiple_param_fdb.py new file mode 100644 index 000000000..04b8e7127 --- /dev/null +++ b/tests/test_multiple_param_fdb.py @@ -0,0 +1,41 @@ +import pandas as pd +import pytest + +from polytope.engine.hullslicer import HullSlicer +from polytope.polytope import Polytope, Request +from polytope.shapes import Box, Select + + +class TestSlicingFDBDatacube: + def setup_method(self, method): + from polytope.datacube.backends.fdb import FDBDatacube + + # Create a dataarray with 3 labelled axes using different index types + self.options = { + "values": {"mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]}}, + "date": {"merge": {"with": "time", "linkers": ["T", "00"]}}, + "step": {"type_change": "int"}, + } + self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper", "type": "fc"} + self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) + self.slicer = HullSlicer() + self.API = Polytope(datacube=self.fdbdatacube, engine=self.slicer, axis_options=self.options) + + # Testing different shapes + @pytest.mark.fdb + def test_fdb_datacube(self): + request = Request( + Select("step", [0]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20240118T000000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["49", "167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Box(["latitude", "longitude"], [0, 0], [0.2, 0.2]), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 18 From db04526d8b5e484100d1efd2242854f03dd0bdd1 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Mon, 22 Jan 2024 21:16:06 +0100 Subject: [PATCH 09/17] fix cyclic nearest method bug --- polytope/datacube/backends/fdb.py | 8 +++--- tests/test_point_nearest.py | 41 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index 4d7b7aa57..661e52330 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -93,7 +93,10 @@ def get_2nd_last_values(self, requests, leaf_path={}): lat_child.remove_branch() else: possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value] - for lon_child in lat_child.children: + lon_children_values = [child.value for child in lat_child.children] + for j in range(len(lon_children_values)): + lon_child_val = lon_children_values[j] + lon_child = [child for child in lat_child.children if child.value == lon_child_val][0] if lon_child.value not in possible_lons: lon_child.remove_branch() @@ -182,10 +185,7 @@ def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) original_indices, sorted_request_ranges = zip(*sorted_list) fdb_requests.append(tuple((path, sorted_request_ranges))) - print("REQUEST TO FDB") - print(fdb_requests) output_values = self.fdb.extract(fdb_requests) - print(output_values) return (output_values, original_indices) def datacube_natural_indexes(self, axis, subarray): diff --git a/tests/test_point_nearest.py b/tests/test_point_nearest.py index c1ae7dbd8..081050dad 100644 --- a/tests/test_point_nearest.py +++ b/tests/test_point_nearest.py @@ -16,6 +16,7 @@ def setup_method(self, method): "date": {"merge": {"with": "time", "linkers": ["T", "00"]}}, "step": {"type_change": "int"}, "number": {"type_change": "int"}, + "longitude": {"cyclic": [0, 360]}, } self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"} self.fdbdatacube = FDBDatacube(self.config, axis_options=self.options) @@ -76,3 +77,43 @@ def test_fdb_datacube_true_point_2(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 3 + + @pytest.mark.fdb + def test_fdb_datacube_true_point_3(self): + request = Request( + Select("step", [21]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20231102T000000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Point(["latitude", "longitude"], [[0.035149384216, -0.01]], method="nearest"), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 1 + assert result.leaves[0].value == 0 + assert result.leaves[0].axis.name == "longitude" + + @pytest.mark.fdb + def test_fdb_datacube_true_point_4(self): + request = Request( + Select("step", [21]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20231102T000000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Point(["latitude", "longitude"], [[0.035149384216, 359.97]], method="nearest"), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 1 + assert result.leaves[0].value == 359.929906542056 + assert result.leaves[0].axis.name == "longitude" From 740e91766aae31fe42b5115c7a17d9bae7aff1ba Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Tue, 23 Jan 2024 09:54:52 +0100 Subject: [PATCH 10/17] black setup.py --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 0bcb9ac6c..005e76dfd 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ io.open("polytope/version.py", encoding="utf_8_sig").read(), ).group(1) -with open('requirements.txt') as f: +with open("requirements.txt") as f: requirements = f.read().splitlines() setup( @@ -25,5 +25,5 @@ packages=find_packages(), zip_safe=False, include_package_data=True, - install_requires=requirements + install_requires=requirements, ) From aba829540c3a0cf15d3835f24d9aa74c715c0867 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Tue, 23 Jan 2024 10:05:04 +0100 Subject: [PATCH 11/17] fix CI --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 005e76dfd..e152f4899 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ io.open("polytope/version.py", encoding="utf_8_sig").read(), ).group(1) -with open("requirements.txt") as f: +with open(".requirements.txt") as f: requirements = f.read().splitlines() setup( From 4f5667dfc1b863aab86654bcc16c1e27f44ffa15 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Tue, 23 Jan 2024 10:07:21 +0100 Subject: [PATCH 12/17] fix CI --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e152f4899..18241b407 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ io.open("polytope/version.py", encoding="utf_8_sig").read(), ).group(1) -with open(".requirements.txt") as f: +with open("./requirements.txt") as f: requirements = f.read().splitlines() setup( From 2ae23da43384b8704723942e0275a4ab95e4b15b Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Tue, 23 Jan 2024 10:14:03 +0100 Subject: [PATCH 13/17] fix CI --- MANIFEST.in | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..540b72040 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include requirements.txt \ No newline at end of file diff --git a/setup.py b/setup.py index 18241b407..005e76dfd 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ io.open("polytope/version.py", encoding="utf_8_sig").read(), ).group(1) -with open("./requirements.txt") as f: +with open("requirements.txt") as f: requirements = f.read().splitlines() setup( From 78531228f7dc96323f558edee223a3174d0418bf Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Tue, 23 Jan 2024 10:17:27 +0100 Subject: [PATCH 14/17] fix CI --- MANIFEST.in | 1 + 1 file changed, 1 insertion(+) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..540b72040 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include requirements.txt \ No newline at end of file From 56b10ad5453728da7be770429fa72043f3d042f9 Mon Sep 17 00:00:00 2001 From: majh Date: Wed, 24 Jan 2024 00:16:05 +0000 Subject: [PATCH 15/17] logging --- polytope/datacube/index_tree.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/polytope/datacube/index_tree.py b/polytope/datacube/index_tree.py index 303bd4d84..b15e67fe2 100644 --- a/polytope/datacube/index_tree.py +++ b/polytope/datacube/index_tree.py @@ -1,5 +1,6 @@ import json from typing import OrderedDict +import logging from sortedcontainers import SortedList @@ -168,10 +169,12 @@ def intersect(self, other): def pprint(self, level=0): if self.axis.name == "root": - print("\n") - print("\t" * level + "\u21b3" + str(self)) + logging.debug("\n") + logging.debug("\t" * level + "\u21b3" + str(self)) for child in self.children: child.pprint(level + 1) + if len(self.children) == 0: + logging.debug("\t" * (level + 1) + "\u21b3" + str(self.result)) def remove_branch(self): if not self.is_root(): From 49838d05d04e8462404b189df9cefc13287c0fd8 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Wed, 24 Jan 2024 10:21:45 +0100 Subject: [PATCH 16/17] better batching to fdb backend --- polytope/datacube/backends/fdb.py | 156 ++++++++++++++++++++++++++++-- 1 file changed, 149 insertions(+), 7 deletions(-) diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index d0f4264ff..8e9300128 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -39,7 +39,14 @@ def __init__(self, config={}, axis_options={}): val = self._axes[name].type self._check_and_add_axes(options, name, val) - def get(self, requests: IndexTree, leaf_path={}): + def get(self, requests: IndexTree): + fdb_requests = [] + fdb_requests_decoding_info = [] + self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info) + output_values = self.fdb.extract(fdb_requests) + self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info) + + def old_get(self, requests: IndexTree, leaf_path={}): # First when request node is root, go to its children if requests.axis.name == "root": for c in requests.children: @@ -61,6 +68,39 @@ def get(self, requests: IndexTree, leaf_path={}): for c in requests.children: self.get(c, leaf_path) + def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_decoding_info=[], leaf_path={}): + # First when request node is root, go to its children + if requests.axis.name == "root": + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info) + # If request node has no children, we have a leaf so need to assign fdb values to it + else: + key_value_path = {requests.axis.name: requests.value} + ax = requests.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path.update(key_value_path) + if len(requests.children[0].children[0].children) == 0: + # remap this last key + # TODO: here, find the fdb_requests and associated nodes to which to add results + + (path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) = self.get_2nd_last_values( + requests, leaf_path + ) + (original_indices, sorted_request_ranges) = self.sort_fdb_request_ranges( + range_lengths, current_start_idxs, lat_length + ) + fdb_requests.append(tuple((path, sorted_request_ranges))) + fdb_requests_decoding_info.append( + tuple((original_indices, fdb_node_ranges, lat_length, range_lengths, current_start_idxs)) + ) + + # Otherwise remap the path for this key and iterate again over children + else: + for c in requests.children: + self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info, leaf_path) + def get_2nd_last_values(self, requests, leaf_path={}): # In this function, we recursively loop over the last two layers of the tree and store the indices of the # request ranges in those layers @@ -68,6 +108,75 @@ def get_2nd_last_values(self, requests, leaf_path={}): if len(self.nearest_search) != 0: first_ax_name = requests.children[0].axis.name second_ax_name = requests.children[0].children[0].axis.name + # TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys() + nearest_pts = [ + [lat_val, lon_val] + for (lat_val, lon_val) in zip( + self.nearest_search[first_ax_name][0], self.nearest_search[second_ax_name][0] + ) + ] + # first collect the lat lon points found + found_latlon_pts = [] + for lat_child in requests.children: + for lon_child in lat_child.children: + found_latlon_pts.append([lat_child.value, lon_child.value]) + # now find the nearest lat lon to the points requested + nearest_latlons = [] + for pt in nearest_pts: + nearest_latlon = nearest_pt(found_latlon_pts, pt) + nearest_latlons.append(nearest_latlon) + # TODO: now combine with the rest of the function.... + # TODO: need to remove the branches that do not fit + lat_children_values = [child.value for child in requests.children] + for i in range(len(lat_children_values)): + lat_child_val = lat_children_values[i] + lat_child = [child for child in requests.children if child.value == lat_child_val][0] + if lat_child.value not in [latlon[0] for latlon in nearest_latlons]: + lat_child.remove_branch() + else: + possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value] + lon_children_values = [child.value for child in lat_child.children] + for j in range(len(lon_children_values)): + lon_child_val = lon_children_values[j] + lon_child = [child for child in lat_child.children if child.value == lon_child_val][0] + if lon_child.value not in possible_lons: + lon_child.remove_branch() + + lat_length = len(requests.children) + range_lengths = [False] * lat_length + current_start_idxs = [False] * lat_length + fdb_node_ranges = [False] * lat_length + for i in range(len(requests.children)): + lat_child = requests.children[i] + lon_length = len(lat_child.children) + range_lengths[i] = [1] * lon_length + current_start_idxs[i] = [None] * lon_length + fdb_node_ranges[i] = [[IndexTree.root] * lon_length] * lon_length + range_length = deepcopy(range_lengths[i]) + current_start_idx = deepcopy(current_start_idxs[i]) + fdb_range_nodes = deepcopy(fdb_node_ranges[i]) + key_value_path = {lat_child.axis.name: lat_child.value} + ax = lat_child.axis + (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( + key_value_path, leaf_path, self.unwanted_path + ) + leaf_path.update(key_value_path) + (range_lengths[i], current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf( + lat_child, leaf_path, range_length, current_start_idx, fdb_range_nodes + ) + # TODO: do we need to return all of this? + leaf_path_copy = deepcopy(leaf_path) + leaf_path_copy.pop("values") + return (leaf_path_copy, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) + + def old_get_2nd_last_values(self, requests, leaf_path={}): + # In this function, we recursively loop over the last two layers of the tree and store the indices of the + # request ranges in those layers + # TODO: here find nearest point first before retrieving etc + if len(self.nearest_search) != 0: + first_ax_name = requests.children[0].axis.name + second_ax_name = requests.children[0].children[0].axis.name + # TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys() nearest_pts = [ [lat_val, lon_val] for (lat_val, lon_val) in zip( @@ -155,6 +264,30 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, current_idx[i] = current_start_idx return (range_l, current_idx, fdb_range_n) + def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): + for k in range(len(output_values)): + request_output_values = output_values[k] + ( + original_indices, + fdb_node_ranges, + lat_length, + range_lengths, + current_start_idxs, + ) = fdb_requests_decoding_info[k] + new_fdb_range_nodes = [] + new_range_lengths = [] + for j in range(lat_length): + for i in range(len(range_lengths[j])): + if current_start_idxs[j][i] is not None: + new_fdb_range_nodes.append(fdb_node_ranges[j][i]) + new_range_lengths.append(range_lengths[j][i]) + sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] + sorted_range_lengths = [new_range_lengths[i] for i in original_indices] + for i in range(len(sorted_fdb_range_nodes)): + for j in range(sorted_range_lengths[i]): + n = sorted_fdb_range_nodes[i][j] + n.result = request_output_values[0][i][0][j] + def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_range_nodes, lat_length): (output_values, original_indices) = self.find_fdb_values( leaf_path, range_lengths, current_start_idx, lat_length @@ -169,9 +302,21 @@ def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_ sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] sorted_range_lengths = [new_range_lengths[i] for i in original_indices] for i in range(len(sorted_fdb_range_nodes)): - for k in range(sorted_range_lengths[i]): - n = sorted_fdb_range_nodes[i][k] - n.result = output_values[0][0][i][0][k] + for j in range(sorted_range_lengths[i]): + n = sorted_fdb_range_nodes[i][j] + n.result = output_values[0][0][i][0][j] + + def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length): + interm_request_ranges = [] + for i in range(lat_length): + for j in range(len(range_lengths[i])): + if current_start_idx[i][j] is not None: + current_request_ranges = (current_start_idx[i][j], current_start_idx[i][j] + range_lengths[i][j]) + interm_request_ranges.append(current_request_ranges) + request_ranges_with_idx = list(enumerate(interm_request_ranges)) + sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) + original_indices, sorted_request_ranges = zip(*sorted_list) + return (original_indices, sorted_request_ranges) def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): path.pop("values") @@ -186,10 +331,7 @@ def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) original_indices, sorted_request_ranges = zip(*sorted_list) fdb_requests.append(tuple((path, sorted_request_ranges))) - print("REQUEST TO FDB") - print(fdb_requests) output_values = self.fdb.extract(fdb_requests) - print(output_values) return (output_values, original_indices) def datacube_natural_indexes(self, axis, subarray): From b95c500dfcb4a40dff5ca4f51cbf9c6cca2cb578 Mon Sep 17 00:00:00 2001 From: Mathilde Leuridan Date: Wed, 24 Jan 2024 10:23:53 +0100 Subject: [PATCH 17/17] clean up unnecessary functions --- polytope/datacube/backends/fdb.py | 121 ------------------------------ 1 file changed, 121 deletions(-) diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index 8e9300128..7b2d19295 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -46,28 +46,6 @@ def get(self, requests: IndexTree): output_values = self.fdb.extract(fdb_requests) self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info) - def old_get(self, requests: IndexTree, leaf_path={}): - # First when request node is root, go to its children - if requests.axis.name == "root": - for c in requests.children: - self.get(c) - # If request node has no children, we have a leaf so need to assign fdb values to it - else: - key_value_path = {requests.axis.name: requests.value} - ax = requests.axis - (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( - key_value_path, leaf_path, self.unwanted_path - ) - leaf_path.update(key_value_path) - if len(requests.children[0].children[0].children) == 0: - # remap this last key - self.get_2nd_last_values(requests, leaf_path) - - # Otherwise remap the path for this key and iterate again over children - else: - for c in requests.children: - self.get(c, leaf_path) - def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_decoding_info=[], leaf_path={}): # First when request node is root, go to its children if requests.axis.name == "root": @@ -169,71 +147,6 @@ def get_2nd_last_values(self, requests, leaf_path={}): leaf_path_copy.pop("values") return (leaf_path_copy, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) - def old_get_2nd_last_values(self, requests, leaf_path={}): - # In this function, we recursively loop over the last two layers of the tree and store the indices of the - # request ranges in those layers - # TODO: here find nearest point first before retrieving etc - if len(self.nearest_search) != 0: - first_ax_name = requests.children[0].axis.name - second_ax_name = requests.children[0].children[0].axis.name - # TODO: throw error if first_ax_name or second_ax_name not in self.nearest_search.keys() - nearest_pts = [ - [lat_val, lon_val] - for (lat_val, lon_val) in zip( - self.nearest_search[first_ax_name][0], self.nearest_search[second_ax_name][0] - ) - ] - # first collect the lat lon points found - found_latlon_pts = [] - for lat_child in requests.children: - for lon_child in lat_child.children: - found_latlon_pts.append([lat_child.value, lon_child.value]) - # now find the nearest lat lon to the points requested - nearest_latlons = [] - for pt in nearest_pts: - nearest_latlon = nearest_pt(found_latlon_pts, pt) - nearest_latlons.append(nearest_latlon) - # TODO: now combine with the rest of the function.... - # TODO: need to remove the branches that do not fit - lat_children_values = [child.value for child in requests.children] - for i in range(len(lat_children_values)): - lat_child_val = lat_children_values[i] - lat_child = [child for child in requests.children if child.value == lat_child_val][0] - if lat_child.value not in [latlon[0] for latlon in nearest_latlons]: - lat_child.remove_branch() - else: - possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value] - lon_children_values = [child.value for child in lat_child.children] - for j in range(len(lon_children_values)): - lon_child_val = lon_children_values[j] - lon_child = [child for child in lat_child.children if child.value == lon_child_val][0] - if lon_child.value not in possible_lons: - lon_child.remove_branch() - - lat_length = len(requests.children) - range_lengths = [False] * lat_length - current_start_idxs = [False] * lat_length - fdb_node_ranges = [False] * lat_length - for i in range(len(requests.children)): - lat_child = requests.children[i] - lon_length = len(lat_child.children) - range_lengths[i] = [1] * lon_length - current_start_idxs[i] = [None] * lon_length - fdb_node_ranges[i] = [[IndexTree.root] * lon_length] * lon_length - range_length = deepcopy(range_lengths[i]) - current_start_idx = deepcopy(current_start_idxs[i]) - fdb_range_nodes = deepcopy(fdb_node_ranges[i]) - key_value_path = {lat_child.axis.name: lat_child.value} - ax = lat_child.axis - (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( - key_value_path, leaf_path, self.unwanted_path - ) - leaf_path.update(key_value_path) - (range_lengths[i], current_start_idxs[i], fdb_node_ranges[i]) = self.get_last_layer_before_leaf( - lat_child, leaf_path, range_length, current_start_idx, fdb_range_nodes - ) - self.give_fdb_val_to_node(leaf_path, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) - def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, fdb_range_n): i = 0 for c in requests.children: @@ -288,24 +201,6 @@ def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): n = sorted_fdb_range_nodes[i][j] n.result = request_output_values[0][i][0][j] - def give_fdb_val_to_node(self, leaf_path, range_lengths, current_start_idx, fdb_range_nodes, lat_length): - (output_values, original_indices) = self.find_fdb_values( - leaf_path, range_lengths, current_start_idx, lat_length - ) - new_fdb_range_nodes = [] - new_range_lengths = [] - for j in range(lat_length): - for i in range(len(range_lengths[j])): - if current_start_idx[j][i] is not None: - new_fdb_range_nodes.append(fdb_range_nodes[j][i]) - new_range_lengths.append(range_lengths[j][i]) - sorted_fdb_range_nodes = [new_fdb_range_nodes[i] for i in original_indices] - sorted_range_lengths = [new_range_lengths[i] for i in original_indices] - for i in range(len(sorted_fdb_range_nodes)): - for j in range(sorted_range_lengths[i]): - n = sorted_fdb_range_nodes[i][j] - n.result = output_values[0][0][i][0][j] - def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length): interm_request_ranges = [] for i in range(lat_length): @@ -318,22 +213,6 @@ def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length): original_indices, sorted_request_ranges = zip(*sorted_list) return (original_indices, sorted_request_ranges) - def find_fdb_values(self, path, range_lengths, current_start_idx, lat_length): - path.pop("values") - fdb_requests = [] - interm_request_ranges = [] - for i in range(lat_length): - for j in range(len(range_lengths[i])): - if current_start_idx[i][j] is not None: - current_request_ranges = (current_start_idx[i][j], current_start_idx[i][j] + range_lengths[i][j]) - interm_request_ranges.append(current_request_ranges) - request_ranges_with_idx = list(enumerate(interm_request_ranges)) - sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0]) - original_indices, sorted_request_ranges = zip(*sorted_list) - fdb_requests.append(tuple((path, sorted_request_ranges))) - output_values = self.fdb.extract(fdb_requests) - return (output_values, original_indices) - def datacube_natural_indexes(self, axis, subarray): indexes = subarray[axis.name] return indexes